feat: support memo filter for mysql and postgres

pull/4390/head^2
johnnyjoy 3 months ago
parent 0f8b7b7fe6
commit e0e735d14d

@ -0,0 +1,20 @@
package filter
import (
"strings"
)
type ConvertContext struct {
Buffer strings.Builder
Args []any
// The offset of the next argument in the condition string.
// Mainly using for PostgreSQL.
ArgsOffset int
}
func NewConvertContext() *ConvertContext {
return &ConvertContext{
Buffer: strings.Builder{},
Args: []any{},
}
}

@ -0,0 +1,39 @@
package filter
import (
"errors"
exprv1 "google.golang.org/genproto/googleapis/api/expr/v1alpha1"
)
// GetConstValue returns the constant value of the expression.
func GetConstValue(expr *exprv1.Expr) (any, error) {
v, ok := expr.ExprKind.(*exprv1.Expr_ConstExpr)
if !ok {
return nil, errors.New("invalid constant expression")
}
switch v.ConstExpr.ConstantKind.(type) {
case *exprv1.Constant_StringValue:
return v.ConstExpr.GetStringValue(), nil
case *exprv1.Constant_Int64Value:
return v.ConstExpr.GetInt64Value(), nil
case *exprv1.Constant_Uint64Value:
return v.ConstExpr.GetUint64Value(), nil
case *exprv1.Constant_DoubleValue:
return v.ConstExpr.GetDoubleValue(), nil
case *exprv1.Constant_BoolValue:
return v.ConstExpr.GetBoolValue(), nil
default:
return nil, errors.New("unexpected constant type")
}
}
// GetIdentExprName returns the name of the identifier expression.
func GetIdentExprName(expr *exprv1.Expr) (string, error) {
_, ok := expr.ExprKind.(*exprv1.Expr_IdentExpr)
if !ok {
return "", errors.New("invalid identifier expression")
}
return expr.GetIdentExpr().GetName(), nil
}

@ -30,26 +30,3 @@ func Parse(filter string, opts ...cel.EnvOption) (expr *exprv1.ParsedExpr, err e
} }
return cel.AstToParsedExpr(ast) return cel.AstToParsedExpr(ast)
} }
// GetConstValue returns the constant value of the expression.
func GetConstValue(expr *exprv1.Expr) (any, error) {
v, ok := expr.ExprKind.(*exprv1.Expr_ConstExpr)
if !ok {
return nil, errors.New("invalid constant expression")
}
switch v.ConstExpr.ConstantKind.(type) {
case *exprv1.Constant_StringValue:
return v.ConstExpr.GetStringValue(), nil
case *exprv1.Constant_Int64Value:
return v.ConstExpr.GetInt64Value(), nil
case *exprv1.Constant_Uint64Value:
return v.ConstExpr.GetUint64Value(), nil
case *exprv1.Constant_DoubleValue:
return v.ConstExpr.GetDoubleValue(), nil
case *exprv1.Constant_BoolValue:
return v.ConstExpr.GetBoolValue(), nil
default:
return nil, errors.New("unexpected constant type")
}
}

@ -8,6 +8,7 @@ import (
"github.com/pkg/errors" "github.com/pkg/errors"
"google.golang.org/protobuf/encoding/protojson" "google.golang.org/protobuf/encoding/protojson"
"github.com/usememos/memos/plugin/filter"
storepb "github.com/usememos/memos/proto/gen/store" storepb "github.com/usememos/memos/proto/gen/store"
"github.com/usememos/memos/store" "github.com/usememos/memos/store"
) )
@ -108,6 +109,21 @@ func (d *DB) ListMemos(ctx context.Context, find *store.FindMemo) ([]*store.Memo
where = append(where, "JSON_EXTRACT(`memo`.`payload`, '$.property.hasIncompleteTasks') IS TRUE") where = append(where, "JSON_EXTRACT(`memo`.`payload`, '$.property.hasIncompleteTasks') IS TRUE")
} }
} }
if v := find.Filter; v != nil {
// Parse filter string and return the parsed expression.
// The filter string should be a CEL expression.
parsedExpr, err := filter.Parse(*v, filter.MemoFilterCELAttributes...)
if err != nil {
return nil, err
}
convertCtx := filter.NewConvertContext()
// ConvertExprToSQL converts the parsed expression to a SQL condition string.
if err := ConvertExprToSQL(convertCtx, parsedExpr.GetExpr()); err != nil {
return nil, err
}
where = append(where, fmt.Sprintf("(%s)", convertCtx.Buffer.String()))
args = append(args, convertCtx.Args...)
}
if find.ExcludeComments { if find.ExcludeComments {
having = append(having, "`parent_id` IS NULL") having = append(having, "`parent_id` IS NULL")
} }

