From 4c33d8d762aeb95a18c867001d12eeefac0adfc0 Mon Sep 17 00:00:00 2001 From: boojack Date: Thu, 20 Jul 2023 23:15:56 +0800 Subject: [PATCH] chore: remove unused transaction in store (#1995) * chore: remove unused transaction in store * chore: update --- store/activity.go | 14 +- store/db/db.go | 12 +- store/db/migration_history.go | 49 +------ store/idp.go | 174 ++++++++---------------- store/memo.go | 249 +++++++++++++--------------------- store/memo_organizer.go | 43 +----- store/memo_relation.go | 123 ++++++----------- store/memo_resource.go | 118 +++++----------- store/resource.go | 205 +++++++++++----------------- store/shortcut.go | 159 +++++++--------------- store/storage.go | 136 ++++++------------- store/system_setting.go | 100 ++++---------- store/tag.go | 47 +------ store/user.go | 157 +++++++-------------- store/user_setting.go | 93 ++++--------- test/store/store_test.go | 47 +++++++ 16 files changed, 570 insertions(+), 1156 deletions(-) create mode 100644 test/store/store_test.go diff --git a/store/activity.go b/store/activity.go index 1b7a12a58..639bb43e6 100644 --- a/store/activity.go +++ b/store/activity.go @@ -18,13 +18,7 @@ type Activity struct { } func (s *Store) CreateActivity(ctx context.Context, create *Activity) (*Activity, error) { - tx, err := s.db.BeginTx(ctx, nil) - if err != nil { - return nil, err - } - defer tx.Rollback() - - query := ` + stmt := ` INSERT INTO activity ( creator_id, type, @@ -34,17 +28,13 @@ func (s *Store) CreateActivity(ctx context.Context, create *Activity) (*Activity VALUES (?, ?, ?, ?) RETURNING id, created_ts ` - if err := tx.QueryRowContext(ctx, query, create.CreatorID, create.Type, create.Level, create.Payload).Scan( + if err := s.db.QueryRowContext(ctx, stmt, create.CreatorID, create.Type, create.Level, create.Payload).Scan( &create.ID, &create.CreatedTs, ); err != nil { return nil, err } - if err := tx.Commit(); err != nil { - return nil, err - } - activity := create return activity, nil } diff --git a/store/db/db.go b/store/db/db.go index df0cdc531..b57f59cc5 100644 --- a/store/db/db.go +++ b/store/db/db.go @@ -190,21 +190,15 @@ func (db *DB) applyMigrationForMinorVersion(ctx context.Context, minorVersion st } } - tx, err := db.DBInstance.Begin() - if err != nil { - return err - } - defer tx.Rollback() - - // upsert the newest version to migration_history + // Upsert the newest version to migration_history. version := minorVersion + ".0" - if _, err = upsertMigrationHistory(ctx, tx, &MigrationHistoryUpsert{ + if _, err = db.UpsertMigrationHistory(ctx, &MigrationHistoryUpsert{ Version: version, }); err != nil { return fmt.Errorf("failed to upsert migration history with version: %s, err: %w", version, err) } - return tx.Commit() + return nil } func (db *DB) seed(ctx context.Context) error { diff --git a/store/db/migration_history.go b/store/db/migration_history.go index cbda3445b..e4b897e64 100644 --- a/store/db/migration_history.go +++ b/store/db/migration_history.go @@ -2,7 +2,6 @@ package db import ( "context" - "database/sql" "strings" ) @@ -20,40 +19,6 @@ type MigrationHistoryFind struct { } func (db *DB) FindMigrationHistoryList(ctx context.Context, find *MigrationHistoryFind) ([]*MigrationHistory, error) { - tx, err := db.DBInstance.BeginTx(ctx, nil) - if err != nil { - return nil, err - } - defer tx.Rollback() - - list, err := findMigrationHistoryList(ctx, tx, find) - if err != nil { - return nil, err - } - - return list, nil -} - -func (db *DB) UpsertMigrationHistory(ctx context.Context, upsert *MigrationHistoryUpsert) (*MigrationHistory, error) { - tx, err := db.DBInstance.BeginTx(ctx, nil) - if err != nil { - return nil, err - } - defer tx.Rollback() - - migrationHistory, err := upsertMigrationHistory(ctx, tx, upsert) - if err != nil { - return nil, err - } - - if err := tx.Commit(); err != nil { - return nil, err - } - - return migrationHistory, nil -} - -func findMigrationHistoryList(ctx context.Context, tx *sql.Tx, find *MigrationHistoryFind) ([]*MigrationHistory, error) { where, args := []string{"1 = 1"}, []any{} if v := find.Version; v != nil { @@ -69,13 +34,13 @@ func findMigrationHistoryList(ctx context.Context, tx *sql.Tx, find *MigrationHi WHERE ` + strings.Join(where, " AND ") + ` ORDER BY created_ts DESC ` - rows, err := tx.QueryContext(ctx, query, args...) + rows, err := db.DBInstance.QueryContext(ctx, query, args...) if err != nil { return nil, err } defer rows.Close() - migrationHistoryList := make([]*MigrationHistory, 0) + list := make([]*MigrationHistory, 0) for rows.Next() { var migrationHistory MigrationHistory if err := rows.Scan( @@ -85,18 +50,18 @@ func findMigrationHistoryList(ctx context.Context, tx *sql.Tx, find *MigrationHi return nil, err } - migrationHistoryList = append(migrationHistoryList, &migrationHistory) + list = append(list, &migrationHistory) } if err := rows.Err(); err != nil { return nil, err } - return migrationHistoryList, nil + return list, nil } -func upsertMigrationHistory(ctx context.Context, tx *sql.Tx, upsert *MigrationHistoryUpsert) (*MigrationHistory, error) { - query := ` +func (db *DB) UpsertMigrationHistory(ctx context.Context, upsert *MigrationHistoryUpsert) (*MigrationHistory, error) { + stmt := ` INSERT INTO migration_history ( version ) @@ -107,7 +72,7 @@ func upsertMigrationHistory(ctx context.Context, tx *sql.Tx, upsert *MigrationHi RETURNING version, created_ts ` var migrationHistory MigrationHistory - if err := tx.QueryRowContext(ctx, query, upsert.Version).Scan( + if err := db.DBInstance.QueryRowContext(ctx, stmt, upsert.Version).Scan( &migrationHistory.Version, &migrationHistory.CreatedTs, ); err != nil { diff --git a/store/idp.go b/store/idp.go index 4f22a8716..e94eab73f 100644 --- a/store/idp.go +++ b/store/idp.go @@ -2,7 +2,6 @@ package store import ( "context" - "database/sql" "encoding/json" "fmt" "strings" @@ -63,23 +62,18 @@ type DeleteIdentityProvider struct { } func (s *Store) CreateIdentityProvider(ctx context.Context, create *IdentityProvider) (*IdentityProvider, error) { - tx, err := s.db.BeginTx(ctx, nil) - if err != nil { - return nil, err - } - defer tx.Rollback() - var configBytes []byte if create.Type == IdentityProviderOAuth2Type { - configBytes, err = json.Marshal(create.Config.OAuth2Config) + bytes, err := json.Marshal(create.Config.OAuth2Config) if err != nil { return nil, err } + configBytes = bytes } else { return nil, fmt.Errorf("unsupported idp type %s", string(create.Type)) } - query := ` + stmt := ` INSERT INTO idp ( name, type, @@ -89,9 +83,9 @@ func (s *Store) CreateIdentityProvider(ctx context.Context, create *IdentityProv VALUES (?, ?, ?, ?) RETURNING id ` - if err := tx.QueryRowContext( + if err := s.db.QueryRowContext( ctx, - query, + stmt, create.Name, create.Type, create.IdentifierFilter, @@ -102,35 +96,69 @@ func (s *Store) CreateIdentityProvider(ctx context.Context, create *IdentityProv return nil, err } - if err := tx.Commit(); err != nil { - return nil, err - } - identityProvider := create s.idpCache.Store(identityProvider.ID, identityProvider) return identityProvider, nil } func (s *Store) ListIdentityProviders(ctx context.Context, find *FindIdentityProvider) ([]*IdentityProvider, error) { - tx, err := s.db.BeginTx(ctx, nil) - if err != nil { - return nil, err + where, args := []string{"1 = 1"}, []any{} + if v := find.ID; v != nil { + where, args = append(where, fmt.Sprintf("id = $%d", len(args)+1)), append(args, *v) } - defer tx.Rollback() - list, err := listIdentityProviders(ctx, tx, find) + rows, err := s.db.QueryContext(ctx, ` + SELECT + id, + name, + type, + identifier_filter, + config + FROM idp + WHERE `+strings.Join(where, " AND ")+` ORDER BY id ASC`, + args..., + ) if err != nil { return nil, err } + defer rows.Close() + + var identityProviders []*IdentityProvider + for rows.Next() { + var identityProvider IdentityProvider + var identityProviderConfig string + if err := rows.Scan( + &identityProvider.ID, + &identityProvider.Name, + &identityProvider.Type, + &identityProvider.IdentifierFilter, + &identityProviderConfig, + ); err != nil { + return nil, err + } + + if identityProvider.Type == IdentityProviderOAuth2Type { + oauth2Config := &IdentityProviderOAuth2Config{} + if err := json.Unmarshal([]byte(identityProviderConfig), oauth2Config); err != nil { + return nil, err + } + identityProvider.Config = &IdentityProviderConfig{ + OAuth2Config: oauth2Config, + } + } else { + return nil, fmt.Errorf("unsupported idp type %s", string(identityProvider.Type)) + } + identityProviders = append(identityProviders, &identityProvider) + } - if err := tx.Commit(); err != nil { + if err := rows.Err(); err != nil { return nil, err } - for _, item := range list { + for _, item := range identityProviders { s.idpCache.Store(item.ID, item) } - return list, nil + return identityProviders, nil } func (s *Store) GetIdentityProvider(ctx context.Context, find *FindIdentityProvider) (*IdentityProvider, error) { @@ -140,13 +168,7 @@ func (s *Store) GetIdentityProvider(ctx context.Context, find *FindIdentityProvi } } - tx, err := s.db.BeginTx(ctx, nil) - if err != nil { - return nil, err - } - defer tx.Rollback() - - list, err := listIdentityProviders(ctx, tx, find) + list, err := s.ListIdentityProviders(ctx, find) if err != nil { return nil, err } @@ -154,22 +176,12 @@ func (s *Store) GetIdentityProvider(ctx context.Context, find *FindIdentityProvi return nil, nil } - if err := tx.Commit(); err != nil { - return nil, err - } - identityProvider := list[0] s.idpCache.Store(identityProvider.ID, identityProvider) return identityProvider, nil } func (s *Store) UpdateIdentityProvider(ctx context.Context, update *UpdateIdentityProvider) (*IdentityProvider, error) { - tx, err := s.db.BeginTx(ctx, nil) - if err != nil { - return nil, err - } - defer tx.Rollback() - set, args := []string{}, []any{} if v := update.Name; v != nil { set, args = append(set, "name = ?"), append(args, *v) @@ -180,10 +192,11 @@ func (s *Store) UpdateIdentityProvider(ctx context.Context, update *UpdateIdenti if v := update.Config; v != nil { var configBytes []byte if update.Type == IdentityProviderOAuth2Type { - configBytes, err = json.Marshal(update.Config.OAuth2Config) + bytes, err := json.Marshal(update.Config.OAuth2Config) if err != nil { return nil, err } + configBytes = bytes } else { return nil, fmt.Errorf("unsupported idp type %s", string(update.Type)) } @@ -191,7 +204,7 @@ func (s *Store) UpdateIdentityProvider(ctx context.Context, update *UpdateIdenti } args = append(args, update.ID) - query := ` + stmt := ` UPDATE idp SET ` + strings.Join(set, ", ") + ` WHERE id = ? @@ -199,7 +212,7 @@ func (s *Store) UpdateIdentityProvider(ctx context.Context, update *UpdateIdenti ` var identityProvider IdentityProvider var identityProviderConfig string - if err := tx.QueryRowContext(ctx, query, args...).Scan( + if err := s.db.QueryRowContext(ctx, stmt, args...).Scan( &identityProvider.ID, &identityProvider.Name, &identityProvider.Type, @@ -221,93 +234,20 @@ func (s *Store) UpdateIdentityProvider(ctx context.Context, update *UpdateIdenti return nil, fmt.Errorf("unsupported idp type %s", string(identityProvider.Type)) } - if err := tx.Commit(); err != nil { - return nil, err - } - s.idpCache.Store(identityProvider.ID, identityProvider) return &identityProvider, nil } func (s *Store) DeleteIdentityProvider(ctx context.Context, delete *DeleteIdentityProvider) error { - tx, err := s.db.BeginTx(ctx, nil) - if err != nil { - return err - } - defer tx.Rollback() - where, args := []string{"id = ?"}, []any{delete.ID} stmt := `DELETE FROM idp WHERE ` + strings.Join(where, " AND ") - result, err := tx.ExecContext(ctx, stmt, args...) + result, err := s.db.ExecContext(ctx, stmt, args...) if err != nil { return err } - if _, err = result.RowsAffected(); err != nil { return err } - - if err := tx.Commit(); err != nil { - return err - } - s.idpCache.Delete(delete.ID) return nil } - -func listIdentityProviders(ctx context.Context, tx *sql.Tx, find *FindIdentityProvider) ([]*IdentityProvider, error) { - where, args := []string{"1 = 1"}, []any{} - if v := find.ID; v != nil { - where, args = append(where, fmt.Sprintf("id = $%d", len(args)+1)), append(args, *v) - } - - rows, err := tx.QueryContext(ctx, ` - SELECT - id, - name, - type, - identifier_filter, - config - FROM idp - WHERE `+strings.Join(where, " AND ")+` ORDER BY id ASC`, - args..., - ) - if err != nil { - return nil, err - } - defer rows.Close() - - var identityProviders []*IdentityProvider - for rows.Next() { - var identityProvider IdentityProvider - var identityProviderConfig string - if err := rows.Scan( - &identityProvider.ID, - &identityProvider.Name, - &identityProvider.Type, - &identityProvider.IdentifierFilter, - &identityProviderConfig, - ); err != nil { - return nil, err - } - - if identityProvider.Type == IdentityProviderOAuth2Type { - oauth2Config := &IdentityProviderOAuth2Config{} - if err := json.Unmarshal([]byte(identityProviderConfig), oauth2Config); err != nil { - return nil, err - } - identityProvider.Config = &IdentityProviderConfig{ - OAuth2Config: oauth2Config, - } - } else { - return nil, fmt.Errorf("unsupported idp type %s", string(identityProvider.Type)) - } - identityProviders = append(identityProviders, &identityProvider) - } - - if err := rows.Err(); err != nil { - return nil, err - } - - return identityProviders, nil -} diff --git a/store/memo.go b/store/memo.go index fba079c3e..83c7b0f97 100644 --- a/store/memo.go +++ b/store/memo.go @@ -84,17 +84,11 @@ type DeleteMemo struct { } func (s *Store) CreateMemo(ctx context.Context, create *Memo) (*Memo, error) { - tx, err := s.db.BeginTx(ctx, nil) - if err != nil { - return nil, err - } - defer tx.Rollback() - if create.CreatedTs == 0 { create.CreatedTs = time.Now().Unix() } - query := ` + stmt := ` INSERT INTO memo ( creator_id, created_ts, @@ -104,9 +98,9 @@ func (s *Store) CreateMemo(ctx context.Context, create *Memo) (*Memo, error) { VALUES (?, ?, ?, ?) RETURNING id, created_ts, updated_ts, row_status ` - if err := tx.QueryRowContext( + if err := s.db.QueryRowContext( ctx, - query, + stmt, create.CreatorID, create.CreatedTs, create.Content, @@ -119,155 +113,12 @@ func (s *Store) CreateMemo(ctx context.Context, create *Memo) (*Memo, error) { ); err != nil { return nil, err } - if err := tx.Commit(); err != nil { - return nil, err - } memo := create return memo, nil } func (s *Store) ListMemos(ctx context.Context, find *FindMemo) ([]*Memo, error) { - tx, err := s.db.BeginTx(ctx, nil) - if err != nil { - return nil, err - } - defer tx.Rollback() - - list, err := listMemos(ctx, tx, find) - if err != nil { - return nil, err - } - - if err := tx.Commit(); err != nil { - return nil, err - } - - return list, nil -} - -func (s *Store) GetMemo(ctx context.Context, find *FindMemo) (*Memo, error) { - tx, err := s.db.BeginTx(ctx, nil) - if err != nil { - return nil, err - } - defer tx.Rollback() - - list, err := listMemos(ctx, tx, find) - if err != nil { - return nil, err - } - if len(list) == 0 { - return nil, nil - } - - if err := tx.Commit(); err != nil { - return nil, err - } - - memo := list[0] - return memo, nil -} - -func (s *Store) UpdateMemo(ctx context.Context, update *UpdateMemo) error { - tx, err := s.db.BeginTx(ctx, nil) - if err != nil { - return err - } - defer tx.Rollback() - - set, args := []string{}, []any{} - if v := update.CreatedTs; v != nil { - set, args = append(set, "created_ts = ?"), append(args, *v) - } - if v := update.UpdatedTs; v != nil { - set, args = append(set, "updated_ts = ?"), append(args, *v) - } - if v := update.RowStatus; v != nil { - set, args = append(set, "row_status = ?"), append(args, *v) - } - if v := update.Content; v != nil { - set, args = append(set, "content = ?"), append(args, *v) - } - if v := update.Visibility; v != nil { - set, args = append(set, "visibility = ?"), append(args, *v) - } - args = append(args, update.ID) - - query := ` - UPDATE memo - SET ` + strings.Join(set, ", ") + ` - WHERE id = ? - ` - if _, err := tx.ExecContext(ctx, query, args...); err != nil { - return err - } - err = tx.Commit() - return err -} - -func (s *Store) DeleteMemo(ctx context.Context, delete *DeleteMemo) error { - tx, err := s.db.BeginTx(ctx, nil) - if err != nil { - return err - } - defer tx.Rollback() - - where, args := []string{"id = ?"}, []any{delete.ID} - stmt := `DELETE FROM memo WHERE ` + strings.Join(where, " AND ") - _, err = tx.ExecContext(ctx, stmt, args...) - if err != nil { - return err - } - - if err := s.vacuumImpl(ctx, tx); err != nil { - return err - } - err = tx.Commit() - return err -} - -func (s *Store) FindMemosVisibilityList(ctx context.Context, memoIDs []int) ([]Visibility, error) { - tx, err := s.db.BeginTx(ctx, nil) - if err != nil { - return nil, err - } - defer tx.Rollback() - - args := make([]any, 0, len(memoIDs)) - list := make([]string, 0, len(memoIDs)) - for _, memoID := range memoIDs { - args = append(args, memoID) - list = append(list, "?") - } - - where := fmt.Sprintf("id in (%s)", strings.Join(list, ",")) - - query := `SELECT DISTINCT(visibility) FROM memo WHERE ` + where - - rows, err := tx.QueryContext(ctx, query, args...) - if err != nil { - return nil, err - } - defer rows.Close() - - visibilityList := make([]Visibility, 0) - for rows.Next() { - var visibility Visibility - if err := rows.Scan(&visibility); err != nil { - return nil, err - } - visibilityList = append(visibilityList, visibility) - } - - if err := rows.Err(); err != nil { - return nil, err - } - - return visibilityList, nil -} - -func listMemos(ctx context.Context, tx *sql.Tx, find *FindMemo) ([]*Memo, error) { where, args := []string{"1 = 1"}, []any{} if v := find.ID; v != nil { @@ -341,7 +192,7 @@ func listMemos(ctx context.Context, tx *sql.Tx, find *FindMemo) ([]*Memo, error) } } - rows, err := tx.QueryContext(ctx, query, args...) + rows, err := s.db.QueryContext(ctx, query, args...) if err != nil { return nil, err } @@ -407,6 +258,98 @@ func listMemos(ctx context.Context, tx *sql.Tx, find *FindMemo) ([]*Memo, error) return list, nil } +func (s *Store) GetMemo(ctx context.Context, find *FindMemo) (*Memo, error) { + list, err := s.ListMemos(ctx, find) + if err != nil { + return nil, err + } + if len(list) == 0 { + return nil, nil + } + + memo := list[0] + return memo, nil +} + +func (s *Store) UpdateMemo(ctx context.Context, update *UpdateMemo) error { + set, args := []string{}, []any{} + if v := update.CreatedTs; v != nil { + set, args = append(set, "created_ts = ?"), append(args, *v) + } + if v := update.UpdatedTs; v != nil { + set, args = append(set, "updated_ts = ?"), append(args, *v) + } + if v := update.RowStatus; v != nil { + set, args = append(set, "row_status = ?"), append(args, *v) + } + if v := update.Content; v != nil { + set, args = append(set, "content = ?"), append(args, *v) + } + if v := update.Visibility; v != nil { + set, args = append(set, "visibility = ?"), append(args, *v) + } + args = append(args, update.ID) + + stmt := ` + UPDATE memo + SET ` + strings.Join(set, ", ") + ` + WHERE id = ? + ` + if _, err := s.db.ExecContext(ctx, stmt, args...); err != nil { + return err + } + return nil +} + +func (s *Store) DeleteMemo(ctx context.Context, delete *DeleteMemo) error { + where, args := []string{"id = ?"}, []any{delete.ID} + stmt := `DELETE FROM memo WHERE ` + strings.Join(where, " AND ") + result, err := s.db.ExecContext(ctx, stmt, args...) + if err != nil { + return err + } + if _, err := result.RowsAffected(); err != nil { + return err + } + if err := s.Vacuum(ctx); err != nil { + // Prevent linter warning. + return err + } + return nil +} + +func (s *Store) FindMemosVisibilityList(ctx context.Context, memoIDs []int) ([]Visibility, error) { + args := make([]any, 0, len(memoIDs)) + list := make([]string, 0, len(memoIDs)) + for _, memoID := range memoIDs { + args = append(args, memoID) + list = append(list, "?") + } + + where := fmt.Sprintf("id in (%s)", strings.Join(list, ",")) + query := `SELECT DISTINCT(visibility) FROM memo WHERE ` + where + rows, err := s.db.QueryContext(ctx, query, args...) + if err != nil { + return nil, err + } + defer rows.Close() + + visibilityList := make([]Visibility, 0) + for rows.Next() { + var visibility Visibility + if err := rows.Scan(&visibility); err != nil { + return nil, err + } + visibilityList = append(visibilityList, visibility) + } + + if err := rows.Err(); err != nil { + return nil, err + } + + return visibilityList, nil +} + func vacuumMemo(ctx context.Context, tx *sql.Tx) error { stmt := ` DELETE FROM diff --git a/store/memo_organizer.go b/store/memo_organizer.go index 1047f30c7..910440a8e 100644 --- a/store/memo_organizer.go +++ b/store/memo_organizer.go @@ -24,13 +24,7 @@ type DeleteMemoOrganizer struct { } func (s *Store) UpsertMemoOrganizer(ctx context.Context, upsert *MemoOrganizer) (*MemoOrganizer, error) { - tx, err := s.db.BeginTx(ctx, nil) - if err != nil { - return nil, err - } - defer tx.Rollback() - - query := ` + stmt := ` INSERT INTO memo_organizer ( memo_id, user_id, @@ -41,11 +35,7 @@ func (s *Store) UpsertMemoOrganizer(ctx context.Context, upsert *MemoOrganizer) SET pinned = EXCLUDED.pinned ` - if _, err := tx.ExecContext(ctx, query, upsert.MemoID, upsert.UserID, upsert.Pinned); err != nil { - return nil, err - } - - if err := tx.Commit(); err != nil { + if _, err := s.db.ExecContext(ctx, stmt, upsert.MemoID, upsert.UserID, upsert.Pinned); err != nil { return nil, err } @@ -54,12 +44,6 @@ func (s *Store) UpsertMemoOrganizer(ctx context.Context, upsert *MemoOrganizer) } func (s *Store) GetMemoOrganizer(ctx context.Context, find *FindMemoOrganizer) (*MemoOrganizer, error) { - tx, err := s.db.BeginTx(ctx, nil) - if err != nil { - return nil, err - } - defer tx.Rollback() - where, args := []string{}, []any{} if find.MemoID != 0 { where = append(where, "memo_id = ?") @@ -78,7 +62,7 @@ func (s *Store) GetMemoOrganizer(ctx context.Context, find *FindMemoOrganizer) ( FROM memo_organizer WHERE %s `, strings.Join(where, " AND ")) - row := tx.QueryRowContext(ctx, query, args...) + row := s.db.QueryRowContext(ctx, query, args...) if err := row.Err(); err != nil { return nil, err } @@ -95,40 +79,21 @@ func (s *Store) GetMemoOrganizer(ctx context.Context, find *FindMemoOrganizer) ( return nil, err } - if err := tx.Commit(); err != nil { - return nil, err - } - return memoOrganizer, nil } func (s *Store) DeleteMemoOrganizer(ctx context.Context, delete *DeleteMemoOrganizer) error { - tx, err := s.db.BeginTx(ctx, nil) - if err != nil { - return err - } - defer tx.Rollback() - where, args := []string{}, []any{} - if v := delete.MemoID; v != nil { where, args = append(where, "memo_id = ?"), append(args, *v) } if v := delete.UserID; v != nil { where, args = append(where, "user_id = ?"), append(args, *v) } - stmt := `DELETE FROM memo_organizer WHERE ` + strings.Join(where, " AND ") - _, err = tx.ExecContext(ctx, stmt, args...) - if err != nil { - return err - } - - if err := tx.Commit(); err != nil { - // Prevent linter warning. + if _, err := s.db.ExecContext(ctx, stmt, args...); err != nil { return err } - return nil } diff --git a/store/memo_relation.go b/store/memo_relation.go index 3230b54ac..d32a7f353 100644 --- a/store/memo_relation.go +++ b/store/memo_relation.go @@ -32,13 +32,7 @@ type DeleteMemoRelation struct { } func (s *Store) UpsertMemoRelation(ctx context.Context, create *MemoRelation) (*MemoRelation, error) { - tx, err := s.db.BeginTx(ctx, nil) - if err != nil { - return nil, err - } - defer tx.Rollback() - - query := ` + stmt := ` INSERT INTO memo_relation ( memo_id, related_memo_id, @@ -50,9 +44,9 @@ func (s *Store) UpsertMemoRelation(ctx context.Context, create *MemoRelation) (* RETURNING memo_id, related_memo_id, type ` memoRelation := &MemoRelation{} - if err := tx.QueryRowContext( + if err := s.db.QueryRowContext( ctx, - query, + stmt, create.MemoID, create.RelatedMemoID, create.Type, @@ -64,26 +58,47 @@ func (s *Store) UpsertMemoRelation(ctx context.Context, create *MemoRelation) (* return nil, err } - if err := tx.Commit(); err != nil { - return nil, err - } - return memoRelation, nil } func (s *Store) ListMemoRelations(ctx context.Context, find *FindMemoRelation) ([]*MemoRelation, error) { - tx, err := s.db.BeginTx(ctx, nil) - if err != nil { - return nil, err + where, args := []string{"TRUE"}, []any{} + if find.MemoID != nil { + where, args = append(where, "memo_id = ?"), append(args, find.MemoID) + } + if find.RelatedMemoID != nil { + where, args = append(where, "related_memo_id = ?"), append(args, find.RelatedMemoID) + } + if find.Type != nil { + where, args = append(where, "type = ?"), append(args, find.Type) } - defer tx.Rollback() - list, err := listMemoRelations(ctx, tx, find) + rows, err := s.db.QueryContext(ctx, ` + SELECT + memo_id, + related_memo_id, + type + FROM memo_relation + WHERE `+strings.Join(where, " AND "), args...) if err != nil { return nil, err } + defer rows.Close() + + list := []*MemoRelation{} + for rows.Next() { + memoRelation := &MemoRelation{} + if err := rows.Scan( + &memoRelation.MemoID, + &memoRelation.RelatedMemoID, + &memoRelation.Type, + ); err != nil { + return nil, err + } + list = append(list, memoRelation) + } - if err := tx.Commit(); err != nil { + if err := rows.Err(); err != nil { return nil, err } @@ -91,13 +106,7 @@ func (s *Store) ListMemoRelations(ctx context.Context, find *FindMemoRelation) ( } func (s *Store) GetMemoRelation(ctx context.Context, find *FindMemoRelation) (*MemoRelation, error) { - tx, err := s.db.BeginTx(ctx, nil) - if err != nil { - return nil, err - } - defer tx.Rollback() - - list, err := listMemoRelations(ctx, tx, find) + list, err := s.ListMemoRelations(ctx, find) if err != nil { return nil, err } @@ -106,20 +115,10 @@ func (s *Store) GetMemoRelation(ctx context.Context, find *FindMemoRelation) (*M return nil, nil } - if err := tx.Commit(); err != nil { - return nil, err - } - return list[0], nil } func (s *Store) DeleteMemoRelation(ctx context.Context, delete *DeleteMemoRelation) error { - tx, err := s.db.BeginTx(ctx, nil) - if err != nil { - return err - } - defer tx.Rollback() - where, args := []string{"TRUE"}, []any{} if delete.MemoID != nil { where, args = append(where, "memo_id = ?"), append(args, delete.MemoID) @@ -130,63 +129,19 @@ func (s *Store) DeleteMemoRelation(ctx context.Context, delete *DeleteMemoRelati if delete.Type != nil { where, args = append(where, "type = ?"), append(args, delete.Type) } - - query := ` + stmt := ` DELETE FROM memo_relation WHERE ` + strings.Join(where, " AND ") - if _, err := tx.ExecContext(ctx, query, args...); err != nil { + result, err := s.db.ExecContext(ctx, stmt, args...) + if err != nil { return err } - - if err := tx.Commit(); err != nil { - // Prevent lint warning. + if _, err = result.RowsAffected(); err != nil { return err } return nil } -func listMemoRelations(ctx context.Context, tx *sql.Tx, find *FindMemoRelation) ([]*MemoRelation, error) { - where, args := []string{"TRUE"}, []any{} - if find.MemoID != nil { - where, args = append(where, "memo_id = ?"), append(args, find.MemoID) - } - if find.RelatedMemoID != nil { - where, args = append(where, "related_memo_id = ?"), append(args, find.RelatedMemoID) - } - if find.Type != nil { - where, args = append(where, "type = ?"), append(args, find.Type) - } - - rows, err := tx.QueryContext(ctx, ` - SELECT - memo_id, - related_memo_id, - type - FROM memo_relation - WHERE `+strings.Join(where, " AND "), args...) - if err != nil { - return nil, err - } - defer rows.Close() - - memoRelationMessages := []*MemoRelation{} - for rows.Next() { - memoRelationMessage := &MemoRelation{} - if err := rows.Scan( - &memoRelationMessage.MemoID, - &memoRelationMessage.RelatedMemoID, - &memoRelationMessage.Type, - ); err != nil { - return nil, err - } - memoRelationMessages = append(memoRelationMessages, memoRelationMessage) - } - if err := rows.Err(); err != nil { - return nil, err - } - return memoRelationMessages, nil -} - func vacuumMemoRelations(ctx context.Context, tx *sql.Tx) error { if _, err := tx.ExecContext(ctx, ` DELETE FROM memo_relation diff --git a/store/memo_resource.go b/store/memo_resource.go index 41024b37e..80d993d08 100644 --- a/store/memo_resource.go +++ b/store/memo_resource.go @@ -31,12 +31,6 @@ type DeleteMemoResource struct { } func (s *Store) UpsertMemoResource(ctx context.Context, upsert *UpsertMemoResource) (*MemoResource, error) { - tx, err := s.db.BeginTx(ctx, nil) - if err != nil { - return nil, err - } - defer tx.Rollback() - set := []string{"memo_id", "resource_id"} args := []any{upsert.MemoID, upsert.ResourceID} placeholder := []string{"?", "?"} @@ -56,7 +50,7 @@ func (s *Store) UpsertMemoResource(ctx context.Context, upsert *UpsertMemoResour RETURNING memo_id, resource_id, created_ts, updated_ts ` memoResource := &MemoResource{} - if err := tx.QueryRowContext(ctx, query, args...).Scan( + if err := s.db.QueryRowContext(ctx, query, args...).Scan( &memoResource.MemoID, &memoResource.ResourceID, &memoResource.CreatedTs, @@ -65,86 +59,10 @@ func (s *Store) UpsertMemoResource(ctx context.Context, upsert *UpsertMemoResour return nil, err } - if err := tx.Commit(); err != nil { - return nil, err - } - return memoResource, nil } func (s *Store) ListMemoResources(ctx context.Context, find *FindMemoResource) ([]*MemoResource, error) { - tx, err := s.db.BeginTx(ctx, nil) - if err != nil { - return nil, err - } - defer tx.Rollback() - - list, err := listMemoResources(ctx, tx, find) - if err != nil { - return nil, err - } - - if err := tx.Commit(); err != nil { - return nil, err - } - - return list, nil -} - -func (s *Store) GetMemoResource(ctx context.Context, find *FindMemoResource) (*MemoResource, error) { - tx, err := s.db.BeginTx(ctx, nil) - if err != nil { - return nil, err - } - defer tx.Rollback() - - list, err := listMemoResources(ctx, tx, find) - if err != nil { - return nil, err - } - if len(list) == 0 { - return nil, nil - } - - if err := tx.Commit(); err != nil { - return nil, err - } - - memoResource := list[0] - return memoResource, nil -} - -func (s *Store) DeleteMemoResource(ctx context.Context, delete *DeleteMemoResource) error { - tx, err := s.db.BeginTx(ctx, nil) - if err != nil { - return err - } - defer tx.Rollback() - - where, args := []string{}, []any{} - - if v := delete.MemoID; v != nil { - where, args = append(where, "memo_id = ?"), append(args, *v) - } - if v := delete.ResourceID; v != nil { - where, args = append(where, "resource_id = ?"), append(args, *v) - } - - stmt := `DELETE FROM memo_resource WHERE ` + strings.Join(where, " AND ") - _, err = tx.ExecContext(ctx, stmt, args...) - if err != nil { - return err - } - - if err := tx.Commit(); err != nil { - // Prevent linter warning. - return err - } - - return nil -} - -func listMemoResources(ctx context.Context, tx *sql.Tx, find *FindMemoResource) ([]*MemoResource, error) { where, args := []string{"1 = 1"}, []any{} if v := find.MemoID; v != nil { @@ -164,7 +82,7 @@ func listMemoResources(ctx context.Context, tx *sql.Tx, find *FindMemoResource) WHERE ` + strings.Join(where, " AND ") + ` ORDER BY updated_ts DESC ` - rows, err := tx.QueryContext(ctx, query, args...) + rows, err := s.db.QueryContext(ctx, query, args...) if err != nil { return nil, err } @@ -192,6 +110,38 @@ func listMemoResources(ctx context.Context, tx *sql.Tx, find *FindMemoResource) return list, nil } +func (s *Store) GetMemoResource(ctx context.Context, find *FindMemoResource) (*MemoResource, error) { + list, err := s.ListMemoResources(ctx, find) + if err != nil { + return nil, err + } + if len(list) == 0 { + return nil, nil + } + + memoResource := list[0] + return memoResource, nil +} + +func (s *Store) DeleteMemoResource(ctx context.Context, delete *DeleteMemoResource) error { + where, args := []string{}, []any{} + if v := delete.MemoID; v != nil { + where, args = append(where, "memo_id = ?"), append(args, *v) + } + if v := delete.ResourceID; v != nil { + where, args = append(where, "resource_id = ?"), append(args, *v) + } + stmt := `DELETE FROM memo_resource WHERE ` + strings.Join(where, " AND ") + result, err := s.db.ExecContext(ctx, stmt, args...) + if err != nil { + return err + } + if _, err = result.RowsAffected(); err != nil { + return err + } + return nil +} + func vacuumMemoResource(ctx context.Context, tx *sql.Tx) error { stmt := ` DELETE FROM diff --git a/store/resource.go b/store/resource.go index 6fbfebd51..7700124c1 100644 --- a/store/resource.go +++ b/store/resource.go @@ -46,13 +46,7 @@ type DeleteResource struct { } func (s *Store) CreateResource(ctx context.Context, create *Resource) (*Resource, error) { - tx, err := s.db.BeginTx(ctx, nil) - if err != nil { - return nil, err - } - defer tx.Rollback() - - if err := tx.QueryRowContext(ctx, ` + stmt := ` INSERT INTO resource ( filename, blob, @@ -64,131 +58,26 @@ func (s *Store) CreateResource(ctx context.Context, create *Resource) (*Resource ) VALUES (?, ?, ?, ?, ?, ?, ?) RETURNING id, created_ts, updated_ts - `, - create.Filename, create.Blob, create.ExternalLink, create.Type, create.Size, create.CreatorID, create.InternalPath, + ` + if err := s.db.QueryRowContext( + ctx, + stmt, + create.Filename, + create.Blob, + create.ExternalLink, + create.Type, + create.Size, + create.CreatorID, + create.InternalPath, ).Scan(&create.ID, &create.CreatedTs, &create.UpdatedTs); err != nil { return nil, err } - if err := tx.Commit(); err != nil { - return nil, err - } - resource := create return resource, nil } func (s *Store) ListResources(ctx context.Context, find *FindResource) ([]*Resource, error) { - tx, err := s.db.BeginTx(ctx, nil) - if err != nil { - return nil, err - } - defer tx.Rollback() - - resources, err := listResources(ctx, tx, find) - if err != nil { - return nil, err - } - - if err := tx.Commit(); err != nil { - return nil, err - } - - return resources, nil -} - -func (s *Store) GetResource(ctx context.Context, find *FindResource) (*Resource, error) { - tx, err := s.db.BeginTx(ctx, nil) - if err != nil { - return nil, err - } - defer tx.Rollback() - - resources, err := listResources(ctx, tx, find) - if err != nil { - return nil, err - } - - if len(resources) == 0 { - return nil, nil - } - - if err := tx.Commit(); err != nil { - return nil, err - } - - return resources[0], nil -} - -func (s *Store) UpdateResource(ctx context.Context, update *UpdateResource) (*Resource, error) { - tx, err := s.db.BeginTx(ctx, nil) - if err != nil { - return nil, err - } - defer tx.Rollback() - - set, args := []string{}, []any{} - - if v := update.UpdatedTs; v != nil { - set, args = append(set, "updated_ts = ?"), append(args, *v) - } - if v := update.Filename; v != nil { - set, args = append(set, "filename = ?"), append(args, *v) - } - - args = append(args, update.ID) - fields := []string{"id", "filename", "external_link", "type", "size", "creator_id", "created_ts", "updated_ts", "internal_path"} - query := ` - UPDATE resource - SET ` + strings.Join(set, ", ") + ` - WHERE id = ? - RETURNING ` + strings.Join(fields, ", ") - resource := Resource{} - dests := []any{ - &resource.ID, - &resource.Filename, - &resource.ExternalLink, - &resource.Type, - &resource.Size, - &resource.CreatorID, - &resource.CreatedTs, - &resource.UpdatedTs, - &resource.InternalPath, - } - if err := tx.QueryRowContext(ctx, query, args...).Scan(dests...); err != nil { - return nil, err - } - - if err := tx.Commit(); err != nil { - return nil, err - } - - return &resource, nil -} - -func (s *Store) DeleteResource(ctx context.Context, delete *DeleteResource) error { - tx, err := s.db.BeginTx(ctx, nil) - if err != nil { - return err - } - defer tx.Rollback() - - if _, err := tx.ExecContext(ctx, ` - DELETE FROM resource - WHERE id = ? - `, delete.ID); err != nil { - return err - } - - if err := tx.Commit(); err != nil { - // Prevent linter warning. - return err - } - - return nil -} - -func listResources(ctx context.Context, tx *sql.Tx, find *FindResource) ([]*Resource, error) { where, args := []string{"1 = 1"}, []any{} if v := find.ID; v != nil { @@ -226,7 +115,7 @@ func listResources(ctx context.Context, tx *sql.Tx, find *FindResource) ([]*Reso } } - rows, err := tx.QueryContext(ctx, query, args...) + rows, err := s.db.QueryContext(ctx, query, args...) if err != nil { return nil, err } @@ -263,6 +152,74 @@ func listResources(ctx context.Context, tx *sql.Tx, find *FindResource) ([]*Reso return list, nil } +func (s *Store) GetResource(ctx context.Context, find *FindResource) (*Resource, error) { + resources, err := s.ListResources(ctx, find) + if err != nil { + return nil, err + } + + if len(resources) == 0 { + return nil, nil + } + + return resources[0], nil +} + +func (s *Store) UpdateResource(ctx context.Context, update *UpdateResource) (*Resource, error) { + set, args := []string{}, []any{} + + if v := update.UpdatedTs; v != nil { + set, args = append(set, "updated_ts = ?"), append(args, *v) + } + if v := update.Filename; v != nil { + set, args = append(set, "filename = ?"), append(args, *v) + } + + args = append(args, update.ID) + fields := []string{"id", "filename", "external_link", "type", "size", "creator_id", "created_ts", "updated_ts", "internal_path"} + stmt := ` + UPDATE resource + SET ` + strings.Join(set, ", ") + ` + WHERE id = ? + RETURNING ` + strings.Join(fields, ", ") + resource := Resource{} + dests := []any{ + &resource.ID, + &resource.Filename, + &resource.ExternalLink, + &resource.Type, + &resource.Size, + &resource.CreatorID, + &resource.CreatedTs, + &resource.UpdatedTs, + &resource.InternalPath, + } + if err := s.db.QueryRowContext(ctx, stmt, args...).Scan(dests...); err != nil { + return nil, err + } + + return &resource, nil +} + +func (s *Store) DeleteResource(ctx context.Context, delete *DeleteResource) error { + stmt := ` + DELETE FROM resource + WHERE id = ? + ` + result, err := s.db.ExecContext(ctx, stmt, delete.ID) + if err != nil { + return err + } + if _, err := result.RowsAffected(); err != nil { + return err + } + if err := s.Vacuum(ctx); err != nil { + // Prevent linter warning. + return err + } + return nil +} + func vacuumResource(ctx context.Context, tx *sql.Tx) error { stmt := ` DELETE FROM diff --git a/store/shortcut.go b/store/shortcut.go index 7fb4047fa..969e8b965 100644 --- a/store/shortcut.go +++ b/store/shortcut.go @@ -41,13 +41,7 @@ type DeleteShortcut struct { } func (s *Store) CreateShortcut(ctx context.Context, create *Shortcut) (*Shortcut, error) { - tx, err := s.db.BeginTx(ctx, nil) - if err != nil { - return nil, err - } - defer tx.Rollback() - - query := ` + stmt := ` INSERT INTO shortcut ( title, payload, @@ -56,7 +50,7 @@ func (s *Store) CreateShortcut(ctx context.Context, create *Shortcut) (*Shortcut VALUES (?, ?, ?) RETURNING id, created_ts, updated_ts, row_status ` - if err := tx.QueryRowContext(ctx, query, create.Title, create.Payload, create.CreatorID).Scan( + if err := s.db.QueryRowContext(ctx, stmt, create.Title, create.Payload, create.CreatorID).Scan( &create.ID, &create.CreatedTs, &create.UpdatedTs, @@ -65,27 +59,60 @@ func (s *Store) CreateShortcut(ctx context.Context, create *Shortcut) (*Shortcut return nil, err } - if err := tx.Commit(); err != nil { - return nil, err - } - shortcut := create return shortcut, nil } func (s *Store) ListShortcuts(ctx context.Context, find *FindShortcut) ([]*Shortcut, error) { - tx, err := s.db.BeginTx(ctx, nil) - if err != nil { - return nil, err + where, args := []string{"1 = 1"}, []any{} + + if v := find.ID; v != nil { + where, args = append(where, "id = ?"), append(args, *v) + } + if v := find.CreatorID; v != nil { + where, args = append(where, "creator_id = ?"), append(args, *v) + } + if v := find.Title; v != nil { + where, args = append(where, "title = ?"), append(args, *v) } - defer tx.Rollback() - list, err := listShortcuts(ctx, tx, find) + rows, err := s.db.QueryContext(ctx, ` + SELECT + id, + title, + payload, + creator_id, + created_ts, + updated_ts, + row_status + FROM shortcut + WHERE `+strings.Join(where, " AND ")+` + ORDER BY created_ts DESC`, + args..., + ) if err != nil { return nil, err } + defer rows.Close() + + list := make([]*Shortcut, 0) + for rows.Next() { + var shortcut Shortcut + if err := rows.Scan( + &shortcut.ID, + &shortcut.Title, + &shortcut.Payload, + &shortcut.CreatorID, + &shortcut.CreatedTs, + &shortcut.UpdatedTs, + &shortcut.RowStatus, + ); err != nil { + return nil, err + } + list = append(list, &shortcut) + } - if err := tx.Commit(); err != nil { + if err := rows.Err(); err != nil { return nil, err } @@ -93,13 +120,7 @@ func (s *Store) ListShortcuts(ctx context.Context, find *FindShortcut) ([]*Short } func (s *Store) GetShortcut(ctx context.Context, find *FindShortcut) (*Shortcut, error) { - tx, err := s.db.BeginTx(ctx, nil) - if err != nil { - return nil, err - } - defer tx.Rollback() - - list, err := listShortcuts(ctx, tx, find) + list, err := s.ListShortcuts(ctx, find) if err != nil { return nil, err } @@ -108,21 +129,11 @@ func (s *Store) GetShortcut(ctx context.Context, find *FindShortcut) (*Shortcut, return nil, nil } - if err := tx.Commit(); err != nil { - return nil, err - } - shortcut := list[0] return shortcut, nil } func (s *Store) UpdateShortcut(ctx context.Context, update *UpdateShortcut) (*Shortcut, error) { - tx, err := s.db.BeginTx(ctx, nil) - if err != nil { - return nil, err - } - defer tx.Rollback() - set, args := []string{}, []any{} if v := update.UpdatedTs; v != nil { set, args = append(set, "updated_ts = ?"), append(args, *v) @@ -138,14 +149,14 @@ func (s *Store) UpdateShortcut(ctx context.Context, update *UpdateShortcut) (*Sh } args = append(args, update.ID) - query := ` + stmt := ` UPDATE shortcut SET ` + strings.Join(set, ", ") + ` WHERE id = ? RETURNING id, title, payload, creator_id, created_ts, updated_ts, row_status ` shortcut := &Shortcut{} - if err := tx.QueryRowContext(ctx, query, args...).Scan( + if err := s.db.QueryRowContext(ctx, stmt, args...).Scan( &shortcut.ID, &shortcut.Title, &shortcut.Payload, @@ -157,20 +168,10 @@ func (s *Store) UpdateShortcut(ctx context.Context, update *UpdateShortcut) (*Sh return nil, err } - if err := tx.Commit(); err != nil { - return nil, err - } - return shortcut, nil } func (s *Store) DeleteShortcut(ctx context.Context, delete *DeleteShortcut) error { - tx, err := s.db.BeginTx(ctx, nil) - if err != nil { - return err - } - defer tx.Rollback() - where, args := []string{}, []any{} if v := delete.ID; v != nil { where, args = append(where, "id = ?"), append(args, *v) @@ -178,76 +179,18 @@ func (s *Store) DeleteShortcut(ctx context.Context, delete *DeleteShortcut) erro if v := delete.CreatorID; v != nil { where, args = append(where, "creator_id = ?"), append(args, *v) } - stmt := `DELETE FROM shortcut WHERE ` + strings.Join(where, " AND ") - if _, err := tx.ExecContext(ctx, stmt, args...); err != nil { + result, err := s.db.ExecContext(ctx, stmt, args...) + if err != nil { return err } - - if err := tx.Commit(); err != nil { + if _, err := result.RowsAffected(); err != nil { return err } - s.shortcutCache.Delete(*delete.ID) return nil } -func listShortcuts(ctx context.Context, tx *sql.Tx, find *FindShortcut) ([]*Shortcut, error) { - where, args := []string{"1 = 1"}, []any{} - - if v := find.ID; v != nil { - where, args = append(where, "id = ?"), append(args, *v) - } - if v := find.CreatorID; v != nil { - where, args = append(where, "creator_id = ?"), append(args, *v) - } - if v := find.Title; v != nil { - where, args = append(where, "title = ?"), append(args, *v) - } - - rows, err := tx.QueryContext(ctx, ` - SELECT - id, - title, - payload, - creator_id, - created_ts, - updated_ts, - row_status - FROM shortcut - WHERE `+strings.Join(where, " AND ")+` - ORDER BY created_ts DESC`, - args..., - ) - if err != nil { - return nil, err - } - defer rows.Close() - - list := make([]*Shortcut, 0) - for rows.Next() { - var shortcut Shortcut - if err := rows.Scan( - &shortcut.ID, - &shortcut.Title, - &shortcut.Payload, - &shortcut.CreatorID, - &shortcut.CreatedTs, - &shortcut.UpdatedTs, - &shortcut.RowStatus, - ); err != nil { - return nil, err - } - list = append(list, &shortcut) - } - - if err := rows.Err(); err != nil { - return nil, err - } - - return list, nil -} - func vacuumShortcut(ctx context.Context, tx *sql.Tx) error { stmt := ` DELETE FROM diff --git a/store/storage.go b/store/storage.go index d043c3eb2..8e8c7802c 100644 --- a/store/storage.go +++ b/store/storage.go @@ -2,7 +2,6 @@ package store import ( "context" - "database/sql" "strings" ) @@ -28,13 +27,7 @@ type DeleteStorage struct { } func (s *Store) CreateStorage(ctx context.Context, create *Storage) (*Storage, error) { - tx, err := s.db.BeginTx(ctx, nil) - if err != nil { - return nil, err - } - defer tx.Rollback() - - query := ` + stmt := ` INSERT INTO storage ( name, type, @@ -43,33 +36,53 @@ func (s *Store) CreateStorage(ctx context.Context, create *Storage) (*Storage, e VALUES (?, ?, ?) RETURNING id ` - if err := tx.QueryRowContext(ctx, query, create.Name, create.Type, create.Config).Scan( + if err := s.db.QueryRowContext(ctx, stmt, create.Name, create.Type, create.Config).Scan( &create.ID, ); err != nil { return nil, err } - if err := tx.Commit(); err != nil { - return nil, err - } - storage := create return storage, nil } func (s *Store) ListStorages(ctx context.Context, find *FindStorage) ([]*Storage, error) { - tx, err := s.db.BeginTx(ctx, nil) - if err != nil { - return nil, err + where, args := []string{"1 = 1"}, []any{} + if find.ID != nil { + where, args = append(where, "id = ?"), append(args, *find.ID) } - defer tx.Rollback() - list, err := listStorages(ctx, tx, find) + rows, err := s.db.QueryContext(ctx, ` + SELECT + id, + name, + type, + config + FROM storage + WHERE `+strings.Join(where, " AND ")+` + ORDER BY id DESC`, + args..., + ) if err != nil { return nil, err } + defer rows.Close() + + list := []*Storage{} + for rows.Next() { + storage := &Storage{} + if err := rows.Scan( + &storage.ID, + &storage.Name, + &storage.Type, + &storage.Config, + ); err != nil { + return nil, err + } + list = append(list, storage) + } - if err := tx.Commit(); err != nil { + if err := rows.Err(); err != nil { return nil, err } @@ -77,13 +90,7 @@ func (s *Store) ListStorages(ctx context.Context, find *FindStorage) ([]*Storage } func (s *Store) GetStorage(ctx context.Context, find *FindStorage) (*Storage, error) { - tx, err := s.db.BeginTx(ctx, nil) - if err != nil { - return nil, err - } - defer tx.Rollback() - - list, err := listStorages(ctx, tx, find) + list, err := s.ListStorages(ctx, find) if err != nil { return nil, err } @@ -91,20 +98,10 @@ func (s *Store) GetStorage(ctx context.Context, find *FindStorage) (*Storage, er return nil, nil } - if err := tx.Commit(); err != nil { - return nil, err - } - return list[0], nil } func (s *Store) UpdateStorage(ctx context.Context, update *UpdateStorage) (*Storage, error) { - tx, err := s.db.BeginTx(ctx, nil) - if err != nil { - return nil, err - } - defer tx.Rollback() - set, args := []string{}, []any{} if update.Name != nil { set = append(set, "name = ?") @@ -116,7 +113,7 @@ func (s *Store) UpdateStorage(ctx context.Context, update *UpdateStorage) (*Stor } args = append(args, update.ID) - query := ` + stmt := ` UPDATE storage SET ` + strings.Join(set, ", ") + ` WHERE id = ? @@ -127,7 +124,7 @@ func (s *Store) UpdateStorage(ctx context.Context, update *UpdateStorage) (*Stor config ` storage := &Storage{} - if err := tx.QueryRowContext(ctx, query, args...).Scan( + if err := s.db.QueryRowContext(ctx, stmt, args...).Scan( &storage.ID, &storage.Name, &storage.Type, @@ -136,75 +133,20 @@ func (s *Store) UpdateStorage(ctx context.Context, update *UpdateStorage) (*Stor return nil, err } - if err := tx.Commit(); err != nil { - return nil, err - } - return storage, nil } func (s *Store) DeleteStorage(ctx context.Context, delete *DeleteStorage) error { - tx, err := s.db.BeginTx(ctx, nil) - if err != nil { - return err - } - defer tx.Rollback() - - query := ` + stmt := ` DELETE FROM storage WHERE id = ? ` - if _, err := tx.ExecContext(ctx, query, delete.ID); err != nil { + result, err := s.db.ExecContext(ctx, stmt, delete.ID) + if err != nil { return err } - - if err := tx.Commit(); err != nil { - // Prevent linter warning. + if _, err := result.RowsAffected(); err != nil { return err } - return nil } - -func listStorages(ctx context.Context, tx *sql.Tx, find *FindStorage) ([]*Storage, error) { - where, args := []string{"1 = 1"}, []any{} - if find.ID != nil { - where, args = append(where, "id = ?"), append(args, *find.ID) - } - - rows, err := tx.QueryContext(ctx, ` - SELECT - id, - name, - type, - config - FROM storage - WHERE `+strings.Join(where, " AND ")+` - ORDER BY id DESC`, - args..., - ) - if err != nil { - return nil, err - } - defer rows.Close() - - list := []*Storage{} - for rows.Next() { - storage := &Storage{} - if err := rows.Scan( - &storage.ID, - &storage.Name, - &storage.Type, - &storage.Config, - ); err != nil { - return nil, err - } - list = append(list, storage) - } - - if err := rows.Err(); err != nil { - return nil, err - } - - return list, nil -} diff --git a/store/system_setting.go b/store/system_setting.go index 6c06ff5b3..aa94849f5 100644 --- a/store/system_setting.go +++ b/store/system_setting.go @@ -2,7 +2,6 @@ package store import ( "context" - "database/sql" "strings" ) @@ -17,13 +16,7 @@ type FindSystemSetting struct { } func (s *Store) UpsertSystemSetting(ctx context.Context, upsert *SystemSetting) (*SystemSetting, error) { - tx, err := s.db.BeginTx(ctx, nil) - if err != nil { - return nil, err - } - defer tx.Rollback() - - query := ` + stmt := ` INSERT INTO system_setting ( name, value, description ) @@ -33,11 +26,7 @@ func (s *Store) UpsertSystemSetting(ctx context.Context, upsert *SystemSetting) value = EXCLUDED.value, description = EXCLUDED.description ` - if _, err := tx.ExecContext(ctx, query, upsert.Name, upsert.Value, upsert.Description); err != nil { - return nil, err - } - - if err := tx.Commit(); err != nil { + if _, err := s.db.ExecContext(ctx, stmt, upsert.Name, upsert.Value, upsert.Description); err != nil { return nil, err } @@ -46,18 +35,39 @@ func (s *Store) UpsertSystemSetting(ctx context.Context, upsert *SystemSetting) } func (s *Store) ListSystemSettings(ctx context.Context, find *FindSystemSetting) ([]*SystemSetting, error) { - tx, err := s.db.BeginTx(ctx, nil) - if err != nil { - return nil, err + where, args := []string{"1 = 1"}, []any{} + if find.Name != "" { + where, args = append(where, "name = ?"), append(args, find.Name) } - defer tx.Rollback() - list, err := listSystemSettings(ctx, tx, find) + query := ` + SELECT + name, + value, + description + FROM system_setting + WHERE ` + strings.Join(where, " AND ") + + rows, err := s.db.QueryContext(ctx, query, args...) if err != nil { return nil, err } + defer rows.Close() + + list := []*SystemSetting{} + for rows.Next() { + systemSettingMessage := &SystemSetting{} + if err := rows.Scan( + &systemSettingMessage.Name, + &systemSettingMessage.Value, + &systemSettingMessage.Description, + ); err != nil { + return nil, err + } + list = append(list, systemSettingMessage) + } - if err := tx.Commit(); err != nil { + if err := rows.Err(); err != nil { return nil, err } @@ -74,13 +84,7 @@ func (s *Store) GetSystemSetting(ctx context.Context, find *FindSystemSetting) ( } } - tx, err := s.db.BeginTx(ctx, nil) - if err != nil { - return nil, err - } - defer tx.Rollback() - - list, err := listSystemSettings(ctx, tx, find) + list, err := s.ListSystemSettings(ctx, find) if err != nil { return nil, err } @@ -89,10 +93,6 @@ func (s *Store) GetSystemSetting(ctx context.Context, find *FindSystemSetting) ( return nil, nil } - if err := tx.Commit(); err != nil { - return nil, err - } - systemSettingMessage := list[0] s.systemSettingCache.Store(systemSettingMessage.Name, systemSettingMessage) return systemSettingMessage, nil @@ -106,43 +106,3 @@ func (s *Store) GetSystemSettingValueWithDefault(ctx *context.Context, settingNa } return defaultValue } - -func listSystemSettings(ctx context.Context, tx *sql.Tx, find *FindSystemSetting) ([]*SystemSetting, error) { - where, args := []string{"1 = 1"}, []any{} - if find.Name != "" { - where, args = append(where, "name = ?"), append(args, find.Name) - } - - query := ` - SELECT - name, - value, - description - FROM system_setting - WHERE ` + strings.Join(where, " AND ") - - rows, err := tx.QueryContext(ctx, query, args...) - if err != nil { - return nil, err - } - defer rows.Close() - - list := []*SystemSetting{} - for rows.Next() { - systemSettingMessage := &SystemSetting{} - if err := rows.Scan( - &systemSettingMessage.Name, - &systemSettingMessage.Value, - &systemSettingMessage.Description, - ); err != nil { - return nil, err - } - list = append(list, systemSettingMessage) - } - - if err := rows.Err(); err != nil { - return nil, err - } - - return list, nil -} diff --git a/store/tag.go b/store/tag.go index c6295291c..37b0077f1 100644 --- a/store/tag.go +++ b/store/tag.go @@ -3,7 +3,6 @@ package store import ( "context" "database/sql" - "fmt" "strings" ) @@ -22,13 +21,7 @@ type DeleteTag struct { } func (s *Store) UpsertTag(ctx context.Context, upsert *Tag) (*Tag, error) { - tx, err := s.db.BeginTx(ctx, nil) - if err != nil { - return nil, err - } - defer tx.Rollback() - - query := ` + stmt := ` INSERT INTO tag ( name, creator_id ) @@ -37,11 +30,7 @@ func (s *Store) UpsertTag(ctx context.Context, upsert *Tag) (*Tag, error) { SET name = EXCLUDED.name ` - if _, err := tx.ExecContext(ctx, query, upsert.Name, upsert.CreatorID); err != nil { - return nil, err - } - - if err := tx.Commit(); err != nil { + if _, err := s.db.ExecContext(ctx, stmt, upsert.Name, upsert.CreatorID); err != nil { return nil, err } @@ -50,12 +39,6 @@ func (s *Store) UpsertTag(ctx context.Context, upsert *Tag) (*Tag, error) { } func (s *Store) ListTags(ctx context.Context, find *FindTag) ([]*Tag, error) { - tx, err := s.db.BeginTx(ctx, nil) - if err != nil { - return nil, err - } - defer tx.Rollback() - where, args := []string{"creator_id = ?"}, []any{find.CreatorID} query := ` SELECT @@ -65,7 +48,7 @@ func (s *Store) ListTags(ctx context.Context, find *FindTag) ([]*Tag, error) { WHERE ` + strings.Join(where, " AND ") + ` ORDER BY name ASC ` - rows, err := tx.QueryContext(ctx, query, args...) + rows, err := s.db.QueryContext(ctx, query, args...) if err != nil { return nil, err } @@ -88,37 +71,19 @@ func (s *Store) ListTags(ctx context.Context, find *FindTag) ([]*Tag, error) { return nil, err } - if err := tx.Commit(); err != nil { - return nil, err - } - return list, nil } func (s *Store) DeleteTag(ctx context.Context, delete *DeleteTag) error { - tx, err := s.db.BeginTx(ctx, nil) - if err != nil { - return err - } - defer tx.Rollback() - where, args := []string{"name = ?", "creator_id = ?"}, []any{delete.Name, delete.CreatorID} - query := `DELETE FROM tag WHERE ` + strings.Join(where, " AND ") - result, err := tx.ExecContext(ctx, query, args...) + stmt := `DELETE FROM tag WHERE ` + strings.Join(where, " AND ") + result, err := s.db.ExecContext(ctx, stmt, args...) if err != nil { return err } - - rows, _ := result.RowsAffected() - if rows == 0 { - return fmt.Errorf("tag not found") - } - - if err := tx.Commit(); err != nil { - // Prevent linter warning. + if _, err = result.RowsAffected(); err != nil { return err } - return nil } diff --git a/store/user.go b/store/user.go index 13b71ef44..4ce11d707 100644 --- a/store/user.go +++ b/store/user.go @@ -2,8 +2,6 @@ package store import ( "context" - "database/sql" - "errors" "strings" ) @@ -79,13 +77,7 @@ type DeleteUser struct { } func (s *Store) CreateUser(ctx context.Context, create *User) (*User, error) { - tx, err := s.db.BeginTx(ctx, nil) - if err != nil { - return nil, err - } - defer tx.Rollback() - - query := ` + stmt := ` INSERT INTO user ( username, role, @@ -97,7 +89,9 @@ func (s *Store) CreateUser(ctx context.Context, create *User) (*User, error) { VALUES (?, ?, ?, ?, ?, ?) RETURNING id, avatar_url, created_ts, updated_ts, row_status ` - if err := tx.QueryRowContext(ctx, query, + if err := s.db.QueryRowContext( + ctx, + stmt, create.Username, create.Role, create.Email, @@ -113,9 +107,6 @@ func (s *Store) CreateUser(ctx context.Context, create *User) (*User, error) { ); err != nil { return nil, err } - if err := tx.Commit(); err != nil { - return nil, err - } user := create s.userCache.Store(user.ID, user) @@ -123,12 +114,6 @@ func (s *Store) CreateUser(ctx context.Context, create *User) (*User, error) { } func (s *Store) UpdateUser(ctx context.Context, update *UpdateUser) (*User, error) { - tx, err := s.db.BeginTx(ctx, nil) - if err != nil { - return nil, err - } - defer tx.Rollback() - set, args := []string{}, []any{} if v := update.UpdatedTs; v != nil { set, args = append(set, "updated_ts = ?"), append(args, *v) @@ -163,7 +148,7 @@ func (s *Store) UpdateUser(ctx context.Context, update *UpdateUser) (*User, erro RETURNING id, username, role, email, nickname, password_hash, open_id, avatar_url, created_ts, updated_ts, row_status ` user := &User{} - if err := tx.QueryRowContext(ctx, query, args...).Scan( + if err := s.db.QueryRowContext(ctx, query, args...).Scan( &user.ID, &user.Username, &user.Role, @@ -179,100 +164,11 @@ func (s *Store) UpdateUser(ctx context.Context, update *UpdateUser) (*User, erro return nil, err } - if err := tx.Commit(); err != nil { - return nil, err - } - s.userCache.Store(user.ID, user) return user, nil } func (s *Store) ListUsers(ctx context.Context, find *FindUser) ([]*User, error) { - tx, err := s.db.BeginTx(ctx, nil) - if err != nil { - return nil, err - } - defer tx.Rollback() - - list, err := listUsers(ctx, tx, find) - if err != nil { - return nil, err - } - - if err := tx.Commit(); err != nil { - return nil, err - } - - for _, user := range list { - s.userCache.Store(user.ID, user) - } - return list, nil -} - -func (s *Store) GetUser(ctx context.Context, find *FindUser) (*User, error) { - if find.ID != nil { - if cache, ok := s.userCache.Load(*find.ID); ok { - return cache.(*User), nil - } - } - - tx, err := s.db.BeginTx(ctx, nil) - if err != nil { - return nil, err - } - defer tx.Rollback() - - list, err := listUsers(ctx, tx, find) - if err != nil { - return nil, err - } - if len(list) == 0 { - return nil, nil - } - - if err := tx.Commit(); err != nil { - return nil, err - } - - user := list[0] - s.userCache.Store(user.ID, user) - return user, nil -} - -func (s *Store) DeleteUser(ctx context.Context, delete *DeleteUser) error { - tx, err := s.db.BeginTx(ctx, nil) - if err != nil { - return err - } - defer tx.Rollback() - - result, err := tx.ExecContext(ctx, ` - DELETE FROM user WHERE id = ? - `, delete.ID) - if err != nil { - return err - } - - rows, err := result.RowsAffected() - if err != nil { - return err - } - if rows == 0 { - return errors.New("user not found") - } - if err := s.vacuumImpl(ctx, tx); err != nil { - return err - } - - if err := tx.Commit(); err != nil { - return err - } - - s.userCache.Delete(delete.ID) - return nil -} - -func listUsers(ctx context.Context, tx *sql.Tx, find *FindUser) ([]*User, error) { where, args := []string{"1 = 1"}, []any{} if v := find.ID; v != nil { @@ -311,7 +207,7 @@ func listUsers(ctx context.Context, tx *sql.Tx, find *FindUser) ([]*User, error) WHERE ` + strings.Join(where, " AND ") + ` ORDER BY created_ts DESC, row_status DESC ` - rows, err := tx.QueryContext(ctx, query, args...) + rows, err := s.db.QueryContext(ctx, query, args...) if err != nil { return nil, err } @@ -342,5 +238,46 @@ func listUsers(ctx context.Context, tx *sql.Tx, find *FindUser) ([]*User, error) return nil, err } + for _, user := range list { + s.userCache.Store(user.ID, user) + } return list, nil } + +func (s *Store) GetUser(ctx context.Context, find *FindUser) (*User, error) { + if find.ID != nil { + if cache, ok := s.userCache.Load(*find.ID); ok { + return cache.(*User), nil + } + } + + list, err := s.ListUsers(ctx, find) + if err != nil { + return nil, err + } + if len(list) == 0 { + return nil, nil + } + + user := list[0] + s.userCache.Store(user.ID, user) + return user, nil +} + +func (s *Store) DeleteUser(ctx context.Context, delete *DeleteUser) error { + result, err := s.db.ExecContext(ctx, ` + DELETE FROM user WHERE id = ? + `, delete.ID) + if err != nil { + return err + } + if _, err := result.RowsAffected(); err != nil { + return err + } + if err := s.Vacuum(ctx); err != nil { + // Prevent linter warning. + return err + } + s.userCache.Delete(delete.ID) + return nil +} diff --git a/store/user_setting.go b/store/user_setting.go index 8fd6c2845..fa948167d 100644 --- a/store/user_setting.go +++ b/store/user_setting.go @@ -18,13 +18,7 @@ type FindUserSetting struct { } func (s *Store) UpsertUserSetting(ctx context.Context, upsert *UserSetting) (*UserSetting, error) { - tx, err := s.db.BeginTx(ctx, nil) - if err != nil { - return nil, err - } - defer tx.Rollback() - - query := ` + stmt := ` INSERT INTO user_setting ( user_id, key, value ) @@ -32,11 +26,7 @@ func (s *Store) UpsertUserSetting(ctx context.Context, upsert *UserSetting) (*Us ON CONFLICT(user_id, key) DO UPDATE SET value = EXCLUDED.value ` - if _, err := tx.ExecContext(ctx, query, upsert.UserID, upsert.Key, upsert.Value); err != nil { - return nil, err - } - - if err := tx.Commit(); err != nil { + if _, err := s.db.ExecContext(ctx, stmt, upsert.UserID, upsert.Key, upsert.Value); err != nil { return nil, err } @@ -46,59 +36,6 @@ func (s *Store) UpsertUserSetting(ctx context.Context, upsert *UserSetting) (*Us } func (s *Store) ListUserSettings(ctx context.Context, find *FindUserSetting) ([]*UserSetting, error) { - tx, err := s.db.BeginTx(ctx, nil) - if err != nil { - return nil, err - } - defer tx.Rollback() - - userSettingList, err := listUserSettings(ctx, tx, find) - if err != nil { - return nil, err - } - - if err := tx.Commit(); err != nil { - return nil, err - } - - for _, userSetting := range userSettingList { - s.userSettingCache.Store(getUserSettingCacheKey(userSetting.UserID, userSetting.Key), userSetting) - } - return userSettingList, nil -} - -func (s *Store) GetUserSetting(ctx context.Context, find *FindUserSetting) (*UserSetting, error) { - if find.UserID != nil { - if cache, ok := s.userSettingCache.Load(getUserSettingCacheKey(*find.UserID, find.Key)); ok { - return cache.(*UserSetting), nil - } - } - - tx, err := s.db.BeginTx(ctx, nil) - if err != nil { - return nil, err - } - defer tx.Rollback() - - list, err := listUserSettings(ctx, tx, find) - if err != nil { - return nil, err - } - - if len(list) == 0 { - return nil, nil - } - - if err := tx.Commit(); err != nil { - return nil, err - } - - userSetting := list[0] - s.userSettingCache.Store(getUserSettingCacheKey(userSetting.UserID, userSetting.Key), userSetting) - return userSetting, nil -} - -func listUserSettings(ctx context.Context, tx *sql.Tx, find *FindUserSetting) ([]*UserSetting, error) { where, args := []string{"1 = 1"}, []any{} if v := find.Key; v != "" { @@ -115,7 +52,7 @@ func listUserSettings(ctx context.Context, tx *sql.Tx, find *FindUserSetting) ([ value FROM user_setting WHERE ` + strings.Join(where, " AND ") - rows, err := tx.QueryContext(ctx, query, args...) + rows, err := s.db.QueryContext(ctx, query, args...) if err != nil { return nil, err } @@ -138,9 +75,33 @@ func listUserSettings(ctx context.Context, tx *sql.Tx, find *FindUserSetting) ([ return nil, err } + for _, userSetting := range userSettingList { + s.userSettingCache.Store(getUserSettingCacheKey(userSetting.UserID, userSetting.Key), userSetting) + } return userSettingList, nil } +func (s *Store) GetUserSetting(ctx context.Context, find *FindUserSetting) (*UserSetting, error) { + if find.UserID != nil { + if cache, ok := s.userSettingCache.Load(getUserSettingCacheKey(*find.UserID, find.Key)); ok { + return cache.(*UserSetting), nil + } + } + + list, err := s.ListUserSettings(ctx, find) + if err != nil { + return nil, err + } + + if len(list) == 0 { + return nil, nil + } + + userSetting := list[0] + s.userSettingCache.Store(getUserSettingCacheKey(userSetting.UserID, userSetting.Key), userSetting) + return userSetting, nil +} + func vacuumUserSetting(ctx context.Context, tx *sql.Tx) error { stmt := ` DELETE FROM diff --git a/test/store/store_test.go b/test/store/store_test.go new file mode 100644 index 000000000..b82a1ad3a --- /dev/null +++ b/test/store/store_test.go @@ -0,0 +1,47 @@ +package teststore + +import ( + "context" + "fmt" + "sync" + "testing" + + "github.com/stretchr/testify/require" + "github.com/usememos/memos/store" +) + +func TestConcurrentReadWrite(t *testing.T) { + ctx := context.Background() + ts := NewTestingStore(ctx, t) + user, err := createTestingHostUser(ctx, ts) + require.NoError(t, err) + + const numWorkers = 10 + const numIterations = 100 + + wg := sync.WaitGroup{} + wg.Add(numWorkers) + + for i := 0; i < numWorkers; i++ { + go func() { + for j := 0; j < numIterations; j++ { + _, err := ts.CreateMemo(ctx, &store.Memo{ + CreatorID: user.ID, + Content: fmt.Sprintf("test_content_%d", i), + Visibility: store.Public, + }) + require.NoError(t, err) + } + }() + + go func() { + _, err := ts.ListMemos(ctx, &store.FindMemo{ + CreatorID: &user.ID, + }) + require.NoError(t, err) + wg.Done() + }() + } + + wg.Wait() +}