From 4eb5b67bafe717fc01765bb0ef3ee9366e5dad10 Mon Sep 17 00:00:00 2001 From: varsnotwars Date: Sat, 16 Aug 2025 00:02:29 +1000 Subject: [PATCH] feat: attachments by id (#5008) --- plugin/filter/common_converter.go | 26 +++++- plugin/filter/filter.go | 5 ++ server/router/api/v1/attachment_service.go | 21 ++--- .../router/api/v1/memo_attachment_service.go | 2 +- server/router/api/v1/memo_service.go | 81 ++++++++++++++++--- .../router/api/v1/memo_service_converter.go | 11 +-- store/attachment.go | 4 + store/db/mysql/attachment.go | 63 ++++++++++++--- store/db/mysql/attachment_filter_test.go | 39 +++++++++ store/db/postgres/attachment.go | 61 +++++++++++--- store/db/postgres/attachment_filter_test.go | 39 +++++++++ store/db/sqlite/attachment.go | 63 ++++++++++++--- store/db/sqlite/attachment_filter_test.go | 39 +++++++++ 13 files changed, 387 insertions(+), 67 deletions(-) create mode 100644 store/db/mysql/attachment_filter_test.go create mode 100644 store/db/postgres/attachment_filter_test.go create mode 100644 store/db/sqlite/attachment_filter_test.go diff --git a/plugin/filter/common_converter.go b/plugin/filter/common_converter.go index 6ad0a04c3..5b277d195 100644 --- a/plugin/filter/common_converter.go +++ b/plugin/filter/common_converter.go @@ -207,7 +207,7 @@ func (c *CommonSQLConverter) handleInOperator(ctx *ConvertContext, callExpr *exp return err } - if !slices.Contains([]string{"tag", "visibility", "content_id"}, identifier) { + if !slices.Contains([]string{"tag", "visibility", "content_id", "memo_id"}, identifier) { return errors.Errorf("invalid identifier for %s", callExpr.Function) } @@ -226,6 +226,8 @@ func (c *CommonSQLConverter) handleInOperator(ctx *ConvertContext, callExpr *exp return c.handleVisibilityInList(ctx, values) } else if identifier == "content_id" { return c.handleContentIDInList(ctx, values) + } else if identifier == "memo_id" { + return c.handleMemoIDInList(ctx, values) } return nil @@ -333,6 +335,28 @@ func (c *CommonSQLConverter) handleContentIDInList(ctx *ConvertContext, values [ return nil } +func (c *CommonSQLConverter) handleMemoIDInList(ctx *ConvertContext, values []any) error { + placeholders := []string{} + for range values { + placeholders = append(placeholders, c.dialect.GetParameterPlaceholder(c.paramIndex)) + c.paramIndex++ + } + + tablePrefix := c.dialect.GetTablePrefix("resource") + if _, ok := c.dialect.(*PostgreSQLDialect); ok { + if _, err := ctx.Buffer.WriteString(fmt.Sprintf("%s.memo_id IN (%s)", tablePrefix, strings.Join(placeholders, ","))); err != nil { + return err + } + } else { + if _, err := ctx.Buffer.WriteString(fmt.Sprintf("%s.`memo_id` IN (%s)", tablePrefix, strings.Join(placeholders, ","))); err != nil { + return err + } + } + + ctx.Args = append(ctx.Args, values...) + return nil +} + func (c *CommonSQLConverter) handleContainsOperator(ctx *ConvertContext, callExpr *exprv1.Expr_Call) error { if len(callExpr.Args) != 1 { return errors.Errorf("invalid number of arguments for %s", callExpr.Function) diff --git a/plugin/filter/filter.go b/plugin/filter/filter.go index e30ef27c9..dc4190deb 100644 --- a/plugin/filter/filter.go +++ b/plugin/filter/filter.go @@ -46,6 +46,11 @@ var UserFilterCELAttributes = []cel.EnvOption{ cel.Variable("username", cel.StringType), } +// AttachmentFilterCELAttributes are the CEL attributes for user. +var AttachmentFilterCELAttributes = []cel.EnvOption{ + cel.Variable("memo_id", cel.StringType), +} + // Parse parses the filter string and returns the parsed expression. // The filter string should be a CEL expression. func Parse(filter string, opts ...cel.EnvOption) (expr *exprv1.ParsedExpr, err error) { diff --git a/server/router/api/v1/attachment_service.go b/server/router/api/v1/attachment_service.go index 07e29daff..3ca6799dc 100644 --- a/server/router/api/v1/attachment_service.go +++ b/server/router/api/v1/attachment_service.go @@ -116,7 +116,7 @@ func (s *APIV1Service) CreateAttachment(ctx context.Context, request *v1pb.Creat return nil, status.Errorf(codes.Internal, "failed to create attachment: %v", err) } - return s.convertAttachmentFromStore(ctx, attachment), nil + return convertAttachmentFromStore(attachment), nil } func (s *APIV1Service) ListAttachments(ctx context.Context, request *v1pb.ListAttachmentsRequest) (*v1pb.ListAttachmentsResponse, error) { @@ -182,7 +182,7 @@ func (s *APIV1Service) ListAttachments(ctx context.Context, request *v1pb.ListAt response := &v1pb.ListAttachmentsResponse{} for _, attachment := range attachments { - response.Attachments = append(response.Attachments, s.convertAttachmentFromStore(ctx, attachment)) + response.Attachments = append(response.Attachments, convertAttachmentFromStore(attachment)) } // For simplicity, set total size to the number of returned attachments. @@ -209,7 +209,7 @@ func (s *APIV1Service) GetAttachment(ctx context.Context, request *v1pb.GetAttac if attachment == nil { return nil, status.Errorf(codes.NotFound, "attachment not found") } - return s.convertAttachmentFromStore(ctx, attachment), nil + return convertAttachmentFromStore(attachment), nil } func (s *APIV1Service) GetAttachmentBinary(ctx context.Context, request *v1pb.GetAttachmentBinaryRequest) (*httpbody.HttpBody, error) { @@ -383,7 +383,7 @@ func (s *APIV1Service) DeleteAttachment(ctx context.Context, request *v1pb.Delet return &emptypb.Empty{}, nil } -func (s *APIV1Service) convertAttachmentFromStore(ctx context.Context, attachment *store.Attachment) *v1pb.Attachment { +func convertAttachmentFromStore(attachment *store.Attachment) *v1pb.Attachment { attachmentMessage := &v1pb.Attachment{ Name: fmt.Sprintf("%s%s", AttachmentNamePrefix, attachment.UID), CreateTime: timestamppb.New(time.Unix(attachment.CreatedTs, 0)), @@ -391,18 +391,13 @@ func (s *APIV1Service) convertAttachmentFromStore(ctx context.Context, attachmen Type: attachment.Type, Size: attachment.Size, } + if attachment.MemoUID != nil && *attachment.MemoUID != "" { + memoName := fmt.Sprintf("%s%s", MemoNamePrefix, *attachment.MemoUID) + attachmentMessage.Memo = &memoName + } if attachment.StorageType == storepb.AttachmentStorageType_EXTERNAL || attachment.StorageType == storepb.AttachmentStorageType_S3 { attachmentMessage.ExternalLink = attachment.Reference } - if attachment.MemoID != nil { - memo, _ := s.Store.GetMemo(ctx, &store.FindMemo{ - ID: attachment.MemoID, - }) - if memo != nil { - memoName := fmt.Sprintf("%s%s", MemoNamePrefix, memo.UID) - attachmentMessage.Memo = &memoName - } - } return attachmentMessage } diff --git a/server/router/api/v1/memo_attachment_service.go b/server/router/api/v1/memo_attachment_service.go index 95e0d480b..e7c7f18ac 100644 --- a/server/router/api/v1/memo_attachment_service.go +++ b/server/router/api/v1/memo_attachment_service.go @@ -96,7 +96,7 @@ func (s *APIV1Service) ListMemoAttachments(ctx context.Context, request *v1pb.Li Attachments: []*v1pb.Attachment{}, } for _, attachment := range attachments { - response.Attachments = append(response.Attachments, s.convertAttachmentFromStore(ctx, attachment)) + response.Attachments = append(response.Attachments, convertAttachmentFromStore(attachment)) } return response, nil } diff --git a/server/router/api/v1/memo_service.go b/server/router/api/v1/memo_service.go index 378ea7bcf..57f6bfab8 100644 --- a/server/router/api/v1/memo_service.go +++ b/server/router/api/v1/memo_service.go @@ -63,6 +63,9 @@ func (s *APIV1Service) CreateMemo(ctx context.Context, request *v1pb.CreateMemoR if err != nil { return nil, err } + + attachments := []*store.Attachment{} + if len(request.Memo.Attachments) > 0 { _, err := s.SetMemoAttachments(ctx, &v1pb.SetMemoAttachmentsRequest{ Name: fmt.Sprintf("%s%s", MemoNamePrefix, memo.UID), @@ -71,6 +74,14 @@ func (s *APIV1Service) CreateMemo(ctx context.Context, request *v1pb.CreateMemoR if err != nil { return nil, errors.Wrap(err, "failed to set memo attachments") } + + a, err := s.Store.ListAttachments(ctx, &store.FindAttachment{ + MemoID: &memo.ID, + }) + if err != nil { + return nil, errors.Wrap(err, "failed to get memo attachments") + } + attachments = a } if len(request.Memo.Relations) > 0 { _, err := s.SetMemoRelations(ctx, &v1pb.SetMemoRelationsRequest{ @@ -82,7 +93,7 @@ func (s *APIV1Service) CreateMemo(ctx context.Context, request *v1pb.CreateMemoR } } - memoMessage, err := s.convertMemoFromStore(ctx, memo, []*store.Reaction{}) + memoMessage, err := s.convertMemoFromStore(ctx, memo, nil, attachments) if err != nil { return nil, errors.Wrap(err, "failed to convert memo") } @@ -190,8 +201,12 @@ func (s *APIV1Service) ListMemos(ctx context.Context, request *v1pb.ListMemosReq reactionMap := make(map[string][]*store.Reaction) memoNames := make([]string, 0, len(memos)) + attachmentMap := make(map[int32][]*store.Attachment) + memoIDs := make([]string, 0, len(memos)) + for _, m := range memos { memoNames = append(memoNames, fmt.Sprintf("'%s%s'", MemoNamePrefix, m.UID)) + memoIDs = append(memoIDs, fmt.Sprintf("'%d'", m.ID)) } // REACTIONS @@ -205,9 +220,23 @@ func (s *APIV1Service) ListMemos(ctx context.Context, request *v1pb.ListMemosReq reactionMap[reaction.ContentID] = append(reactionMap[reaction.ContentID], reaction) } + // ATTACHMENTS + attachments, err := s.Store.ListAttachments(ctx, &store.FindAttachment{ + Filters: []string{fmt.Sprintf("memo_id in [%s]", strings.Join(memoIDs, ", "))}, + }) + if err != nil { + return nil, status.Errorf(codes.Internal, "failed to list attachments") + } + for _, attachment := range attachments { + attachmentMap[*attachment.MemoID] = append(attachmentMap[*attachment.MemoID], attachment) + } + for _, memo := range memos { - name := fmt.Sprintf("%s%s", MemoNamePrefix, memo.UID) - memoMessage, err := s.convertMemoFromStore(ctx, memo, reactionMap[name]) + memoName := fmt.Sprintf("%s%s", MemoNamePrefix, memo.UID) + reactions := reactionMap[memoName] + attachments := attachmentMap[memo.ID] + + memoMessage, err := s.convertMemoFromStore(ctx, memo, reactions, attachments) if err != nil { return nil, errors.Wrap(err, "failed to convert memo") } @@ -256,7 +285,14 @@ func (s *APIV1Service) GetMemo(ctx context.Context, request *v1pb.GetMemoRequest return nil, status.Errorf(codes.Internal, "failed to list reactions") } - memoMessage, err := s.convertMemoFromStore(ctx, memo, reactions) + attachments, err := s.Store.ListAttachments(ctx, &store.FindAttachment{ + MemoID: &memo.ID, + }) + if err != nil { + return nil, status.Errorf(codes.Internal, "failed to list attachments") + } + + memoMessage, err := s.convertMemoFromStore(ctx, memo, reactions, attachments) if err != nil { return nil, errors.Wrap(err, "failed to convert memo") } @@ -381,8 +417,14 @@ func (s *APIV1Service) UpdateMemo(ctx context.Context, request *v1pb.UpdateMemoR if err != nil { return nil, status.Errorf(codes.Internal, "failed to list reactions") } + attachments, err := s.Store.ListAttachments(ctx, &store.FindAttachment{ + MemoID: &memo.ID, + }) + if err != nil { + return nil, status.Errorf(codes.Internal, "failed to list attachments") + } - memoMessage, err := s.convertMemoFromStore(ctx, memo, reactions) + memoMessage, err := s.convertMemoFromStore(ctx, memo, reactions, attachments) if err != nil { return nil, errors.Wrap(err, "failed to convert memo") } @@ -425,7 +467,14 @@ func (s *APIV1Service) DeleteMemo(ctx context.Context, request *v1pb.DeleteMemoR return nil, status.Errorf(codes.Internal, "failed to list reactions") } - if memoMessage, err := s.convertMemoFromStore(ctx, memo, reactions); err == nil { + attachments, err := s.Store.ListAttachments(ctx, &store.FindAttachment{ + MemoID: &memo.ID, + }) + if err != nil { + return nil, status.Errorf(codes.Internal, "failed to list attachments") + } + + if memoMessage, err := s.convertMemoFromStore(ctx, memo, reactions, attachments); err == nil { // Try to dispatch webhook when memo is deleted. if err := s.DispatchMemoDeletedWebhook(ctx, memoMessage); err != nil { slog.Warn("Failed to dispatch memo deleted webhook", slog.Any("err", err)) @@ -442,10 +491,6 @@ func (s *APIV1Service) DeleteMemo(ctx context.Context, request *v1pb.DeleteMemoR } // Delete related attachments. - attachments, err := s.Store.ListAttachments(ctx, &store.FindAttachment{MemoID: &memo.ID}) - if err != nil { - return nil, status.Errorf(codes.Internal, "failed to list attachments") - } for _, attachment := range attachments { if err := s.Store.DeleteAttachment(ctx, &store.DeleteAttachment{ID: attachment.ID}); err != nil { return nil, status.Errorf(codes.Internal, "failed to delete attachment") @@ -591,11 +636,13 @@ func (s *APIV1Service) ListMemoComments(ctx context.Context, request *v1pb.ListM memoIDToNameMap := make(map[int32]string) memoNamesForQuery := make([]string, 0, len(memos)) + memoIDsForQuery := make([]string, 0, len(memos)) for _, memo := range memos { memoName := fmt.Sprintf("%s%s", MemoNamePrefix, memo.UID) memoIDToNameMap[memo.ID] = memoName memoNamesForQuery = append(memoNamesForQuery, fmt.Sprintf("'%s'", memoName)) + memoIDsForQuery = append(memoIDsForQuery, fmt.Sprintf("'%d'", memo.ID)) } reactions, err := s.Store.ListReactions(ctx, &store.FindReaction{ Filters: []string{fmt.Sprintf("content_id in [%s]", strings.Join(memoNamesForQuery, ", "))}, @@ -609,12 +656,24 @@ func (s *APIV1Service) ListMemoComments(ctx context.Context, request *v1pb.ListM memoReactionsMap[reaction.ContentID] = append(memoReactionsMap[reaction.ContentID], reaction) } + attachments, err := s.Store.ListAttachments(ctx, &store.FindAttachment{ + Filters: []string{fmt.Sprintf("memo_id in [%s]", strings.Join(memoIDsForQuery, ", "))}, + }) + if err != nil { + return nil, status.Errorf(codes.Internal, "failed to list attachments") + } + attachmentMap := make(map[int32][]*store.Attachment) + for _, attachment := range attachments { + attachmentMap[*attachment.MemoID] = append(attachmentMap[*attachment.MemoID], attachment) + } + var memosResponse []*v1pb.Memo for _, m := range memos { memoName := memoIDToNameMap[m.ID] reactions := memoReactionsMap[memoName] + attachments := attachmentMap[m.ID] - memoMessage, err := s.convertMemoFromStore(ctx, m, reactions) + memoMessage, err := s.convertMemoFromStore(ctx, m, reactions, attachments) if err != nil { return nil, errors.Wrap(err, "failed to convert memo") } diff --git a/server/router/api/v1/memo_service_converter.go b/server/router/api/v1/memo_service_converter.go index 06500e274..cca1eef92 100644 --- a/server/router/api/v1/memo_service_converter.go +++ b/server/router/api/v1/memo_service_converter.go @@ -16,7 +16,7 @@ import ( "github.com/usememos/memos/store" ) -func (s *APIV1Service) convertMemoFromStore(ctx context.Context, memo *store.Memo, reactions []*store.Reaction) (*v1pb.Memo, error) { +func (s *APIV1Service) convertMemoFromStore(ctx context.Context, memo *store.Memo, reactions []*store.Reaction, attachments []*store.Attachment) (*v1pb.Memo, error) { displayTs := memo.CreatedTs workspaceMemoRelatedSetting, err := s.Store.GetWorkspaceMemoRelatedSetting(ctx) if err != nil { @@ -62,11 +62,12 @@ func (s *APIV1Service) convertMemoFromStore(ctx context.Context, memo *store.Mem } memoMessage.Relations = listMemoRelationsResponse.Relations - listMemoAttachmentsResponse, err := s.ListMemoAttachments(ctx, &v1pb.ListMemoAttachmentsRequest{Name: name}) - if err != nil { - return nil, errors.Wrap(err, "failed to list memo attachments") + memoMessage.Attachments = []*v1pb.Attachment{} + + for _, attachment := range attachments { + attachmentResponse := convertAttachmentFromStore(attachment) + memoMessage.Attachments = append(memoMessage.Attachments, attachmentResponse) } - memoMessage.Attachments = listMemoAttachmentsResponse.Attachments nodes, err := parser.Parse(tokenizer.Tokenize(memo.Content)) if err != nil { diff --git a/store/attachment.go b/store/attachment.go index acbc1777e..5b57d8c31 100644 --- a/store/attachment.go +++ b/store/attachment.go @@ -35,6 +35,9 @@ type Attachment struct { // The related memo ID. MemoID *int32 + + // Composed field + MemoUID *string } type FindAttachment struct { @@ -49,6 +52,7 @@ type FindAttachment struct { StorageType *storepb.AttachmentStorageType Limit *int Offset *int + Filters []string } type UpdateAttachment struct { diff --git a/store/db/mysql/attachment.go b/store/db/mysql/attachment.go index 468e9032e..525bc1573 100644 --- a/store/db/mysql/attachment.go +++ b/store/db/mysql/attachment.go @@ -9,6 +9,7 @@ import ( "github.com/pkg/errors" "google.golang.org/protobuf/encoding/protojson" + "github.com/usememos/memos/plugin/filter" storepb "github.com/usememos/memos/proto/gen/store" "github.com/usememos/memos/store" ) @@ -48,37 +49,74 @@ func (d *DB) CreateAttachment(ctx context.Context, create *store.Attachment) (*s func (d *DB) ListAttachments(ctx context.Context, find *store.FindAttachment) ([]*store.Attachment, error) { where, args := []string{"1 = 1"}, []any{} + for _, filterStr := range find.Filters { + // Parse filter string and return the parsed expression. + // The filter string should be a CEL expression. + parsedExpr, err := filter.Parse(filterStr, filter.AttachmentFilterCELAttributes...) + if err != nil { + return nil, err + } + convertCtx := filter.NewConvertContext() + // ConvertExprToSQL converts the parsed expression to a SQL condition string. + converter := filter.NewCommonSQLConverter(&filter.MySQLDialect{}) + if err := converter.ConvertExprToSQL(convertCtx, parsedExpr.GetExpr()); err != nil { + return nil, err + } + condition := convertCtx.Buffer.String() + if condition != "" { + where = append(where, fmt.Sprintf("(%s)", condition)) + args = append(args, convertCtx.Args...) + } + } + if v := find.ID; v != nil { - where, args = append(where, "`id` = ?"), append(args, *v) + where, args = append(where, "`resource`.`id` = ?"), append(args, *v) } if v := find.UID; v != nil { - where, args = append(where, "`uid` = ?"), append(args, *v) + where, args = append(where, "`resource`.`uid` = ?"), append(args, *v) } if v := find.CreatorID; v != nil { - where, args = append(where, "`creator_id` = ?"), append(args, *v) + where, args = append(where, "`resource`.`creator_id` = ?"), append(args, *v) } if v := find.Filename; v != nil { - where, args = append(where, "`filename` = ?"), append(args, *v) + where, args = append(where, "`resource`.`filename` = ?"), append(args, *v) } if v := find.FilenameSearch; v != nil { - where, args = append(where, "`filename` LIKE ?"), append(args, "%"+*v+"%") + where, args = append(where, "`resource`.`filename` LIKE ?"), append(args, "%"+*v+"%") } if v := find.MemoID; v != nil { - where, args = append(where, "`memo_id` = ?"), append(args, *v) + where, args = append(where, "`resource`.`memo_id` = ?"), append(args, *v) } if find.HasRelatedMemo { - where = append(where, "`memo_id` IS NOT NULL") + where = append(where, "`resource`.`memo_id` IS NOT NULL") } if find.StorageType != nil { - where, args = append(where, "`storage_type` = ?"), append(args, find.StorageType.String()) + where, args = append(where, "`resource`.`storage_type` = ?"), append(args, find.StorageType.String()) + } + + fields := []string{ + "`resource`.`id` AS `id`", + "`resource`.`uid` AS `uid`", + "`resource`.`filename` AS `filename`", + "`resource`.`type` AS `type`", + "`resource`.`size` AS `size`", + "`resource`.`creator_id` AS `creator_id`", + "UNIX_TIMESTAMP(`resource`.`created_ts`) AS `created_ts`", + "UNIX_TIMESTAMP(`resource`.`updated_ts`) AS `updated_ts`", + "`resource`.`memo_id` AS `memo_id`", + "`resource`.`storage_type` AS `storage_type`", + "`resource`.`reference` AS `reference`", + "`resource`.`payload` AS `payload`", + "CASE WHEN `memo`.`uid` IS NOT NULL THEN `memo`.`uid` ELSE NULL END AS `memo_uid`", } - - fields := []string{"`id`", "`uid`", "`filename`", "`type`", "`size`", "`creator_id`", "UNIX_TIMESTAMP(`created_ts`)", "UNIX_TIMESTAMP(`updated_ts`)", "`memo_id`", "`storage_type`", "`reference`", "`payload`"} if find.GetBlob { - fields = append(fields, "`blob`") + fields = append(fields, "`resource`.`blob` AS `blob`") } - query := fmt.Sprintf("SELECT %s FROM `resource` WHERE %s ORDER BY `updated_ts` DESC", strings.Join(fields, ", "), strings.Join(where, " AND ")) + query := "SELECT " + strings.Join(fields, ", ") + " FROM `resource`" + " " + + "LEFT JOIN `memo` ON `resource`.`memo_id` = `memo`.`id`" + " " + + "WHERE " + strings.Join(where, " AND ") + " " + + "ORDER BY `updated_ts` DESC" if find.Limit != nil { query = fmt.Sprintf("%s LIMIT %d", query, *find.Limit) if find.Offset != nil { @@ -111,6 +149,7 @@ func (d *DB) ListAttachments(ctx context.Context, find *store.FindAttachment) ([ &storageType, &attachment.Reference, &payloadBytes, + &attachment.MemoUID, } if find.GetBlob { dests = append(dests, &attachment.Blob) diff --git a/store/db/mysql/attachment_filter_test.go b/store/db/mysql/attachment_filter_test.go new file mode 100644 index 000000000..ea43b8bb0 --- /dev/null +++ b/store/db/mysql/attachment_filter_test.go @@ -0,0 +1,39 @@ +package mysql + +import ( + "testing" + + "github.com/stretchr/testify/require" + + "github.com/usememos/memos/plugin/filter" +) + +func TestAttachmentConvertExprToSQL(t *testing.T) { + tests := []struct { + filter string + want string + args []any + }{ + { + filter: `memo_id in ["5atZAj8GcvkSuUA3X2KLaY"]`, + want: "`resource`.`memo_id` IN (?)", + args: []any{"5atZAj8GcvkSuUA3X2KLaY"}, + }, + { + filter: `memo_id in ["5atZAj8GcvkSuUA3X2KLaY", "4EN8aEpcJ3MaK4ExHTpiTE"]`, + want: "`resource`.`memo_id` IN (?,?)", + args: []any{"5atZAj8GcvkSuUA3X2KLaY", "4EN8aEpcJ3MaK4ExHTpiTE"}, + }, + } + + for _, tt := range tests { + parsedExpr, err := filter.Parse(tt.filter, filter.AttachmentFilterCELAttributes...) + require.NoError(t, err) + convertCtx := filter.NewConvertContext() + converter := filter.NewCommonSQLConverter(&filter.MySQLDialect{}) + err = converter.ConvertExprToSQL(convertCtx, parsedExpr.GetExpr()) + require.NoError(t, err) + require.Equal(t, tt.want, convertCtx.Buffer.String()) + require.Equal(t, tt.args, convertCtx.Args) + } +} diff --git a/store/db/postgres/attachment.go b/store/db/postgres/attachment.go index da24c3710..90311c92d 100644 --- a/store/db/postgres/attachment.go +++ b/store/db/postgres/attachment.go @@ -9,6 +9,7 @@ import ( "github.com/pkg/errors" "google.golang.org/protobuf/encoding/protojson" + "github.com/usememos/memos/plugin/filter" storepb "github.com/usememos/memos/proto/gen/store" "github.com/usememos/memos/store" ) @@ -39,42 +40,77 @@ func (d *DB) CreateAttachment(ctx context.Context, create *store.Attachment) (*s func (d *DB) ListAttachments(ctx context.Context, find *store.FindAttachment) ([]*store.Attachment, error) { where, args := []string{"1 = 1"}, []any{} + for _, filterStr := range find.Filters { + // Parse filter string and return the parsed expression. + // The filter string should be a CEL expression. + parsedExpr, err := filter.Parse(filterStr, filter.AttachmentFilterCELAttributes...) + if err != nil { + return nil, err + } + convertCtx := filter.NewConvertContext() + // ConvertExprToSQL converts the parsed expression to a SQL condition string. + converter := filter.NewCommonSQLConverter(&filter.PostgreSQLDialect{}) + if err := converter.ConvertExprToSQL(convertCtx, parsedExpr.GetExpr()); err != nil { + return nil, err + } + condition := convertCtx.Buffer.String() + if condition != "" { + where = append(where, fmt.Sprintf("(%s)", condition)) + args = append(args, convertCtx.Args...) + } + } + if v := find.ID; v != nil { - where, args = append(where, "id = "+placeholder(len(args)+1)), append(args, *v) + where, args = append(where, "resource.id = "+placeholder(len(args)+1)), append(args, *v) } if v := find.UID; v != nil { - where, args = append(where, "uid = "+placeholder(len(args)+1)), append(args, *v) + where, args = append(where, "resource.uid = "+placeholder(len(args)+1)), append(args, *v) } if v := find.CreatorID; v != nil { - where, args = append(where, "creator_id = "+placeholder(len(args)+1)), append(args, *v) + where, args = append(where, "resource.creator_id = "+placeholder(len(args)+1)), append(args, *v) } if v := find.Filename; v != nil { - where, args = append(where, "filename = "+placeholder(len(args)+1)), append(args, *v) + where, args = append(where, "resource.filename = "+placeholder(len(args)+1)), append(args, *v) } if v := find.FilenameSearch; v != nil { - where, args = append(where, "filename LIKE "+placeholder(len(args)+1)), append(args, fmt.Sprintf("%%%s%%", *v)) + where, args = append(where, "resource.filename LIKE "+placeholder(len(args)+1)), append(args, fmt.Sprintf("%%%s%%", *v)) } if v := find.MemoID; v != nil { - where, args = append(where, "memo_id = "+placeholder(len(args)+1)), append(args, *v) + where, args = append(where, "resource.memo_id = "+placeholder(len(args)+1)), append(args, *v) } if find.HasRelatedMemo { - where = append(where, "memo_id IS NOT NULL") + where = append(where, "resource.memo_id IS NOT NULL") } if v := find.StorageType; v != nil { - where, args = append(where, "storage_type = "+placeholder(len(args)+1)), append(args, v.String()) + where, args = append(where, "resource.storage_type = "+placeholder(len(args)+1)), append(args, v.String()) + } + + fields := []string{ + "resource.id AS id", + "resource.uid AS uid", + "resource.filename AS filename", + "resource.type AS type", + "resource.size AS size", + "resource.creator_id AS creator_id", + "resource.created_ts AS created_ts", + "resource.updated_ts AS updated_ts", + "resource.memo_id AS memo_id", + "resource.storage_type AS storage_type", + "resource.reference AS reference", + "resource.payload AS payload", + "CASE WHEN memo.uid IS NOT NULL THEN memo.uid ELSE NULL END AS memo_uid", } - - fields := []string{"id", "uid", "filename", "type", "size", "creator_id", "created_ts", "updated_ts", "memo_id", "storage_type", "reference", "payload"} if find.GetBlob { - fields = append(fields, "blob") + fields = append(fields, "resource.blob AS blob") } query := fmt.Sprintf(` SELECT %s FROM resource + LEFT JOIN memo ON resource.memo_id = memo.id WHERE %s - ORDER BY updated_ts DESC + ORDER BY resource.updated_ts DESC `, strings.Join(fields, ", "), strings.Join(where, " AND ")) if find.Limit != nil { query = fmt.Sprintf("%s LIMIT %d", query, *find.Limit) @@ -108,6 +144,7 @@ func (d *DB) ListAttachments(ctx context.Context, find *store.FindAttachment) ([ &storageType, &attachment.Reference, &payloadBytes, + &attachment.MemoUID, } if find.GetBlob { dests = append(dests, &attachment.Blob) diff --git a/store/db/postgres/attachment_filter_test.go b/store/db/postgres/attachment_filter_test.go new file mode 100644 index 000000000..788962d68 --- /dev/null +++ b/store/db/postgres/attachment_filter_test.go @@ -0,0 +1,39 @@ +package postgres + +import ( + "testing" + + "github.com/stretchr/testify/require" + + "github.com/usememos/memos/plugin/filter" +) + +func TestAttachmentConvertExprToSQL(t *testing.T) { + tests := []struct { + filter string + want string + args []any + }{ + { + filter: `memo_id in ["5atZAj8GcvkSuUA3X2KLaY"]`, + want: "resource.memo_id IN ($1)", + args: []any{"5atZAj8GcvkSuUA3X2KLaY"}, + }, + { + filter: `memo_id in ["5atZAj8GcvkSuUA3X2KLaY", "4EN8aEpcJ3MaK4ExHTpiTE"]`, + want: "resource.memo_id IN ($1,$2)", + args: []any{"5atZAj8GcvkSuUA3X2KLaY", "4EN8aEpcJ3MaK4ExHTpiTE"}, + }, + } + + for _, tt := range tests { + parsedExpr, err := filter.Parse(tt.filter, filter.AttachmentFilterCELAttributes...) + require.NoError(t, err) + convertCtx := filter.NewConvertContext() + converter := filter.NewCommonSQLConverter(&filter.PostgreSQLDialect{}) + err = converter.ConvertExprToSQL(convertCtx, parsedExpr.GetExpr()) + require.NoError(t, err) + require.Equal(t, tt.want, convertCtx.Buffer.String()) + require.Equal(t, tt.args, convertCtx.Args) + } +} diff --git a/store/db/sqlite/attachment.go b/store/db/sqlite/attachment.go index ee547b8f4..34aaac0b7 100644 --- a/store/db/sqlite/attachment.go +++ b/store/db/sqlite/attachment.go @@ -9,6 +9,7 @@ import ( "github.com/pkg/errors" "google.golang.org/protobuf/encoding/protojson" + "github.com/usememos/memos/plugin/filter" storepb "github.com/usememos/memos/proto/gen/store" "github.com/usememos/memos/store" ) @@ -41,37 +42,74 @@ func (d *DB) CreateAttachment(ctx context.Context, create *store.Attachment) (*s func (d *DB) ListAttachments(ctx context.Context, find *store.FindAttachment) ([]*store.Attachment, error) { where, args := []string{"1 = 1"}, []any{} + for _, filterStr := range find.Filters { + // Parse filter string and return the parsed expression. + // The filter string should be a CEL expression. + parsedExpr, err := filter.Parse(filterStr, filter.AttachmentFilterCELAttributes...) + if err != nil { + return nil, err + } + convertCtx := filter.NewConvertContext() + // ConvertExprToSQL converts the parsed expression to a SQL condition string. + converter := filter.NewCommonSQLConverter(&filter.SQLiteDialect{}) + if err := converter.ConvertExprToSQL(convertCtx, parsedExpr.GetExpr()); err != nil { + return nil, err + } + condition := convertCtx.Buffer.String() + if condition != "" { + where = append(where, fmt.Sprintf("(%s)", condition)) + args = append(args, convertCtx.Args...) + } + } + if v := find.ID; v != nil { - where, args = append(where, "`id` = ?"), append(args, *v) + where, args = append(where, "`resource`.`id` = ?"), append(args, *v) } if v := find.UID; v != nil { - where, args = append(where, "`uid` = ?"), append(args, *v) + where, args = append(where, "`resource`.`uid` = ?"), append(args, *v) } if v := find.CreatorID; v != nil { - where, args = append(where, "`creator_id` = ?"), append(args, *v) + where, args = append(where, "`resource`.`creator_id` = ?"), append(args, *v) } if v := find.Filename; v != nil { - where, args = append(where, "`filename` = ?"), append(args, *v) + where, args = append(where, "`resource`.`filename` = ?"), append(args, *v) } if v := find.FilenameSearch; v != nil { - where, args = append(where, "`filename` LIKE ?"), append(args, fmt.Sprintf("%%%s%%", *v)) + where, args = append(where, "`resource`.`filename` LIKE ?"), append(args, fmt.Sprintf("%%%s%%", *v)) } if v := find.MemoID; v != nil { - where, args = append(where, "`memo_id` = ?"), append(args, *v) + where, args = append(where, "`resource`.`memo_id` = ?"), append(args, *v) } if find.HasRelatedMemo { - where = append(where, "`memo_id` IS NOT NULL") + where = append(where, "`resource`.`memo_id` IS NOT NULL") } if find.StorageType != nil { - where, args = append(where, "`storage_type` = ?"), append(args, find.StorageType.String()) + where, args = append(where, "`resource`.`storage_type` = ?"), append(args, find.StorageType.String()) + } + + fields := []string{ + "`resource`.`id` AS `id`", + "`resource`.`uid` AS `uid`", + "`resource`.`filename` AS `filename`", + "`resource`.`type` AS `type`", + "`resource`.`size` AS `size`", + "`resource`.`creator_id` AS `creator_id`", + "`resource`.`created_ts` AS `created_ts`", + "`resource`.`updated_ts` AS `updated_ts`", + "`resource`.`memo_id` AS `memo_id`", + "`resource`.`storage_type` AS `storage_type`", + "`resource`.`reference` AS `reference`", + "`resource`.`payload` AS `payload`", + "CASE WHEN `memo`.`uid` IS NOT NULL THEN `memo`.`uid` ELSE NULL END AS `memo_uid`", } - - fields := []string{"`id`", "`uid`", "`filename`", "`type`", "`size`", "`creator_id`", "`created_ts`", "`updated_ts`", "`memo_id`", "`storage_type`", "`reference`", "`payload`"} if find.GetBlob { - fields = append(fields, "`blob`") + fields = append(fields, "`resource`.`blob` AS `blob`") } - query := fmt.Sprintf("SELECT %s FROM `resource` WHERE %s ORDER BY `updated_ts` DESC", strings.Join(fields, ", "), strings.Join(where, " AND ")) + query := "SELECT " + strings.Join(fields, ", ") + " FROM `resource`" + " " + + "LEFT JOIN `memo` ON `resource`.`memo_id` = `memo`.`id`" + " " + + "WHERE " + strings.Join(where, " AND ") + " " + + "ORDER BY `resource`.`updated_ts` DESC" if find.Limit != nil { query = fmt.Sprintf("%s LIMIT %d", query, *find.Limit) if find.Offset != nil { @@ -104,6 +142,7 @@ func (d *DB) ListAttachments(ctx context.Context, find *store.FindAttachment) ([ &storageType, &attachment.Reference, &payloadBytes, + &attachment.MemoUID, } if find.GetBlob { dests = append(dests, &attachment.Blob) diff --git a/store/db/sqlite/attachment_filter_test.go b/store/db/sqlite/attachment_filter_test.go new file mode 100644 index 000000000..efe7b0c6f --- /dev/null +++ b/store/db/sqlite/attachment_filter_test.go @@ -0,0 +1,39 @@ +package sqlite + +import ( + "testing" + + "github.com/stretchr/testify/require" + + "github.com/usememos/memos/plugin/filter" +) + +func TestAttachmentConvertExprToSQL(t *testing.T) { + tests := []struct { + filter string + want string + args []any + }{ + { + filter: `memo_id in ["5atZAj8GcvkSuUA3X2KLaY"]`, + want: "`resource`.`memo_id` IN (?)", + args: []any{"5atZAj8GcvkSuUA3X2KLaY"}, + }, + { + filter: `memo_id in ["5atZAj8GcvkSuUA3X2KLaY", "4EN8aEpcJ3MaK4ExHTpiTE"]`, + want: "`resource`.`memo_id` IN (?,?)", + args: []any{"5atZAj8GcvkSuUA3X2KLaY", "4EN8aEpcJ3MaK4ExHTpiTE"}, + }, + } + + for _, tt := range tests { + parsedExpr, err := filter.Parse(tt.filter, filter.AttachmentFilterCELAttributes...) + require.NoError(t, err) + convertCtx := filter.NewConvertContext() + converter := filter.NewCommonSQLConverter(&filter.SQLiteDialect{}) + err = converter.ConvertExprToSQL(convertCtx, parsedExpr.GetExpr()) + require.NoError(t, err) + require.Equal(t, tt.want, convertCtx.Buffer.String()) + require.Equal(t, tt.args, convertCtx.Args) + } +}