fix: auth context

pull/4784/head
Johnny 5 months ago
parent 45df653f37
commit 6e4d1d9100

@ -23,9 +23,8 @@ import (
type ContextKey int
const (
// The key name used to store username in the context
// user id is extracted from the jwt token subject field.
usernameContextKey ContextKey = iota
// The key name used to store user's ID in the context (for user-based auth).
userIDContextKey ContextKey = iota
// The key name used to store session ID in the context (for session-based auth).
sessionIDContextKey
// The key name used to store access token in the context (for token-based auth).
@ -48,11 +47,6 @@ func NewGRPCAuthInterceptor(store *store.Store, secret string) *GRPCAuthIntercep
// AuthenticationInterceptor is the unary interceptor for gRPC API.
func (in *GRPCAuthInterceptor) AuthenticationInterceptor(ctx context.Context, request any, serverInfo *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (any, error) {
// Check if this method is in the allowlist first
if isUnauthorizeAllowedMethod(serverInfo.FullMethod) {
return handler(ctx, request)
}
md, ok := metadata.FromIncomingContext(ctx)
if !ok {
return nil, status.Errorf(codes.Unauthenticated, "failed to parse metadata from incoming context")
@ -65,21 +59,25 @@ func (in *GRPCAuthInterceptor) AuthenticationInterceptor(ctx context.Context, re
}
// Authenticate using access token (which also validates sessions when it's from cookie)
username, user, err := in.authenticateByAccessToken(ctx, accessToken)
user, err := in.authenticateByAccessToken(ctx, accessToken)
if err != nil {
// Check if this method is in the allowlist first
if isUnauthorizeAllowedMethod(serverInfo.FullMethod) {
return handler(ctx, request)
}
return nil, err
}
// Check user status
if user.RowStatus == store.Archived {
return nil, errors.Errorf("user %q is archived", username)
return nil, errors.Errorf("user %q is archived", user.Username)
}
if isOnlyForAdminAllowedMethod(serverInfo.FullMethod) && user.Role != store.RoleHost && user.Role != store.RoleAdmin {
return nil, errors.Errorf("user %q is not admin", username)
return nil, errors.Errorf("user %q is not admin", user.Username)
}
// Set context values
ctx = context.WithValue(ctx, usernameContextKey, username)
ctx = context.WithValue(ctx, userIDContextKey, user.ID)
// Determine if this came from cookie (session) or header (API token)
if _, headerErr := getAccessTokenFromMetadata(md); headerErr != nil {
@ -96,9 +94,9 @@ func (in *GRPCAuthInterceptor) AuthenticationInterceptor(ctx context.Context, re
}
// authenticateByAccessToken authenticates a user using access token from Authorization header or cookie.
func (in *GRPCAuthInterceptor) authenticateByAccessToken(ctx context.Context, accessToken string) (string, *store.User, error) {
func (in *GRPCAuthInterceptor) authenticateByAccessToken(ctx context.Context, accessToken string) (*store.User, error) {
if accessToken == "" {
return "", nil, status.Errorf(codes.Unauthenticated, "access token not found")
return nil, status.Errorf(codes.Unauthenticated, "access token not found")
}
claims := &ClaimsMessage{}
_, err := jwt.ParseWithClaims(accessToken, claims, func(t *jwt.Token) (any, error) {
@ -113,33 +111,33 @@ func (in *GRPCAuthInterceptor) authenticateByAccessToken(ctx context.Context, ac
return nil, status.Errorf(codes.Unauthenticated, "unexpected access token kid=%v", t.Header["kid"])
})
if err != nil {
return "", nil, status.Errorf(codes.Unauthenticated, "Invalid or expired access token")
return nil, status.Errorf(codes.Unauthenticated, "Invalid or expired access token")
}
// We either have a valid access token or we will attempt to generate new access token.
userID, err := util.ConvertStringToInt32(claims.Subject)
if err != nil {
return "", nil, errors.Wrap(err, "malformed ID in the token")
return nil, errors.Wrap(err, "malformed ID in the token")
}
user, err := in.Store.GetUser(ctx, &store.FindUser{
ID: &userID,
})
if err != nil {
return "", nil, errors.Wrap(err, "failed to get user")
return nil, errors.Wrap(err, "failed to get user")
}
if user == nil {
return "", nil, errors.Errorf("user %q not exists", userID)
return nil, errors.Errorf("user %q not exists", userID)
}
if user.RowStatus == store.Archived {
return "", nil, errors.Errorf("user %q is archived", userID)
return nil, errors.Errorf("user %q is archived", userID)
}
accessTokens, err := in.Store.GetUserAccessTokens(ctx, user.ID)
if err != nil {
return "", nil, errors.Wrapf(err, "failed to get user access tokens")
return nil, errors.Wrapf(err, "failed to get user access tokens")
}
if !validateAccessToken(accessToken, accessTokens) {
return "", nil, status.Errorf(codes.Unauthenticated, "invalid access token")
return nil, status.Errorf(codes.Unauthenticated, "invalid access token")
}
// For tokens that might be used as session IDs (from cookies), also validate session existence
@ -148,7 +146,7 @@ func (in *GRPCAuthInterceptor) authenticateByAccessToken(ctx context.Context, ac
validateUserSession(accessToken, sessions) // Result doesn't matter for API tokens
}
return user.Username, user, nil
return user, nil
}
// updateSessionLastAccessed updates the last accessed time for a user session.
@ -204,9 +202,6 @@ func getTokenFromMetadata(md metadata.MD) (string, error) {
accessToken = v.Value
}
}
if accessToken == "" {
return "", errors.New("access token not found")
}
return accessToken, nil
}

@ -6,6 +6,7 @@ var authenticationAllowlistMethods = map[string]bool{
"/memos.api.v1.IdentityProviderService/GetIdentityProvider": true,
"/memos.api.v1.IdentityProviderService/ListIdentityProviders": true,
"/memos.api.v1.AuthService/CreateSession": true,
"/memos.api.v1.AuthService/GetCurrentSession": true,
"/memos.api.v1.AuthService/SignUp": true,
"/memos.api.v1.UserService/GetUser": true,
"/memos.api.v1.UserService/GetUserAvatar": true,

@ -331,16 +331,19 @@ func (*APIV1Service) buildAccessTokenCookie(ctx context.Context, accessToken str
}
func (s *APIV1Service) GetCurrentUser(ctx context.Context) (*store.User, error) {
username, ok := ctx.Value(usernameContextKey).(string)
userID, ok := ctx.Value(userIDContextKey).(int32)
if !ok {
return nil, nil
}
user, err := s.Store.GetUser(ctx, &store.FindUser{
Username: &username,
ID: &userID,
})
if err != nil {
return nil, err
}
if user == nil {
return nil, errors.Errorf("user %d not found", userID)
}
return user, nil
}

@ -22,7 +22,7 @@ func TestCreateIdentityProvider(t *testing.T) {
require.NoError(t, err)
// Set user context
ctx := ts.CreateUserContext(ctx, hostUser.Username)
ctx := ts.CreateUserContext(ctx, hostUser.ID)
// Create OAuth2 identity provider
req := &v1pb.CreateIdentityProviderRequest{
@ -71,7 +71,7 @@ func TestCreateIdentityProvider(t *testing.T) {
require.NoError(t, err)
// Set user context
ctx := ts.CreateUserContext(ctx, regularUser.Username)
ctx := ts.CreateUserContext(ctx, regularUser.ID)
req := &v1pb.CreateIdentityProviderRequest{
IdentityProvider: &v1pb.IdentityProvider{
@ -125,7 +125,7 @@ func TestListIdentityProviders(t *testing.T) {
require.NoError(t, err)
// Set user context
userCtx := ts.CreateUserContext(ctx, hostUser.Username)
userCtx := ts.CreateUserContext(ctx, hostUser.ID)
// Create a couple of identity providers
createReq1 := &v1pb.CreateIdentityProviderRequest{
@ -199,7 +199,7 @@ func TestGetIdentityProvider(t *testing.T) {
require.NoError(t, err)
// Set user context
userCtx := ts.CreateUserContext(ctx, hostUser.Username)
userCtx := ts.CreateUserContext(ctx, hostUser.ID)
// Create identity provider
createReq := &v1pb.CreateIdentityProviderRequest{
@ -284,7 +284,7 @@ func TestUpdateIdentityProvider(t *testing.T) {
require.NoError(t, err)
// Set user context
userCtx := ts.CreateUserContext(ctx, hostUser.Username)
userCtx := ts.CreateUserContext(ctx, hostUser.ID)
// Create identity provider
createReq := &v1pb.CreateIdentityProviderRequest{
@ -398,7 +398,7 @@ func TestDeleteIdentityProvider(t *testing.T) {
require.NoError(t, err)
// Set user context
userCtx := ts.CreateUserContext(ctx, hostUser.Username)
userCtx := ts.CreateUserContext(ctx, hostUser.ID)
// Create identity provider
createReq := &v1pb.CreateIdentityProviderRequest{
@ -464,7 +464,7 @@ func TestDeleteIdentityProvider(t *testing.T) {
require.NoError(t, err)
// Set user context
userCtx := ts.CreateUserContext(ctx, hostUser.Username)
userCtx := ts.CreateUserContext(ctx, hostUser.ID)
req := &v1pb.DeleteIdentityProviderRequest{
Name: "identityProviders/999",
@ -488,7 +488,7 @@ func TestIdentityProviderPermissions(t *testing.T) {
require.NoError(t, err)
// Set user context
userCtx := ts.CreateUserContext(ctx, regularUser.Username)
userCtx := ts.CreateUserContext(ctx, regularUser.ID)
req := &v1pb.CreateIdentityProviderRequest{
IdentityProvider: &v1pb.IdentityProvider{

@ -27,7 +27,7 @@ func TestListInboxes(t *testing.T) {
require.NoError(t, err)
// Set user context
userCtx := ts.CreateUserContext(ctx, user.Username)
userCtx := ts.CreateUserContext(ctx, user.ID)
// List inboxes (should be empty initially)
req := &v1pb.ListInboxesRequest{
@ -64,7 +64,7 @@ func TestListInboxes(t *testing.T) {
}
// Set user context
userCtx := ts.CreateUserContext(ctx, user.Username)
userCtx := ts.CreateUserContext(ctx, user.ID)
// List inboxes with page size limit
req := &v1pb.ListInboxesRequest{
@ -90,7 +90,7 @@ func TestListInboxes(t *testing.T) {
require.NoError(t, err)
// Set user1 context but try to list user2's inboxes
userCtx := ts.CreateUserContext(ctx, user1.Username)
userCtx := ts.CreateUserContext(ctx, user1.ID)
req := &v1pb.ListInboxesRequest{
Parent: fmt.Sprintf("users/%d", user2.ID),
@ -124,7 +124,7 @@ func TestListInboxes(t *testing.T) {
require.NoError(t, err)
// Set host user context and try to list regular user's inboxes
hostCtx := ts.CreateUserContext(ctx, hostUser.Username)
hostCtx := ts.CreateUserContext(ctx, hostUser.ID)
req := &v1pb.ListInboxesRequest{
Parent: fmt.Sprintf("users/%d", regularUser.ID),
@ -145,7 +145,7 @@ func TestListInboxes(t *testing.T) {
require.NoError(t, err)
// Set user context
userCtx := ts.CreateUserContext(ctx, user.Username)
userCtx := ts.CreateUserContext(ctx, user.ID)
req := &v1pb.ListInboxesRequest{
Parent: "invalid-parent-format",
@ -194,7 +194,7 @@ func TestUpdateInbox(t *testing.T) {
require.NoError(t, err)
// Set user context
userCtx := ts.CreateUserContext(ctx, user.Username)
userCtx := ts.CreateUserContext(ctx, user.ID)
// Update inbox status
req := &v1pb.UpdateInboxRequest{
@ -236,7 +236,7 @@ func TestUpdateInbox(t *testing.T) {
require.NoError(t, err)
// Set user1 context but try to update user2's inbox
userCtx := ts.CreateUserContext(ctx, user1.Username)
userCtx := ts.CreateUserContext(ctx, user1.ID)
req := &v1pb.UpdateInboxRequest{
Inbox: &v1pb.Inbox{
@ -262,7 +262,7 @@ func TestUpdateInbox(t *testing.T) {
require.NoError(t, err)
// Set user context
userCtx := ts.CreateUserContext(ctx, user.Username)
userCtx := ts.CreateUserContext(ctx, user.ID)
req := &v1pb.UpdateInboxRequest{
Inbox: &v1pb.Inbox{
@ -285,7 +285,7 @@ func TestUpdateInbox(t *testing.T) {
require.NoError(t, err)
// Set user context
userCtx := ts.CreateUserContext(ctx, user.Username)
userCtx := ts.CreateUserContext(ctx, user.ID)
req := &v1pb.UpdateInboxRequest{
Inbox: &v1pb.Inbox{
@ -311,7 +311,7 @@ func TestUpdateInbox(t *testing.T) {
require.NoError(t, err)
// Set user context
userCtx := ts.CreateUserContext(ctx, user.Username)
userCtx := ts.CreateUserContext(ctx, user.ID)
req := &v1pb.UpdateInboxRequest{
Inbox: &v1pb.Inbox{
@ -351,7 +351,7 @@ func TestUpdateInbox(t *testing.T) {
require.NoError(t, err)
// Set user context
userCtx := ts.CreateUserContext(ctx, user.Username)
userCtx := ts.CreateUserContext(ctx, user.ID)
req := &v1pb.UpdateInboxRequest{
Inbox: &v1pb.Inbox{
@ -393,7 +393,7 @@ func TestDeleteInbox(t *testing.T) {
require.NoError(t, err)
// Set user context
userCtx := ts.CreateUserContext(ctx, user.Username)
userCtx := ts.CreateUserContext(ctx, user.ID)
// Delete inbox
req := &v1pb.DeleteInboxRequest{
@ -434,7 +434,7 @@ func TestDeleteInbox(t *testing.T) {
require.NoError(t, err)
// Set user1 context but try to delete user2's inbox
userCtx := ts.CreateUserContext(ctx, user1.Username)
userCtx := ts.CreateUserContext(ctx, user1.ID)
req := &v1pb.DeleteInboxRequest{
Name: fmt.Sprintf("inboxes/%d", inbox.ID),
@ -454,7 +454,7 @@ func TestDeleteInbox(t *testing.T) {
require.NoError(t, err)
// Set user context
userCtx := ts.CreateUserContext(ctx, user.Username)
userCtx := ts.CreateUserContext(ctx, user.ID)
req := &v1pb.DeleteInboxRequest{
Name: "invalid-inbox-name",
@ -474,7 +474,7 @@ func TestDeleteInbox(t *testing.T) {
require.NoError(t, err)
// Set user context
userCtx := ts.CreateUserContext(ctx, user.Username)
userCtx := ts.CreateUserContext(ctx, user.ID)
req := &v1pb.DeleteInboxRequest{
Name: "inboxes/99999", // Non-existent inbox
@ -512,7 +512,7 @@ func TestInboxCRUDComplete(t *testing.T) {
require.NoError(t, err)
// Set user context
userCtx := ts.CreateUserContext(ctx, user.Username)
userCtx := ts.CreateUserContext(ctx, user.ID)
// 1. List inboxes - should have 1
listReq := &v1pb.ListInboxesRequest{

@ -23,7 +23,7 @@ func TestListShortcuts(t *testing.T) {
require.NoError(t, err)
// Set user context
userCtx := ts.CreateUserContext(ctx, user.Username)
userCtx := ts.CreateUserContext(ctx, user.ID)
// List shortcuts (should be empty initially)
req := &v1pb.ListShortcutsRequest{
@ -47,7 +47,7 @@ func TestListShortcuts(t *testing.T) {
require.NoError(t, err)
// Set user1 context but try to list user2's shortcuts
userCtx := ts.CreateUserContext(ctx, user1.Username)
userCtx := ts.CreateUserContext(ctx, user1.ID)
req := &v1pb.ListShortcutsRequest{
Parent: fmt.Sprintf("users/%d", user2.ID),
@ -67,7 +67,7 @@ func TestListShortcuts(t *testing.T) {
require.NoError(t, err)
// Set user context
userCtx := ts.CreateUserContext(ctx, user.Username)
userCtx := ts.CreateUserContext(ctx, user.ID)
req := &v1pb.ListShortcutsRequest{
Parent: "invalid-parent-format",
@ -104,7 +104,7 @@ func TestGetShortcut(t *testing.T) {
require.NoError(t, err)
// Set user context
userCtx := ts.CreateUserContext(ctx, user.Username)
userCtx := ts.CreateUserContext(ctx, user.ID)
// First create a shortcut
createReq := &v1pb.CreateShortcutRequest{
@ -142,7 +142,7 @@ func TestGetShortcut(t *testing.T) {
require.NoError(t, err)
// Create shortcut as user1
user1Ctx := ts.CreateUserContext(ctx, user1.Username)
user1Ctx := ts.CreateUserContext(ctx, user1.ID)
createReq := &v1pb.CreateShortcutRequest{
Parent: fmt.Sprintf("users/%d", user1.ID),
Shortcut: &v1pb.Shortcut{
@ -155,7 +155,7 @@ func TestGetShortcut(t *testing.T) {
require.NoError(t, err)
// Try to get shortcut as user2
user2Ctx := ts.CreateUserContext(ctx, user2.Username)
user2Ctx := ts.CreateUserContext(ctx, user2.ID)
getReq := &v1pb.GetShortcutRequest{
Name: created.Name,
}
@ -174,7 +174,7 @@ func TestGetShortcut(t *testing.T) {
require.NoError(t, err)
// Set user context
userCtx := ts.CreateUserContext(ctx, user.Username)
userCtx := ts.CreateUserContext(ctx, user.ID)
req := &v1pb.GetShortcutRequest{
Name: "invalid-shortcut-name",
@ -194,7 +194,7 @@ func TestGetShortcut(t *testing.T) {
require.NoError(t, err)
// Set user context
userCtx := ts.CreateUserContext(ctx, user.Username)
userCtx := ts.CreateUserContext(ctx, user.ID)
req := &v1pb.GetShortcutRequest{
Name: fmt.Sprintf("users/%d", user.ID) + "/shortcuts/nonexistent",
@ -218,7 +218,7 @@ func TestCreateShortcut(t *testing.T) {
require.NoError(t, err)
// Set user context
userCtx := ts.CreateUserContext(ctx, user.Username)
userCtx := ts.CreateUserContext(ctx, user.ID)
req := &v1pb.CreateShortcutRequest{
Parent: fmt.Sprintf("users/%d", user.ID),
@ -257,7 +257,7 @@ func TestCreateShortcut(t *testing.T) {
require.NoError(t, err)
// Set user1 context but try to create shortcut for user2
userCtx := ts.CreateUserContext(ctx, user1.Username)
userCtx := ts.CreateUserContext(ctx, user1.ID)
req := &v1pb.CreateShortcutRequest{
Parent: fmt.Sprintf("users/%d", user2.ID),
@ -281,7 +281,7 @@ func TestCreateShortcut(t *testing.T) {
require.NoError(t, err)
// Set user context
userCtx := ts.CreateUserContext(ctx, user.Username)
userCtx := ts.CreateUserContext(ctx, user.ID)
req := &v1pb.CreateShortcutRequest{
Parent: "invalid-parent",
@ -305,7 +305,7 @@ func TestCreateShortcut(t *testing.T) {
require.NoError(t, err)
// Set user context
userCtx := ts.CreateUserContext(ctx, user.Username)
userCtx := ts.CreateUserContext(ctx, user.ID)
req := &v1pb.CreateShortcutRequest{
Parent: fmt.Sprintf("users/%d", user.ID),
@ -329,7 +329,7 @@ func TestCreateShortcut(t *testing.T) {
require.NoError(t, err)
// Set user context
userCtx := ts.CreateUserContext(ctx, user.Username)
userCtx := ts.CreateUserContext(ctx, user.ID)
req := &v1pb.CreateShortcutRequest{
Parent: fmt.Sprintf("users/%d", user.ID),
@ -356,7 +356,7 @@ func TestUpdateShortcut(t *testing.T) {
require.NoError(t, err)
// Set user context
userCtx := ts.CreateUserContext(ctx, user.Username)
userCtx := ts.CreateUserContext(ctx, user.ID)
// Create a shortcut first
createReq := &v1pb.CreateShortcutRequest{
@ -401,7 +401,7 @@ func TestUpdateShortcut(t *testing.T) {
require.NoError(t, err)
// Create shortcut as user1
user1Ctx := ts.CreateUserContext(ctx, user1.Username)
user1Ctx := ts.CreateUserContext(ctx, user1.ID)
createReq := &v1pb.CreateShortcutRequest{
Parent: fmt.Sprintf("users/%d", user1.ID),
Shortcut: &v1pb.Shortcut{
@ -414,7 +414,7 @@ func TestUpdateShortcut(t *testing.T) {
require.NoError(t, err)
// Try to update shortcut as user2
user2Ctx := ts.CreateUserContext(ctx, user2.Username)
user2Ctx := ts.CreateUserContext(ctx, user2.ID)
updateReq := &v1pb.UpdateShortcutRequest{
Shortcut: &v1pb.Shortcut{
Name: created.Name,
@ -438,7 +438,7 @@ func TestUpdateShortcut(t *testing.T) {
// Create a user and context for authentication
user, err := ts.CreateRegularUser(ctx, "testuser")
require.NoError(t, err)
userCtx := ts.CreateUserContext(ctx, user.Username)
userCtx := ts.CreateUserContext(ctx, user.ID)
req := &v1pb.UpdateShortcutRequest{
Shortcut: &v1pb.Shortcut{
@ -480,7 +480,7 @@ func TestUpdateShortcut(t *testing.T) {
require.NoError(t, err)
// Set user context
userCtx := ts.CreateUserContext(ctx, user.Username)
userCtx := ts.CreateUserContext(ctx, user.ID)
// Create a shortcut first
createReq := &v1pb.CreateShortcutRequest{
@ -523,7 +523,7 @@ func TestDeleteShortcut(t *testing.T) {
require.NoError(t, err)
// Set user context
userCtx := ts.CreateUserContext(ctx, user.Username)
userCtx := ts.CreateUserContext(ctx, user.ID)
// Create a shortcut first
createReq := &v1pb.CreateShortcutRequest{
@ -575,7 +575,7 @@ func TestDeleteShortcut(t *testing.T) {
require.NoError(t, err)
// Create shortcut as user1
user1Ctx := ts.CreateUserContext(ctx, user1.Username)
user1Ctx := ts.CreateUserContext(ctx, user1.ID)
createReq := &v1pb.CreateShortcutRequest{
Parent: fmt.Sprintf("users/%d", user1.ID),
Shortcut: &v1pb.Shortcut{
@ -588,7 +588,7 @@ func TestDeleteShortcut(t *testing.T) {
require.NoError(t, err)
// Try to delete shortcut as user2
user2Ctx := ts.CreateUserContext(ctx, user2.Username)
user2Ctx := ts.CreateUserContext(ctx, user2.ID)
deleteReq := &v1pb.DeleteShortcutRequest{
Name: created.Name,
}
@ -620,7 +620,7 @@ func TestDeleteShortcut(t *testing.T) {
require.NoError(t, err)
// Set user context
userCtx := ts.CreateUserContext(ctx, user.Username)
userCtx := ts.CreateUserContext(ctx, user.ID)
req := &v1pb.DeleteShortcutRequest{
Name: fmt.Sprintf("users/%d", user.ID) + "/shortcuts/nonexistent",
@ -644,7 +644,7 @@ func TestShortcutFiltering(t *testing.T) {
require.NoError(t, err)
// Set user context
userCtx := ts.CreateUserContext(ctx, user.Username)
userCtx := ts.CreateUserContext(ctx, user.ID)
// Test various valid filter formats
validFilters := []string{
@ -681,7 +681,7 @@ func TestShortcutFiltering(t *testing.T) {
require.NoError(t, err)
// Set user context
userCtx := ts.CreateUserContext(ctx, user.Username)
userCtx := ts.CreateUserContext(ctx, user.ID)
// Test various invalid filter formats
invalidFilters := []string{
@ -723,7 +723,7 @@ func TestShortcutCRUDComplete(t *testing.T) {
require.NoError(t, err)
// Set user context
userCtx := ts.CreateUserContext(ctx, user.Username)
userCtx := ts.CreateUserContext(ctx, user.ID)
// 1. Create multiple shortcuts
shortcut1Req := &v1pb.CreateShortcutRequest{

@ -74,8 +74,8 @@ func (ts *TestService) CreateRegularUser(ctx context.Context, username string) (
})
}
// CreateUserContext creates a context with the given username for authentication.
func (*TestService) CreateUserContext(ctx context.Context, username string) context.Context {
// CreateUserContext creates a context with the given user's ID for authentication.
func (*TestService) CreateUserContext(ctx context.Context, userID int32) context.Context {
// Use the real context key from the parent package
return apiv1.CreateTestUserContext(ctx, username)
return apiv1.CreateTestUserContext(ctx, userID)
}

@ -23,7 +23,7 @@ func TestCreateWebhook(t *testing.T) {
hostUser, err := ts.CreateHostUser(ctx, "admin")
require.NoError(t, err)
userCtx := ts.CreateUserContext(ctx, hostUser.Username)
userCtx := ts.CreateUserContext(ctx, hostUser.ID)
// Create a webhook
req := &v1pb.CreateWebhookRequest{
@ -72,7 +72,7 @@ func TestCreateWebhook(t *testing.T) {
regularUser, err := ts.CreateRegularUser(ctx, "user1")
require.NoError(t, err)
userCtx := ts.CreateUserContext(ctx, regularUser.Username)
userCtx := ts.CreateUserContext(ctx, regularUser.ID)
// Try to create webhook as regular user
req := &v1pb.CreateWebhookRequest{
@ -98,7 +98,7 @@ func TestCreateWebhook(t *testing.T) {
hostUser, err := ts.CreateHostUser(ctx, "admin")
require.NoError(t, err)
userCtx := ts.CreateUserContext(ctx, hostUser.Username)
userCtx := ts.CreateUserContext(ctx, hostUser.ID)
// Try to create webhook with missing URL
req := &v1pb.CreateWebhookRequest{
@ -127,7 +127,7 @@ func TestListWebhooks(t *testing.T) {
hostUser, err := ts.CreateHostUser(ctx, "admin")
require.NoError(t, err)
userCtx := ts.CreateUserContext(ctx, hostUser.Username)
userCtx := ts.CreateUserContext(ctx, hostUser.ID)
// List webhooks
req := &v1pb.ListWebhooksRequest{}
@ -147,7 +147,7 @@ func TestListWebhooks(t *testing.T) {
// Create host user and authenticate
hostUser, err := ts.CreateHostUser(ctx, "admin")
require.NoError(t, err)
userCtx := ts.CreateUserContext(ctx, hostUser.Username)
userCtx := ts.CreateUserContext(ctx, hostUser.ID)
// Create a webhook
createReq := &v1pb.CreateWebhookRequest{
@ -196,7 +196,7 @@ func TestGetWebhook(t *testing.T) {
// Create host user and authenticate
hostUser, err := ts.CreateHostUser(ctx, "admin")
require.NoError(t, err)
userCtx := ts.CreateUserContext(ctx, hostUser.Username)
userCtx := ts.CreateUserContext(ctx, hostUser.ID)
// Create a webhook
createReq := &v1pb.CreateWebhookRequest{
@ -230,7 +230,7 @@ func TestGetWebhook(t *testing.T) {
// Create host user and authenticate
hostUser, err := ts.CreateHostUser(ctx, "admin")
require.NoError(t, err)
userCtx := ts.CreateUserContext(ctx, hostUser.Username)
userCtx := ts.CreateUserContext(ctx, hostUser.ID)
// Try to get webhook with invalid name
req := &v1pb.GetWebhookRequest{
@ -250,7 +250,7 @@ func TestGetWebhook(t *testing.T) {
// Create host user and authenticate
hostUser, err := ts.CreateHostUser(ctx, "admin")
require.NoError(t, err)
userCtx := ts.CreateUserContext(ctx, hostUser.Username)
userCtx := ts.CreateUserContext(ctx, hostUser.ID)
// Try to get non-existent webhook
req := &v1pb.GetWebhookRequest{
@ -275,7 +275,7 @@ func TestUpdateWebhook(t *testing.T) {
// Create host user and authenticate
hostUser, err := ts.CreateHostUser(ctx, "admin")
require.NoError(t, err)
userCtx := ts.CreateUserContext(ctx, hostUser.Username)
userCtx := ts.CreateUserContext(ctx, hostUser.ID)
// Create a webhook
createReq := &v1pb.CreateWebhookRequest{
@ -337,7 +337,7 @@ func TestDeleteWebhook(t *testing.T) {
// Create host user and authenticate
hostUser, err := ts.CreateHostUser(ctx, "admin")
require.NoError(t, err)
userCtx := ts.CreateUserContext(ctx, hostUser.Username)
userCtx := ts.CreateUserContext(ctx, hostUser.ID)
// Create a webhook
createReq := &v1pb.CreateWebhookRequest{
@ -393,7 +393,7 @@ func TestDeleteWebhook(t *testing.T) {
// Create host user and authenticate
hostUser, err := ts.CreateHostUser(ctx, "admin")
require.NoError(t, err)
userCtx := ts.CreateUserContext(ctx, hostUser.Username)
userCtx := ts.CreateUserContext(ctx, hostUser.ID)
// Try to delete non-existent webhook
req := &v1pb.DeleteWebhookRequest{

@ -149,7 +149,7 @@ func TestGetWorkspaceSetting(t *testing.T) {
require.NoError(t, err)
// Add user to context
userCtx := ts.CreateUserContext(ctx, hostUser.Username)
userCtx := ts.CreateUserContext(ctx, hostUser.ID)
// Call GetWorkspaceSetting for storage setting
req := &v1pb.GetWorkspaceSettingRequest{

@ -6,14 +6,14 @@ import (
"github.com/usememos/memos/store"
)
// CreateTestUserContext creates a context with username for testing purposes.
// CreateTestUserContext creates a context with user's ID for testing purposes.
// This function is only intended for use in tests.
func CreateTestUserContext(ctx context.Context, username string) context.Context {
return context.WithValue(ctx, usernameContextKey, username)
func CreateTestUserContext(ctx context.Context, userID int32) context.Context {
return context.WithValue(ctx, userIDContextKey, userID)
}
// CreateTestUserContextWithUser creates a context and ensures the user exists for testing.
// This function is only intended for use in tests.
func CreateTestUserContextWithUser(ctx context.Context, _ *APIV1Service, user *store.User) context.Context {
return context.WithValue(ctx, usernameContextKey, user.Username)
return context.WithValue(ctx, userIDContextKey, user.ID)
}

Loading…
Cancel
Save