@ -0,0 +1,175 @@
package mysql
import (
"fmt"
"slices"
"strings"
"time"
"github.com/pkg/errors"
exprv1 "google.golang.org/genproto/googleapis/api/expr/v1alpha1"
"github.com/usememos/memos/plugin/filter"
)
func ConvertExprToSQL(ctx *filter.ConvertContext, expr *exprv1.Expr) error {
switch v := expr.ExprKind.(type) {
case *exprv1.Expr_CallExpr:
switch v.CallExpr.Function {
case "_||_", "_&&_":
if len(v.CallExpr.Args) != 2 {
return errors.Errorf("invalid number of arguments for %s", v.CallExpr.Function)
}
if _, err := ctx.Buffer.WriteString("("); err != nil {
return err
}
if err := ConvertExprToSQL(ctx, v.CallExpr.Args[0]); err != nil {
return err
}
operator := "AND"
if v.CallExpr.Function == "_||_" {
operator = "OR"
}
if _, err := ctx.Buffer.WriteString(fmt.Sprintf(" %s ", operator)); err != nil {
return err
}
if err := ConvertExprToSQL(ctx, v.CallExpr.Args[1]); err != nil {
return err
}
if _, err := ctx.Buffer.WriteString(")"); err != nil {
return err
}
case "!_":
if len(v.CallExpr.Args) != 1 {
return errors.Errorf("invalid number of arguments for %s", v.CallExpr.Function)
}
if _, err := ctx.Buffer.WriteString("NOT ("); err != nil {
return err
}
if err := ConvertExprToSQL(ctx, v.CallExpr.Args[0]); err != nil {
return err
}
if _, err := ctx.Buffer.WriteString(")"); err != nil {
return err
}
case "_==_", "_!=_", "_<_", "_>_", "_<=_", "_>=_":
if len(v.CallExpr.Args) != 2 {
return errors.Errorf("invalid number of arguments for %s", v.CallExpr.Function)
}
identifier, err := filter.GetIdentExprName(v.CallExpr.Args[0])
if err != nil {
return err
}
if !slices.Contains([]string{"create_time", "update_time"}, identifier) {
return errors.Errorf("invalid identifier for %s", v.CallExpr.Function)
}
value, err := filter.GetConstValue(v.CallExpr.Args[1])
if err != nil {
return err
}
operator := "="
switch v.CallExpr.Function {
case "_==_":
operator = "="
case "_!=_":
operator = "!="
case "_<_":
operator = "<"
case "_>_":
operator = ">"
case "_<=_":
operator = "<="
case "_>=_":
operator = ">="
}
if identifier == "create_time" || identifier == "update_time" {
timestampStr, ok := value.(string)
if !ok {
return errors.New("invalid timestamp value")
}
timestamp, err := time.Parse(time.RFC3339, timestampStr)
if err != nil {
return errors.Wrap(err, "failed to parse timestamp")
}
var factor string
if identifier == "create_time" {
factor = "`memo`.`created_ts`"
} else if identifier == "update_time" {
factor = "`memo`.`updated_ts`"
}
if _, err := ctx.Buffer.WriteString(fmt.Sprintf("UNIX_TIMESTAMP(%s) %s ?", factor, operator)); err != nil {
return err
}
ctx.Args = append(ctx.Args, timestamp.Unix())
}
case "@in":
if len(v.CallExpr.Args) != 2 {
return errors.Errorf("invalid number of arguments for %s", v.CallExpr.Function)
}
identifier, err := filter.GetIdentExprName(v.CallExpr.Args[0])
if err != nil {
return err
}
if !slices.Contains([]string{"tag", "visibility"}, identifier) {
return errors.Errorf("invalid identifier for %s", v.CallExpr.Function)
}
values := []any{}
for _, element := range v.CallExpr.Args[1].GetListExpr().Elements {
value, err := filter.GetConstValue(element)
if err != nil {
return err
}
values = append(values, value)
}
if identifier == "tag" {
subcodition := []string{}
args := []any{}
for _, v := range values {
subcodition, args = append(subcodition, "JSON_CONTAINS(JSON_EXTRACT(`memo`.`payload`, '$.tags'), ?)"), append(args, v)
}
if len(subcodition) == 1 {
if _, err := ctx.Buffer.WriteString(subcodition[0]); err != nil {
return err
}
} else {
if _, err := ctx.Buffer.WriteString(fmt.Sprintf("(%s)", strings.Join(subcodition, " OR "))); err != nil {
return err
}
}
ctx.Args = append(ctx.Args, args...)
} else if identifier == "visibility" {
placeholder := []string{}
for range values {
placeholder = append(placeholder, "?")
}
if _, err := ctx.Buffer.WriteString(fmt.Sprintf("`memo`.`visibility` IN (%s)", strings.Join(placeholder, ","))); err != nil {
return err
}
ctx.Args = append(ctx.Args, values...)
}
case "contains":
if len(v.CallExpr.Args) != 1 {
return errors.Errorf("invalid number of arguments for %s", v.CallExpr.Function)
}
identifier, err := filter.GetIdentExprName(v.CallExpr.Target)
if err != nil {
return err
}
if identifier != "content" {
return errors.Errorf("invalid identifier for %s", v.CallExpr.Function)
}
arg, err := filter.GetConstValue(v.CallExpr.Args[0])
if err != nil {
return err
}
if _, err := ctx.Buffer.WriteString("`memo`.`content` LIKE ?"); err != nil {
return err
}
ctx.Args = append(ctx.Args, fmt.Sprintf("%%%s%%", arg))
}
}
return nil
}

