mirror of https://github.com/usememos/memos
feat(auth): add SSO user identity linkage (#5883)
parent
50638040f6
commit
d688914b28
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,20 @@
|
||||
package v1
|
||||
|
||||
import (
|
||||
"github.com/pkg/errors"
|
||||
|
||||
"github.com/usememos/memos/internal/util"
|
||||
)
|
||||
|
||||
// deriveSSOUsername produces the local username for a new SSO-created user.
|
||||
//
|
||||
// The current policy is to use a standard UUID string directly. This keeps the
|
||||
// username independent of IdP profile fields and avoids availability probes or
|
||||
// retry loops around concurrent first-time logins.
|
||||
func deriveSSOUsername() (string, error) {
|
||||
username := util.GenUUID()
|
||||
if err := validateUsername(username); err != nil {
|
||||
return "", errors.Wrap(err, "generated UUID did not satisfy username constraints")
|
||||
}
|
||||
return username, nil
|
||||
}
|
||||
@ -0,0 +1,254 @@
|
||||
package test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
|
||||
v1pb "github.com/usememos/memos/proto/gen/api/v1"
|
||||
storepb "github.com/usememos/memos/proto/gen/store"
|
||||
apiv1 "github.com/usememos/memos/server/router/api/v1"
|
||||
"github.com/usememos/memos/store"
|
||||
)
|
||||
|
||||
func TestCreateLinkedIdentityBindsCurrentUser(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
|
||||
ctx := context.Background()
|
||||
currentUser, err := ts.CreateRegularUser(ctx, "alice")
|
||||
require.NoError(t, err)
|
||||
|
||||
mockIDP := newMockOAuthServer(t, "bind-code", "bind-access-token", map[string]any{
|
||||
"sub": "google-sub-1",
|
||||
"name": "Alice Example",
|
||||
"email": "alice@example.com",
|
||||
})
|
||||
defer mockIDP.Close()
|
||||
|
||||
idpName := createTestingOAuthIdentityProvider(ctx, t, ts, mockIDP.URL, "google-bind")
|
||||
beforeUsers, err := ts.Store.ListUsers(ctx, &store.FindUser{})
|
||||
require.NoError(t, err)
|
||||
|
||||
authCtx := ts.CreateUserContext(apiv1.WithHeaderCarrier(ctx), currentUser.ID)
|
||||
response, err := ts.Service.CreateLinkedIdentity(authCtx, &v1pb.CreateLinkedIdentityRequest{
|
||||
Parent: apiv1.BuildUserName(currentUser.Username),
|
||||
IdpName: idpName,
|
||||
Code: "bind-code",
|
||||
RedirectUri: "http://localhost:8080/auth/callback",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, response)
|
||||
require.Equal(t, apiv1.BuildUserName(currentUser.Username)+"/linkedIdentities/google-bind", response.Name)
|
||||
require.Equal(t, apiv1.IdentityProviderNamePrefix+"google-bind", response.IdpName)
|
||||
require.Equal(t, "google-sub-1", response.ExternUid)
|
||||
|
||||
afterUsers, err := ts.Store.ListUsers(ctx, &store.FindUser{})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, afterUsers, len(beforeUsers))
|
||||
|
||||
provider := "google-bind"
|
||||
externUID := "google-sub-1"
|
||||
identity, err := ts.Store.GetUserIdentity(ctx, &store.FindUserIdentity{
|
||||
Provider: &provider,
|
||||
ExternUID: &externUID,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, identity)
|
||||
require.Equal(t, currentUser.ID, identity.UserID)
|
||||
}
|
||||
|
||||
func TestCreateLinkedIdentityRejectsBindingIdentityLinkedToAnotherUser(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
|
||||
ctx := context.Background()
|
||||
owner, err := ts.CreateRegularUser(ctx, "owner")
|
||||
require.NoError(t, err)
|
||||
binder, err := ts.CreateRegularUser(ctx, "binder")
|
||||
require.NoError(t, err)
|
||||
|
||||
mockIDP := newMockOAuthServer(t, "conflict-code", "conflict-access-token", map[string]any{
|
||||
"sub": "google-sub-2",
|
||||
"name": "Conflict Example",
|
||||
"email": "conflict@example.com",
|
||||
})
|
||||
defer mockIDP.Close()
|
||||
|
||||
idpName := createTestingOAuthIdentityProvider(ctx, t, ts, mockIDP.URL, "google-conflict")
|
||||
_, err = ts.Store.CreateUserIdentity(ctx, &store.UserIdentity{
|
||||
UserID: owner.ID,
|
||||
Provider: "google-conflict",
|
||||
ExternUID: "google-sub-2",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
authCtx := ts.CreateUserContext(apiv1.WithHeaderCarrier(ctx), binder.ID)
|
||||
_, err = ts.Service.CreateLinkedIdentity(authCtx, &v1pb.CreateLinkedIdentityRequest{
|
||||
Parent: apiv1.BuildUserName(binder.Username),
|
||||
IdpName: idpName,
|
||||
Code: "conflict-code",
|
||||
RedirectUri: "http://localhost:8080/auth/callback",
|
||||
})
|
||||
require.Error(t, err)
|
||||
require.Equal(t, codes.AlreadyExists, status.Code(err))
|
||||
}
|
||||
|
||||
func TestListAndDeleteLinkedIdentities(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
|
||||
ctx := context.Background()
|
||||
currentUser, err := ts.CreateRegularUser(ctx, "alice")
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = ts.Store.CreateUserIdentity(ctx, &store.UserIdentity{
|
||||
UserID: currentUser.ID,
|
||||
Provider: "google",
|
||||
ExternUID: "alice@gmail.com",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
authCtx := ts.CreateUserContext(ctx, currentUser.ID)
|
||||
listResp, err := ts.Service.ListLinkedIdentities(authCtx, &v1pb.ListLinkedIdentitiesRequest{
|
||||
Parent: apiv1.BuildUserName(currentUser.Username),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, listResp.LinkedIdentities, 1)
|
||||
linkedIdentityName := apiv1.BuildUserName(currentUser.Username) + "/linkedIdentities/google"
|
||||
require.Equal(t, linkedIdentityName, listResp.LinkedIdentities[0].Name)
|
||||
require.Equal(t, apiv1.IdentityProviderNamePrefix+"google", listResp.LinkedIdentities[0].IdpName)
|
||||
require.Equal(t, "alice@gmail.com", listResp.LinkedIdentities[0].ExternUid)
|
||||
|
||||
got, err := ts.Service.GetLinkedIdentity(authCtx, &v1pb.GetLinkedIdentityRequest{
|
||||
Name: linkedIdentityName,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, linkedIdentityName, got.Name)
|
||||
require.Equal(t, apiv1.IdentityProviderNamePrefix+"google", got.IdpName)
|
||||
require.Equal(t, "alice@gmail.com", got.ExternUid)
|
||||
|
||||
_, err = ts.Service.DeleteLinkedIdentity(authCtx, &v1pb.DeleteLinkedIdentityRequest{
|
||||
Name: linkedIdentityName,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
listResp, err = ts.Service.ListLinkedIdentities(authCtx, &v1pb.ListLinkedIdentitiesRequest{
|
||||
Parent: apiv1.BuildUserName(currentUser.Username),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Empty(t, listResp.LinkedIdentities)
|
||||
}
|
||||
|
||||
func TestCreateLinkedIdentityRejectsSecondIdentityForSameProvider(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
|
||||
ctx := context.Background()
|
||||
currentUser, err := ts.CreateRegularUser(ctx, "alice")
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = ts.Store.CreateUserIdentity(ctx, &store.UserIdentity{
|
||||
UserID: currentUser.ID,
|
||||
Provider: "google-provider",
|
||||
ExternUID: "google-sub-1",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
mockIDP := newMockOAuthServer(t, "second-code", "second-access-token", map[string]any{
|
||||
"sub": "google-sub-2",
|
||||
"name": "Alice Example",
|
||||
"email": "alice@example.com",
|
||||
})
|
||||
defer mockIDP.Close()
|
||||
|
||||
idpName := createTestingOAuthIdentityProvider(ctx, t, ts, mockIDP.URL, "google-provider")
|
||||
authCtx := ts.CreateUserContext(apiv1.WithHeaderCarrier(ctx), currentUser.ID)
|
||||
|
||||
_, err = ts.Service.CreateLinkedIdentity(authCtx, &v1pb.CreateLinkedIdentityRequest{
|
||||
Parent: apiv1.BuildUserName(currentUser.Username),
|
||||
IdpName: idpName,
|
||||
Code: "second-code",
|
||||
RedirectUri: "http://localhost:8080/auth/callback",
|
||||
})
|
||||
require.Error(t, err)
|
||||
require.Equal(t, codes.AlreadyExists, status.Code(err))
|
||||
}
|
||||
|
||||
func createTestingOAuthIdentityProvider(ctx context.Context, t *testing.T, ts *TestService, serverURL, uid string) string {
|
||||
t.Helper()
|
||||
|
||||
idp, err := ts.Store.CreateIdentityProvider(ctx, &storepb.IdentityProvider{
|
||||
Uid: uid,
|
||||
Name: "Google",
|
||||
Type: storepb.IdentityProvider_OAUTH2,
|
||||
Config: &storepb.IdentityProviderConfig{
|
||||
Config: &storepb.IdentityProviderConfig_Oauth2Config{
|
||||
Oauth2Config: &storepb.OAuth2Config{
|
||||
ClientId: "test-client-id",
|
||||
ClientSecret: "test-client-secret",
|
||||
AuthUrl: serverURL + "/oauth2/authorize",
|
||||
TokenUrl: serverURL + "/oauth2/token",
|
||||
UserInfoUrl: serverURL + "/oauth2/userinfo",
|
||||
FieldMapping: &storepb.FieldMapping{
|
||||
Identifier: "sub",
|
||||
DisplayName: "name",
|
||||
Email: "email",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
return apiv1.IdentityProviderNamePrefix + idp.Uid
|
||||
}
|
||||
|
||||
func newMockOAuthServer(t *testing.T, code, accessToken string, userInfo map[string]any) *httptest.Server {
|
||||
t.Helper()
|
||||
|
||||
userInfoBytes, err := json.Marshal(userInfo)
|
||||
require.NoError(t, err)
|
||||
|
||||
mux := http.NewServeMux()
|
||||
mux.HandleFunc("/oauth2/token", func(w http.ResponseWriter, r *http.Request) {
|
||||
require.Equal(t, http.MethodPost, r.Method)
|
||||
|
||||
body, err := io.ReadAll(r.Body)
|
||||
require.NoError(t, err)
|
||||
values, err := url.ParseQuery(string(body))
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, code, values.Get("code"))
|
||||
require.Equal(t, "authorization_code", values.Get("grant_type"))
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
err = json.NewEncoder(w).Encode(map[string]any{
|
||||
"access_token": accessToken,
|
||||
"token_type": "Bearer",
|
||||
"expires_in": 3600,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
})
|
||||
mux.HandleFunc("/oauth2/userinfo", func(w http.ResponseWriter, _ *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, err := w.Write(userInfoBytes)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
|
||||
return httptest.NewServer(mux)
|
||||
}
|
||||
@ -0,0 +1,67 @@
|
||||
package test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
"google.golang.org/protobuf/types/known/timestamppb"
|
||||
|
||||
v1pb "github.com/usememos/memos/proto/gen/api/v1"
|
||||
storepb "github.com/usememos/memos/proto/gen/store"
|
||||
apiv1 "github.com/usememos/memos/server/router/api/v1"
|
||||
"github.com/usememos/memos/store"
|
||||
)
|
||||
|
||||
func TestDeleteUserSelfDeleteCleansAccountDataAndAuthCookies(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
|
||||
ctx := context.Background()
|
||||
user, err := ts.CreateRegularUser(ctx, "alice")
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = ts.Store.CreateUserIdentity(ctx, &store.UserIdentity{
|
||||
UserID: user.ID,
|
||||
Provider: "google",
|
||||
ExternUID: "alice-google-sub",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
err = ts.Store.AddUserRefreshToken(ctx, user.ID, &storepb.RefreshTokensUserSetting_RefreshToken{
|
||||
TokenId: "refresh-token-id",
|
||||
ExpiresAt: timestamppb.New(time.Now().Add(time.Hour)),
|
||||
CreatedAt: timestamppb.Now(),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
headerCtx := apiv1.WithHeaderCarrier(ctx)
|
||||
authCtx := ts.CreateUserContext(headerCtx, user.ID)
|
||||
_, err = ts.Service.DeleteUser(authCtx, &v1pb.DeleteUserRequest{
|
||||
Name: apiv1.BuildUserName(user.Username),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
deletedUser, err := ts.Store.GetUser(ctx, &store.FindUser{ID: &user.ID})
|
||||
require.NoError(t, err)
|
||||
require.Nil(t, deletedUser)
|
||||
|
||||
identities, err := ts.Store.ListUserIdentities(ctx, &store.FindUserIdentity{UserID: &user.ID})
|
||||
require.NoError(t, err)
|
||||
require.Empty(t, identities)
|
||||
|
||||
refreshSetting, err := ts.Store.GetUserSetting(ctx, &store.FindUserSetting{
|
||||
UserID: &user.ID,
|
||||
Key: storepb.UserSetting_REFRESH_TOKENS,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Nil(t, refreshSetting)
|
||||
|
||||
carrier := apiv1.GetHeaderCarrier(authCtx)
|
||||
require.NotNil(t, carrier)
|
||||
require.Contains(t, strings.ToLower(carrier.Get("Set-Cookie")), "memos_refresh=")
|
||||
}
|
||||
@ -0,0 +1,107 @@
|
||||
package mysql
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
|
||||
"github.com/usememos/memos/store"
|
||||
)
|
||||
|
||||
func (d *DB) CreateUserIdentity(ctx context.Context, create *store.UserIdentity) (*store.UserIdentity, error) {
|
||||
stmt := "INSERT INTO `user_identity` (`user_id`, `provider`, `extern_uid`) VALUES (?, ?, ?)"
|
||||
result, err := d.db.ExecContext(ctx, stmt, create.UserID, create.Provider, create.ExternUID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
rawID, err := result.LastInsertId()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
id := int32(rawID)
|
||||
list, err := d.ListUserIdentities(ctx, &store.FindUserIdentity{ID: &id})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(list) == 0 {
|
||||
return nil, errors.Errorf("failed to create user identity")
|
||||
}
|
||||
return list[0], nil
|
||||
}
|
||||
|
||||
func (d *DB) ListUserIdentities(ctx context.Context, find *store.FindUserIdentity) ([]*store.UserIdentity, error) {
|
||||
where, args := []string{"1 = 1"}, []any{}
|
||||
|
||||
if find.ID != nil {
|
||||
where, args = append(where, "`id` = ?"), append(args, *find.ID)
|
||||
}
|
||||
if find.UserID != nil {
|
||||
where, args = append(where, "`user_id` = ?"), append(args, *find.UserID)
|
||||
}
|
||||
if find.Provider != nil {
|
||||
where, args = append(where, "`provider` = ?"), append(args, *find.Provider)
|
||||
}
|
||||
if find.ExternUID != nil {
|
||||
where, args = append(where, "`extern_uid` = ?"), append(args, *find.ExternUID)
|
||||
}
|
||||
|
||||
rows, err := d.db.QueryContext(ctx, `
|
||||
SELECT
|
||||
id,
|
||||
user_id,
|
||||
provider,
|
||||
extern_uid,
|
||||
created_ts,
|
||||
updated_ts
|
||||
FROM user_identity
|
||||
WHERE `+strings.Join(where, " AND ")+`
|
||||
ORDER BY id ASC`,
|
||||
args...,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
list := []*store.UserIdentity{}
|
||||
for rows.Next() {
|
||||
ui := &store.UserIdentity{}
|
||||
if err := rows.Scan(
|
||||
&ui.ID,
|
||||
&ui.UserID,
|
||||
&ui.Provider,
|
||||
&ui.ExternUID,
|
||||
&ui.CreatedTs,
|
||||
&ui.UpdatedTs,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
list = append(list, ui)
|
||||
}
|
||||
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return list, nil
|
||||
}
|
||||
|
||||
func (d *DB) DeleteUserIdentities(ctx context.Context, delete *store.DeleteUserIdentity) error {
|
||||
where, args := []string{"1 = 1"}, []any{}
|
||||
|
||||
if delete.ID != nil {
|
||||
where, args = append(where, "`id` = ?"), append(args, *delete.ID)
|
||||
}
|
||||
if delete.UserID != nil {
|
||||
where, args = append(where, "`user_id` = ?"), append(args, *delete.UserID)
|
||||
}
|
||||
if delete.Provider != nil {
|
||||
where, args = append(where, "`provider` = ?"), append(args, *delete.Provider)
|
||||
}
|
||||
|
||||
if _, err := d.db.ExecContext(ctx, "DELETE FROM `user_identity` WHERE "+strings.Join(where, " AND "), args...); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@ -0,0 +1,95 @@
|
||||
package postgres
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
|
||||
"github.com/usememos/memos/store"
|
||||
)
|
||||
|
||||
func (d *DB) CreateUserIdentity(ctx context.Context, create *store.UserIdentity) (*store.UserIdentity, error) {
|
||||
stmt := "INSERT INTO user_identity (user_id, provider, extern_uid) VALUES (" + placeholders(3) + ") RETURNING id, created_ts, updated_ts"
|
||||
if err := d.db.QueryRowContext(ctx, stmt, create.UserID, create.Provider, create.ExternUID).Scan(
|
||||
&create.ID,
|
||||
&create.CreatedTs,
|
||||
&create.UpdatedTs,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return create, nil
|
||||
}
|
||||
|
||||
func (d *DB) ListUserIdentities(ctx context.Context, find *store.FindUserIdentity) ([]*store.UserIdentity, error) {
|
||||
where, args := []string{"1 = 1"}, []any{}
|
||||
|
||||
if find.ID != nil {
|
||||
where, args = append(where, "id = "+placeholder(len(args)+1)), append(args, *find.ID)
|
||||
}
|
||||
if find.UserID != nil {
|
||||
where, args = append(where, "user_id = "+placeholder(len(args)+1)), append(args, *find.UserID)
|
||||
}
|
||||
if find.Provider != nil {
|
||||
where, args = append(where, "provider = "+placeholder(len(args)+1)), append(args, *find.Provider)
|
||||
}
|
||||
if find.ExternUID != nil {
|
||||
where, args = append(where, "extern_uid = "+placeholder(len(args)+1)), append(args, *find.ExternUID)
|
||||
}
|
||||
|
||||
rows, err := d.db.QueryContext(ctx, `
|
||||
SELECT
|
||||
id,
|
||||
user_id,
|
||||
provider,
|
||||
extern_uid,
|
||||
created_ts,
|
||||
updated_ts
|
||||
FROM user_identity
|
||||
WHERE `+strings.Join(where, " AND ")+`
|
||||
ORDER BY id ASC`,
|
||||
args...,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
list := []*store.UserIdentity{}
|
||||
for rows.Next() {
|
||||
ui := &store.UserIdentity{}
|
||||
if err := rows.Scan(
|
||||
&ui.ID,
|
||||
&ui.UserID,
|
||||
&ui.Provider,
|
||||
&ui.ExternUID,
|
||||
&ui.CreatedTs,
|
||||
&ui.UpdatedTs,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
list = append(list, ui)
|
||||
}
|
||||
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return list, nil
|
||||
}
|
||||
|
||||
func (d *DB) DeleteUserIdentities(ctx context.Context, delete *store.DeleteUserIdentity) error {
|
||||
where, args := []string{"1 = 1"}, []any{}
|
||||
|
||||
if delete.ID != nil {
|
||||
where, args = append(where, "id = "+placeholder(len(args)+1)), append(args, *delete.ID)
|
||||
}
|
||||
if delete.UserID != nil {
|
||||
where, args = append(where, "user_id = "+placeholder(len(args)+1)), append(args, *delete.UserID)
|
||||
}
|
||||
if delete.Provider != nil {
|
||||
where, args = append(where, "provider = "+placeholder(len(args)+1)), append(args, *delete.Provider)
|
||||
}
|
||||
|
||||
if _, err := d.db.ExecContext(ctx, "DELETE FROM user_identity WHERE "+strings.Join(where, " AND "), args...); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@ -0,0 +1,95 @@
|
||||
package sqlite
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
|
||||
"github.com/usememos/memos/store"
|
||||
)
|
||||
|
||||
func (d *DB) CreateUserIdentity(ctx context.Context, create *store.UserIdentity) (*store.UserIdentity, error) {
|
||||
stmt := "INSERT INTO `user_identity` (`user_id`, `provider`, `extern_uid`) VALUES (?, ?, ?) RETURNING `id`, `created_ts`, `updated_ts`"
|
||||
if err := d.db.QueryRowContext(ctx, stmt, create.UserID, create.Provider, create.ExternUID).Scan(
|
||||
&create.ID,
|
||||
&create.CreatedTs,
|
||||
&create.UpdatedTs,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return create, nil
|
||||
}
|
||||
|
||||
func (d *DB) ListUserIdentities(ctx context.Context, find *store.FindUserIdentity) ([]*store.UserIdentity, error) {
|
||||
where, args := []string{"1 = 1"}, []any{}
|
||||
|
||||
if find.ID != nil {
|
||||
where, args = append(where, "`id` = ?"), append(args, *find.ID)
|
||||
}
|
||||
if find.UserID != nil {
|
||||
where, args = append(where, "`user_id` = ?"), append(args, *find.UserID)
|
||||
}
|
||||
if find.Provider != nil {
|
||||
where, args = append(where, "`provider` = ?"), append(args, *find.Provider)
|
||||
}
|
||||
if find.ExternUID != nil {
|
||||
where, args = append(where, "`extern_uid` = ?"), append(args, *find.ExternUID)
|
||||
}
|
||||
|
||||
rows, err := d.db.QueryContext(ctx, `
|
||||
SELECT
|
||||
id,
|
||||
user_id,
|
||||
provider,
|
||||
extern_uid,
|
||||
created_ts,
|
||||
updated_ts
|
||||
FROM user_identity
|
||||
WHERE `+strings.Join(where, " AND ")+`
|
||||
ORDER BY id ASC`,
|
||||
args...,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
list := []*store.UserIdentity{}
|
||||
for rows.Next() {
|
||||
ui := &store.UserIdentity{}
|
||||
if err := rows.Scan(
|
||||
&ui.ID,
|
||||
&ui.UserID,
|
||||
&ui.Provider,
|
||||
&ui.ExternUID,
|
||||
&ui.CreatedTs,
|
||||
&ui.UpdatedTs,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
list = append(list, ui)
|
||||
}
|
||||
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return list, nil
|
||||
}
|
||||
|
||||
func (d *DB) DeleteUserIdentities(ctx context.Context, delete *store.DeleteUserIdentity) error {
|
||||
where, args := []string{"1 = 1"}, []any{}
|
||||
|
||||
if delete.ID != nil {
|
||||
where, args = append(where, "`id` = ?"), append(args, *delete.ID)
|
||||
}
|
||||
if delete.UserID != nil {
|
||||
where, args = append(where, "`user_id` = ?"), append(args, *delete.UserID)
|
||||
}
|
||||
if delete.Provider != nil {
|
||||
where, args = append(where, "`provider` = ?"), append(args, *delete.Provider)
|
||||
}
|
||||
|
||||
if _, err := d.db.ExecContext(ctx, "DELETE FROM `user_identity` WHERE "+strings.Join(where, " AND "), args...); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@ -0,0 +1,15 @@
|
||||
-- user_identity stores the linkage between an external identity subject and a local user.
|
||||
-- (provider, extern_uid) is unique across the table; provider stores the idp.uid.
|
||||
-- Each local user can link at most one external account per provider.
|
||||
CREATE TABLE `user_identity` (
|
||||
`id` INT NOT NULL AUTO_INCREMENT PRIMARY KEY,
|
||||
`user_id` INT NOT NULL,
|
||||
`provider` VARCHAR(256) NOT NULL,
|
||||
`extern_uid` VARCHAR(256) NOT NULL,
|
||||
`created_ts` BIGINT NOT NULL DEFAULT (UNIX_TIMESTAMP()),
|
||||
`updated_ts` BIGINT NOT NULL DEFAULT (UNIX_TIMESTAMP()),
|
||||
UNIQUE (`provider`, `extern_uid`),
|
||||
UNIQUE (`user_id`, `provider`)
|
||||
);
|
||||
|
||||
CREATE INDEX `idx_user_identity_user_id` ON `user_identity`(`user_id`);
|
||||
@ -0,0 +1,15 @@
|
||||
-- user_identity stores the linkage between an external identity subject and a local user.
|
||||
-- (provider, extern_uid) is unique across the table; provider stores the idp.uid.
|
||||
-- Each local user can link at most one external account per provider.
|
||||
CREATE TABLE user_identity (
|
||||
id SERIAL PRIMARY KEY,
|
||||
user_id INTEGER NOT NULL,
|
||||
provider TEXT NOT NULL,
|
||||
extern_uid TEXT NOT NULL,
|
||||
created_ts BIGINT NOT NULL DEFAULT EXTRACT(EPOCH FROM NOW()),
|
||||
updated_ts BIGINT NOT NULL DEFAULT EXTRACT(EPOCH FROM NOW()),
|
||||
UNIQUE (provider, extern_uid),
|
||||
UNIQUE (user_id, provider)
|
||||
);
|
||||
|
||||
CREATE INDEX idx_user_identity_user_id ON user_identity(user_id);
|
||||
@ -0,0 +1,15 @@
|
||||
-- user_identity stores the linkage between an external identity subject and a local user.
|
||||
-- (provider, extern_uid) is unique across the table; provider stores the idp.uid.
|
||||
-- Each local user can link at most one external account per provider.
|
||||
CREATE TABLE user_identity (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
user_id INTEGER NOT NULL,
|
||||
provider TEXT NOT NULL,
|
||||
extern_uid TEXT NOT NULL,
|
||||
created_ts BIGINT NOT NULL DEFAULT (strftime('%s', 'now')),
|
||||
updated_ts BIGINT NOT NULL DEFAULT (strftime('%s', 'now')),
|
||||
UNIQUE (provider, extern_uid),
|
||||
UNIQUE (user_id, provider)
|
||||
);
|
||||
|
||||
CREATE INDEX idx_user_identity_user_id ON user_identity(user_id);
|
||||
@ -0,0 +1,199 @@
|
||||
package test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/usememos/memos/store"
|
||||
)
|
||||
|
||||
func TestUserIdentityCreateAndGet(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
ts := NewTestingStore(ctx, t)
|
||||
defer ts.Close()
|
||||
|
||||
user, err := createTestingHostUser(ctx, ts)
|
||||
require.NoError(t, err)
|
||||
|
||||
provider := "idp-uid-1"
|
||||
externUID := "jane@example.com"
|
||||
created, err := ts.CreateUserIdentity(ctx, &store.UserIdentity{
|
||||
UserID: user.ID,
|
||||
Provider: provider,
|
||||
ExternUID: externUID,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NotZero(t, created.ID)
|
||||
require.NotZero(t, created.CreatedTs)
|
||||
require.Equal(t, user.ID, created.UserID)
|
||||
require.Equal(t, provider, created.Provider)
|
||||
require.Equal(t, externUID, created.ExternUID)
|
||||
|
||||
got, err := ts.GetUserIdentity(ctx, &store.FindUserIdentity{
|
||||
Provider: &provider,
|
||||
ExternUID: &externUID,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, got)
|
||||
require.Equal(t, created.ID, got.ID)
|
||||
require.Equal(t, user.ID, got.UserID)
|
||||
|
||||
// Miss returns (nil, nil).
|
||||
missingProvider := "idp-uid-missing"
|
||||
notFound, err := ts.GetUserIdentity(ctx, &store.FindUserIdentity{
|
||||
Provider: &missingProvider,
|
||||
ExternUID: &externUID,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Nil(t, notFound)
|
||||
}
|
||||
|
||||
func TestUserIdentityListByUserID(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
ts := NewTestingStore(ctx, t)
|
||||
defer ts.Close()
|
||||
|
||||
user, err := createTestingHostUser(ctx, ts)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = ts.CreateUserIdentity(ctx, &store.UserIdentity{
|
||||
UserID: user.ID,
|
||||
Provider: "idp-A",
|
||||
ExternUID: "sub-a-1",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
_, err = ts.CreateUserIdentity(ctx, &store.UserIdentity{
|
||||
UserID: user.ID,
|
||||
Provider: "idp-B",
|
||||
ExternUID: "sub-b-1",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
list, err := ts.ListUserIdentities(ctx, &store.FindUserIdentity{
|
||||
UserID: &user.ID,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, list, 2)
|
||||
}
|
||||
|
||||
func TestUserIdentityUniqueConflict(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
ts := NewTestingStore(ctx, t)
|
||||
defer ts.Close()
|
||||
|
||||
userA, err := createTestingHostUser(ctx, ts)
|
||||
require.NoError(t, err)
|
||||
userB, err := createTestingUserWithRole(ctx, ts, "conflict_user", store.RoleUser)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = ts.CreateUserIdentity(ctx, &store.UserIdentity{
|
||||
UserID: userA.ID,
|
||||
Provider: "idp-A",
|
||||
ExternUID: "sub-1",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Second insert with the same (provider, extern_uid) must fail regardless of user_id.
|
||||
_, err = ts.CreateUserIdentity(ctx, &store.UserIdentity{
|
||||
UserID: userB.ID,
|
||||
Provider: "idp-A",
|
||||
ExternUID: "sub-1",
|
||||
})
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestUserIdentitySameExternUIDDifferentProviders(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
ts := NewTestingStore(ctx, t)
|
||||
defer ts.Close()
|
||||
|
||||
user, err := createTestingHostUser(ctx, ts)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = ts.CreateUserIdentity(ctx, &store.UserIdentity{
|
||||
UserID: user.ID,
|
||||
Provider: "idp-A",
|
||||
ExternUID: "sub-1",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
_, err = ts.CreateUserIdentity(ctx, &store.UserIdentity{
|
||||
UserID: user.ID,
|
||||
Provider: "idp-B",
|
||||
ExternUID: "sub-1",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
externUID := "sub-1"
|
||||
list, err := ts.ListUserIdentities(ctx, &store.FindUserIdentity{
|
||||
ExternUID: &externUID,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, list, 2)
|
||||
}
|
||||
|
||||
func TestUserIdentitySameUserSameProviderConflicts(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
ts := NewTestingStore(ctx, t)
|
||||
defer ts.Close()
|
||||
|
||||
user, err := createTestingHostUser(ctx, ts)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = ts.CreateUserIdentity(ctx, &store.UserIdentity{
|
||||
UserID: user.ID,
|
||||
Provider: "idp-A",
|
||||
ExternUID: "sub-1",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = ts.CreateUserIdentity(ctx, &store.UserIdentity{
|
||||
UserID: user.ID,
|
||||
Provider: "idp-A",
|
||||
ExternUID: "sub-2",
|
||||
})
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestUserIdentityDeleteByUserAndProvider(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
ts := NewTestingStore(ctx, t)
|
||||
defer ts.Close()
|
||||
|
||||
user, err := createTestingHostUser(ctx, ts)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = ts.CreateUserIdentity(ctx, &store.UserIdentity{
|
||||
UserID: user.ID,
|
||||
Provider: "idp-A",
|
||||
ExternUID: "sub-a-1",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
_, err = ts.CreateUserIdentity(ctx, &store.UserIdentity{
|
||||
UserID: user.ID,
|
||||
Provider: "idp-B",
|
||||
ExternUID: "sub-b-1",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
provider := "idp-A"
|
||||
err = ts.DeleteUserIdentities(ctx, &store.DeleteUserIdentity{
|
||||
UserID: &user.ID,
|
||||
Provider: &provider,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
list, err := ts.ListUserIdentities(ctx, &store.FindUserIdentity{
|
||||
UserID: &user.ID,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, list, 1)
|
||||
require.Equal(t, "idp-B", list[0].Provider)
|
||||
}
|
||||
@ -0,0 +1,59 @@
|
||||
package store
|
||||
|
||||
import "context"
|
||||
|
||||
// UserIdentity is the linkage between an external identity subject and a local user.
|
||||
// Uniqueness is enforced on (Provider, ExternUID); one local user may have multiple
|
||||
// identities across different providers.
|
||||
type UserIdentity struct {
|
||||
ID int32
|
||||
UserID int32
|
||||
Provider string
|
||||
ExternUID string
|
||||
CreatedTs int64
|
||||
UpdatedTs int64
|
||||
}
|
||||
|
||||
// FindUserIdentity is used to filter user identities in list/get queries.
|
||||
type FindUserIdentity struct {
|
||||
ID *int32
|
||||
UserID *int32
|
||||
Provider *string
|
||||
ExternUID *string
|
||||
}
|
||||
|
||||
// DeleteUserIdentity is used to delete user identity linkage rows.
|
||||
type DeleteUserIdentity struct {
|
||||
ID *int32
|
||||
UserID *int32
|
||||
Provider *string
|
||||
}
|
||||
|
||||
// CreateUserIdentity creates a new external-identity linkage record.
|
||||
// Returns the driver error on unique-constraint violation; callers are responsible
|
||||
// for reconciling concurrent first-login races on (Provider, ExternUID).
|
||||
func (s *Store) CreateUserIdentity(ctx context.Context, create *UserIdentity) (*UserIdentity, error) {
|
||||
return s.driver.CreateUserIdentity(ctx, create)
|
||||
}
|
||||
|
||||
// ListUserIdentities returns all linkage records matching the filter.
|
||||
func (s *Store) ListUserIdentities(ctx context.Context, find *FindUserIdentity) ([]*UserIdentity, error) {
|
||||
return s.driver.ListUserIdentities(ctx, find)
|
||||
}
|
||||
|
||||
// DeleteUserIdentities deletes all linkage records matching the filter.
|
||||
func (s *Store) DeleteUserIdentities(ctx context.Context, delete *DeleteUserIdentity) error {
|
||||
return s.driver.DeleteUserIdentities(ctx, delete)
|
||||
}
|
||||
|
||||
// GetUserIdentity returns the first linkage record matching the filter, or nil if none found.
|
||||
func (s *Store) GetUserIdentity(ctx context.Context, find *FindUserIdentity) (*UserIdentity, error) {
|
||||
list, err := s.ListUserIdentities(ctx, find)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(list) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
return list[0], nil
|
||||
}
|
||||
@ -0,0 +1,184 @@
|
||||
import { useEffect, useMemo, useState } from "react";
|
||||
import { toast } from "react-hot-toast";
|
||||
import { Button } from "@/components/ui/button";
|
||||
import { identityProviderServiceClient, userServiceClient } from "@/connect";
|
||||
import { absolutifyLink } from "@/helpers/utils";
|
||||
import useCurrentUser from "@/hooks/useCurrentUser";
|
||||
import { handleError } from "@/lib/error";
|
||||
import { IdentityProvider, IdentityProvider_Type } from "@/types/proto/api/v1/idp_service_pb";
|
||||
import { LinkedIdentity } from "@/types/proto/api/v1/user_service_pb";
|
||||
import { useTranslate } from "@/utils/i18n";
|
||||
import { storeOAuthState } from "@/utils/oauth";
|
||||
import SettingGroup from "./SettingGroup";
|
||||
import SettingTable from "./SettingTable";
|
||||
|
||||
interface LinkedIdentityRow extends Record<string, unknown> {
|
||||
name: string;
|
||||
title: string;
|
||||
externUid: string;
|
||||
linkedIdentity?: LinkedIdentity;
|
||||
identityProvider: IdentityProvider;
|
||||
}
|
||||
|
||||
const LinkedIdentitySection = () => {
|
||||
const t = useTranslate();
|
||||
const currentUser = useCurrentUser();
|
||||
const [identityProviderList, setIdentityProviderList] = useState<IdentityProvider[]>([]);
|
||||
const [linkedIdentityList, setLinkedIdentityList] = useState<LinkedIdentity[]>([]);
|
||||
|
||||
const fetchData = async () => {
|
||||
if (!currentUser?.name) {
|
||||
return;
|
||||
}
|
||||
const [{ identityProviders }, { linkedIdentities }] = await Promise.all([
|
||||
identityProviderServiceClient.listIdentityProviders({}),
|
||||
userServiceClient.listLinkedIdentities({ parent: currentUser.name }),
|
||||
]);
|
||||
setIdentityProviderList(identityProviders);
|
||||
setLinkedIdentityList(linkedIdentities);
|
||||
};
|
||||
|
||||
useEffect(() => {
|
||||
if (!currentUser?.name) {
|
||||
return;
|
||||
}
|
||||
fetchData().catch((error: unknown) => {
|
||||
handleError(error, toast.error, {
|
||||
context: "Load linked identities",
|
||||
});
|
||||
});
|
||||
}, [currentUser?.name]);
|
||||
|
||||
const oauthIdentityProviders = useMemo(
|
||||
() => identityProviderList.filter((identityProvider) => identityProvider.type === IdentityProvider_Type.OAUTH2),
|
||||
[identityProviderList],
|
||||
);
|
||||
|
||||
const linkedIdentityByProviderName = useMemo(() => {
|
||||
const mapping = new Map<string, LinkedIdentity>();
|
||||
for (const linkedIdentity of linkedIdentityList) {
|
||||
if (!mapping.has(linkedIdentity.idpName)) {
|
||||
mapping.set(linkedIdentity.idpName, linkedIdentity);
|
||||
}
|
||||
}
|
||||
return mapping;
|
||||
}, [linkedIdentityList]);
|
||||
|
||||
const rows = useMemo<LinkedIdentityRow[]>(
|
||||
() =>
|
||||
oauthIdentityProviders.map((identityProvider) => {
|
||||
const linkedIdentity = linkedIdentityByProviderName.get(identityProvider.name);
|
||||
return {
|
||||
name: identityProvider.name,
|
||||
title: identityProvider.title,
|
||||
externUid: linkedIdentity?.externUid ?? "",
|
||||
linkedIdentity,
|
||||
identityProvider,
|
||||
};
|
||||
}),
|
||||
[linkedIdentityByProviderName, oauthIdentityProviders],
|
||||
);
|
||||
|
||||
const handleLinkIdentityProvider = async (identityProvider: IdentityProvider) => {
|
||||
if (!currentUser?.name) {
|
||||
return;
|
||||
}
|
||||
const redirectUri = absolutifyLink("/auth/callback");
|
||||
const oauth2Config = identityProvider.config?.config?.case === "oauth2Config" ? identityProvider.config.config.value : undefined;
|
||||
if (!oauth2Config) {
|
||||
toast.error("Identity provider configuration is invalid.");
|
||||
return;
|
||||
}
|
||||
|
||||
try {
|
||||
const returnUrl = `${window.location.pathname}${window.location.search}${window.location.hash}`;
|
||||
const { state, codeChallenge } = await storeOAuthState(identityProvider.name, "link", returnUrl, currentUser.name);
|
||||
|
||||
let authUrl = `${oauth2Config.authUrl}?client_id=${
|
||||
oauth2Config.clientId
|
||||
}&redirect_uri=${encodeURIComponent(redirectUri)}&state=${state}&response_type=code&scope=${encodeURIComponent(
|
||||
oauth2Config.scopes.join(" "),
|
||||
)}`;
|
||||
|
||||
if (codeChallenge) {
|
||||
authUrl += `&code_challenge=${codeChallenge}&code_challenge_method=S256`;
|
||||
}
|
||||
|
||||
window.location.href = authUrl;
|
||||
} catch (error) {
|
||||
handleError(error, toast.error, {
|
||||
context: "Failed to initiate OAuth flow",
|
||||
fallbackMessage: "Failed to initiate account linking. Please try again.",
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
const handleUnlinkIdentityProvider = async (row: LinkedIdentityRow) => {
|
||||
if (!row.linkedIdentity?.name) {
|
||||
return;
|
||||
}
|
||||
try {
|
||||
await userServiceClient.deleteLinkedIdentity({
|
||||
name: row.linkedIdentity.name,
|
||||
});
|
||||
await fetchData();
|
||||
toast.success(`Unlinked ${row.title}.`);
|
||||
} catch (error) {
|
||||
handleError(error, toast.error, {
|
||||
context: "Delete linked identity",
|
||||
fallbackMessage: "Failed to unlink identity provider.",
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
if (oauthIdentityProviders.length === 0) {
|
||||
return null;
|
||||
}
|
||||
|
||||
return (
|
||||
<SettingGroup
|
||||
showSeparator
|
||||
title="SSO accounts"
|
||||
description="Each provider can be linked to this account at most once. A linked row shows the current extern_uid and can be unlinked."
|
||||
>
|
||||
<SettingTable<LinkedIdentityRow>
|
||||
columns={[
|
||||
{
|
||||
key: "title",
|
||||
header: "SSO provider",
|
||||
render: (_, row: LinkedIdentityRow) => <span className="text-foreground">{row.title}</span>,
|
||||
},
|
||||
{
|
||||
key: "externUid",
|
||||
header: "extern_uid",
|
||||
render: (_, row: LinkedIdentityRow) => (
|
||||
<span className={row.externUid ? "text-foreground" : "text-muted-foreground"}>
|
||||
{row.externUid || t("attachment-library.labels.not-linked")}
|
||||
</span>
|
||||
),
|
||||
},
|
||||
{
|
||||
key: "actions",
|
||||
header: "",
|
||||
className: "text-right",
|
||||
render: (_, row: LinkedIdentityRow) =>
|
||||
row.linkedIdentity ? (
|
||||
<Button variant="outline" size="sm" onClick={() => handleUnlinkIdentityProvider(row)}>
|
||||
Unlink
|
||||
</Button>
|
||||
) : (
|
||||
<Button variant="outline" size="sm" onClick={() => handleLinkIdentityProvider(row.identityProvider)}>
|
||||
{t("common.link")}
|
||||
</Button>
|
||||
),
|
||||
},
|
||||
]}
|
||||
data={rows}
|
||||
emptyMessage="No SSO providers found."
|
||||
getRowKey={(row) => row.name}
|
||||
/>
|
||||
</SettingGroup>
|
||||
);
|
||||
};
|
||||
|
||||
export default LinkedIdentitySection;
|
||||
File diff suppressed because one or more lines are too long
@ -0,0 +1,44 @@
|
||||
import { afterEach, beforeEach, describe, expect, it } from "vitest";
|
||||
import { storeOAuthState, validateOAuthState } from "@/utils/oauth";
|
||||
|
||||
describe("oauth state", () => {
|
||||
beforeEach(() => {
|
||||
sessionStorage.clear();
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
sessionStorage.clear();
|
||||
});
|
||||
|
||||
it("round-trips the linking user for link flows", async () => {
|
||||
const { state } = await storeOAuthState("identity-providers/google", "link", "/settings", "users/alice");
|
||||
|
||||
expect(validateOAuthState(state)).toEqual({
|
||||
identityProviderName: "identity-providers/google",
|
||||
flowMode: "link",
|
||||
returnUrl: "/settings",
|
||||
linkingUserName: "users/alice",
|
||||
codeVerifier: expect.any(String),
|
||||
});
|
||||
});
|
||||
|
||||
it("defaults older states to signin without a linking user", () => {
|
||||
sessionStorage.setItem(
|
||||
"oauth_state",
|
||||
JSON.stringify({
|
||||
state: "legacy-state",
|
||||
identityProviderName: "identity-providers/google",
|
||||
timestamp: Date.now(),
|
||||
returnUrl: "/auth",
|
||||
}),
|
||||
);
|
||||
|
||||
expect(validateOAuthState("legacy-state")).toEqual({
|
||||
identityProviderName: "identity-providers/google",
|
||||
flowMode: "signin",
|
||||
returnUrl: "/auth",
|
||||
linkingUserName: undefined,
|
||||
codeVerifier: undefined,
|
||||
});
|
||||
});
|
||||
});
|
||||
Loading…
Reference in New Issue