From d68da34eec3d3512cc6fb5b254665f0462fb11f9 Mon Sep 17 00:00:00 2001 From: Steven Date: Tue, 26 Sep 2023 19:17:17 +0800 Subject: [PATCH] refactor: migrate idp to driver --- store/driver.go | 6 ++ store/idp.go | 150 ++-------------------------------- store/sqlite/idp.go | 190 ++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 204 insertions(+), 142 deletions(-) create mode 100644 store/sqlite/idp.go diff --git a/store/driver.go b/store/driver.go index b3bd77f9..30f082e6 100644 --- a/store/driver.go +++ b/store/driver.go @@ -26,4 +26,10 @@ type Driver interface { ListUserSettings(ctx context.Context, find *FindUserSetting) ([]*UserSetting, error) UpsertUserSettingV1(ctx context.Context, upsert *storepb.UserSetting) (*storepb.UserSetting, error) ListUserSettingsV1(ctx context.Context, find *FindUserSettingV1) ([]*storepb.UserSetting, error) + + CreateIdentityProvider(ctx context.Context, create *IdentityProvider) (*IdentityProvider, error) + ListIdentityProviders(ctx context.Context, find *FindIdentityProvider) ([]*IdentityProvider, error) + GetIdentityProvider(ctx context.Context, find *FindIdentityProvider) (*IdentityProvider, error) + UpdateIdentityProvider(ctx context.Context, update *UpdateIdentityProvider) (*IdentityProvider, error) + DeleteIdentityProvider(ctx context.Context, delete *DeleteIdentityProvider) error } diff --git a/store/idp.go b/store/idp.go index a0e31f0f..bb0d0976 100644 --- a/store/idp.go +++ b/store/idp.go @@ -2,11 +2,6 @@ package store import ( "context" - "encoding/json" - "fmt" - "strings" - - "github.com/pkg/errors" ) type IdentityProviderType string @@ -64,98 +59,20 @@ type DeleteIdentityProvider struct { } func (s *Store) CreateIdentityProvider(ctx context.Context, create *IdentityProvider) (*IdentityProvider, error) { - var configBytes []byte - if create.Type == IdentityProviderOAuth2Type { - bytes, err := json.Marshal(create.Config.OAuth2Config) - if err != nil { - return nil, err - } - configBytes = bytes - } else { - return nil, errors.Errorf("unsupported idp type %s", string(create.Type)) - } - - stmt := ` - INSERT INTO idp ( - name, - type, - identifier_filter, - config - ) - VALUES (?, ?, ?, ?) - RETURNING id - ` - if err := s.db.QueryRowContext( - ctx, - stmt, - create.Name, - create.Type, - create.IdentifierFilter, - string(configBytes), - ).Scan( - &create.ID, - ); err != nil { + identityProvider, err := s.driver.CreateIdentityProvider(ctx, create) + if err != nil { return nil, err } - identityProvider := create s.idpCache.Store(identityProvider.ID, identityProvider) return identityProvider, nil } func (s *Store) ListIdentityProviders(ctx context.Context, find *FindIdentityProvider) ([]*IdentityProvider, error) { - where, args := []string{"1 = 1"}, []any{} - if v := find.ID; v != nil { - where, args = append(where, fmt.Sprintf("id = $%d", len(args)+1)), append(args, *v) - } - - rows, err := s.db.QueryContext(ctx, ` - SELECT - id, - name, - type, - identifier_filter, - config - FROM idp - WHERE `+strings.Join(where, " AND ")+` ORDER BY id ASC`, - args..., - ) + identityProviders, err := s.driver.ListIdentityProviders(ctx, find) if err != nil { return nil, err } - defer rows.Close() - - var identityProviders []*IdentityProvider - for rows.Next() { - var identityProvider IdentityProvider - var identityProviderConfig string - if err := rows.Scan( - &identityProvider.ID, - &identityProvider.Name, - &identityProvider.Type, - &identityProvider.IdentifierFilter, - &identityProviderConfig, - ); err != nil { - return nil, err - } - - if identityProvider.Type == IdentityProviderOAuth2Type { - oauth2Config := &IdentityProviderOAuth2Config{} - if err := json.Unmarshal([]byte(identityProviderConfig), oauth2Config); err != nil { - return nil, err - } - identityProvider.Config = &IdentityProviderConfig{ - OAuth2Config: oauth2Config, - } - } else { - return nil, errors.Errorf("unsupported idp type %s", string(identityProvider.Type)) - } - identityProviders = append(identityProviders, &identityProvider) - } - - if err := rows.Err(); err != nil { - return nil, err - } for _, item := range identityProviders { s.idpCache.Store(item.ID, item) @@ -184,72 +101,21 @@ func (s *Store) GetIdentityProvider(ctx context.Context, find *FindIdentityProvi } func (s *Store) UpdateIdentityProvider(ctx context.Context, update *UpdateIdentityProvider) (*IdentityProvider, error) { - set, args := []string{}, []any{} - if v := update.Name; v != nil { - set, args = append(set, "name = ?"), append(args, *v) - } - if v := update.IdentifierFilter; v != nil { - set, args = append(set, "identifier_filter = ?"), append(args, *v) - } - if v := update.Config; v != nil { - var configBytes []byte - if update.Type == IdentityProviderOAuth2Type { - bytes, err := json.Marshal(update.Config.OAuth2Config) - if err != nil { - return nil, err - } - configBytes = bytes - } else { - return nil, errors.Errorf("unsupported idp type %s", string(update.Type)) - } - set, args = append(set, "config = ?"), append(args, string(configBytes)) - } - args = append(args, update.ID) - - stmt := ` - UPDATE idp - SET ` + strings.Join(set, ", ") + ` - WHERE id = ? - RETURNING id, name, type, identifier_filter, config - ` - var identityProvider IdentityProvider - var identityProviderConfig string - if err := s.db.QueryRowContext(ctx, stmt, args...).Scan( - &identityProvider.ID, - &identityProvider.Name, - &identityProvider.Type, - &identityProvider.IdentifierFilter, - &identityProviderConfig, - ); err != nil { + identityProvider, err := s.driver.UpdateIdentityProvider(ctx, update) + if err != nil { return nil, err } - if identityProvider.Type == IdentityProviderOAuth2Type { - oauth2Config := &IdentityProviderOAuth2Config{} - if err := json.Unmarshal([]byte(identityProviderConfig), oauth2Config); err != nil { - return nil, err - } - identityProvider.Config = &IdentityProviderConfig{ - OAuth2Config: oauth2Config, - } - } else { - return nil, errors.Errorf("unsupported idp type %s", string(identityProvider.Type)) - } - s.idpCache.Store(identityProvider.ID, identityProvider) - return &identityProvider, nil + return identityProvider, nil } func (s *Store) DeleteIdentityProvider(ctx context.Context, delete *DeleteIdentityProvider) error { - where, args := []string{"id = ?"}, []any{delete.ID} - stmt := `DELETE FROM idp WHERE ` + strings.Join(where, " AND ") - result, err := s.db.ExecContext(ctx, stmt, args...) + err := s.driver.DeleteIdentityProvider(ctx, delete) if err != nil { return err } - if _, err = result.RowsAffected(); err != nil { - return err - } + s.idpCache.Delete(delete.ID) return nil } diff --git a/store/sqlite/idp.go b/store/sqlite/idp.go new file mode 100644 index 00000000..31d48768 --- /dev/null +++ b/store/sqlite/idp.go @@ -0,0 +1,190 @@ +package sqlite + +import ( + "context" + "encoding/json" + "fmt" + "strings" + + "github.com/pkg/errors" + + "github.com/usememos/memos/store" +) + +func (d *Driver) CreateIdentityProvider(ctx context.Context, create *store.IdentityProvider) (*store.IdentityProvider, error) { + var configBytes []byte + if create.Type == store.IdentityProviderOAuth2Type { + bytes, err := json.Marshal(create.Config.OAuth2Config) + if err != nil { + return nil, err + } + configBytes = bytes + } else { + return nil, errors.Errorf("unsupported idp type %s", string(create.Type)) + } + + stmt := ` + INSERT INTO idp ( + name, + type, + identifier_filter, + config + ) + VALUES (?, ?, ?, ?) + RETURNING id + ` + if err := d.db.QueryRowContext( + ctx, + stmt, + create.Name, + create.Type, + create.IdentifierFilter, + string(configBytes), + ).Scan( + &create.ID, + ); err != nil { + return nil, err + } + + identityProvider := create + return identityProvider, nil +} + +func (d *Driver) ListIdentityProviders(ctx context.Context, find *store.FindIdentityProvider) ([]*store.IdentityProvider, error) { + where, args := []string{"1 = 1"}, []any{} + if v := find.ID; v != nil { + where, args = append(where, fmt.Sprintf("id = $%d", len(args)+1)), append(args, *v) + } + + rows, err := d.db.QueryContext(ctx, ` + SELECT + id, + name, + type, + identifier_filter, + config + FROM idp + WHERE `+strings.Join(where, " AND ")+` ORDER BY id ASC`, + args..., + ) + if err != nil { + return nil, err + } + defer rows.Close() + + var identityProviders []*store.IdentityProvider + for rows.Next() { + var identityProvider store.IdentityProvider + var identityProviderConfig string + if err := rows.Scan( + &identityProvider.ID, + &identityProvider.Name, + &identityProvider.Type, + &identityProvider.IdentifierFilter, + &identityProviderConfig, + ); err != nil { + return nil, err + } + + if identityProvider.Type == store.IdentityProviderOAuth2Type { + oauth2Config := &store.IdentityProviderOAuth2Config{} + if err := json.Unmarshal([]byte(identityProviderConfig), oauth2Config); err != nil { + return nil, err + } + identityProvider.Config = &store.IdentityProviderConfig{ + OAuth2Config: oauth2Config, + } + } else { + return nil, errors.Errorf("unsupported idp type %s", string(identityProvider.Type)) + } + identityProviders = append(identityProviders, &identityProvider) + } + + if err := rows.Err(); err != nil { + return nil, err + } + + return identityProviders, nil +} + +func (d *Driver) GetIdentityProvider(ctx context.Context, find *store.FindIdentityProvider) (*store.IdentityProvider, error) { + list, err := d.ListIdentityProviders(ctx, find) + if err != nil { + return nil, err + } + if len(list) == 0 { + return nil, nil + } + + identityProvider := list[0] + return identityProvider, nil +} + +func (d *Driver) UpdateIdentityProvider(ctx context.Context, update *store.UpdateIdentityProvider) (*store.IdentityProvider, error) { + set, args := []string{}, []any{} + if v := update.Name; v != nil { + set, args = append(set, "name = ?"), append(args, *v) + } + if v := update.IdentifierFilter; v != nil { + set, args = append(set, "identifier_filter = ?"), append(args, *v) + } + if v := update.Config; v != nil { + var configBytes []byte + if update.Type == store.IdentityProviderOAuth2Type { + bytes, err := json.Marshal(update.Config.OAuth2Config) + if err != nil { + return nil, err + } + configBytes = bytes + } else { + return nil, errors.Errorf("unsupported idp type %s", string(update.Type)) + } + set, args = append(set, "config = ?"), append(args, string(configBytes)) + } + args = append(args, update.ID) + + stmt := ` + UPDATE idp + SET ` + strings.Join(set, ", ") + ` + WHERE id = ? + RETURNING id, name, type, identifier_filter, config + ` + var identityProvider store.IdentityProvider + var identityProviderConfig string + if err := d.db.QueryRowContext(ctx, stmt, args...).Scan( + &identityProvider.ID, + &identityProvider.Name, + &identityProvider.Type, + &identityProvider.IdentifierFilter, + &identityProviderConfig, + ); err != nil { + return nil, err + } + + if identityProvider.Type == store.IdentityProviderOAuth2Type { + oauth2Config := &store.IdentityProviderOAuth2Config{} + if err := json.Unmarshal([]byte(identityProviderConfig), oauth2Config); err != nil { + return nil, err + } + identityProvider.Config = &store.IdentityProviderConfig{ + OAuth2Config: oauth2Config, + } + } else { + return nil, errors.Errorf("unsupported idp type %s", string(identityProvider.Type)) + } + + return &identityProvider, nil +} + +func (d *Driver) DeleteIdentityProvider(ctx context.Context, delete *store.DeleteIdentityProvider) error { + where, args := []string{"id = ?"}, []any{delete.ID} + stmt := `DELETE FROM idp WHERE ` + strings.Join(where, " AND ") + result, err := d.db.ExecContext(ctx, stmt, args...) + if err != nil { + return err + } + if _, err = result.RowsAffected(); err != nil { + return err + } + return nil +}