diff --git a/plugin/filter/filter.go b/plugin/filter/filter.go index 0643176b..be1d49eb 100644 --- a/plugin/filter/filter.go +++ b/plugin/filter/filter.go @@ -9,6 +9,7 @@ import ( // MemoFilterCELAttributes are the CEL attributes for memo. var MemoFilterCELAttributes = []cel.EnvOption{ cel.Variable("content", cel.StringType), + cel.Variable("creator_id", cel.IntType), // As the built-in timestamp type is deprecated, we use string type for now. // e.g., "2021-01-01T00:00:00Z" cel.Variable("create_time", cel.StringType), diff --git a/server/router/api/v1/memo_service.go b/server/router/api/v1/memo_service.go index b47d87db..93caf6e1 100644 --- a/server/router/api/v1/memo_service.go +++ b/server/router/api/v1/memo_service.go @@ -120,6 +120,9 @@ func (s *APIV1Service) ListMemos(ctx context.Context, request *v1pb.ListMemosReq memoFind.OrderByTimeAsc = true } if request.Filter != "" { + if err := s.validateFilter(ctx, request.Filter); err != nil { + return nil, status.Errorf(codes.InvalidArgument, "invalid filter: %v", err) + } memoFind.Filter = &request.Filter } @@ -129,8 +132,18 @@ func (s *APIV1Service) ListMemos(ctx context.Context, request *v1pb.ListMemosReq } if currentUser == nil { memoFind.VisibilityList = []store.Visibility{store.Public} - } else if memoFind.CreatorID == nil || *memoFind.CreatorID != currentUser.ID { - memoFind.VisibilityList = []store.Visibility{store.Public, store.Protected} + } else { + if memoFind.CreatorID == nil { + internalFilter := fmt.Sprintf(`creator_id == %d || visibility in ["PUBLIC", "Protected"]`, currentUser.ID) + if memoFind.Filter != nil { + filter := fmt.Sprintf("(%s) && (%s)", *memoFind.Filter, internalFilter) + memoFind.Filter = &filter + } else { + memoFind.Filter = &internalFilter + } + } else if *memoFind.CreatorID != currentUser.ID { + memoFind.VisibilityList = []store.Visibility{store.Public, store.Protected} + } } workspaceMemoRelatedSetting, err := s.Store.GetWorkspaceMemoRelatedSetting(ctx) diff --git a/store/db/mysql/memo_filter.go b/store/db/mysql/memo_filter.go index 544735ba..dcdfb63d 100644 --- a/store/db/mysql/memo_filter.go +++ b/store/db/mysql/memo_filter.go @@ -59,7 +59,7 @@ func (d *DB) ConvertExprToSQL(ctx *filter.ConvertContext, expr *exprv1.Expr) err if err != nil { return err } - if !slices.Contains([]string{"create_time", "update_time"}, identifier) { + if !slices.Contains([]string{"creator_id", "create_time", "update_time", "visibility", "content"}, identifier) { return errors.Errorf("invalid identifier for %s", v.CallExpr.Function) } value, err := filter.GetConstValue(v.CallExpr.Args[1]) @@ -121,6 +121,23 @@ func (d *DB) ConvertExprToSQL(ctx *filter.ConvertContext, expr *exprv1.Expr) err return err } ctx.Args = append(ctx.Args, valueStr) + } else if identifier == "creator_id" { + if operator != "=" && operator != "!=" { + return errors.Errorf("invalid operator for %s", v.CallExpr.Function) + } + valueInt, ok := value.(int64) + if !ok { + return errors.New("invalid int value") + } + + var factor string + if identifier == "creator_id" { + factor = "`memo`.`creator_id`" + } + if _, err := ctx.Buffer.WriteString(fmt.Sprintf("%s %s ?", factor, operator)); err != nil { + return err + } + ctx.Args = append(ctx.Args, valueInt) } case "@in": if len(v.CallExpr.Args) != 2 { diff --git a/store/db/postgres/memo_filter.go b/store/db/postgres/memo_filter.go index 4f3b4065..ee6d78f8 100644 --- a/store/db/postgres/memo_filter.go +++ b/store/db/postgres/memo_filter.go @@ -59,7 +59,7 @@ func (d *DB) ConvertExprToSQL(ctx *filter.ConvertContext, expr *exprv1.Expr) err if err != nil { return err } - if !slices.Contains([]string{"create_time", "update_time", "visibility", "content"}, identifier) { + if !slices.Contains([]string{"creator_id", "create_time", "update_time", "visibility", "content"}, identifier) { return errors.Errorf("invalid identifier for %s", v.CallExpr.Function) } value, err := filter.GetConstValue(v.CallExpr.Args[1]) @@ -121,6 +121,23 @@ func (d *DB) ConvertExprToSQL(ctx *filter.ConvertContext, expr *exprv1.Expr) err return err } ctx.Args = append(ctx.Args, valueStr) + } else if identifier == "creator_id" { + if operator != "=" && operator != "!=" { + return errors.Errorf("invalid operator for %s", v.CallExpr.Function) + } + valueInt, ok := value.(int64) + if !ok { + return errors.New("invalid int value") + } + + var factor string + if identifier == "creator_id" { + factor = "memo.creator_id" + } + if _, err := ctx.Buffer.WriteString(fmt.Sprintf("%s %s ?", factor, operator)); err != nil { + return err + } + ctx.Args = append(ctx.Args, valueInt) } case "@in": if len(v.CallExpr.Args) != 2 { diff --git a/store/db/sqlite/memo_filter.go b/store/db/sqlite/memo_filter.go index 71829cf0..452272ea 100644 --- a/store/db/sqlite/memo_filter.go +++ b/store/db/sqlite/memo_filter.go @@ -59,7 +59,7 @@ func (d *DB) ConvertExprToSQL(ctx *filter.ConvertContext, expr *exprv1.Expr) err if err != nil { return err } - if !slices.Contains([]string{"create_time", "update_time", "visibility", "content"}, identifier) { + if !slices.Contains([]string{"creator_id", "create_time", "update_time", "visibility", "content"}, identifier) { return errors.Errorf("invalid identifier for %s", v.CallExpr.Function) } value, err := filter.GetConstValue(v.CallExpr.Args[1]) @@ -121,6 +121,23 @@ func (d *DB) ConvertExprToSQL(ctx *filter.ConvertContext, expr *exprv1.Expr) err return err } ctx.Args = append(ctx.Args, valueStr) + } else if identifier == "creator_id" { + if operator != "=" && operator != "!=" { + return errors.Errorf("invalid operator for %s", v.CallExpr.Function) + } + valueInt, ok := value.(int64) + if !ok { + return errors.New("invalid int value") + } + + var factor string + if identifier == "creator_id" { + factor = "`memo`.`creator_id`" + } + if _, err := ctx.Buffer.WriteString(fmt.Sprintf("%s %s ?", factor, operator)); err != nil { + return err + } + ctx.Args = append(ctx.Args, valueInt) } case "@in": if len(v.CallExpr.Args) != 2 { diff --git a/store/db/sqlite/memo_filter_test.go b/store/db/sqlite/memo_filter_test.go index 2a75fc81..c9f3f3a0 100644 --- a/store/db/sqlite/memo_filter_test.go +++ b/store/db/sqlite/memo_filter_test.go @@ -69,6 +69,11 @@ func TestConvertExprToSQL(t *testing.T) { want: "NOT (`memo`.`pinned` IS TRUE)", args: []any{}, }, + { + filter: `creator_id == 101 || visibility in ["PUBLIC", "PRIVATE"]`, + want: "(`memo`.`creator_id` = ? OR `memo`.`visibility` IN (?,?))", + args: []any{int64(101), "PUBLIC", "PRIVATE"}, + }, } for _, tt := range tests {