From c3d4f8e9d1f74eb29143f52bb0c2e5e9027c07e5 Mon Sep 17 00:00:00 2001 From: Steven Date: Wed, 10 Sep 2025 21:05:26 +0800 Subject: [PATCH] feat: implement user-specific SQL converter for filtering in user service --- plugin/filter/common_converter.go | 110 +++++++++++++++++++++------ server/router/api/v1/user_service.go | 2 +- store/db/mysql/user.go | 2 +- store/db/postgres/user.go | 2 +- store/db/sqlite/user.go | 2 +- 5 files changed, 92 insertions(+), 26 deletions(-) diff --git a/plugin/filter/common_converter.go b/plugin/filter/common_converter.go index 60beb825d..aa3942929 100644 --- a/plugin/filter/common_converter.go +++ b/plugin/filter/common_converter.go @@ -11,23 +11,39 @@ import ( // CommonSQLConverter handles the common CEL to SQL conversion logic. type CommonSQLConverter struct { - dialect SQLDialect - paramIndex int + dialect SQLDialect + paramIndex int + allowedFields []string + entityType string } -// NewCommonSQLConverter creates a new converter with the specified dialect. +// NewCommonSQLConverter creates a new converter with the specified dialect for memo filters. func NewCommonSQLConverter(dialect SQLDialect) *CommonSQLConverter { return &CommonSQLConverter{ - dialect: dialect, - paramIndex: 1, + dialect: dialect, + paramIndex: 1, + allowedFields: []string{"creator_id", "created_ts", "updated_ts", "visibility", "content", "pinned", "has_task_list", "has_link", "has_code", "has_incomplete_tasks"}, + entityType: "memo", } } -// NewCommonSQLConverterWithOffset creates a new converter with the specified dialect and parameter offset. +// NewCommonSQLConverterWithOffset creates a new converter with the specified dialect and parameter offset for memo filters. func NewCommonSQLConverterWithOffset(dialect SQLDialect, offset int) *CommonSQLConverter { return &CommonSQLConverter{ - dialect: dialect, - paramIndex: offset + 1, + dialect: dialect, + paramIndex: offset + 1, + allowedFields: []string{"creator_id", "created_ts", "updated_ts", "visibility", "content", "pinned", "has_task_list", "has_link", "has_code", "has_incomplete_tasks"}, + entityType: "memo", + } +} + +// NewUserSQLConverter creates a new converter for user filters. +func NewUserSQLConverter(dialect SQLDialect) *CommonSQLConverter { + return &CommonSQLConverter{ + dialect: dialect, + paramIndex: 1, + allowedFields: []string{"username"}, + entityType: "user", } } @@ -124,7 +140,7 @@ func (c *CommonSQLConverter) handleComparisonOperator(ctx *ConvertContext, callE return err } - if !slices.Contains([]string{"creator_id", "created_ts", "updated_ts", "visibility", "content", "pinned", "has_task_list", "has_link", "has_code", "has_incomplete_tasks"}, identifier) { + if !slices.Contains(c.allowedFields, identifier) { return errors.Errorf("invalid identifier for %s", callExpr.Function) } @@ -135,20 +151,35 @@ func (c *CommonSQLConverter) handleComparisonOperator(ctx *ConvertContext, callE operator := c.getComparisonOperator(callExpr.Function) - switch identifier { - case "created_ts", "updated_ts": - return c.handleTimestampComparison(ctx, identifier, operator, value) - case "visibility", "content": - return c.handleStringComparison(ctx, identifier, operator, value) - case "creator_id": - return c.handleIntComparison(ctx, identifier, operator, value) - case "pinned": - return c.handlePinnedComparison(ctx, operator, value) - case "has_task_list", "has_link", "has_code", "has_incomplete_tasks": - return c.handleBooleanComparison(ctx, identifier, operator, value) - default: - return errors.Errorf("unsupported identifier in comparison: %s", identifier) + // Handle memo fields + if c.entityType == "memo" { + switch identifier { + case "created_ts", "updated_ts": + return c.handleTimestampComparison(ctx, identifier, operator, value) + case "visibility", "content": + return c.handleStringComparison(ctx, identifier, operator, value) + case "creator_id": + return c.handleIntComparison(ctx, identifier, operator, value) + case "pinned": + return c.handlePinnedComparison(ctx, operator, value) + case "has_task_list", "has_link", "has_code", "has_incomplete_tasks": + return c.handleBooleanComparison(ctx, identifier, operator, value) + default: + return errors.Errorf("unsupported identifier in comparison: %s", identifier) + } + } + + // Handle user fields + if c.entityType == "user" { + switch identifier { + case "username": + return c.handleUserStringComparison(ctx, identifier, operator, value) + default: + return errors.Errorf("unsupported user identifier in comparison: %s", identifier) + } } + + return errors.Errorf("unsupported entity type: %s", c.entityType) } func (c *CommonSQLConverter) handleSizeComparison(ctx *ConvertContext, callExpr *exprv1.Expr_Call, sizeCall *exprv1.Expr_Call) error { @@ -400,6 +431,11 @@ func (c *CommonSQLConverter) handleContainsOperator(ctx *ConvertContext, callExp func (c *CommonSQLConverter) handleIdentifier(ctx *ConvertContext, identExpr *exprv1.Expr_Ident) error { identifier := identExpr.GetName() + // Only memo entity has boolean identifiers that can be used standalone + if c.entityType != "memo" { + return errors.Errorf("invalid identifier %s for entity type %s", identifier, c.entityType) + } + if !slices.Contains([]string{"pinned", "has_task_list", "has_link", "has_code", "has_incomplete_tasks"}, identifier) { return errors.Errorf("invalid identifier %s", identifier) } @@ -489,6 +525,36 @@ func (c *CommonSQLConverter) handleStringComparison(ctx *ConvertContext, field, return nil } +func (c *CommonSQLConverter) handleUserStringComparison(ctx *ConvertContext, field, operator string, value interface{}) error { + if operator != "=" && operator != "!=" { + return errors.Errorf("invalid operator for %s", field) + } + + valueStr, ok := value.(string) + if !ok { + return errors.New("invalid string value") + } + + tablePrefix := c.dialect.GetTablePrefix("user") + + if _, ok := c.dialect.(*PostgreSQLDialect); ok { + // PostgreSQL doesn't use backticks + if _, err := ctx.Buffer.WriteString(fmt.Sprintf("%s.%s %s %s", tablePrefix, field, operator, c.dialect.GetParameterPlaceholder(c.paramIndex))); err != nil { + return err + } + } else { + // MySQL and SQLite use backticks + if _, err := ctx.Buffer.WriteString(fmt.Sprintf("%s.`%s` %s %s", tablePrefix, field, operator, c.dialect.GetParameterPlaceholder(c.paramIndex))); err != nil { + return err + } + } + + ctx.Args = append(ctx.Args, valueStr) + c.paramIndex++ + + return nil +} + func (c *CommonSQLConverter) handleIntComparison(ctx *ConvertContext, field, operator string, value interface{}) error { if operator != "=" && operator != "!=" { return errors.Errorf("invalid operator for %s", field) diff --git a/server/router/api/v1/user_service.go b/server/router/api/v1/user_service.go index 9bdc5b58f..207427611 100644 --- a/server/router/api/v1/user_service.go +++ b/server/router/api/v1/user_service.go @@ -1368,7 +1368,7 @@ func (s *APIV1Service) validateUserFilter(_ context.Context, filterStr string) e dialect = &filter.SQLiteDialect{} } - converter := filter.NewCommonSQLConverter(dialect) + converter := filter.NewUserSQLConverter(dialect) err = converter.ConvertExprToSQL(convertCtx, parsedExpr.GetExpr()) if err != nil { return errors.Wrap(err, "failed to convert filter to SQL") diff --git a/store/db/mysql/user.go b/store/db/mysql/user.go index 291be66e2..e2bddfd0d 100644 --- a/store/db/mysql/user.go +++ b/store/db/mysql/user.go @@ -94,7 +94,7 @@ func (d *DB) ListUsers(ctx context.Context, find *store.FindUser) ([]*store.User } convertCtx := filter.NewConvertContext() // ConvertExprToSQL converts the parsed expression to a SQL condition string. - converter := filter.NewCommonSQLConverter(&filter.MySQLDialect{}) + converter := filter.NewUserSQLConverter(&filter.MySQLDialect{}) if err := converter.ConvertExprToSQL(convertCtx, parsedExpr.GetExpr()); err != nil { return nil, err } diff --git a/store/db/postgres/user.go b/store/db/postgres/user.go index bbca02c1d..b582cff24 100644 --- a/store/db/postgres/user.go +++ b/store/db/postgres/user.go @@ -95,7 +95,7 @@ func (d *DB) ListUsers(ctx context.Context, find *store.FindUser) ([]*store.User } convertCtx := filter.NewConvertContext() // ConvertExprToSQL converts the parsed expression to a SQL condition string. - converter := filter.NewCommonSQLConverter(&filter.PostgreSQLDialect{}) + converter := filter.NewUserSQLConverter(&filter.PostgreSQLDialect{}) if err := converter.ConvertExprToSQL(convertCtx, parsedExpr.GetExpr()); err != nil { return nil, err } diff --git a/store/db/sqlite/user.go b/store/db/sqlite/user.go index b1e8df06e..2ebf64dd7 100644 --- a/store/db/sqlite/user.go +++ b/store/db/sqlite/user.go @@ -96,7 +96,7 @@ func (d *DB) ListUsers(ctx context.Context, find *store.FindUser) ([]*store.User } convertCtx := filter.NewConvertContext() // ConvertExprToSQL converts the parsed expression to a SQL condition string. - converter := filter.NewCommonSQLConverter(&filter.SQLiteDialect{}) + converter := filter.NewUserSQLConverter(&filter.SQLiteDialect{}) if err := converter.ConvertExprToSQL(convertCtx, parsedExpr.GetExpr()); err != nil { return nil, err }