From 3fd305dce72ca7423a8e4fbf141e6f80335bceea Mon Sep 17 00:00:00 2001 From: varsnotwars Date: Tue, 12 Aug 2025 00:57:52 +1000 Subject: [PATCH] fix: preferences being overwritten (#4990) --- server/router/api/v1/user_service.go | 86 ++++++++++++- server/router/api/v1/user_service_test.go | 146 ++++++++++++++++++++++ 2 files changed, 231 insertions(+), 1 deletion(-) create mode 100644 server/router/api/v1/user_service_test.go diff --git a/server/router/api/v1/user_service.go b/server/router/api/v1/user_service.go index b1bebf983..14b802a9d 100644 --- a/server/router/api/v1/user_service.go +++ b/server/router/api/v1/user_service.go @@ -373,8 +373,23 @@ func (s *APIV1Service) UpdateUserSetting(ctx context.Context, request *v1pb.Upda return nil, status.Errorf(codes.InvalidArgument, "invalid setting key: %v", err) } + // get existing user setting + existingUserSetting, err := s.Store.GetUserSetting(ctx, &store.FindUserSetting{ + UserID: &userID, + Key: storeKey, + }) + if err != nil { + return nil, err + } + if existingUserSetting == nil { + return nil, status.Errorf(codes.NotFound, "%s not found", storeKey.String()) + } + + // merge only the fields specified by UpdateMask + merged := mergeUserSettingWithMask(existingUserSetting, request.Setting, storeKey, request.UpdateMask.Paths) + // Convert API setting to store setting - storeSetting, err := convertUserSettingToStore(request.Setting, userID, storeKey) + storeSetting, err := convertUserSettingToStore(merged, userID, storeKey) if err != nil { return nil, status.Errorf(codes.InvalidArgument, "failed to convert setting: %v", err) } @@ -1320,3 +1335,72 @@ func (s *APIV1Service) validateUserFilter(_ context.Context, filterStr string) e } return nil } + +func mergeUserSettingWithMask(existing *storepb.UserSetting, incoming *v1pb.UserSetting, key storepb.UserSetting_Key, paths []string) *v1pb.UserSetting { + if incoming == nil { + return &v1pb.UserSetting{} + } + + switch key { + case storepb.UserSetting_GENERAL: + var gs *v1pb.UserSetting_GeneralSetting + + if existing == nil { + gs = &v1pb.UserSetting_GeneralSetting{ + Locale: "en", + Appearance: "system", + MemoVisibility: "PRIVATE", + Theme: "", + } + } else { + gs = &v1pb.UserSetting_GeneralSetting{ + Appearance: existing.GetGeneral().GetAppearance(), + MemoVisibility: existing.GetGeneral().GetMemoVisibility(), + Locale: existing.GetGeneral().GetLocale(), + Theme: existing.GetGeneral().GetTheme(), + } + } + + for _, field := range paths { + switch field { + case "appearance": + gs.Appearance = incoming.GetGeneralSetting().Appearance + case "memoVisibility": + gs.MemoVisibility = incoming.GetGeneralSetting().MemoVisibility + case "theme": + gs.Theme = incoming.GetGeneralSetting().Theme + case "locale": + gs.Locale = incoming.GetGeneralSetting().Locale + } + } + + return &v1pb.UserSetting{ + Name: incoming.Name, + Value: &v1pb.UserSetting_GeneralSetting_{ + GeneralSetting: gs, + }, + } + + case storepb.UserSetting_SHORTCUTS: + // handled by the FE calling shortcut_service.CreateShortcut + // if the FE wants to modify shortcuts by calling the user_service we need to handle below + + return incoming + case storepb.UserSetting_WEBHOOKS: + // handled by the FE calling user_service.CreateUserWebhook + // if the FE wants to modify webhooks by calling the user_service we need to handle below + + return incoming + case storepb.UserSetting_ACCESS_TOKENS: + // handled by the FE calling user_service.CreateUserAccessToken + // if the FE wants to modify access tokens by calling the user_service we need to handle below + + return incoming + case storepb.UserSetting_SESSIONS: + // handled by the FE calling auth_service.CreateSession + // if the FE wants to modify sessions by calling the user_service we need to handle below + return incoming + default: + return incoming + } +} diff --git a/server/router/api/v1/user_service_test.go b/server/router/api/v1/user_service_test.go new file mode 100644 index 000000000..ae5348281 --- /dev/null +++ b/server/router/api/v1/user_service_test.go @@ -0,0 +1,146 @@ +package v1 + +import ( + "reflect" + "testing" + + v1pb "github.com/usememos/memos/proto/gen/api/v1" + storepb "github.com/usememos/memos/proto/gen/store" +) + +func TestMergeUserSettingWithMask(t *testing.T) { + tests := []struct { + name string + existing *storepb.UserSetting + incoming *v1pb.UserSetting + key storepb.UserSetting_Key + paths []string + expected *v1pb.UserSetting + }{ + { + name: "adds new field without removing existing fields", + existing: &storepb.UserSetting{ + UserId: 1, + Key: storepb.UserSetting_GENERAL, + Value: &storepb.UserSetting_General{ + General: &storepb.GeneralUserSetting{ + MemoVisibility: "PROTECTED", + }, + }, + }, + incoming: &v1pb.UserSetting{ + Value: &v1pb.UserSetting_GeneralSetting_{ + GeneralSetting: &v1pb.UserSetting_GeneralSetting{ + Appearance: "light", + }, + }, + }, + key: storepb.UserSetting_GENERAL, + paths: []string{"appearance"}, + expected: &v1pb.UserSetting{ + Value: &v1pb.UserSetting_GeneralSetting_{ + GeneralSetting: &v1pb.UserSetting_GeneralSetting{ + Appearance: "light", + MemoVisibility: "PROTECTED", + }, + }, + }, + }, + { + name: "adds new field when no existing fields exist", + existing: &storepb.UserSetting{ + UserId: 1, + Key: storepb.UserSetting_GENERAL, + Value: &storepb.UserSetting_General{}, + }, + incoming: &v1pb.UserSetting{ + Value: &v1pb.UserSetting_GeneralSetting_{ + GeneralSetting: &v1pb.UserSetting_GeneralSetting{ + Theme: "whitewall", + }, + }, + }, + key: storepb.UserSetting_GENERAL, + paths: []string{"theme"}, + expected: &v1pb.UserSetting{ + Value: &v1pb.UserSetting_GeneralSetting_{ + GeneralSetting: &v1pb.UserSetting_GeneralSetting{ + Theme: "whitewall", + }, + }, + }, + }, + { + name: "updates existing field without removing existing fields", + existing: &storepb.UserSetting{ + UserId: 1, + Key: storepb.UserSetting_GENERAL, + Value: &storepb.UserSetting_General{ + General: &storepb.GeneralUserSetting{ + Appearance: "dark", + MemoVisibility: "PUBLIC", + }, + }, + }, + incoming: &v1pb.UserSetting{ + Value: &v1pb.UserSetting_GeneralSetting_{ + GeneralSetting: &v1pb.UserSetting_GeneralSetting{ + Appearance: "light", + }, + }, + }, + key: storepb.UserSetting_GENERAL, + paths: []string{"appearance"}, + expected: &v1pb.UserSetting{ + Value: &v1pb.UserSetting_GeneralSetting_{ + GeneralSetting: &v1pb.UserSetting_GeneralSetting{ + Appearance: "light", + MemoVisibility: "PUBLIC", + }, + }, + }, + }, + { + name: "updates multiple fields without removing existing fields", + existing: &storepb.UserSetting{ + UserId: 1, + Key: storepb.UserSetting_GENERAL, + Value: &storepb.UserSetting_General{ + General: &storepb.GeneralUserSetting{ + Appearance: "system", + }, + }, + }, + incoming: &v1pb.UserSetting{ + Value: &v1pb.UserSetting_GeneralSetting_{ + GeneralSetting: &v1pb.UserSetting_GeneralSetting{ + Appearance: "dark", + Theme: "paper", + MemoVisibility: "PROTECTED", + }, + }, + }, + key: storepb.UserSetting_GENERAL, + paths: []string{"theme", "memoVisibility", "appearance"}, + expected: &v1pb.UserSetting{ + Value: &v1pb.UserSetting_GeneralSetting_{ + GeneralSetting: &v1pb.UserSetting_GeneralSetting{ + Appearance: "dark", + MemoVisibility: "PROTECTED", + Theme: "paper", + }, + }, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + actual := mergeUserSettingWithMask(tt.existing, tt.incoming, tt.key, tt.paths) + + if !reflect.DeepEqual(actual, tt.expected) { + t.Errorf("expected %v but got %v", tt.expected, actual) + } + }) + } +}