diff --git a/store/db/postgres/activity.go b/store/db/postgres/activity.go index 900bb885..84e15328 100644 --- a/store/db/postgres/activity.go +++ b/store/db/postgres/activity.go @@ -2,8 +2,8 @@ package postgres import ( "context" + "strings" - "github.com/Masterminds/squirrel" "github.com/pkg/errors" "google.golang.org/protobuf/encoding/protojson" @@ -21,50 +21,29 @@ func (d *DB) CreateActivity(ctx context.Context, create *store.Activity) (*store payloadString = string(bytes) } - qb := squirrel.Insert("activity"). - Columns("creator_id", "type", "level", "payload"). - PlaceholderFormat(squirrel.Dollar) - - values := []any{create.CreatorID, create.Type.String(), create.Level.String(), payloadString} - qb = qb.Values(values...).Suffix("RETURNING id") - - stmt, args, err := qb.ToSql() - if err != nil { - return nil, errors.Wrap(err, "failed to construct query") - } - - var id int32 - err = d.db.QueryRowContext(ctx, stmt, args...).Scan(&id) - if err != nil { - return nil, errors.Wrap(err, "failed to execute statement and retrieve ID") - } - - list, err := d.ListActivities(ctx, &store.FindActivity{ID: &id}) - if err != nil || len(list) == 0 { - return nil, errors.Wrap(err, "failed to find activity") + fields := []string{"creator_id", "type", "level", "payload"} + args := []any{create.CreatorID, create.Type.String(), create.Level.String(), payloadString} + stmt := "INSERT INTO activity (" + strings.Join(fields, ", ") + ") VALUES (" + placeholders(len(args)) + ") RETURNING id, created_ts" + if err := d.db.QueryRowContext(ctx, stmt, args...).Scan( + &create.ID, + &create.CreatedTs, + ); err != nil { + return nil, err } - return list[0], nil + return create, nil } func (d *DB) ListActivities(ctx context.Context, find *store.FindActivity) ([]*store.Activity, error) { - qb := squirrel.Select("id", "created_ts", "creator_id", "type", "level", "payload"). - From("activity"). - Where("1 = 1"). - PlaceholderFormat(squirrel.Dollar) - + where, args := []string{"1 = 1"}, []any{} if find.ID != nil { - qb = qb.Where(squirrel.Eq{"id": *find.ID}) + where, args = append(where, "id = "+placeholder(len(args)+1)), append(args, *find.ID) } if find.Type != nil { - qb = qb.Where(squirrel.Eq{"type": find.Type.String()}) - } - - query, args, err := qb.ToSql() - if err != nil { - return nil, err + where, args = append(where, "type = "+placeholder(len(args)+1)), append(args, find.Type.String()) } + query := "SELECT id, creator_id, type, level, payload, created_ts FROM activity WHERE " + strings.Join(where, " AND ") + " ORDER BY created_ts DESC" rows, err := d.db.QueryContext(ctx, query, args...) if err != nil { return nil, err @@ -77,17 +56,17 @@ func (d *DB) ListActivities(ctx context.Context, find *store.FindActivity) ([]*s var payloadBytes []byte if err := rows.Scan( &activity.ID, - &activity.CreatedTs, &activity.CreatorID, &activity.Type, &activity.Level, &payloadBytes, + &activity.CreatedTs, ); err != nil { return nil, err } payload := &storepb.ActivityPayload{} - if err := protojson.Unmarshal(payloadBytes, payload); err != nil { + if err := protojsonUnmarshaler.Unmarshal(payloadBytes, payload); err != nil { return nil, err } activity.Payload = payload diff --git a/store/db/postgres/common.go b/store/db/postgres/common.go index fd5706d9..b4f074d6 100644 --- a/store/db/postgres/common.go +++ b/store/db/postgres/common.go @@ -1,9 +1,26 @@ package postgres -import "google.golang.org/protobuf/encoding/protojson" +import ( + "fmt" + "strings" + + "google.golang.org/protobuf/encoding/protojson" +) var ( protojsonUnmarshaler = protojson.UnmarshalOptions{ DiscardUnknown: true, } ) + +func placeholder(n int) string { + return "$" + fmt.Sprint(n) +} + +func placeholders(n int) string { + list := []string{} + for i := 0; i < n; i++ { + list = append(list, placeholder(i+1)) + } + return strings.Join(list, ", ") +} diff --git a/store/db/postgres/idp.go b/store/db/postgres/idp.go index 819cb793..2257df51 100644 --- a/store/db/postgres/idp.go +++ b/store/db/postgres/idp.go @@ -3,8 +3,8 @@ package postgres import ( "context" "encoding/json" + "strings" - "github.com/Masterminds/squirrel" "github.com/pkg/errors" "github.com/usememos/memos/store" @@ -22,42 +22,34 @@ func (d *DB) CreateIdentityProvider(ctx context.Context, create *store.IdentityP return nil, errors.Errorf("unsupported idp type %s", string(create.Type)) } - qb := squirrel.Insert("idp").Columns("name", "type", "identifier_filter", "config") - values := []any{create.Name, create.Type, create.IdentifierFilter, string(configBytes)} - - qb = qb.Values(values...).PlaceholderFormat(squirrel.Dollar) - qb = qb.Suffix("RETURNING id") - - stmt, args, err := qb.ToSql() - if err != nil { + fields := []string{"name", "type", "identifier_filter", "config"} + args := []any{create.Name, create.Type, create.IdentifierFilter, string(configBytes)} + stmt := "INSERT INTO idp (" + strings.Join(fields, ", ") + ") VALUES (" + placeholders(len(args)) + ") RETURNING id" + if err := d.db.QueryRowContext(ctx, stmt, args...).Scan(&create.ID); err != nil { return nil, err } - var id int32 - err = d.db.QueryRowContext(ctx, stmt, args...).Scan(&id) - if err != nil { - return nil, err - } - - create.ID = id - return create, nil + identityProvider := create + return identityProvider, nil } -func (d *DB) ListIdentityProviders(ctx context.Context, find *store.FindIdentityProvider) ([]*store.IdentityProvider, error) { - qb := squirrel.Select("id", "name", "type", "identifier_filter", "config"). - From("idp"). - Where("1 = 1"). - PlaceholderFormat(squirrel.Dollar) +func (d *DB) ListIdentityProviders(ctx context.Context, find *store.FindIdentityProvider) ([]*store.IdentityProvider, error) { + where, args := []string{"1 = 1"}, []any{} if v := find.ID; v != nil { - qb = qb.Where(squirrel.Eq{"id": *v}) - } - - query, args, err := qb.ToSql() - if err != nil { - return nil, err - } - - rows, err := d.db.QueryContext(ctx, query, args...) + where, args = append(where, "id = "+placeholder(len(args)+1)), append(args, *v) + } + + rows, err := d.db.QueryContext(ctx, ` + SELECT + id, + name, + type, + identifier_filter, + config + FROM idp + WHERE `+strings.Join(where, " AND ")+` ORDER BY id ASC`, + args..., + ) if err != nil { return nil, err } @@ -111,15 +103,12 @@ func (d *DB) GetIdentityProvider(ctx context.Context, find *store.FindIdentityPr } func (d *DB) UpdateIdentityProvider(ctx context.Context, update *store.UpdateIdentityProvider) (*store.IdentityProvider, error) { - qb := squirrel.Update("idp"). - PlaceholderFormat(squirrel.Dollar) - var err error - + set, args := []string{}, []any{} if v := update.Name; v != nil { - qb = qb.Set("name", *v) + set, args = append(set, "name = "+placeholder(len(args)+1)), append(args, *v) } if v := update.IdentifierFilter; v != nil { - qb = qb.Set("identifier_filter", *v) + set, args = append(set, "identifier_filter = "+placeholder(len(args)+1)), append(args, *v) } if v := update.Config; v != nil { var configBytes []byte @@ -132,42 +121,53 @@ func (d *DB) UpdateIdentityProvider(ctx context.Context, update *store.UpdateIde } else { return nil, errors.Errorf("unsupported idp type %s", string(update.Type)) } - qb = qb.Set("config", string(configBytes)) - } - - qb = qb.Where(squirrel.Eq{"id": update.ID}) - - stmt, args, err := qb.ToSql() - if err != nil { + set, args = append(set, "config = "+placeholder(len(args)+1)), append(args, string(configBytes)) + } + + stmt := ` + UPDATE idp + SET ` + strings.Join(set, ", ") + ` + WHERE id = ` + placeholder(len(args)+1) + ` + RETURNING id, name, type, identifier_filter, config + ` + args = append(args, update.ID) + + var identityProvider store.IdentityProvider + var identityProviderConfig string + if err := d.db.QueryRowContext(ctx, stmt, args...).Scan( + &identityProvider.ID, + &identityProvider.Name, + &identityProvider.Type, + &identityProvider.IdentifierFilter, + &identityProviderConfig, + ); err != nil { return nil, err } - _, err = d.db.ExecContext(ctx, stmt, args...) - if err != nil { - return nil, err + if identityProvider.Type == store.IdentityProviderOAuth2Type { + oauth2Config := &store.IdentityProviderOAuth2Config{} + if err := json.Unmarshal([]byte(identityProviderConfig), oauth2Config); err != nil { + return nil, err + } + identityProvider.Config = &store.IdentityProviderConfig{ + OAuth2Config: oauth2Config, + } + } else { + return nil, errors.Errorf("unsupported idp type %s", string(identityProvider.Type)) } - return d.GetIdentityProvider(ctx, &store.FindIdentityProvider{ID: &update.ID}) + return &identityProvider, nil } func (d *DB) DeleteIdentityProvider(ctx context.Context, delete *store.DeleteIdentityProvider) error { - qb := squirrel.Delete("idp"). - Where(squirrel.Eq{"id": delete.ID}). - PlaceholderFormat(squirrel.Dollar) - - stmt, args, err := qb.ToSql() - if err != nil { - return err - } - + where, args := []string{"id = $1"}, []any{delete.ID} + stmt := `DELETE FROM idp WHERE ` + strings.Join(where, " AND ") result, err := d.db.ExecContext(ctx, stmt, args...) if err != nil { return err } - if _, err = result.RowsAffected(); err != nil { return err } - return nil } diff --git a/store/db/postgres/inbox.go b/store/db/postgres/inbox.go index 5decb3c9..1191e414 100644 --- a/store/db/postgres/inbox.go +++ b/store/db/postgres/inbox.go @@ -2,8 +2,8 @@ package postgres import ( "context" + "strings" - "github.com/Masterminds/squirrel" "github.com/pkg/errors" "google.golang.org/protobuf/encoding/protojson" @@ -21,61 +21,54 @@ func (d *DB) CreateInbox(ctx context.Context, create *store.Inbox) (*store.Inbox messageString = string(bytes) } - qb := squirrel.Insert("inbox"). - Columns("sender_id", "receiver_id", "status", "message"). - Values(create.SenderID, create.ReceiverID, create.Status, messageString). - Suffix("RETURNING id"). - PlaceholderFormat(squirrel.Dollar) - - stmt, args, err := qb.ToSql() - if err != nil { - return nil, err - } - - var id int32 - err = d.db.QueryRowContext(ctx, stmt, args...).Scan(&id) - if err != nil { + fields := []string{"sender_id", "receiver_id", "status", "message"} + args := []any{create.SenderID, create.ReceiverID, create.Status, messageString} + stmt := "INSERT INTO inbox (" + strings.Join(fields, ", ") + ") VALUES (" + placeholders(len(args)) + ") RETURNING id, created_ts" + if err := d.db.QueryRowContext(ctx, stmt, args...).Scan( + &create.ID, + &create.CreatedTs, + ); err != nil { return nil, err } - return d.GetInbox(ctx, &store.FindInbox{ID: &id}) + return create, nil } func (d *DB) ListInboxes(ctx context.Context, find *store.FindInbox) ([]*store.Inbox, error) { - qb := squirrel.Select("id", "created_ts", "sender_id", "receiver_id", "status", "message"). - From("inbox"). - Where("1 = 1"). - PlaceholderFormat(squirrel.Dollar) + where, args := []string{"1 = 1"}, []any{} if find.ID != nil { - qb = qb.Where(squirrel.Eq{"id": *find.ID}) + where, args = append(where, "id = "+placeholder(len(args)+1)), append(args, *find.ID) } if find.SenderID != nil { - qb = qb.Where(squirrel.Eq{"sender_id": *find.SenderID}) + where, args = append(where, "sender_id = "+placeholder(len(args)+1)), append(args, *find.SenderID) } if find.ReceiverID != nil { - qb = qb.Where(squirrel.Eq{"receiver_id": *find.ReceiverID}) + where, args = append(where, "receiver_id = "+placeholder(len(args)+1)), append(args, *find.ReceiverID) } if find.Status != nil { - qb = qb.Where(squirrel.Eq{"status": *find.Status}) - } - - query, args, err := qb.ToSql() - if err != nil { - return nil, err + where, args = append(where, "status = "+placeholder(len(args)+1)), append(args, *find.Status) } + query := "SELECT id, created_ts, sender_id, receiver_id, status, message FROM inbox WHERE " + strings.Join(where, " AND ") + " ORDER BY created_ts DESC" rows, err := d.db.QueryContext(ctx, query, args...) if err != nil { return nil, err } defer rows.Close() - var list []*store.Inbox + list := []*store.Inbox{} for rows.Next() { inbox := &store.Inbox{} var messageBytes []byte - if err := rows.Scan(&inbox.ID, &inbox.CreatedTs, &inbox.SenderID, &inbox.ReceiverID, &inbox.Status, &messageBytes); err != nil { + if err := rows.Scan( + &inbox.ID, + &inbox.CreatedTs, + &inbox.SenderID, + &inbox.ReceiverID, + &inbox.Status, + &messageBytes, + ); err != nil { return nil, err } @@ -87,7 +80,11 @@ func (d *DB) ListInboxes(ctx context.Context, find *store.FindInbox) ([]*store.I list = append(list, inbox) } - return list, rows.Err() + if err := rows.Err(); err != nil { + return nil, err + } + + return list, nil } func (d *DB) GetInbox(ctx context.Context, find *store.FindInbox) (*store.Inbox, error) { @@ -102,39 +99,36 @@ func (d *DB) GetInbox(ctx context.Context, find *store.FindInbox) (*store.Inbox, } func (d *DB) UpdateInbox(ctx context.Context, update *store.UpdateInbox) (*store.Inbox, error) { - qb := squirrel.Update("inbox"). - Set("status", update.Status.String()). - Where(squirrel.Eq{"id": update.ID}). - PlaceholderFormat(squirrel.Dollar) - - stmt, args, err := qb.ToSql() - if err != nil { + set, args := []string{"status = $1"}, []any{update.Status.String()} + args = append(args, update.ID) + query := "UPDATE inbox SET " + strings.Join(set, ", ") + " WHERE id = $2 RETURNING id, created_ts, sender_id, receiver_id, status, message" + inbox := &store.Inbox{} + var messageBytes []byte + if err := d.db.QueryRowContext(ctx, query, args...).Scan( + &inbox.ID, + &inbox.CreatedTs, + &inbox.SenderID, + &inbox.ReceiverID, + &inbox.Status, + &messageBytes, + ); err != nil { return nil, err } - - _, err = d.db.ExecContext(ctx, stmt, args...) - if err != nil { + message := &storepb.InboxMessage{} + if err := protojsonUnmarshaler.Unmarshal(messageBytes, message); err != nil { return nil, err } - - return d.GetInbox(ctx, &store.FindInbox{ID: &update.ID}) + inbox.Message = message + return inbox, nil } func (d *DB) DeleteInbox(ctx context.Context, delete *store.DeleteInbox) error { - qb := squirrel.Delete("inbox"). - Where(squirrel.Eq{"id": delete.ID}). - PlaceholderFormat(squirrel.Dollar) - - stmt, args, err := qb.ToSql() + result, err := d.db.ExecContext(ctx, "DELETE FROM inbox WHERE id = $1", delete.ID) if err != nil { return err } - - result, err := d.db.ExecContext(ctx, stmt, args...) - if err != nil { + if _, err := result.RowsAffected(); err != nil { return err } - - _, err = result.RowsAffected() - return err + return nil } diff --git a/store/db/postgres/memo.go b/store/db/postgres/memo.go index f1d31bd8..651a4738 100644 --- a/store/db/postgres/memo.go +++ b/store/db/postgres/memo.go @@ -3,153 +3,127 @@ package postgres import ( "context" "database/sql" + "fmt" "strings" - "github.com/Masterminds/squirrel" "github.com/pkg/errors" "github.com/usememos/memos/store" ) func (d *DB) CreateMemo(ctx context.Context, create *store.Memo) (*store.Memo, error) { - // Initialize a Squirrel statement builder for PostgreSQL - builder := squirrel.Insert("memo"). - PlaceholderFormat(squirrel.Dollar). - Columns("creator_id", "content", "visibility") - - // Add initial values for the columns - values := []any{create.CreatorID, create.Content, create.Visibility} - - // Add all the values at once - builder = builder.Values(values...) - - // Add the RETURNING clause to get the ID of the inserted row - builder = builder.Suffix("RETURNING id") - - // Prepare and execute the query - query, args, err := builder.ToSql() - if err != nil { + fields := []string{"creator_id", "content", "visibility"} + args := []any{create.CreatorID, create.Content, create.Visibility} + + stmt := "INSERT INTO memo (" + strings.Join(fields, ", ") + ") VALUES (" + placeholders(len(args)) + ") RETURNING id, created_ts, updated_ts, row_status" + if err := d.db.QueryRowContext(ctx, stmt, args...).Scan( + &create.ID, + &create.CreatedTs, + &create.UpdatedTs, + &create.RowStatus, + ); err != nil { return nil, err } - var id int32 - err = d.db.QueryRowContext(ctx, query, args...).Scan(&id) - if err != nil { - return nil, err - } - - // Retrieve the newly created memo - memo, err := d.GetMemo(ctx, &store.FindMemo{ID: &id}) - if err != nil { - return nil, err - } - if memo == nil { - return nil, errors.Errorf("failed to create memo") - } - - return memo, nil + return create, nil } func (d *DB) ListMemos(ctx context.Context, find *store.FindMemo) ([]*store.Memo, error) { - // Start building the SELECT statement - builder := squirrel.Select( - "memo.id AS id", - "memo.creator_id AS creator_id", - "memo.created_ts AS created_ts", - "memo.updated_ts AS updated_ts", - "memo.row_status AS row_status", - "memo.content AS content", - "memo.visibility AS visibility", - "MAX(CASE WHEN memo_organizer.pinned = 1 THEN 1 ELSE 0 END) AS pinned"). - From("memo"). - LeftJoin("memo_organizer ON memo.id = memo_organizer.memo_id"). - LeftJoin("resource ON memo.id = resource.memo_id"). - GroupBy("memo.id"). - PlaceholderFormat(squirrel.Dollar) - - // Add conditional where clauses + where, args := []string{"1 = 1"}, []any{} + if v := find.ID; v != nil { - builder = builder.Where("memo.id = ?", *v) + where, args = append(where, "memo.id = "+placeholder(len(args)+1)), append(args, *v) } if v := find.CreatorID; v != nil { - builder = builder.Where("memo.creator_id = ?", *v) + where, args = append(where, "memo.creator_id = "+placeholder(len(args)+1)), append(args, *v) } if v := find.RowStatus; v != nil { - builder = builder.Where("memo.row_status = ?", *v) + where, args = append(where, "memo.row_status = "+placeholder(len(args)+1)), append(args, *v) } if v := find.CreatedTsBefore; v != nil { - builder = builder.Where("memo.created_ts < ?", *v) + where, args = append(where, "memo.created_ts < "+placeholder(len(args)+1)), append(args, *v) } if v := find.CreatedTsAfter; v != nil { - builder = builder.Where("memo.created_ts > ?", *v) - } - if v := find.Pinned; v != nil { - builder = builder.Where("memo_organizer.pinned = 1") + where, args = append(where, "memo.created_ts > "+placeholder(len(args)+1)), append(args, *v) } if v := find.ContentSearch; len(v) != 0 { for _, s := range v { - builder = builder.Where("memo.content LIKE ?", "%"+s+"%") + where, args = append(where, "memo.content LIKE "+placeholder(len(args)+1)), append(args, fmt.Sprintf("%%%s%%", s)) } } - if v := find.VisibilityList; len(v) != 0 { - placeholders := make([]string, len(v)) - args := make([]any, len(v)) - for i, visibility := range v { - placeholders[i] = "?" - args[i] = visibility // Assuming visibility can be directly used as an argument + holders := []string{} + for _, visibility := range v { + holders = append(holders, placeholder(len(args)+1)) + args = append(args, visibility.String()) } - inClause := strings.Join(placeholders, ",") - builder = builder.Where("memo.visibility IN ("+inClause+")", args...) + where = append(where, fmt.Sprintf("memo.visibility in (%s)", strings.Join(holders, ", "))) + } + if v := find.Pinned; v != nil { + where = append(where, "memo_organizer.pinned = 1") } - // Add order by clauses + + orders := []string{} if find.OrderByPinned { - builder = builder.OrderBy("pinned DESC") + orders = append(orders, "pinned DESC") } if find.OrderByUpdatedTs { - builder = builder.OrderBy("updated_ts DESC") + orders = append(orders, "updated_ts DESC") } else { - builder = builder.OrderBy("created_ts DESC") + orders = append(orders, "created_ts DESC") } - builder = builder.OrderBy("id DESC") + orders = append(orders, "id DESC") - // Handle pagination + fields := []string{ + `memo.id AS id`, + `memo.creator_id AS creator_id`, + `memo.created_ts AS created_ts`, + `memo.updated_ts AS updated_ts`, + `memo.row_status AS row_status`, + `memo.visibility AS visibility`, + `MAX(CASE WHEN memo_organizer.pinned = 1 THEN 1 ELSE 0 END) AS pinned`, + } + if !find.ExcludeContent { + fields = append(fields, `memo.content AS content`) + } + + query := `SELECT ` + strings.Join(fields, ", ") + ` + FROM memo + LEFT JOIN memo_organizer ON memo.id = memo_organizer.memo_id + WHERE ` + strings.Join(where, " AND ") + ` + GROUP BY memo.id + ORDER BY ` + strings.Join(orders, ", ") if find.Limit != nil { - builder = builder.Limit(uint64(*find.Limit)) + query = fmt.Sprintf("%s LIMIT %d", query, *find.Limit) if find.Offset != nil { - builder = builder.Offset(uint64(*find.Offset)) + query = fmt.Sprintf("%s OFFSET %d", query, *find.Offset) } } - // Prepare and execute the query - query, args, err := builder.ToSql() - if err != nil { - return nil, err - } - rows, err := d.db.QueryContext(ctx, query, args...) if err != nil { return nil, err } defer rows.Close() - // Process the result set list := make([]*store.Memo, 0) for rows.Next() { var memo store.Memo - if err := rows.Scan( + dests := []any{ &memo.ID, &memo.CreatorID, &memo.CreatedTs, &memo.UpdatedTs, &memo.RowStatus, - &memo.Content, &memo.Visibility, &memo.Pinned, - ); err != nil { + } + if !find.ExcludeContent { + dests = append(dests, &memo.Content) + } + if err := rows.Scan(dests...); err != nil { return nil, err } - list = append(list, &memo) } @@ -174,51 +148,42 @@ func (d *DB) GetMemo(ctx context.Context, find *store.FindMemo) (*store.Memo, er } func (d *DB) UpdateMemo(ctx context.Context, update *store.UpdateMemo) error { - // Start building the update statement - builder := squirrel.Update("memo"). - PlaceholderFormat(squirrel.Dollar). - Where("id = ?", update.ID) - - // Conditionally add set clauses + set, args := []string{}, []any{} if v := update.CreatedTs; v != nil { - builder = builder.Set("created_ts", *v) + set, args = append(set, "created_ts = "+placeholder(len(args)+1)), append(args, *v) } if v := update.UpdatedTs; v != nil { - builder = builder.Set("updated_ts", *v) + set, args = append(set, "updated_ts = "+placeholder(len(args)+1)), append(args, *v) } if v := update.RowStatus; v != nil { - builder = builder.Set("row_status", *v) + set, args = append(set, "row_status = "+placeholder(len(args)+1)), append(args, *v) } if v := update.Content; v != nil { - builder = builder.Set("content", *v) + set, args = append(set, "content = "+placeholder(len(args)+1)), append(args, *v) } if v := update.Visibility; v != nil { - builder = builder.Set("visibility", *v) - } - - // Prepare and execute the query - query, args, err := builder.ToSql() - if err != nil { - return err + set, args = append(set, "visibility = "+placeholder(len(args)+1)), append(args, *v) } - - if _, err := d.db.ExecContext(ctx, query, args...); err != nil { + stmt := `UPDATE memo SET ` + strings.Join(set, ", ") + ` WHERE id = ` + placeholder(len(args)+1) + args = append(args, update.ID) + if _, err := d.db.ExecContext(ctx, stmt, args...); err != nil { return err } - return nil } func (d *DB) DeleteMemo(ctx context.Context, delete *store.DeleteMemo) error { - stmt := `DELETE FROM memo WHERE id = $1` - result, err := d.db.ExecContext(ctx, stmt, delete.ID) + where, args := []string{"id = " + placeholder(1)}, []any{delete.ID} + stmt := `DELETE FROM memo WHERE ` + strings.Join(where, " AND ") + println("stmt", stmt, delete.ID) + result, err := d.db.ExecContext(ctx, stmt, args...) if err != nil { - return err + return errors.Wrap(err, "failed to delete memo") } if _, err := result.RowsAffected(); err != nil { return err } - return d.Vacuum(ctx) + return nil } func vacuumMemo(ctx context.Context, tx *sql.Tx) error { diff --git a/store/db/postgres/memo_organizer.go b/store/db/postgres/memo_organizer.go index 3bf03198..00b5807b 100644 --- a/store/db/postgres/memo_organizer.go +++ b/store/db/postgres/memo_organizer.go @@ -4,8 +4,7 @@ import ( "context" "database/sql" "fmt" - - "github.com/Masterminds/squirrel" + "strings" "github.com/usememos/memos/store" ) @@ -15,99 +14,94 @@ func (d *DB) UpsertMemoOrganizer(ctx context.Context, upsert *store.MemoOrganize if upsert.Pinned { pinned = 1 } - stmt := "INSERT INTO memo_organizer (memo_id, user_id, pinned) VALUES ($1, $2, $3) ON CONFLICT (memo_id, user_id) DO UPDATE SET pinned = $4" - if _, err := d.db.ExecContext(ctx, stmt, upsert.MemoID, upsert.UserID, pinned, pinned); err != nil { + stmt := ` + INSERT INTO memo_organizer ( + memo_id, + user_id, + pinned + ) + VALUES (` + placeholders(3) + `) + ON CONFLICT(memo_id, user_id) DO UPDATE + SET pinned = EXCLUDED.pinned` + if _, err := d.db.ExecContext(ctx, stmt, upsert.MemoID, upsert.UserID, pinned); err != nil { return nil, err } + return upsert, nil } func (d *DB) ListMemoOrganizer(ctx context.Context, find *store.FindMemoOrganizer) ([]*store.MemoOrganizer, error) { - qb := squirrel.Select("memo_id", "user_id", "pinned"). - From("memo_organizer"). - Where("1 = 1"). - PlaceholderFormat(squirrel.Dollar) - + where, args := []string{"1 = 1"}, []any{} if find.MemoID != 0 { - qb = qb.Where(squirrel.Eq{"memo_id": find.MemoID}) + where, args = append(where, "memo_id = "+placeholder(len(args)+1)), append(args, find.MemoID) } if find.UserID != 0 { - qb = qb.Where(squirrel.Eq{"user_id": find.UserID}) - } - - query, args, err := qb.ToSql() - if err != nil { - return nil, err + where, args = append(where, "user_id = "+placeholder(len(args)+1)), append(args, find.UserID) } + query := fmt.Sprintf(` + SELECT + memo_id, + user_id, + pinned + FROM memo_organizer + WHERE %s + `, strings.Join(where, " AND ")) rows, err := d.db.QueryContext(ctx, query, args...) if err != nil { return nil, err } defer rows.Close() - var list []*store.MemoOrganizer + list := []*store.MemoOrganizer{} for rows.Next() { memoOrganizer := &store.MemoOrganizer{} - if err := rows.Scan(&memoOrganizer.MemoID, &memoOrganizer.UserID, &memoOrganizer.Pinned); err != nil { + pinned := 0 + if err := rows.Scan( + &memoOrganizer.MemoID, + &memoOrganizer.UserID, + &pinned, + ); err != nil { return nil, err } + + memoOrganizer.Pinned = pinned == 1 list = append(list, memoOrganizer) } - return list, rows.Err() + if err := rows.Err(); err != nil { + return nil, err + } + + return list, nil } func (d *DB) DeleteMemoOrganizer(ctx context.Context, delete *store.DeleteMemoOrganizer) error { - qb := squirrel.Delete("memo_organizer"). - PlaceholderFormat(squirrel.Dollar) - + where, args := []string{}, []any{} if v := delete.MemoID; v != nil { - qb = qb.Where(squirrel.Eq{"memo_id": *v}) + where, args = append(where, "memo_id = "+placeholder(len(args)+1)), append(args, *v) } if v := delete.UserID; v != nil { - qb = qb.Where(squirrel.Eq{"user_id": *v}) - } - - stmt, args, err := qb.ToSql() - if err != nil { - return err + where, args = append(where, "user_id = "+placeholder(len(args)+1)), append(args, *v) } - - if _, err = d.db.ExecContext(ctx, stmt, args...); err != nil { + stmt := `DELETE FROM memo_organizer WHERE ` + strings.Join(where, " AND ") + if _, err := d.db.ExecContext(ctx, stmt, args...); err != nil { return err } - return nil } func vacuumMemoOrganizer(ctx context.Context, tx *sql.Tx) error { - // First, build the subquery for memo_id - subQueryMemo, subArgsMemo, err := squirrel.Select("id").From("memo").PlaceholderFormat(squirrel.Dollar).ToSql() - if err != nil { - return err - } - - // Build the subquery for user_id - subQueryUser, subArgsUser, err := squirrel.Select("id").From(`"user"`).PlaceholderFormat(squirrel.Dollar).ToSql() + stmt := ` + DELETE FROM + memo_organizer + WHERE + memo_id NOT IN (SELECT id FROM memo) + OR user_id NOT IN (SELECT id FROM "user")` + _, err := tx.ExecContext(ctx, stmt) if err != nil { return err } - // Now, build the main delete query using the subqueries - query, args, err := squirrel.Delete("memo_organizer"). - Where(fmt.Sprintf("memo_id NOT IN (%s)", subQueryMemo), subArgsMemo...). - Where(fmt.Sprintf("user_id NOT IN (%s)", subQueryUser), subArgsUser...). - PlaceholderFormat(squirrel.Dollar). - ToSql() - if err != nil { - return err - } - - // Combine the arguments from both subqueries - args = append(args, subArgsUser...) - - // Execute the query - _, err = tx.ExecContext(ctx, query, args...) - return err + return nil } diff --git a/store/db/postgres/memo_relation.go b/store/db/postgres/memo_relation.go index a5c8b11f..39bc10cb 100644 --- a/store/db/postgres/memo_relation.go +++ b/store/db/postgres/memo_relation.go @@ -3,127 +3,111 @@ package postgres import ( "context" "database/sql" - "fmt" - - "github.com/Masterminds/squirrel" + "strings" "github.com/usememos/memos/store" ) func (d *DB) UpsertMemoRelation(ctx context.Context, create *store.MemoRelation) (*store.MemoRelation, error) { - qb := squirrel.Insert("memo_relation"). - Columns("memo_id", "related_memo_id", "type"). - Values(create.MemoID, create.RelatedMemoID, create.Type). - Suffix("ON CONFLICT (version) DO NOTHING"). - PlaceholderFormat(squirrel.Dollar) - - stmt, args, err := qb.ToSql() - if err != nil { - return nil, err - } - - _, err = d.db.ExecContext(ctx, stmt, args...) - if err != nil { + stmt := ` + INSERT INTO memo_relation ( + memo_id, + related_memo_id, + type + ) + VALUES (` + placeholders(3) + `) + RETURNING memo_id, related_memo_id, type + ` + memoRelation := &store.MemoRelation{} + if err := d.db.QueryRowContext( + ctx, + stmt, + create.MemoID, + create.RelatedMemoID, + create.Type, + ).Scan( + &memoRelation.MemoID, + &memoRelation.RelatedMemoID, + &memoRelation.Type, + ); err != nil { return nil, err } - return &store.MemoRelation{ - MemoID: create.MemoID, - RelatedMemoID: create.RelatedMemoID, - Type: create.Type, - }, nil + return memoRelation, nil } func (d *DB) ListMemoRelations(ctx context.Context, find *store.FindMemoRelation) ([]*store.MemoRelation, error) { - qb := squirrel.Select("memo_id", "related_memo_id", "type"). - From("memo_relation"). - Where("TRUE"). - PlaceholderFormat(squirrel.Dollar) - + where, args := []string{"1 = 1"}, []any{} if find.MemoID != nil { - qb = qb.Where(squirrel.Eq{"memo_id": *find.MemoID}) + where, args = append(where, "memo_id = "+placeholder(len(args)+1)), append(args, find.MemoID) } if find.RelatedMemoID != nil { - qb = qb.Where(squirrel.Eq{"related_memo_id": *find.RelatedMemoID}) + where, args = append(where, "related_memo_id = "+placeholder(len(args)+1)), append(args, find.RelatedMemoID) } if find.Type != nil { - qb = qb.Where(squirrel.Eq{"type": *find.Type}) + where, args = append(where, "type = "+placeholder(len(args)+1)), append(args, find.Type) } - query, args, err := qb.ToSql() - if err != nil { - return nil, err - } - - rows, err := d.db.QueryContext(ctx, query, args...) + rows, err := d.db.QueryContext(ctx, ` + SELECT + memo_id, + related_memo_id, + type + FROM memo_relation + WHERE `+strings.Join(where, " AND "), args...) if err != nil { return nil, err } defer rows.Close() - var list []*store.MemoRelation + list := []*store.MemoRelation{} for rows.Next() { memoRelation := &store.MemoRelation{} - if err := rows.Scan(&memoRelation.MemoID, &memoRelation.RelatedMemoID, &memoRelation.Type); err != nil { + if err := rows.Scan( + &memoRelation.MemoID, + &memoRelation.RelatedMemoID, + &memoRelation.Type, + ); err != nil { return nil, err } list = append(list, memoRelation) } - return list, rows.Err() + if err := rows.Err(); err != nil { + return nil, err + } + + return list, nil } func (d *DB) DeleteMemoRelation(ctx context.Context, delete *store.DeleteMemoRelation) error { - qb := squirrel.Delete("memo_relation"). - PlaceholderFormat(squirrel.Dollar) - + where, args := []string{"1 = 1"}, []any{} if delete.MemoID != nil { - qb = qb.Where(squirrel.Eq{"memo_id": *delete.MemoID}) + where, args = append(where, "memo_id = "+placeholder(len(args)+1)), append(args, delete.MemoID) } if delete.RelatedMemoID != nil { - qb = qb.Where(squirrel.Eq{"related_memo_id": *delete.RelatedMemoID}) + where, args = append(where, "related_memo_id = "+placeholder(len(args)+1)), append(args, delete.RelatedMemoID) } if delete.Type != nil { - qb = qb.Where(squirrel.Eq{"type": *delete.Type}) + where, args = append(where, "type = "+placeholder(len(args)+1)), append(args, delete.Type) } - - stmt, args, err := qb.ToSql() + stmt := `DELETE FROM memo_relation WHERE ` + strings.Join(where, " AND ") + result, err := d.db.ExecContext(ctx, stmt, args...) if err != nil { return err } - - result, err := d.db.ExecContext(ctx, stmt, args...) - if err != nil { + if _, err = result.RowsAffected(); err != nil { return err } - - _, err = result.RowsAffected() - return err + return nil } func vacuumMemoRelations(ctx context.Context, tx *sql.Tx) error { - // First, build the subquery for memo_id - subQueryMemo, subArgsMemo, err := squirrel.Select("id").From("memo").PlaceholderFormat(squirrel.Dollar).ToSql() - if err != nil { + if _, err := tx.ExecContext(ctx, ` + DELETE FROM memo_relation + WHERE memo_id NOT IN (SELECT id FROM memo) OR related_memo_id NOT IN (SELECT id FROM memo) + `); err != nil { return err } - - // Note: The same subquery is used for related_memo_id as it's also checking against the "memo" table - - // Now, build the main delete query using the subqueries - query, args, err := squirrel.Delete("memo_relation"). - Where(fmt.Sprintf("memo_id NOT IN (%s)", subQueryMemo), subArgsMemo...). - Where(fmt.Sprintf("related_memo_id NOT IN (%s)", subQueryMemo), subArgsMemo...). - PlaceholderFormat(squirrel.Dollar). - ToSql() - if err != nil { - return err - } - - // Combine the arguments for both instances of the same subquery - args = append(args, subArgsMemo...) - - // Execute the query - _, err = tx.ExecContext(ctx, query, args...) - return err + return nil } diff --git a/store/db/postgres/migration_history.go b/store/db/postgres/migration_history.go index d7b8a665..385b898b 100644 --- a/store/db/postgres/migration_history.go +++ b/store/db/postgres/migration_history.go @@ -3,22 +3,12 @@ package postgres import ( "context" - "github.com/Masterminds/squirrel" - "github.com/usememos/memos/store" ) func (d *DB) FindMigrationHistoryList(ctx context.Context, _ *store.FindMigrationHistory) ([]*store.MigrationHistory, error) { - qb := squirrel.Select("version", "created_ts"). - From("migration_history"). - OrderBy("created_ts DESC") - - query, args, err := qb.PlaceholderFormat(squirrel.Dollar).ToSql() - if err != nil { - return nil, err - } - - rows, err := d.db.QueryContext(ctx, query, args...) + query := "SELECT version, created_ts FROM migration_history ORDER BY created_ts DESC" + rows, err := d.db.QueryContext(ctx, query) if err != nil { return nil, err } @@ -27,9 +17,13 @@ func (d *DB) FindMigrationHistoryList(ctx context.Context, _ *store.FindMigratio list := make([]*store.MigrationHistory, 0) for rows.Next() { var migrationHistory store.MigrationHistory - if err := rows.Scan(&migrationHistory.Version, &migrationHistory.CreatedTs); err != nil { + if err := rows.Scan( + &migrationHistory.Version, + &migrationHistory.CreatedTs, + ); err != nil { return nil, err } + list = append(list, &migrationHistory) } @@ -41,33 +35,21 @@ func (d *DB) FindMigrationHistoryList(ctx context.Context, _ *store.FindMigratio } func (d *DB) UpsertMigrationHistory(ctx context.Context, upsert *store.UpsertMigrationHistory) (*store.MigrationHistory, error) { - qb := squirrel.Insert("migration_history"). - Columns("version"). - Values(upsert.Version). - Suffix("ON CONFLICT (version) DO NOTHING"). - PlaceholderFormat(squirrel.Dollar) - - query, args, err := qb.ToSql() - if err != nil { - return nil, err - } - - _, err = d.db.ExecContext(ctx, query, args...) - if err != nil { - return nil, err - } - + stmt := ` + INSERT INTO migration_history ( + version + ) + VALUES ($1) + ON CONFLICT(version) DO UPDATE + SET + version=EXCLUDED.version + RETURNING version, created_ts + ` var migrationHistory store.MigrationHistory - query, args, err = squirrel.Select("version", "created_ts"). - From("migration_history"). - Where(squirrel.Eq{"version": upsert.Version}). - PlaceholderFormat(squirrel.Dollar). - ToSql() - if err != nil { - return nil, err - } - - if err := d.db.QueryRowContext(ctx, query, args...).Scan(&migrationHistory.Version, &migrationHistory.CreatedTs); err != nil { + if err := d.db.QueryRowContext(ctx, stmt, upsert.Version).Scan( + &migrationHistory.Version, + &migrationHistory.CreatedTs, + ); err != nil { return nil, err } diff --git a/store/db/postgres/resource.go b/store/db/postgres/resource.go index 41780b57..b344dbf0 100644 --- a/store/db/postgres/resource.go +++ b/store/db/postgres/resource.go @@ -4,77 +4,61 @@ import ( "context" "database/sql" "fmt" - "time" - - "github.com/Masterminds/squirrel" - "github.com/pkg/errors" + "strings" "github.com/usememos/memos/store" ) func (d *DB) CreateResource(ctx context.Context, create *store.Resource) (*store.Resource, error) { - qb := squirrel.Insert("resource").Columns("filename", "blob", "external_link", "type", "size", "creator_id", "internal_path", "memo_id") - values := []any{create.Filename, create.Blob, create.ExternalLink, create.Type, create.Size, create.CreatorID, create.InternalPath, create.MemoID} - - qb = qb.Values(values...).Suffix("RETURNING id") - query, args, err := qb.PlaceholderFormat(squirrel.Dollar).ToSql() - if err != nil { - return nil, err - } - - var id int32 - err = d.db.QueryRowContext(ctx, query, args...).Scan(&id) - if err != nil { - return nil, err - } + fields := []string{"filename", "blob", "external_link", "type", "size", "creator_id", "internal_path", "memo_id"} + args := []any{create.Filename, create.Blob, create.ExternalLink, create.Type, create.Size, create.CreatorID, create.InternalPath, create.MemoID} - list, err := d.ListResources(ctx, &store.FindResource{ID: &id}) - if err != nil { + stmt := "INSERT INTO resource (" + strings.Join(fields, ", ") + ") VALUES (" + placeholders(len(args)) + ") RETURNING id, created_ts, updated_ts" + if err := d.db.QueryRowContext(ctx, stmt, args...).Scan(&create.ID, &create.CreatedTs, &create.UpdatedTs); err != nil { return nil, err } - if len(list) != 1 { - return nil, errors.Wrapf(nil, "unexpected resource count: %d", len(list)) - } - - return list[0], nil + return create, nil } func (d *DB) ListResources(ctx context.Context, find *store.FindResource) ([]*store.Resource, error) { - qb := squirrel.Select("id", "filename", "external_link", "type", "size", "creator_id", "created_ts", "updated_ts", "internal_path", "memo_id").From("resource") + where, args := []string{"1 = 1"}, []any{} if v := find.ID; v != nil { - qb = qb.Where(squirrel.Eq{"id": *v}) + where, args = append(where, "id = "+placeholder(len(args)+1)), append(args, *v) } if v := find.CreatorID; v != nil { - qb = qb.Where(squirrel.Eq{"creator_id": *v}) + where, args = append(where, "creator_id = "+placeholder(len(args)+1)), append(args, *v) } if v := find.Filename; v != nil { - qb = qb.Where(squirrel.Eq{"filename": *v}) + where, args = append(where, "filename = "+placeholder(len(args)+1)), append(args, *v) } if v := find.MemoID; v != nil { - qb = qb.Where(squirrel.Eq{"memo_id": *v}) + where, args = append(where, "memo_id = "+placeholder(len(args)+1)), append(args, *v) } if find.HasRelatedMemo { - qb = qb.Where("memo_id IS NOT NULL") + where = append(where, "memo_id IS NOT NULL") } + + fields := []string{"id", "filename", "external_link", "type", "size", "creator_id", "created_ts", "updated_ts", "internal_path", "memo_id"} if find.GetBlob { - qb = qb.Columns("blob") + fields = append(fields, "blob") } - qb = qb.GroupBy("id").OrderBy("created_ts DESC") - + query := fmt.Sprintf(` + SELECT + %s + FROM resource + WHERE %s + GROUP BY id + ORDER BY created_ts DESC + `, strings.Join(fields, ", "), strings.Join(where, " AND ")) if find.Limit != nil { - qb = qb.Limit(uint64(*find.Limit)) + query = fmt.Sprintf("%s LIMIT %d", query, *find.Limit) if find.Offset != nil { - qb = qb.Offset(uint64(*find.Offset)) + query = fmt.Sprintf("%s OFFSET %d", query, *find.Offset) } } - query, args, err := qb.PlaceholderFormat(squirrel.Dollar).ToSql() - if err != nil { - return nil, err - } - rows, err := d.db.QueryContext(ctx, query, args...) if err != nil { return nil, err @@ -103,7 +87,6 @@ func (d *DB) ListResources(ctx context.Context, find *store.FindResource) ([]*st if err := rows.Scan(dests...); err != nil { return nil, err } - if memoID.Valid { resource.MemoID = &memoID.Int32 } @@ -118,88 +101,72 @@ func (d *DB) ListResources(ctx context.Context, find *store.FindResource) ([]*st } func (d *DB) UpdateResource(ctx context.Context, update *store.UpdateResource) (*store.Resource, error) { - qb := squirrel.Update("resource") + set, args := []string{}, []any{} if v := update.UpdatedTs; v != nil { - qb = qb.Set("updated_ts", time.Unix(0, *v)) + set, args = append(set, "updated_ts = "+placeholder(len(args)+1)), append(args, *v) } if v := update.Filename; v != nil { - qb = qb.Set("filename", *v) + set, args = append(set, "filename = "+placeholder(len(args)+1)), append(args, *v) } if v := update.InternalPath; v != nil { - qb = qb.Set("internal_path", *v) + set, args = append(set, "internal_path = "+placeholder(len(args)+1)), append(args, *v) } if v := update.MemoID; v != nil { - qb = qb.Set("memo_id", *v) + set, args = append(set, "memo_id = "+placeholder(len(args)+1)), append(args, *v) } if v := update.Blob; v != nil { - qb = qb.Set("blob", v) - } - - qb = qb.Where(squirrel.Eq{"id": update.ID}) - - query, args, err := qb.PlaceholderFormat(squirrel.Dollar).ToSql() - if err != nil { + set, args = append(set, "blob = "+placeholder(len(args)+1)), append(args, v) + } + + fields := []string{"id", "filename", "external_link", "type", "size", "creator_id", "created_ts", "updated_ts", "internal_path"} + stmt := ` + UPDATE resource + SET ` + strings.Join(set, ", ") + ` + WHERE id = ` + placeholder(len(args)+1) + ` + RETURNING ` + strings.Join(fields, ", ") + args = append(args, update.ID) + resource := store.Resource{} + dests := []any{ + &resource.ID, + &resource.Filename, + &resource.ExternalLink, + &resource.Type, + &resource.Size, + &resource.CreatorID, + &resource.CreatedTs, + &resource.UpdatedTs, + &resource.InternalPath, + } + if err := d.db.QueryRowContext(ctx, stmt, args...).Scan(dests...); err != nil { return nil, err } - if _, err := d.db.ExecContext(ctx, query, args...); err != nil { - return nil, err - } - - list, err := d.ListResources(ctx, &store.FindResource{ID: &update.ID}) - if err != nil { - return nil, err - } - if len(list) != 1 { - return nil, errors.Wrapf(nil, "unexpected resource count: %d", len(list)) - } - - return list[0], nil + return &resource, nil } func (d *DB) DeleteResource(ctx context.Context, delete *store.DeleteResource) error { - qb := squirrel.Delete("resource").Where(squirrel.Eq{"id": delete.ID}) - - query, args, err := qb.PlaceholderFormat(squirrel.Dollar).ToSql() + stmt := `DELETE FROM resource WHERE id = $1` + result, err := d.db.ExecContext(ctx, stmt, delete.ID) if err != nil { return err } - - result, err := d.db.ExecContext(ctx, query, args...) - if err != nil { - return err - } - if _, err := result.RowsAffected(); err != nil { return err } - - if err := d.Vacuum(ctx); err != nil { - // Prevent linter warning. - return err - } - return nil } func vacuumResource(ctx context.Context, tx *sql.Tx) error { - // First, build the subquery - subQuery, subArgs, err := squirrel.Select("id").From(`"user"`).PlaceholderFormat(squirrel.Dollar).ToSql() + stmt := ` + DELETE FROM + resource + WHERE + creator_id NOT IN (SELECT id FROM "user")` + _, err := tx.ExecContext(ctx, stmt) if err != nil { return err } - // Now, build the main delete query using the subquery - query, args, err := squirrel.Delete("resource"). - Where(fmt.Sprintf("creator_id NOT IN (%s)", subQuery), subArgs...). - PlaceholderFormat(squirrel.Dollar). - ToSql() - if err != nil { - return err - } - - // Execute the query - _, err = tx.ExecContext(ctx, query, args...) - return err + return nil } diff --git a/store/db/postgres/tag.go b/store/db/postgres/tag.go index 1dece866..65d89ab3 100644 --- a/store/db/postgres/tag.go +++ b/store/db/postgres/tag.go @@ -3,7 +3,6 @@ package postgres import ( "context" "database/sql" - "fmt" "github.com/Masterminds/squirrel" @@ -82,22 +81,15 @@ func (d *DB) DeleteTag(ctx context.Context, delete *store.DeleteTag) error { } func vacuumTag(ctx context.Context, tx *sql.Tx) error { - // First, build the subquery for creator_id - subQuery, subArgs, err := squirrel.Select("id").From(`"user"`).PlaceholderFormat(squirrel.Dollar).ToSql() + stmt := ` + DELETE FROM + tag + WHERE + creator_id NOT IN (SELECT id FROM "user")` + _, err := tx.ExecContext(ctx, stmt) if err != nil { return err } - // Now, build the main delete query using the subquery - query, args, err := squirrel.Delete("tag"). - Where(fmt.Sprintf("creator_id NOT IN (%s)", subQuery), subArgs...). - PlaceholderFormat(squirrel.Dollar). - ToSql() - if err != nil { - return err - } - - // Execute the query - _, err = tx.ExecContext(ctx, query, args...) - return err + return nil } diff --git a/store/db/postgres/user.go b/store/db/postgres/user.go index 8941920a..588bbac3 100644 --- a/store/db/postgres/user.go +++ b/store/db/postgres/user.go @@ -2,132 +2,113 @@ package postgres import ( "context" - - "github.com/Masterminds/squirrel" - "github.com/pkg/errors" + "strings" "github.com/usememos/memos/store" ) func (d *DB) CreateUser(ctx context.Context, create *store.User) (*store.User, error) { - // Start building the insert statement - builder := squirrel.Insert(`"user"`).PlaceholderFormat(squirrel.Dollar) - - columns := []string{"username", "role", "email", "nickname", "password_hash", "avatar_url"} - builder = builder.Columns(columns...) - - values := []any{create.Username, create.Role, create.Email, create.Nickname, create.PasswordHash, create.AvatarURL} - - builder = builder.Values(values...) - builder = builder.Suffix("RETURNING id") - - // Prepare the final query - query, args, err := builder.ToSql() - if err != nil { + fields := []string{"username", "role", "email", "nickname", "password_hash", "avatar_url"} + args := []any{create.Username, create.Role, create.Email, create.Nickname, create.PasswordHash, create.AvatarURL} + stmt := "INSERT INTO \"user\" (" + strings.Join(fields, ", ") + ") VALUES (" + placeholders(len(args)) + ") RETURNING id, avatar_url, created_ts, updated_ts, row_status" + if err := d.db.QueryRowContext(ctx, stmt, args...).Scan( + &create.ID, + &create.AvatarURL, + &create.CreatedTs, + &create.UpdatedTs, + &create.RowStatus, + ); err != nil { return nil, err } - // Execute the query and get the returned ID - var id int32 - err = d.db.QueryRowContext(ctx, query, args...).Scan(&id) - if err != nil { - return nil, err - } - - // Use the returned ID to retrieve the full user object - user, err := d.GetUser(ctx, &store.FindUser{ID: &id}) - if err != nil { - return nil, err - } - - return user, nil + return create, nil } func (d *DB) UpdateUser(ctx context.Context, update *store.UpdateUser) (*store.User, error) { - // Start building the update statement - builder := squirrel.Update(`"user"`).PlaceholderFormat(squirrel.Dollar) - - // Conditionally add set clauses + set, args := []string{}, []any{} if v := update.UpdatedTs; v != nil { - builder = builder.Set("updated_ts", *v) + set, args = append(set, "updated_ts = "+placeholder(len(args)+1)), append(args, *v) } if v := update.RowStatus; v != nil { - builder = builder.Set("row_status", *v) + set, args = append(set, "row_status = "+placeholder(len(args)+1)), append(args, *v) } if v := update.Username; v != nil { - builder = builder.Set("username", *v) + set, args = append(set, "username = "+placeholder(len(args)+1)), append(args, *v) } if v := update.Email; v != nil { - builder = builder.Set("email", *v) + set, args = append(set, "email = "+placeholder(len(args)+1)), append(args, *v) } if v := update.Nickname; v != nil { - builder = builder.Set("nickname", *v) + set, args = append(set, "nickname = "+placeholder(len(args)+1)), append(args, *v) } if v := update.AvatarURL; v != nil { - builder = builder.Set("avatar_url", *v) + set, args = append(set, "avatar_url = "+placeholder(len(args)+1)), append(args, *v) } if v := update.PasswordHash; v != nil { - builder = builder.Set("password_hash", *v) - } - - // Add the WHERE clause - builder = builder.Where(squirrel.Eq{"id": update.ID}) - - // Prepare the final query - query, args, err := builder.ToSql() - if err != nil { - return nil, err - } - - // Execute the query with the context - if _, err := d.db.ExecContext(ctx, query, args...); err != nil { + set, args = append(set, "password_hash = "+placeholder(len(args)+1)), append(args, *v) + } + + query := ` + UPDATE "user" + SET ` + strings.Join(set, ", ") + ` + WHERE id = ` + placeholder(len(args)+1) + ` + RETURNING id, username, role, email, nickname, password_hash, avatar_url, created_ts, updated_ts, row_status + ` + args = append(args, update.ID) + user := &store.User{} + if err := d.db.QueryRowContext(ctx, query, args...).Scan( + &user.ID, + &user.Username, + &user.Role, + &user.Email, + &user.Nickname, + &user.PasswordHash, + &user.AvatarURL, + &user.CreatedTs, + &user.UpdatedTs, + &user.RowStatus, + ); err != nil { return nil, err } - // Retrieve the updated user - user, err := d.GetUser(ctx, &store.FindUser{ID: &update.ID}) - if err != nil { - return nil, err - } return user, nil } func (d *DB) ListUsers(ctx context.Context, find *store.FindUser) ([]*store.User, error) { - // Start building the SELECT statement - builder := squirrel.Select("id", "username", "role", "email", "nickname", "password_hash", "avatar_url", "created_ts", "updated_ts", "row_status"). - From(`"user"`). - PlaceholderFormat(squirrel.Dollar) + where, args := []string{"1 = 1"}, []any{} - // 1 = 1 is often used as a no-op in SQL, ensuring there's always a WHERE clause - builder = builder.Where("1 = 1") - - // Conditionally add where clauses if v := find.ID; v != nil { - builder = builder.Where(squirrel.Eq{"id": *v}) + where, args = append(where, "id = "+placeholder(len(args)+1)), append(args, *v) } if v := find.Username; v != nil { - builder = builder.Where(squirrel.Eq{"username": *v}) + where, args = append(where, "username = "+placeholder(len(args)+1)), append(args, *v) } if v := find.Role; v != nil { - builder = builder.Where(squirrel.Eq{"role": *v}) + where, args = append(where, "role = "+placeholder(len(args)+1)), append(args, *v) } if v := find.Email; v != nil { - builder = builder.Where(squirrel.Eq{"email": *v}) + where, args = append(where, "email = "+placeholder(len(args)+1)), append(args, *v) } if v := find.Nickname; v != nil { - builder = builder.Where(squirrel.Eq{"nickname": *v}) - } - - // Add ordering - builder = builder.OrderBy("created_ts DESC", "row_status DESC") - - // Prepare the final query - query, args, err := builder.ToSql() - if err != nil { - return nil, err - } - - // Execute the query with the context + where, args = append(where, "nickname = "+placeholder(len(args)+1)), append(args, *v) + } + + query := ` + SELECT + id, + username, + role, + email, + nickname, + password_hash, + avatar_url, + created_ts, + updated_ts, + row_status + FROM "user" + WHERE ` + strings.Join(where, " AND ") + ` + ORDER BY created_ts DESC, row_status DESC + ` rows, err := d.db.QueryContext(ctx, query, args...) if err != nil { return nil, err @@ -161,35 +142,13 @@ func (d *DB) ListUsers(ctx context.Context, find *store.FindUser) ([]*store.User return list, nil } -func (d *DB) GetUser(ctx context.Context, find *store.FindUser) (*store.User, error) { - list, err := d.ListUsers(ctx, find) - if err != nil { - return nil, err - } - if len(list) != 1 { - return nil, errors.Wrapf(nil, "unexpected user count: %d", len(list)) - } - return list[0], nil -} - func (d *DB) DeleteUser(ctx context.Context, delete *store.DeleteUser) error { - // Start building the DELETE statement - builder := squirrel.Delete(`"user"`). - PlaceholderFormat(squirrel.Dollar). - Where(squirrel.Eq{"id": delete.ID}) - - // Prepare the final query - query, args, err := builder.ToSql() + result, err := d.db.ExecContext(ctx, ` + DELETE FROM "user" WHERE id = $1 + `, delete.ID) if err != nil { return err } - - // Execute the query with the context - result, err := d.db.ExecContext(ctx, query, args...) - if err != nil { - return err - } - if _, err := result.RowsAffected(); err != nil { return err } diff --git a/store/db/postgres/user_setting.go b/store/db/postgres/user_setting.go index 7e25e3a7..e76a3a3f 100644 --- a/store/db/postgres/user_setting.go +++ b/store/db/postgres/user_setting.go @@ -3,7 +3,6 @@ package postgres import ( "context" "database/sql" - "fmt" "github.com/Masterminds/squirrel" "github.com/pkg/errors" @@ -130,22 +129,15 @@ func (d *DB) ListUserSettings(ctx context.Context, find *store.FindUserSetting) } func vacuumUserSetting(ctx context.Context, tx *sql.Tx) error { - // First, build the subquery - subQuery, subArgs, err := squirrel.Select("id").From(`"user"`).PlaceholderFormat(squirrel.Dollar).ToSql() + stmt := ` + DELETE FROM + user_setting + WHERE + user_id NOT IN (SELECT id FROM "user")` + _, err := tx.ExecContext(ctx, stmt) if err != nil { return err } - // Now, build the main delete query using the subquery - query, args, err := squirrel.Delete("user_setting"). - Where(fmt.Sprintf("user_id NOT IN (%s)", subQuery), subArgs...). - PlaceholderFormat(squirrel.Dollar). - ToSql() - if err != nil { - return err - } - - // Execute the query - _, err = tx.ExecContext(ctx, query, args...) - return err + return nil } diff --git a/store/db/sqlite/activity.go b/store/db/sqlite/activity.go index 9dbf26d2..8bbe331c 100644 --- a/store/db/sqlite/activity.go +++ b/store/db/sqlite/activity.go @@ -38,7 +38,6 @@ func (d *DB) CreateActivity(ctx context.Context, create *store.Activity) (*store func (d *DB) ListActivities(ctx context.Context, find *store.FindActivity) ([]*store.Activity, error) { where, args := []string{"1 = 1"}, []any{} - if find.ID != nil { where, args = append(where, "`id` = ?"), append(args, *find.ID) } diff --git a/test/store/memo_organizer_test.go b/test/store/memo_organizer_test.go new file mode 100644 index 00000000..c2370684 --- /dev/null +++ b/test/store/memo_organizer_test.go @@ -0,0 +1,62 @@ +package teststore + +import ( + "context" + "testing" + + "github.com/stretchr/testify/require" + + "github.com/usememos/memos/store" +) + +func TestMemoOrganizerStore(t *testing.T) { + ctx := context.Background() + ts := NewTestingStore(ctx, t) + user, err := createTestingHostUser(ctx, ts) + require.NoError(t, err) + memoCreate := &store.Memo{ + CreatorID: user.ID, + Content: "main memo content", + Visibility: store.Public, + } + memo, err := ts.CreateMemo(ctx, memoCreate) + require.NoError(t, err) + require.Equal(t, memoCreate.Content, memo.Content) + + memoOrganizer, err := ts.UpsertMemoOrganizer(ctx, &store.MemoOrganizer{ + MemoID: memo.ID, + UserID: user.ID, + Pinned: true, + }) + require.NoError(t, err) + require.NotNil(t, memoOrganizer) + require.Equal(t, memo.ID, memoOrganizer.MemoID) + require.Equal(t, user.ID, memoOrganizer.UserID) + require.Equal(t, true, memoOrganizer.Pinned) + + memoOrganizerTemp, err := ts.GetMemoOrganizer(ctx, &store.FindMemoOrganizer{ + MemoID: memo.ID, + }) + require.NoError(t, err) + require.Equal(t, memoOrganizer, memoOrganizerTemp) + memoOrganizerTemp, err = ts.UpsertMemoOrganizer(ctx, &store.MemoOrganizer{ + MemoID: memo.ID, + UserID: user.ID, + Pinned: false, + }) + require.NoError(t, err) + require.NotNil(t, memoOrganizerTemp) + require.Equal(t, memo.ID, memoOrganizerTemp.MemoID) + require.Equal(t, user.ID, memoOrganizerTemp.UserID) + require.Equal(t, false, memoOrganizerTemp.Pinned) + err = ts.DeleteMemoOrganizer(ctx, &store.DeleteMemoOrganizer{ + MemoID: &memo.ID, + UserID: &user.ID, + }) + require.NoError(t, err) + memoOrganizers, err := ts.ListMemoOrganizer(ctx, &store.FindMemoOrganizer{ + UserID: user.ID, + }) + require.NoError(t, err) + require.Equal(t, 0, len(memoOrganizers)) +} diff --git a/test/store/memo_test.go b/test/store/memo_test.go index e1df1f77..9f8db0da 100644 --- a/test/store/memo_test.go +++ b/test/store/memo_test.go @@ -59,3 +59,22 @@ func TestMemoStore(t *testing.T) { require.NoError(t, err) require.Equal(t, 0, len(memoList)) } + +func TestDeleteMemoStore(t *testing.T) { + ctx := context.Background() + ts := NewTestingStore(ctx, t) + user, err := createTestingHostUser(ctx, ts) + require.NoError(t, err) + memoCreate := &store.Memo{ + CreatorID: user.ID, + Content: "test_content", + Visibility: store.Public, + } + memo, err := ts.CreateMemo(ctx, memoCreate) + require.NoError(t, err) + require.Equal(t, memoCreate.Content, memo.Content) + err = ts.DeleteMemo(ctx, &store.DeleteMemo{ + ID: memo.ID, + }) + require.NoError(t, err) +} diff --git a/test/store/store.go b/test/store/store.go index 8d1b128e..e05af5f0 100644 --- a/test/store/store.go +++ b/test/store/store.go @@ -8,6 +8,7 @@ import ( // sqlite driver. _ "modernc.org/sqlite" + "github.com/usememos/memos/server/profile" "github.com/usememos/memos/store" "github.com/usememos/memos/store/db" "github.com/usememos/memos/test" @@ -19,6 +20,7 @@ func NewTestingStore(ctx context.Context, t *testing.T) *store.Store { if err != nil { fmt.Printf("failed to create db driver, error: %+v\n", err) } + resetTestingDB(ctx, profile, dbDriver) if err := dbDriver.Migrate(ctx); err != nil { fmt.Printf("failed to migrate db, error: %+v\n", err) } @@ -26,3 +28,13 @@ func NewTestingStore(ctx context.Context, t *testing.T) *store.Store { store := store.New(dbDriver, profile) return store } + +func resetTestingDB(ctx context.Context, profile *profile.Profile, dbDriver store.Driver) { + if profile.Driver == "postgres" { + _, err := dbDriver.GetDB().ExecContext(ctx, `DROP SCHEMA public CASCADE; CREATE SCHEMA public;`) + if err != nil { + fmt.Printf("failed to reset testing db, error: %+v\n", err) + panic(err) + } + } +}