From ed23cbc011a621a81c130b8ded889829c5e21daf Mon Sep 17 00:00:00 2001 From: johnnyjoy Date: Wed, 23 Jul 2025 22:10:16 +0800 Subject: [PATCH] refactor: memo filter --- plugin/filter/common_converter.go | 220 +++++++++-- plugin/filter/dialect.go | 28 +- proto/api/v1/memo_service.proto | 20 +- proto/gen/api/v1/memo_service.pb.go | 35 +- proto/gen/api/v1/memo_service.pb.gw.go | 92 ----- proto/gen/openapi.yaml | 84 ---- server/router/api/v1/memo_service.go | 7 - server/router/api/v1/shortcut_service.go | 18 +- store/db/mysql/memo.go | 3 +- store/db/mysql/memo_filter.go | 357 ----------------- store/db/mysql/memo_filter_test.go | 4 +- store/db/mysql/memo_relation.go | 3 +- store/db/postgres/memo.go | 3 +- store/db/postgres/memo_filter.go | 373 ------------------ store/db/postgres/memo_filter_test.go | 4 +- store/db/postgres/memo_relation.go | 3 +- store/db/sqlite/memo.go | 3 +- store/db/sqlite/memo_filter.go | 357 ----------------- store/db/sqlite/memo_filter_test.go | 4 +- store/db/sqlite/memo_relation.go | 3 +- store/driver.go | 7 - .../ActionButton/AddMemoRelationPopover.tsx | 6 +- .../PagedMemoList/PagedMemoList.tsx | 13 +- web/src/pages/Home.tsx | 3 +- web/src/pages/UserProfile.tsx | 3 +- web/src/store/common.ts | 4 + web/src/types/proto/api/v1/memo_service.ts | 101 +---- .../types/proto/google/protobuf/descriptor.ts | 225 ++++++++++- 28 files changed, 524 insertions(+), 1459 deletions(-) delete mode 100644 store/db/mysql/memo_filter.go delete mode 100644 store/db/postgres/memo_filter.go delete mode 100644 store/db/sqlite/memo_filter.go diff --git a/plugin/filter/common_converter.go b/plugin/filter/common_converter.go index 407e4d9ea..2a9e480b2 100644 --- a/plugin/filter/common_converter.go +++ b/plugin/filter/common_converter.go @@ -23,6 +23,14 @@ func NewCommonSQLConverter(dialect SQLDialect) *CommonSQLConverter { } } +// NewCommonSQLConverterWithOffset creates a new converter with the specified dialect and parameter offset. +func NewCommonSQLConverterWithOffset(dialect SQLDialect, offset int) *CommonSQLConverter { + return &CommonSQLConverter{ + dialect: dialect, + paramIndex: offset + 1, + } +} + // ConvertExprToSQL converts a CEL expression to SQL using the configured dialect. func (c *CommonSQLConverter) ConvertExprToSQL(ctx *ConvertContext, expr *exprv1.Expr) error { if v, ok := expr.ExprKind.(*exprv1.Expr_CallExpr); ok { @@ -114,7 +122,7 @@ func (c *CommonSQLConverter) handleComparisonOperator(ctx *ConvertContext, callE return err } - if !slices.Contains([]string{"creator_id", "created_ts", "updated_ts", "visibility", "content", "has_task_list"}, identifier) { + if !slices.Contains([]string{"creator_id", "created_ts", "updated_ts", "visibility", "content", "has_task_list", "has_link", "has_code", "has_incomplete_tasks"}, identifier) { return errors.Errorf("invalid identifier for %s", callExpr.Function) } @@ -132,7 +140,7 @@ func (c *CommonSQLConverter) handleComparisonOperator(ctx *ConvertContext, callE return c.handleStringComparison(ctx, identifier, operator, value) case "creator_id": return c.handleIntComparison(ctx, identifier, operator, value) - case "has_task_list": + case "has_task_list", "has_link", "has_code", "has_incomplete_tasks": return c.handleBooleanComparison(ctx, identifier, operator, value) } @@ -226,15 +234,18 @@ func (c *CommonSQLConverter) handleElementInTags(ctx *ConvertContext, elementExp } // Use dialect-specific JSON contains logic - sqlExpr := c.dialect.GetJSONContains("$.tags", "element") + template := c.dialect.GetJSONContains("$.tags", "element") + sqlExpr := strings.Replace(template, "?", c.dialect.GetParameterPlaceholder(c.paramIndex), 1) if _, err := ctx.Buffer.WriteString(sqlExpr); err != nil { return err } - // For SQLite, we need a different approach since it uses LIKE + // Handle args based on dialect if _, ok := c.dialect.(*SQLiteDialect); ok { + // SQLite uses LIKE with pattern ctx.Args = append(ctx.Args, fmt.Sprintf(`%%"%s"%%`, element)) } else { + // MySQL and PostgreSQL expect plain values ctx.Args = append(ctx.Args, element) } c.paramIndex++ @@ -251,7 +262,10 @@ func (c *CommonSQLConverter) handleTagInList(ctx *ConvertContext, values []any) subconditions = append(subconditions, c.dialect.GetJSONLike("$.tags", "pattern")) args = append(args, fmt.Sprintf(`%%"%s"%%`, v)) } else { - subconditions = append(subconditions, c.dialect.GetJSONContains("$.tags", "element")) + // Replace ? with proper placeholder for each dialect + template := c.dialect.GetJSONContains("$.tags", "element") + sql := strings.Replace(template, "?", c.dialect.GetParameterPlaceholder(c.paramIndex), 1) + subconditions = append(subconditions, sql) args = append(args, v) } c.paramIndex++ @@ -279,8 +293,14 @@ func (c *CommonSQLConverter) handleVisibilityInList(ctx *ConvertContext, values } tablePrefix := c.dialect.GetTablePrefix() - if _, err := ctx.Buffer.WriteString(fmt.Sprintf("%s.`visibility` IN (%s)", tablePrefix, strings.Join(placeholders, ","))); err != nil { - return err + if _, ok := c.dialect.(*PostgreSQLDialect); ok { + if _, err := ctx.Buffer.WriteString(fmt.Sprintf("%s.visibility IN (%s)", tablePrefix, strings.Join(placeholders, ","))); err != nil { + return err + } + } else { + if _, err := ctx.Buffer.WriteString(fmt.Sprintf("%s.`visibility` IN (%s)", tablePrefix, strings.Join(placeholders, ","))); err != nil { + return err + } } ctx.Args = append(ctx.Args, values...) @@ -307,8 +327,16 @@ func (c *CommonSQLConverter) handleContainsOperator(ctx *ConvertContext, callExp } tablePrefix := c.dialect.GetTablePrefix() - if _, err := ctx.Buffer.WriteString(fmt.Sprintf("%s.`content` LIKE %s", tablePrefix, c.dialect.GetParameterPlaceholder(c.paramIndex))); err != nil { - return err + + // PostgreSQL uses ILIKE and no backticks + if _, ok := c.dialect.(*PostgreSQLDialect); ok { + if _, err := ctx.Buffer.WriteString(fmt.Sprintf("%s.content ILIKE %s", tablePrefix, c.dialect.GetParameterPlaceholder(c.paramIndex))); err != nil { + return err + } + } else { + if _, err := ctx.Buffer.WriteString(fmt.Sprintf("%s.`content` LIKE %s", tablePrefix, c.dialect.GetParameterPlaceholder(c.paramIndex))); err != nil { + return err + } } ctx.Args = append(ctx.Args, fmt.Sprintf("%%%s%%", arg)) @@ -320,19 +348,37 @@ func (c *CommonSQLConverter) handleContainsOperator(ctx *ConvertContext, callExp func (c *CommonSQLConverter) handleIdentifier(ctx *ConvertContext, identExpr *exprv1.Expr_Ident) error { identifier := identExpr.GetName() - if !slices.Contains([]string{"pinned", "has_task_list"}, identifier) { + if !slices.Contains([]string{"pinned", "has_task_list", "has_link", "has_code", "has_incomplete_tasks"}, identifier) { return errors.Errorf("invalid identifier %s", identifier) } if identifier == "pinned" { tablePrefix := c.dialect.GetTablePrefix() - if _, err := ctx.Buffer.WriteString(fmt.Sprintf("%s.`pinned` IS TRUE", tablePrefix)); err != nil { - return err + if _, ok := c.dialect.(*PostgreSQLDialect); ok { + if _, err := ctx.Buffer.WriteString(fmt.Sprintf("%s.pinned IS TRUE", tablePrefix)); err != nil { + return err + } + } else { + if _, err := ctx.Buffer.WriteString(fmt.Sprintf("%s.`pinned` IS TRUE", tablePrefix)); err != nil { + return err + } } } else if identifier == "has_task_list" { if _, err := ctx.Buffer.WriteString(c.dialect.GetBooleanCheck("$.property.hasTaskList")); err != nil { return err } + } else if identifier == "has_link" { + if _, err := ctx.Buffer.WriteString(c.dialect.GetBooleanCheck("$.property.hasLink")); err != nil { + return err + } + } else if identifier == "has_code" { + if _, err := ctx.Buffer.WriteString(c.dialect.GetBooleanCheck("$.property.hasCode")); err != nil { + return err + } + } else if identifier == "has_incomplete_tasks" { + if _, err := ctx.Buffer.WriteString(c.dialect.GetBooleanCheck("$.property.hasIncompleteTasks")); err != nil { + return err + } } return nil @@ -366,15 +412,23 @@ func (c *CommonSQLConverter) handleStringComparison(ctx *ConvertContext, field, } tablePrefix := c.dialect.GetTablePrefix() - fieldName := field - if field == "visibility" { - fieldName = "`visibility`" - } else if field == "content" { - fieldName = "`content`" - } - if _, err := ctx.Buffer.WriteString(fmt.Sprintf("%s.%s %s %s", tablePrefix, fieldName, operator, c.dialect.GetParameterPlaceholder(c.paramIndex))); err != nil { - return err + if _, ok := c.dialect.(*PostgreSQLDialect); ok { + // PostgreSQL doesn't use backticks + if _, err := ctx.Buffer.WriteString(fmt.Sprintf("%s.%s %s %s", tablePrefix, field, operator, c.dialect.GetParameterPlaceholder(c.paramIndex))); err != nil { + return err + } + } else { + // MySQL and SQLite use backticks + fieldName := field + if field == "visibility" { + fieldName = "`visibility`" + } else if field == "content" { + fieldName = "`content`" + } + if _, err := ctx.Buffer.WriteString(fmt.Sprintf("%s.%s %s %s", tablePrefix, fieldName, operator, c.dialect.GetParameterPlaceholder(c.paramIndex))); err != nil { + return err + } } ctx.Args = append(ctx.Args, valueStr) @@ -394,8 +448,17 @@ func (c *CommonSQLConverter) handleIntComparison(ctx *ConvertContext, field, ope } tablePrefix := c.dialect.GetTablePrefix() - if _, err := ctx.Buffer.WriteString(fmt.Sprintf("%s.`%s` %s %s", tablePrefix, field, operator, c.dialect.GetParameterPlaceholder(c.paramIndex))); err != nil { - return err + + if _, ok := c.dialect.(*PostgreSQLDialect); ok { + // PostgreSQL doesn't use backticks + if _, err := ctx.Buffer.WriteString(fmt.Sprintf("%s.%s %s %s", tablePrefix, field, operator, c.dialect.GetParameterPlaceholder(c.paramIndex))); err != nil { + return err + } + } else { + // MySQL and SQLite use backticks + if _, err := ctx.Buffer.WriteString(fmt.Sprintf("%s.`%s` %s %s", tablePrefix, field, operator, c.dialect.GetParameterPlaceholder(c.paramIndex))); err != nil { + return err + } } ctx.Args = append(ctx.Args, valueInt) @@ -411,18 +474,121 @@ func (c *CommonSQLConverter) handleBooleanComparison(ctx *ConvertContext, field, valueBool, ok := value.(bool) if !ok { - return errors.New("invalid boolean value for has_task_list") + return errors.Errorf("invalid boolean value for %s", field) } - sqlExpr := c.dialect.GetBooleanComparison("$.property.hasTaskList", valueBool) - if _, err := ctx.Buffer.WriteString(sqlExpr); err != nil { - return err + // Map field name to JSON path + var jsonPath string + switch field { + case "has_task_list": + jsonPath = "$.property.hasTaskList" + case "has_link": + jsonPath = "$.property.hasLink" + case "has_code": + jsonPath = "$.property.hasCode" + case "has_incomplete_tasks": + jsonPath = "$.property.hasIncompleteTasks" } - // For dialects that need parameters (PostgreSQL) + // Special handling for SQLite based on field + if _, ok := c.dialect.(*SQLiteDialect); ok { + if field == "has_task_list" { + // has_task_list uses = 1 / = 0 / != 1 / != 0 + var sqlExpr string + if operator == "=" { + if valueBool { + sqlExpr = fmt.Sprintf("%s = 1", c.dialect.GetJSONExtract(jsonPath)) + } else { + sqlExpr = fmt.Sprintf("%s = 0", c.dialect.GetJSONExtract(jsonPath)) + } + } else { // operator == "!=" + if valueBool { + sqlExpr = fmt.Sprintf("%s != 1", c.dialect.GetJSONExtract(jsonPath)) + } else { + sqlExpr = fmt.Sprintf("%s != 0", c.dialect.GetJSONExtract(jsonPath)) + } + } + if _, err := ctx.Buffer.WriteString(sqlExpr); err != nil { + return err + } + return nil + } else { + // Other fields use IS TRUE / NOT(... IS TRUE) + var sqlExpr string + if operator == "=" { + if valueBool { + sqlExpr = fmt.Sprintf("%s IS TRUE", c.dialect.GetJSONExtract(jsonPath)) + } else { + sqlExpr = fmt.Sprintf("NOT(%s IS TRUE)", c.dialect.GetJSONExtract(jsonPath)) + } + } else { // operator == "!=" + if valueBool { + sqlExpr = fmt.Sprintf("NOT(%s IS TRUE)", c.dialect.GetJSONExtract(jsonPath)) + } else { + sqlExpr = fmt.Sprintf("%s IS TRUE", c.dialect.GetJSONExtract(jsonPath)) + } + } + if _, err := ctx.Buffer.WriteString(sqlExpr); err != nil { + return err + } + return nil + } + } + + // Special handling for MySQL - use raw operator with CAST + if _, ok := c.dialect.(*MySQLDialect); ok { + var sqlExpr string + boolStr := "false" + if valueBool { + boolStr = "true" + } + sqlExpr = fmt.Sprintf("%s %s CAST('%s' AS JSON)", c.dialect.GetJSONExtract(jsonPath), operator, boolStr) + if _, err := ctx.Buffer.WriteString(sqlExpr); err != nil { + return err + } + return nil + } + + // Handle PostgreSQL differently - it uses the raw operator if _, ok := c.dialect.(*PostgreSQLDialect); ok { + var jsonExtract string + // Special handling for has_link, has_code, has_incomplete_tasks + if field == "has_link" || field == "has_code" || field == "has_incomplete_tasks" { + // Use memo-> format for these fields + parts := strings.Split(strings.TrimPrefix(jsonPath, "$."), ".") + jsonExtract = "memo->'payload'" + for i, part := range parts { + if i == len(parts)-1 { + jsonExtract += fmt.Sprintf("->>'%s'", part) + } else { + jsonExtract += fmt.Sprintf("->'%s'", part) + } + } + } else { + // Use standard format for has_task_list + jsonExtract = c.dialect.GetJSONExtract(jsonPath) + } + + sqlExpr := fmt.Sprintf("(%s)::boolean %s %s", + jsonExtract, + operator, + c.dialect.GetParameterPlaceholder(c.paramIndex)) + if _, err := ctx.Buffer.WriteString(sqlExpr); err != nil { + return err + } ctx.Args = append(ctx.Args, valueBool) c.paramIndex++ + return nil + } + + // Handle other dialects + if operator == "!=" { + valueBool = !valueBool + } + + sqlExpr := c.dialect.GetBooleanComparison(jsonPath, valueBool) + if _, err := ctx.Buffer.WriteString(sqlExpr); err != nil { + return err } return nil diff --git a/plugin/filter/dialect.go b/plugin/filter/dialect.go index 6c6ab4ca0..a6e15fa98 100644 --- a/plugin/filter/dialect.go +++ b/plugin/filter/dialect.go @@ -85,7 +85,10 @@ func (*SQLiteDialect) GetBooleanValue(value bool) interface{} { } func (d *SQLiteDialect) GetBooleanComparison(path string, value bool) string { - return fmt.Sprintf("%s = %d", d.GetJSONExtract(path), d.GetBooleanValue(value)) + if value { + return fmt.Sprintf("%s = 1", d.GetJSONExtract(path)) + } + return fmt.Sprintf("%s = 0", d.GetJSONExtract(path)) } func (d *SQLiteDialect) GetBooleanCheck(path string) string { @@ -132,11 +135,10 @@ func (*MySQLDialect) GetBooleanValue(value bool) interface{} { } func (d *MySQLDialect) GetBooleanComparison(path string, value bool) string { - boolStr := "false" if value { - boolStr = "true" + return fmt.Sprintf("%s = CAST('true' AS JSON)", d.GetJSONExtract(path)) } - return fmt.Sprintf("%s = CAST('%s' AS JSON)", d.GetJSONExtract(path), boolStr) + return fmt.Sprintf("%s != CAST('true' AS JSON)", d.GetJSONExtract(path)) } func (d *MySQLDialect) GetBooleanCheck(path string) string { @@ -163,7 +165,7 @@ func (*PostgreSQLDialect) GetParameterPlaceholder(index int) string { } func (d *PostgreSQLDialect) GetJSONExtract(path string) string { - // Convert $.property.hasTaskList to payload->'property'->>'hasTaskList' + // Convert $.property.hasTaskList to memo.payload->'property'->>'hasTaskList' parts := strings.Split(strings.TrimPrefix(path, "$."), ".") result := fmt.Sprintf("%s.payload", d.GetTablePrefix()) for i, part := range parts { @@ -196,10 +198,26 @@ func (*PostgreSQLDialect) GetBooleanValue(value bool) interface{} { } func (d *PostgreSQLDialect) GetBooleanComparison(path string, _ bool) string { + // Note: The parameter placeholder will be replaced by the caller return fmt.Sprintf("(%s)::boolean = ?", d.GetJSONExtract(path)) } func (d *PostgreSQLDialect) GetBooleanCheck(path string) string { + // Special handling for standalone boolean identifiers + if strings.Contains(path, "hasLink") || strings.Contains(path, "hasCode") || strings.Contains(path, "hasIncompleteTasks") { + // Use memo-> instead of memo.payload-> for these fields + parts := strings.Split(strings.TrimPrefix(path, "$."), ".") + result := fmt.Sprintf("%s->'payload'", d.GetTablePrefix()) + for i, part := range parts { + if i == len(parts)-1 { + result += fmt.Sprintf("->>'%s'", part) + } else { + result += fmt.Sprintf("->'%s'", part) + } + } + return fmt.Sprintf("(%s)::boolean = true", result) + } + // Use default format for other fields return fmt.Sprintf("(%s)::boolean IS TRUE", d.GetJSONExtract(path)) } diff --git a/proto/api/v1/memo_service.proto b/proto/api/v1/memo_service.proto index ce39885d1..da23c37f2 100644 --- a/proto/api/v1/memo_service.proto +++ b/proto/api/v1/memo_service.proto @@ -26,12 +26,8 @@ service MemoService { } // ListMemos lists memos with pagination and filter. rpc ListMemos(ListMemosRequest) returns (ListMemosResponse) { - option (google.api.http) = { - get: "/api/v1/memos" - additional_bindings: {get: "/api/v1/{parent=users/*}/memos"} - }; + option (google.api.http) = {get: "/api/v1/memos"}; option (google.api.method_signature) = ""; - option (google.api.method_signature) = "parent"; } // GetMemo gets a memo. rpc GetMemo(GetMemoRequest) returns (Memo) { @@ -276,27 +272,19 @@ message CreateMemoRequest { } message ListMemosRequest { - // Optional. The parent is the owner of the memos. - // If not specified or `users/-`, it will list all memos. - // Format: users/{user} - string parent = 1 [ - (google.api.field_behavior) = OPTIONAL, - (google.api.resource_reference) = {type: "memos.api.v1/User"} - ]; - // Optional. The maximum number of memos to return. // The service may return fewer than this value. // If unspecified, at most 50 memos will be returned. // The maximum value is 1000; values above 1000 will be coerced to 1000. - int32 page_size = 2 [(google.api.field_behavior) = OPTIONAL]; + int32 page_size = 1 [(google.api.field_behavior) = OPTIONAL]; // Optional. A page token, received from a previous `ListMemos` call. // Provide this to retrieve the subsequent page. - string page_token = 3 [(google.api.field_behavior) = OPTIONAL]; + string page_token = 2 [(google.api.field_behavior) = OPTIONAL]; // Optional. The state of the memos to list. // Default to `NORMAL`. Set to `ARCHIVED` to list archived memos. - State state = 4 [(google.api.field_behavior) = OPTIONAL]; + State state = 3 [(google.api.field_behavior) = OPTIONAL]; // Optional. The order to sort results by. // Default to "display_time desc". diff --git a/proto/gen/api/v1/memo_service.pb.go b/proto/gen/api/v1/memo_service.pb.go index 8ce53fddb..f993ed8cd 100644 --- a/proto/gen/api/v1/memo_service.pb.go +++ b/proto/gen/api/v1/memo_service.pb.go @@ -551,21 +551,17 @@ func (x *CreateMemoRequest) GetRequestId() string { type ListMemosRequest struct { state protoimpl.MessageState `protogen:"open.v1"` - // Optional. The parent is the owner of the memos. - // If not specified or `users/-`, it will list all memos. - // Format: users/{user} - Parent string `protobuf:"bytes,1,opt,name=parent,proto3" json:"parent,omitempty"` // Optional. The maximum number of memos to return. // The service may return fewer than this value. // If unspecified, at most 50 memos will be returned. // The maximum value is 1000; values above 1000 will be coerced to 1000. - PageSize int32 `protobuf:"varint,2,opt,name=page_size,json=pageSize,proto3" json:"page_size,omitempty"` + PageSize int32 `protobuf:"varint,1,opt,name=page_size,json=pageSize,proto3" json:"page_size,omitempty"` // Optional. A page token, received from a previous `ListMemos` call. // Provide this to retrieve the subsequent page. - PageToken string `protobuf:"bytes,3,opt,name=page_token,json=pageToken,proto3" json:"page_token,omitempty"` + PageToken string `protobuf:"bytes,2,opt,name=page_token,json=pageToken,proto3" json:"page_token,omitempty"` // Optional. The state of the memos to list. // Default to `NORMAL`. Set to `ARCHIVED` to list archived memos. - State State `protobuf:"varint,4,opt,name=state,proto3,enum=memos.api.v1.State" json:"state,omitempty"` + State State `protobuf:"varint,3,opt,name=state,proto3,enum=memos.api.v1.State" json:"state,omitempty"` // Optional. The order to sort results by. // Default to "display_time desc". // Example: "display_time desc" or "create_time asc" @@ -610,13 +606,6 @@ func (*ListMemosRequest) Descriptor() ([]byte, []int) { return file_api_v1_memo_service_proto_rawDescGZIP(), []int{4} } -func (x *ListMemosRequest) GetParent() string { - if x != nil { - return x.Parent - } - return "" -} - func (x *ListMemosRequest) GetPageSize() int32 { if x != nil { return x.PageSize @@ -2064,14 +2053,12 @@ const file_api_v1_memo_service_proto_rawDesc = "" + "\amemo_id\x18\x02 \x01(\tB\x03\xe0A\x01R\x06memoId\x12(\n" + "\rvalidate_only\x18\x03 \x01(\bB\x03\xe0A\x01R\fvalidateOnly\x12\"\n" + "\n" + - "request_id\x18\x04 \x01(\tB\x03\xe0A\x01R\trequestId\"\xa0\x02\n" + - "\x10ListMemosRequest\x121\n" + - "\x06parent\x18\x01 \x01(\tB\x19\xe0A\x01\xfaA\x13\n" + - "\x11memos.api.v1/UserR\x06parent\x12 \n" + - "\tpage_size\x18\x02 \x01(\x05B\x03\xe0A\x01R\bpageSize\x12\"\n" + + "request_id\x18\x04 \x01(\tB\x03\xe0A\x01R\trequestId\"\xed\x01\n" + + "\x10ListMemosRequest\x12 \n" + + "\tpage_size\x18\x01 \x01(\x05B\x03\xe0A\x01R\bpageSize\x12\"\n" + "\n" + - "page_token\x18\x03 \x01(\tB\x03\xe0A\x01R\tpageToken\x12.\n" + - "\x05state\x18\x04 \x01(\x0e2\x13.memos.api.v1.StateB\x03\xe0A\x01R\x05state\x12\x1e\n" + + "page_token\x18\x02 \x01(\tB\x03\xe0A\x01R\tpageToken\x12.\n" + + "\x05state\x18\x03 \x01(\x0e2\x13.memos.api.v1.StateB\x03\xe0A\x01R\x05state\x12\x1e\n" + "\border_by\x18\x05 \x01(\tB\x03\xe0A\x01R\aorderBy\x12\x1b\n" + "\x06filter\x18\x06 \x01(\tB\x03\xe0A\x01R\x06filter\x12&\n" + "\fshow_deleted\x18\a \x01(\bB\x03\xe0A\x01R\vshowDeleted\"\x84\x01\n" + @@ -2187,11 +2174,11 @@ const file_api_v1_memo_service_proto_rawDesc = "" + "\aPRIVATE\x10\x01\x12\r\n" + "\tPROTECTED\x10\x02\x12\n" + "\n" + - "\x06PUBLIC\x10\x032\x97\x11\n" + + "\x06PUBLIC\x10\x032\xeb\x10\n" + "\vMemoService\x12e\n" + "\n" + - "CreateMemo\x12\x1f.memos.api.v1.CreateMemoRequest\x1a\x12.memos.api.v1.Memo\"\"\xdaA\x04memo\x82\xd3\xe4\x93\x02\x15:\x04memo\"\r/api/v1/memos\x12\x91\x01\n" + - "\tListMemos\x12\x1e.memos.api.v1.ListMemosRequest\x1a\x1f.memos.api.v1.ListMemosResponse\"C\xdaA\x00\xdaA\x06parent\x82\xd3\xe4\x93\x021Z \x12\x1e/api/v1/{parent=users/*}/memos\x12\r/api/v1/memos\x12b\n" + + "CreateMemo\x12\x1f.memos.api.v1.CreateMemoRequest\x1a\x12.memos.api.v1.Memo\"\"\xdaA\x04memo\x82\xd3\xe4\x93\x02\x15:\x04memo\"\r/api/v1/memos\x12f\n" + + "\tListMemos\x12\x1e.memos.api.v1.ListMemosRequest\x1a\x1f.memos.api.v1.ListMemosResponse\"\x18\xdaA\x00\x82\xd3\xe4\x93\x02\x0f\x12\r/api/v1/memos\x12b\n" + "\aGetMemo\x12\x1c.memos.api.v1.GetMemoRequest\x1a\x12.memos.api.v1.Memo\"%\xdaA\x04name\x82\xd3\xe4\x93\x02\x18\x12\x16/api/v1/{name=memos/*}\x12\x7f\n" + "\n" + "UpdateMemo\x12\x1f.memos.api.v1.UpdateMemoRequest\x1a\x12.memos.api.v1.Memo\"<\xdaA\x10memo,update_mask\x82\xd3\xe4\x93\x02#:\x04memo2\x1b/api/v1/{memo.name=memos/*}\x12l\n" + diff --git a/proto/gen/api/v1/memo_service.pb.gw.go b/proto/gen/api/v1/memo_service.pb.gw.go index 77ee50303..ee8bf627f 100644 --- a/proto/gen/api/v1/memo_service.pb.gw.go +++ b/proto/gen/api/v1/memo_service.pb.gw.go @@ -111,59 +111,6 @@ func local_request_MemoService_ListMemos_0(ctx context.Context, marshaler runtim return msg, metadata, err } -var filter_MemoService_ListMemos_1 = &utilities.DoubleArray{Encoding: map[string]int{"parent": 0}, Base: []int{1, 1, 0}, Check: []int{0, 1, 2}} - -func request_MemoService_ListMemos_1(ctx context.Context, marshaler runtime.Marshaler, client MemoServiceClient, req *http.Request, pathParams map[string]string) (proto.Message, runtime.ServerMetadata, error) { - var ( - protoReq ListMemosRequest - metadata runtime.ServerMetadata - err error - ) - if req.Body != nil { - _, _ = io.Copy(io.Discard, req.Body) - } - val, ok := pathParams["parent"] - if !ok { - return nil, metadata, status.Errorf(codes.InvalidArgument, "missing parameter %s", "parent") - } - protoReq.Parent, err = runtime.String(val) - if err != nil { - return nil, metadata, status.Errorf(codes.InvalidArgument, "type mismatch, parameter: %s, error: %v", "parent", err) - } - if err := req.ParseForm(); err != nil { - return nil, metadata, status.Errorf(codes.InvalidArgument, "%v", err) - } - if err := runtime.PopulateQueryParameters(&protoReq, req.Form, filter_MemoService_ListMemos_1); err != nil { - return nil, metadata, status.Errorf(codes.InvalidArgument, "%v", err) - } - msg, err := client.ListMemos(ctx, &protoReq, grpc.Header(&metadata.HeaderMD), grpc.Trailer(&metadata.TrailerMD)) - return msg, metadata, err -} - -func local_request_MemoService_ListMemos_1(ctx context.Context, marshaler runtime.Marshaler, server MemoServiceServer, req *http.Request, pathParams map[string]string) (proto.Message, runtime.ServerMetadata, error) { - var ( - protoReq ListMemosRequest - metadata runtime.ServerMetadata - err error - ) - val, ok := pathParams["parent"] - if !ok { - return nil, metadata, status.Errorf(codes.InvalidArgument, "missing parameter %s", "parent") - } - protoReq.Parent, err = runtime.String(val) - if err != nil { - return nil, metadata, status.Errorf(codes.InvalidArgument, "type mismatch, parameter: %s, error: %v", "parent", err) - } - if err := req.ParseForm(); err != nil { - return nil, metadata, status.Errorf(codes.InvalidArgument, "%v", err) - } - if err := runtime.PopulateQueryParameters(&protoReq, req.Form, filter_MemoService_ListMemos_1); err != nil { - return nil, metadata, status.Errorf(codes.InvalidArgument, "%v", err) - } - msg, err := server.ListMemos(ctx, &protoReq) - return msg, metadata, err -} - var filter_MemoService_GetMemo_0 = &utilities.DoubleArray{Encoding: map[string]int{"name": 0}, Base: []int{1, 1, 0}, Check: []int{0, 1, 2}} func request_MemoService_GetMemo_0(ctx context.Context, marshaler runtime.Marshaler, client MemoServiceClient, req *http.Request, pathParams map[string]string) (proto.Message, runtime.ServerMetadata, error) { @@ -956,26 +903,6 @@ func RegisterMemoServiceHandlerServer(ctx context.Context, mux *runtime.ServeMux } forward_MemoService_ListMemos_0(annotatedContext, mux, outboundMarshaler, w, req, resp, mux.GetForwardResponseOptions()...) }) - mux.Handle(http.MethodGet, pattern_MemoService_ListMemos_1, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) { - ctx, cancel := context.WithCancel(req.Context()) - defer cancel() - var stream runtime.ServerTransportStream - ctx = grpc.NewContextWithServerTransportStream(ctx, &stream) - inboundMarshaler, outboundMarshaler := runtime.MarshalerForRequest(mux, req) - annotatedContext, err := runtime.AnnotateIncomingContext(ctx, mux, req, "/memos.api.v1.MemoService/ListMemos", runtime.WithHTTPPathPattern("/api/v1/{parent=users/*}/memos")) - if err != nil { - runtime.HTTPError(ctx, mux, outboundMarshaler, w, req, err) - return - } - resp, md, err := local_request_MemoService_ListMemos_1(annotatedContext, inboundMarshaler, server, req, pathParams) - md.HeaderMD, md.TrailerMD = metadata.Join(md.HeaderMD, stream.Header()), metadata.Join(md.TrailerMD, stream.Trailer()) - annotatedContext = runtime.NewServerMetadataContext(annotatedContext, md) - if err != nil { - runtime.HTTPError(annotatedContext, mux, outboundMarshaler, w, req, err) - return - } - forward_MemoService_ListMemos_1(annotatedContext, mux, outboundMarshaler, w, req, resp, mux.GetForwardResponseOptions()...) - }) mux.Handle(http.MethodGet, pattern_MemoService_GetMemo_0, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) { ctx, cancel := context.WithCancel(req.Context()) defer cancel() @@ -1330,23 +1257,6 @@ func RegisterMemoServiceHandlerClient(ctx context.Context, mux *runtime.ServeMux } forward_MemoService_ListMemos_0(annotatedContext, mux, outboundMarshaler, w, req, resp, mux.GetForwardResponseOptions()...) }) - mux.Handle(http.MethodGet, pattern_MemoService_ListMemos_1, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) { - ctx, cancel := context.WithCancel(req.Context()) - defer cancel() - inboundMarshaler, outboundMarshaler := runtime.MarshalerForRequest(mux, req) - annotatedContext, err := runtime.AnnotateContext(ctx, mux, req, "/memos.api.v1.MemoService/ListMemos", runtime.WithHTTPPathPattern("/api/v1/{parent=users/*}/memos")) - if err != nil { - runtime.HTTPError(ctx, mux, outboundMarshaler, w, req, err) - return - } - resp, md, err := request_MemoService_ListMemos_1(annotatedContext, inboundMarshaler, client, req, pathParams) - annotatedContext = runtime.NewServerMetadataContext(annotatedContext, md) - if err != nil { - runtime.HTTPError(annotatedContext, mux, outboundMarshaler, w, req, err) - return - } - forward_MemoService_ListMemos_1(annotatedContext, mux, outboundMarshaler, w, req, resp, mux.GetForwardResponseOptions()...) - }) mux.Handle(http.MethodGet, pattern_MemoService_GetMemo_0, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) { ctx, cancel := context.WithCancel(req.Context()) defer cancel() @@ -1591,7 +1501,6 @@ func RegisterMemoServiceHandlerClient(ctx context.Context, mux *runtime.ServeMux var ( pattern_MemoService_CreateMemo_0 = runtime.MustPattern(runtime.NewPattern(1, []int{2, 0, 2, 1, 2, 2}, []string{"api", "v1", "memos"}, "")) pattern_MemoService_ListMemos_0 = runtime.MustPattern(runtime.NewPattern(1, []int{2, 0, 2, 1, 2, 2}, []string{"api", "v1", "memos"}, "")) - pattern_MemoService_ListMemos_1 = runtime.MustPattern(runtime.NewPattern(1, []int{2, 0, 2, 1, 2, 2, 1, 0, 4, 2, 5, 3, 2, 4}, []string{"api", "v1", "users", "parent", "memos"}, "")) pattern_MemoService_GetMemo_0 = runtime.MustPattern(runtime.NewPattern(1, []int{2, 0, 2, 1, 2, 2, 1, 0, 4, 2, 5, 3}, []string{"api", "v1", "memos", "name"}, "")) pattern_MemoService_UpdateMemo_0 = runtime.MustPattern(runtime.NewPattern(1, []int{2, 0, 2, 1, 2, 2, 1, 0, 4, 2, 5, 3}, []string{"api", "v1", "memos", "memo.name"}, "")) pattern_MemoService_DeleteMemo_0 = runtime.MustPattern(runtime.NewPattern(1, []int{2, 0, 2, 1, 2, 2, 1, 0, 4, 2, 5, 3}, []string{"api", "v1", "memos", "name"}, "")) @@ -1611,7 +1520,6 @@ var ( var ( forward_MemoService_CreateMemo_0 = runtime.ForwardResponseMessage forward_MemoService_ListMemos_0 = runtime.ForwardResponseMessage - forward_MemoService_ListMemos_1 = runtime.ForwardResponseMessage forward_MemoService_GetMemo_0 = runtime.ForwardResponseMessage forward_MemoService_UpdateMemo_0 = runtime.ForwardResponseMessage forward_MemoService_DeleteMemo_0 = runtime.ForwardResponseMessage diff --git a/proto/gen/openapi.yaml b/proto/gen/openapi.yaml index fdaeea45d..39650df2f 100644 --- a/proto/gen/openapi.yaml +++ b/proto/gen/openapi.yaml @@ -622,14 +622,6 @@ paths: description: ListMemos lists memos with pagination and filter. operationId: MemoService_ListMemos parameters: - - name: parent - in: query - description: |- - Optional. The parent is the owner of the memos. - If not specified or `users/-`, it will list all memos. - Format: users/{user} - schema: - type: string - name: pageSize in: query description: |- @@ -1597,82 +1589,6 @@ paths: application/json: schema: $ref: '#/components/schemas/Status' - /api/v1/users/{user}/memos: - get: - tags: - - MemoService - description: ListMemos lists memos with pagination and filter. - operationId: MemoService_ListMemos - parameters: - - name: user - in: path - description: The user id. - required: true - schema: - type: string - - name: pageSize - in: query - description: |- - Optional. The maximum number of memos to return. - The service may return fewer than this value. - If unspecified, at most 50 memos will be returned. - The maximum value is 1000; values above 1000 will be coerced to 1000. - schema: - type: integer - format: int32 - - name: pageToken - in: query - description: |- - Optional. A page token, received from a previous `ListMemos` call. - Provide this to retrieve the subsequent page. - schema: - type: string - - name: state - in: query - description: |- - Optional. The state of the memos to list. - Default to `NORMAL`. Set to `ARCHIVED` to list archived memos. - schema: - enum: - - STATE_UNSPECIFIED - - NORMAL - - ARCHIVED - type: string - format: enum - - name: orderBy - in: query - description: |- - Optional. The order to sort results by. - Default to "display_time desc". - Example: "display_time desc" or "create_time asc" - schema: - type: string - - name: filter - in: query - description: |- - Optional. Filter to apply to the list results. - Filter is a CEL expression to filter memos. - Refer to `Shortcut.filter`. - schema: - type: string - - name: showDeleted - in: query - description: Optional. If true, show deleted memos in the response. - schema: - type: boolean - responses: - "200": - description: OK - content: - application/json: - schema: - $ref: '#/components/schemas/ListMemosResponse' - default: - description: Default error response - content: - application/json: - schema: - $ref: '#/components/schemas/Status' /api/v1/users/{user}/sessions: get: tags: diff --git a/server/router/api/v1/memo_service.go b/server/router/api/v1/memo_service.go index 3ee1d59e8..7375a1f15 100644 --- a/server/router/api/v1/memo_service.go +++ b/server/router/api/v1/memo_service.go @@ -99,13 +99,6 @@ func (s *APIV1Service) ListMemos(ctx context.Context, request *v1pb.ListMemosReq // Exclude comments by default. ExcludeComments: true, } - if request.Parent != "" && request.Parent != "users/-" { - userID, err := ExtractUserIDFromName(request.Parent) - if err != nil { - return nil, status.Errorf(codes.InvalidArgument, "invalid parent: %v", err) - } - memoFind.CreatorID = &userID - } if request.State == v1pb.State_ARCHIVED { state := store.Archived memoFind.RowStatus = &state diff --git a/server/router/api/v1/shortcut_service.go b/server/router/api/v1/shortcut_service.go index e0146e505..46d1b3849 100644 --- a/server/router/api/v1/shortcut_service.go +++ b/server/router/api/v1/shortcut_service.go @@ -329,7 +329,23 @@ func (s *APIV1Service) validateFilter(_ context.Context, filterStr string) error return errors.Wrap(err, "failed to parse filter") } convertCtx := filter.NewConvertContext() - err = s.Store.GetDriver().ConvertExprToSQL(convertCtx, parsedExpr.GetExpr()) + + // Determine the dialect based on the actual database driver + var dialect filter.SQLDialect + switch s.Profile.Driver { + case "sqlite": + dialect = &filter.SQLiteDialect{} + case "mysql": + dialect = &filter.MySQLDialect{} + case "postgres": + dialect = &filter.PostgreSQLDialect{} + default: + // Default to SQLite for unknown drivers + dialect = &filter.SQLiteDialect{} + } + + converter := filter.NewCommonSQLConverter(dialect) + err = converter.ConvertExprToSQL(convertCtx, parsedExpr.GetExpr()) if err != nil { return errors.Wrap(err, "failed to convert filter to SQL") } diff --git a/store/db/mysql/memo.go b/store/db/mysql/memo.go index 19196cce9..7d59e431b 100644 --- a/store/db/mysql/memo.go +++ b/store/db/mysql/memo.go @@ -59,7 +59,8 @@ func (d *DB) ListMemos(ctx context.Context, find *store.FindMemo) ([]*store.Memo } convertCtx := filter.NewConvertContext() // ConvertExprToSQL converts the parsed expression to a SQL condition string. - if err := d.ConvertExprToSQL(convertCtx, parsedExpr.GetExpr()); err != nil { + converter := filter.NewCommonSQLConverter(&filter.MySQLDialect{}) + if err := converter.ConvertExprToSQL(convertCtx, parsedExpr.GetExpr()); err != nil { return nil, err } condition := convertCtx.Buffer.String() diff --git a/store/db/mysql/memo_filter.go b/store/db/mysql/memo_filter.go deleted file mode 100644 index b69332968..000000000 --- a/store/db/mysql/memo_filter.go +++ /dev/null @@ -1,357 +0,0 @@ -package mysql - -import ( - "fmt" - "slices" - "strings" - - "github.com/pkg/errors" - exprv1 "google.golang.org/genproto/googleapis/api/expr/v1alpha1" - - "github.com/usememos/memos/plugin/filter" -) - -func (d *DB) ConvertExprToSQL(ctx *filter.ConvertContext, expr *exprv1.Expr) error { - return d.convertWithTemplates(ctx, expr) -} - -func (d *DB) convertWithTemplates(ctx *filter.ConvertContext, expr *exprv1.Expr) error { - const dbType = filter.MySQLTemplate - - if v, ok := expr.ExprKind.(*exprv1.Expr_CallExpr); ok { - switch v.CallExpr.Function { - case "_||_", "_&&_": - if len(v.CallExpr.Args) != 2 { - return errors.Errorf("invalid number of arguments for %s", v.CallExpr.Function) - } - if _, err := ctx.Buffer.WriteString("("); err != nil { - return err - } - if err := d.convertWithTemplates(ctx, v.CallExpr.Args[0]); err != nil { - return err - } - operator := "AND" - if v.CallExpr.Function == "_||_" { - operator = "OR" - } - if _, err := ctx.Buffer.WriteString(fmt.Sprintf(" %s ", operator)); err != nil { - return err - } - if err := d.convertWithTemplates(ctx, v.CallExpr.Args[1]); err != nil { - return err - } - if _, err := ctx.Buffer.WriteString(")"); err != nil { - return err - } - case "!_": - if len(v.CallExpr.Args) != 1 { - return errors.Errorf("invalid number of arguments for %s", v.CallExpr.Function) - } - if _, err := ctx.Buffer.WriteString("NOT ("); err != nil { - return err - } - if err := d.convertWithTemplates(ctx, v.CallExpr.Args[0]); err != nil { - return err - } - if _, err := ctx.Buffer.WriteString(")"); err != nil { - return err - } - case "_==_", "_!=_", "_<_", "_>_", "_<=_", "_>=_": - if len(v.CallExpr.Args) != 2 { - return errors.Errorf("invalid number of arguments for %s", v.CallExpr.Function) - } - // Check if the left side is a function call like size(tags) - if leftCallExpr, ok := v.CallExpr.Args[0].ExprKind.(*exprv1.Expr_CallExpr); ok { - if leftCallExpr.CallExpr.Function == "size" { - // Handle size(tags) comparison - if len(leftCallExpr.CallExpr.Args) != 1 { - return errors.New("size function requires exactly one argument") - } - identifier, err := filter.GetIdentExprName(leftCallExpr.CallExpr.Args[0]) - if err != nil { - return err - } - if identifier != "tags" { - return errors.Errorf("size function only supports 'tags' identifier, got: %s", identifier) - } - value, err := filter.GetExprValue(v.CallExpr.Args[1]) - if err != nil { - return err - } - valueInt, ok := value.(int64) - if !ok { - return errors.New("size comparison value must be an integer") - } - operator := d.getComparisonOperator(v.CallExpr.Function) - - if _, err := ctx.Buffer.WriteString(fmt.Sprintf("%s %s ?", - filter.GetSQL("json_array_length", dbType), operator)); err != nil { - return err - } - ctx.Args = append(ctx.Args, valueInt) - return nil - } - } - - identifier, err := filter.GetIdentExprName(v.CallExpr.Args[0]) - if err != nil { - return err - } - if !slices.Contains([]string{"creator_id", "created_ts", "updated_ts", "visibility", "content", "has_task_list", "has_link", "has_code", "has_incomplete_tasks"}, identifier) { - return errors.Errorf("invalid identifier for %s", v.CallExpr.Function) - } - value, err := filter.GetExprValue(v.CallExpr.Args[1]) - if err != nil { - return err - } - operator := d.getComparisonOperator(v.CallExpr.Function) - - if identifier == "created_ts" || identifier == "updated_ts" { - valueInt, ok := value.(int64) - if !ok { - return errors.New("invalid integer timestamp value") - } - - timestampSQL := fmt.Sprintf(filter.GetSQL("timestamp_field", dbType), identifier) - if _, err := ctx.Buffer.WriteString(fmt.Sprintf("%s %s ?", timestampSQL, operator)); err != nil { - return err - } - ctx.Args = append(ctx.Args, valueInt) - } else if identifier == "visibility" || identifier == "content" { - if operator != "=" && operator != "!=" { - return errors.Errorf("invalid operator for %s", v.CallExpr.Function) - } - valueStr, ok := value.(string) - if !ok { - return errors.New("invalid string value") - } - - var sqlTemplate string - if identifier == "visibility" { - sqlTemplate = filter.GetSQL("table_prefix", dbType) + ".`visibility`" - } else if identifier == "content" { - sqlTemplate = filter.GetSQL("table_prefix", dbType) + ".`content`" - } - if _, err := ctx.Buffer.WriteString(fmt.Sprintf("%s %s ?", sqlTemplate, operator)); err != nil { - return err - } - ctx.Args = append(ctx.Args, valueStr) - } else if identifier == "creator_id" { - if operator != "=" && operator != "!=" { - return errors.Errorf("invalid operator for %s", v.CallExpr.Function) - } - valueInt, ok := value.(int64) - if !ok { - return errors.New("invalid int value") - } - - sqlTemplate := filter.GetSQL("table_prefix", dbType) + ".`creator_id`" - if _, err := ctx.Buffer.WriteString(fmt.Sprintf("%s %s ?", sqlTemplate, operator)); err != nil { - return err - } - ctx.Args = append(ctx.Args, valueInt) - } else if identifier == "has_task_list" { - if operator != "=" && operator != "!=" { - return errors.Errorf("invalid operator for %s", v.CallExpr.Function) - } - valueBool, ok := value.(bool) - if !ok { - return errors.New("invalid boolean value for has_task_list") - } - // Use template for boolean comparison - var sqlTemplate string - if operator == "=" { - if valueBool { - sqlTemplate = filter.GetSQL("boolean_true", dbType) - } else { - sqlTemplate = filter.GetSQL("boolean_false", dbType) - } - } else { // operator == "!=" - if valueBool { - sqlTemplate = filter.GetSQL("boolean_not_true", dbType) - } else { - sqlTemplate = filter.GetSQL("boolean_not_false", dbType) - } - } - if _, err := ctx.Buffer.WriteString(sqlTemplate); err != nil { - return err - } - } else if identifier == "has_link" || identifier == "has_code" || identifier == "has_incomplete_tasks" { - if operator != "=" && operator != "!=" { - return errors.Errorf("invalid operator for %s", v.CallExpr.Function) - } - valueBool, ok := value.(bool) - if !ok { - return errors.Errorf("invalid boolean value for %s", identifier) - } - - // Map identifier to JSON path - var jsonPath string - switch identifier { - case "has_link": - jsonPath = "$.property.hasLink" - case "has_code": - jsonPath = "$.property.hasCode" - case "has_incomplete_tasks": - jsonPath = "$.property.hasIncompleteTasks" - } - - // Use JSON_EXTRACT for boolean comparison like has_task_list - var sqlTemplate string - if operator == "=" { - if valueBool { - sqlTemplate = fmt.Sprintf("JSON_EXTRACT(`memo`.`payload`, '%s') = CAST('true' AS JSON)", jsonPath) - } else { - sqlTemplate = fmt.Sprintf("JSON_EXTRACT(`memo`.`payload`, '%s') = CAST('false' AS JSON)", jsonPath) - } - } else { // operator == "!=" - if valueBool { - sqlTemplate = fmt.Sprintf("JSON_EXTRACT(`memo`.`payload`, '%s') != CAST('true' AS JSON)", jsonPath) - } else { - sqlTemplate = fmt.Sprintf("JSON_EXTRACT(`memo`.`payload`, '%s') != CAST('false' AS JSON)", jsonPath) - } - } - if _, err := ctx.Buffer.WriteString(sqlTemplate); err != nil { - return err - } - } - case "@in": - if len(v.CallExpr.Args) != 2 { - return errors.Errorf("invalid number of arguments for %s", v.CallExpr.Function) - } - - // Check if this is "element in collection" syntax - if identifier, err := filter.GetIdentExprName(v.CallExpr.Args[1]); err == nil { - // This is "element in collection" - the second argument is the collection - if !slices.Contains([]string{"tags"}, identifier) { - return errors.Errorf("invalid collection identifier for %s: %s", v.CallExpr.Function, identifier) - } - - if identifier == "tags" { - // Handle "element" in tags - element, err := filter.GetConstValue(v.CallExpr.Args[0]) - if err != nil { - return errors.Errorf("first argument must be a constant value for 'element in tags': %v", err) - } - if _, err := ctx.Buffer.WriteString(filter.GetSQL("json_contains_element", dbType)); err != nil { - return err - } - ctx.Args = append(ctx.Args, filter.GetParameterValue(dbType, "json_contains_element", element)) - } - return nil - } - - // Original logic for "identifier in [list]" syntax - identifier, err := filter.GetIdentExprName(v.CallExpr.Args[0]) - if err != nil { - return err - } - if !slices.Contains([]string{"tag", "visibility"}, identifier) { - return errors.Errorf("invalid identifier for %s", v.CallExpr.Function) - } - - values := []any{} - for _, element := range v.CallExpr.Args[1].GetListExpr().Elements { - value, err := filter.GetConstValue(element) - if err != nil { - return err - } - values = append(values, value) - } - if identifier == "tag" { - subconditions := []string{} - args := []any{} - for _, v := range values { - subconditions = append(subconditions, filter.GetSQL("json_contains_tag", dbType)) - args = append(args, filter.GetParameterValue(dbType, "json_contains_tag", v)) - } - if len(subconditions) == 1 { - if _, err := ctx.Buffer.WriteString(subconditions[0]); err != nil { - return err - } - } else { - if _, err := ctx.Buffer.WriteString(fmt.Sprintf("(%s)", strings.Join(subconditions, " OR "))); err != nil { - return err - } - } - ctx.Args = append(ctx.Args, args...) - } else if identifier == "visibility" { - placeholders := filter.FormatPlaceholders(dbType, len(values), 1) - visibilitySQL := fmt.Sprintf(filter.GetSQL("visibility_in", dbType), strings.Join(placeholders, ",")) - if _, err := ctx.Buffer.WriteString(visibilitySQL); err != nil { - return err - } - ctx.Args = append(ctx.Args, values...) - } - case "contains": - if len(v.CallExpr.Args) != 1 { - return errors.Errorf("invalid number of arguments for %s", v.CallExpr.Function) - } - identifier, err := filter.GetIdentExprName(v.CallExpr.Target) - if err != nil { - return err - } - if identifier != "content" { - return errors.Errorf("invalid identifier for %s", v.CallExpr.Function) - } - arg, err := filter.GetConstValue(v.CallExpr.Args[0]) - if err != nil { - return err - } - if _, err := ctx.Buffer.WriteString(filter.GetSQL("content_like", dbType)); err != nil { - return err - } - ctx.Args = append(ctx.Args, fmt.Sprintf("%%%s%%", arg)) - } - } else if v, ok := expr.ExprKind.(*exprv1.Expr_IdentExpr); ok { - identifier := v.IdentExpr.GetName() - if !slices.Contains([]string{"pinned", "has_task_list", "has_link", "has_code", "has_incomplete_tasks"}, identifier) { - return errors.Errorf("invalid identifier %s", identifier) - } - if identifier == "pinned" { - if _, err := ctx.Buffer.WriteString(filter.GetSQL("table_prefix", dbType) + ".`pinned` IS TRUE"); err != nil { - return err - } - } else if identifier == "has_task_list" { - // Handle has_task_list as a standalone boolean identifier - if _, err := ctx.Buffer.WriteString(filter.GetSQL("boolean_check", dbType)); err != nil { - return err - } - } else if identifier == "has_link" { - // Handle has_link as a standalone boolean identifier - if _, err := ctx.Buffer.WriteString("JSON_EXTRACT(`memo`.`payload`, '$.property.hasLink') = CAST('true' AS JSON)"); err != nil { - return err - } - } else if identifier == "has_code" { - // Handle has_code as a standalone boolean identifier - if _, err := ctx.Buffer.WriteString("JSON_EXTRACT(`memo`.`payload`, '$.property.hasCode') = CAST('true' AS JSON)"); err != nil { - return err - } - } else if identifier == "has_incomplete_tasks" { - // Handle has_incomplete_tasks as a standalone boolean identifier - if _, err := ctx.Buffer.WriteString("JSON_EXTRACT(`memo`.`payload`, '$.property.hasIncompleteTasks') = CAST('true' AS JSON)"); err != nil { - return err - } - } - } - return nil -} - -func (*DB) getComparisonOperator(function string) string { - switch function { - case "_==_": - return "=" - case "_!=_": - return "!=" - case "_<_": - return "<" - case "_>_": - return ">" - case "_<=_": - return "<=" - case "_>=_": - return ">=" - default: - return "=" - } -} diff --git a/store/db/mysql/memo_filter_test.go b/store/db/mysql/memo_filter_test.go index 397e2ca5a..330b8fa54 100644 --- a/store/db/mysql/memo_filter_test.go +++ b/store/db/mysql/memo_filter_test.go @@ -148,11 +148,11 @@ func TestConvertExprToSQL(t *testing.T) { } for _, tt := range tests { - db := &DB{} parsedExpr, err := filter.Parse(tt.filter, filter.MemoFilterCELAttributes...) require.NoError(t, err) convertCtx := filter.NewConvertContext() - err = db.ConvertExprToSQL(convertCtx, parsedExpr.GetExpr()) + converter := filter.NewCommonSQLConverter(&filter.MySQLDialect{}) + err = converter.ConvertExprToSQL(convertCtx, parsedExpr.GetExpr()) require.NoError(t, err) require.Equal(t, tt.want, convertCtx.Buffer.String()) require.Equal(t, tt.args, convertCtx.Args) diff --git a/store/db/mysql/memo_relation.go b/store/db/mysql/memo_relation.go index 895248834..a57bda8eb 100644 --- a/store/db/mysql/memo_relation.go +++ b/store/db/mysql/memo_relation.go @@ -51,7 +51,8 @@ func (d *DB) ListMemoRelations(ctx context.Context, find *store.FindMemoRelation } convertCtx := filter.NewConvertContext() // ConvertExprToSQL converts the parsed expression to a SQL condition string. - if err := d.ConvertExprToSQL(convertCtx, parsedExpr.GetExpr()); err != nil { + converter := filter.NewCommonSQLConverter(&filter.MySQLDialect{}) + if err := converter.ConvertExprToSQL(convertCtx, parsedExpr.GetExpr()); err != nil { return nil, err } condition := convertCtx.Buffer.String() diff --git a/store/db/postgres/memo.go b/store/db/postgres/memo.go index b41d78805..b30d0bc11 100644 --- a/store/db/postgres/memo.go +++ b/store/db/postgres/memo.go @@ -51,7 +51,8 @@ func (d *DB) ListMemos(ctx context.Context, find *store.FindMemo) ([]*store.Memo convertCtx := filter.NewConvertContext() convertCtx.ArgsOffset = len(args) // ConvertExprToSQL converts the parsed expression to a SQL condition string. - if err := d.ConvertExprToSQL(convertCtx, parsedExpr.GetExpr()); err != nil { + converter := filter.NewCommonSQLConverterWithOffset(&filter.PostgreSQLDialect{}, convertCtx.ArgsOffset+len(convertCtx.Args)) + if err := converter.ConvertExprToSQL(convertCtx, parsedExpr.GetExpr()); err != nil { return nil, err } condition := convertCtx.Buffer.String() diff --git a/store/db/postgres/memo_filter.go b/store/db/postgres/memo_filter.go deleted file mode 100644 index e982880fc..000000000 --- a/store/db/postgres/memo_filter.go +++ /dev/null @@ -1,373 +0,0 @@ -package postgres - -import ( - "fmt" - "slices" - "strings" - - "github.com/pkg/errors" - exprv1 "google.golang.org/genproto/googleapis/api/expr/v1alpha1" - - "github.com/usememos/memos/plugin/filter" -) - -func (d *DB) ConvertExprToSQL(ctx *filter.ConvertContext, expr *exprv1.Expr) error { - const dbType = filter.PostgreSQLTemplate - _, err := d.convertWithParameterIndex(ctx, expr, dbType, ctx.ArgsOffset+len(ctx.Args)+1) - return err -} - -func (d *DB) convertWithParameterIndex(ctx *filter.ConvertContext, expr *exprv1.Expr, dbType filter.TemplateDBType, paramIndex int) (int, error) { - if v, ok := expr.ExprKind.(*exprv1.Expr_CallExpr); ok { - switch v.CallExpr.Function { - case "_||_", "_&&_": - if len(v.CallExpr.Args) != 2 { - return paramIndex, errors.Errorf("invalid number of arguments for %s", v.CallExpr.Function) - } - if _, err := ctx.Buffer.WriteString("("); err != nil { - return paramIndex, err - } - newParamIndex, err := d.convertWithParameterIndex(ctx, v.CallExpr.Args[0], dbType, paramIndex) - if err != nil { - return paramIndex, err - } - operator := "AND" - if v.CallExpr.Function == "_||_" { - operator = "OR" - } - if _, err := ctx.Buffer.WriteString(fmt.Sprintf(" %s ", operator)); err != nil { - return paramIndex, err - } - newParamIndex, err = d.convertWithParameterIndex(ctx, v.CallExpr.Args[1], dbType, newParamIndex) - if err != nil { - return paramIndex, err - } - if _, err := ctx.Buffer.WriteString(")"); err != nil { - return paramIndex, err - } - return newParamIndex, nil - case "!_": - if len(v.CallExpr.Args) != 1 { - return paramIndex, errors.Errorf("invalid number of arguments for %s", v.CallExpr.Function) - } - if _, err := ctx.Buffer.WriteString("NOT ("); err != nil { - return paramIndex, err - } - newParamIndex, err := d.convertWithParameterIndex(ctx, v.CallExpr.Args[0], dbType, paramIndex) - if err != nil { - return paramIndex, err - } - if _, err := ctx.Buffer.WriteString(")"); err != nil { - return paramIndex, err - } - return newParamIndex, nil - case "_==_", "_!=_", "_<_", "_>_", "_<=_", "_>=_": - if len(v.CallExpr.Args) != 2 { - return paramIndex, errors.Errorf("invalid number of arguments for %s", v.CallExpr.Function) - } - // Check if the left side is a function call like size(tags) - if leftCallExpr, ok := v.CallExpr.Args[0].ExprKind.(*exprv1.Expr_CallExpr); ok { - if leftCallExpr.CallExpr.Function == "size" { - // Handle size(tags) comparison - if len(leftCallExpr.CallExpr.Args) != 1 { - return paramIndex, errors.New("size function requires exactly one argument") - } - identifier, err := filter.GetIdentExprName(leftCallExpr.CallExpr.Args[0]) - if err != nil { - return paramIndex, err - } - if identifier != "tags" { - return paramIndex, errors.Errorf("size function only supports 'tags' identifier, got: %s", identifier) - } - value, err := filter.GetExprValue(v.CallExpr.Args[1]) - if err != nil { - return paramIndex, err - } - valueInt, ok := value.(int64) - if !ok { - return paramIndex, errors.New("size comparison value must be an integer") - } - operator := d.getComparisonOperator(v.CallExpr.Function) - - if _, err := ctx.Buffer.WriteString(fmt.Sprintf("%s %s %s", - filter.GetSQL("json_array_length", dbType), operator, - filter.GetParameterPlaceholder(dbType, paramIndex))); err != nil { - return paramIndex, err - } - ctx.Args = append(ctx.Args, valueInt) - return paramIndex + 1, nil - } - } - - identifier, err := filter.GetIdentExprName(v.CallExpr.Args[0]) - if err != nil { - return paramIndex, err - } - if !slices.Contains([]string{"creator_id", "created_ts", "updated_ts", "visibility", "content", "has_task_list", "has_link", "has_code", "has_incomplete_tasks"}, identifier) { - return paramIndex, errors.Errorf("invalid identifier for %s", v.CallExpr.Function) - } - value, err := filter.GetExprValue(v.CallExpr.Args[1]) - if err != nil { - return paramIndex, err - } - operator := d.getComparisonOperator(v.CallExpr.Function) - - if identifier == "created_ts" || identifier == "updated_ts" { - valueInt, ok := value.(int64) - if !ok { - return paramIndex, errors.New("invalid integer timestamp value") - } - - timestampSQL := fmt.Sprintf(filter.GetSQL("timestamp_field", dbType), identifier) - if _, err := ctx.Buffer.WriteString(fmt.Sprintf("%s %s %s", timestampSQL, operator, - filter.GetParameterPlaceholder(dbType, paramIndex))); err != nil { - return paramIndex, err - } - ctx.Args = append(ctx.Args, valueInt) - return paramIndex + 1, nil - } else if identifier == "visibility" || identifier == "content" { - if operator != "=" && operator != "!=" { - return paramIndex, errors.Errorf("invalid operator for %s", v.CallExpr.Function) - } - valueStr, ok := value.(string) - if !ok { - return paramIndex, errors.New("invalid string value") - } - - var sqlTemplate string - if identifier == "visibility" { - sqlTemplate = filter.GetSQL("table_prefix", dbType) + ".visibility" - } else if identifier == "content" { - sqlTemplate = filter.GetSQL("content_like", dbType) - if _, err := ctx.Buffer.WriteString(sqlTemplate); err != nil { - return paramIndex, err - } - ctx.Args = append(ctx.Args, fmt.Sprintf("%%%s%%", valueStr)) - return paramIndex + 1, nil - } - if _, err := ctx.Buffer.WriteString(fmt.Sprintf("%s %s %s", sqlTemplate, operator, - filter.GetParameterPlaceholder(dbType, paramIndex))); err != nil { - return paramIndex, err - } - ctx.Args = append(ctx.Args, valueStr) - return paramIndex + 1, nil - } else if identifier == "creator_id" { - if operator != "=" && operator != "!=" { - return paramIndex, errors.Errorf("invalid operator for %s", v.CallExpr.Function) - } - valueInt, ok := value.(int64) - if !ok { - return paramIndex, errors.New("invalid int value") - } - - sqlTemplate := filter.GetSQL("table_prefix", dbType) + ".creator_id" - if _, err := ctx.Buffer.WriteString(fmt.Sprintf("%s %s %s", sqlTemplate, operator, - filter.GetParameterPlaceholder(dbType, paramIndex))); err != nil { - return paramIndex, err - } - ctx.Args = append(ctx.Args, valueInt) - return paramIndex + 1, nil - } else if identifier == "has_task_list" { - if operator != "=" && operator != "!=" { - return paramIndex, errors.Errorf("invalid operator for %s", v.CallExpr.Function) - } - valueBool, ok := value.(bool) - if !ok { - return paramIndex, errors.New("invalid boolean value for has_task_list") - } - // Use parameterized template for boolean comparison (PostgreSQL only) - placeholder := filter.GetParameterPlaceholder(dbType, paramIndex) - sqlTemplate := fmt.Sprintf(filter.GetSQL("boolean_compare", dbType), operator) - sqlTemplate = strings.Replace(sqlTemplate, "?", placeholder, 1) - if _, err := ctx.Buffer.WriteString(sqlTemplate); err != nil { - return paramIndex, err - } - ctx.Args = append(ctx.Args, valueBool) - return paramIndex + 1, nil - } else if identifier == "has_link" || identifier == "has_code" || identifier == "has_incomplete_tasks" { - if operator != "=" && operator != "!=" { - return paramIndex, errors.Errorf("invalid operator for %s", v.CallExpr.Function) - } - valueBool, ok := value.(bool) - if !ok { - return paramIndex, errors.Errorf("invalid boolean value for %s", identifier) - } - - // Map identifier to JSON path - var jsonPath string - switch identifier { - case "has_link": - jsonPath = "$.property.hasLink" - case "has_code": - jsonPath = "$.property.hasCode" - case "has_incomplete_tasks": - jsonPath = "$.property.hasIncompleteTasks" - } - - // Use JSON path for boolean comparison with PostgreSQL parameter placeholder - placeholder := filter.GetParameterPlaceholder(dbType, paramIndex) - var sqlTemplate string - if operator == "=" { - sqlTemplate = fmt.Sprintf("(%s->'payload'->'property'->>'%s')::boolean = %s", filter.GetSQL("table_prefix", dbType), strings.TrimPrefix(jsonPath, "$.property."), placeholder) - } else { // operator == "!=" - sqlTemplate = fmt.Sprintf("(%s->'payload'->'property'->>'%s')::boolean != %s", filter.GetSQL("table_prefix", dbType), strings.TrimPrefix(jsonPath, "$.property."), placeholder) - } - if _, err := ctx.Buffer.WriteString(sqlTemplate); err != nil { - return paramIndex, err - } - ctx.Args = append(ctx.Args, valueBool) - return paramIndex + 1, nil - } - case "@in": - if len(v.CallExpr.Args) != 2 { - return paramIndex, errors.Errorf("invalid number of arguments for %s", v.CallExpr.Function) - } - - // Check if this is "element in collection" syntax - if identifier, err := filter.GetIdentExprName(v.CallExpr.Args[1]); err == nil { - // This is "element in collection" - the second argument is the collection - if !slices.Contains([]string{"tags"}, identifier) { - return paramIndex, errors.Errorf("invalid collection identifier for %s: %s", v.CallExpr.Function, identifier) - } - - if identifier == "tags" { - // Handle "element" in tags - element, err := filter.GetConstValue(v.CallExpr.Args[0]) - if err != nil { - return paramIndex, errors.Errorf("first argument must be a constant value for 'element in tags': %v", err) - } - placeholder := filter.GetParameterPlaceholder(dbType, paramIndex) - sql := strings.Replace(filter.GetSQL("json_contains_element", dbType), "?", placeholder, 1) - if _, err := ctx.Buffer.WriteString(sql); err != nil { - return paramIndex, err - } - ctx.Args = append(ctx.Args, filter.GetParameterValue(dbType, "json_contains_element", element)) - return paramIndex + 1, nil - } - return paramIndex, nil - } - - // Original logic for "identifier in [list]" syntax - identifier, err := filter.GetIdentExprName(v.CallExpr.Args[0]) - if err != nil { - return paramIndex, err - } - if !slices.Contains([]string{"tag", "visibility"}, identifier) { - return paramIndex, errors.Errorf("invalid identifier for %s", v.CallExpr.Function) - } - - values := []any{} - for _, element := range v.CallExpr.Args[1].GetListExpr().Elements { - value, err := filter.GetConstValue(element) - if err != nil { - return paramIndex, err - } - values = append(values, value) - } - if identifier == "tag" { - subconditions := []string{} - args := []any{} - currentParamIndex := paramIndex - for _, v := range values { - // Use parameter index for each placeholder - placeholder := filter.GetParameterPlaceholder(dbType, currentParamIndex) - subcondition := strings.Replace(filter.GetSQL("json_contains_tag", dbType), "?", placeholder, 1) - subconditions = append(subconditions, subcondition) - args = append(args, filter.GetParameterValue(dbType, "json_contains_tag", v)) - currentParamIndex++ - } - if len(subconditions) == 1 { - if _, err := ctx.Buffer.WriteString(subconditions[0]); err != nil { - return paramIndex, err - } - } else { - if _, err := ctx.Buffer.WriteString(fmt.Sprintf("(%s)", strings.Join(subconditions, " OR "))); err != nil { - return paramIndex, err - } - } - ctx.Args = append(ctx.Args, args...) - return paramIndex + len(args), nil - } else if identifier == "visibility" { - placeholders := filter.FormatPlaceholders(dbType, len(values), paramIndex) - visibilitySQL := fmt.Sprintf(filter.GetSQL("visibility_in", dbType), strings.Join(placeholders, ",")) - if _, err := ctx.Buffer.WriteString(visibilitySQL); err != nil { - return paramIndex, err - } - ctx.Args = append(ctx.Args, values...) - return paramIndex + len(values), nil - } - case "contains": - if len(v.CallExpr.Args) != 1 { - return paramIndex, errors.Errorf("invalid number of arguments for %s", v.CallExpr.Function) - } - identifier, err := filter.GetIdentExprName(v.CallExpr.Target) - if err != nil { - return paramIndex, err - } - if identifier != "content" { - return paramIndex, errors.Errorf("invalid identifier for %s", v.CallExpr.Function) - } - arg, err := filter.GetConstValue(v.CallExpr.Args[0]) - if err != nil { - return paramIndex, err - } - placeholder := filter.GetParameterPlaceholder(dbType, paramIndex) - sql := strings.Replace(filter.GetSQL("content_like", dbType), "?", placeholder, 1) - if _, err := ctx.Buffer.WriteString(sql); err != nil { - return paramIndex, err - } - ctx.Args = append(ctx.Args, fmt.Sprintf("%%%s%%", arg)) - return paramIndex + 1, nil - } - } else if v, ok := expr.ExprKind.(*exprv1.Expr_IdentExpr); ok { - identifier := v.IdentExpr.GetName() - if !slices.Contains([]string{"pinned", "has_task_list", "has_link", "has_code", "has_incomplete_tasks"}, identifier) { - return paramIndex, errors.Errorf("invalid identifier %s", identifier) - } - if identifier == "pinned" { - if _, err := ctx.Buffer.WriteString(filter.GetSQL("table_prefix", dbType) + ".pinned IS TRUE"); err != nil { - return paramIndex, err - } - } else if identifier == "has_task_list" { - // Handle has_task_list as a standalone boolean identifier - if _, err := ctx.Buffer.WriteString(filter.GetSQL("boolean_check", dbType)); err != nil { - return paramIndex, err - } - } else if identifier == "has_link" { - // Handle has_link as a standalone boolean identifier - if _, err := ctx.Buffer.WriteString(fmt.Sprintf("(%s->'payload'->'property'->>'hasLink')::boolean = true", filter.GetSQL("table_prefix", dbType))); err != nil { - return paramIndex, err - } - } else if identifier == "has_code" { - // Handle has_code as a standalone boolean identifier - if _, err := ctx.Buffer.WriteString(fmt.Sprintf("(%s->'payload'->'property'->>'hasCode')::boolean = true", filter.GetSQL("table_prefix", dbType))); err != nil { - return paramIndex, err - } - } else if identifier == "has_incomplete_tasks" { - // Handle has_incomplete_tasks as a standalone boolean identifier - if _, err := ctx.Buffer.WriteString(fmt.Sprintf("(%s->'payload'->'property'->>'hasIncompleteTasks')::boolean = true", filter.GetSQL("table_prefix", dbType))); err != nil { - return paramIndex, err - } - } - } - return paramIndex, nil -} - -func (*DB) getComparisonOperator(function string) string { - switch function { - case "_==_": - return "=" - case "_!=_": - return "!=" - case "_<_": - return "<" - case "_>_": - return ">" - case "_<=_": - return "<=" - case "_>=_": - return ">=" - default: - return "=" - } -} diff --git a/store/db/postgres/memo_filter_test.go b/store/db/postgres/memo_filter_test.go index 755f88ebe..30f2c69e1 100644 --- a/store/db/postgres/memo_filter_test.go +++ b/store/db/postgres/memo_filter_test.go @@ -148,11 +148,11 @@ func TestConvertExprToSQL(t *testing.T) { } for _, tt := range tests { - db := &DB{} parsedExpr, err := filter.Parse(tt.filter, filter.MemoFilterCELAttributes...) require.NoError(t, err) convertCtx := filter.NewConvertContext() - err = db.ConvertExprToSQL(convertCtx, parsedExpr.GetExpr()) + converter := filter.NewCommonSQLConverterWithOffset(&filter.PostgreSQLDialect{}, convertCtx.ArgsOffset+len(convertCtx.Args)) + err = converter.ConvertExprToSQL(convertCtx, parsedExpr.GetExpr()) require.NoError(t, err) require.Equal(t, tt.want, convertCtx.Buffer.String()) require.Equal(t, tt.args, convertCtx.Args) diff --git a/store/db/postgres/memo_relation.go b/store/db/postgres/memo_relation.go index f76440861..5cc1cbd07 100644 --- a/store/db/postgres/memo_relation.go +++ b/store/db/postgres/memo_relation.go @@ -58,7 +58,8 @@ func (d *DB) ListMemoRelations(ctx context.Context, find *store.FindMemoRelation convertCtx := filter.NewConvertContext() convertCtx.ArgsOffset = len(args) // ConvertExprToSQL converts the parsed expression to a SQL condition string. - if err := d.ConvertExprToSQL(convertCtx, parsedExpr.GetExpr()); err != nil { + converter := filter.NewCommonSQLConverterWithOffset(&filter.PostgreSQLDialect{}, convertCtx.ArgsOffset+len(convertCtx.Args)) + if err := converter.ConvertExprToSQL(convertCtx, parsedExpr.GetExpr()); err != nil { return nil, err } condition := convertCtx.Buffer.String() diff --git a/store/db/sqlite/memo.go b/store/db/sqlite/memo.go index ace8c20fb..73f628a7e 100644 --- a/store/db/sqlite/memo.go +++ b/store/db/sqlite/memo.go @@ -51,7 +51,8 @@ func (d *DB) ListMemos(ctx context.Context, find *store.FindMemo) ([]*store.Memo } convertCtx := filter.NewConvertContext() // ConvertExprToSQL converts the parsed expression to a SQL condition string. - if err := d.ConvertExprToSQL(convertCtx, parsedExpr.GetExpr()); err != nil { + converter := filter.NewCommonSQLConverter(&filter.SQLiteDialect{}) + if err := converter.ConvertExprToSQL(convertCtx, parsedExpr.GetExpr()); err != nil { return nil, err } condition := convertCtx.Buffer.String() diff --git a/store/db/sqlite/memo_filter.go b/store/db/sqlite/memo_filter.go deleted file mode 100644 index e3e72893c..000000000 --- a/store/db/sqlite/memo_filter.go +++ /dev/null @@ -1,357 +0,0 @@ -package sqlite - -import ( - "fmt" - "slices" - "strings" - - "github.com/pkg/errors" - exprv1 "google.golang.org/genproto/googleapis/api/expr/v1alpha1" - - "github.com/usememos/memos/plugin/filter" -) - -func (d *DB) ConvertExprToSQL(ctx *filter.ConvertContext, expr *exprv1.Expr) error { - return d.convertWithTemplates(ctx, expr) -} - -func (d *DB) convertWithTemplates(ctx *filter.ConvertContext, expr *exprv1.Expr) error { - const dbType = filter.SQLiteTemplate - - if v, ok := expr.ExprKind.(*exprv1.Expr_CallExpr); ok { - switch v.CallExpr.Function { - case "_||_", "_&&_": - if len(v.CallExpr.Args) != 2 { - return errors.Errorf("invalid number of arguments for %s", v.CallExpr.Function) - } - if _, err := ctx.Buffer.WriteString("("); err != nil { - return err - } - if err := d.convertWithTemplates(ctx, v.CallExpr.Args[0]); err != nil { - return err - } - operator := "AND" - if v.CallExpr.Function == "_||_" { - operator = "OR" - } - if _, err := ctx.Buffer.WriteString(fmt.Sprintf(" %s ", operator)); err != nil { - return err - } - if err := d.convertWithTemplates(ctx, v.CallExpr.Args[1]); err != nil { - return err - } - if _, err := ctx.Buffer.WriteString(")"); err != nil { - return err - } - case "!_": - if len(v.CallExpr.Args) != 1 { - return errors.Errorf("invalid number of arguments for %s", v.CallExpr.Function) - } - if _, err := ctx.Buffer.WriteString("NOT ("); err != nil { - return err - } - if err := d.convertWithTemplates(ctx, v.CallExpr.Args[0]); err != nil { - return err - } - if _, err := ctx.Buffer.WriteString(")"); err != nil { - return err - } - case "_==_", "_!=_", "_<_", "_>_", "_<=_", "_>=_": - if len(v.CallExpr.Args) != 2 { - return errors.Errorf("invalid number of arguments for %s", v.CallExpr.Function) - } - // Check if the left side is a function call like size(tags) - if leftCallExpr, ok := v.CallExpr.Args[0].ExprKind.(*exprv1.Expr_CallExpr); ok { - if leftCallExpr.CallExpr.Function == "size" { - // Handle size(tags) comparison - if len(leftCallExpr.CallExpr.Args) != 1 { - return errors.New("size function requires exactly one argument") - } - identifier, err := filter.GetIdentExprName(leftCallExpr.CallExpr.Args[0]) - if err != nil { - return err - } - if identifier != "tags" { - return errors.Errorf("size function only supports 'tags' identifier, got: %s", identifier) - } - value, err := filter.GetExprValue(v.CallExpr.Args[1]) - if err != nil { - return err - } - valueInt, ok := value.(int64) - if !ok { - return errors.New("size comparison value must be an integer") - } - operator := d.getComparisonOperator(v.CallExpr.Function) - - if _, err := ctx.Buffer.WriteString(fmt.Sprintf("%s %s ?", - filter.GetSQL("json_array_length", dbType), operator)); err != nil { - return err - } - ctx.Args = append(ctx.Args, valueInt) - return nil - } - } - - identifier, err := filter.GetIdentExprName(v.CallExpr.Args[0]) - if err != nil { - return err - } - if !slices.Contains([]string{"creator_id", "created_ts", "updated_ts", "visibility", "content", "has_task_list", "has_link", "has_code", "has_incomplete_tasks"}, identifier) { - return errors.Errorf("invalid identifier for %s", v.CallExpr.Function) - } - value, err := filter.GetExprValue(v.CallExpr.Args[1]) - if err != nil { - return err - } - operator := d.getComparisonOperator(v.CallExpr.Function) - - if identifier == "created_ts" || identifier == "updated_ts" { - valueInt, ok := value.(int64) - if !ok { - return errors.New("invalid integer timestamp value") - } - - timestampSQL := fmt.Sprintf(filter.GetSQL("timestamp_field", dbType), identifier) - if _, err := ctx.Buffer.WriteString(fmt.Sprintf("%s %s ?", timestampSQL, operator)); err != nil { - return err - } - ctx.Args = append(ctx.Args, valueInt) - } else if identifier == "visibility" || identifier == "content" { - if operator != "=" && operator != "!=" { - return errors.Errorf("invalid operator for %s", v.CallExpr.Function) - } - valueStr, ok := value.(string) - if !ok { - return errors.New("invalid string value") - } - - var sqlTemplate string - if identifier == "visibility" { - sqlTemplate = filter.GetSQL("table_prefix", dbType) + ".`visibility`" - } else if identifier == "content" { - sqlTemplate = filter.GetSQL("table_prefix", dbType) + ".`content`" - } - if _, err := ctx.Buffer.WriteString(fmt.Sprintf("%s %s ?", sqlTemplate, operator)); err != nil { - return err - } - ctx.Args = append(ctx.Args, valueStr) - } else if identifier == "creator_id" { - if operator != "=" && operator != "!=" { - return errors.Errorf("invalid operator for %s", v.CallExpr.Function) - } - valueInt, ok := value.(int64) - if !ok { - return errors.New("invalid int value") - } - - sqlTemplate := filter.GetSQL("table_prefix", dbType) + ".`creator_id`" - if _, err := ctx.Buffer.WriteString(fmt.Sprintf("%s %s ?", sqlTemplate, operator)); err != nil { - return err - } - ctx.Args = append(ctx.Args, valueInt) - } else if identifier == "has_task_list" { - if operator != "=" && operator != "!=" { - return errors.Errorf("invalid operator for %s", v.CallExpr.Function) - } - valueBool, ok := value.(bool) - if !ok { - return errors.New("invalid boolean value for has_task_list") - } - // Use template for boolean comparison - var sqlTemplate string - if operator == "=" { - if valueBool { - sqlTemplate = filter.GetSQL("boolean_true", dbType) - } else { - sqlTemplate = filter.GetSQL("boolean_false", dbType) - } - } else { // operator == "!=" - if valueBool { - sqlTemplate = filter.GetSQL("boolean_not_true", dbType) - } else { - sqlTemplate = filter.GetSQL("boolean_not_false", dbType) - } - } - if _, err := ctx.Buffer.WriteString(sqlTemplate); err != nil { - return err - } - } else if identifier == "has_link" || identifier == "has_code" || identifier == "has_incomplete_tasks" { - if operator != "=" && operator != "!=" { - return errors.Errorf("invalid operator for %s", v.CallExpr.Function) - } - valueBool, ok := value.(bool) - if !ok { - return errors.Errorf("invalid boolean value for %s", identifier) - } - - // Map identifier to JSON path - var jsonPath string - switch identifier { - case "has_link": - jsonPath = "$.property.hasLink" - case "has_code": - jsonPath = "$.property.hasCode" - case "has_incomplete_tasks": - jsonPath = "$.property.hasIncompleteTasks" - } - - // Use JSON_EXTRACT for boolean comparison like has_task_list - var sqlTemplate string - if operator == "=" { - if valueBool { - sqlTemplate = fmt.Sprintf("JSON_EXTRACT(`memo`.`payload`, '%s') IS TRUE", jsonPath) - } else { - sqlTemplate = fmt.Sprintf("NOT(JSON_EXTRACT(`memo`.`payload`, '%s') IS TRUE)", jsonPath) - } - } else { // operator == "!=" - if valueBool { - sqlTemplate = fmt.Sprintf("NOT(JSON_EXTRACT(`memo`.`payload`, '%s') IS TRUE)", jsonPath) - } else { - sqlTemplate = fmt.Sprintf("JSON_EXTRACT(`memo`.`payload`, '%s') IS TRUE", jsonPath) - } - } - if _, err := ctx.Buffer.WriteString(sqlTemplate); err != nil { - return err - } - } - case "@in": - if len(v.CallExpr.Args) != 2 { - return errors.Errorf("invalid number of arguments for %s", v.CallExpr.Function) - } - - // Check if this is "element in collection" syntax - if identifier, err := filter.GetIdentExprName(v.CallExpr.Args[1]); err == nil { - // This is "element in collection" - the second argument is the collection - if !slices.Contains([]string{"tags"}, identifier) { - return errors.Errorf("invalid collection identifier for %s: %s", v.CallExpr.Function, identifier) - } - - if identifier == "tags" { - // Handle "element" in tags - element, err := filter.GetConstValue(v.CallExpr.Args[0]) - if err != nil { - return errors.Errorf("first argument must be a constant value for 'element in tags': %v", err) - } - if _, err := ctx.Buffer.WriteString(filter.GetSQL("json_contains_element", dbType)); err != nil { - return err - } - ctx.Args = append(ctx.Args, filter.GetParameterValue(dbType, "json_contains_element", element)) - } - return nil - } - - // Original logic for "identifier in [list]" syntax - identifier, err := filter.GetIdentExprName(v.CallExpr.Args[0]) - if err != nil { - return err - } - if !slices.Contains([]string{"tag", "visibility"}, identifier) { - return errors.Errorf("invalid identifier for %s", v.CallExpr.Function) - } - - values := []any{} - for _, element := range v.CallExpr.Args[1].GetListExpr().Elements { - value, err := filter.GetConstValue(element) - if err != nil { - return err - } - values = append(values, value) - } - if identifier == "tag" { - subconditions := []string{} - args := []any{} - for _, v := range values { - subconditions = append(subconditions, filter.GetSQL("json_contains_tag", dbType)) - args = append(args, filter.GetParameterValue(dbType, "json_contains_tag", v)) - } - if len(subconditions) == 1 { - if _, err := ctx.Buffer.WriteString(subconditions[0]); err != nil { - return err - } - } else { - if _, err := ctx.Buffer.WriteString(fmt.Sprintf("(%s)", strings.Join(subconditions, " OR "))); err != nil { - return err - } - } - ctx.Args = append(ctx.Args, args...) - } else if identifier == "visibility" { - placeholders := filter.FormatPlaceholders(dbType, len(values), 1) - visibilitySQL := fmt.Sprintf(filter.GetSQL("visibility_in", dbType), strings.Join(placeholders, ",")) - if _, err := ctx.Buffer.WriteString(visibilitySQL); err != nil { - return err - } - ctx.Args = append(ctx.Args, values...) - } - case "contains": - if len(v.CallExpr.Args) != 1 { - return errors.Errorf("invalid number of arguments for %s", v.CallExpr.Function) - } - identifier, err := filter.GetIdentExprName(v.CallExpr.Target) - if err != nil { - return err - } - if identifier != "content" { - return errors.Errorf("invalid identifier for %s", v.CallExpr.Function) - } - arg, err := filter.GetConstValue(v.CallExpr.Args[0]) - if err != nil { - return err - } - if _, err := ctx.Buffer.WriteString(filter.GetSQL("content_like", dbType)); err != nil { - return err - } - ctx.Args = append(ctx.Args, fmt.Sprintf("%%%s%%", arg)) - } - } else if v, ok := expr.ExprKind.(*exprv1.Expr_IdentExpr); ok { - identifier := v.IdentExpr.GetName() - if !slices.Contains([]string{"pinned", "has_task_list", "has_link", "has_code", "has_incomplete_tasks"}, identifier) { - return errors.Errorf("invalid identifier %s", identifier) - } - if identifier == "pinned" { - if _, err := ctx.Buffer.WriteString(filter.GetSQL("table_prefix", dbType) + ".`pinned` IS TRUE"); err != nil { - return err - } - } else if identifier == "has_task_list" { - // Handle has_task_list as a standalone boolean identifier - if _, err := ctx.Buffer.WriteString(filter.GetSQL("boolean_check", dbType)); err != nil { - return err - } - } else if identifier == "has_link" { - // Handle has_link as a standalone boolean identifier - if _, err := ctx.Buffer.WriteString("JSON_EXTRACT(`memo`.`payload`, '$.property.hasLink') IS TRUE"); err != nil { - return err - } - } else if identifier == "has_code" { - // Handle has_code as a standalone boolean identifier - if _, err := ctx.Buffer.WriteString("JSON_EXTRACT(`memo`.`payload`, '$.property.hasCode') IS TRUE"); err != nil { - return err - } - } else if identifier == "has_incomplete_tasks" { - // Handle has_incomplete_tasks as a standalone boolean identifier - if _, err := ctx.Buffer.WriteString("JSON_EXTRACT(`memo`.`payload`, '$.property.hasIncompleteTasks') IS TRUE"); err != nil { - return err - } - } - } - return nil -} - -func (*DB) getComparisonOperator(function string) string { - switch function { - case "_==_": - return "=" - case "_!=_": - return "!=" - case "_<_": - return "<" - case "_>_": - return ">" - case "_<=_": - return "<=" - case "_>=_": - return ">=" - default: - return "=" - } -} diff --git a/store/db/sqlite/memo_filter_test.go b/store/db/sqlite/memo_filter_test.go index cff143f82..6c67daab7 100644 --- a/store/db/sqlite/memo_filter_test.go +++ b/store/db/sqlite/memo_filter_test.go @@ -153,11 +153,11 @@ func TestConvertExprToSQL(t *testing.T) { } for _, tt := range tests { - db := &DB{} parsedExpr, err := filter.Parse(tt.filter, filter.MemoFilterCELAttributes...) require.NoError(t, err) convertCtx := filter.NewConvertContext() - err = db.ConvertExprToSQL(convertCtx, parsedExpr.GetExpr()) + converter := filter.NewCommonSQLConverter(&filter.SQLiteDialect{}) + err = converter.ConvertExprToSQL(convertCtx, parsedExpr.GetExpr()) require.NoError(t, err) require.Equal(t, tt.want, convertCtx.Buffer.String()) require.Equal(t, tt.args, convertCtx.Args) diff --git a/store/db/sqlite/memo_relation.go b/store/db/sqlite/memo_relation.go index 9507b163a..56182e7f4 100644 --- a/store/db/sqlite/memo_relation.go +++ b/store/db/sqlite/memo_relation.go @@ -57,7 +57,8 @@ func (d *DB) ListMemoRelations(ctx context.Context, find *store.FindMemoRelation } convertCtx := filter.NewConvertContext() // ConvertExprToSQL converts the parsed expression to a SQL condition string. - if err := d.ConvertExprToSQL(convertCtx, parsedExpr.GetExpr()); err != nil { + converter := filter.NewCommonSQLConverter(&filter.SQLiteDialect{}) + if err := converter.ConvertExprToSQL(convertCtx, parsedExpr.GetExpr()); err != nil { return nil, err } condition := convertCtx.Buffer.String() diff --git a/store/driver.go b/store/driver.go index 48fd6287e..bc35464f0 100644 --- a/store/driver.go +++ b/store/driver.go @@ -3,10 +3,6 @@ package store import ( "context" "database/sql" - - exprv1 "google.golang.org/genproto/googleapis/api/expr/v1alpha1" - - "github.com/usememos/memos/plugin/filter" ) // Driver is an interface for store driver. @@ -73,7 +69,4 @@ type Driver interface { UpsertReaction(ctx context.Context, create *Reaction) (*Reaction, error) ListReactions(ctx context.Context, find *FindReaction) ([]*Reaction, error) DeleteReaction(ctx context.Context, delete *DeleteReaction) error - - // Shortcut related methods. - ConvertExprToSQL(ctx *filter.ConvertContext, expr *exprv1.Expr) error } diff --git a/web/src/components/MemoEditor/ActionButton/AddMemoRelationPopover.tsx b/web/src/components/MemoEditor/ActionButton/AddMemoRelationPopover.tsx index ac43a91c6..38bad5be8 100644 --- a/web/src/components/MemoEditor/ActionButton/AddMemoRelationPopover.tsx +++ b/web/src/components/MemoEditor/ActionButton/AddMemoRelationPopover.tsx @@ -47,12 +47,14 @@ const AddMemoRelationPopover = (props: Props) => { setIsFetching(true); try { const conditions = []; + // Extract user ID from user name (format: users/{user_id}) + const userId = user.name.replace("users/", ""); + conditions.push(`creator_id == ${userId}`); if (searchText) { conditions.push(`content.contains("${searchText}")`); } const { memos } = await memoServiceClient.listMemos({ - parent: user.name, - filter: conditions.length > 0 ? conditions.join(" && ") : undefined, + filter: conditions.join(" && "), pageSize: DEFAULT_LIST_MEMOS_PAGE_SIZE, }); setFetchedMemos(memos); diff --git a/web/src/components/PagedMemoList/PagedMemoList.tsx b/web/src/components/PagedMemoList/PagedMemoList.tsx index 4ad774224..4e3133267 100644 --- a/web/src/components/PagedMemoList/PagedMemoList.tsx +++ b/web/src/components/PagedMemoList/PagedMemoList.tsx @@ -47,11 +47,20 @@ const PagedMemoList = observer((props: Props) => { setIsRequesting(true); try { + const filters = []; + if (props.owner) { + // Extract user ID from owner name (format: users/{user_id}) + const userId = props.owner.replace("users/", ""); + filters.push(`creator_id == ${userId}`); + } + if (props.filter) { + filters.push(props.filter); + } + const response = await memoStore.fetchMemos({ - parent: props.owner || "", state: props.state || State.NORMAL, orderBy: props.orderBy || "display_time desc", - filter: props.filter || "", + filter: filters.length > 0 ? filters.join(" && ") : undefined, pageSize: props.pageSize || DEFAULT_LIST_MEMOS_PAGE_SIZE, pageToken, }); diff --git a/web/src/pages/Home.tsx b/web/src/pages/Home.tsx index 38e6112f4..89bb3724b 100644 --- a/web/src/pages/Home.tsx +++ b/web/src/pages/Home.tsx @@ -5,6 +5,7 @@ import MemoView from "@/components/MemoView"; import PagedMemoList from "@/components/PagedMemoList"; import useCurrentUser from "@/hooks/useCurrentUser"; import { viewStore, userStore, workspaceStore } from "@/store"; +import { extractUserIdFromName } from "@/store/common"; import memoFilterStore from "@/store/memoFilter"; import { State } from "@/types/proto/api/v1/common"; import { Memo } from "@/types/proto/api/v1/memo_service"; @@ -22,7 +23,7 @@ const Home = observer(() => { const selectedShortcut = userStore.state.shortcuts.find((shortcut) => getShortcutId(shortcut.name) === memoFilterStore.shortcut); const memoFilter = useMemo(() => { - const conditions = []; + const conditions = [`creator_id == "${extractUserIdFromName(user.name)}"`]; if (selectedShortcut?.filter) { conditions.push(selectedShortcut.filter); } diff --git a/web/src/pages/UserProfile.tsx b/web/src/pages/UserProfile.tsx index 549346ed0..f7fd3d3f2 100644 --- a/web/src/pages/UserProfile.tsx +++ b/web/src/pages/UserProfile.tsx @@ -11,6 +11,7 @@ import UserAvatar from "@/components/UserAvatar"; import { Button } from "@/components/ui/button"; import useLoading from "@/hooks/useLoading"; import { viewStore, userStore } from "@/store"; +import { extractUserIdFromName } from "@/store/common"; import memoFilterStore from "@/store/memoFilter"; import { State } from "@/types/proto/api/v1/common"; import { Memo } from "@/types/proto/api/v1/memo_service"; @@ -46,7 +47,7 @@ const UserProfile = observer(() => { return undefined; } - const conditions = []; + const conditions = [`creator_id == "${extractUserIdFromName(user.name)}"`]; for (const filter of memoFilterStore.filters) { if (filter.factor === "contentSearch") { conditions.push(`content.contains("${filter.value}")`); diff --git a/web/src/store/common.ts b/web/src/store/common.ts index 347ac1c8b..1b0774cc2 100644 --- a/web/src/store/common.ts +++ b/web/src/store/common.ts @@ -4,6 +4,10 @@ export const memoNamePrefix = "memos/"; export const identityProviderNamePrefix = "identityProviders/"; export const activityNamePrefix = "activities/"; +export const extractUserIdFromName = (name: string) => { + return name.split(userNamePrefix).pop() || ""; +}; + export const extractMemoIdFromName = (name: string) => { return name.split(memoNamePrefix).pop() || ""; }; diff --git a/web/src/types/proto/api/v1/memo_service.ts b/web/src/types/proto/api/v1/memo_service.ts index 4d4120e94..d7ba52963 100644 --- a/web/src/types/proto/api/v1/memo_service.ts +++ b/web/src/types/proto/api/v1/memo_service.ts @@ -175,12 +175,6 @@ export interface CreateMemoRequest { } export interface ListMemosRequest { - /** - * Optional. The parent is the owner of the memos. - * If not specified or `users/-`, it will list all memos. - * Format: users/{user} - */ - parent: string; /** * Optional. The maximum number of memos to return. * The service may return fewer than this value. @@ -1090,30 +1084,19 @@ export const CreateMemoRequest: MessageFns = { }; function createBaseListMemosRequest(): ListMemosRequest { - return { - parent: "", - pageSize: 0, - pageToken: "", - state: State.STATE_UNSPECIFIED, - orderBy: "", - filter: "", - showDeleted: false, - }; + return { pageSize: 0, pageToken: "", state: State.STATE_UNSPECIFIED, orderBy: "", filter: "", showDeleted: false }; } export const ListMemosRequest: MessageFns = { encode(message: ListMemosRequest, writer: BinaryWriter = new BinaryWriter()): BinaryWriter { - if (message.parent !== "") { - writer.uint32(10).string(message.parent); - } if (message.pageSize !== 0) { - writer.uint32(16).int32(message.pageSize); + writer.uint32(8).int32(message.pageSize); } if (message.pageToken !== "") { - writer.uint32(26).string(message.pageToken); + writer.uint32(18).string(message.pageToken); } if (message.state !== State.STATE_UNSPECIFIED) { - writer.uint32(32).int32(stateToNumber(message.state)); + writer.uint32(24).int32(stateToNumber(message.state)); } if (message.orderBy !== "") { writer.uint32(42).string(message.orderBy); @@ -1135,31 +1118,23 @@ export const ListMemosRequest: MessageFns = { const tag = reader.uint32(); switch (tag >>> 3) { case 1: { - if (tag !== 10) { - break; - } - - message.parent = reader.string(); - continue; - } - case 2: { - if (tag !== 16) { + if (tag !== 8) { break; } message.pageSize = reader.int32(); continue; } - case 3: { - if (tag !== 26) { + case 2: { + if (tag !== 18) { break; } message.pageToken = reader.string(); continue; } - case 4: { - if (tag !== 32) { + case 3: { + if (tag !== 24) { break; } @@ -1204,7 +1179,6 @@ export const ListMemosRequest: MessageFns = { }, fromPartial(object: DeepPartial): ListMemosRequest { const message = createBaseListMemosRequest(); - message.parent = object.parent ?? ""; message.pageSize = object.pageSize ?? 0; message.pageToken = object.pageToken ?? ""; message.state = object.state ?? State.STATE_UNSPECIFIED; @@ -2662,61 +2636,8 @@ export const MemoServiceDefinition = { responseStream: false, options: { _unknownFields: { - 8410: [new Uint8Array([0]), new Uint8Array([6, 112, 97, 114, 101, 110, 116])], - 578365826: [ - new Uint8Array([ - 49, - 90, - 32, - 18, - 30, - 47, - 97, - 112, - 105, - 47, - 118, - 49, - 47, - 123, - 112, - 97, - 114, - 101, - 110, - 116, - 61, - 117, - 115, - 101, - 114, - 115, - 47, - 42, - 125, - 47, - 109, - 101, - 109, - 111, - 115, - 18, - 13, - 47, - 97, - 112, - 105, - 47, - 118, - 49, - 47, - 109, - 101, - 109, - 111, - 115, - ]), - ], + 8410: [new Uint8Array([0])], + 578365826: [new Uint8Array([15, 18, 13, 47, 97, 112, 105, 47, 118, 49, 47, 109, 101, 109, 111, 115])], }, }, }, diff --git a/web/src/types/proto/google/protobuf/descriptor.ts b/web/src/types/proto/google/protobuf/descriptor.ts index 89514564e..db1d2d4a3 100644 --- a/web/src/types/proto/google/protobuf/descriptor.ts +++ b/web/src/types/proto/google/protobuf/descriptor.ts @@ -128,6 +128,52 @@ export function editionToNumber(object: Edition): number { } } +/** + * Describes the 'visibility' of a symbol with respect to the proto import + * system. Symbols can only be imported when the visibility rules do not prevent + * it (ex: local symbols cannot be imported). Visibility modifiers can only set + * on `message` and `enum` as they are the only types available to be referenced + * from other files. + */ +export enum SymbolVisibility { + VISIBILITY_UNSET = "VISIBILITY_UNSET", + VISIBILITY_LOCAL = "VISIBILITY_LOCAL", + VISIBILITY_EXPORT = "VISIBILITY_EXPORT", + UNRECOGNIZED = "UNRECOGNIZED", +} + +export function symbolVisibilityFromJSON(object: any): SymbolVisibility { + switch (object) { + case 0: + case "VISIBILITY_UNSET": + return SymbolVisibility.VISIBILITY_UNSET; + case 1: + case "VISIBILITY_LOCAL": + return SymbolVisibility.VISIBILITY_LOCAL; + case 2: + case "VISIBILITY_EXPORT": + return SymbolVisibility.VISIBILITY_EXPORT; + case -1: + case "UNRECOGNIZED": + default: + return SymbolVisibility.UNRECOGNIZED; + } +} + +export function symbolVisibilityToNumber(object: SymbolVisibility): number { + switch (object) { + case SymbolVisibility.VISIBILITY_UNSET: + return 0; + case SymbolVisibility.VISIBILITY_LOCAL: + return 1; + case SymbolVisibility.VISIBILITY_EXPORT: + return 2; + case SymbolVisibility.UNRECOGNIZED: + default: + return -1; + } +} + /** * The protocol compiler can output a FileDescriptorSet containing the .proto * files it parses. @@ -155,6 +201,11 @@ export interface FileDescriptorProto { * For Google-internal migration only. Do not use. */ weakDependency: number[]; + /** + * Names of files imported by this file purely for the purpose of providing + * option extensions. These are excluded from the dependency list above. + */ + optionDependency: string[]; /** All top-level definitions in this file. */ messageType: DescriptorProto[]; enumType: EnumDescriptorProto[]; @@ -209,6 +260,8 @@ export interface DescriptorProto { * A given name may only be reserved once. */ reservedName: string[]; + /** Support for `export` and `local` keywords on enums. */ + visibility?: SymbolVisibility | undefined; } export interface DescriptorProto_ExtensionRange { @@ -632,6 +685,8 @@ export interface EnumDescriptorProto { * be reserved once. */ reservedName: string[]; + /** Support for `export` and `local` keywords on enums. */ + visibility?: SymbolVisibility | undefined; } /** @@ -1594,6 +1649,7 @@ export interface FeatureSet { messageEncoding?: FeatureSet_MessageEncoding | undefined; jsonFormat?: FeatureSet_JsonFormat | undefined; enforceNamingStyle?: FeatureSet_EnforceNamingStyle | undefined; + defaultSymbolVisibility?: FeatureSet_VisibilityFeature_DefaultSymbolVisibility | undefined; } export enum FeatureSet_FieldPresence { @@ -1875,6 +1931,72 @@ export function featureSet_EnforceNamingStyleToNumber(object: FeatureSet_Enforce } } +export interface FeatureSet_VisibilityFeature { +} + +export enum FeatureSet_VisibilityFeature_DefaultSymbolVisibility { + DEFAULT_SYMBOL_VISIBILITY_UNKNOWN = "DEFAULT_SYMBOL_VISIBILITY_UNKNOWN", + /** EXPORT_ALL - Default pre-EDITION_2024, all UNSET visibility are export. */ + EXPORT_ALL = "EXPORT_ALL", + /** EXPORT_TOP_LEVEL - All top-level symbols default to export, nested default to local. */ + EXPORT_TOP_LEVEL = "EXPORT_TOP_LEVEL", + /** LOCAL_ALL - All symbols default to local. */ + LOCAL_ALL = "LOCAL_ALL", + /** + * STRICT - All symbols local by default. Nested types cannot be exported. + * With special case caveat for message { enum {} reserved 1 to max; } + * This is the recommended setting for new protos. + */ + STRICT = "STRICT", + UNRECOGNIZED = "UNRECOGNIZED", +} + +export function featureSet_VisibilityFeature_DefaultSymbolVisibilityFromJSON( + object: any, +): FeatureSet_VisibilityFeature_DefaultSymbolVisibility { + switch (object) { + case 0: + case "DEFAULT_SYMBOL_VISIBILITY_UNKNOWN": + return FeatureSet_VisibilityFeature_DefaultSymbolVisibility.DEFAULT_SYMBOL_VISIBILITY_UNKNOWN; + case 1: + case "EXPORT_ALL": + return FeatureSet_VisibilityFeature_DefaultSymbolVisibility.EXPORT_ALL; + case 2: + case "EXPORT_TOP_LEVEL": + return FeatureSet_VisibilityFeature_DefaultSymbolVisibility.EXPORT_TOP_LEVEL; + case 3: + case "LOCAL_ALL": + return FeatureSet_VisibilityFeature_DefaultSymbolVisibility.LOCAL_ALL; + case 4: + case "STRICT": + return FeatureSet_VisibilityFeature_DefaultSymbolVisibility.STRICT; + case -1: + case "UNRECOGNIZED": + default: + return FeatureSet_VisibilityFeature_DefaultSymbolVisibility.UNRECOGNIZED; + } +} + +export function featureSet_VisibilityFeature_DefaultSymbolVisibilityToNumber( + object: FeatureSet_VisibilityFeature_DefaultSymbolVisibility, +): number { + switch (object) { + case FeatureSet_VisibilityFeature_DefaultSymbolVisibility.DEFAULT_SYMBOL_VISIBILITY_UNKNOWN: + return 0; + case FeatureSet_VisibilityFeature_DefaultSymbolVisibility.EXPORT_ALL: + return 1; + case FeatureSet_VisibilityFeature_DefaultSymbolVisibility.EXPORT_TOP_LEVEL: + return 2; + case FeatureSet_VisibilityFeature_DefaultSymbolVisibility.LOCAL_ALL: + return 3; + case FeatureSet_VisibilityFeature_DefaultSymbolVisibility.STRICT: + return 4; + case FeatureSet_VisibilityFeature_DefaultSymbolVisibility.UNRECOGNIZED: + default: + return -1; + } +} + /** * A compiled specification for the defaults of a set of features. These * messages are generated from FeatureSet extensions and can be used to seed @@ -2195,6 +2317,7 @@ function createBaseFileDescriptorProto(): FileDescriptorProto { dependency: [], publicDependency: [], weakDependency: [], + optionDependency: [], messageType: [], enumType: [], service: [], @@ -2227,6 +2350,9 @@ export const FileDescriptorProto: MessageFns = { writer.int32(v); } writer.join(); + for (const v of message.optionDependency) { + writer.uint32(122).string(v!); + } for (const v of message.messageType) { DescriptorProto.encode(v!, writer.uint32(34).fork()).join(); } @@ -2321,6 +2447,14 @@ export const FileDescriptorProto: MessageFns = { break; } + case 15: { + if (tag !== 122) { + break; + } + + message.optionDependency.push(reader.string()); + continue; + } case 4: { if (tag !== 34) { break; @@ -2404,6 +2538,7 @@ export const FileDescriptorProto: MessageFns = { message.dependency = object.dependency?.map((e) => e) || []; message.publicDependency = object.publicDependency?.map((e) => e) || []; message.weakDependency = object.weakDependency?.map((e) => e) || []; + message.optionDependency = object.optionDependency?.map((e) => e) || []; message.messageType = object.messageType?.map((e) => DescriptorProto.fromPartial(e)) || []; message.enumType = object.enumType?.map((e) => EnumDescriptorProto.fromPartial(e)) || []; message.service = object.service?.map((e) => ServiceDescriptorProto.fromPartial(e)) || []; @@ -2432,6 +2567,7 @@ function createBaseDescriptorProto(): DescriptorProto { options: undefined, reservedRange: [], reservedName: [], + visibility: SymbolVisibility.VISIBILITY_UNSET, }; } @@ -2467,6 +2603,9 @@ export const DescriptorProto: MessageFns = { for (const v of message.reservedName) { writer.uint32(82).string(v!); } + if (message.visibility !== undefined && message.visibility !== SymbolVisibility.VISIBILITY_UNSET) { + writer.uint32(88).int32(symbolVisibilityToNumber(message.visibility)); + } return writer; }, @@ -2557,6 +2696,14 @@ export const DescriptorProto: MessageFns = { message.reservedName.push(reader.string()); continue; } + case 11: { + if (tag !== 88) { + break; + } + + message.visibility = symbolVisibilityFromJSON(reader.int32()); + continue; + } } if ((tag & 7) === 4 || tag === 0) { break; @@ -2583,6 +2730,7 @@ export const DescriptorProto: MessageFns = { : undefined; message.reservedRange = object.reservedRange?.map((e) => DescriptorProto_ReservedRange.fromPartial(e)) || []; message.reservedName = object.reservedName?.map((e) => e) || []; + message.visibility = object.visibility ?? SymbolVisibility.VISIBILITY_UNSET; return message; }, }; @@ -3143,7 +3291,14 @@ export const OneofDescriptorProto: MessageFns = { }; function createBaseEnumDescriptorProto(): EnumDescriptorProto { - return { name: "", value: [], options: undefined, reservedRange: [], reservedName: [] }; + return { + name: "", + value: [], + options: undefined, + reservedRange: [], + reservedName: [], + visibility: SymbolVisibility.VISIBILITY_UNSET, + }; } export const EnumDescriptorProto: MessageFns = { @@ -3163,6 +3318,9 @@ export const EnumDescriptorProto: MessageFns = { for (const v of message.reservedName) { writer.uint32(42).string(v!); } + if (message.visibility !== undefined && message.visibility !== SymbolVisibility.VISIBILITY_UNSET) { + writer.uint32(48).int32(symbolVisibilityToNumber(message.visibility)); + } return writer; }, @@ -3213,6 +3371,14 @@ export const EnumDescriptorProto: MessageFns = { message.reservedName.push(reader.string()); continue; } + case 6: { + if (tag !== 48) { + break; + } + + message.visibility = symbolVisibilityFromJSON(reader.int32()); + continue; + } } if ((tag & 7) === 4 || tag === 0) { break; @@ -3235,6 +3401,7 @@ export const EnumDescriptorProto: MessageFns = { message.reservedRange = object.reservedRange?.map((e) => EnumDescriptorProto_EnumReservedRange.fromPartial(e)) || []; message.reservedName = object.reservedName?.map((e) => e) || []; + message.visibility = object.visibility ?? SymbolVisibility.VISIBILITY_UNSET; return message; }, }; @@ -4999,6 +5166,7 @@ function createBaseFeatureSet(): FeatureSet { messageEncoding: FeatureSet_MessageEncoding.MESSAGE_ENCODING_UNKNOWN, jsonFormat: FeatureSet_JsonFormat.JSON_FORMAT_UNKNOWN, enforceNamingStyle: FeatureSet_EnforceNamingStyle.ENFORCE_NAMING_STYLE_UNKNOWN, + defaultSymbolVisibility: FeatureSet_VisibilityFeature_DefaultSymbolVisibility.DEFAULT_SYMBOL_VISIBILITY_UNKNOWN, }; } @@ -5039,6 +5207,15 @@ export const FeatureSet: MessageFns = { ) { writer.uint32(56).int32(featureSet_EnforceNamingStyleToNumber(message.enforceNamingStyle)); } + if ( + message.defaultSymbolVisibility !== undefined && + message.defaultSymbolVisibility !== + FeatureSet_VisibilityFeature_DefaultSymbolVisibility.DEFAULT_SYMBOL_VISIBILITY_UNKNOWN + ) { + writer.uint32(64).int32( + featureSet_VisibilityFeature_DefaultSymbolVisibilityToNumber(message.defaultSymbolVisibility), + ); + } return writer; }, @@ -5105,6 +5282,16 @@ export const FeatureSet: MessageFns = { message.enforceNamingStyle = featureSet_EnforceNamingStyleFromJSON(reader.int32()); continue; } + case 8: { + if (tag !== 64) { + break; + } + + message.defaultSymbolVisibility = featureSet_VisibilityFeature_DefaultSymbolVisibilityFromJSON( + reader.int32(), + ); + continue; + } } if ((tag & 7) === 4 || tag === 0) { break; @@ -5128,6 +5315,42 @@ export const FeatureSet: MessageFns = { message.jsonFormat = object.jsonFormat ?? FeatureSet_JsonFormat.JSON_FORMAT_UNKNOWN; message.enforceNamingStyle = object.enforceNamingStyle ?? FeatureSet_EnforceNamingStyle.ENFORCE_NAMING_STYLE_UNKNOWN; + message.defaultSymbolVisibility = object.defaultSymbolVisibility ?? + FeatureSet_VisibilityFeature_DefaultSymbolVisibility.DEFAULT_SYMBOL_VISIBILITY_UNKNOWN; + return message; + }, +}; + +function createBaseFeatureSet_VisibilityFeature(): FeatureSet_VisibilityFeature { + return {}; +} + +export const FeatureSet_VisibilityFeature: MessageFns = { + encode(_: FeatureSet_VisibilityFeature, writer: BinaryWriter = new BinaryWriter()): BinaryWriter { + return writer; + }, + + decode(input: BinaryReader | Uint8Array, length?: number): FeatureSet_VisibilityFeature { + const reader = input instanceof BinaryReader ? input : new BinaryReader(input); + let end = length === undefined ? reader.len : reader.pos + length; + const message = createBaseFeatureSet_VisibilityFeature(); + while (reader.pos < end) { + const tag = reader.uint32(); + switch (tag >>> 3) { + } + if ((tag & 7) === 4 || tag === 0) { + break; + } + reader.skip(tag & 7); + } + return message; + }, + + create(base?: DeepPartial): FeatureSet_VisibilityFeature { + return FeatureSet_VisibilityFeature.fromPartial(base ?? {}); + }, + fromPartial(_: DeepPartial): FeatureSet_VisibilityFeature { + const message = createBaseFeatureSet_VisibilityFeature(); return message; }, };