chore: update idp store (#1856)

pull/1862/head
boojack 2 years ago committed by GitHub
parent b44f2b5ffb
commit 7226a9ad47
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -74,16 +74,19 @@ func (s *APIV1Service) registerAuthRoutes(g *echo.Group, secret string) {
return echo.NewHTTPError(http.StatusBadRequest, "Malformatted signin request").SetInternal(err)
}
identityProviderMessage, err := s.Store.GetIdentityProvider(ctx, &store.FindIdentityProviderMessage{
identityProvider, err := s.Store.GetIdentityProvider(ctx, &store.FindIdentityProvider{
ID: &signin.IdentityProviderID,
})
if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to find identity provider").SetInternal(err)
}
if identityProvider == nil {
return echo.NewHTTPError(http.StatusNotFound, "Identity provider not found")
}
var userInfo *idp.IdentityProviderUserInfo
if identityProviderMessage.Type == store.IdentityProviderOAuth2 {
oauth2IdentityProvider, err := oauth2.NewIdentityProvider(identityProviderMessage.Config.OAuth2Config)
if identityProvider.Type == store.IdentityProviderOAuth2 {
oauth2IdentityProvider, err := oauth2.NewIdentityProvider(identityProvider.Config.OAuth2Config)
if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to create identity provider instance").SetInternal(err)
}
@ -97,7 +100,7 @@ func (s *APIV1Service) registerAuthRoutes(g *echo.Group, secret string) {
}
}
identifierFilter := identityProviderMessage.IdentifierFilter
identifierFilter := identityProvider.IdentifierFilter
if identifierFilter != "" {
identifierFilterRegex, err := regexp.Compile(identifierFilter)
if err != nil {

@ -83,7 +83,7 @@ func (s *APIV1Service) registerIdentityProviderRoutes(g *echo.Group) {
return echo.NewHTTPError(http.StatusBadRequest, "Malformatted post identity provider request").SetInternal(err)
}
identityProviderMessage, err := s.Store.CreateIdentityProvider(ctx, &store.IdentityProviderMessage{
identityProvider, err := s.Store.CreateIdentityProvider(ctx, &store.IdentityProvider{
Name: identityProviderCreate.Name,
Type: store.IdentityProviderType(identityProviderCreate.Type),
IdentifierFilter: identityProviderCreate.IdentifierFilter,
@ -92,7 +92,7 @@ func (s *APIV1Service) registerIdentityProviderRoutes(g *echo.Group) {
if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to create identity provider").SetInternal(err)
}
return c.JSON(http.StatusOK, convertIdentityProviderFromStore(identityProviderMessage))
return c.JSON(http.StatusOK, convertIdentityProviderFromStore(identityProvider))
})
g.PATCH("/idp/:idpId", func(c echo.Context) error {
@ -124,7 +124,7 @@ func (s *APIV1Service) registerIdentityProviderRoutes(g *echo.Group) {
return echo.NewHTTPError(http.StatusBadRequest, "Malformatted patch identity provider request").SetInternal(err)
}
identityProviderMessage, err := s.Store.UpdateIdentityProvider(ctx, &store.UpdateIdentityProviderMessage{
identityProvider, err := s.Store.UpdateIdentityProvider(ctx, &store.UpdateIdentityProvider{
ID: identityProviderPatch.ID,
Type: store.IdentityProviderType(identityProviderPatch.Type),
Name: identityProviderPatch.Name,
@ -134,12 +134,12 @@ func (s *APIV1Service) registerIdentityProviderRoutes(g *echo.Group) {
if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to patch identity provider").SetInternal(err)
}
return c.JSON(http.StatusOK, convertIdentityProviderFromStore(identityProviderMessage))
return c.JSON(http.StatusOK, convertIdentityProviderFromStore(identityProvider))
})
g.GET("/idp", func(c echo.Context) error {
ctx := c.Request().Context()
identityProviderMessageList, err := s.Store.ListIdentityProviders(ctx, &store.FindIdentityProviderMessage{})
list, err := s.Store.ListIdentityProviders(ctx, &store.FindIdentityProvider{})
if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to find identity provider list").SetInternal(err)
}
@ -159,8 +159,8 @@ func (s *APIV1Service) registerIdentityProviderRoutes(g *echo.Group) {
}
identityProviderList := []*IdentityProvider{}
for _, identityProviderMessage := range identityProviderMessageList {
identityProvider := convertIdentityProviderFromStore(identityProviderMessage)
for _, item := range list {
identityProvider := convertIdentityProviderFromStore(item)
// data desensitize
if !isHostUser {
identityProvider.Config.OAuth2Config.ClientSecret = ""
@ -191,13 +191,17 @@ func (s *APIV1Service) registerIdentityProviderRoutes(g *echo.Group) {
if err != nil {
return echo.NewHTTPError(http.StatusBadRequest, fmt.Sprintf("ID is not a number: %s", c.Param("idpId"))).SetInternal(err)
}
identityProviderMessage, err := s.Store.GetIdentityProvider(ctx, &store.FindIdentityProviderMessage{
identityProvider, err := s.Store.GetIdentityProvider(ctx, &store.FindIdentityProvider{
ID: &identityProviderID,
})
if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to get identity provider").SetInternal(err)
}
return c.JSON(http.StatusOK, convertIdentityProviderFromStore(identityProviderMessage))
if identityProvider == nil {
return echo.NewHTTPError(http.StatusNotFound, "Identity provider not found")
}
return c.JSON(http.StatusOK, convertIdentityProviderFromStore(identityProvider))
})
g.DELETE("/idp/:idpId", func(c echo.Context) error {
@ -222,7 +226,7 @@ func (s *APIV1Service) registerIdentityProviderRoutes(g *echo.Group) {
return echo.NewHTTPError(http.StatusBadRequest, fmt.Sprintf("ID is not a number: %s", c.Param("idpId"))).SetInternal(err)
}
if err = s.Store.DeleteIdentityProvider(ctx, &store.DeleteIdentityProviderMessage{ID: identityProviderID}); err != nil {
if err = s.Store.DeleteIdentityProvider(ctx, &store.DeleteIdentityProvider{ID: identityProviderID}); err != nil {
if common.ErrorCode(err) == common.NotFound {
return echo.NewHTTPError(http.StatusNotFound, fmt.Sprintf("Identity provider ID not found: %d", identityProviderID))
}
@ -232,13 +236,13 @@ func (s *APIV1Service) registerIdentityProviderRoutes(g *echo.Group) {
})
}
func convertIdentityProviderFromStore(identityProviderMessage *store.IdentityProviderMessage) *IdentityProvider {
func convertIdentityProviderFromStore(identityProvider *store.IdentityProvider) *IdentityProvider {
return &IdentityProvider{
ID: identityProviderMessage.ID,
Name: identityProviderMessage.Name,
Type: IdentityProviderType(identityProviderMessage.Type),
IdentifierFilter: identityProviderMessage.IdentifierFilter,
Config: convertIdentityProviderConfigFromStore(identityProviderMessage.Config),
ID: identityProvider.ID,
Name: identityProvider.Name,
Type: IdentityProviderType(identityProvider.Type),
IdentifierFilter: identityProvider.IdentifierFilter,
Config: convertIdentityProviderConfigFromStore(identityProvider.Config),
}
}

@ -6,8 +6,6 @@ import (
"encoding/json"
"fmt"
"strings"
"github.com/usememos/memos/common"
)
type IdentityProviderType string
@ -36,7 +34,7 @@ type FieldMapping struct {
Email string `json:"email"`
}
type IdentityProviderMessage struct {
type IdentityProvider struct {
ID int
Name string
Type IdentityProviderType
@ -44,11 +42,11 @@ type IdentityProviderMessage struct {
Config *IdentityProviderConfig
}
type FindIdentityProviderMessage struct {
type FindIdentityProvider struct {
ID *int
}
type UpdateIdentityProviderMessage struct {
type UpdateIdentityProvider struct {
ID int
Type IdentityProviderType
Name *string
@ -56,14 +54,14 @@ type UpdateIdentityProviderMessage struct {
Config *IdentityProviderConfig
}
type DeleteIdentityProviderMessage struct {
type DeleteIdentityProvider struct {
ID int
}
func (s *Store) CreateIdentityProvider(ctx context.Context, create *IdentityProviderMessage) (*IdentityProviderMessage, error) {
func (s *Store) CreateIdentityProvider(ctx context.Context, create *IdentityProvider) (*IdentityProvider, error) {
tx, err := s.db.BeginTx(ctx, nil)
if err != nil {
return nil, FormatError(err)
return nil, err
}
defer tx.Rollback()
@ -76,6 +74,7 @@ func (s *Store) CreateIdentityProvider(ctx context.Context, create *IdentityProv
} else {
return nil, fmt.Errorf("unsupported idp type %s", string(create.Type))
}
query := `
INSERT INTO idp (
name,
@ -96,20 +95,22 @@ func (s *Store) CreateIdentityProvider(ctx context.Context, create *IdentityProv
).Scan(
&create.ID,
); err != nil {
return nil, FormatError(err)
return nil, err
}
if err := tx.Commit(); err != nil {
return nil, FormatError(err)
return nil, err
}
identityProviderMessage := create
s.idpCache.Store(identityProviderMessage.ID, identityProviderMessage)
return identityProviderMessage, nil
identityProvider := create
s.idpCache.Store(identityProvider.ID, identityProvider)
return identityProvider, nil
}
func (s *Store) ListIdentityProviders(ctx context.Context, find *FindIdentityProviderMessage) ([]*IdentityProviderMessage, error) {
func (s *Store) ListIdentityProviders(ctx context.Context, find *FindIdentityProvider) ([]*IdentityProvider, error) {
tx, err := s.db.BeginTx(ctx, nil)
if err != nil {
return nil, FormatError(err)
return nil, err
}
defer tx.Rollback()
@ -124,16 +125,16 @@ func (s *Store) ListIdentityProviders(ctx context.Context, find *FindIdentityPro
return list, nil
}
func (s *Store) GetIdentityProvider(ctx context.Context, find *FindIdentityProviderMessage) (*IdentityProviderMessage, error) {
func (s *Store) GetIdentityProvider(ctx context.Context, find *FindIdentityProvider) (*IdentityProvider, error) {
if find.ID != nil {
if cache, ok := s.idpCache.Load(*find.ID); ok {
return cache.(*IdentityProviderMessage), nil
return cache.(*IdentityProvider), nil
}
}
tx, err := s.db.BeginTx(ctx, nil)
if err != nil {
return nil, FormatError(err)
return nil, err
}
defer tx.Rollback()
@ -142,18 +143,18 @@ func (s *Store) GetIdentityProvider(ctx context.Context, find *FindIdentityProvi
return nil, err
}
if len(list) == 0 {
return nil, &common.Error{Code: common.NotFound, Err: fmt.Errorf("not found")}
return nil, nil
}
identityProviderMessage := list[0]
s.idpCache.Store(identityProviderMessage.ID, identityProviderMessage)
return identityProviderMessage, nil
identityProvider := list[0]
s.idpCache.Store(identityProvider.ID, identityProvider)
return identityProvider, nil
}
func (s *Store) UpdateIdentityProvider(ctx context.Context, update *UpdateIdentityProviderMessage) (*IdentityProviderMessage, error) {
func (s *Store) UpdateIdentityProvider(ctx context.Context, update *UpdateIdentityProvider) (*IdentityProvider, error) {
tx, err := s.db.BeginTx(ctx, nil)
if err != nil {
return nil, FormatError(err)
return nil, err
}
defer tx.Rollback()
@ -184,39 +185,42 @@ func (s *Store) UpdateIdentityProvider(ctx context.Context, update *UpdateIdenti
WHERE id = ?
RETURNING id, name, type, identifier_filter, config
`
var identityProviderMessage IdentityProviderMessage
var identityProvider IdentityProvider
var identityProviderConfig string
if err := tx.QueryRowContext(ctx, query, args...).Scan(
&identityProviderMessage.ID,
&identityProviderMessage.Name,
&identityProviderMessage.Type,
&identityProviderMessage.IdentifierFilter,
&identityProvider.ID,
&identityProvider.Name,
&identityProvider.Type,
&identityProvider.IdentifierFilter,
&identityProviderConfig,
); err != nil {
return nil, FormatError(err)
return nil, err
}
if identityProviderMessage.Type == IdentityProviderOAuth2 {
if identityProvider.Type == IdentityProviderOAuth2 {
oauth2Config := &IdentityProviderOAuth2Config{}
if err := json.Unmarshal([]byte(identityProviderConfig), oauth2Config); err != nil {
return nil, err
}
identityProviderMessage.Config = &IdentityProviderConfig{
identityProvider.Config = &IdentityProviderConfig{
OAuth2Config: oauth2Config,
}
} else {
return nil, fmt.Errorf("unsupported idp type %s", string(identityProviderMessage.Type))
return nil, fmt.Errorf("unsupported idp type %s", string(identityProvider.Type))
}
if err := tx.Commit(); err != nil {
return nil, FormatError(err)
return nil, err
}
s.idpCache.Store(identityProviderMessage.ID, identityProviderMessage)
return &identityProviderMessage, nil
s.idpCache.Store(identityProvider.ID, identityProvider)
return &identityProvider, nil
}
func (s *Store) DeleteIdentityProvider(ctx context.Context, delete *DeleteIdentityProviderMessage) error {
func (s *Store) DeleteIdentityProvider(ctx context.Context, delete *DeleteIdentityProvider) error {
tx, err := s.db.BeginTx(ctx, nil)
if err != nil {
return FormatError(err)
return err
}
defer tx.Rollback()
@ -224,24 +228,22 @@ func (s *Store) DeleteIdentityProvider(ctx context.Context, delete *DeleteIdenti
stmt := `DELETE FROM idp WHERE ` + strings.Join(where, " AND ")
result, err := tx.ExecContext(ctx, stmt, args...)
if err != nil {
return FormatError(err)
return err
}
rows, err := result.RowsAffected()
if err != nil {
if _, err = result.RowsAffected(); err != nil {
return err
}
if rows == 0 {
return &common.Error{Code: common.NotFound, Err: fmt.Errorf("idp not found")}
}
if err := tx.Commit(); err != nil {
return err
}
s.idpCache.Delete(delete.ID)
return nil
}
func listIdentityProviders(ctx context.Context, tx *sql.Tx, find *FindIdentityProviderMessage) ([]*IdentityProviderMessage, error) {
func listIdentityProviders(ctx context.Context, tx *sql.Tx, find *FindIdentityProvider) ([]*IdentityProvider, error) {
where, args := []string{"TRUE"}, []any{}
if v := find.ID; v != nil {
where, args = append(where, fmt.Sprintf("id = $%d", len(args)+1)), append(args, *v)
@ -259,40 +261,41 @@ func listIdentityProviders(ctx context.Context, tx *sql.Tx, find *FindIdentityPr
args...,
)
if err != nil {
return nil, FormatError(err)
return nil, err
}
defer rows.Close()
var identityProviderMessages []*IdentityProviderMessage
var identityProviders []*IdentityProvider
for rows.Next() {
var identityProviderMessage IdentityProviderMessage
var identityProvider IdentityProvider
var identityProviderConfig string
if err := rows.Scan(
&identityProviderMessage.ID,
&identityProviderMessage.Name,
&identityProviderMessage.Type,
&identityProviderMessage.IdentifierFilter,
&identityProvider.ID,
&identityProvider.Name,
&identityProvider.Type,
&identityProvider.IdentifierFilter,
&identityProviderConfig,
); err != nil {
return nil, FormatError(err)
return nil, err
}
if identityProviderMessage.Type == IdentityProviderOAuth2 {
if identityProvider.Type == IdentityProviderOAuth2 {
oauth2Config := &IdentityProviderOAuth2Config{}
if err := json.Unmarshal([]byte(identityProviderConfig), oauth2Config); err != nil {
return nil, err
}
identityProviderMessage.Config = &IdentityProviderConfig{
identityProvider.Config = &IdentityProviderConfig{
OAuth2Config: oauth2Config,
}
} else {
return nil, fmt.Errorf("unsupported idp type %s", string(identityProviderMessage.Type))
return nil, fmt.Errorf("unsupported idp type %s", string(identityProvider.Type))
}
identityProviderMessages = append(identityProviderMessages, &identityProviderMessage)
identityProviders = append(identityProviders, &identityProvider)
}
if err := rows.Err(); err != nil {
return nil, err
}
return identityProviderMessages, nil
return identityProviders, nil
}

@ -16,7 +16,7 @@ type Store struct {
userCache sync.Map // map[int]*userRaw
userSettingCache sync.Map // map[string]*UserSettingMessage
shortcutCache sync.Map // map[int]*shortcutRaw
idpCache sync.Map // map[int]*IdentityProviderMessage
idpCache sync.Map // map[int]*IdentityProvider
resourceCache sync.Map // map[int]*resourceRaw
}

@ -12,14 +12,14 @@ import (
func TestIdentityProviderStore(t *testing.T) {
ctx := context.Background()
ts := NewTestingStore(ctx, t)
createdIDP, err := ts.CreateIdentityProvider(ctx, &store.IdentityProviderMessage{
createdIDP, err := ts.CreateIdentityProvider(ctx, &store.IdentityProvider{
Name: "GitHub OAuth",
Type: store.IdentityProviderOAuth2,
IdentifierFilter: "",
Config: &store.IdentityProviderConfig{
OAuth2Config: &store.IdentityProviderOAuth2Config{
ClientID: "asd",
ClientSecret: "123",
ClientID: "client_id",
ClientSecret: "client_secret",
AuthURL: "https://github.com/auth",
TokenURL: "https://github.com/token",
UserInfoURL: "https://github.com/user",
@ -33,16 +33,23 @@ func TestIdentityProviderStore(t *testing.T) {
},
})
require.NoError(t, err)
idp, err := ts.GetIdentityProvider(ctx, &store.FindIdentityProviderMessage{
idp, err := ts.GetIdentityProvider(ctx, &store.FindIdentityProvider{
ID: &createdIDP.ID,
})
require.NoError(t, err)
require.Equal(t, createdIDP, idp)
err = ts.DeleteIdentityProvider(ctx, &store.DeleteIdentityProviderMessage{
newName := "My GitHub OAuth"
updatedIdp, err := ts.UpdateIdentityProvider(ctx, &store.UpdateIdentityProvider{
ID: idp.ID,
Name: &newName,
})
require.NoError(t, err)
require.Equal(t, newName, updatedIdp.Name)
err = ts.DeleteIdentityProvider(ctx, &store.DeleteIdentityProvider{
ID: idp.ID,
})
require.NoError(t, err)
idpList, err := ts.ListIdentityProviders(ctx, &store.FindIdentityProviderMessage{})
idpList, err := ts.ListIdentityProviders(ctx, &store.FindIdentityProvider{})
require.NoError(t, err)
require.Equal(t, 0, len(idpList))
}

Loading…
Cancel
Save