feat: attachments by id (#5008)

pull/5012/head
varsnotwars 2 months ago committed by GitHub
parent a3add85c95
commit 4eb5b67baf
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -207,7 +207,7 @@ func (c *CommonSQLConverter) handleInOperator(ctx *ConvertContext, callExpr *exp
return err
}
if !slices.Contains([]string{"tag", "visibility", "content_id"}, identifier) {
if !slices.Contains([]string{"tag", "visibility", "content_id", "memo_id"}, identifier) {
return errors.Errorf("invalid identifier for %s", callExpr.Function)
}
@ -226,6 +226,8 @@ func (c *CommonSQLConverter) handleInOperator(ctx *ConvertContext, callExpr *exp
return c.handleVisibilityInList(ctx, values)
} else if identifier == "content_id" {
return c.handleContentIDInList(ctx, values)
} else if identifier == "memo_id" {
return c.handleMemoIDInList(ctx, values)
}
return nil
@ -333,6 +335,28 @@ func (c *CommonSQLConverter) handleContentIDInList(ctx *ConvertContext, values [
return nil
}
func (c *CommonSQLConverter) handleMemoIDInList(ctx *ConvertContext, values []any) error {
placeholders := []string{}
for range values {
placeholders = append(placeholders, c.dialect.GetParameterPlaceholder(c.paramIndex))
c.paramIndex++
}
tablePrefix := c.dialect.GetTablePrefix("resource")
if _, ok := c.dialect.(*PostgreSQLDialect); ok {
if _, err := ctx.Buffer.WriteString(fmt.Sprintf("%s.memo_id IN (%s)", tablePrefix, strings.Join(placeholders, ","))); err != nil {
return err
}
} else {
if _, err := ctx.Buffer.WriteString(fmt.Sprintf("%s.`memo_id` IN (%s)", tablePrefix, strings.Join(placeholders, ","))); err != nil {
return err
}
}
ctx.Args = append(ctx.Args, values...)
return nil
}
func (c *CommonSQLConverter) handleContainsOperator(ctx *ConvertContext, callExpr *exprv1.Expr_Call) error {
if len(callExpr.Args) != 1 {
return errors.Errorf("invalid number of arguments for %s", callExpr.Function)

@ -46,6 +46,11 @@ var UserFilterCELAttributes = []cel.EnvOption{
cel.Variable("username", cel.StringType),
}
// AttachmentFilterCELAttributes are the CEL attributes for user.
var AttachmentFilterCELAttributes = []cel.EnvOption{
cel.Variable("memo_id", cel.StringType),
}
// Parse parses the filter string and returns the parsed expression.
// The filter string should be a CEL expression.
func Parse(filter string, opts ...cel.EnvOption) (expr *exprv1.ParsedExpr, err error) {

@ -116,7 +116,7 @@ func (s *APIV1Service) CreateAttachment(ctx context.Context, request *v1pb.Creat
return nil, status.Errorf(codes.Internal, "failed to create attachment: %v", err)
}
return s.convertAttachmentFromStore(ctx, attachment), nil
return convertAttachmentFromStore(attachment), nil
}
func (s *APIV1Service) ListAttachments(ctx context.Context, request *v1pb.ListAttachmentsRequest) (*v1pb.ListAttachmentsResponse, error) {
@ -182,7 +182,7 @@ func (s *APIV1Service) ListAttachments(ctx context.Context, request *v1pb.ListAt
response := &v1pb.ListAttachmentsResponse{}
for _, attachment := range attachments {
response.Attachments = append(response.Attachments, s.convertAttachmentFromStore(ctx, attachment))
response.Attachments = append(response.Attachments, convertAttachmentFromStore(attachment))
}
// For simplicity, set total size to the number of returned attachments.
@ -209,7 +209,7 @@ func (s *APIV1Service) GetAttachment(ctx context.Context, request *v1pb.GetAttac
if attachment == nil {
return nil, status.Errorf(codes.NotFound, "attachment not found")
}
return s.convertAttachmentFromStore(ctx, attachment), nil
return convertAttachmentFromStore(attachment), nil
}
func (s *APIV1Service) GetAttachmentBinary(ctx context.Context, request *v1pb.GetAttachmentBinaryRequest) (*httpbody.HttpBody, error) {
@ -383,7 +383,7 @@ func (s *APIV1Service) DeleteAttachment(ctx context.Context, request *v1pb.Delet
return &emptypb.Empty{}, nil
}
func (s *APIV1Service) convertAttachmentFromStore(ctx context.Context, attachment *store.Attachment) *v1pb.Attachment {
func convertAttachmentFromStore(attachment *store.Attachment) *v1pb.Attachment {
attachmentMessage := &v1pb.Attachment{
Name: fmt.Sprintf("%s%s", AttachmentNamePrefix, attachment.UID),
CreateTime: timestamppb.New(time.Unix(attachment.CreatedTs, 0)),
@ -391,18 +391,13 @@ func (s *APIV1Service) convertAttachmentFromStore(ctx context.Context, attachmen
Type: attachment.Type,
Size: attachment.Size,
}
if attachment.MemoUID != nil && *attachment.MemoUID != "" {
memoName := fmt.Sprintf("%s%s", MemoNamePrefix, *attachment.MemoUID)
attachmentMessage.Memo = &memoName
}
if attachment.StorageType == storepb.AttachmentStorageType_EXTERNAL || attachment.StorageType == storepb.AttachmentStorageType_S3 {
attachmentMessage.ExternalLink = attachment.Reference
}
if attachment.MemoID != nil {
memo, _ := s.Store.GetMemo(ctx, &store.FindMemo{
ID: attachment.MemoID,
})
if memo != nil {
memoName := fmt.Sprintf("%s%s", MemoNamePrefix, memo.UID)
attachmentMessage.Memo = &memoName
}
}
return attachmentMessage
}

@ -96,7 +96,7 @@ func (s *APIV1Service) ListMemoAttachments(ctx context.Context, request *v1pb.Li
Attachments: []*v1pb.Attachment{},
}
for _, attachment := range attachments {
response.Attachments = append(response.Attachments, s.convertAttachmentFromStore(ctx, attachment))
response.Attachments = append(response.Attachments, convertAttachmentFromStore(attachment))
}
return response, nil
}

@ -63,6 +63,9 @@ func (s *APIV1Service) CreateMemo(ctx context.Context, request *v1pb.CreateMemoR
if err != nil {
return nil, err
}
attachments := []*store.Attachment{}
if len(request.Memo.Attachments) > 0 {
_, err := s.SetMemoAttachments(ctx, &v1pb.SetMemoAttachmentsRequest{
Name: fmt.Sprintf("%s%s", MemoNamePrefix, memo.UID),
@ -71,6 +74,14 @@ func (s *APIV1Service) CreateMemo(ctx context.Context, request *v1pb.CreateMemoR
if err != nil {
return nil, errors.Wrap(err, "failed to set memo attachments")
}
a, err := s.Store.ListAttachments(ctx, &store.FindAttachment{
MemoID: &memo.ID,
})
if err != nil {
return nil, errors.Wrap(err, "failed to get memo attachments")
}
attachments = a
}
if len(request.Memo.Relations) > 0 {
_, err := s.SetMemoRelations(ctx, &v1pb.SetMemoRelationsRequest{
@ -82,7 +93,7 @@ func (s *APIV1Service) CreateMemo(ctx context.Context, request *v1pb.CreateMemoR
}
}
memoMessage, err := s.convertMemoFromStore(ctx, memo, []*store.Reaction{})
memoMessage, err := s.convertMemoFromStore(ctx, memo, nil, attachments)
if err != nil {
return nil, errors.Wrap(err, "failed to convert memo")
}
@ -190,8 +201,12 @@ func (s *APIV1Service) ListMemos(ctx context.Context, request *v1pb.ListMemosReq
reactionMap := make(map[string][]*store.Reaction)
memoNames := make([]string, 0, len(memos))
attachmentMap := make(map[int32][]*store.Attachment)
memoIDs := make([]string, 0, len(memos))
for _, m := range memos {
memoNames = append(memoNames, fmt.Sprintf("'%s%s'", MemoNamePrefix, m.UID))
memoIDs = append(memoIDs, fmt.Sprintf("'%d'", m.ID))
}
// REACTIONS
@ -205,9 +220,23 @@ func (s *APIV1Service) ListMemos(ctx context.Context, request *v1pb.ListMemosReq
reactionMap[reaction.ContentID] = append(reactionMap[reaction.ContentID], reaction)
}
// ATTACHMENTS
attachments, err := s.Store.ListAttachments(ctx, &store.FindAttachment{
Filters: []string{fmt.Sprintf("memo_id in [%s]", strings.Join(memoIDs, ", "))},
})
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to list attachments")
}
for _, attachment := range attachments {
attachmentMap[*attachment.MemoID] = append(attachmentMap[*attachment.MemoID], attachment)
}
for _, memo := range memos {
name := fmt.Sprintf("%s%s", MemoNamePrefix, memo.UID)
memoMessage, err := s.convertMemoFromStore(ctx, memo, reactionMap[name])
memoName := fmt.Sprintf("%s%s", MemoNamePrefix, memo.UID)
reactions := reactionMap[memoName]
attachments := attachmentMap[memo.ID]
memoMessage, err := s.convertMemoFromStore(ctx, memo, reactions, attachments)
if err != nil {
return nil, errors.Wrap(err, "failed to convert memo")
}
@ -256,7 +285,14 @@ func (s *APIV1Service) GetMemo(ctx context.Context, request *v1pb.GetMemoRequest
return nil, status.Errorf(codes.Internal, "failed to list reactions")
}
memoMessage, err := s.convertMemoFromStore(ctx, memo, reactions)
attachments, err := s.Store.ListAttachments(ctx, &store.FindAttachment{
MemoID: &memo.ID,
})
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to list attachments")
}
memoMessage, err := s.convertMemoFromStore(ctx, memo, reactions, attachments)
if err != nil {
return nil, errors.Wrap(err, "failed to convert memo")
}
@ -381,8 +417,14 @@ func (s *APIV1Service) UpdateMemo(ctx context.Context, request *v1pb.UpdateMemoR
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to list reactions")
}
attachments, err := s.Store.ListAttachments(ctx, &store.FindAttachment{
MemoID: &memo.ID,
})
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to list attachments")
}
memoMessage, err := s.convertMemoFromStore(ctx, memo, reactions)
memoMessage, err := s.convertMemoFromStore(ctx, memo, reactions, attachments)
if err != nil {
return nil, errors.Wrap(err, "failed to convert memo")
}
@ -425,7 +467,14 @@ func (s *APIV1Service) DeleteMemo(ctx context.Context, request *v1pb.DeleteMemoR
return nil, status.Errorf(codes.Internal, "failed to list reactions")
}
if memoMessage, err := s.convertMemoFromStore(ctx, memo, reactions); err == nil {
attachments, err := s.Store.ListAttachments(ctx, &store.FindAttachment{
MemoID: &memo.ID,
})
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to list attachments")
}
if memoMessage, err := s.convertMemoFromStore(ctx, memo, reactions, attachments); err == nil {
// Try to dispatch webhook when memo is deleted.
if err := s.DispatchMemoDeletedWebhook(ctx, memoMessage); err != nil {
slog.Warn("Failed to dispatch memo deleted webhook", slog.Any("err", err))
@ -442,10 +491,6 @@ func (s *APIV1Service) DeleteMemo(ctx context.Context, request *v1pb.DeleteMemoR
}
// Delete related attachments.
attachments, err := s.Store.ListAttachments(ctx, &store.FindAttachment{MemoID: &memo.ID})
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to list attachments")
}
for _, attachment := range attachments {
if err := s.Store.DeleteAttachment(ctx, &store.DeleteAttachment{ID: attachment.ID}); err != nil {
return nil, status.Errorf(codes.Internal, "failed to delete attachment")
@ -591,11 +636,13 @@ func (s *APIV1Service) ListMemoComments(ctx context.Context, request *v1pb.ListM
memoIDToNameMap := make(map[int32]string)
memoNamesForQuery := make([]string, 0, len(memos))
memoIDsForQuery := make([]string, 0, len(memos))
for _, memo := range memos {
memoName := fmt.Sprintf("%s%s", MemoNamePrefix, memo.UID)
memoIDToNameMap[memo.ID] = memoName
memoNamesForQuery = append(memoNamesForQuery, fmt.Sprintf("'%s'", memoName))
memoIDsForQuery = append(memoIDsForQuery, fmt.Sprintf("'%d'", memo.ID))
}
reactions, err := s.Store.ListReactions(ctx, &store.FindReaction{
Filters: []string{fmt.Sprintf("content_id in [%s]", strings.Join(memoNamesForQuery, ", "))},
@ -609,12 +656,24 @@ func (s *APIV1Service) ListMemoComments(ctx context.Context, request *v1pb.ListM
memoReactionsMap[reaction.ContentID] = append(memoReactionsMap[reaction.ContentID], reaction)
}
attachments, err := s.Store.ListAttachments(ctx, &store.FindAttachment{
Filters: []string{fmt.Sprintf("memo_id in [%s]", strings.Join(memoIDsForQuery, ", "))},
})
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to list attachments")
}
attachmentMap := make(map[int32][]*store.Attachment)
for _, attachment := range attachments {
attachmentMap[*attachment.MemoID] = append(attachmentMap[*attachment.MemoID], attachment)
}
var memosResponse []*v1pb.Memo
for _, m := range memos {
memoName := memoIDToNameMap[m.ID]
reactions := memoReactionsMap[memoName]
attachments := attachmentMap[m.ID]
memoMessage, err := s.convertMemoFromStore(ctx, m, reactions)
memoMessage, err := s.convertMemoFromStore(ctx, m, reactions, attachments)
if err != nil {
return nil, errors.Wrap(err, "failed to convert memo")
}

@ -16,7 +16,7 @@ import (
"github.com/usememos/memos/store"
)
func (s *APIV1Service) convertMemoFromStore(ctx context.Context, memo *store.Memo, reactions []*store.Reaction) (*v1pb.Memo, error) {
func (s *APIV1Service) convertMemoFromStore(ctx context.Context, memo *store.Memo, reactions []*store.Reaction, attachments []*store.Attachment) (*v1pb.Memo, error) {
displayTs := memo.CreatedTs
workspaceMemoRelatedSetting, err := s.Store.GetWorkspaceMemoRelatedSetting(ctx)
if err != nil {
@ -62,11 +62,12 @@ func (s *APIV1Service) convertMemoFromStore(ctx context.Context, memo *store.Mem
}
memoMessage.Relations = listMemoRelationsResponse.Relations
listMemoAttachmentsResponse, err := s.ListMemoAttachments(ctx, &v1pb.ListMemoAttachmentsRequest{Name: name})
if err != nil {
return nil, errors.Wrap(err, "failed to list memo attachments")
memoMessage.Attachments = []*v1pb.Attachment{}
for _, attachment := range attachments {
attachmentResponse := convertAttachmentFromStore(attachment)
memoMessage.Attachments = append(memoMessage.Attachments, attachmentResponse)
}
memoMessage.Attachments = listMemoAttachmentsResponse.Attachments
nodes, err := parser.Parse(tokenizer.Tokenize(memo.Content))
if err != nil {

@ -35,6 +35,9 @@ type Attachment struct {
// The related memo ID.
MemoID *int32
// Composed field
MemoUID *string
}
type FindAttachment struct {
@ -49,6 +52,7 @@ type FindAttachment struct {
StorageType *storepb.AttachmentStorageType
Limit *int
Offset *int
Filters []string
}
type UpdateAttachment struct {

@ -9,6 +9,7 @@ import (
"github.com/pkg/errors"
"google.golang.org/protobuf/encoding/protojson"
"github.com/usememos/memos/plugin/filter"
storepb "github.com/usememos/memos/proto/gen/store"
"github.com/usememos/memos/store"
)
@ -48,37 +49,74 @@ func (d *DB) CreateAttachment(ctx context.Context, create *store.Attachment) (*s
func (d *DB) ListAttachments(ctx context.Context, find *store.FindAttachment) ([]*store.Attachment, error) {
where, args := []string{"1 = 1"}, []any{}
for _, filterStr := range find.Filters {
// Parse filter string and return the parsed expression.
// The filter string should be a CEL expression.
parsedExpr, err := filter.Parse(filterStr, filter.AttachmentFilterCELAttributes...)
if err != nil {
return nil, err
}
convertCtx := filter.NewConvertContext()
// ConvertExprToSQL converts the parsed expression to a SQL condition string.
converter := filter.NewCommonSQLConverter(&filter.MySQLDialect{})
if err := converter.ConvertExprToSQL(convertCtx, parsedExpr.GetExpr()); err != nil {
return nil, err
}
condition := convertCtx.Buffer.String()
if condition != "" {
where = append(where, fmt.Sprintf("(%s)", condition))
args = append(args, convertCtx.Args...)
}
}
if v := find.ID; v != nil {
where, args = append(where, "`id` = ?"), append(args, *v)
where, args = append(where, "`resource`.`id` = ?"), append(args, *v)
}
if v := find.UID; v != nil {
where, args = append(where, "`uid` = ?"), append(args, *v)
where, args = append(where, "`resource`.`uid` = ?"), append(args, *v)
}
if v := find.CreatorID; v != nil {
where, args = append(where, "`creator_id` = ?"), append(args, *v)
where, args = append(where, "`resource`.`creator_id` = ?"), append(args, *v)
}
if v := find.Filename; v != nil {
where, args = append(where, "`filename` = ?"), append(args, *v)
where, args = append(where, "`resource`.`filename` = ?"), append(args, *v)
}
if v := find.FilenameSearch; v != nil {
where, args = append(where, "`filename` LIKE ?"), append(args, "%"+*v+"%")
where, args = append(where, "`resource`.`filename` LIKE ?"), append(args, "%"+*v+"%")
}
if v := find.MemoID; v != nil {
where, args = append(where, "`memo_id` = ?"), append(args, *v)
where, args = append(where, "`resource`.`memo_id` = ?"), append(args, *v)
}
if find.HasRelatedMemo {
where = append(where, "`memo_id` IS NOT NULL")
where = append(where, "`resource`.`memo_id` IS NOT NULL")
}
if find.StorageType != nil {
where, args = append(where, "`storage_type` = ?"), append(args, find.StorageType.String())
where, args = append(where, "`resource`.`storage_type` = ?"), append(args, find.StorageType.String())
}
fields := []string{
"`resource`.`id` AS `id`",
"`resource`.`uid` AS `uid`",
"`resource`.`filename` AS `filename`",
"`resource`.`type` AS `type`",
"`resource`.`size` AS `size`",
"`resource`.`creator_id` AS `creator_id`",
"UNIX_TIMESTAMP(`resource`.`created_ts`) AS `created_ts`",
"UNIX_TIMESTAMP(`resource`.`updated_ts`) AS `updated_ts`",
"`resource`.`memo_id` AS `memo_id`",
"`resource`.`storage_type` AS `storage_type`",
"`resource`.`reference` AS `reference`",
"`resource`.`payload` AS `payload`",
"CASE WHEN `memo`.`uid` IS NOT NULL THEN `memo`.`uid` ELSE NULL END AS `memo_uid`",
}
fields := []string{"`id`", "`uid`", "`filename`", "`type`", "`size`", "`creator_id`", "UNIX_TIMESTAMP(`created_ts`)", "UNIX_TIMESTAMP(`updated_ts`)", "`memo_id`", "`storage_type`", "`reference`", "`payload`"}
if find.GetBlob {
fields = append(fields, "`blob`")
fields = append(fields, "`resource`.`blob` AS `blob`")
}
query := fmt.Sprintf("SELECT %s FROM `resource` WHERE %s ORDER BY `updated_ts` DESC", strings.Join(fields, ", "), strings.Join(where, " AND "))
query := "SELECT " + strings.Join(fields, ", ") + " FROM `resource`" + " " +
"LEFT JOIN `memo` ON `resource`.`memo_id` = `memo`.`id`" + " " +
"WHERE " + strings.Join(where, " AND ") + " " +
"ORDER BY `updated_ts` DESC"
if find.Limit != nil {
query = fmt.Sprintf("%s LIMIT %d", query, *find.Limit)
if find.Offset != nil {
@ -111,6 +149,7 @@ func (d *DB) ListAttachments(ctx context.Context, find *store.FindAttachment) ([
&storageType,
&attachment.Reference,
&payloadBytes,
&attachment.MemoUID,
}
if find.GetBlob {
dests = append(dests, &attachment.Blob)

@ -0,0 +1,39 @@
package mysql
import (
"testing"
"github.com/stretchr/testify/require"
"github.com/usememos/memos/plugin/filter"
)
func TestAttachmentConvertExprToSQL(t *testing.T) {
tests := []struct {
filter string
want string
args []any
}{
{
filter: `memo_id in ["5atZAj8GcvkSuUA3X2KLaY"]`,
want: "`resource`.`memo_id` IN (?)",
args: []any{"5atZAj8GcvkSuUA3X2KLaY"},
},
{
filter: `memo_id in ["5atZAj8GcvkSuUA3X2KLaY", "4EN8aEpcJ3MaK4ExHTpiTE"]`,
want: "`resource`.`memo_id` IN (?,?)",
args: []any{"5atZAj8GcvkSuUA3X2KLaY", "4EN8aEpcJ3MaK4ExHTpiTE"},
},
}
for _, tt := range tests {
parsedExpr, err := filter.Parse(tt.filter, filter.AttachmentFilterCELAttributes...)
require.NoError(t, err)
convertCtx := filter.NewConvertContext()
converter := filter.NewCommonSQLConverter(&filter.MySQLDialect{})
err = converter.ConvertExprToSQL(convertCtx, parsedExpr.GetExpr())
require.NoError(t, err)
require.Equal(t, tt.want, convertCtx.Buffer.String())
require.Equal(t, tt.args, convertCtx.Args)
}
}

@ -9,6 +9,7 @@ import (
"github.com/pkg/errors"
"google.golang.org/protobuf/encoding/protojson"
"github.com/usememos/memos/plugin/filter"
storepb "github.com/usememos/memos/proto/gen/store"
"github.com/usememos/memos/store"
)
@ -39,42 +40,77 @@ func (d *DB) CreateAttachment(ctx context.Context, create *store.Attachment) (*s
func (d *DB) ListAttachments(ctx context.Context, find *store.FindAttachment) ([]*store.Attachment, error) {
where, args := []string{"1 = 1"}, []any{}
for _, filterStr := range find.Filters {
// Parse filter string and return the parsed expression.
// The filter string should be a CEL expression.
parsedExpr, err := filter.Parse(filterStr, filter.AttachmentFilterCELAttributes...)
if err != nil {
return nil, err
}
convertCtx := filter.NewConvertContext()
// ConvertExprToSQL converts the parsed expression to a SQL condition string.
converter := filter.NewCommonSQLConverter(&filter.PostgreSQLDialect{})
if err := converter.ConvertExprToSQL(convertCtx, parsedExpr.GetExpr()); err != nil {
return nil, err
}
condition := convertCtx.Buffer.String()
if condition != "" {
where = append(where, fmt.Sprintf("(%s)", condition))
args = append(args, convertCtx.Args...)
}
}
if v := find.ID; v != nil {
where, args = append(where, "id = "+placeholder(len(args)+1)), append(args, *v)
where, args = append(where, "resource.id = "+placeholder(len(args)+1)), append(args, *v)
}
if v := find.UID; v != nil {
where, args = append(where, "uid = "+placeholder(len(args)+1)), append(args, *v)
where, args = append(where, "resource.uid = "+placeholder(len(args)+1)), append(args, *v)
}
if v := find.CreatorID; v != nil {
where, args = append(where, "creator_id = "+placeholder(len(args)+1)), append(args, *v)
where, args = append(where, "resource.creator_id = "+placeholder(len(args)+1)), append(args, *v)
}
if v := find.Filename; v != nil {
where, args = append(where, "filename = "+placeholder(len(args)+1)), append(args, *v)
where, args = append(where, "resource.filename = "+placeholder(len(args)+1)), append(args, *v)
}
if v := find.FilenameSearch; v != nil {
where, args = append(where, "filename LIKE "+placeholder(len(args)+1)), append(args, fmt.Sprintf("%%%s%%", *v))
where, args = append(where, "resource.filename LIKE "+placeholder(len(args)+1)), append(args, fmt.Sprintf("%%%s%%", *v))
}
if v := find.MemoID; v != nil {
where, args = append(where, "memo_id = "+placeholder(len(args)+1)), append(args, *v)
where, args = append(where, "resource.memo_id = "+placeholder(len(args)+1)), append(args, *v)
}
if find.HasRelatedMemo {
where = append(where, "memo_id IS NOT NULL")
where = append(where, "resource.memo_id IS NOT NULL")
}
if v := find.StorageType; v != nil {
where, args = append(where, "storage_type = "+placeholder(len(args)+1)), append(args, v.String())
where, args = append(where, "resource.storage_type = "+placeholder(len(args)+1)), append(args, v.String())
}
fields := []string{
"resource.id AS id",
"resource.uid AS uid",
"resource.filename AS filename",
"resource.type AS type",
"resource.size AS size",
"resource.creator_id AS creator_id",
"resource.created_ts AS created_ts",
"resource.updated_ts AS updated_ts",
"resource.memo_id AS memo_id",
"resource.storage_type AS storage_type",
"resource.reference AS reference",
"resource.payload AS payload",
"CASE WHEN memo.uid IS NOT NULL THEN memo.uid ELSE NULL END AS memo_uid",
}
fields := []string{"id", "uid", "filename", "type", "size", "creator_id", "created_ts", "updated_ts", "memo_id", "storage_type", "reference", "payload"}
if find.GetBlob {
fields = append(fields, "blob")
fields = append(fields, "resource.blob AS blob")
}
query := fmt.Sprintf(`
SELECT
%s
FROM resource
LEFT JOIN memo ON resource.memo_id = memo.id
WHERE %s
ORDER BY updated_ts DESC
ORDER BY resource.updated_ts DESC
`, strings.Join(fields, ", "), strings.Join(where, " AND "))
if find.Limit != nil {
query = fmt.Sprintf("%s LIMIT %d", query, *find.Limit)
@ -108,6 +144,7 @@ func (d *DB) ListAttachments(ctx context.Context, find *store.FindAttachment) ([
&storageType,
&attachment.Reference,
&payloadBytes,
&attachment.MemoUID,
}
if find.GetBlob {
dests = append(dests, &attachment.Blob)

@ -0,0 +1,39 @@
package postgres
import (
"testing"
"github.com/stretchr/testify/require"
"github.com/usememos/memos/plugin/filter"
)
func TestAttachmentConvertExprToSQL(t *testing.T) {
tests := []struct {
filter string
want string
args []any
}{
{
filter: `memo_id in ["5atZAj8GcvkSuUA3X2KLaY"]`,
want: "resource.memo_id IN ($1)",
args: []any{"5atZAj8GcvkSuUA3X2KLaY"},
},
{
filter: `memo_id in ["5atZAj8GcvkSuUA3X2KLaY", "4EN8aEpcJ3MaK4ExHTpiTE"]`,
want: "resource.memo_id IN ($1,$2)",
args: []any{"5atZAj8GcvkSuUA3X2KLaY", "4EN8aEpcJ3MaK4ExHTpiTE"},
},
}
for _, tt := range tests {
parsedExpr, err := filter.Parse(tt.filter, filter.AttachmentFilterCELAttributes...)
require.NoError(t, err)
convertCtx := filter.NewConvertContext()
converter := filter.NewCommonSQLConverter(&filter.PostgreSQLDialect{})
err = converter.ConvertExprToSQL(convertCtx, parsedExpr.GetExpr())
require.NoError(t, err)
require.Equal(t, tt.want, convertCtx.Buffer.String())
require.Equal(t, tt.args, convertCtx.Args)
}
}

@ -9,6 +9,7 @@ import (
"github.com/pkg/errors"
"google.golang.org/protobuf/encoding/protojson"
"github.com/usememos/memos/plugin/filter"
storepb "github.com/usememos/memos/proto/gen/store"
"github.com/usememos/memos/store"
)
@ -41,37 +42,74 @@ func (d *DB) CreateAttachment(ctx context.Context, create *store.Attachment) (*s
func (d *DB) ListAttachments(ctx context.Context, find *store.FindAttachment) ([]*store.Attachment, error) {
where, args := []string{"1 = 1"}, []any{}
for _, filterStr := range find.Filters {
// Parse filter string and return the parsed expression.
// The filter string should be a CEL expression.
parsedExpr, err := filter.Parse(filterStr, filter.AttachmentFilterCELAttributes...)
if err != nil {
return nil, err
}
convertCtx := filter.NewConvertContext()
// ConvertExprToSQL converts the parsed expression to a SQL condition string.
converter := filter.NewCommonSQLConverter(&filter.SQLiteDialect{})
if err := converter.ConvertExprToSQL(convertCtx, parsedExpr.GetExpr()); err != nil {
return nil, err
}
condition := convertCtx.Buffer.String()
if condition != "" {
where = append(where, fmt.Sprintf("(%s)", condition))
args = append(args, convertCtx.Args...)
}
}
if v := find.ID; v != nil {
where, args = append(where, "`id` = ?"), append(args, *v)
where, args = append(where, "`resource`.`id` = ?"), append(args, *v)
}
if v := find.UID; v != nil {
where, args = append(where, "`uid` = ?"), append(args, *v)
where, args = append(where, "`resource`.`uid` = ?"), append(args, *v)
}
if v := find.CreatorID; v != nil {
where, args = append(where, "`creator_id` = ?"), append(args, *v)
where, args = append(where, "`resource`.`creator_id` = ?"), append(args, *v)
}
if v := find.Filename; v != nil {
where, args = append(where, "`filename` = ?"), append(args, *v)
where, args = append(where, "`resource`.`filename` = ?"), append(args, *v)
}
if v := find.FilenameSearch; v != nil {
where, args = append(where, "`filename` LIKE ?"), append(args, fmt.Sprintf("%%%s%%", *v))
where, args = append(where, "`resource`.`filename` LIKE ?"), append(args, fmt.Sprintf("%%%s%%", *v))
}
if v := find.MemoID; v != nil {
where, args = append(where, "`memo_id` = ?"), append(args, *v)
where, args = append(where, "`resource`.`memo_id` = ?"), append(args, *v)
}
if find.HasRelatedMemo {
where = append(where, "`memo_id` IS NOT NULL")
where = append(where, "`resource`.`memo_id` IS NOT NULL")
}
if find.StorageType != nil {
where, args = append(where, "`storage_type` = ?"), append(args, find.StorageType.String())
where, args = append(where, "`resource`.`storage_type` = ?"), append(args, find.StorageType.String())
}
fields := []string{
"`resource`.`id` AS `id`",
"`resource`.`uid` AS `uid`",
"`resource`.`filename` AS `filename`",
"`resource`.`type` AS `type`",
"`resource`.`size` AS `size`",
"`resource`.`creator_id` AS `creator_id`",
"`resource`.`created_ts` AS `created_ts`",
"`resource`.`updated_ts` AS `updated_ts`",
"`resource`.`memo_id` AS `memo_id`",
"`resource`.`storage_type` AS `storage_type`",
"`resource`.`reference` AS `reference`",
"`resource`.`payload` AS `payload`",
"CASE WHEN `memo`.`uid` IS NOT NULL THEN `memo`.`uid` ELSE NULL END AS `memo_uid`",
}
fields := []string{"`id`", "`uid`", "`filename`", "`type`", "`size`", "`creator_id`", "`created_ts`", "`updated_ts`", "`memo_id`", "`storage_type`", "`reference`", "`payload`"}
if find.GetBlob {
fields = append(fields, "`blob`")
fields = append(fields, "`resource`.`blob` AS `blob`")
}
query := fmt.Sprintf("SELECT %s FROM `resource` WHERE %s ORDER BY `updated_ts` DESC", strings.Join(fields, ", "), strings.Join(where, " AND "))
query := "SELECT " + strings.Join(fields, ", ") + " FROM `resource`" + " " +
"LEFT JOIN `memo` ON `resource`.`memo_id` = `memo`.`id`" + " " +
"WHERE " + strings.Join(where, " AND ") + " " +
"ORDER BY `resource`.`updated_ts` DESC"
if find.Limit != nil {
query = fmt.Sprintf("%s LIMIT %d", query, *find.Limit)
if find.Offset != nil {
@ -104,6 +142,7 @@ func (d *DB) ListAttachments(ctx context.Context, find *store.FindAttachment) ([
&storageType,
&attachment.Reference,
&payloadBytes,
&attachment.MemoUID,
}
if find.GetBlob {
dests = append(dests, &attachment.Blob)

@ -0,0 +1,39 @@
package sqlite
import (
"testing"
"github.com/stretchr/testify/require"
"github.com/usememos/memos/plugin/filter"
)
func TestAttachmentConvertExprToSQL(t *testing.T) {
tests := []struct {
filter string
want string
args []any
}{
{
filter: `memo_id in ["5atZAj8GcvkSuUA3X2KLaY"]`,
want: "`resource`.`memo_id` IN (?)",
args: []any{"5atZAj8GcvkSuUA3X2KLaY"},
},
{
filter: `memo_id in ["5atZAj8GcvkSuUA3X2KLaY", "4EN8aEpcJ3MaK4ExHTpiTE"]`,
want: "`resource`.`memo_id` IN (?,?)",
args: []any{"5atZAj8GcvkSuUA3X2KLaY", "4EN8aEpcJ3MaK4ExHTpiTE"},
},
}
for _, tt := range tests {
parsedExpr, err := filter.Parse(tt.filter, filter.AttachmentFilterCELAttributes...)
require.NoError(t, err)
convertCtx := filter.NewConvertContext()
converter := filter.NewCommonSQLConverter(&filter.SQLiteDialect{})
err = converter.ConvertExprToSQL(convertCtx, parsedExpr.GetExpr())
require.NoError(t, err)
require.Equal(t, tt.want, convertCtx.Buffer.String())
require.Equal(t, tt.args, convertCtx.Args)
}
}
Loading…
Cancel
Save