You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
memos/plugin/filter/engine.go

181 lines
3.8 KiB
Go

package filter
import (
"context"
"fmt"
"strings"
"sync"
"github.com/google/cel-go/cel"
"github.com/pkg/errors"
)
// Engine parses CEL filters into a dialect-agnostic condition tree.
type Engine struct {
schema Schema
env *cel.Env
}
// NewEngine builds a new Engine for the provided schema.
func NewEngine(schema Schema) (*Engine, error) {
env, err := cel.NewEnv(schema.EnvOptions...)
if err != nil {
return nil, errors.Wrap(err, "failed to create CEL environment")
}
return &Engine{
schema: schema,
env: env,
}, nil
}
// Program stores a compiled filter condition.
type Program struct {
schema Schema
condition Condition
}
// ConditionTree exposes the underlying condition tree.
func (p *Program) ConditionTree() Condition {
return p.condition
}
// Compile parses the filter string into an executable program.
func (e *Engine) Compile(_ context.Context, filter string) (*Program, error) {
if strings.TrimSpace(filter) == "" {
return nil, errors.New("filter expression is empty")
}
filter = normalizeLegacyFilter(filter)
ast, issues := e.env.Compile(filter)
if issues != nil && issues.Err() != nil {
return nil, errors.Wrap(issues.Err(), "failed to compile filter")
}
parsed, err := cel.AstToParsedExpr(ast)
if err != nil {
return nil, errors.Wrap(err, "failed to convert AST")
}
cond, err := buildCondition(parsed.GetExpr(), e.schema)
if err != nil {
return nil, err
}
return &Program{
schema: e.schema,
condition: cond,
}, nil
}
// CompileToStatement compiles and renders the filter in a single step.
func (e *Engine) CompileToStatement(ctx context.Context, filter string, opts RenderOptions) (Statement, error) {
program, err := e.Compile(ctx, filter)
if err != nil {
return Statement{}, err
}
return program.Render(opts)
}
// RenderOptions configure SQL rendering.
type RenderOptions struct {
Dialect DialectName
PlaceholderOffset int
DisableNullChecks bool
}
// Statement contains the rendered SQL fragment and its args.
type Statement struct {
SQL string
Args []any
}
// Render converts the program into a dialect-specific SQL fragment.
func (p *Program) Render(opts RenderOptions) (Statement, error) {
renderer := newRenderer(p.schema, opts)
return renderer.Render(p.condition)
}
var (
defaultOnce sync.Once
defaultInst *Engine
defaultErr error
)
// DefaultEngine returns the process-wide memo filter engine.
func DefaultEngine() (*Engine, error) {
defaultOnce.Do(func() {
defaultInst, defaultErr = NewEngine(NewSchema())
})
return defaultInst, defaultErr
}
func normalizeLegacyFilter(expr string) string {
expr = rewriteNumericLogicalOperand(expr, "&&")
expr = rewriteNumericLogicalOperand(expr, "||")
return expr
}
func rewriteNumericLogicalOperand(expr, op string) string {
var builder strings.Builder
n := len(expr)
i := 0
var inQuote rune
for i < n {
ch := expr[i]
if inQuote != 0 {
builder.WriteByte(ch)
if ch == '\\' && i+1 < n {
builder.WriteByte(expr[i+1])
i += 2
continue
}
if ch == byte(inQuote) {
inQuote = 0
}
i++
continue
}
if ch == '\'' || ch == '"' {
inQuote = rune(ch)
builder.WriteByte(ch)
i++
continue
}
if strings.HasPrefix(expr[i:], op) {
builder.WriteString(op)
i += len(op)
// Preserve whitespace following the operator.
wsStart := i
for i < n && (expr[i] == ' ' || expr[i] == '\t') {
i++
}
builder.WriteString(expr[wsStart:i])
signStart := i
if i < n && (expr[i] == '+' || expr[i] == '-') {
i++
}
for i < n && expr[i] >= '0' && expr[i] <= '9' {
i++
}
if i > signStart {
numLiteral := expr[signStart:i]
builder.WriteString(fmt.Sprintf("(%s != 0)", numLiteral))
} else {
builder.WriteString(expr[signStart:i])
}
continue
}
builder.WriteByte(ch)
i++
}
return builder.String()
}