fix(memo): enforce parent visibility for comments

pull/5947/head
boojack 3 weeks ago
parent 1df6479443
commit 4a1e401bd9

@ -38,6 +38,37 @@ func isSSESuppressed(ctx context.Context) bool {
return ok && v
}
func (s *APIV1Service) checkMemoReadAccess(ctx context.Context, memo *store.Memo) error {
if memo == nil {
return status.Errorf(codes.NotFound, "memo not found")
}
// Archived memos are only visible to their creator.
if memo.RowStatus == store.Archived {
user, err := s.fetchCurrentUser(ctx)
if err != nil {
return status.Errorf(codes.Internal, "failed to get user")
}
if user == nil || memo.CreatorID != user.ID {
return status.Errorf(codes.NotFound, "memo not found")
}
}
if memo.Visibility != store.Public {
user, err := s.fetchCurrentUser(ctx)
if err != nil {
return status.Errorf(codes.Internal, "failed to get user")
}
if user == nil {
return status.Errorf(codes.Unauthenticated, "user not authenticated")
}
if memo.Visibility == store.Private && memo.CreatorID != user.ID {
return status.Errorf(codes.PermissionDenied, "permission denied")
}
}
return nil
}
func (s *APIV1Service) CreateMemo(ctx context.Context, request *v1pb.CreateMemoRequest) (*v1pb.Memo, error) {
user, err := s.fetchCurrentUser(ctx)
if err != nil {
@ -335,27 +366,19 @@ func (s *APIV1Service) GetMemo(ctx context.Context, request *v1pb.GetMemoRequest
return nil, status.Errorf(codes.NotFound, "memo not found")
}
// Archived memos are only visible to their creator.
if memo.RowStatus == store.Archived {
user, err := s.fetchCurrentUser(ctx)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get user")
}
if user == nil || memo.CreatorID != user.ID {
return nil, status.Errorf(codes.NotFound, "memo not found")
}
if err := s.checkMemoReadAccess(ctx, memo); err != nil {
return nil, err
}
if memo.Visibility != store.Public {
user, err := s.fetchCurrentUser(ctx)
if memo.ParentUID != nil {
parentMemo, err := s.Store.GetMemo(ctx, &store.FindMemo{UID: memo.ParentUID})
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get user")
return nil, status.Errorf(codes.Internal, "failed to get parent memo")
}
if user == nil {
return nil, status.Errorf(codes.Unauthenticated, "user not authenticated")
if parentMemo == nil {
return nil, status.Errorf(codes.NotFound, "memo not found")
}
if memo.Visibility == store.Private && memo.CreatorID != user.ID {
return nil, status.Errorf(codes.PermissionDenied, "permission denied")
if err := s.checkMemoReadAccess(ctx, parentMemo); err != nil {
return nil, err
}
}
@ -486,6 +509,16 @@ func (s *APIV1Service) UpdateMemo(ctx context.Context, request *v1pb.UpdateMemoR
update.Payload = memo.Payload
} else if path == "visibility" {
visibility := convertVisibilityToStore(request.Memo.Visibility)
if memo.ParentUID != nil {
parentMemo, err := s.Store.GetMemo(ctx, &store.FindMemo{UID: memo.ParentUID})
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get parent memo")
}
if parentMemo == nil {
return nil, status.Errorf(codes.NotFound, "memo not found")
}
visibility = parentMemo.Visibility
}
update.Visibility = &visibility
} else if path == "pinned" {
update.Pinned = &request.Memo.Pinned
@ -641,11 +674,17 @@ func (s *APIV1Service) CreateMemoComment(ctx context.Context, request *v1pb.Crea
if relatedMemo.Visibility == store.Private && relatedMemo.CreatorID != user.ID && !isSuperUser(user) {
return nil, status.Errorf(codes.PermissionDenied, "permission denied")
}
if request.Comment == nil {
return nil, status.Errorf(codes.InvalidArgument, "comment is required")
}
comment := *request.Comment
comment.Visibility = convertVisibilityFromStore(relatedMemo.Visibility)
// Create the memo comment first; suppress the generic memo.created SSE event
// since CreateMemoComment broadcasts memo.comment.created for the parent instead.
memoComment, err := s.CreateMemo(withSuppressMentionNotifications(withSuppressSSE(ctx)), &v1pb.CreateMemoRequest{
Memo: request.Comment,
Memo: &comment,
MemoId: request.CommentId,
})
if err != nil {
@ -722,6 +761,12 @@ func (s *APIV1Service) ListMemoComments(ctx context.Context, request *v1pb.ListM
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get memo")
}
if memo == nil {
return nil, status.Errorf(codes.NotFound, "memo not found")
}
if err := s.checkMemoReadAccess(ctx, memo); err != nil {
return nil, err
}
currentUser, err := s.fetchCurrentUser(ctx)
if err != nil {

@ -8,6 +8,9 @@ import (
"time"
"github.com/stretchr/testify/require"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"google.golang.org/protobuf/types/known/fieldmaskpb"
"google.golang.org/protobuf/types/known/timestamppb"
apiv1 "github.com/usememos/memos/proto/gen/api/v1"
@ -527,6 +530,113 @@ func TestListMemoCommentsPaginates(t *testing.T) {
require.Empty(t, secondPage.NextPageToken)
}
func TestCreateMemoCommentInheritsParentVisibility(t *testing.T) {
ctx := context.Background()
ts := NewTestService(t)
defer ts.Cleanup()
owner, err := ts.CreateRegularUser(ctx, "private-comment-owner")
require.NoError(t, err)
ownerCtx := ts.CreateUserContext(ctx, owner.ID)
parent, err := ts.Service.CreateMemo(ownerCtx, &apiv1.CreateMemoRequest{
Memo: &apiv1.Memo{
Content: "private parent",
Visibility: apiv1.Visibility_PRIVATE,
},
})
require.NoError(t, err)
comment, err := ts.Service.CreateMemoComment(ownerCtx, &apiv1.CreateMemoCommentRequest{
Name: parent.Name,
Comment: &apiv1.Memo{
Content: "client requested public comment",
Visibility: apiv1.Visibility_PUBLIC,
},
})
require.NoError(t, err)
require.Equal(t, apiv1.Visibility_PRIVATE, comment.Visibility)
updatedComment, err := ts.Service.UpdateMemo(ownerCtx, &apiv1.UpdateMemoRequest{
Memo: &apiv1.Memo{
Name: comment.Name,
Visibility: apiv1.Visibility_PUBLIC,
},
UpdateMask: &fieldmaskpb.FieldMask{Paths: []string{"visibility"}},
})
require.NoError(t, err)
require.Equal(t, apiv1.Visibility_PRIVATE, updatedComment.Visibility)
_, err = ts.Service.GetMemo(ctx, &apiv1.GetMemoRequest{Name: comment.Name})
require.Equal(t, codes.Unauthenticated, status.Code(err))
}
func TestGetMemoCommentRequiresParentReadAccess(t *testing.T) {
ctx := context.Background()
ts := NewTestService(t)
defer ts.Cleanup()
owner, err := ts.CreateRegularUser(ctx, "legacy-comment-owner")
require.NoError(t, err)
ownerCtx := ts.CreateUserContext(ctx, owner.ID)
other, err := ts.CreateRegularUser(ctx, "legacy-comment-other")
require.NoError(t, err)
otherCtx := ts.CreateUserContext(ctx, other.ID)
parent, err := ts.Service.CreateMemo(ownerCtx, &apiv1.CreateMemoRequest{
Memo: &apiv1.Memo{
Content: "private parent for legacy comment",
Visibility: apiv1.Visibility_PRIVATE,
},
})
require.NoError(t, err)
legacyComment, err := ts.Store.CreateMemo(ctx, &store.Memo{
UID: "legacy-public-comment",
CreatorID: owner.ID,
Content: "legacy public comment under private parent",
Visibility: store.Public,
})
require.NoError(t, err)
parentUID := parent.Name[len("memos/"):]
parentMemo, err := ts.Store.GetMemo(ctx, &store.FindMemo{UID: &parentUID})
require.NoError(t, err)
require.NotNil(t, parentMemo)
_, err = ts.Store.UpsertMemoRelation(ctx, &store.MemoRelation{
MemoID: legacyComment.ID,
RelatedMemoID: parentMemo.ID,
Type: store.MemoRelationComment,
})
require.NoError(t, err)
commentName := "memos/" + legacyComment.UID
_, err = ts.Service.GetMemo(ctx, &apiv1.GetMemoRequest{Name: commentName})
require.Equal(t, codes.Unauthenticated, status.Code(err))
_, err = ts.Service.GetMemo(otherCtx, &apiv1.GetMemoRequest{Name: commentName})
require.Equal(t, codes.PermissionDenied, status.Code(err))
comment, err := ts.Service.GetMemo(ownerCtx, &apiv1.GetMemoRequest{Name: commentName})
require.NoError(t, err)
require.Equal(t, parent.Name, comment.GetParent())
_, err = ts.Service.ListMemoComments(ctx, &apiv1.ListMemoCommentsRequest{Name: parent.Name})
require.Equal(t, codes.Unauthenticated, status.Code(err))
_, err = ts.Service.ListMemoComments(otherCtx, &apiv1.ListMemoCommentsRequest{Name: parent.Name})
require.Equal(t, codes.PermissionDenied, status.Code(err))
comments, err := ts.Service.ListMemoComments(ownerCtx, &apiv1.ListMemoCommentsRequest{Name: parent.Name})
require.NoError(t, err)
require.Len(t, comments.Memos, 1)
require.Equal(t, commentName, comments.Memos[0].Name)
}
// TestCreateMemoWithCustomTimestamps tests that custom timestamps can be set when creating memos and comments.
// This addresses issue #5483: https://github.com/usememos/memos/issues/5483
func TestCreateMemoWithCustomTimestamps(t *testing.T) {

@ -86,9 +86,10 @@ func (s *RSSService) GetExploreRSS(c *echo.Context) error {
normalStatus := store.Normal
limit := maxRSSItemCount
memoFind := store.FindMemo{
RowStatus: &normalStatus,
VisibilityList: []store.Visibility{store.Public},
Limit: &limit,
RowStatus: &normalStatus,
VisibilityList: []store.Visibility{store.Public},
ExcludeComments: true,
Limit: &limit,
}
memoList, err := s.Store.ListMemos(ctx, &memoFind)
if err != nil {
@ -135,10 +136,11 @@ func (s *RSSService) GetUserRSS(c *echo.Context) error {
normalStatus := store.Normal
limit := maxRSSItemCount
memoFind := store.FindMemo{
CreatorID: &user.ID,
RowStatus: &normalStatus,
VisibilityList: []store.Visibility{store.Public},
Limit: &limit,
CreatorID: &user.ID,
RowStatus: &normalStatus,
VisibilityList: []store.Visibility{store.Public},
ExcludeComments: true,
Limit: &limit,
}
memoList, err := s.Store.ListMemos(ctx, &memoFind)
if err != nil {

@ -0,0 +1,86 @@
package rss
import (
"context"
"net/http"
"net/http/httptest"
"strings"
"testing"
"github.com/labstack/echo/v5"
"github.com/stretchr/testify/require"
"github.com/usememos/memos/internal/markdown"
"github.com/usememos/memos/internal/profile"
"github.com/usememos/memos/store"
teststore "github.com/usememos/memos/store/test"
)
func TestPublicRSSExcludesComments(t *testing.T) {
ctx := context.Background()
stores := teststore.NewTestingStore(ctx, t)
defer stores.Close()
user, err := stores.CreateUser(ctx, &store.User{
Username: "rss-comment-owner",
Role: store.RoleUser,
Email: "rss-comment-owner@example.com",
})
require.NoError(t, err)
parent, err := stores.CreateMemo(ctx, &store.Memo{
UID: "rss-public-parent",
CreatorID: user.ID,
Content: "public parent should stay in rss",
Visibility: store.Public,
})
require.NoError(t, err)
comment, err := stores.CreateMemo(ctx, &store.Memo{
UID: "rss-public-comment",
CreatorID: user.ID,
Content: "public comment should not be in rss",
Visibility: store.Public,
})
require.NoError(t, err)
_, err = stores.UpsertMemoRelation(ctx, &store.MemoRelation{
MemoID: comment.ID,
RelatedMemoID: parent.ID,
Type: store.MemoRelationComment,
})
require.NoError(t, err)
service := NewRSSService(&profile.Profile{}, stores, markdown.NewService())
exploreRSS := renderRSS(t, service, "/explore/rss.xml", "")
require.Contains(t, exploreRSS, "public parent should stay in rss")
require.NotContains(t, exploreRSS, "public comment should not be in rss")
userRSS := renderRSS(t, service, "/u/rss-comment-owner/rss.xml", user.Username)
require.Contains(t, userRSS, "public parent should stay in rss")
require.NotContains(t, userRSS, "public comment should not be in rss")
}
func renderRSS(t *testing.T, service *RSSService, target string, username string) string {
t.Helper()
e := echo.New()
req := httptest.NewRequest(http.MethodGet, target, strings.NewReader(""))
req.Host = "example.com"
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
if username != "" {
c.SetPathValues(echo.PathValues{{Name: "username", Value: username}})
}
var err error
if username == "" {
err = service.GetExploreRSS(c)
} else {
err = service.GetUserRSS(c)
}
require.NoError(t, err)
require.Equal(t, http.StatusOK, rec.Code)
return rec.Body.String()
}
Loading…
Cancel
Save