@ -0,0 +1,63 @@
package mysql
import (
"testing"
"github.com/stretchr/testify/require"
"github.com/usememos/memos/plugin/filter"
)
func TestConvertExprToSQL(t *testing.T) {
tests := []struct {
filter string
want string
args []any
}{
{
filter: `tag in ["tag1", "tag2"]`,
want: "(JSON_CONTAINS(JSON_EXTRACT(`memo`.`payload`, '$.tags'), ?) OR JSON_CONTAINS(JSON_EXTRACT(`memo`.`payload`, '$.tags'), ?))",
args: []any{"tag1", "tag2"},
},
{
filter: `!(tag in ["tag1", "tag2"])`,
want: "NOT ((JSON_CONTAINS(JSON_EXTRACT(`memo`.`payload`, '$.tags'), ?) OR JSON_CONTAINS(JSON_EXTRACT(`memo`.`payload`, '$.tags'), ?)))",
args: []any{"tag1", "tag2"},
},
{
filter: `content.contains("memos")`,
want: "`memo`.`content` LIKE ?",
args: []any{"%memos%"},
},
{
filter: `visibility in ["PUBLIC"]`,
want: "`memo`.`visibility` IN (?)",
args: []any{"PUBLIC"},
},
{
filter: `visibility in ["PUBLIC", "PRIVATE"]`,
want: "`memo`.`visibility` IN (?,?)",
args: []any{"PUBLIC", "PRIVATE"},
},
{
filter: `create_time == "2006-01-02T15:04:05+07:00"`,
want: "UNIX_TIMESTAMP(`memo`.`created_ts`) = ?",
args: []any{int64(1136189045)},
},
{
filter: `tag in ['tag1'] || content.contains('hello')`,
want: "(JSON_CONTAINS(JSON_EXTRACT(`memo`.`payload`, '$.tags'), ?) OR `memo`.`content` LIKE ?)",
args: []any{"tag1", "%hello%"},
},
}
for _, tt := range tests {
parsedExpr, err := filter.Parse(tt.filter, filter.MemoFilterCELAttributes...)
require.NoError(t, err)
convertCtx := filter.NewConvertContext()
err = ConvertExprToSQL(convertCtx, parsedExpr.GetExpr())
require.NoError(t, err)
require.Equal(t, tt.want, convertCtx.Buffer.String())
require.Equal(t, tt.args, convertCtx.Args)
}
}

@ -8,6 +8,7 @@ import (
"github.com/pkg/errors" "github.com/pkg/errors"
"google.golang.org/protobuf/encoding/protojson" "google.golang.org/protobuf/encoding/protojson"
"github.com/usememos/memos/plugin/filter"
storepb "github.com/usememos/memos/proto/gen/store" storepb "github.com/usememos/memos/proto/gen/store"
"github.com/usememos/memos/store" "github.com/usememos/memos/store"
) )
@ -99,6 +100,22 @@ func (d *DB) ListMemos(ctx context.Context, find *store.FindMemo) ([]*store.Memo
where = append(where, "(memo.payload->'property'->>'hasIncompleteTasks')::BOOLEAN IS TRUE") where = append(where, "(memo.payload->'property'->>'hasIncompleteTasks')::BOOLEAN IS TRUE")
} }
} }
if v := find.Filter; v != nil {
// Parse filter string and return the parsed expression.
// The filter string should be a CEL expression.
parsedExpr, err := filter.Parse(*v, filter.MemoFilterCELAttributes...)
if err != nil {
return nil, err
}
convertCtx := filter.NewConvertContext()
convertCtx.ArgsOffset = len(args)
// ConvertExprToSQL converts the parsed expression to a SQL condition string.
if err := ConvertExprToSQL(convertCtx, parsedExpr.GetExpr()); err != nil {
return nil, err
}
where = append(where, fmt.Sprintf("(%s)", convertCtx.Buffer.String()))
args = append(args, convertCtx.Args...)
}
if find.ExcludeComments { if find.ExcludeComments {
where = append(where, "memo_relation.related_memo_id IS NULL") where = append(where, "memo_relation.related_memo_id IS NULL")
} }

