diff --git a/plugin/filter/filter.go b/plugin/filter/filter.go index 0a075b48..696ba30b 100644 --- a/plugin/filter/filter.go +++ b/plugin/filter/filter.go @@ -9,7 +9,12 @@ import ( // MemoFilterCELAttributes are the CEL attributes for memo. var MemoFilterCELAttributes = []cel.EnvOption{ cel.Variable("content", cel.StringType), + // As the built-in timestamp type is deprecated, we use string type for now. + // e.g., "2021-01-01T00:00:00Z" + cel.Variable("create_time", cel.StringType), cel.Variable("tag", cel.StringType), + cel.Variable("update_time", cel.StringType), + cel.Variable("visibility", cel.StringType), } // Parse parses the filter string and returns the parsed expression. diff --git a/server/router/api/v1/memo_service_filter.go b/server/router/api/v1/memo_service_filter.go index 3908f610..e25d3bd0 100644 --- a/server/router/api/v1/memo_service_filter.go +++ b/server/router/api/v1/memo_service_filter.go @@ -141,7 +141,6 @@ var MemoFilterCELAttributes = []cel.EnvOption{ cel.Variable("display_time_before", cel.IntType), cel.Variable("display_time_after", cel.IntType), cel.Variable("creator", cel.StringType), - cel.Variable("uid", cel.StringType), cel.Variable("state", cel.StringType), cel.Variable("random", cel.BoolType), cel.Variable("limit", cel.IntType), diff --git a/store/db/sqlite/memo.go b/store/db/sqlite/memo.go index c0bdc9bf..c5c65321 100644 --- a/store/db/sqlite/memo.go +++ b/store/db/sqlite/memo.go @@ -8,6 +8,7 @@ import ( "github.com/pkg/errors" "google.golang.org/protobuf/encoding/protojson" + "github.com/usememos/memos/plugin/filter" storepb "github.com/usememos/memos/proto/gen/store" "github.com/usememos/memos/store" ) @@ -100,6 +101,20 @@ func (d *DB) ListMemos(ctx context.Context, find *store.FindMemo) ([]*store.Memo where = append(where, "JSON_EXTRACT(`memo`.`payload`, '$.property.hasIncompleteTasks') IS TRUE") } } + if v := find.Filter; v != nil { + // Parse filter string and return the parsed expression. + // The filter string should be a CEL expression. + parsedExpr, err := filter.Parse(*v, filter.MemoFilterCELAttributes...) + if err != nil { + return nil, err + } + // RestoreExprToSQL parses the expression and returns the SQL condition. + condition, err := RestoreExprToSQL(parsedExpr.GetExpr()) + if err != nil { + return nil, err + } + where = append(where, condition) + } if find.ExcludeComments { where = append(where, "`parent_id` IS NULL") } diff --git a/store/db/sqlite/memo_filter.go b/store/db/sqlite/memo_filter.go index ab19b7d0..00e83911 100644 --- a/store/db/sqlite/memo_filter.go +++ b/store/db/sqlite/memo_filter.go @@ -4,6 +4,7 @@ import ( "fmt" "slices" "strings" + "time" "github.com/pkg/errors" exprv1 "google.golang.org/genproto/googleapis/api/expr/v1alpha1" @@ -36,15 +37,55 @@ func RestoreExprToSQL(expr *exprv1.Expr) (string, error) { if len(v.CallExpr.Args) != 2 { return "", errors.Errorf("invalid number of arguments for %s", v.CallExpr.Function) } - // TODO(j): Implement this part. + identifier := v.CallExpr.Args[0].GetIdentExpr().GetName() + if !slices.Contains([]string{"create_time", "update_time"}, identifier) { + return "", errors.Errorf("invalid identifier for %s", v.CallExpr.Function) + } + value, err := filter.GetConstValue(v.CallExpr.Args[1]) + if err != nil { + return "", err + } + operator := "=" + switch v.CallExpr.Function { + case "_==_": + operator = "=" + case "_!=_": + operator = "!=" + case "_<_": + operator = "<" + case "_>_": + operator = ">" + case "_<=_": + operator = "<=" + case "_>=_": + operator = ">=" + } + + if identifier == "create_time" || identifier == "update_time" { + timestampStr, ok := value.(string) + if !ok { + return "", errors.New("invalid timestamp value") + } + timestamp, err := time.Parse(time.RFC3339, timestampStr) + if err != nil { + return "", errors.Wrap(err, "failed to parse timestamp") + } + + if identifier == "create_time" { + condition = fmt.Sprintf("`memo`.`created_ts` %s %d", operator, timestamp.Unix()) + } else if identifier == "update_time" { + condition = fmt.Sprintf("`memo`.`updated_ts` %s %d", operator, timestamp.Unix()) + } + } case "@in": if len(v.CallExpr.Args) != 2 { return "", errors.Errorf("invalid number of arguments for %s", v.CallExpr.Function) } - factor := v.CallExpr.Args[0].GetIdentExpr().Name - if !slices.Contains([]string{"tag"}, factor) { - return "", errors.Errorf("invalid factor for %s", v.CallExpr.Function) + identifier := v.CallExpr.Args[0].GetIdentExpr().GetName() + if !slices.Contains([]string{"tag", "visibility"}, identifier) { + return "", errors.Errorf("invalid identifier for %s", v.CallExpr.Function) } + values := []any{} for _, element := range v.CallExpr.Args[1].GetListExpr().Elements { value, err := filter.GetConstValue(element) @@ -53,33 +94,43 @@ func RestoreExprToSQL(expr *exprv1.Expr) (string, error) { } values = append(values, value) } - if factor == "tag" { - t := []string{} + if identifier == "tag" { + subcodition := []string{} + for _, v := range values { + subcodition = append(subcodition, fmt.Sprintf("JSON_EXTRACT(`memo`.`payload`, '$.tags') LIKE %s", fmt.Sprintf(`%%"%s"%%`, v))) + } + if len(subcodition) == 1 { + condition = subcodition[0] + } else { + condition = fmt.Sprintf("(%s)", strings.Join(subcodition, " OR ")) + } + } else if identifier == "visibility" { + vs := []string{} for _, v := range values { - t = append(t, fmt.Sprintf("JSON_EXTRACT(`memo`.`payload`, '$.tags') LIKE %s", fmt.Sprintf(`%%"%s"%%`, v))) + vs = append(vs, fmt.Sprintf(`"%s"`, v)) } - if len(t) == 1 { - condition = t[0] + if len(vs) == 1 { + condition = fmt.Sprintf("`memo`.`visibility` = %s", vs[0]) } else { - condition = fmt.Sprintf("(%s)", strings.Join(t, " OR ")) + condition = fmt.Sprintf("`memo`.`visibility` IN (%s)", strings.Join(vs, ",")) } } case "contains": if len(v.CallExpr.Args) != 1 { return "", errors.Errorf("invalid number of arguments for %s", v.CallExpr.Function) } - factor, err := RestoreExprToSQL(v.CallExpr.Target) + identifier, err := RestoreExprToSQL(v.CallExpr.Target) if err != nil { return "", err } - if factor != "content" { - return "", errors.Errorf("invalid factor for %s", v.CallExpr.Function) + if identifier != "content" { + return "", errors.Errorf("invalid identifier for %s", v.CallExpr.Function) } arg, err := filter.GetConstValue(v.CallExpr.Args[0]) if err != nil { return "", err } - condition = fmt.Sprintf("JSON_EXTRACT(`memo`.`payload`, '$.content') LIKE %s", fmt.Sprintf(`%%"%s"%%`, arg)) + condition = fmt.Sprintf("`memo`.`content` LIKE %s", fmt.Sprintf(`%%"%s"%%`, arg)) case "!_": if len(v.CallExpr.Args) != 1 { return "", errors.Errorf("invalid number of arguments for %s", v.CallExpr.Function) diff --git a/store/db/sqlite/memo_filter_test.go b/store/db/sqlite/memo_filter_test.go index a8997807..3a08cc42 100644 --- a/store/db/sqlite/memo_filter_test.go +++ b/store/db/sqlite/memo_filter_test.go @@ -26,8 +26,20 @@ func TestRestoreExprToSQL(t *testing.T) { want: "((JSON_EXTRACT(`memo`.`payload`, '$.tags') LIKE %\"tag3\"% OR JSON_EXTRACT(`memo`.`payload`, '$.tags') LIKE %\"tag4\"%) OR (JSON_EXTRACT(`memo`.`payload`, '$.tags') LIKE %\"tag3\"% OR JSON_EXTRACT(`memo`.`payload`, '$.tags') LIKE %\"tag4\"%))", }, { - filter: `content.contains("hello")`, - want: "JSON_EXTRACT(`memo`.`payload`, '$.content') LIKE %\"hello\"%", + filter: `content.contains("memos")`, + want: "`memo`.`content` LIKE %\"memos\"%", + }, + { + filter: `visibility in ["PUBLIC"]`, + want: "`memo`.`visibility` = \"PUBLIC\"", + }, + { + filter: `visibility in ["PUBLIC", "PRIVATE"]`, + want: "`memo`.`visibility` IN (\"PUBLIC\",\"PRIVATE\")", + }, + { + filter: `create_time == "2006-01-02T15:04:05+07:00"`, + want: "`memo`.`created_ts` = 1136189045", }, } diff --git a/store/memo.go b/store/memo.go index 3388695c..58172fa2 100644 --- a/store/memo.go +++ b/store/memo.go @@ -74,6 +74,7 @@ type FindMemo struct { ExcludeContent bool ExcludeComments bool Random bool + Filter *string // Pagination Limit *int