From 08f9b18ced6f252dff2e2ed64d85120050534d25 Mon Sep 17 00:00:00 2001 From: Johnny Date: Sat, 12 Apr 2025 22:02:13 +0800 Subject: [PATCH] fix: list memo relations --- server/router/api/v1/memo_relation_service.go | 15 ++++++++++++- store/db/sqlite/memo_relation.go | 21 +++++++++++++++++++ store/memo_relation.go | 1 + 3 files changed, 36 insertions(+), 1 deletion(-) diff --git a/server/router/api/v1/memo_relation_service.go b/server/router/api/v1/memo_relation_service.go index c4ff1db7a..3a4c1a551 100644 --- a/server/router/api/v1/memo_relation_service.go +++ b/server/router/api/v1/memo_relation_service.go @@ -70,9 +70,21 @@ func (s *APIV1Service) ListMemoRelations(ctx context.Context, request *v1pb.List if err != nil { return nil, status.Errorf(codes.Internal, "failed to get memo") } + + currentUser, err := s.GetCurrentUser(ctx) + if err != nil { + return nil, status.Errorf(codes.Internal, "failed to get user") + } + var memoFilter string + if currentUser == nil { + memoFilter = `visibility == "PUBLIC"` + } else { + memoFilter = fmt.Sprintf(`creator_id == %d || visibility in ["PUBLIC", "PROTECTED"]`, currentUser.ID) + } relationList := []*v1pb.MemoRelation{} tempList, err := s.Store.ListMemoRelations(ctx, &store.FindMemoRelation{ - MemoID: &memo.ID, + MemoID: &memo.ID, + MemoFilter: &memoFilter, }) if err != nil { return nil, err @@ -86,6 +98,7 @@ func (s *APIV1Service) ListMemoRelations(ctx context.Context, request *v1pb.List } tempList, err = s.Store.ListMemoRelations(ctx, &store.FindMemoRelation{ RelatedMemoID: &memo.ID, + MemoFilter: &memoFilter, }) if err != nil { return nil, err diff --git a/store/db/sqlite/memo_relation.go b/store/db/sqlite/memo_relation.go index 8d9d716d5..9507b163a 100644 --- a/store/db/sqlite/memo_relation.go +++ b/store/db/sqlite/memo_relation.go @@ -2,8 +2,10 @@ package sqlite import ( "context" + "fmt" "strings" + "github.com/usememos/memos/plugin/filter" "github.com/usememos/memos/store" ) @@ -46,6 +48,25 @@ func (d *DB) ListMemoRelations(ctx context.Context, find *store.FindMemoRelation if find.Type != nil { where, args = append(where, "type = ?"), append(args, find.Type) } + if find.MemoFilter != nil { + // Parse filter string and return the parsed expression. + // The filter string should be a CEL expression. + parsedExpr, err := filter.Parse(*find.MemoFilter, filter.MemoFilterCELAttributes...) + if err != nil { + return nil, err + } + convertCtx := filter.NewConvertContext() + // ConvertExprToSQL converts the parsed expression to a SQL condition string. + if err := d.ConvertExprToSQL(convertCtx, parsedExpr.GetExpr()); err != nil { + return nil, err + } + condition := convertCtx.Buffer.String() + if condition != "" { + where = append(where, fmt.Sprintf("memo_id IN (SELECT id FROM memo WHERE %s)", condition)) + where = append(where, fmt.Sprintf("related_memo_id IN (SELECT id FROM memo WHERE %s)", condition)) + args = append(args, append(convertCtx.Args, convertCtx.Args...)...) + } + } rows, err := d.db.QueryContext(ctx, ` SELECT diff --git a/store/memo_relation.go b/store/memo_relation.go index 3d68049df..61b022884 100644 --- a/store/memo_relation.go +++ b/store/memo_relation.go @@ -23,6 +23,7 @@ type FindMemoRelation struct { MemoID *int32 RelatedMemoID *int32 Type *MemoRelationType + MemoFilter *string } type DeleteMemoRelation struct {