From 2d731c5cc56abe5516f30be1215fcb9ee4e807cf Mon Sep 17 00:00:00 2001 From: johnnyjoy Date: Sun, 2 Feb 2025 13:35:57 +0800 Subject: [PATCH] feat: memo filter for sqlite --- plugin/filter/filter.go | 47 ++++++++++ server/router/api/v1/memo_service_filter.go | 68 +++++++-------- store/db/sqlite/memo_filter.go | 96 +++++++++++++++++++++ store/db/sqlite/memo_filter_test.go | 40 +++++++++ store/memo_filter.go | 13 --- 5 files changed, 217 insertions(+), 47 deletions(-) create mode 100644 plugin/filter/filter.go create mode 100644 store/db/sqlite/memo_filter.go create mode 100644 store/db/sqlite/memo_filter_test.go delete mode 100644 store/memo_filter.go diff --git a/plugin/filter/filter.go b/plugin/filter/filter.go new file mode 100644 index 00000000..145e8c9e --- /dev/null +++ b/plugin/filter/filter.go @@ -0,0 +1,47 @@ +package filter + +import ( + "github.com/google/cel-go/cel" + "github.com/pkg/errors" + exprv1 "google.golang.org/genproto/googleapis/api/expr/v1alpha1" +) + +// MemoFilterCELAttributes are the CEL attributes for memo. +var MemoFilterCELAttributes = []cel.EnvOption{ + cel.Variable("content", cel.StringType), + cel.Variable("tag", cel.StringType), +} + +// Parse parses the filter string and returns the parsed expression. +// The filter string should be a CEL expression. +func Parse(filter string) (expr *exprv1.ParsedExpr, err error) { + e, err := cel.NewEnv(MemoFilterCELAttributes...) + if err != nil { + return nil, errors.Wrap(err, "failed to create CEL environment") + } + ast, issues := e.Compile(filter) + if issues != nil { + return nil, errors.Errorf("failed to compile filter: %v", issues) + } + return cel.AstToParsedExpr(ast) +} + +// GetConstValue returns the constant value of the expression. +func GetConstValue(expr *exprv1.Expr) (any, error) { + switch v := expr.ExprKind.(type) { + case *exprv1.Expr_ConstExpr: + switch v.ConstExpr.ConstantKind.(type) { + case *exprv1.Constant_StringValue: + return v.ConstExpr.GetStringValue(), nil + case *exprv1.Constant_Int64Value: + return v.ConstExpr.GetInt64Value(), nil + case *exprv1.Constant_Uint64Value: + return v.ConstExpr.GetUint64Value(), nil + case *exprv1.Constant_DoubleValue: + return v.ConstExpr.GetDoubleValue(), nil + case *exprv1.Constant_BoolValue: + return v.ConstExpr.GetBoolValue(), nil + } + } + return nil, errors.New("invalid constant expression") +} diff --git a/server/router/api/v1/memo_service_filter.go b/server/router/api/v1/memo_service_filter.go index c79e4241..3908f610 100644 --- a/server/router/api/v1/memo_service_filter.go +++ b/server/router/api/v1/memo_service_filter.go @@ -5,7 +5,7 @@ import ( "github.com/google/cel-go/cel" "github.com/pkg/errors" - expr "google.golang.org/genproto/googleapis/api/expr/v1alpha1" + exprv1 "google.golang.org/genproto/googleapis/api/expr/v1alpha1" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" @@ -18,52 +18,52 @@ func (s *APIV1Service) buildMemoFindWithFilter(ctx context.Context, find *store. find.PayloadFind = &store.FindMemoPayload{} } if filter != "" { - filter, err := parseMemoFilter(filter) + filterExpr, err := parseMemoFilter(filter) if err != nil { return status.Errorf(codes.InvalidArgument, "invalid filter: %v", err) } - if len(filter.ContentSearch) > 0 { - find.ContentSearch = filter.ContentSearch + if len(filterExpr.ContentSearch) > 0 { + find.ContentSearch = filterExpr.ContentSearch } - if len(filter.Visibilities) > 0 { - find.VisibilityList = filter.Visibilities + if len(filterExpr.Visibilities) > 0 { + find.VisibilityList = filterExpr.Visibilities } - if filter.TagSearch != nil { + if filterExpr.TagSearch != nil { if find.PayloadFind == nil { find.PayloadFind = &store.FindMemoPayload{} } - find.PayloadFind.TagSearch = filter.TagSearch + find.PayloadFind.TagSearch = filterExpr.TagSearch } - if filter.OrderByPinned { - find.OrderByPinned = filter.OrderByPinned + if filterExpr.OrderByPinned { + find.OrderByPinned = filterExpr.OrderByPinned } - if filter.OrderByTimeAsc { - find.OrderByTimeAsc = filter.OrderByTimeAsc + if filterExpr.OrderByTimeAsc { + find.OrderByTimeAsc = filterExpr.OrderByTimeAsc } - if filter.DisplayTimeAfter != nil { + if filterExpr.DisplayTimeAfter != nil { workspaceMemoRelatedSetting, err := s.Store.GetWorkspaceMemoRelatedSetting(ctx) if err != nil { return status.Errorf(codes.Internal, "failed to get workspace memo related setting") } if workspaceMemoRelatedSetting.DisplayWithUpdateTime { - find.UpdatedTsAfter = filter.DisplayTimeAfter + find.UpdatedTsAfter = filterExpr.DisplayTimeAfter } else { - find.CreatedTsAfter = filter.DisplayTimeAfter + find.CreatedTsAfter = filterExpr.DisplayTimeAfter } } - if filter.DisplayTimeBefore != nil { + if filterExpr.DisplayTimeBefore != nil { workspaceMemoRelatedSetting, err := s.Store.GetWorkspaceMemoRelatedSetting(ctx) if err != nil { return status.Errorf(codes.Internal, "failed to get workspace memo related setting") } if workspaceMemoRelatedSetting.DisplayWithUpdateTime { - find.UpdatedTsBefore = filter.DisplayTimeBefore + find.UpdatedTsBefore = filterExpr.DisplayTimeBefore } else { - find.CreatedTsBefore = filter.DisplayTimeBefore + find.CreatedTsBefore = filterExpr.DisplayTimeBefore } } - if filter.Creator != nil { - userID, err := ExtractUserIDFromName(*filter.Creator) + if filterExpr.Creator != nil { + userID, err := ExtractUserIDFromName(*filterExpr.Creator) if err != nil { return errors.Wrap(err, "invalid user name") } @@ -78,28 +78,28 @@ func (s *APIV1Service) buildMemoFindWithFilter(ctx context.Context, find *store. } find.CreatorID = &user.ID } - if filter.RowStatus != nil { - find.RowStatus = filter.RowStatus + if filterExpr.RowStatus != nil { + find.RowStatus = filterExpr.RowStatus } - if filter.Random { - find.Random = filter.Random + if filterExpr.Random { + find.Random = filterExpr.Random } - if filter.Limit != nil { - find.Limit = filter.Limit + if filterExpr.Limit != nil { + find.Limit = filterExpr.Limit } - if filter.IncludeComments { + if filterExpr.IncludeComments { find.ExcludeComments = false } - if filter.HasLink { + if filterExpr.HasLink { find.PayloadFind.HasLink = true } - if filter.HasTaskList { + if filterExpr.HasTaskList { find.PayloadFind.HasTaskList = true } - if filter.HasCode { + if filterExpr.HasCode { find.PayloadFind.HasCode = true } - if filter.HasIncompleteTasks { + if filterExpr.HasIncompleteTasks { find.PayloadFind.HasIncompleteTasks = true } } @@ -181,16 +181,16 @@ func parseMemoFilter(expression string) (*MemoFilter, error) { return nil, errors.Errorf("found issue %v", issues) } filter := &MemoFilter{} - expr, err := cel.AstToParsedExpr(ast) + parsedExpr, err := cel.AstToParsedExpr(ast) if err != nil { return nil, err } - callExpr := expr.GetExpr().GetCallExpr() + callExpr := parsedExpr.GetExpr().GetCallExpr() findMemoField(callExpr, filter) return filter, nil } -func findMemoField(callExpr *expr.Expr_Call, filter *MemoFilter) { +func findMemoField(callExpr *exprv1.Expr_Call, filter *MemoFilter) { if len(callExpr.Args) == 2 { idExpr := callExpr.Args[0].GetIdentExpr() if idExpr != nil { diff --git a/store/db/sqlite/memo_filter.go b/store/db/sqlite/memo_filter.go new file mode 100644 index 00000000..b888e2d4 --- /dev/null +++ b/store/db/sqlite/memo_filter.go @@ -0,0 +1,96 @@ +package sqlite + +import ( + "fmt" + "slices" + "strings" + + "github.com/pkg/errors" + "github.com/usememos/memos/plugin/filter" + exprv1 "google.golang.org/genproto/googleapis/api/expr/v1alpha1" +) + +func RestoreExprToSQL(expr *exprv1.Expr) (string, error) { + var condition string + switch v := expr.ExprKind.(type) { + case *exprv1.Expr_CallExpr: + switch v.CallExpr.Function { + case "_||_", "_&&_": + for _, arg := range v.CallExpr.Args { + left, err := RestoreExprToSQL(arg) + if err != nil { + return "", err + } + right, err := RestoreExprToSQL(arg) + if err != nil { + return "", err + } + operator := "AND" + if v.CallExpr.Function == "_||_" { + operator = "OR" + } + condition = fmt.Sprintf("(%s %s %s)", left, operator, right) + } + case "_==_", "_!=_", "_<_", "_>_", "_<=_", "_>=_": + if len(v.CallExpr.Args) != 2 { + return "", errors.Errorf("invalid number of arguments for %s", v.CallExpr.Function) + } + // TODO(j): Implement this part. + case "@in": + if len(v.CallExpr.Args) != 2 { + return "", errors.Errorf("invalid number of arguments for %s", v.CallExpr.Function) + } + factor := v.CallExpr.Args[0].GetIdentExpr().Name + if !slices.Contains([]string{"tag"}, factor) { + return "", errors.Errorf("invalid factor for %s", v.CallExpr.Function) + } + values := []any{} + for _, element := range v.CallExpr.Args[1].GetListExpr().Elements { + value, err := filter.GetConstValue(element) + if err != nil { + return "", err + } + values = append(values, value) + } + if factor == "tag" { + t := []string{} + for _, v := range values { + t = append(t, fmt.Sprintf("JSON_EXTRACT(`memo`.`payload`, '$.tags') LIKE %s", fmt.Sprintf(`%%"%s"%%`, v))) + } + if len(t) == 1 { + condition = t[0] + } else { + condition = fmt.Sprintf("(%s)", strings.Join(t, " OR ")) + } + } + case "contains": + if len(v.CallExpr.Args) != 1 { + return "", errors.Errorf("invalid number of arguments for %s", v.CallExpr.Function) + } + factor, err := RestoreExprToSQL(v.CallExpr.Target) + if err != nil { + return "", err + } + if factor != "content" { + return "", errors.Errorf("invalid factor for %s", v.CallExpr.Function) + } + arg, err := filter.GetConstValue(v.CallExpr.Args[0]) + if err != nil { + return "", err + } + condition = fmt.Sprintf("JSON_EXTRACT(`memo`.`payload`, '$.content') LIKE %s", fmt.Sprintf(`%%"%s"%%`, arg)) + case "!_": + if len(v.CallExpr.Args) != 1 { + return "", errors.Errorf("invalid number of arguments for %s", v.CallExpr.Function) + } + arg, err := RestoreExprToSQL(v.CallExpr.Args[0]) + if err != nil { + return "", err + } + condition = fmt.Sprintf("NOT (%s)", arg) + } + case *exprv1.Expr_IdentExpr: + return v.IdentExpr.GetName(), nil + } + return condition, nil +} diff --git a/store/db/sqlite/memo_filter_test.go b/store/db/sqlite/memo_filter_test.go new file mode 100644 index 00000000..617b192c --- /dev/null +++ b/store/db/sqlite/memo_filter_test.go @@ -0,0 +1,40 @@ +package sqlite + +import ( + "testing" + + "github.com/stretchr/testify/require" + "github.com/usememos/memos/plugin/filter" +) + +func TestRestoreExprToSQL(t *testing.T) { + tests := []struct { + filter string + want string + }{ + { + filter: `tag in ["tag1", "tag2"]`, + want: "(JSON_EXTRACT(`memo`.`payload`, '$.tags') LIKE %\"tag1\"% OR JSON_EXTRACT(`memo`.`payload`, '$.tags') LIKE %\"tag2\"%)", + }, + { + filter: `!(tag in ["tag1", "tag2"])`, + want: "NOT ((JSON_EXTRACT(`memo`.`payload`, '$.tags') LIKE %\"tag1\"% OR JSON_EXTRACT(`memo`.`payload`, '$.tags') LIKE %\"tag2\"%))", + }, + { + filter: `tag in ["tag1", "tag2"] || tag in ["tag3", "tag4"]`, + want: "((JSON_EXTRACT(`memo`.`payload`, '$.tags') LIKE %\"tag3\"% OR JSON_EXTRACT(`memo`.`payload`, '$.tags') LIKE %\"tag4\"%) OR (JSON_EXTRACT(`memo`.`payload`, '$.tags') LIKE %\"tag3\"% OR JSON_EXTRACT(`memo`.`payload`, '$.tags') LIKE %\"tag4\"%))", + }, + { + filter: `content.contains("hello")`, + want: "JSON_EXTRACT(`memo`.`payload`, '$.content') LIKE %\"hello\"%", + }, + } + + for _, tt := range tests { + parsedExpr, err := filter.Parse(tt.filter) + require.NoError(t, err) + result, err := RestoreExprToSQL(parsedExpr.GetExpr()) + require.NoError(t, err) + require.Equal(t, tt.want, result) + } +} diff --git a/store/memo_filter.go b/store/memo_filter.go deleted file mode 100644 index 82939661..00000000 --- a/store/memo_filter.go +++ /dev/null @@ -1,13 +0,0 @@ -package store - -type LogicOperator string - -const ( - AND LogicOperator = "AND" - OR LogicOperator = "OR" -) - -type QueryExpression struct { - Operator LogicOperator - Children []*QueryExpression -}