From e0e735d14dbac24f1cc864712cbf7088ee85f5a6 Mon Sep 17 00:00:00 2001 From: johnnyjoy Date: Mon, 3 Feb 2025 17:14:53 +0800 Subject: [PATCH] feat: support memo filter for mysql and postgres --- plugin/filter/converter.go | 20 +++ plugin/filter/expr.go | 39 ++++++ plugin/filter/filter.go | 23 ---- store/db/mysql/memo.go | 16 +++ store/db/mysql/memo_filter.go | 175 ++++++++++++++++++++++++++ store/db/mysql/memo_filter_test.go | 63 ++++++++++ store/db/postgres/memo.go | 17 +++ store/db/postgres/memo_filter.go | 175 ++++++++++++++++++++++++++ store/db/postgres/memo_filter_test.go | 63 ++++++++++ store/db/sqlite/memo.go | 9 +- store/db/sqlite/memo_filter.go | 124 ++++++++++-------- store/db/sqlite/memo_filter_test.go | 33 +++-- 12 files changed, 670 insertions(+), 87 deletions(-) create mode 100644 plugin/filter/converter.go create mode 100644 plugin/filter/expr.go create mode 100644 store/db/mysql/memo_filter.go create mode 100644 store/db/mysql/memo_filter_test.go create mode 100644 store/db/postgres/memo_filter.go create mode 100644 store/db/postgres/memo_filter_test.go diff --git a/plugin/filter/converter.go b/plugin/filter/converter.go new file mode 100644 index 00000000..c55a395b --- /dev/null +++ b/plugin/filter/converter.go @@ -0,0 +1,20 @@ +package filter + +import ( + "strings" +) + +type ConvertContext struct { + Buffer strings.Builder + Args []any + // The offset of the next argument in the condition string. + // Mainly using for PostgreSQL. + ArgsOffset int +} + +func NewConvertContext() *ConvertContext { + return &ConvertContext{ + Buffer: strings.Builder{}, + Args: []any{}, + } +} diff --git a/plugin/filter/expr.go b/plugin/filter/expr.go new file mode 100644 index 00000000..5bdcdc27 --- /dev/null +++ b/plugin/filter/expr.go @@ -0,0 +1,39 @@ +package filter + +import ( + "errors" + + exprv1 "google.golang.org/genproto/googleapis/api/expr/v1alpha1" +) + +// GetConstValue returns the constant value of the expression. +func GetConstValue(expr *exprv1.Expr) (any, error) { + v, ok := expr.ExprKind.(*exprv1.Expr_ConstExpr) + if !ok { + return nil, errors.New("invalid constant expression") + } + + switch v.ConstExpr.ConstantKind.(type) { + case *exprv1.Constant_StringValue: + return v.ConstExpr.GetStringValue(), nil + case *exprv1.Constant_Int64Value: + return v.ConstExpr.GetInt64Value(), nil + case *exprv1.Constant_Uint64Value: + return v.ConstExpr.GetUint64Value(), nil + case *exprv1.Constant_DoubleValue: + return v.ConstExpr.GetDoubleValue(), nil + case *exprv1.Constant_BoolValue: + return v.ConstExpr.GetBoolValue(), nil + default: + return nil, errors.New("unexpected constant type") + } +} + +// GetIdentExprName returns the name of the identifier expression. +func GetIdentExprName(expr *exprv1.Expr) (string, error) { + _, ok := expr.ExprKind.(*exprv1.Expr_IdentExpr) + if !ok { + return "", errors.New("invalid identifier expression") + } + return expr.GetIdentExpr().GetName(), nil +} diff --git a/plugin/filter/filter.go b/plugin/filter/filter.go index 696ba30b..7d7ae5c0 100644 --- a/plugin/filter/filter.go +++ b/plugin/filter/filter.go @@ -30,26 +30,3 @@ func Parse(filter string, opts ...cel.EnvOption) (expr *exprv1.ParsedExpr, err e } return cel.AstToParsedExpr(ast) } - -// GetConstValue returns the constant value of the expression. -func GetConstValue(expr *exprv1.Expr) (any, error) { - v, ok := expr.ExprKind.(*exprv1.Expr_ConstExpr) - if !ok { - return nil, errors.New("invalid constant expression") - } - - switch v.ConstExpr.ConstantKind.(type) { - case *exprv1.Constant_StringValue: - return v.ConstExpr.GetStringValue(), nil - case *exprv1.Constant_Int64Value: - return v.ConstExpr.GetInt64Value(), nil - case *exprv1.Constant_Uint64Value: - return v.ConstExpr.GetUint64Value(), nil - case *exprv1.Constant_DoubleValue: - return v.ConstExpr.GetDoubleValue(), nil - case *exprv1.Constant_BoolValue: - return v.ConstExpr.GetBoolValue(), nil - default: - return nil, errors.New("unexpected constant type") - } -} diff --git a/store/db/mysql/memo.go b/store/db/mysql/memo.go index af17a493..e519c26f 100644 --- a/store/db/mysql/memo.go +++ b/store/db/mysql/memo.go @@ -8,6 +8,7 @@ import ( "github.com/pkg/errors" "google.golang.org/protobuf/encoding/protojson" + "github.com/usememos/memos/plugin/filter" storepb "github.com/usememos/memos/proto/gen/store" "github.com/usememos/memos/store" ) @@ -108,6 +109,21 @@ func (d *DB) ListMemos(ctx context.Context, find *store.FindMemo) ([]*store.Memo where = append(where, "JSON_EXTRACT(`memo`.`payload`, '$.property.hasIncompleteTasks') IS TRUE") } } + if v := find.Filter; v != nil { + // Parse filter string and return the parsed expression. + // The filter string should be a CEL expression. + parsedExpr, err := filter.Parse(*v, filter.MemoFilterCELAttributes...) + if err != nil { + return nil, err + } + convertCtx := filter.NewConvertContext() + // ConvertExprToSQL converts the parsed expression to a SQL condition string. + if err := ConvertExprToSQL(convertCtx, parsedExpr.GetExpr()); err != nil { + return nil, err + } + where = append(where, fmt.Sprintf("(%s)", convertCtx.Buffer.String())) + args = append(args, convertCtx.Args...) + } if find.ExcludeComments { having = append(having, "`parent_id` IS NULL") } diff --git a/store/db/mysql/memo_filter.go b/store/db/mysql/memo_filter.go new file mode 100644 index 00000000..8807c1e8 --- /dev/null +++ b/store/db/mysql/memo_filter.go @@ -0,0 +1,175 @@ +package mysql + +import ( + "fmt" + "slices" + "strings" + "time" + + "github.com/pkg/errors" + exprv1 "google.golang.org/genproto/googleapis/api/expr/v1alpha1" + + "github.com/usememos/memos/plugin/filter" +) + +func ConvertExprToSQL(ctx *filter.ConvertContext, expr *exprv1.Expr) error { + switch v := expr.ExprKind.(type) { + case *exprv1.Expr_CallExpr: + switch v.CallExpr.Function { + case "_||_", "_&&_": + if len(v.CallExpr.Args) != 2 { + return errors.Errorf("invalid number of arguments for %s", v.CallExpr.Function) + } + if _, err := ctx.Buffer.WriteString("("); err != nil { + return err + } + if err := ConvertExprToSQL(ctx, v.CallExpr.Args[0]); err != nil { + return err + } + operator := "AND" + if v.CallExpr.Function == "_||_" { + operator = "OR" + } + if _, err := ctx.Buffer.WriteString(fmt.Sprintf(" %s ", operator)); err != nil { + return err + } + if err := ConvertExprToSQL(ctx, v.CallExpr.Args[1]); err != nil { + return err + } + if _, err := ctx.Buffer.WriteString(")"); err != nil { + return err + } + case "!_": + if len(v.CallExpr.Args) != 1 { + return errors.Errorf("invalid number of arguments for %s", v.CallExpr.Function) + } + if _, err := ctx.Buffer.WriteString("NOT ("); err != nil { + return err + } + if err := ConvertExprToSQL(ctx, v.CallExpr.Args[0]); err != nil { + return err + } + if _, err := ctx.Buffer.WriteString(")"); err != nil { + return err + } + case "_==_", "_!=_", "_<_", "_>_", "_<=_", "_>=_": + if len(v.CallExpr.Args) != 2 { + return errors.Errorf("invalid number of arguments for %s", v.CallExpr.Function) + } + identifier, err := filter.GetIdentExprName(v.CallExpr.Args[0]) + if err != nil { + return err + } + if !slices.Contains([]string{"create_time", "update_time"}, identifier) { + return errors.Errorf("invalid identifier for %s", v.CallExpr.Function) + } + value, err := filter.GetConstValue(v.CallExpr.Args[1]) + if err != nil { + return err + } + operator := "=" + switch v.CallExpr.Function { + case "_==_": + operator = "=" + case "_!=_": + operator = "!=" + case "_<_": + operator = "<" + case "_>_": + operator = ">" + case "_<=_": + operator = "<=" + case "_>=_": + operator = ">=" + } + + if identifier == "create_time" || identifier == "update_time" { + timestampStr, ok := value.(string) + if !ok { + return errors.New("invalid timestamp value") + } + timestamp, err := time.Parse(time.RFC3339, timestampStr) + if err != nil { + return errors.Wrap(err, "failed to parse timestamp") + } + + var factor string + if identifier == "create_time" { + factor = "`memo`.`created_ts`" + } else if identifier == "update_time" { + factor = "`memo`.`updated_ts`" + } + if _, err := ctx.Buffer.WriteString(fmt.Sprintf("UNIX_TIMESTAMP(%s) %s ?", factor, operator)); err != nil { + return err + } + ctx.Args = append(ctx.Args, timestamp.Unix()) + } + case "@in": + if len(v.CallExpr.Args) != 2 { + return errors.Errorf("invalid number of arguments for %s", v.CallExpr.Function) + } + identifier, err := filter.GetIdentExprName(v.CallExpr.Args[0]) + if err != nil { + return err + } + if !slices.Contains([]string{"tag", "visibility"}, identifier) { + return errors.Errorf("invalid identifier for %s", v.CallExpr.Function) + } + + values := []any{} + for _, element := range v.CallExpr.Args[1].GetListExpr().Elements { + value, err := filter.GetConstValue(element) + if err != nil { + return err + } + values = append(values, value) + } + if identifier == "tag" { + subcodition := []string{} + args := []any{} + for _, v := range values { + subcodition, args = append(subcodition, "JSON_CONTAINS(JSON_EXTRACT(`memo`.`payload`, '$.tags'), ?)"), append(args, v) + } + if len(subcodition) == 1 { + if _, err := ctx.Buffer.WriteString(subcodition[0]); err != nil { + return err + } + } else { + if _, err := ctx.Buffer.WriteString(fmt.Sprintf("(%s)", strings.Join(subcodition, " OR "))); err != nil { + return err + } + } + ctx.Args = append(ctx.Args, args...) + } else if identifier == "visibility" { + placeholder := []string{} + for range values { + placeholder = append(placeholder, "?") + } + if _, err := ctx.Buffer.WriteString(fmt.Sprintf("`memo`.`visibility` IN (%s)", strings.Join(placeholder, ","))); err != nil { + return err + } + ctx.Args = append(ctx.Args, values...) + } + case "contains": + if len(v.CallExpr.Args) != 1 { + return errors.Errorf("invalid number of arguments for %s", v.CallExpr.Function) + } + identifier, err := filter.GetIdentExprName(v.CallExpr.Target) + if err != nil { + return err + } + if identifier != "content" { + return errors.Errorf("invalid identifier for %s", v.CallExpr.Function) + } + arg, err := filter.GetConstValue(v.CallExpr.Args[0]) + if err != nil { + return err + } + if _, err := ctx.Buffer.WriteString("`memo`.`content` LIKE ?"); err != nil { + return err + } + ctx.Args = append(ctx.Args, fmt.Sprintf("%%%s%%", arg)) + } + } + return nil +} diff --git a/store/db/mysql/memo_filter_test.go b/store/db/mysql/memo_filter_test.go new file mode 100644 index 00000000..b5a6e71e --- /dev/null +++ b/store/db/mysql/memo_filter_test.go @@ -0,0 +1,63 @@ +package mysql + +import ( + "testing" + + "github.com/stretchr/testify/require" + + "github.com/usememos/memos/plugin/filter" +) + +func TestConvertExprToSQL(t *testing.T) { + tests := []struct { + filter string + want string + args []any + }{ + { + filter: `tag in ["tag1", "tag2"]`, + want: "(JSON_CONTAINS(JSON_EXTRACT(`memo`.`payload`, '$.tags'), ?) OR JSON_CONTAINS(JSON_EXTRACT(`memo`.`payload`, '$.tags'), ?))", + args: []any{"tag1", "tag2"}, + }, + { + filter: `!(tag in ["tag1", "tag2"])`, + want: "NOT ((JSON_CONTAINS(JSON_EXTRACT(`memo`.`payload`, '$.tags'), ?) OR JSON_CONTAINS(JSON_EXTRACT(`memo`.`payload`, '$.tags'), ?)))", + args: []any{"tag1", "tag2"}, + }, + { + filter: `content.contains("memos")`, + want: "`memo`.`content` LIKE ?", + args: []any{"%memos%"}, + }, + { + filter: `visibility in ["PUBLIC"]`, + want: "`memo`.`visibility` IN (?)", + args: []any{"PUBLIC"}, + }, + { + filter: `visibility in ["PUBLIC", "PRIVATE"]`, + want: "`memo`.`visibility` IN (?,?)", + args: []any{"PUBLIC", "PRIVATE"}, + }, + { + filter: `create_time == "2006-01-02T15:04:05+07:00"`, + want: "UNIX_TIMESTAMP(`memo`.`created_ts`) = ?", + args: []any{int64(1136189045)}, + }, + { + filter: `tag in ['tag1'] || content.contains('hello')`, + want: "(JSON_CONTAINS(JSON_EXTRACT(`memo`.`payload`, '$.tags'), ?) OR `memo`.`content` LIKE ?)", + args: []any{"tag1", "%hello%"}, + }, + } + + for _, tt := range tests { + parsedExpr, err := filter.Parse(tt.filter, filter.MemoFilterCELAttributes...) + require.NoError(t, err) + convertCtx := filter.NewConvertContext() + err = ConvertExprToSQL(convertCtx, parsedExpr.GetExpr()) + require.NoError(t, err) + require.Equal(t, tt.want, convertCtx.Buffer.String()) + require.Equal(t, tt.args, convertCtx.Args) + } +} diff --git a/store/db/postgres/memo.go b/store/db/postgres/memo.go index 8f605bc2..f5e1560c 100644 --- a/store/db/postgres/memo.go +++ b/store/db/postgres/memo.go @@ -8,6 +8,7 @@ import ( "github.com/pkg/errors" "google.golang.org/protobuf/encoding/protojson" + "github.com/usememos/memos/plugin/filter" storepb "github.com/usememos/memos/proto/gen/store" "github.com/usememos/memos/store" ) @@ -99,6 +100,22 @@ func (d *DB) ListMemos(ctx context.Context, find *store.FindMemo) ([]*store.Memo where = append(where, "(memo.payload->'property'->>'hasIncompleteTasks')::BOOLEAN IS TRUE") } } + if v := find.Filter; v != nil { + // Parse filter string and return the parsed expression. + // The filter string should be a CEL expression. + parsedExpr, err := filter.Parse(*v, filter.MemoFilterCELAttributes...) + if err != nil { + return nil, err + } + convertCtx := filter.NewConvertContext() + convertCtx.ArgsOffset = len(args) + // ConvertExprToSQL converts the parsed expression to a SQL condition string. + if err := ConvertExprToSQL(convertCtx, parsedExpr.GetExpr()); err != nil { + return nil, err + } + where = append(where, fmt.Sprintf("(%s)", convertCtx.Buffer.String())) + args = append(args, convertCtx.Args...) + } if find.ExcludeComments { where = append(where, "memo_relation.related_memo_id IS NULL") } diff --git a/store/db/postgres/memo_filter.go b/store/db/postgres/memo_filter.go new file mode 100644 index 00000000..06daaad2 --- /dev/null +++ b/store/db/postgres/memo_filter.go @@ -0,0 +1,175 @@ +package postgres + +import ( + "fmt" + "slices" + "strings" + "time" + + "github.com/pkg/errors" + exprv1 "google.golang.org/genproto/googleapis/api/expr/v1alpha1" + + "github.com/usememos/memos/plugin/filter" +) + +func ConvertExprToSQL(ctx *filter.ConvertContext, expr *exprv1.Expr) error { + switch v := expr.ExprKind.(type) { + case *exprv1.Expr_CallExpr: + switch v.CallExpr.Function { + case "_||_", "_&&_": + if len(v.CallExpr.Args) != 2 { + return errors.Errorf("invalid number of arguments for %s", v.CallExpr.Function) + } + if _, err := ctx.Buffer.WriteString("("); err != nil { + return err + } + if err := ConvertExprToSQL(ctx, v.CallExpr.Args[0]); err != nil { + return err + } + operator := "AND" + if v.CallExpr.Function == "_||_" { + operator = "OR" + } + if _, err := ctx.Buffer.WriteString(fmt.Sprintf(" %s ", operator)); err != nil { + return err + } + if err := ConvertExprToSQL(ctx, v.CallExpr.Args[1]); err != nil { + return err + } + if _, err := ctx.Buffer.WriteString(")"); err != nil { + return err + } + case "!_": + if len(v.CallExpr.Args) != 1 { + return errors.Errorf("invalid number of arguments for %s", v.CallExpr.Function) + } + if _, err := ctx.Buffer.WriteString("NOT ("); err != nil { + return err + } + if err := ConvertExprToSQL(ctx, v.CallExpr.Args[0]); err != nil { + return err + } + if _, err := ctx.Buffer.WriteString(")"); err != nil { + return err + } + case "_==_", "_!=_", "_<_", "_>_", "_<=_", "_>=_": + if len(v.CallExpr.Args) != 2 { + return errors.Errorf("invalid number of arguments for %s", v.CallExpr.Function) + } + identifier, err := filter.GetIdentExprName(v.CallExpr.Args[0]) + if err != nil { + return err + } + if !slices.Contains([]string{"create_time", "update_time"}, identifier) { + return errors.Errorf("invalid identifier for %s", v.CallExpr.Function) + } + value, err := filter.GetConstValue(v.CallExpr.Args[1]) + if err != nil { + return err + } + operator := "=" + switch v.CallExpr.Function { + case "_==_": + operator = "=" + case "_!=_": + operator = "!=" + case "_<_": + operator = "<" + case "_>_": + operator = ">" + case "_<=_": + operator = "<=" + case "_>=_": + operator = ">=" + } + + if identifier == "create_time" || identifier == "update_time" { + timestampStr, ok := value.(string) + if !ok { + return errors.New("invalid timestamp value") + } + timestamp, err := time.Parse(time.RFC3339, timestampStr) + if err != nil { + return errors.Wrap(err, "failed to parse timestamp") + } + + var factor string + if identifier == "create_time" { + factor = "memo.created_ts" + } else if identifier == "update_time" { + factor = "memo.updated_ts" + } + if _, err := ctx.Buffer.WriteString(fmt.Sprintf("%s %s %s", factor, operator, placeholder(len(ctx.Args)+ctx.ArgsOffset+1))); err != nil { + return err + } + ctx.Args = append(ctx.Args, timestamp.Unix()) + } + case "@in": + if len(v.CallExpr.Args) != 2 { + return errors.Errorf("invalid number of arguments for %s", v.CallExpr.Function) + } + identifier, err := filter.GetIdentExprName(v.CallExpr.Args[0]) + if err != nil { + return err + } + if !slices.Contains([]string{"tag", "visibility"}, identifier) { + return errors.Errorf("invalid identifier for %s", v.CallExpr.Function) + } + + values := []any{} + for _, element := range v.CallExpr.Args[1].GetListExpr().Elements { + value, err := filter.GetConstValue(element) + if err != nil { + return err + } + values = append(values, value) + } + if identifier == "tag" { + subcodition := []string{} + args := []any{} + for _, v := range values { + subcodition, args = append(subcodition, fmt.Sprintf(`memo.payload->'tags' @> %s::jsonb`, placeholder(len(ctx.Args)+ctx.ArgsOffset+len(args)+1))), append(args, []any{v}) + } + if len(subcodition) == 1 { + if _, err := ctx.Buffer.WriteString(subcodition[0]); err != nil { + return err + } + } else { + if _, err := ctx.Buffer.WriteString(fmt.Sprintf("(%s)", strings.Join(subcodition, " OR "))); err != nil { + return err + } + } + ctx.Args = append(ctx.Args, args...) + } else if identifier == "visibility" { + placeholders := []string{} + for i := range values { + placeholders = append(placeholders, placeholder(len(ctx.Args)+ctx.ArgsOffset+i+1)) + } + if _, err := ctx.Buffer.WriteString(fmt.Sprintf("memo.visibility IN (%s)", strings.Join(placeholders, ","))); err != nil { + return err + } + ctx.Args = append(ctx.Args, values...) + } + case "contains": + if len(v.CallExpr.Args) != 1 { + return errors.Errorf("invalid number of arguments for %s", v.CallExpr.Function) + } + identifier, err := filter.GetIdentExprName(v.CallExpr.Target) + if err != nil { + return err + } + if identifier != "content" { + return errors.Errorf("invalid identifier for %s", v.CallExpr.Function) + } + arg, err := filter.GetConstValue(v.CallExpr.Args[0]) + if err != nil { + return err + } + if _, err := ctx.Buffer.WriteString("memo.content ILIKE LIKE " + placeholder(len(ctx.Args)+ctx.ArgsOffset+1)); err != nil { + return err + } + ctx.Args = append(ctx.Args, fmt.Sprintf("%%%s%%", arg)) + } + } + return nil +} diff --git a/store/db/postgres/memo_filter_test.go b/store/db/postgres/memo_filter_test.go new file mode 100644 index 00000000..c5deb5a1 --- /dev/null +++ b/store/db/postgres/memo_filter_test.go @@ -0,0 +1,63 @@ +package postgres + +import ( + "testing" + + "github.com/stretchr/testify/require" + + "github.com/usememos/memos/plugin/filter" +) + +func TestRestoreExprToSQL(t *testing.T) { + tests := []struct { + filter string + want string + args []any + }{ + { + filter: `tag in ["tag1", "tag2"]`, + want: "(memo.payload->'tags' @> $1::jsonb OR memo.payload->'tags' @> $2::jsonb)", + args: []any{[]any{"tag1"}, []any{"tag2"}}, + }, + { + filter: `!(tag in ["tag1", "tag2"])`, + want: `NOT ((memo.payload->'tags' @> $1::jsonb OR memo.payload->'tags' @> $2::jsonb))`, + args: []any{[]any{"tag1"}, []any{"tag2"}}, + }, + { + filter: `content.contains("memos")`, + want: "memo.content ILIKE LIKE $1", + args: []any{"%memos%"}, + }, + { + filter: `visibility in ["PUBLIC"]`, + want: "memo.visibility IN ($1)", + args: []any{"PUBLIC"}, + }, + { + filter: `visibility in ["PUBLIC", "PRIVATE"]`, + want: "memo.visibility IN ($1,$2)", + args: []any{"PUBLIC", "PRIVATE"}, + }, + { + filter: `create_time == "2006-01-02T15:04:05+07:00"`, + want: "memo.created_ts = $1", + args: []any{int64(1136189045)}, + }, + { + filter: `tag in ['tag1'] || content.contains('hello')`, + want: "(memo.payload->'tags' @> $1::jsonb OR memo.content ILIKE LIKE $2)", + args: []any{[]any{"tag1"}, "%hello%"}, + }, + } + + for _, tt := range tests { + parsedExpr, err := filter.Parse(tt.filter, filter.MemoFilterCELAttributes...) + require.NoError(t, err) + convertCtx := filter.NewConvertContext() + err = ConvertExprToSQL(convertCtx, parsedExpr.GetExpr()) + require.NoError(t, err) + require.Equal(t, tt.want, convertCtx.Buffer.String()) + require.Equal(t, tt.args, convertCtx.Args) + } +} diff --git a/store/db/sqlite/memo.go b/store/db/sqlite/memo.go index 3f8e290e..58db5c57 100644 --- a/store/db/sqlite/memo.go +++ b/store/db/sqlite/memo.go @@ -108,12 +108,13 @@ func (d *DB) ListMemos(ctx context.Context, find *store.FindMemo) ([]*store.Memo if err != nil { return nil, err } - // RestoreExprToSQL parses the expression and returns the SQL condition. - condition, err := RestoreExprToSQL(parsedExpr.GetExpr()) - if err != nil { + convertCtx := filter.NewConvertContext() + // ConvertExprToSQL converts the parsed expression to a SQL condition string. + if err := ConvertExprToSQL(convertCtx, parsedExpr.GetExpr()); err != nil { return nil, err } - where = append(where, fmt.Sprintf("(%s)", condition)) + where = append(where, fmt.Sprintf("(%s)", convertCtx.Buffer.String())) + args = append(args, convertCtx.Args...) } if find.ExcludeComments { where = append(where, "`parent_id` IS NULL") diff --git a/store/db/sqlite/memo_filter.go b/store/db/sqlite/memo_filter.go index f4954db2..c56530c0 100644 --- a/store/db/sqlite/memo_filter.go +++ b/store/db/sqlite/memo_filter.go @@ -12,39 +12,60 @@ import ( "github.com/usememos/memos/plugin/filter" ) -func RestoreExprToSQL(expr *exprv1.Expr) (string, error) { - var condition string +func ConvertExprToSQL(ctx *filter.ConvertContext, expr *exprv1.Expr) error { switch v := expr.ExprKind.(type) { case *exprv1.Expr_CallExpr: switch v.CallExpr.Function { case "_||_", "_&&_": if len(v.CallExpr.Args) != 2 { - return "", errors.Errorf("invalid number of arguments for %s", v.CallExpr.Function) + return errors.Errorf("invalid number of arguments for %s", v.CallExpr.Function) } - left, err := RestoreExprToSQL(v.CallExpr.Args[0]) - if err != nil { - return "", err + if _, err := ctx.Buffer.WriteString("("); err != nil { + return err } - right, err := RestoreExprToSQL(v.CallExpr.Args[1]) - if err != nil { - return "", err + if err := ConvertExprToSQL(ctx, v.CallExpr.Args[0]); err != nil { + return err } operator := "AND" if v.CallExpr.Function == "_||_" { operator = "OR" } - condition = fmt.Sprintf("(%s %s %s)", left, operator, right) + if _, err := ctx.Buffer.WriteString(fmt.Sprintf(" %s ", operator)); err != nil { + return err + } + if err := ConvertExprToSQL(ctx, v.CallExpr.Args[1]); err != nil { + return err + } + if _, err := ctx.Buffer.WriteString(")"); err != nil { + return err + } + case "!_": + if len(v.CallExpr.Args) != 1 { + return errors.Errorf("invalid number of arguments for %s", v.CallExpr.Function) + } + if _, err := ctx.Buffer.WriteString("NOT ("); err != nil { + return err + } + if err := ConvertExprToSQL(ctx, v.CallExpr.Args[0]); err != nil { + return err + } + if _, err := ctx.Buffer.WriteString(")"); err != nil { + return err + } case "_==_", "_!=_", "_<_", "_>_", "_<=_", "_>=_": if len(v.CallExpr.Args) != 2 { - return "", errors.Errorf("invalid number of arguments for %s", v.CallExpr.Function) + return errors.Errorf("invalid number of arguments for %s", v.CallExpr.Function) + } + identifier, err := filter.GetIdentExprName(v.CallExpr.Args[0]) + if err != nil { + return err } - identifier := v.CallExpr.Args[0].GetIdentExpr().GetName() if !slices.Contains([]string{"create_time", "update_time"}, identifier) { - return "", errors.Errorf("invalid identifier for %s", v.CallExpr.Function) + return errors.Errorf("invalid identifier for %s", v.CallExpr.Function) } value, err := filter.GetConstValue(v.CallExpr.Args[1]) if err != nil { - return "", err + return err } operator := "=" switch v.CallExpr.Function { @@ -65,85 +86,90 @@ func RestoreExprToSQL(expr *exprv1.Expr) (string, error) { if identifier == "create_time" || identifier == "update_time" { timestampStr, ok := value.(string) if !ok { - return "", errors.New("invalid timestamp value") + return errors.New("invalid timestamp value") } timestamp, err := time.Parse(time.RFC3339, timestampStr) if err != nil { - return "", errors.Wrap(err, "failed to parse timestamp") + return errors.Wrap(err, "failed to parse timestamp") } + var factor string if identifier == "create_time" { - condition = fmt.Sprintf("`memo`.`created_ts` %s %d", operator, timestamp.Unix()) + factor = "`memo`.`created_ts`" } else if identifier == "update_time" { - condition = fmt.Sprintf("`memo`.`updated_ts` %s %d", operator, timestamp.Unix()) + factor = "`memo`.`updated_ts`" + } + if _, err := ctx.Buffer.WriteString(fmt.Sprintf("%s %s ?", factor, operator)); err != nil { + return err } + ctx.Args = append(ctx.Args, timestamp.Unix()) } case "@in": if len(v.CallExpr.Args) != 2 { - return "", errors.Errorf("invalid number of arguments for %s", v.CallExpr.Function) + return errors.Errorf("invalid number of arguments for %s", v.CallExpr.Function) + } + identifier, err := filter.GetIdentExprName(v.CallExpr.Args[0]) + if err != nil { + return err } - identifier := v.CallExpr.Args[0].GetIdentExpr().GetName() if !slices.Contains([]string{"tag", "visibility"}, identifier) { - return "", errors.Errorf("invalid identifier for %s", v.CallExpr.Function) + return errors.Errorf("invalid identifier for %s", v.CallExpr.Function) } values := []any{} for _, element := range v.CallExpr.Args[1].GetListExpr().Elements { value, err := filter.GetConstValue(element) if err != nil { - return "", err + return err } values = append(values, value) } if identifier == "tag" { subcodition := []string{} + args := []any{} for _, v := range values { - subcodition = append(subcodition, fmt.Sprintf("JSON_EXTRACT(`memo`.`payload`, '$.tags') LIKE %s", fmt.Sprintf(`'%%"%s"%%'`, v))) + subcodition, args = append(subcodition, "JSON_EXTRACT(`memo`.`payload`, '$.tags') LIKE ?"), append(args, fmt.Sprintf(`%%"%s"%%`, v)) } if len(subcodition) == 1 { - condition = subcodition[0] + if _, err := ctx.Buffer.WriteString(subcodition[0]); err != nil { + return err + } } else { - condition = fmt.Sprintf("(%s)", strings.Join(subcodition, " OR ")) + if _, err := ctx.Buffer.WriteString(fmt.Sprintf("(%s)", strings.Join(subcodition, " OR "))); err != nil { + return err + } } + ctx.Args = append(ctx.Args, args...) } else if identifier == "visibility" { - vs := []string{} - for _, v := range values { - vs = append(vs, fmt.Sprintf(`"%s"`, v)) + placeholder := []string{} + for range values { + placeholder = append(placeholder, "?") } - if len(vs) == 1 { - condition = fmt.Sprintf("`memo`.`visibility` = %s", vs[0]) - } else { - condition = fmt.Sprintf("`memo`.`visibility` IN (%s)", strings.Join(vs, ",")) + if _, err := ctx.Buffer.WriteString(fmt.Sprintf("`memo`.`visibility` IN (%s)", strings.Join(placeholder, ","))); err != nil { + return err } + ctx.Args = append(ctx.Args, values...) } case "contains": if len(v.CallExpr.Args) != 1 { - return "", errors.Errorf("invalid number of arguments for %s", v.CallExpr.Function) + return errors.Errorf("invalid number of arguments for %s", v.CallExpr.Function) } - identifier, err := RestoreExprToSQL(v.CallExpr.Target) + identifier, err := filter.GetIdentExprName(v.CallExpr.Target) if err != nil { - return "", err + return err } if identifier != "content" { - return "", errors.Errorf("invalid identifier for %s", v.CallExpr.Function) + return errors.Errorf("invalid identifier for %s", v.CallExpr.Function) } arg, err := filter.GetConstValue(v.CallExpr.Args[0]) if err != nil { - return "", err + return err } - condition = fmt.Sprintf("`memo`.`content` LIKE %s", fmt.Sprintf(`'%%%s%%'`, arg)) - case "!_": - if len(v.CallExpr.Args) != 1 { - return "", errors.Errorf("invalid number of arguments for %s", v.CallExpr.Function) - } - arg, err := RestoreExprToSQL(v.CallExpr.Args[0]) - if err != nil { - return "", err + if _, err := ctx.Buffer.WriteString("`memo`.`content` LIKE ?"); err != nil { + return err } - condition = fmt.Sprintf("NOT (%s)", arg) + ctx.Args = append(ctx.Args, fmt.Sprintf("%%%s%%", arg)) } - case *exprv1.Expr_IdentExpr: - return v.IdentExpr.GetName(), nil } - return condition, nil + return nil } diff --git a/store/db/sqlite/memo_filter_test.go b/store/db/sqlite/memo_filter_test.go index 6cd56635..ae5b4414 100644 --- a/store/db/sqlite/memo_filter_test.go +++ b/store/db/sqlite/memo_filter_test.go @@ -8,50 +8,61 @@ import ( "github.com/usememos/memos/plugin/filter" ) -func TestRestoreExprToSQL(t *testing.T) { +func TestConvertExprToSQL(t *testing.T) { tests := []struct { filter string want string + args []any }{ { filter: `tag in ["tag1", "tag2"]`, - want: "(JSON_EXTRACT(`memo`.`payload`, '$.tags') LIKE '%\"tag1\"%' OR JSON_EXTRACT(`memo`.`payload`, '$.tags') LIKE '%\"tag2\"%')", + want: "(JSON_EXTRACT(`memo`.`payload`, '$.tags') LIKE ? OR JSON_EXTRACT(`memo`.`payload`, '$.tags') LIKE ?)", + args: []any{`%"tag1"%`, `%"tag2"%`}, }, { filter: `!(tag in ["tag1", "tag2"])`, - want: "NOT ((JSON_EXTRACT(`memo`.`payload`, '$.tags') LIKE '%\"tag1\"%' OR JSON_EXTRACT(`memo`.`payload`, '$.tags') LIKE '%\"tag2\"%'))", + want: "NOT ((JSON_EXTRACT(`memo`.`payload`, '$.tags') LIKE ? OR JSON_EXTRACT(`memo`.`payload`, '$.tags') LIKE ?))", + args: []any{`%"tag1"%`, `%"tag2"%`}, }, { filter: `tag in ["tag1", "tag2"] || tag in ["tag3", "tag4"]`, - want: "((JSON_EXTRACT(`memo`.`payload`, '$.tags') LIKE '%\"tag1\"%' OR JSON_EXTRACT(`memo`.`payload`, '$.tags') LIKE '%\"tag2\"%') OR (JSON_EXTRACT(`memo`.`payload`, '$.tags') LIKE '%\"tag3\"%' OR JSON_EXTRACT(`memo`.`payload`, '$.tags') LIKE '%\"tag4\"%'))", + want: "((JSON_EXTRACT(`memo`.`payload`, '$.tags') LIKE ? OR JSON_EXTRACT(`memo`.`payload`, '$.tags') LIKE ?) OR (JSON_EXTRACT(`memo`.`payload`, '$.tags') LIKE ? OR JSON_EXTRACT(`memo`.`payload`, '$.tags') LIKE ?))", + args: []any{`%"tag1"%`, `%"tag2"%`, `%"tag3"%`, `%"tag4"%`}, }, { filter: `content.contains("memos")`, - want: "`memo`.`content` LIKE '%memos%'", + want: "`memo`.`content` LIKE ?", + args: []any{"%memos%"}, }, { filter: `visibility in ["PUBLIC"]`, - want: "`memo`.`visibility` = \"PUBLIC\"", + want: "`memo`.`visibility` IN (?)", + args: []any{"PUBLIC"}, }, { filter: `visibility in ["PUBLIC", "PRIVATE"]`, - want: "`memo`.`visibility` IN (\"PUBLIC\",\"PRIVATE\")", + want: "`memo`.`visibility` IN (?,?)", + args: []any{"PUBLIC", "PRIVATE"}, }, { filter: `create_time == "2006-01-02T15:04:05+07:00"`, - want: "`memo`.`created_ts` = 1136189045", + want: "`memo`.`created_ts` = ?", + args: []any{int64(1136189045)}, }, { filter: `tag in ['tag1'] || content.contains('hello')`, - want: "(JSON_EXTRACT(`memo`.`payload`, '$.tags') LIKE '%\"tag1\"%' OR `memo`.`content` LIKE '%hello%')", + want: "(JSON_EXTRACT(`memo`.`payload`, '$.tags') LIKE ? OR `memo`.`content` LIKE ?)", + args: []any{`%"tag1"%`, "%hello%"}, }, } for _, tt := range tests { parsedExpr, err := filter.Parse(tt.filter, filter.MemoFilterCELAttributes...) require.NoError(t, err) - result, err := RestoreExprToSQL(parsedExpr.GetExpr()) + convertCtx := filter.NewConvertContext() + err = ConvertExprToSQL(convertCtx, parsedExpr.GetExpr()) require.NoError(t, err) - require.Equal(t, tt.want, result) + require.Equal(t, tt.want, convertCtx.Buffer.String()) + require.Equal(t, tt.args, convertCtx.Args) } }