feat: implement creator_id factor

pull/4491/head
Johnny 1 month ago
parent ba52a786f9
commit e3a4f49c5c

@ -9,6 +9,7 @@ import (
// MemoFilterCELAttributes are the CEL attributes for memo. // MemoFilterCELAttributes are the CEL attributes for memo.
var MemoFilterCELAttributes = []cel.EnvOption{ var MemoFilterCELAttributes = []cel.EnvOption{
cel.Variable("content", cel.StringType), cel.Variable("content", cel.StringType),
cel.Variable("creator_id", cel.IntType),
// As the built-in timestamp type is deprecated, we use string type for now. // As the built-in timestamp type is deprecated, we use string type for now.
// e.g., "2021-01-01T00:00:00Z" // e.g., "2021-01-01T00:00:00Z"
cel.Variable("create_time", cel.StringType), cel.Variable("create_time", cel.StringType),

@ -120,6 +120,9 @@ func (s *APIV1Service) ListMemos(ctx context.Context, request *v1pb.ListMemosReq
memoFind.OrderByTimeAsc = true memoFind.OrderByTimeAsc = true
} }
if request.Filter != "" { if request.Filter != "" {
if err := s.validateFilter(ctx, request.Filter); err != nil {
return nil, status.Errorf(codes.InvalidArgument, "invalid filter: %v", err)
}
memoFind.Filter = &request.Filter memoFind.Filter = &request.Filter
} }
@ -129,9 +132,19 @@ func (s *APIV1Service) ListMemos(ctx context.Context, request *v1pb.ListMemosReq
} }
if currentUser == nil { if currentUser == nil {
memoFind.VisibilityList = []store.Visibility{store.Public} memoFind.VisibilityList = []store.Visibility{store.Public}
} else if memoFind.CreatorID == nil || *memoFind.CreatorID != currentUser.ID { } else {
if memoFind.CreatorID == nil {
internalFilter := fmt.Sprintf(`creator_id == %d || visibility in ["PUBLIC", "Protected"]`, currentUser.ID)
if memoFind.Filter != nil {
filter := fmt.Sprintf("(%s) && (%s)", *memoFind.Filter, internalFilter)
memoFind.Filter = &filter
} else {
memoFind.Filter = &internalFilter
}
} else if *memoFind.CreatorID != currentUser.ID {
memoFind.VisibilityList = []store.Visibility{store.Public, store.Protected} memoFind.VisibilityList = []store.Visibility{store.Public, store.Protected}
} }
}
workspaceMemoRelatedSetting, err := s.Store.GetWorkspaceMemoRelatedSetting(ctx) workspaceMemoRelatedSetting, err := s.Store.GetWorkspaceMemoRelatedSetting(ctx)
if err != nil { if err != nil {

@ -59,7 +59,7 @@ func (d *DB) ConvertExprToSQL(ctx *filter.ConvertContext, expr *exprv1.Expr) err
if err != nil { if err != nil {
return err return err
} }
if !slices.Contains([]string{"create_time", "update_time"}, identifier) { if !slices.Contains([]string{"creator_id", "create_time", "update_time", "visibility", "content"}, identifier) {
return errors.Errorf("invalid identifier for %s", v.CallExpr.Function) return errors.Errorf("invalid identifier for %s", v.CallExpr.Function)
} }
value, err := filter.GetConstValue(v.CallExpr.Args[1]) value, err := filter.GetConstValue(v.CallExpr.Args[1])
@ -121,6 +121,23 @@ func (d *DB) ConvertExprToSQL(ctx *filter.ConvertContext, expr *exprv1.Expr) err
return err return err
} }
ctx.Args = append(ctx.Args, valueStr) ctx.Args = append(ctx.Args, valueStr)
} else if identifier == "creator_id" {
if operator != "=" && operator != "!=" {
return errors.Errorf("invalid operator for %s", v.CallExpr.Function)
}
valueInt, ok := value.(int64)
if !ok {
return errors.New("invalid int value")
}
var factor string
if identifier == "creator_id" {
factor = "`memo`.`creator_id`"
}
if _, err := ctx.Buffer.WriteString(fmt.Sprintf("%s %s ?", factor, operator)); err != nil {
return err
}
ctx.Args = append(ctx.Args, valueInt)
} }
case "@in": case "@in":
if len(v.CallExpr.Args) != 2 { if len(v.CallExpr.Args) != 2 {

@ -59,7 +59,7 @@ func (d *DB) ConvertExprToSQL(ctx *filter.ConvertContext, expr *exprv1.Expr) err
if err != nil { if err != nil {
return err return err
} }
if !slices.Contains([]string{"create_time", "update_time", "visibility", "content"}, identifier) { if !slices.Contains([]string{"creator_id", "create_time", "update_time", "visibility", "content"}, identifier) {
return errors.Errorf("invalid identifier for %s", v.CallExpr.Function) return errors.Errorf("invalid identifier for %s", v.CallExpr.Function)
} }
value, err := filter.GetConstValue(v.CallExpr.Args[1]) value, err := filter.GetConstValue(v.CallExpr.Args[1])
@ -121,6 +121,23 @@ func (d *DB) ConvertExprToSQL(ctx *filter.ConvertContext, expr *exprv1.Expr) err
return err return err
} }
ctx.Args = append(ctx.Args, valueStr) ctx.Args = append(ctx.Args, valueStr)
} else if identifier == "creator_id" {
if operator != "=" && operator != "!=" {
return errors.Errorf("invalid operator for %s", v.CallExpr.Function)
}
valueInt, ok := value.(int64)
if !ok {
return errors.New("invalid int value")
}
var factor string
if identifier == "creator_id" {
factor = "memo.creator_id"
}
if _, err := ctx.Buffer.WriteString(fmt.Sprintf("%s %s ?", factor, operator)); err != nil {
return err
}
ctx.Args = append(ctx.Args, valueInt)
} }
case "@in": case "@in":
if len(v.CallExpr.Args) != 2 { if len(v.CallExpr.Args) != 2 {

@ -59,7 +59,7 @@ func (d *DB) ConvertExprToSQL(ctx *filter.ConvertContext, expr *exprv1.Expr) err
if err != nil { if err != nil {
return err return err
} }
if !slices.Contains([]string{"create_time", "update_time", "visibility", "content"}, identifier) { if !slices.Contains([]string{"creator_id", "create_time", "update_time", "visibility", "content"}, identifier) {
return errors.Errorf("invalid identifier for %s", v.CallExpr.Function) return errors.Errorf("invalid identifier for %s", v.CallExpr.Function)
} }
value, err := filter.GetConstValue(v.CallExpr.Args[1]) value, err := filter.GetConstValue(v.CallExpr.Args[1])
@ -121,6 +121,23 @@ func (d *DB) ConvertExprToSQL(ctx *filter.ConvertContext, expr *exprv1.Expr) err
return err return err
} }
ctx.Args = append(ctx.Args, valueStr) ctx.Args = append(ctx.Args, valueStr)
} else if identifier == "creator_id" {
if operator != "=" && operator != "!=" {
return errors.Errorf("invalid operator for %s", v.CallExpr.Function)
}
valueInt, ok := value.(int64)
if !ok {
return errors.New("invalid int value")
}
var factor string
if identifier == "creator_id" {
factor = "`memo`.`creator_id`"
}
if _, err := ctx.Buffer.WriteString(fmt.Sprintf("%s %s ?", factor, operator)); err != nil {
return err
}
ctx.Args = append(ctx.Args, valueInt)
} }
case "@in": case "@in":
if len(v.CallExpr.Args) != 2 { if len(v.CallExpr.Args) != 2 {

@ -69,6 +69,11 @@ func TestConvertExprToSQL(t *testing.T) {
want: "NOT (`memo`.`pinned` IS TRUE)", want: "NOT (`memo`.`pinned` IS TRUE)",
args: []any{}, args: []any{},
}, },
{
filter: `creator_id == 101 || visibility in ["PUBLIC", "PRIVATE"]`,
want: "(`memo`.`creator_id` = ? OR `memo`.`visibility` IN (?,?))",
args: []any{int64(101), "PUBLIC", "PRIVATE"},
},
} }
for _, tt := range tests { for _, tt := range tests {

Loading…
Cancel
Save