@ -0,0 +1,175 @@
package postgres
import (
"fmt"
"slices"
"strings"
"time"
"github.com/pkg/errors"
exprv1 "google.golang.org/genproto/googleapis/api/expr/v1alpha1"
"github.com/usememos/memos/plugin/filter"
)
func ConvertExprToSQL(ctx *filter.ConvertContext, expr *exprv1.Expr) error {
switch v := expr.ExprKind.(type) {
case *exprv1.Expr_CallExpr:
switch v.CallExpr.Function {
case "_||_", "_&&_":
if len(v.CallExpr.Args) != 2 {
return errors.Errorf("invalid number of arguments for %s", v.CallExpr.Function)
}
if _, err := ctx.Buffer.WriteString("("); err != nil {
return err
}
if err := ConvertExprToSQL(ctx, v.CallExpr.Args[0]); err != nil {
return err
}
operator := "AND"
if v.CallExpr.Function == "_||_" {
operator = "OR"
}
if _, err := ctx.Buffer.WriteString(fmt.Sprintf(" %s ", operator)); err != nil {
return err
}
if err := ConvertExprToSQL(ctx, v.CallExpr.Args[1]); err != nil {
return err
}
if _, err := ctx.Buffer.WriteString(")"); err != nil {
return err
}
case "!_":
if len(v.CallExpr.Args) != 1 {
return errors.Errorf("invalid number of arguments for %s", v.CallExpr.Function)
}
if _, err := ctx.Buffer.WriteString("NOT ("); err != nil {
return err
}
if err := ConvertExprToSQL(ctx, v.CallExpr.Args[0]); err != nil {
return err
}
if _, err := ctx.Buffer.WriteString(")"); err != nil {
return err
}
case "_==_", "_!=_", "_<_", "_>_", "_<=_", "_>=_":
if len(v.CallExpr.Args) != 2 {
return errors.Errorf("invalid number of arguments for %s", v.CallExpr.Function)
}
identifier, err := filter.GetIdentExprName(v.CallExpr.Args[0])
if err != nil {
return err
}
if !slices.Contains([]string{"create_time", "update_time"}, identifier) {
return errors.Errorf("invalid identifier for %s", v.CallExpr.Function)
}
value, err := filter.GetConstValue(v.CallExpr.Args[1])
if err != nil {
return err
}
operator := "="
switch v.CallExpr.Function {
case "_==_":
operator = "="
case "_!=_":
operator = "!="
case "_<_":
operator = "<"
case "_>_":
operator = ">"
case "_<=_":
operator = "<="
case "_>=_":
operator = ">="
}
if identifier == "create_time" || identifier == "update_time" {
timestampStr, ok := value.(string)
if !ok {
return errors.New("invalid timestamp value")
}
timestamp, err := time.Parse(time.RFC3339, timestampStr)
if err != nil {
return errors.Wrap(err, "failed to parse timestamp")
}
var factor string
if identifier == "create_time" {
factor = "memo.created_ts"
} else if identifier == "update_time" {
factor = "memo.updated_ts"
}
if _, err := ctx.Buffer.WriteString(fmt.Sprintf("%s %s %s", factor, operator, placeholder(len(ctx.Args)+ctx.ArgsOffset+1))); err != nil {
return err
}
ctx.Args = append(ctx.Args, timestamp.Unix())
}
case "@in":
if len(v.CallExpr.Args) != 2 {
return errors.Errorf("invalid number of arguments for %s", v.CallExpr.Function)
}
identifier, err := filter.GetIdentExprName(v.CallExpr.Args[0])
if err != nil {
return err
}
if !slices.Contains([]string{"tag", "visibility"}, identifier) {
return errors.Errorf("invalid identifier for %s", v.CallExpr.Function)
}
values := []any{}
for _, element := range v.CallExpr.Args[1].GetListExpr().Elements {
value, err := filter.GetConstValue(element)
if err != nil {
return err
}
values = append(values, value)
}
if identifier == "tag" {
subcodition := []string{}
args := []any{}
for _, v := range values {
subcodition, args = append(subcodition, fmt.Sprintf(`memo.payload->'tags' @> %s::jsonb`, placeholder(len(ctx.Args)+ctx.ArgsOffset+len(args)+1))), append(args, []any{v})
}
if len(subcodition) == 1 {
if _, err := ctx.Buffer.WriteString(subcodition[0]); err != nil {
return err
}
} else {
if _, err := ctx.Buffer.WriteString(fmt.Sprintf("(%s)", strings.Join(subcodition, " OR "))); err != nil {
return err
}
}
ctx.Args = append(ctx.Args, args...)
} else if identifier == "visibility" {
placeholders := []string{}
for i := range values {
placeholders = append(placeholders, placeholder(len(ctx.Args)+ctx.ArgsOffset+i+1))
}
if _, err := ctx.Buffer.WriteString(fmt.Sprintf("memo.visibility IN (%s)", strings.Join(placeholders, ","))); err != nil {
return err
}
ctx.Args = append(ctx.Args, values...)
}
case "contains":
if len(v.CallExpr.Args) != 1 {
return errors.Errorf("invalid number of arguments for %s", v.CallExpr.Function)
}
identifier, err := filter.GetIdentExprName(v.CallExpr.Target)
if err != nil {
return err
}
if identifier != "content" {
return errors.Errorf("invalid identifier for %s", v.CallExpr.Function)
}
arg, err := filter.GetConstValue(v.CallExpr.Args[0])
if err != nil {
return err
}
if _, err := ctx.Buffer.WriteString("memo.content ILIKE LIKE " + placeholder(len(ctx.Args)+ctx.ArgsOffset+1)); err != nil {
return err
}
ctx.Args = append(ctx.Args, fmt.Sprintf("%%%s%%", arg))
}
}
return nil
}

