From 6e4d1d91008fcc26676e2b92dc1ff15b980ed1fd Mon Sep 17 00:00:00 2001 From: Johnny Date: Sun, 22 Jun 2025 22:58:00 +0800 Subject: [PATCH] fix: auth context --- server/router/api/v1/acl.go | 45 ++++++++--------- server/router/api/v1/acl_config.go | 1 + server/router/api/v1/auth_service.go | 7 ++- server/router/api/v1/test/idp_service_test.go | 16 +++--- .../router/api/v1/test/inbox_service_test.go | 32 ++++++------ .../api/v1/test/shortcut_service_test.go | 50 +++++++++---------- server/router/api/v1/test/test_helper.go | 6 +-- .../api/v1/test/webhook_service_test.go | 22 ++++---- .../api/v1/test/workspace_service_test.go | 2 +- server/router/api/v1/test_auth.go | 8 +-- 10 files changed, 94 insertions(+), 95 deletions(-) diff --git a/server/router/api/v1/acl.go b/server/router/api/v1/acl.go index 14171b594..bf8b06538 100644 --- a/server/router/api/v1/acl.go +++ b/server/router/api/v1/acl.go @@ -23,9 +23,8 @@ import ( type ContextKey int const ( - // The key name used to store username in the context - // user id is extracted from the jwt token subject field. - usernameContextKey ContextKey = iota + // The key name used to store user's ID in the context (for user-based auth). + userIDContextKey ContextKey = iota // The key name used to store session ID in the context (for session-based auth). sessionIDContextKey // The key name used to store access token in the context (for token-based auth). @@ -48,11 +47,6 @@ func NewGRPCAuthInterceptor(store *store.Store, secret string) *GRPCAuthIntercep // AuthenticationInterceptor is the unary interceptor for gRPC API. func (in *GRPCAuthInterceptor) AuthenticationInterceptor(ctx context.Context, request any, serverInfo *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (any, error) { - // Check if this method is in the allowlist first - if isUnauthorizeAllowedMethod(serverInfo.FullMethod) { - return handler(ctx, request) - } - md, ok := metadata.FromIncomingContext(ctx) if !ok { return nil, status.Errorf(codes.Unauthenticated, "failed to parse metadata from incoming context") @@ -65,21 +59,25 @@ func (in *GRPCAuthInterceptor) AuthenticationInterceptor(ctx context.Context, re } // Authenticate using access token (which also validates sessions when it's from cookie) - username, user, err := in.authenticateByAccessToken(ctx, accessToken) + user, err := in.authenticateByAccessToken(ctx, accessToken) if err != nil { + // Check if this method is in the allowlist first + if isUnauthorizeAllowedMethod(serverInfo.FullMethod) { + return handler(ctx, request) + } return nil, err } // Check user status if user.RowStatus == store.Archived { - return nil, errors.Errorf("user %q is archived", username) + return nil, errors.Errorf("user %q is archived", user.Username) } if isOnlyForAdminAllowedMethod(serverInfo.FullMethod) && user.Role != store.RoleHost && user.Role != store.RoleAdmin { - return nil, errors.Errorf("user %q is not admin", username) + return nil, errors.Errorf("user %q is not admin", user.Username) } // Set context values - ctx = context.WithValue(ctx, usernameContextKey, username) + ctx = context.WithValue(ctx, userIDContextKey, user.ID) // Determine if this came from cookie (session) or header (API token) if _, headerErr := getAccessTokenFromMetadata(md); headerErr != nil { @@ -96,9 +94,9 @@ func (in *GRPCAuthInterceptor) AuthenticationInterceptor(ctx context.Context, re } // authenticateByAccessToken authenticates a user using access token from Authorization header or cookie. -func (in *GRPCAuthInterceptor) authenticateByAccessToken(ctx context.Context, accessToken string) (string, *store.User, error) { +func (in *GRPCAuthInterceptor) authenticateByAccessToken(ctx context.Context, accessToken string) (*store.User, error) { if accessToken == "" { - return "", nil, status.Errorf(codes.Unauthenticated, "access token not found") + return nil, status.Errorf(codes.Unauthenticated, "access token not found") } claims := &ClaimsMessage{} _, err := jwt.ParseWithClaims(accessToken, claims, func(t *jwt.Token) (any, error) { @@ -113,33 +111,33 @@ func (in *GRPCAuthInterceptor) authenticateByAccessToken(ctx context.Context, ac return nil, status.Errorf(codes.Unauthenticated, "unexpected access token kid=%v", t.Header["kid"]) }) if err != nil { - return "", nil, status.Errorf(codes.Unauthenticated, "Invalid or expired access token") + return nil, status.Errorf(codes.Unauthenticated, "Invalid or expired access token") } // We either have a valid access token or we will attempt to generate new access token. userID, err := util.ConvertStringToInt32(claims.Subject) if err != nil { - return "", nil, errors.Wrap(err, "malformed ID in the token") + return nil, errors.Wrap(err, "malformed ID in the token") } user, err := in.Store.GetUser(ctx, &store.FindUser{ ID: &userID, }) if err != nil { - return "", nil, errors.Wrap(err, "failed to get user") + return nil, errors.Wrap(err, "failed to get user") } if user == nil { - return "", nil, errors.Errorf("user %q not exists", userID) + return nil, errors.Errorf("user %q not exists", userID) } if user.RowStatus == store.Archived { - return "", nil, errors.Errorf("user %q is archived", userID) + return nil, errors.Errorf("user %q is archived", userID) } accessTokens, err := in.Store.GetUserAccessTokens(ctx, user.ID) if err != nil { - return "", nil, errors.Wrapf(err, "failed to get user access tokens") + return nil, errors.Wrapf(err, "failed to get user access tokens") } if !validateAccessToken(accessToken, accessTokens) { - return "", nil, status.Errorf(codes.Unauthenticated, "invalid access token") + return nil, status.Errorf(codes.Unauthenticated, "invalid access token") } // For tokens that might be used as session IDs (from cookies), also validate session existence @@ -148,7 +146,7 @@ func (in *GRPCAuthInterceptor) authenticateByAccessToken(ctx context.Context, ac validateUserSession(accessToken, sessions) // Result doesn't matter for API tokens } - return user.Username, user, nil + return user, nil } // updateSessionLastAccessed updates the last accessed time for a user session. @@ -204,9 +202,6 @@ func getTokenFromMetadata(md metadata.MD) (string, error) { accessToken = v.Value } } - if accessToken == "" { - return "", errors.New("access token not found") - } return accessToken, nil } diff --git a/server/router/api/v1/acl_config.go b/server/router/api/v1/acl_config.go index 3db136b50..d14d6b305 100644 --- a/server/router/api/v1/acl_config.go +++ b/server/router/api/v1/acl_config.go @@ -6,6 +6,7 @@ var authenticationAllowlistMethods = map[string]bool{ "/memos.api.v1.IdentityProviderService/GetIdentityProvider": true, "/memos.api.v1.IdentityProviderService/ListIdentityProviders": true, "/memos.api.v1.AuthService/CreateSession": true, + "/memos.api.v1.AuthService/GetCurrentSession": true, "/memos.api.v1.AuthService/SignUp": true, "/memos.api.v1.UserService/GetUser": true, "/memos.api.v1.UserService/GetUserAvatar": true, diff --git a/server/router/api/v1/auth_service.go b/server/router/api/v1/auth_service.go index e4db57b9c..96f1ba5c0 100644 --- a/server/router/api/v1/auth_service.go +++ b/server/router/api/v1/auth_service.go @@ -331,16 +331,19 @@ func (*APIV1Service) buildAccessTokenCookie(ctx context.Context, accessToken str } func (s *APIV1Service) GetCurrentUser(ctx context.Context) (*store.User, error) { - username, ok := ctx.Value(usernameContextKey).(string) + userID, ok := ctx.Value(userIDContextKey).(int32) if !ok { return nil, nil } user, err := s.Store.GetUser(ctx, &store.FindUser{ - Username: &username, + ID: &userID, }) if err != nil { return nil, err } + if user == nil { + return nil, errors.Errorf("user %d not found", userID) + } return user, nil } diff --git a/server/router/api/v1/test/idp_service_test.go b/server/router/api/v1/test/idp_service_test.go index 5b0b05a93..8c3da8c84 100644 --- a/server/router/api/v1/test/idp_service_test.go +++ b/server/router/api/v1/test/idp_service_test.go @@ -22,7 +22,7 @@ func TestCreateIdentityProvider(t *testing.T) { require.NoError(t, err) // Set user context - ctx := ts.CreateUserContext(ctx, hostUser.Username) + ctx := ts.CreateUserContext(ctx, hostUser.ID) // Create OAuth2 identity provider req := &v1pb.CreateIdentityProviderRequest{ @@ -71,7 +71,7 @@ func TestCreateIdentityProvider(t *testing.T) { require.NoError(t, err) // Set user context - ctx := ts.CreateUserContext(ctx, regularUser.Username) + ctx := ts.CreateUserContext(ctx, regularUser.ID) req := &v1pb.CreateIdentityProviderRequest{ IdentityProvider: &v1pb.IdentityProvider{ @@ -125,7 +125,7 @@ func TestListIdentityProviders(t *testing.T) { require.NoError(t, err) // Set user context - userCtx := ts.CreateUserContext(ctx, hostUser.Username) + userCtx := ts.CreateUserContext(ctx, hostUser.ID) // Create a couple of identity providers createReq1 := &v1pb.CreateIdentityProviderRequest{ @@ -199,7 +199,7 @@ func TestGetIdentityProvider(t *testing.T) { require.NoError(t, err) // Set user context - userCtx := ts.CreateUserContext(ctx, hostUser.Username) + userCtx := ts.CreateUserContext(ctx, hostUser.ID) // Create identity provider createReq := &v1pb.CreateIdentityProviderRequest{ @@ -284,7 +284,7 @@ func TestUpdateIdentityProvider(t *testing.T) { require.NoError(t, err) // Set user context - userCtx := ts.CreateUserContext(ctx, hostUser.Username) + userCtx := ts.CreateUserContext(ctx, hostUser.ID) // Create identity provider createReq := &v1pb.CreateIdentityProviderRequest{ @@ -398,7 +398,7 @@ func TestDeleteIdentityProvider(t *testing.T) { require.NoError(t, err) // Set user context - userCtx := ts.CreateUserContext(ctx, hostUser.Username) + userCtx := ts.CreateUserContext(ctx, hostUser.ID) // Create identity provider createReq := &v1pb.CreateIdentityProviderRequest{ @@ -464,7 +464,7 @@ func TestDeleteIdentityProvider(t *testing.T) { require.NoError(t, err) // Set user context - userCtx := ts.CreateUserContext(ctx, hostUser.Username) + userCtx := ts.CreateUserContext(ctx, hostUser.ID) req := &v1pb.DeleteIdentityProviderRequest{ Name: "identityProviders/999", @@ -488,7 +488,7 @@ func TestIdentityProviderPermissions(t *testing.T) { require.NoError(t, err) // Set user context - userCtx := ts.CreateUserContext(ctx, regularUser.Username) + userCtx := ts.CreateUserContext(ctx, regularUser.ID) req := &v1pb.CreateIdentityProviderRequest{ IdentityProvider: &v1pb.IdentityProvider{ diff --git a/server/router/api/v1/test/inbox_service_test.go b/server/router/api/v1/test/inbox_service_test.go index caeef3db7..44cc82afc 100644 --- a/server/router/api/v1/test/inbox_service_test.go +++ b/server/router/api/v1/test/inbox_service_test.go @@ -27,7 +27,7 @@ func TestListInboxes(t *testing.T) { require.NoError(t, err) // Set user context - userCtx := ts.CreateUserContext(ctx, user.Username) + userCtx := ts.CreateUserContext(ctx, user.ID) // List inboxes (should be empty initially) req := &v1pb.ListInboxesRequest{ @@ -64,7 +64,7 @@ func TestListInboxes(t *testing.T) { } // Set user context - userCtx := ts.CreateUserContext(ctx, user.Username) + userCtx := ts.CreateUserContext(ctx, user.ID) // List inboxes with page size limit req := &v1pb.ListInboxesRequest{ @@ -90,7 +90,7 @@ func TestListInboxes(t *testing.T) { require.NoError(t, err) // Set user1 context but try to list user2's inboxes - userCtx := ts.CreateUserContext(ctx, user1.Username) + userCtx := ts.CreateUserContext(ctx, user1.ID) req := &v1pb.ListInboxesRequest{ Parent: fmt.Sprintf("users/%d", user2.ID), @@ -124,7 +124,7 @@ func TestListInboxes(t *testing.T) { require.NoError(t, err) // Set host user context and try to list regular user's inboxes - hostCtx := ts.CreateUserContext(ctx, hostUser.Username) + hostCtx := ts.CreateUserContext(ctx, hostUser.ID) req := &v1pb.ListInboxesRequest{ Parent: fmt.Sprintf("users/%d", regularUser.ID), @@ -145,7 +145,7 @@ func TestListInboxes(t *testing.T) { require.NoError(t, err) // Set user context - userCtx := ts.CreateUserContext(ctx, user.Username) + userCtx := ts.CreateUserContext(ctx, user.ID) req := &v1pb.ListInboxesRequest{ Parent: "invalid-parent-format", @@ -194,7 +194,7 @@ func TestUpdateInbox(t *testing.T) { require.NoError(t, err) // Set user context - userCtx := ts.CreateUserContext(ctx, user.Username) + userCtx := ts.CreateUserContext(ctx, user.ID) // Update inbox status req := &v1pb.UpdateInboxRequest{ @@ -236,7 +236,7 @@ func TestUpdateInbox(t *testing.T) { require.NoError(t, err) // Set user1 context but try to update user2's inbox - userCtx := ts.CreateUserContext(ctx, user1.Username) + userCtx := ts.CreateUserContext(ctx, user1.ID) req := &v1pb.UpdateInboxRequest{ Inbox: &v1pb.Inbox{ @@ -262,7 +262,7 @@ func TestUpdateInbox(t *testing.T) { require.NoError(t, err) // Set user context - userCtx := ts.CreateUserContext(ctx, user.Username) + userCtx := ts.CreateUserContext(ctx, user.ID) req := &v1pb.UpdateInboxRequest{ Inbox: &v1pb.Inbox{ @@ -285,7 +285,7 @@ func TestUpdateInbox(t *testing.T) { require.NoError(t, err) // Set user context - userCtx := ts.CreateUserContext(ctx, user.Username) + userCtx := ts.CreateUserContext(ctx, user.ID) req := &v1pb.UpdateInboxRequest{ Inbox: &v1pb.Inbox{ @@ -311,7 +311,7 @@ func TestUpdateInbox(t *testing.T) { require.NoError(t, err) // Set user context - userCtx := ts.CreateUserContext(ctx, user.Username) + userCtx := ts.CreateUserContext(ctx, user.ID) req := &v1pb.UpdateInboxRequest{ Inbox: &v1pb.Inbox{ @@ -351,7 +351,7 @@ func TestUpdateInbox(t *testing.T) { require.NoError(t, err) // Set user context - userCtx := ts.CreateUserContext(ctx, user.Username) + userCtx := ts.CreateUserContext(ctx, user.ID) req := &v1pb.UpdateInboxRequest{ Inbox: &v1pb.Inbox{ @@ -393,7 +393,7 @@ func TestDeleteInbox(t *testing.T) { require.NoError(t, err) // Set user context - userCtx := ts.CreateUserContext(ctx, user.Username) + userCtx := ts.CreateUserContext(ctx, user.ID) // Delete inbox req := &v1pb.DeleteInboxRequest{ @@ -434,7 +434,7 @@ func TestDeleteInbox(t *testing.T) { require.NoError(t, err) // Set user1 context but try to delete user2's inbox - userCtx := ts.CreateUserContext(ctx, user1.Username) + userCtx := ts.CreateUserContext(ctx, user1.ID) req := &v1pb.DeleteInboxRequest{ Name: fmt.Sprintf("inboxes/%d", inbox.ID), @@ -454,7 +454,7 @@ func TestDeleteInbox(t *testing.T) { require.NoError(t, err) // Set user context - userCtx := ts.CreateUserContext(ctx, user.Username) + userCtx := ts.CreateUserContext(ctx, user.ID) req := &v1pb.DeleteInboxRequest{ Name: "invalid-inbox-name", @@ -474,7 +474,7 @@ func TestDeleteInbox(t *testing.T) { require.NoError(t, err) // Set user context - userCtx := ts.CreateUserContext(ctx, user.Username) + userCtx := ts.CreateUserContext(ctx, user.ID) req := &v1pb.DeleteInboxRequest{ Name: "inboxes/99999", // Non-existent inbox @@ -512,7 +512,7 @@ func TestInboxCRUDComplete(t *testing.T) { require.NoError(t, err) // Set user context - userCtx := ts.CreateUserContext(ctx, user.Username) + userCtx := ts.CreateUserContext(ctx, user.ID) // 1. List inboxes - should have 1 listReq := &v1pb.ListInboxesRequest{ diff --git a/server/router/api/v1/test/shortcut_service_test.go b/server/router/api/v1/test/shortcut_service_test.go index 6f210789f..90921cdff 100644 --- a/server/router/api/v1/test/shortcut_service_test.go +++ b/server/router/api/v1/test/shortcut_service_test.go @@ -23,7 +23,7 @@ func TestListShortcuts(t *testing.T) { require.NoError(t, err) // Set user context - userCtx := ts.CreateUserContext(ctx, user.Username) + userCtx := ts.CreateUserContext(ctx, user.ID) // List shortcuts (should be empty initially) req := &v1pb.ListShortcutsRequest{ @@ -47,7 +47,7 @@ func TestListShortcuts(t *testing.T) { require.NoError(t, err) // Set user1 context but try to list user2's shortcuts - userCtx := ts.CreateUserContext(ctx, user1.Username) + userCtx := ts.CreateUserContext(ctx, user1.ID) req := &v1pb.ListShortcutsRequest{ Parent: fmt.Sprintf("users/%d", user2.ID), @@ -67,7 +67,7 @@ func TestListShortcuts(t *testing.T) { require.NoError(t, err) // Set user context - userCtx := ts.CreateUserContext(ctx, user.Username) + userCtx := ts.CreateUserContext(ctx, user.ID) req := &v1pb.ListShortcutsRequest{ Parent: "invalid-parent-format", @@ -104,7 +104,7 @@ func TestGetShortcut(t *testing.T) { require.NoError(t, err) // Set user context - userCtx := ts.CreateUserContext(ctx, user.Username) + userCtx := ts.CreateUserContext(ctx, user.ID) // First create a shortcut createReq := &v1pb.CreateShortcutRequest{ @@ -142,7 +142,7 @@ func TestGetShortcut(t *testing.T) { require.NoError(t, err) // Create shortcut as user1 - user1Ctx := ts.CreateUserContext(ctx, user1.Username) + user1Ctx := ts.CreateUserContext(ctx, user1.ID) createReq := &v1pb.CreateShortcutRequest{ Parent: fmt.Sprintf("users/%d", user1.ID), Shortcut: &v1pb.Shortcut{ @@ -155,7 +155,7 @@ func TestGetShortcut(t *testing.T) { require.NoError(t, err) // Try to get shortcut as user2 - user2Ctx := ts.CreateUserContext(ctx, user2.Username) + user2Ctx := ts.CreateUserContext(ctx, user2.ID) getReq := &v1pb.GetShortcutRequest{ Name: created.Name, } @@ -174,7 +174,7 @@ func TestGetShortcut(t *testing.T) { require.NoError(t, err) // Set user context - userCtx := ts.CreateUserContext(ctx, user.Username) + userCtx := ts.CreateUserContext(ctx, user.ID) req := &v1pb.GetShortcutRequest{ Name: "invalid-shortcut-name", @@ -194,7 +194,7 @@ func TestGetShortcut(t *testing.T) { require.NoError(t, err) // Set user context - userCtx := ts.CreateUserContext(ctx, user.Username) + userCtx := ts.CreateUserContext(ctx, user.ID) req := &v1pb.GetShortcutRequest{ Name: fmt.Sprintf("users/%d", user.ID) + "/shortcuts/nonexistent", @@ -218,7 +218,7 @@ func TestCreateShortcut(t *testing.T) { require.NoError(t, err) // Set user context - userCtx := ts.CreateUserContext(ctx, user.Username) + userCtx := ts.CreateUserContext(ctx, user.ID) req := &v1pb.CreateShortcutRequest{ Parent: fmt.Sprintf("users/%d", user.ID), @@ -257,7 +257,7 @@ func TestCreateShortcut(t *testing.T) { require.NoError(t, err) // Set user1 context but try to create shortcut for user2 - userCtx := ts.CreateUserContext(ctx, user1.Username) + userCtx := ts.CreateUserContext(ctx, user1.ID) req := &v1pb.CreateShortcutRequest{ Parent: fmt.Sprintf("users/%d", user2.ID), @@ -281,7 +281,7 @@ func TestCreateShortcut(t *testing.T) { require.NoError(t, err) // Set user context - userCtx := ts.CreateUserContext(ctx, user.Username) + userCtx := ts.CreateUserContext(ctx, user.ID) req := &v1pb.CreateShortcutRequest{ Parent: "invalid-parent", @@ -305,7 +305,7 @@ func TestCreateShortcut(t *testing.T) { require.NoError(t, err) // Set user context - userCtx := ts.CreateUserContext(ctx, user.Username) + userCtx := ts.CreateUserContext(ctx, user.ID) req := &v1pb.CreateShortcutRequest{ Parent: fmt.Sprintf("users/%d", user.ID), @@ -329,7 +329,7 @@ func TestCreateShortcut(t *testing.T) { require.NoError(t, err) // Set user context - userCtx := ts.CreateUserContext(ctx, user.Username) + userCtx := ts.CreateUserContext(ctx, user.ID) req := &v1pb.CreateShortcutRequest{ Parent: fmt.Sprintf("users/%d", user.ID), @@ -356,7 +356,7 @@ func TestUpdateShortcut(t *testing.T) { require.NoError(t, err) // Set user context - userCtx := ts.CreateUserContext(ctx, user.Username) + userCtx := ts.CreateUserContext(ctx, user.ID) // Create a shortcut first createReq := &v1pb.CreateShortcutRequest{ @@ -401,7 +401,7 @@ func TestUpdateShortcut(t *testing.T) { require.NoError(t, err) // Create shortcut as user1 - user1Ctx := ts.CreateUserContext(ctx, user1.Username) + user1Ctx := ts.CreateUserContext(ctx, user1.ID) createReq := &v1pb.CreateShortcutRequest{ Parent: fmt.Sprintf("users/%d", user1.ID), Shortcut: &v1pb.Shortcut{ @@ -414,7 +414,7 @@ func TestUpdateShortcut(t *testing.T) { require.NoError(t, err) // Try to update shortcut as user2 - user2Ctx := ts.CreateUserContext(ctx, user2.Username) + user2Ctx := ts.CreateUserContext(ctx, user2.ID) updateReq := &v1pb.UpdateShortcutRequest{ Shortcut: &v1pb.Shortcut{ Name: created.Name, @@ -438,7 +438,7 @@ func TestUpdateShortcut(t *testing.T) { // Create a user and context for authentication user, err := ts.CreateRegularUser(ctx, "testuser") require.NoError(t, err) - userCtx := ts.CreateUserContext(ctx, user.Username) + userCtx := ts.CreateUserContext(ctx, user.ID) req := &v1pb.UpdateShortcutRequest{ Shortcut: &v1pb.Shortcut{ @@ -480,7 +480,7 @@ func TestUpdateShortcut(t *testing.T) { require.NoError(t, err) // Set user context - userCtx := ts.CreateUserContext(ctx, user.Username) + userCtx := ts.CreateUserContext(ctx, user.ID) // Create a shortcut first createReq := &v1pb.CreateShortcutRequest{ @@ -523,7 +523,7 @@ func TestDeleteShortcut(t *testing.T) { require.NoError(t, err) // Set user context - userCtx := ts.CreateUserContext(ctx, user.Username) + userCtx := ts.CreateUserContext(ctx, user.ID) // Create a shortcut first createReq := &v1pb.CreateShortcutRequest{ @@ -575,7 +575,7 @@ func TestDeleteShortcut(t *testing.T) { require.NoError(t, err) // Create shortcut as user1 - user1Ctx := ts.CreateUserContext(ctx, user1.Username) + user1Ctx := ts.CreateUserContext(ctx, user1.ID) createReq := &v1pb.CreateShortcutRequest{ Parent: fmt.Sprintf("users/%d", user1.ID), Shortcut: &v1pb.Shortcut{ @@ -588,7 +588,7 @@ func TestDeleteShortcut(t *testing.T) { require.NoError(t, err) // Try to delete shortcut as user2 - user2Ctx := ts.CreateUserContext(ctx, user2.Username) + user2Ctx := ts.CreateUserContext(ctx, user2.ID) deleteReq := &v1pb.DeleteShortcutRequest{ Name: created.Name, } @@ -620,7 +620,7 @@ func TestDeleteShortcut(t *testing.T) { require.NoError(t, err) // Set user context - userCtx := ts.CreateUserContext(ctx, user.Username) + userCtx := ts.CreateUserContext(ctx, user.ID) req := &v1pb.DeleteShortcutRequest{ Name: fmt.Sprintf("users/%d", user.ID) + "/shortcuts/nonexistent", @@ -644,7 +644,7 @@ func TestShortcutFiltering(t *testing.T) { require.NoError(t, err) // Set user context - userCtx := ts.CreateUserContext(ctx, user.Username) + userCtx := ts.CreateUserContext(ctx, user.ID) // Test various valid filter formats validFilters := []string{ @@ -681,7 +681,7 @@ func TestShortcutFiltering(t *testing.T) { require.NoError(t, err) // Set user context - userCtx := ts.CreateUserContext(ctx, user.Username) + userCtx := ts.CreateUserContext(ctx, user.ID) // Test various invalid filter formats invalidFilters := []string{ @@ -723,7 +723,7 @@ func TestShortcutCRUDComplete(t *testing.T) { require.NoError(t, err) // Set user context - userCtx := ts.CreateUserContext(ctx, user.Username) + userCtx := ts.CreateUserContext(ctx, user.ID) // 1. Create multiple shortcuts shortcut1Req := &v1pb.CreateShortcutRequest{ diff --git a/server/router/api/v1/test/test_helper.go b/server/router/api/v1/test/test_helper.go index f4ea12f7e..b40e9c7cb 100644 --- a/server/router/api/v1/test/test_helper.go +++ b/server/router/api/v1/test/test_helper.go @@ -74,8 +74,8 @@ func (ts *TestService) CreateRegularUser(ctx context.Context, username string) ( }) } -// CreateUserContext creates a context with the given username for authentication. -func (*TestService) CreateUserContext(ctx context.Context, username string) context.Context { +// CreateUserContext creates a context with the given user's ID for authentication. +func (*TestService) CreateUserContext(ctx context.Context, userID int32) context.Context { // Use the real context key from the parent package - return apiv1.CreateTestUserContext(ctx, username) + return apiv1.CreateTestUserContext(ctx, userID) } diff --git a/server/router/api/v1/test/webhook_service_test.go b/server/router/api/v1/test/webhook_service_test.go index 0a0c1bb24..4ea2cdf15 100644 --- a/server/router/api/v1/test/webhook_service_test.go +++ b/server/router/api/v1/test/webhook_service_test.go @@ -23,7 +23,7 @@ func TestCreateWebhook(t *testing.T) { hostUser, err := ts.CreateHostUser(ctx, "admin") require.NoError(t, err) - userCtx := ts.CreateUserContext(ctx, hostUser.Username) + userCtx := ts.CreateUserContext(ctx, hostUser.ID) // Create a webhook req := &v1pb.CreateWebhookRequest{ @@ -72,7 +72,7 @@ func TestCreateWebhook(t *testing.T) { regularUser, err := ts.CreateRegularUser(ctx, "user1") require.NoError(t, err) - userCtx := ts.CreateUserContext(ctx, regularUser.Username) + userCtx := ts.CreateUserContext(ctx, regularUser.ID) // Try to create webhook as regular user req := &v1pb.CreateWebhookRequest{ @@ -98,7 +98,7 @@ func TestCreateWebhook(t *testing.T) { hostUser, err := ts.CreateHostUser(ctx, "admin") require.NoError(t, err) - userCtx := ts.CreateUserContext(ctx, hostUser.Username) + userCtx := ts.CreateUserContext(ctx, hostUser.ID) // Try to create webhook with missing URL req := &v1pb.CreateWebhookRequest{ @@ -127,7 +127,7 @@ func TestListWebhooks(t *testing.T) { hostUser, err := ts.CreateHostUser(ctx, "admin") require.NoError(t, err) - userCtx := ts.CreateUserContext(ctx, hostUser.Username) + userCtx := ts.CreateUserContext(ctx, hostUser.ID) // List webhooks req := &v1pb.ListWebhooksRequest{} @@ -147,7 +147,7 @@ func TestListWebhooks(t *testing.T) { // Create host user and authenticate hostUser, err := ts.CreateHostUser(ctx, "admin") require.NoError(t, err) - userCtx := ts.CreateUserContext(ctx, hostUser.Username) + userCtx := ts.CreateUserContext(ctx, hostUser.ID) // Create a webhook createReq := &v1pb.CreateWebhookRequest{ @@ -196,7 +196,7 @@ func TestGetWebhook(t *testing.T) { // Create host user and authenticate hostUser, err := ts.CreateHostUser(ctx, "admin") require.NoError(t, err) - userCtx := ts.CreateUserContext(ctx, hostUser.Username) + userCtx := ts.CreateUserContext(ctx, hostUser.ID) // Create a webhook createReq := &v1pb.CreateWebhookRequest{ @@ -230,7 +230,7 @@ func TestGetWebhook(t *testing.T) { // Create host user and authenticate hostUser, err := ts.CreateHostUser(ctx, "admin") require.NoError(t, err) - userCtx := ts.CreateUserContext(ctx, hostUser.Username) + userCtx := ts.CreateUserContext(ctx, hostUser.ID) // Try to get webhook with invalid name req := &v1pb.GetWebhookRequest{ @@ -250,7 +250,7 @@ func TestGetWebhook(t *testing.T) { // Create host user and authenticate hostUser, err := ts.CreateHostUser(ctx, "admin") require.NoError(t, err) - userCtx := ts.CreateUserContext(ctx, hostUser.Username) + userCtx := ts.CreateUserContext(ctx, hostUser.ID) // Try to get non-existent webhook req := &v1pb.GetWebhookRequest{ @@ -275,7 +275,7 @@ func TestUpdateWebhook(t *testing.T) { // Create host user and authenticate hostUser, err := ts.CreateHostUser(ctx, "admin") require.NoError(t, err) - userCtx := ts.CreateUserContext(ctx, hostUser.Username) + userCtx := ts.CreateUserContext(ctx, hostUser.ID) // Create a webhook createReq := &v1pb.CreateWebhookRequest{ @@ -337,7 +337,7 @@ func TestDeleteWebhook(t *testing.T) { // Create host user and authenticate hostUser, err := ts.CreateHostUser(ctx, "admin") require.NoError(t, err) - userCtx := ts.CreateUserContext(ctx, hostUser.Username) + userCtx := ts.CreateUserContext(ctx, hostUser.ID) // Create a webhook createReq := &v1pb.CreateWebhookRequest{ @@ -393,7 +393,7 @@ func TestDeleteWebhook(t *testing.T) { // Create host user and authenticate hostUser, err := ts.CreateHostUser(ctx, "admin") require.NoError(t, err) - userCtx := ts.CreateUserContext(ctx, hostUser.Username) + userCtx := ts.CreateUserContext(ctx, hostUser.ID) // Try to delete non-existent webhook req := &v1pb.DeleteWebhookRequest{ diff --git a/server/router/api/v1/test/workspace_service_test.go b/server/router/api/v1/test/workspace_service_test.go index 2971a2d79..95a93a0dc 100644 --- a/server/router/api/v1/test/workspace_service_test.go +++ b/server/router/api/v1/test/workspace_service_test.go @@ -149,7 +149,7 @@ func TestGetWorkspaceSetting(t *testing.T) { require.NoError(t, err) // Add user to context - userCtx := ts.CreateUserContext(ctx, hostUser.Username) + userCtx := ts.CreateUserContext(ctx, hostUser.ID) // Call GetWorkspaceSetting for storage setting req := &v1pb.GetWorkspaceSettingRequest{ diff --git a/server/router/api/v1/test_auth.go b/server/router/api/v1/test_auth.go index 840785200..f2f09bd1e 100644 --- a/server/router/api/v1/test_auth.go +++ b/server/router/api/v1/test_auth.go @@ -6,14 +6,14 @@ import ( "github.com/usememos/memos/store" ) -// CreateTestUserContext creates a context with username for testing purposes. +// CreateTestUserContext creates a context with user's ID for testing purposes. // This function is only intended for use in tests. -func CreateTestUserContext(ctx context.Context, username string) context.Context { - return context.WithValue(ctx, usernameContextKey, username) +func CreateTestUserContext(ctx context.Context, userID int32) context.Context { + return context.WithValue(ctx, userIDContextKey, userID) } // CreateTestUserContextWithUser creates a context and ensures the user exists for testing. // This function is only intended for use in tests. func CreateTestUserContextWithUser(ctx context.Context, _ *APIV1Service, user *store.User) context.Context { - return context.WithValue(ctx, usernameContextKey, user.Username) + return context.WithValue(ctx, userIDContextKey, user.ID) }