From f4bdfa28a00514e71644980bd6dcf588da9798cb Mon Sep 17 00:00:00 2001 From: varsnotwars Date: Fri, 8 Aug 2025 02:00:51 +1000 Subject: [PATCH] feat: filter/method for reactions by content_id (#4969) --- plugin/filter/common_converter.go | 36 ++++++++++++++--- plugin/filter/dialect.go | 32 +++++++-------- plugin/filter/filter.go | 5 +++ server/router/api/v1/memo_service.go | 32 ++++++++++++--- .../router/api/v1/memo_service_converter.go | 25 +++++++++--- store/db/mysql/reaction.go | 23 +++++++++++ store/db/mysql/reaction_filter_test.go | 39 +++++++++++++++++++ store/db/postgres/reaction.go | 23 +++++++++++ store/db/postgres/reaction_filter_test.go | 39 +++++++++++++++++++ store/db/sqlite/reaction.go | 23 +++++++++++ store/db/sqlite/reaction_filter_test.go | 39 +++++++++++++++++++ store/reaction.go | 1 + 12 files changed, 284 insertions(+), 33 deletions(-) create mode 100644 store/db/mysql/reaction_filter_test.go create mode 100644 store/db/postgres/reaction_filter_test.go create mode 100644 store/db/sqlite/reaction_filter_test.go diff --git a/plugin/filter/common_converter.go b/plugin/filter/common_converter.go index 44270e5b5..73c693911 100644 --- a/plugin/filter/common_converter.go +++ b/plugin/filter/common_converter.go @@ -205,7 +205,7 @@ func (c *CommonSQLConverter) handleInOperator(ctx *ConvertContext, callExpr *exp return err } - if !slices.Contains([]string{"tag", "visibility"}, identifier) { + if !slices.Contains([]string{"tag", "visibility", "content_id"}, identifier) { return errors.Errorf("invalid identifier for %s", callExpr.Function) } @@ -222,6 +222,8 @@ func (c *CommonSQLConverter) handleInOperator(ctx *ConvertContext, callExpr *exp return c.handleTagInList(ctx, values) } else if identifier == "visibility" { return c.handleVisibilityInList(ctx, values) + } else if identifier == "content_id" { + return c.handleContentIDInList(ctx, values) } return nil @@ -292,7 +294,7 @@ func (c *CommonSQLConverter) handleVisibilityInList(ctx *ConvertContext, values c.paramIndex++ } - tablePrefix := c.dialect.GetTablePrefix() + tablePrefix := c.dialect.GetTablePrefix("memo") if _, ok := c.dialect.(*PostgreSQLDialect); ok { if _, err := ctx.Buffer.WriteString(fmt.Sprintf("%s.visibility IN (%s)", tablePrefix, strings.Join(placeholders, ","))); err != nil { return err @@ -307,6 +309,28 @@ func (c *CommonSQLConverter) handleVisibilityInList(ctx *ConvertContext, values return nil } +func (c *CommonSQLConverter) handleContentIDInList(ctx *ConvertContext, values []any) error { + placeholders := []string{} + for range values { + placeholders = append(placeholders, c.dialect.GetParameterPlaceholder(c.paramIndex)) + c.paramIndex++ + } + + tablePrefix := c.dialect.GetTablePrefix("reaction") + if _, ok := c.dialect.(*PostgreSQLDialect); ok { + if _, err := ctx.Buffer.WriteString(fmt.Sprintf("%s.content_id IN (%s)", tablePrefix, strings.Join(placeholders, ","))); err != nil { + return err + } + } else { + if _, err := ctx.Buffer.WriteString(fmt.Sprintf("%s.`content_id` IN (%s)", tablePrefix, strings.Join(placeholders, ","))); err != nil { + return err + } + } + + ctx.Args = append(ctx.Args, values...) + return nil +} + func (c *CommonSQLConverter) handleContainsOperator(ctx *ConvertContext, callExpr *exprv1.Expr_Call) error { if len(callExpr.Args) != 1 { return errors.Errorf("invalid number of arguments for %s", callExpr.Function) @@ -326,7 +350,7 @@ func (c *CommonSQLConverter) handleContainsOperator(ctx *ConvertContext, callExp return err } - tablePrefix := c.dialect.GetTablePrefix() + tablePrefix := c.dialect.GetTablePrefix("memo") // PostgreSQL uses ILIKE and no backticks if _, ok := c.dialect.(*PostgreSQLDialect); ok { @@ -353,7 +377,7 @@ func (c *CommonSQLConverter) handleIdentifier(ctx *ConvertContext, identExpr *ex } if identifier == "pinned" { - tablePrefix := c.dialect.GetTablePrefix() + tablePrefix := c.dialect.GetTablePrefix("memo") if _, ok := c.dialect.(*PostgreSQLDialect); ok { if _, err := ctx.Buffer.WriteString(fmt.Sprintf("%s.pinned IS TRUE", tablePrefix)); err != nil { return err @@ -411,7 +435,7 @@ func (c *CommonSQLConverter) handleStringComparison(ctx *ConvertContext, field, return errors.New("invalid string value") } - tablePrefix := c.dialect.GetTablePrefix() + tablePrefix := c.dialect.GetTablePrefix("memo") if _, ok := c.dialect.(*PostgreSQLDialect); ok { // PostgreSQL doesn't use backticks @@ -447,7 +471,7 @@ func (c *CommonSQLConverter) handleIntComparison(ctx *ConvertContext, field, ope return errors.New("invalid int value") } - tablePrefix := c.dialect.GetTablePrefix() + tablePrefix := c.dialect.GetTablePrefix("memo") if _, ok := c.dialect.(*PostgreSQLDialect); ok { // PostgreSQL doesn't use backticks diff --git a/plugin/filter/dialect.go b/plugin/filter/dialect.go index da2de82f2..293d7d078 100644 --- a/plugin/filter/dialect.go +++ b/plugin/filter/dialect.go @@ -8,7 +8,7 @@ import ( // SQLDialect defines database-specific SQL generation methods. type SQLDialect interface { // Basic field access - GetTablePrefix() string + GetTablePrefix(entityName string) string GetParameterPlaceholder(index int) string // JSON operations @@ -53,8 +53,8 @@ func GetDialect(dbType DatabaseType) SQLDialect { // SQLiteDialect implements SQLDialect for SQLite. type SQLiteDialect struct{} -func (*SQLiteDialect) GetTablePrefix() string { - return "`memo`" +func (*SQLiteDialect) GetTablePrefix(entityName string) string { + return fmt.Sprintf("`%s`", entityName) } func (*SQLiteDialect) GetParameterPlaceholder(_ int) string { @@ -62,7 +62,7 @@ func (*SQLiteDialect) GetParameterPlaceholder(_ int) string { } func (d *SQLiteDialect) GetJSONExtract(path string) string { - return fmt.Sprintf("JSON_EXTRACT(%s.`payload`, '%s')", d.GetTablePrefix(), path) + return fmt.Sprintf("JSON_EXTRACT(%s.`payload`, '%s')", d.GetTablePrefix("memo"), path) } func (d *SQLiteDialect) GetJSONArrayLength(path string) string { @@ -96,7 +96,7 @@ func (d *SQLiteDialect) GetBooleanCheck(path string) string { } func (d *SQLiteDialect) GetTimestampComparison(field string) string { - return fmt.Sprintf("%s.`%s`", d.GetTablePrefix(), field) + return fmt.Sprintf("%s.`%s`", d.GetTablePrefix("memo"), field) } func (*SQLiteDialect) GetCurrentTimestamp() string { @@ -106,8 +106,8 @@ func (*SQLiteDialect) GetCurrentTimestamp() string { // MySQLDialect implements SQLDialect for MySQL. type MySQLDialect struct{} -func (*MySQLDialect) GetTablePrefix() string { - return "`memo`" +func (*MySQLDialect) GetTablePrefix(entityName string) string { + return fmt.Sprintf("`%s`", entityName) } func (*MySQLDialect) GetParameterPlaceholder(_ int) string { @@ -115,7 +115,7 @@ func (*MySQLDialect) GetParameterPlaceholder(_ int) string { } func (d *MySQLDialect) GetJSONExtract(path string) string { - return fmt.Sprintf("JSON_EXTRACT(%s.`payload`, '%s')", d.GetTablePrefix(), path) + return fmt.Sprintf("JSON_EXTRACT(%s.`payload`, '%s')", d.GetTablePrefix("memo"), path) } func (d *MySQLDialect) GetJSONArrayLength(path string) string { @@ -146,7 +146,7 @@ func (d *MySQLDialect) GetBooleanCheck(path string) string { } func (d *MySQLDialect) GetTimestampComparison(field string) string { - return fmt.Sprintf("UNIX_TIMESTAMP(%s.`%s`)", d.GetTablePrefix(), field) + return fmt.Sprintf("UNIX_TIMESTAMP(%s.`%s`)", d.GetTablePrefix("memo"), field) } func (*MySQLDialect) GetCurrentTimestamp() string { @@ -156,8 +156,8 @@ func (*MySQLDialect) GetCurrentTimestamp() string { // PostgreSQLDialect implements SQLDialect for PostgreSQL. type PostgreSQLDialect struct{} -func (*PostgreSQLDialect) GetTablePrefix() string { - return "memo" +func (*PostgreSQLDialect) GetTablePrefix(entityName string) string { + return entityName } func (*PostgreSQLDialect) GetParameterPlaceholder(index int) string { @@ -167,7 +167,7 @@ func (*PostgreSQLDialect) GetParameterPlaceholder(index int) string { func (d *PostgreSQLDialect) GetJSONExtract(path string) string { // Convert $.property.hasTaskList to memo.payload->'property'->>'hasTaskList' parts := strings.Split(strings.TrimPrefix(path, "$."), ".") - result := fmt.Sprintf("%s.payload", d.GetTablePrefix()) + result := fmt.Sprintf("%s.payload", d.GetTablePrefix("memo")) for i, part := range parts { if i == len(parts)-1 { result += fmt.Sprintf("->>'%s'", part) @@ -180,17 +180,17 @@ func (d *PostgreSQLDialect) GetJSONExtract(path string) string { func (d *PostgreSQLDialect) GetJSONArrayLength(path string) string { jsonPath := strings.Replace(path, "$.tags", "payload->'tags'", 1) - return fmt.Sprintf("jsonb_array_length(COALESCE(%s.%s, '[]'::jsonb))", d.GetTablePrefix(), jsonPath) + return fmt.Sprintf("jsonb_array_length(COALESCE(%s.%s, '[]'::jsonb))", d.GetTablePrefix("memo"), jsonPath) } func (d *PostgreSQLDialect) GetJSONContains(path, _ string) string { jsonPath := strings.Replace(path, "$.tags", "payload->'tags'", 1) - return fmt.Sprintf("%s.%s @> jsonb_build_array(?::json)", d.GetTablePrefix(), jsonPath) + return fmt.Sprintf("%s.%s @> jsonb_build_array(?::json)", d.GetTablePrefix("memo"), jsonPath) } func (d *PostgreSQLDialect) GetJSONLike(path, _ string) string { jsonPath := strings.Replace(path, "$.tags", "payload->'tags'", 1) - return fmt.Sprintf("%s.%s @> jsonb_build_array(?::json)", d.GetTablePrefix(), jsonPath) + return fmt.Sprintf("%s.%s @> jsonb_build_array(?::json)", d.GetTablePrefix("memo"), jsonPath) } func (*PostgreSQLDialect) GetBooleanValue(value bool) interface{} { @@ -207,7 +207,7 @@ func (d *PostgreSQLDialect) GetBooleanCheck(path string) string { } func (d *PostgreSQLDialect) GetTimestampComparison(field string) string { - return fmt.Sprintf("EXTRACT(EPOCH FROM TO_TIMESTAMP(%s.%s))", d.GetTablePrefix(), field) + return fmt.Sprintf("EXTRACT(EPOCH FROM TO_TIMESTAMP(%s.%s))", d.GetTablePrefix("memo"), field) } func (*PostgreSQLDialect) GetCurrentTimestamp() string { diff --git a/plugin/filter/filter.go b/plugin/filter/filter.go index 44f50d638..e30ef27c9 100644 --- a/plugin/filter/filter.go +++ b/plugin/filter/filter.go @@ -36,6 +36,11 @@ var MemoFilterCELAttributes = []cel.EnvOption{ ), } +// ReactionFilterCELAttributes are the CEL attributes for reaction. +var ReactionFilterCELAttributes = []cel.EnvOption{ + cel.Variable("content_id", cel.StringType), +} + // UserFilterCELAttributes are the CEL attributes for user. var UserFilterCELAttributes = []cel.EnvOption{ cel.Variable("username", cel.StringType), diff --git a/server/router/api/v1/memo_service.go b/server/router/api/v1/memo_service.go index 7375a1f15..41f6f238c 100644 --- a/server/router/api/v1/memo_service.go +++ b/server/router/api/v1/memo_service.go @@ -82,7 +82,7 @@ func (s *APIV1Service) CreateMemo(ctx context.Context, request *v1pb.CreateMemoR } } - memoMessage, err := s.convertMemoFromStore(ctx, memo) + memoMessage, err := s.convertMemoFromStore(ctx, memo, nil) if err != nil { return nil, errors.Wrap(err, "failed to convert memo") } @@ -178,8 +178,28 @@ func (s *APIV1Service) ListMemos(ctx context.Context, request *v1pb.ListMemosReq return nil, status.Errorf(codes.Internal, "failed to get next page token, error: %v", err) } } + + reactionMap := make(map[string][]*store.Reaction) + + memoNames := make([]string, 0, len(memos)) + for _, m := range memos { + memoNames = append(memoNames, fmt.Sprintf("'%s/%s'", MemoNamePrefix, m.UID)) + } + + reactions, err := s.Store.ListReactions(ctx, &store.FindReaction{ + Filters: []string{fmt.Sprintf("content_id in [%s]", strings.Join(memoNames, ", "))}, + }) + if err != nil { + return nil, status.Errorf(codes.Internal, "failed to list reactions") + } + + for _, reaction := range reactions { + reactionMap[reaction.ContentID] = append(reactionMap[reaction.ContentID], reaction) + } + for _, memo := range memos { - memoMessage, err := s.convertMemoFromStore(ctx, memo) + name := fmt.Sprintf("'%s/%s'", MemoNamePrefix, memo.UID) + memoMessage, err := s.convertMemoFromStore(ctx, memo, reactionMap[name]) if err != nil { return nil, errors.Wrap(err, "failed to convert memo") } @@ -220,7 +240,7 @@ func (s *APIV1Service) GetMemo(ctx context.Context, request *v1pb.GetMemoRequest } } - memoMessage, err := s.convertMemoFromStore(ctx, memo) + memoMessage, err := s.convertMemoFromStore(ctx, memo, nil) if err != nil { return nil, errors.Wrap(err, "failed to convert memo") } @@ -339,7 +359,7 @@ func (s *APIV1Service) UpdateMemo(ctx context.Context, request *v1pb.UpdateMemoR if err != nil { return nil, errors.Wrap(err, "failed to get memo") } - memoMessage, err := s.convertMemoFromStore(ctx, memo) + memoMessage, err := s.convertMemoFromStore(ctx, memo, nil) if err != nil { return nil, errors.Wrap(err, "failed to convert memo") } @@ -375,7 +395,7 @@ func (s *APIV1Service) DeleteMemo(ctx context.Context, request *v1pb.DeleteMemoR return nil, status.Errorf(codes.PermissionDenied, "permission denied") } - if memoMessage, err := s.convertMemoFromStore(ctx, memo); err == nil { + if memoMessage, err := s.convertMemoFromStore(ctx, memo, nil); err == nil { // Try to dispatch webhook when memo is deleted. if err := s.DispatchMemoDeletedWebhook(ctx, memoMessage); err != nil { slog.Warn("Failed to dispatch memo deleted webhook", slog.Any("err", err)) @@ -530,7 +550,7 @@ func (s *APIV1Service) ListMemoComments(ctx context.Context, request *v1pb.ListM return nil, status.Errorf(codes.Internal, "failed to get memo") } if memo != nil { - memoMessage, err := s.convertMemoFromStore(ctx, memo) + memoMessage, err := s.convertMemoFromStore(ctx, memo, nil) if err != nil { return nil, errors.Wrap(err, "failed to convert memo") } diff --git a/server/router/api/v1/memo_service_converter.go b/server/router/api/v1/memo_service_converter.go index 139cbfaba..bf4b3febc 100644 --- a/server/router/api/v1/memo_service_converter.go +++ b/server/router/api/v1/memo_service_converter.go @@ -6,6 +6,8 @@ import ( "time" "github.com/pkg/errors" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" "google.golang.org/protobuf/types/known/timestamppb" "github.com/usememos/gomark/parser" @@ -16,7 +18,7 @@ import ( "github.com/usememos/memos/store" ) -func (s *APIV1Service) convertMemoFromStore(ctx context.Context, memo *store.Memo) (*v1pb.Memo, error) { +func (s *APIV1Service) convertMemoFromStore(ctx context.Context, memo *store.Memo, reactions []*store.Reaction) (*v1pb.Memo, error) { displayTs := memo.CreatedTs workspaceMemoRelatedSetting, err := s.Store.GetWorkspaceMemoRelatedSetting(ctx) if err != nil { @@ -61,11 +63,24 @@ func (s *APIV1Service) convertMemoFromStore(ctx context.Context, memo *store.Mem } memoMessage.Attachments = listMemoAttachmentsResponse.Attachments - listMemoReactionsResponse, err := s.ListMemoReactions(ctx, &v1pb.ListMemoReactionsRequest{Name: name}) - if err != nil { - return nil, errors.Wrap(err, "failed to list memo reactions") + if len(reactions) > 0 { + for _, reaction := range reactions { + reactionMessage, err := s.convertReactionFromStore(ctx, reaction) + if err != nil { + return nil, status.Errorf(codes.Internal, "failed to convert reaction") + } + memoMessage.Reactions = append(memoMessage.Reactions, reactionMessage) + } + } else { + // done for backwards compatibility + // can remove once convertMemoFromStore is only responsible for mapping + // and all related DB entities are passed in as arguments purely for converting to request entities + listMemoReactionsResponse, err := s.ListMemoReactions(ctx, &v1pb.ListMemoReactionsRequest{Name: name}) + if err != nil { + return nil, errors.Wrap(err, "failed to list memo reactions") + } + memoMessage.Reactions = listMemoReactionsResponse.Reactions } - memoMessage.Reactions = listMemoReactionsResponse.Reactions nodes, err := parser.Parse(tokenizer.Tokenize(memo.Content)) if err != nil { diff --git a/store/db/mysql/reaction.go b/store/db/mysql/reaction.go index d59937e01..b2878b4e4 100644 --- a/store/db/mysql/reaction.go +++ b/store/db/mysql/reaction.go @@ -2,10 +2,12 @@ package mysql import ( "context" + "fmt" "strings" "github.com/pkg/errors" + "github.com/usememos/memos/plugin/filter" "github.com/usememos/memos/store" ) @@ -36,6 +38,27 @@ func (d *DB) UpsertReaction(ctx context.Context, upsert *store.Reaction) (*store func (d *DB) ListReactions(ctx context.Context, find *store.FindReaction) ([]*store.Reaction, error) { where, args := []string{"1 = 1"}, []interface{}{} + + for _, filterStr := range find.Filters { + // Parse filter string and return the parsed expression. + // The filter string should be a CEL expression. + parsedExpr, err := filter.Parse(filterStr, filter.ReactionFilterCELAttributes...) + if err != nil { + return nil, err + } + convertCtx := filter.NewConvertContext() + // ConvertExprToSQL converts the parsed expression to a SQL condition string. + converter := filter.NewCommonSQLConverter(&filter.MySQLDialect{}) + if err := converter.ConvertExprToSQL(convertCtx, parsedExpr.GetExpr()); err != nil { + return nil, err + } + condition := convertCtx.Buffer.String() + if condition != "" { + where = append(where, fmt.Sprintf("(%s)", condition)) + args = append(args, convertCtx.Args...) + } + } + if find.ID != nil { where, args = append(where, "`id` = ?"), append(args, *find.ID) } diff --git a/store/db/mysql/reaction_filter_test.go b/store/db/mysql/reaction_filter_test.go new file mode 100644 index 000000000..1ea4621df --- /dev/null +++ b/store/db/mysql/reaction_filter_test.go @@ -0,0 +1,39 @@ +package mysql + +import ( + "testing" + + "github.com/stretchr/testify/require" + + "github.com/usememos/memos/plugin/filter" +) + +func TestReactionConvertExprToSQL(t *testing.T) { + tests := []struct { + filter string + want string + args []any + }{ + { + filter: `content_id in ["memos/5atZAj8GcvkSuUA3X2KLaY"]`, + want: "`reaction`.`content_id` IN (?)", + args: []any{"memos/5atZAj8GcvkSuUA3X2KLaY"}, + }, + { + filter: `content_id in ["memos/5atZAj8GcvkSuUA3X2KLaY", "memos/4EN8aEpcJ3MaK4ExHTpiTE"]`, + want: "`reaction`.`content_id` IN (?,?)", + args: []any{"memos/5atZAj8GcvkSuUA3X2KLaY", "memos/4EN8aEpcJ3MaK4ExHTpiTE"}, + }, + } + + for _, tt := range tests { + parsedExpr, err := filter.Parse(tt.filter, filter.ReactionFilterCELAttributes...) + require.NoError(t, err) + convertCtx := filter.NewConvertContext() + converter := filter.NewCommonSQLConverter(&filter.MySQLDialect{}) + err = converter.ConvertExprToSQL(convertCtx, parsedExpr.GetExpr()) + require.NoError(t, err) + require.Equal(t, tt.want, convertCtx.Buffer.String()) + require.Equal(t, tt.args, convertCtx.Args) + } +} diff --git a/store/db/postgres/reaction.go b/store/db/postgres/reaction.go index 295c34dd4..4bfb9f7df 100644 --- a/store/db/postgres/reaction.go +++ b/store/db/postgres/reaction.go @@ -2,8 +2,10 @@ package postgres import ( "context" + "fmt" "strings" + "github.com/usememos/memos/plugin/filter" "github.com/usememos/memos/store" ) @@ -24,6 +26,27 @@ func (d *DB) UpsertReaction(ctx context.Context, upsert *store.Reaction) (*store func (d *DB) ListReactions(ctx context.Context, find *store.FindReaction) ([]*store.Reaction, error) { where, args := []string{"1 = 1"}, []interface{}{} + + for _, filterStr := range find.Filters { + // Parse filter string and return the parsed expression. + // The filter string should be a CEL expression. + parsedExpr, err := filter.Parse(filterStr, filter.ReactionFilterCELAttributes...) + if err != nil { + return nil, err + } + convertCtx := filter.NewConvertContext() + // ConvertExprToSQL converts the parsed expression to a SQL condition string. + converter := filter.NewCommonSQLConverter(&filter.PostgreSQLDialect{}) + if err := converter.ConvertExprToSQL(convertCtx, parsedExpr.GetExpr()); err != nil { + return nil, err + } + condition := convertCtx.Buffer.String() + if condition != "" { + where = append(where, fmt.Sprintf("(%s)", condition)) + args = append(args, convertCtx.Args...) + } + } + if find.ID != nil { where, args = append(where, "id = "+placeholder(len(args)+1)), append(args, *find.ID) } diff --git a/store/db/postgres/reaction_filter_test.go b/store/db/postgres/reaction_filter_test.go new file mode 100644 index 000000000..05f801699 --- /dev/null +++ b/store/db/postgres/reaction_filter_test.go @@ -0,0 +1,39 @@ +package postgres + +import ( + "testing" + + "github.com/stretchr/testify/require" + + "github.com/usememos/memos/plugin/filter" +) + +func TestReactionConvertExprToSQL(t *testing.T) { + tests := []struct { + filter string + want string + args []any + }{ + { + filter: `content_id in ["memos/5atZAj8GcvkSuUA3X2KLaY"]`, + want: "reaction.content_id IN ($1)", + args: []any{"memos/5atZAj8GcvkSuUA3X2KLaY"}, + }, + { + filter: `content_id in ["memos/5atZAj8GcvkSuUA3X2KLaY", "memos/4EN8aEpcJ3MaK4ExHTpiTE"]`, + want: "reaction.content_id IN ($1,$2)", + args: []any{"memos/5atZAj8GcvkSuUA3X2KLaY", "memos/4EN8aEpcJ3MaK4ExHTpiTE"}, + }, + } + + for _, tt := range tests { + parsedExpr, err := filter.Parse(tt.filter, filter.ReactionFilterCELAttributes...) + require.NoError(t, err) + convertCtx := filter.NewConvertContext() + converter := filter.NewCommonSQLConverter(&filter.PostgreSQLDialect{}) + err = converter.ConvertExprToSQL(convertCtx, parsedExpr.GetExpr()) + require.NoError(t, err) + require.Equal(t, tt.want, convertCtx.Buffer.String()) + require.Equal(t, tt.args, convertCtx.Args) + } +} diff --git a/store/db/sqlite/reaction.go b/store/db/sqlite/reaction.go index d95a54c95..10f86bbd8 100644 --- a/store/db/sqlite/reaction.go +++ b/store/db/sqlite/reaction.go @@ -2,8 +2,10 @@ package sqlite import ( "context" + "fmt" "strings" + "github.com/usememos/memos/plugin/filter" "github.com/usememos/memos/store" ) @@ -25,6 +27,27 @@ func (d *DB) UpsertReaction(ctx context.Context, upsert *store.Reaction) (*store func (d *DB) ListReactions(ctx context.Context, find *store.FindReaction) ([]*store.Reaction, error) { where, args := []string{"1 = 1"}, []interface{}{} + + for _, filterStr := range find.Filters { + // Parse filter string and return the parsed expression. + // The filter string should be a CEL expression. + parsedExpr, err := filter.Parse(filterStr, filter.ReactionFilterCELAttributes...) + if err != nil { + return nil, err + } + convertCtx := filter.NewConvertContext() + // ConvertExprToSQL converts the parsed expression to a SQL condition string. + converter := filter.NewCommonSQLConverter(&filter.SQLiteDialect{}) + if err := converter.ConvertExprToSQL(convertCtx, parsedExpr.GetExpr()); err != nil { + return nil, err + } + condition := convertCtx.Buffer.String() + if condition != "" { + where = append(where, fmt.Sprintf("(%s)", condition)) + args = append(args, convertCtx.Args...) + } + } + if find.ID != nil { where, args = append(where, "id = ?"), append(args, *find.ID) } diff --git a/store/db/sqlite/reaction_filter_test.go b/store/db/sqlite/reaction_filter_test.go new file mode 100644 index 000000000..d07f7cbbb --- /dev/null +++ b/store/db/sqlite/reaction_filter_test.go @@ -0,0 +1,39 @@ +package sqlite + +import ( + "testing" + + "github.com/stretchr/testify/require" + + "github.com/usememos/memos/plugin/filter" +) + +func TestReactionConvertExprToSQL(t *testing.T) { + tests := []struct { + filter string + want string + args []any + }{ + { + filter: `content_id in ["memos/5atZAj8GcvkSuUA3X2KLaY"]`, + want: "`reaction`.`content_id` IN (?)", + args: []any{"memos/5atZAj8GcvkSuUA3X2KLaY"}, + }, + { + filter: `content_id in ["memos/5atZAj8GcvkSuUA3X2KLaY", "memos/4EN8aEpcJ3MaK4ExHTpiTE"]`, + want: "`reaction`.`content_id` IN (?,?)", + args: []any{"memos/5atZAj8GcvkSuUA3X2KLaY", "memos/4EN8aEpcJ3MaK4ExHTpiTE"}, + }, + } + + for _, tt := range tests { + parsedExpr, err := filter.Parse(tt.filter, filter.ReactionFilterCELAttributes...) + require.NoError(t, err) + convertCtx := filter.NewConvertContext() + converter := filter.NewCommonSQLConverter(&filter.SQLiteDialect{}) + err = converter.ConvertExprToSQL(convertCtx, parsedExpr.GetExpr()) + require.NoError(t, err) + require.Equal(t, tt.want, convertCtx.Buffer.String()) + require.Equal(t, tt.args, convertCtx.Args) + } +} diff --git a/store/reaction.go b/store/reaction.go index be25b5fa4..7354cd9e0 100644 --- a/store/reaction.go +++ b/store/reaction.go @@ -17,6 +17,7 @@ type FindReaction struct { ID *int32 CreatorID *int32 ContentID *string + Filters []string } type DeleteReaction struct {