feat: support more factors in filter

pull/4379/head
johnnyjoy 4 weeks ago
parent 2a392b8f8d
commit b9a0c56163

@ -9,7 +9,12 @@ import (
// MemoFilterCELAttributes are the CEL attributes for memo.
var MemoFilterCELAttributes = []cel.EnvOption{
cel.Variable("content", cel.StringType),
// As the built-in timestamp type is deprecated, we use string type for now.
// e.g., "2021-01-01T00:00:00Z"
cel.Variable("create_time", cel.StringType),
cel.Variable("tag", cel.StringType),
cel.Variable("update_time", cel.StringType),
cel.Variable("visibility", cel.StringType),
}
// Parse parses the filter string and returns the parsed expression.

@ -141,7 +141,6 @@ var MemoFilterCELAttributes = []cel.EnvOption{
cel.Variable("display_time_before", cel.IntType),
cel.Variable("display_time_after", cel.IntType),
cel.Variable("creator", cel.StringType),
cel.Variable("uid", cel.StringType),
cel.Variable("state", cel.StringType),
cel.Variable("random", cel.BoolType),
cel.Variable("limit", cel.IntType),

@ -8,6 +8,7 @@ import (
"github.com/pkg/errors"
"google.golang.org/protobuf/encoding/protojson"
"github.com/usememos/memos/plugin/filter"
storepb "github.com/usememos/memos/proto/gen/store"
"github.com/usememos/memos/store"
)
@ -100,6 +101,20 @@ func (d *DB) ListMemos(ctx context.Context, find *store.FindMemo) ([]*store.Memo
where = append(where, "JSON_EXTRACT(`memo`.`payload`, '$.property.hasIncompleteTasks') IS TRUE")
}
}
if v := find.Filter; v != nil {
// Parse filter string and return the parsed expression.
// The filter string should be a CEL expression.
parsedExpr, err := filter.Parse(*v, filter.MemoFilterCELAttributes...)
if err != nil {
return nil, err
}
// RestoreExprToSQL parses the expression and returns the SQL condition.
condition, err := RestoreExprToSQL(parsedExpr.GetExpr())
if err != nil {
return nil, err
}
where = append(where, condition)
}
if find.ExcludeComments {
where = append(where, "`parent_id` IS NULL")
}

@ -4,6 +4,7 @@ import (
"fmt"
"slices"
"strings"
"time"
"github.com/pkg/errors"
exprv1 "google.golang.org/genproto/googleapis/api/expr/v1alpha1"
@ -36,15 +37,55 @@ func RestoreExprToSQL(expr *exprv1.Expr) (string, error) {
if len(v.CallExpr.Args) != 2 {
return "", errors.Errorf("invalid number of arguments for %s", v.CallExpr.Function)
}
// TODO(j): Implement this part.
identifier := v.CallExpr.Args[0].GetIdentExpr().GetName()
if !slices.Contains([]string{"create_time", "update_time"}, identifier) {
return "", errors.Errorf("invalid identifier for %s", v.CallExpr.Function)
}
value, err := filter.GetConstValue(v.CallExpr.Args[1])
if err != nil {
return "", err
}
operator := "="
switch v.CallExpr.Function {
case "_==_":
operator = "="
case "_!=_":
operator = "!="
case "_<_":
operator = "<"
case "_>_":
operator = ">"
case "_<=_":
operator = "<="
case "_>=_":
operator = ">="
}
if identifier == "create_time" || identifier == "update_time" {
timestampStr, ok := value.(string)
if !ok {
return "", errors.New("invalid timestamp value")
}
timestamp, err := time.Parse(time.RFC3339, timestampStr)
if err != nil {
return "", errors.Wrap(err, "failed to parse timestamp")
}
if identifier == "create_time" {
condition = fmt.Sprintf("`memo`.`created_ts` %s %d", operator, timestamp.Unix())
} else if identifier == "update_time" {
condition = fmt.Sprintf("`memo`.`updated_ts` %s %d", operator, timestamp.Unix())
}
}
case "@in":
if len(v.CallExpr.Args) != 2 {
return "", errors.Errorf("invalid number of arguments for %s", v.CallExpr.Function)
}
factor := v.CallExpr.Args[0].GetIdentExpr().Name
if !slices.Contains([]string{"tag"}, factor) {
return "", errors.Errorf("invalid factor for %s", v.CallExpr.Function)
identifier := v.CallExpr.Args[0].GetIdentExpr().GetName()
if !slices.Contains([]string{"tag", "visibility"}, identifier) {
return "", errors.Errorf("invalid identifier for %s", v.CallExpr.Function)
}
values := []any{}
for _, element := range v.CallExpr.Args[1].GetListExpr().Elements {
value, err := filter.GetConstValue(element)
@ -53,33 +94,43 @@ func RestoreExprToSQL(expr *exprv1.Expr) (string, error) {
}
values = append(values, value)
}
if factor == "tag" {
t := []string{}
if identifier == "tag" {
subcodition := []string{}
for _, v := range values {
subcodition = append(subcodition, fmt.Sprintf("JSON_EXTRACT(`memo`.`payload`, '$.tags') LIKE %s", fmt.Sprintf(`%%"%s"%%`, v)))
}
if len(subcodition) == 1 {
condition = subcodition[0]
} else {
condition = fmt.Sprintf("(%s)", strings.Join(subcodition, " OR "))
}
} else if identifier == "visibility" {
vs := []string{}
for _, v := range values {
t = append(t, fmt.Sprintf("JSON_EXTRACT(`memo`.`payload`, '$.tags') LIKE %s", fmt.Sprintf(`%%"%s"%%`, v)))
vs = append(vs, fmt.Sprintf(`"%s"`, v))
}
if len(t) == 1 {
condition = t[0]
if len(vs) == 1 {
condition = fmt.Sprintf("`memo`.`visibility` = %s", vs[0])
} else {
condition = fmt.Sprintf("(%s)", strings.Join(t, " OR "))
condition = fmt.Sprintf("`memo`.`visibility` IN (%s)", strings.Join(vs, ","))
}
}
case "contains":
if len(v.CallExpr.Args) != 1 {
return "", errors.Errorf("invalid number of arguments for %s", v.CallExpr.Function)
}
factor, err := RestoreExprToSQL(v.CallExpr.Target)
identifier, err := RestoreExprToSQL(v.CallExpr.Target)
if err != nil {
return "", err
}
if factor != "content" {
return "", errors.Errorf("invalid factor for %s", v.CallExpr.Function)
if identifier != "content" {
return "", errors.Errorf("invalid identifier for %s", v.CallExpr.Function)
}
arg, err := filter.GetConstValue(v.CallExpr.Args[0])
if err != nil {
return "", err
}
condition = fmt.Sprintf("JSON_EXTRACT(`memo`.`payload`, '$.content') LIKE %s", fmt.Sprintf(`%%"%s"%%`, arg))
condition = fmt.Sprintf("`memo`.`content` LIKE %s", fmt.Sprintf(`%%"%s"%%`, arg))
case "!_":
if len(v.CallExpr.Args) != 1 {
return "", errors.Errorf("invalid number of arguments for %s", v.CallExpr.Function)

@ -26,8 +26,20 @@ func TestRestoreExprToSQL(t *testing.T) {
want: "((JSON_EXTRACT(`memo`.`payload`, '$.tags') LIKE %\"tag3\"% OR JSON_EXTRACT(`memo`.`payload`, '$.tags') LIKE %\"tag4\"%) OR (JSON_EXTRACT(`memo`.`payload`, '$.tags') LIKE %\"tag3\"% OR JSON_EXTRACT(`memo`.`payload`, '$.tags') LIKE %\"tag4\"%))",
},
{
filter: `content.contains("hello")`,
want: "JSON_EXTRACT(`memo`.`payload`, '$.content') LIKE %\"hello\"%",
filter: `content.contains("memos")`,
want: "`memo`.`content` LIKE %\"memos\"%",
},
{
filter: `visibility in ["PUBLIC"]`,
want: "`memo`.`visibility` = \"PUBLIC\"",
},
{
filter: `visibility in ["PUBLIC", "PRIVATE"]`,
want: "`memo`.`visibility` IN (\"PUBLIC\",\"PRIVATE\")",
},
{
filter: `create_time == "2006-01-02T15:04:05+07:00"`,
want: "`memo`.`created_ts` = 1136189045",
},
}

@ -74,6 +74,7 @@ type FindMemo struct {
ExcludeContent bool
ExcludeComments bool
Random bool
Filter *string
// Pagination
Limit *int

Loading…
Cancel
Save