feat: implement user-specific SQL converter for filtering in user service

pull/5091/head v0.25.1
Steven 1 month ago
parent 383553d3c8
commit c3d4f8e9d1

@ -11,23 +11,39 @@ import (
// CommonSQLConverter handles the common CEL to SQL conversion logic. // CommonSQLConverter handles the common CEL to SQL conversion logic.
type CommonSQLConverter struct { type CommonSQLConverter struct {
dialect SQLDialect dialect SQLDialect
paramIndex int paramIndex int
allowedFields []string
entityType string
} }
// NewCommonSQLConverter creates a new converter with the specified dialect. // NewCommonSQLConverter creates a new converter with the specified dialect for memo filters.
func NewCommonSQLConverter(dialect SQLDialect) *CommonSQLConverter { func NewCommonSQLConverter(dialect SQLDialect) *CommonSQLConverter {
return &CommonSQLConverter{ return &CommonSQLConverter{
dialect: dialect, dialect: dialect,
paramIndex: 1, paramIndex: 1,
allowedFields: []string{"creator_id", "created_ts", "updated_ts", "visibility", "content", "pinned", "has_task_list", "has_link", "has_code", "has_incomplete_tasks"},
entityType: "memo",
} }
} }
// NewCommonSQLConverterWithOffset creates a new converter with the specified dialect and parameter offset. // NewCommonSQLConverterWithOffset creates a new converter with the specified dialect and parameter offset for memo filters.
func NewCommonSQLConverterWithOffset(dialect SQLDialect, offset int) *CommonSQLConverter { func NewCommonSQLConverterWithOffset(dialect SQLDialect, offset int) *CommonSQLConverter {
return &CommonSQLConverter{ return &CommonSQLConverter{
dialect: dialect, dialect: dialect,
paramIndex: offset + 1, paramIndex: offset + 1,
allowedFields: []string{"creator_id", "created_ts", "updated_ts", "visibility", "content", "pinned", "has_task_list", "has_link", "has_code", "has_incomplete_tasks"},
entityType: "memo",
}
}
// NewUserSQLConverter creates a new converter for user filters.
func NewUserSQLConverter(dialect SQLDialect) *CommonSQLConverter {
return &CommonSQLConverter{
dialect: dialect,
paramIndex: 1,
allowedFields: []string{"username"},
entityType: "user",
} }
} }
@ -124,7 +140,7 @@ func (c *CommonSQLConverter) handleComparisonOperator(ctx *ConvertContext, callE
return err return err
} }
if !slices.Contains([]string{"creator_id", "created_ts", "updated_ts", "visibility", "content", "pinned", "has_task_list", "has_link", "has_code", "has_incomplete_tasks"}, identifier) { if !slices.Contains(c.allowedFields, identifier) {
return errors.Errorf("invalid identifier for %s", callExpr.Function) return errors.Errorf("invalid identifier for %s", callExpr.Function)
} }
@ -135,20 +151,35 @@ func (c *CommonSQLConverter) handleComparisonOperator(ctx *ConvertContext, callE
operator := c.getComparisonOperator(callExpr.Function) operator := c.getComparisonOperator(callExpr.Function)
switch identifier { // Handle memo fields
case "created_ts", "updated_ts": if c.entityType == "memo" {
return c.handleTimestampComparison(ctx, identifier, operator, value) switch identifier {
case "visibility", "content": case "created_ts", "updated_ts":
return c.handleStringComparison(ctx, identifier, operator, value) return c.handleTimestampComparison(ctx, identifier, operator, value)
case "creator_id": case "visibility", "content":
return c.handleIntComparison(ctx, identifier, operator, value) return c.handleStringComparison(ctx, identifier, operator, value)
case "pinned": case "creator_id":
return c.handlePinnedComparison(ctx, operator, value) return c.handleIntComparison(ctx, identifier, operator, value)
case "has_task_list", "has_link", "has_code", "has_incomplete_tasks": case "pinned":
return c.handleBooleanComparison(ctx, identifier, operator, value) return c.handlePinnedComparison(ctx, operator, value)
default: case "has_task_list", "has_link", "has_code", "has_incomplete_tasks":
return errors.Errorf("unsupported identifier in comparison: %s", identifier) return c.handleBooleanComparison(ctx, identifier, operator, value)
default:
return errors.Errorf("unsupported identifier in comparison: %s", identifier)
}
}
// Handle user fields
if c.entityType == "user" {
switch identifier {
case "username":
return c.handleUserStringComparison(ctx, identifier, operator, value)
default:
return errors.Errorf("unsupported user identifier in comparison: %s", identifier)
}
} }
return errors.Errorf("unsupported entity type: %s", c.entityType)
} }
func (c *CommonSQLConverter) handleSizeComparison(ctx *ConvertContext, callExpr *exprv1.Expr_Call, sizeCall *exprv1.Expr_Call) error { func (c *CommonSQLConverter) handleSizeComparison(ctx *ConvertContext, callExpr *exprv1.Expr_Call, sizeCall *exprv1.Expr_Call) error {
@ -400,6 +431,11 @@ func (c *CommonSQLConverter) handleContainsOperator(ctx *ConvertContext, callExp
func (c *CommonSQLConverter) handleIdentifier(ctx *ConvertContext, identExpr *exprv1.Expr_Ident) error { func (c *CommonSQLConverter) handleIdentifier(ctx *ConvertContext, identExpr *exprv1.Expr_Ident) error {
identifier := identExpr.GetName() identifier := identExpr.GetName()
// Only memo entity has boolean identifiers that can be used standalone
if c.entityType != "memo" {
return errors.Errorf("invalid identifier %s for entity type %s", identifier, c.entityType)
}
if !slices.Contains([]string{"pinned", "has_task_list", "has_link", "has_code", "has_incomplete_tasks"}, identifier) { if !slices.Contains([]string{"pinned", "has_task_list", "has_link", "has_code", "has_incomplete_tasks"}, identifier) {
return errors.Errorf("invalid identifier %s", identifier) return errors.Errorf("invalid identifier %s", identifier)
} }
@ -489,6 +525,36 @@ func (c *CommonSQLConverter) handleStringComparison(ctx *ConvertContext, field,
return nil return nil
} }
func (c *CommonSQLConverter) handleUserStringComparison(ctx *ConvertContext, field, operator string, value interface{}) error {
if operator != "=" && operator != "!=" {
return errors.Errorf("invalid operator for %s", field)
}
valueStr, ok := value.(string)
if !ok {
return errors.New("invalid string value")
}
tablePrefix := c.dialect.GetTablePrefix("user")
if _, ok := c.dialect.(*PostgreSQLDialect); ok {
// PostgreSQL doesn't use backticks
if _, err := ctx.Buffer.WriteString(fmt.Sprintf("%s.%s %s %s", tablePrefix, field, operator, c.dialect.GetParameterPlaceholder(c.paramIndex))); err != nil {
return err
}
} else {
// MySQL and SQLite use backticks
if _, err := ctx.Buffer.WriteString(fmt.Sprintf("%s.`%s` %s %s", tablePrefix, field, operator, c.dialect.GetParameterPlaceholder(c.paramIndex))); err != nil {
return err
}
}
ctx.Args = append(ctx.Args, valueStr)
c.paramIndex++
return nil
}
func (c *CommonSQLConverter) handleIntComparison(ctx *ConvertContext, field, operator string, value interface{}) error { func (c *CommonSQLConverter) handleIntComparison(ctx *ConvertContext, field, operator string, value interface{}) error {
if operator != "=" && operator != "!=" { if operator != "=" && operator != "!=" {
return errors.Errorf("invalid operator for %s", field) return errors.Errorf("invalid operator for %s", field)

@ -1368,7 +1368,7 @@ func (s *APIV1Service) validateUserFilter(_ context.Context, filterStr string) e
dialect = &filter.SQLiteDialect{} dialect = &filter.SQLiteDialect{}
} }
converter := filter.NewCommonSQLConverter(dialect) converter := filter.NewUserSQLConverter(dialect)
err = converter.ConvertExprToSQL(convertCtx, parsedExpr.GetExpr()) err = converter.ConvertExprToSQL(convertCtx, parsedExpr.GetExpr())
if err != nil { if err != nil {
return errors.Wrap(err, "failed to convert filter to SQL") return errors.Wrap(err, "failed to convert filter to SQL")

@ -94,7 +94,7 @@ func (d *DB) ListUsers(ctx context.Context, find *store.FindUser) ([]*store.User
} }
convertCtx := filter.NewConvertContext() convertCtx := filter.NewConvertContext()
// ConvertExprToSQL converts the parsed expression to a SQL condition string. // ConvertExprToSQL converts the parsed expression to a SQL condition string.
converter := filter.NewCommonSQLConverter(&filter.MySQLDialect{}) converter := filter.NewUserSQLConverter(&filter.MySQLDialect{})
if err := converter.ConvertExprToSQL(convertCtx, parsedExpr.GetExpr()); err != nil { if err := converter.ConvertExprToSQL(convertCtx, parsedExpr.GetExpr()); err != nil {
return nil, err return nil, err
} }

@ -95,7 +95,7 @@ func (d *DB) ListUsers(ctx context.Context, find *store.FindUser) ([]*store.User
} }
convertCtx := filter.NewConvertContext() convertCtx := filter.NewConvertContext()
// ConvertExprToSQL converts the parsed expression to a SQL condition string. // ConvertExprToSQL converts the parsed expression to a SQL condition string.
converter := filter.NewCommonSQLConverter(&filter.PostgreSQLDialect{}) converter := filter.NewUserSQLConverter(&filter.PostgreSQLDialect{})
if err := converter.ConvertExprToSQL(convertCtx, parsedExpr.GetExpr()); err != nil { if err := converter.ConvertExprToSQL(convertCtx, parsedExpr.GetExpr()); err != nil {
return nil, err return nil, err
} }

@ -96,7 +96,7 @@ func (d *DB) ListUsers(ctx context.Context, find *store.FindUser) ([]*store.User
} }
convertCtx := filter.NewConvertContext() convertCtx := filter.NewConvertContext()
// ConvertExprToSQL converts the parsed expression to a SQL condition string. // ConvertExprToSQL converts the parsed expression to a SQL condition string.
converter := filter.NewCommonSQLConverter(&filter.SQLiteDialect{}) converter := filter.NewUserSQLConverter(&filter.SQLiteDialect{})
if err := converter.ConvertExprToSQL(convertCtx, parsedExpr.GetExpr()); err != nil { if err := converter.ConvertExprToSQL(convertCtx, parsedExpr.GetExpr()); err != nil {
return nil, err return nil, err
} }

Loading…
Cancel
Save