mirror of https://github.com/usememos/memos
feat: memo filter for sqlite
parent
0af08d9c06
commit
2d731c5cc5
@ -0,0 +1,47 @@
|
||||
package filter
|
||||
|
||||
import (
|
||||
"github.com/google/cel-go/cel"
|
||||
"github.com/pkg/errors"
|
||||
exprv1 "google.golang.org/genproto/googleapis/api/expr/v1alpha1"
|
||||
)
|
||||
|
||||
// MemoFilterCELAttributes are the CEL attributes for memo.
|
||||
var MemoFilterCELAttributes = []cel.EnvOption{
|
||||
cel.Variable("content", cel.StringType),
|
||||
cel.Variable("tag", cel.StringType),
|
||||
}
|
||||
|
||||
// Parse parses the filter string and returns the parsed expression.
|
||||
// The filter string should be a CEL expression.
|
||||
func Parse(filter string) (expr *exprv1.ParsedExpr, err error) {
|
||||
e, err := cel.NewEnv(MemoFilterCELAttributes...)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to create CEL environment")
|
||||
}
|
||||
ast, issues := e.Compile(filter)
|
||||
if issues != nil {
|
||||
return nil, errors.Errorf("failed to compile filter: %v", issues)
|
||||
}
|
||||
return cel.AstToParsedExpr(ast)
|
||||
}
|
||||
|
||||
// GetConstValue returns the constant value of the expression.
|
||||
func GetConstValue(expr *exprv1.Expr) (any, error) {
|
||||
switch v := expr.ExprKind.(type) {
|
||||
case *exprv1.Expr_ConstExpr:
|
||||
switch v.ConstExpr.ConstantKind.(type) {
|
||||
case *exprv1.Constant_StringValue:
|
||||
return v.ConstExpr.GetStringValue(), nil
|
||||
case *exprv1.Constant_Int64Value:
|
||||
return v.ConstExpr.GetInt64Value(), nil
|
||||
case *exprv1.Constant_Uint64Value:
|
||||
return v.ConstExpr.GetUint64Value(), nil
|
||||
case *exprv1.Constant_DoubleValue:
|
||||
return v.ConstExpr.GetDoubleValue(), nil
|
||||
case *exprv1.Constant_BoolValue:
|
||||
return v.ConstExpr.GetBoolValue(), nil
|
||||
}
|
||||
}
|
||||
return nil, errors.New("invalid constant expression")
|
||||
}
|
@ -0,0 +1,96 @@
|
||||
package sqlite
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"slices"
|
||||
"strings"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"github.com/usememos/memos/plugin/filter"
|
||||
exprv1 "google.golang.org/genproto/googleapis/api/expr/v1alpha1"
|
||||
)
|
||||
|
||||
func RestoreExprToSQL(expr *exprv1.Expr) (string, error) {
|
||||
var condition string
|
||||
switch v := expr.ExprKind.(type) {
|
||||
case *exprv1.Expr_CallExpr:
|
||||
switch v.CallExpr.Function {
|
||||
case "_||_", "_&&_":
|
||||
for _, arg := range v.CallExpr.Args {
|
||||
left, err := RestoreExprToSQL(arg)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
right, err := RestoreExprToSQL(arg)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
operator := "AND"
|
||||
if v.CallExpr.Function == "_||_" {
|
||||
operator = "OR"
|
||||
}
|
||||
condition = fmt.Sprintf("(%s %s %s)", left, operator, right)
|
||||
}
|
||||
case "_==_", "_!=_", "_<_", "_>_", "_<=_", "_>=_":
|
||||
if len(v.CallExpr.Args) != 2 {
|
||||
return "", errors.Errorf("invalid number of arguments for %s", v.CallExpr.Function)
|
||||
}
|
||||
// TODO(j): Implement this part.
|
||||
case "@in":
|
||||
if len(v.CallExpr.Args) != 2 {
|
||||
return "", errors.Errorf("invalid number of arguments for %s", v.CallExpr.Function)
|
||||
}
|
||||
factor := v.CallExpr.Args[0].GetIdentExpr().Name
|
||||
if !slices.Contains([]string{"tag"}, factor) {
|
||||
return "", errors.Errorf("invalid factor for %s", v.CallExpr.Function)
|
||||
}
|
||||
values := []any{}
|
||||
for _, element := range v.CallExpr.Args[1].GetListExpr().Elements {
|
||||
value, err := filter.GetConstValue(element)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
values = append(values, value)
|
||||
}
|
||||
if factor == "tag" {
|
||||
t := []string{}
|
||||
for _, v := range values {
|
||||
t = append(t, fmt.Sprintf("JSON_EXTRACT(`memo`.`payload`, '$.tags') LIKE %s", fmt.Sprintf(`%%"%s"%%`, v)))
|
||||
}
|
||||
if len(t) == 1 {
|
||||
condition = t[0]
|
||||
} else {
|
||||
condition = fmt.Sprintf("(%s)", strings.Join(t, " OR "))
|
||||
}
|
||||
}
|
||||
case "contains":
|
||||
if len(v.CallExpr.Args) != 1 {
|
||||
return "", errors.Errorf("invalid number of arguments for %s", v.CallExpr.Function)
|
||||
}
|
||||
factor, err := RestoreExprToSQL(v.CallExpr.Target)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if factor != "content" {
|
||||
return "", errors.Errorf("invalid factor for %s", v.CallExpr.Function)
|
||||
}
|
||||
arg, err := filter.GetConstValue(v.CallExpr.Args[0])
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
condition = fmt.Sprintf("JSON_EXTRACT(`memo`.`payload`, '$.content') LIKE %s", fmt.Sprintf(`%%"%s"%%`, arg))
|
||||
case "!_":
|
||||
if len(v.CallExpr.Args) != 1 {
|
||||
return "", errors.Errorf("invalid number of arguments for %s", v.CallExpr.Function)
|
||||
}
|
||||
arg, err := RestoreExprToSQL(v.CallExpr.Args[0])
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
condition = fmt.Sprintf("NOT (%s)", arg)
|
||||
}
|
||||
case *exprv1.Expr_IdentExpr:
|
||||
return v.IdentExpr.GetName(), nil
|
||||
}
|
||||
return condition, nil
|
||||
}
|
@ -0,0 +1,40 @@
|
||||
package sqlite
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/usememos/memos/plugin/filter"
|
||||
)
|
||||
|
||||
func TestRestoreExprToSQL(t *testing.T) {
|
||||
tests := []struct {
|
||||
filter string
|
||||
want string
|
||||
}{
|
||||
{
|
||||
filter: `tag in ["tag1", "tag2"]`,
|
||||
want: "(JSON_EXTRACT(`memo`.`payload`, '$.tags') LIKE %\"tag1\"% OR JSON_EXTRACT(`memo`.`payload`, '$.tags') LIKE %\"tag2\"%)",
|
||||
},
|
||||
{
|
||||
filter: `!(tag in ["tag1", "tag2"])`,
|
||||
want: "NOT ((JSON_EXTRACT(`memo`.`payload`, '$.tags') LIKE %\"tag1\"% OR JSON_EXTRACT(`memo`.`payload`, '$.tags') LIKE %\"tag2\"%))",
|
||||
},
|
||||
{
|
||||
filter: `tag in ["tag1", "tag2"] || tag in ["tag3", "tag4"]`,
|
||||
want: "((JSON_EXTRACT(`memo`.`payload`, '$.tags') LIKE %\"tag3\"% OR JSON_EXTRACT(`memo`.`payload`, '$.tags') LIKE %\"tag4\"%) OR (JSON_EXTRACT(`memo`.`payload`, '$.tags') LIKE %\"tag3\"% OR JSON_EXTRACT(`memo`.`payload`, '$.tags') LIKE %\"tag4\"%))",
|
||||
},
|
||||
{
|
||||
filter: `content.contains("hello")`,
|
||||
want: "JSON_EXTRACT(`memo`.`payload`, '$.content') LIKE %\"hello\"%",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
parsedExpr, err := filter.Parse(tt.filter)
|
||||
require.NoError(t, err)
|
||||
result, err := RestoreExprToSQL(parsedExpr.GetExpr())
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, tt.want, result)
|
||||
}
|
||||
}
|
@ -1,13 +0,0 @@
|
||||
package store
|
||||
|
||||
type LogicOperator string
|
||||
|
||||
const (
|
||||
AND LogicOperator = "AND"
|
||||
OR LogicOperator = "OR"
|
||||
)
|
||||
|
||||
type QueryExpression struct {
|
||||
Operator LogicOperator
|
||||
Children []*QueryExpression
|
||||
}
|
Loading…
Reference in New Issue