@ -0,0 +1,63 @@
package postgres
import (
"testing"
"github.com/stretchr/testify/require"
"github.com/usememos/memos/plugin/filter"
)
func TestRestoreExprToSQL(t *testing.T) {
tests := []struct {
filter string
want string
args []any
}{
{
filter: `tag in ["tag1", "tag2"]`,
want: "(memo.payload->'tags' @> $1::jsonb OR memo.payload->'tags' @> $2::jsonb)",
args: []any{[]any{"tag1"}, []any{"tag2"}},
},
{
filter: `!(tag in ["tag1", "tag2"])`,
want: `NOT ((memo.payload->'tags' @> $1::jsonb OR memo.payload->'tags' @> $2::jsonb))`,
args: []any{[]any{"tag1"}, []any{"tag2"}},
},
{
filter: `content.contains("memos")`,
want: "memo.content ILIKE LIKE $1",
args: []any{"%memos%"},
},
{
filter: `visibility in ["PUBLIC"]`,
want: "memo.visibility IN ($1)",
args: []any{"PUBLIC"},
},
{
filter: `visibility in ["PUBLIC", "PRIVATE"]`,
want: "memo.visibility IN ($1,$2)",
args: []any{"PUBLIC", "PRIVATE"},
},
{
filter: `create_time == "2006-01-02T15:04:05+07:00"`,
want: "memo.created_ts = $1",
args: []any{int64(1136189045)},
},
{
filter: `tag in ['tag1'] || content.contains('hello')`,
want: "(memo.payload->'tags' @> $1::jsonb OR memo.content ILIKE LIKE $2)",
args: []any{[]any{"tag1"}, "%hello%"},
},
}
for _, tt := range tests {
parsedExpr, err := filter.Parse(tt.filter, filter.MemoFilterCELAttributes...)
require.NoError(t, err)
convertCtx := filter.NewConvertContext()
err = ConvertExprToSQL(convertCtx, parsedExpr.GetExpr())
require.NoError(t, err)
require.Equal(t, tt.want, convertCtx.Buffer.String())
require.Equal(t, tt.args, convertCtx.Args)
}
}

