mirror of https://github.com/usememos/memos
You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
181 lines
3.8 KiB
Go
181 lines
3.8 KiB
Go
package filter
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"strings"
|
|
"sync"
|
|
|
|
"github.com/google/cel-go/cel"
|
|
"github.com/pkg/errors"
|
|
)
|
|
|
|
// Engine parses CEL filters into a dialect-agnostic condition tree.
|
|
type Engine struct {
|
|
schema Schema
|
|
env *cel.Env
|
|
}
|
|
|
|
// NewEngine builds a new Engine for the provided schema.
|
|
func NewEngine(schema Schema) (*Engine, error) {
|
|
env, err := cel.NewEnv(schema.EnvOptions...)
|
|
if err != nil {
|
|
return nil, errors.Wrap(err, "failed to create CEL environment")
|
|
}
|
|
return &Engine{
|
|
schema: schema,
|
|
env: env,
|
|
}, nil
|
|
}
|
|
|
|
// Program stores a compiled filter condition.
|
|
type Program struct {
|
|
schema Schema
|
|
condition Condition
|
|
}
|
|
|
|
// ConditionTree exposes the underlying condition tree.
|
|
func (p *Program) ConditionTree() Condition {
|
|
return p.condition
|
|
}
|
|
|
|
// Compile parses the filter string into an executable program.
|
|
func (e *Engine) Compile(_ context.Context, filter string) (*Program, error) {
|
|
if strings.TrimSpace(filter) == "" {
|
|
return nil, errors.New("filter expression is empty")
|
|
}
|
|
|
|
filter = normalizeLegacyFilter(filter)
|
|
|
|
ast, issues := e.env.Compile(filter)
|
|
if issues != nil && issues.Err() != nil {
|
|
return nil, errors.Wrap(issues.Err(), "failed to compile filter")
|
|
}
|
|
parsed, err := cel.AstToParsedExpr(ast)
|
|
if err != nil {
|
|
return nil, errors.Wrap(err, "failed to convert AST")
|
|
}
|
|
|
|
cond, err := buildCondition(parsed.GetExpr(), e.schema)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return &Program{
|
|
schema: e.schema,
|
|
condition: cond,
|
|
}, nil
|
|
}
|
|
|
|
// CompileToStatement compiles and renders the filter in a single step.
|
|
func (e *Engine) CompileToStatement(ctx context.Context, filter string, opts RenderOptions) (Statement, error) {
|
|
program, err := e.Compile(ctx, filter)
|
|
if err != nil {
|
|
return Statement{}, err
|
|
}
|
|
return program.Render(opts)
|
|
}
|
|
|
|
// RenderOptions configure SQL rendering.
|
|
type RenderOptions struct {
|
|
Dialect DialectName
|
|
PlaceholderOffset int
|
|
DisableNullChecks bool
|
|
}
|
|
|
|
// Statement contains the rendered SQL fragment and its args.
|
|
type Statement struct {
|
|
SQL string
|
|
Args []any
|
|
}
|
|
|
|
// Render converts the program into a dialect-specific SQL fragment.
|
|
func (p *Program) Render(opts RenderOptions) (Statement, error) {
|
|
renderer := newRenderer(p.schema, opts)
|
|
return renderer.Render(p.condition)
|
|
}
|
|
|
|
var (
|
|
defaultOnce sync.Once
|
|
defaultInst *Engine
|
|
defaultErr error
|
|
)
|
|
|
|
// DefaultEngine returns the process-wide memo filter engine.
|
|
func DefaultEngine() (*Engine, error) {
|
|
defaultOnce.Do(func() {
|
|
defaultInst, defaultErr = NewEngine(NewSchema())
|
|
})
|
|
return defaultInst, defaultErr
|
|
}
|
|
|
|
func normalizeLegacyFilter(expr string) string {
|
|
expr = rewriteNumericLogicalOperand(expr, "&&")
|
|
expr = rewriteNumericLogicalOperand(expr, "||")
|
|
return expr
|
|
}
|
|
|
|
func rewriteNumericLogicalOperand(expr, op string) string {
|
|
var builder strings.Builder
|
|
n := len(expr)
|
|
i := 0
|
|
var inQuote rune
|
|
|
|
for i < n {
|
|
ch := expr[i]
|
|
|
|
if inQuote != 0 {
|
|
builder.WriteByte(ch)
|
|
if ch == '\\' && i+1 < n {
|
|
builder.WriteByte(expr[i+1])
|
|
i += 2
|
|
continue
|
|
}
|
|
if ch == byte(inQuote) {
|
|
inQuote = 0
|
|
}
|
|
i++
|
|
continue
|
|
}
|
|
|
|
if ch == '\'' || ch == '"' {
|
|
inQuote = rune(ch)
|
|
builder.WriteByte(ch)
|
|
i++
|
|
continue
|
|
}
|
|
|
|
if strings.HasPrefix(expr[i:], op) {
|
|
builder.WriteString(op)
|
|
i += len(op)
|
|
|
|
// Preserve whitespace following the operator.
|
|
wsStart := i
|
|
for i < n && (expr[i] == ' ' || expr[i] == '\t') {
|
|
i++
|
|
}
|
|
builder.WriteString(expr[wsStart:i])
|
|
|
|
signStart := i
|
|
if i < n && (expr[i] == '+' || expr[i] == '-') {
|
|
i++
|
|
}
|
|
for i < n && expr[i] >= '0' && expr[i] <= '9' {
|
|
i++
|
|
}
|
|
if i > signStart {
|
|
numLiteral := expr[signStart:i]
|
|
builder.WriteString(fmt.Sprintf("(%s != 0)", numLiteral))
|
|
} else {
|
|
builder.WriteString(expr[signStart:i])
|
|
}
|
|
continue
|
|
}
|
|
|
|
builder.WriteByte(ch)
|
|
i++
|
|
}
|
|
|
|
return builder.String()
|
|
}
|