@ -108,12 +108,13 @@ func (d *DB) ListMemos(ctx context.Context, find *store.FindMemo) ([]*store.Memo
if err != nil { if err != nil {
return nil, err return nil, err
} }
// RestoreExprToSQL parses the expression and returns the SQL condition. convertCtx := filter.NewConvertContext()
condition, err := RestoreExprToSQL(parsedExpr.GetExpr()) // ConvertExprToSQL converts the parsed expression to a SQL condition string.
if err != nil { if err := ConvertExprToSQL(convertCtx, parsedExpr.GetExpr()); err != nil {
return nil, err return nil, err
} }
where = append(where, fmt.Sprintf("(%s)", condition)) where = append(where, fmt.Sprintf("(%s)", convertCtx.Buffer.String()))
args = append(args, convertCtx.Args...)
} }
if find.ExcludeComments { if find.ExcludeComments {
where = append(where, "`parent_id` IS NULL") where = append(where, "`parent_id` IS NULL")

@ -12,39 +12,60 @@ import (
"github.com/usememos/memos/plugin/filter" "github.com/usememos/memos/plugin/filter"
) )
func RestoreExprToSQL(expr *exprv1.Expr) (string, error) { func ConvertExprToSQL(ctx *filter.ConvertContext, expr *exprv1.Expr) error {
var condition string
switch v := expr.ExprKind.(type) { switch v := expr.ExprKind.(type) {
case *exprv1.Expr_CallExpr: case *exprv1.Expr_CallExpr:
switch v.CallExpr.Function { switch v.CallExpr.Function {
case "_||_", "_&&_": case "_||_", "_&&_":
if len(v.CallExpr.Args) != 2 { if len(v.CallExpr.Args) != 2 {
return "", errors.Errorf("invalid number of arguments for %s", v.CallExpr.Function) return errors.Errorf("invalid number of arguments for %s", v.CallExpr.Function)
} }
left, err := RestoreExprToSQL(v.CallExpr.Args[0]) if _, err := ctx.Buffer.WriteString("("); err != nil {
if err != nil { return err
return "", err
} }
right, err := RestoreExprToSQL(v.CallExpr.Args[1]) if err := ConvertExprToSQL(ctx, v.CallExpr.Args[0]); err != nil {
if err != nil { return err
return "", err
} }
operator := "AND" operator := "AND"
if v.CallExpr.Function == "_||_" { if v.CallExpr.Function == "_||_" {
operator = "OR" operator = "OR"
} }
condition = fmt.Sprintf("(%s %s %s)", left, operator, right) if _, err := ctx.Buffer.WriteString(fmt.Sprintf(" %s ", operator)); err != nil {
return err
}
if err := ConvertExprToSQL(ctx, v.CallExpr.Args[1]); err != nil {
return err
}
if _, err := ctx.Buffer.WriteString(")"); err != nil {
return err
}
case "!_":
if len(v.CallExpr.Args) != 1 {
return errors.Errorf("invalid number of arguments for %s", v.CallExpr.Function)
}
if _, err := ctx.Buffer.WriteString("NOT ("); err != nil {
return err
}
if err := ConvertExprToSQL(ctx, v.CallExpr.Args[0]); err != nil {
return err
}
if _, err := ctx.Buffer.WriteString(")"); err != nil {
return err
}
case "_==_", "_!=_", "_<_", "_>_", "_<=_", "_>=_": case "_==_", "_!=_", "_<_", "_>_", "_<=_", "_>=_":
if len(v.CallExpr.Args) != 2 { if len(v.CallExpr.Args) != 2 {
return "", errors.Errorf("invalid number of arguments for %s", v.CallExpr.Function) return errors.Errorf("invalid number of arguments for %s", v.CallExpr.Function)
}
identifier, err := filter.GetIdentExprName(v.CallExpr.Args[0])
if err != nil {
return err
} }
identifier := v.CallExpr.Args[0].GetIdentExpr().GetName()
if !slices.Contains([]string{"create_time", "update_time"}, identifier) { if !slices.Contains([]string{"create_time", "update_time"}, identifier) {
return "", errors.Errorf("invalid identifier for %s", v.CallExpr.Function) return errors.Errorf("invalid identifier for %s", v.CallExpr.Function)
} }
value, err := filter.GetConstValue(v.CallExpr.Args[1]) value, err := filter.GetConstValue(v.CallExpr.Args[1])
if err != nil { if err != nil {
return "", err return err
} }
operator := "=" operator := "="
switch v.CallExpr.Function { switch v.CallExpr.Function {
@ -65,85 +86,90 @@ func RestoreExprToSQL(expr *exprv1.Expr) (string, error) {
if identifier == "create_time" || identifier == "update_time" { if identifier == "create_time" || identifier == "update_time" {
timestampStr, ok := value.(string) timestampStr, ok := value.(string)
if !ok { if !ok {
return "", errors.New("invalid timestamp value") return errors.New("invalid timestamp value")
} }
timestamp, err := time.Parse(time.RFC3339, timestampStr) timestamp, err := time.Parse(time.RFC3339, timestampStr)
if err != nil { if err != nil {
return "", errors.Wrap(err, "failed to parse timestamp") return errors.Wrap(err, "failed to parse timestamp")
} }
var factor string
if identifier == "create_time" { if identifier == "create_time" {
condition = fmt.Sprintf("`memo`.`created_ts` %s %d", operator, timestamp.Unix()) factor = "`memo`.`created_ts`"
} else if identifier == "update_time" { } else if identifier == "update_time" {
condition = fmt.Sprintf("`memo`.`updated_ts` %s %d", operator, timestamp.Unix()) factor = "`memo`.`updated_ts`"
}
if _, err := ctx.Buffer.WriteString(fmt.Sprintf("%s %s ?", factor, operator)); err != nil {
return err
} }
ctx.Args = append(ctx.Args, timestamp.Unix())
} }
case "@in": case "@in":
if len(v.CallExpr.Args) != 2 { if len(v.CallExpr.Args) != 2 {
return "", errors.Errorf("invalid number of arguments for %s", v.CallExpr.Function) return errors.Errorf("invalid number of arguments for %s", v.CallExpr.Function)
}
identifier, err := filter.GetIdentExprName(v.CallExpr.Args[0])
if err != nil {
return err
} }
identifier := v.CallExpr.Args[0].GetIdentExpr().GetName()
if !slices.Contains([]string{"tag", "visibility"}, identifier) { if !slices.Contains([]string{"tag", "visibility"}, identifier) {
return "", errors.Errorf("invalid identifier for %s", v.CallExpr.Function) return errors.Errorf("invalid identifier for %s", v.CallExpr.Function)
} }
values := []any{} values := []any{}
for _, element := range v.CallExpr.Args[1].GetListExpr().Elements { for _, element := range v.CallExpr.Args[1].GetListExpr().Elements {
value, err := filter.GetConstValue(element) value, err := filter.GetConstValue(element)
if err != nil { if err != nil {
return "", err return err
} }
values = append(values, value) values = append(values, value)
} }
if identifier == "tag" { if identifier == "tag" {
subcodition := []string{} subcodition := []string{}
args := []any{}
for _, v := range values { for _, v := range values {
subcodition = append(subcodition, fmt.Sprintf("JSON_EXTRACT(`memo`.`payload`, '$.tags') LIKE %s", fmt.Sprintf(`'%%"%s"%%'`, v))) subcodition, args = append(subcodition, "JSON_EXTRACT(`memo`.`payload`, '$.tags') LIKE ?"), append(args, fmt.Sprintf(`%%"%s"%%`, v))
} }
if len(subcodition) == 1 { if len(subcodition) == 1 {
condition = subcodition[0] if _, err := ctx.Buffer.WriteString(subcodition[0]); err != nil {
return err
}
} else { } else {
condition = fmt.Sprintf("(%s)", strings.Join(subcodition, " OR ")) if _, err := ctx.Buffer.WriteString(fmt.Sprintf("(%s)", strings.Join(subcodition, " OR "))); err != nil {
return err
}
} }
ctx.Args = append(ctx.Args, args...)
} else if identifier == "visibility" { } else if identifier == "visibility" {
vs := []string{} placeholder := []string{}
for _, v := range values { for range values {
vs = append(vs, fmt.Sprintf(`"%s"`, v)) placeholder = append(placeholder, "?")
} }
if len(vs) == 1 { if _, err := ctx.Buffer.WriteString(fmt.Sprintf("`memo`.`visibility` IN (%s)", strings.Join(placeholder, ","))); err != nil {
condition = fmt.Sprintf("`memo`.`visibility` = %s", vs[0]) return err
} else {
condition = fmt.Sprintf("`memo`.`visibility` IN (%s)", strings.Join(vs, ","))
} }
ctx.Args = append(ctx.Args, values...)
} }
case "contains": case "contains":
if len(v.CallExpr.Args) != 1 { if len(v.CallExpr.Args) != 1 {
return "", errors.Errorf("invalid number of arguments for %s", v.CallExpr.Function) return errors.Errorf("invalid number of arguments for %s", v.CallExpr.Function)
} }
identifier, err := RestoreExprToSQL(v.CallExpr.Target) identifier, err := filter.GetIdentExprName(v.CallExpr.Target)
if err != nil { if err != nil {
return "", err return err
} }
if identifier != "content" { if identifier != "content" {
return "", errors.Errorf("invalid identifier for %s", v.CallExpr.Function) return errors.Errorf("invalid identifier for %s", v.CallExpr.Function)
} }
arg, err := filter.GetConstValue(v.CallExpr.Args[0]) arg, err := filter.GetConstValue(v.CallExpr.Args[0])
if err != nil { if err != nil {
return "", err return err
} }
condition = fmt.Sprintf("`memo`.`content` LIKE %s", fmt.Sprintf(`'%%%s%%'`, arg)) if _, err := ctx.Buffer.WriteString("`memo`.`content` LIKE ?"); err != nil {
case "!_": return err
if len(v.CallExpr.Args) != 1 {
return "", errors.Errorf("invalid number of arguments for %s", v.CallExpr.Function)
}
arg, err := RestoreExprToSQL(v.CallExpr.Args[0])
if err != nil {
return "", err
} }
condition = fmt.Sprintf("NOT (%s)", arg) ctx.Args = append(ctx.Args, fmt.Sprintf("%%%s%%", arg))
} }
case *exprv1.Expr_IdentExpr:
return v.IdentExpr.GetName(), nil
} }
return condition, nil return nil
} }

@ -8,50 +8,61 @@ import (
"github.com/usememos/memos/plugin/filter" "github.com/usememos/memos/plugin/filter"
) )
func TestRestoreExprToSQL(t *testing.T) { func TestConvertExprToSQL(t *testing.T) {
tests := []struct { tests := []struct {
filter string filter string
want string want string
args []any
}{ }{
{ {
filter: `tag in ["tag1", "tag2"]`, filter: `tag in ["tag1", "tag2"]`,
want: "(JSON_EXTRACT(`memo`.`payload`, '$.tags') LIKE '%\"tag1\"%' OR JSON_EXTRACT(`memo`.`payload`, '$.tags') LIKE '%\"tag2\"%')", want: "(JSON_EXTRACT(`memo`.`payload`, '$.tags') LIKE ? OR JSON_EXTRACT(`memo`.`payload`, '$.tags') LIKE ?)",
args: []any{`%"tag1"%`, `%"tag2"%`},
}, },
{ {
filter: `!(tag in ["tag1", "tag2"])`, filter: `!(tag in ["tag1", "tag2"])`,
want: "NOT ((JSON_EXTRACT(`memo`.`payload`, '$.tags') LIKE '%\"tag1\"%' OR JSON_EXTRACT(`memo`.`payload`, '$.tags') LIKE '%\"tag2\"%'))", want: "NOT ((JSON_EXTRACT(`memo`.`payload`, '$.tags') LIKE ? OR JSON_EXTRACT(`memo`.`payload`, '$.tags') LIKE ?))",
args: []any{`%"tag1"%`, `%"tag2"%`},
}, },
{ {
filter: `tag in ["tag1", "tag2"] || tag in ["tag3", "tag4"]`, filter: `tag in ["tag1", "tag2"] || tag in ["tag3", "tag4"]`,
want: "((JSON_EXTRACT(`memo`.`payload`, '$.tags') LIKE '%\"tag1\"%' OR JSON_EXTRACT(`memo`.`payload`, '$.tags') LIKE '%\"tag2\"%') OR (JSON_EXTRACT(`memo`.`payload`, '$.tags') LIKE '%\"tag3\"%' OR JSON_EXTRACT(`memo`.`payload`, '$.tags') LIKE '%\"tag4\"%'))", want: "((JSON_EXTRACT(`memo`.`payload`, '$.tags') LIKE ? OR JSON_EXTRACT(`memo`.`payload`, '$.tags') LIKE ?) OR (JSON_EXTRACT(`memo`.`payload`, '$.tags') LIKE ? OR JSON_EXTRACT(`memo`.`payload`, '$.tags') LIKE ?))",
args: []any{`%"tag1"%`, `%"tag2"%`, `%"tag3"%`, `%"tag4"%`},
}, },
{ {
filter: `content.contains("memos")`, filter: `content.contains("memos")`,
want: "`memo`.`content` LIKE '%memos%'", want: "`memo`.`content` LIKE ?",
args: []any{"%memos%"},
}, },
{ {
filter: `visibility in ["PUBLIC"]`, filter: `visibility in ["PUBLIC"]`,
want: "`memo`.`visibility` = \"PUBLIC\"", want: "`memo`.`visibility` IN (?)",
args: []any{"PUBLIC"},
}, },
{ {
filter: `visibility in ["PUBLIC", "PRIVATE"]`, filter: `visibility in ["PUBLIC", "PRIVATE"]`,
want: "`memo`.`visibility` IN (\"PUBLIC\",\"PRIVATE\")", want: "`memo`.`visibility` IN (?,?)",
args: []any{"PUBLIC", "PRIVATE"},
}, },
{ {
filter: `create_time == "2006-01-02T15:04:05+07:00"`, filter: `create_time == "2006-01-02T15:04:05+07:00"`,
want: "`memo`.`created_ts` = 1136189045", want: "`memo`.`created_ts` = ?",
args: []any{int64(1136189045)},
}, },
{ {
filter: `tag in ['tag1'] || content.contains('hello')`, filter: `tag in ['tag1'] || content.contains('hello')`,
want: "(JSON_EXTRACT(`memo`.`payload`, '$.tags') LIKE '%\"tag1\"%' OR `memo`.`content` LIKE '%hello%')", want: "(JSON_EXTRACT(`memo`.`payload`, '$.tags') LIKE ? OR `memo`.`content` LIKE ?)",
args: []any{`%"tag1"%`, "%hello%"},
}, },
} }
for _, tt := range tests { for _, tt := range tests {
parsedExpr, err := filter.Parse(tt.filter, filter.MemoFilterCELAttributes...) parsedExpr, err := filter.Parse(tt.filter, filter.MemoFilterCELAttributes...)
require.NoError(t, err) require.NoError(t, err)
result, err := RestoreExprToSQL(parsedExpr.GetExpr()) convertCtx := filter.NewConvertContext()
err = ConvertExprToSQL(convertCtx, parsedExpr.GetExpr())
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, tt.want, result) require.Equal(t, tt.want, convertCtx.Buffer.String())
require.Equal(t, tt.args, convertCtx.Args)
} }
} }

Loading…
Cancel
Save