From 8c28721839078d67425b6ddd74879aaf238559e9 Mon Sep 17 00:00:00 2001 From: boojack Date: Sun, 7 Aug 2022 09:23:46 +0800 Subject: [PATCH] chore: use `tx` for user store --- api/auth.go | 4 -- server/acl.go | 36 ++++++++------- server/auth.go | 8 ++-- server/system.go | 5 ++- server/user.go | 22 +++++---- store/user.go | 114 ++++++++++++++++++++++++++++++----------------- 6 files changed, 114 insertions(+), 75 deletions(-) diff --git a/api/auth.go b/api/auth.go index b04310d6..ddb7998d 100644 --- a/api/auth.go +++ b/api/auth.go @@ -1,9 +1,5 @@ package api -var ( - UNKNOWN_ID = 0 -) - type Signin struct { Email string `json:"email"` Password string `json:"password"` diff --git a/server/acl.go b/server/acl.go index 82d47b2a..db52c88a 100644 --- a/server/acl.go +++ b/server/acl.go @@ -52,42 +52,44 @@ func removeUserSession(ctx echo.Context) error { } func aclMiddleware(s *Server, next echo.HandlerFunc) echo.HandlerFunc { - return func(ctx echo.Context) error { + return func(c echo.Context) error { + ctx := c.Request().Context() + path := c.Path() // Skip auth. - if common.HasPrefixes(ctx.Path(), "/api/auth") { - return next(ctx) + if common.HasPrefixes(path, "/api/auth") { + return next(c) } - if common.HasPrefixes(ctx.Path(), "/api/ping", "/api/status", "/api/user/:id") && ctx.Request().Method == http.MethodGet { - return next(ctx) + if common.HasPrefixes(path, "/api/ping", "/api/status", "/api/user/:id") && c.Request().Method == http.MethodGet { + return next(c) } // If there is openId in query string and related user is found, then skip auth. - openID := ctx.QueryParam("openId") + openID := c.QueryParam("openId") if openID != "" { userFind := &api.UserFind{ OpenID: &openID, } - user, err := s.Store.FindUser(userFind) + user, err := s.Store.FindUser(ctx, userFind) if err != nil { return echo.NewHTTPError(http.StatusInternalServerError, "Failed to find user by open_id").SetInternal(err) } if user != nil { // Stores userID into context. - ctx.Set(getUserIDContextKey(), user.ID) - return next(ctx) + c.Set(getUserIDContextKey(), user.ID) + return next(c) } } { - sess, _ := session.Get("session", ctx) + sess, _ := session.Get("session", c) userIDValue := sess.Values[userIDContextKey] if userIDValue != nil { userID, _ := strconv.Atoi(fmt.Sprintf("%v", userIDValue)) userFind := &api.UserFind{ ID: &userID, } - user, err := s.Store.FindUser(userFind) + user, err := s.Store.FindUser(ctx, userFind) if err != nil { return echo.NewHTTPError(http.StatusInternalServerError, fmt.Sprintf("Failed to find user by ID: %d", userID)).SetInternal(err) } @@ -95,22 +97,22 @@ func aclMiddleware(s *Server, next echo.HandlerFunc) echo.HandlerFunc { if user.RowStatus == api.Archived { return echo.NewHTTPError(http.StatusForbidden, fmt.Sprintf("User has been archived with email %s", user.Email)) } - ctx.Set(getUserIDContextKey(), userID) + c.Set(getUserIDContextKey(), userID) } } } - if common.HasPrefixes(ctx.Path(), "/api/memo", "/api/tag", "/api/shortcut") && ctx.Request().Method == http.MethodGet { - if _, err := strconv.Atoi(ctx.QueryParam("creatorId")); err == nil { - return next(ctx) + if common.HasPrefixes(path, "/api/memo", "/api/tag", "/api/shortcut") && c.Request().Method == http.MethodGet { + if _, err := strconv.Atoi(c.QueryParam("creatorId")); err == nil { + return next(c) } } - userID := ctx.Get(getUserIDContextKey()) + userID := c.Get(getUserIDContextKey()) if userID == nil { return echo.NewHTTPError(http.StatusUnauthorized, "Missing user in session") } - return next(ctx) + return next(c) } } diff --git a/server/auth.go b/server/auth.go index c9ed3e52..967be228 100644 --- a/server/auth.go +++ b/server/auth.go @@ -14,6 +14,7 @@ import ( func (s *Server) registerAuthRoutes(g *echo.Group) { g.POST("/auth/signin", func(c echo.Context) error { + ctx := c.Request().Context() signin := &api.Signin{} if err := json.NewDecoder(c.Request().Body).Decode(signin); err != nil { return echo.NewHTTPError(http.StatusBadRequest, "Malformatted signin request").SetInternal(err) @@ -22,7 +23,7 @@ func (s *Server) registerAuthRoutes(g *echo.Group) { userFind := &api.UserFind{ Email: &signin.Email, } - user, err := s.Store.FindUser(userFind) + user, err := s.Store.FindUser(ctx, userFind) if err != nil { return echo.NewHTTPError(http.StatusInternalServerError, fmt.Sprintf("Failed to find user by email %s", signin.Email)).SetInternal(err) } @@ -60,12 +61,13 @@ func (s *Server) registerAuthRoutes(g *echo.Group) { }) g.POST("/auth/signup", func(c echo.Context) error { + ctx := c.Request().Context() // Don't allow to signup by this api if site host existed. hostUserType := api.Host hostUserFind := api.UserFind{ Role: &hostUserType, } - hostUser, err := s.Store.FindUser(&hostUserFind) + hostUser, err := s.Store.FindUser(ctx, &hostUserFind) if err != nil { return echo.NewHTTPError(http.StatusInternalServerError, "Failed to find host user").SetInternal(err) } @@ -99,7 +101,7 @@ func (s *Server) registerAuthRoutes(g *echo.Group) { PasswordHash: string(passwordHash), OpenID: common.GenUUID(), } - user, err := s.Store.CreateUser(userCreate) + user, err := s.Store.CreateUser(ctx, userCreate) if err != nil { return echo.NewHTTPError(http.StatusInternalServerError, "Failed to create user").SetInternal(err) } diff --git a/server/system.go b/server/system.go index d21b0b66..e72b041f 100644 --- a/server/system.go +++ b/server/system.go @@ -21,11 +21,12 @@ func (s *Server) registerSystemRoutes(g *echo.Group) { }) g.GET("/status", func(c echo.Context) error { + ctx := c.Request().Context() hostUserType := api.Host hostUserFind := api.UserFind{ Role: &hostUserType, } - hostUser, err := s.Store.FindUser(&hostUserFind) + hostUser, err := s.Store.FindUser(ctx, &hostUserFind) if err != nil { return echo.NewHTTPError(http.StatusInternalServerError, "Failed to find host user").SetInternal(err) } @@ -36,7 +37,7 @@ func (s *Server) registerSystemRoutes(g *echo.Group) { } systemStatus := api.SystemStatus{ - Host: hostUser, + Host: hostUser, Profile: s.Profile, } diff --git a/server/user.go b/server/user.go index 4e499cbe..e6dfb63b 100644 --- a/server/user.go +++ b/server/user.go @@ -15,6 +15,7 @@ import ( func (s *Server) registerUserRoutes(g *echo.Group) { g.POST("/user", func(c echo.Context) error { + ctx := c.Request().Context() userCreate := &api.UserCreate{} if err := json.NewDecoder(c.Request().Body).Decode(userCreate); err != nil { return echo.NewHTTPError(http.StatusBadRequest, "Malformatted post user request").SetInternal(err) @@ -26,7 +27,7 @@ func (s *Server) registerUserRoutes(g *echo.Group) { } userCreate.PasswordHash = string(passwordHash) - user, err := s.Store.CreateUser(userCreate) + user, err := s.Store.CreateUser(ctx, userCreate) if err != nil { return echo.NewHTTPError(http.StatusInternalServerError, "Failed to create user").SetInternal(err) } @@ -39,7 +40,8 @@ func (s *Server) registerUserRoutes(g *echo.Group) { }) g.GET("/user", func(c echo.Context) error { - userList, err := s.Store.FindUserList(&api.UserFind{}) + ctx := c.Request().Context() + userList, err := s.Store.FindUserList(ctx, &api.UserFind{}) if err != nil { return echo.NewHTTPError(http.StatusInternalServerError, "Failed to fetch user list").SetInternal(err) } @@ -57,12 +59,13 @@ func (s *Server) registerUserRoutes(g *echo.Group) { }) g.GET("/user/:id", func(c echo.Context) error { + ctx := c.Request().Context() id, err := strconv.Atoi(c.Param("id")) if err != nil { return echo.NewHTTPError(http.StatusBadRequest, "Malformatted user id").SetInternal(err) } - user, err := s.Store.FindUser(&api.UserFind{ + user, err := s.Store.FindUser(ctx, &api.UserFind{ ID: &id, }) if err != nil { @@ -83,6 +86,7 @@ func (s *Server) registerUserRoutes(g *echo.Group) { // GET /api/user/me is used to check if the user is logged in. g.GET("/user/me", func(c echo.Context) error { + ctx := c.Request().Context() userID, ok := c.Get(getUserIDContextKey()).(int) if !ok { return echo.NewHTTPError(http.StatusUnauthorized, "Missing auth session") @@ -91,7 +95,7 @@ func (s *Server) registerUserRoutes(g *echo.Group) { userFind := &api.UserFind{ ID: &userID, } - user, err := s.Store.FindUser(userFind) + user, err := s.Store.FindUser(ctx, userFind) if err != nil { return echo.NewHTTPError(http.StatusInternalServerError, "Failed to fetch user").SetInternal(err) } @@ -104,6 +108,7 @@ func (s *Server) registerUserRoutes(g *echo.Group) { }) g.PATCH("/user/:id", func(c echo.Context) error { + ctx := c.Request().Context() userID, err := strconv.Atoi(c.Param("id")) if err != nil { return echo.NewHTTPError(http.StatusBadRequest, fmt.Sprintf("ID is not a number: %s", c.Param("id"))).SetInternal(err) @@ -112,7 +117,7 @@ func (s *Server) registerUserRoutes(g *echo.Group) { if !ok { return echo.NewHTTPError(http.StatusUnauthorized, "Missing user in session") } - currentUser, err := s.Store.FindUser(&api.UserFind{ + currentUser, err := s.Store.FindUser(ctx, &api.UserFind{ ID: ¤tUserID, }) if err != nil { @@ -146,7 +151,7 @@ func (s *Server) registerUserRoutes(g *echo.Group) { userPatch.OpenID = &openID } - user, err := s.Store.PatchUser(userPatch) + user, err := s.Store.PatchUser(ctx, userPatch) if err != nil { return echo.NewHTTPError(http.StatusInternalServerError, "Failed to patch user").SetInternal(err) } @@ -159,11 +164,12 @@ func (s *Server) registerUserRoutes(g *echo.Group) { }) g.DELETE("/user/:id", func(c echo.Context) error { + ctx := c.Request().Context() currentUserID, ok := c.Get(getUserIDContextKey()).(int) if !ok { return echo.NewHTTPError(http.StatusUnauthorized, "Missing user in session") } - currentUser, err := s.Store.FindUser(&api.UserFind{ + currentUser, err := s.Store.FindUser(ctx, &api.UserFind{ ID: ¤tUserID, }) if err != nil { @@ -183,7 +189,7 @@ func (s *Server) registerUserRoutes(g *echo.Group) { userDelete := &api.UserDelete{ ID: userID, } - if err := s.Store.DeleteUser(userDelete); err != nil { + if err := s.Store.DeleteUser(ctx, userDelete); err != nil { return echo.NewHTTPError(http.StatusInternalServerError, "Failed to delete user").SetInternal(err) } diff --git a/store/user.go b/store/user.go index 0c7776b4..b9b24cce 100644 --- a/store/user.go +++ b/store/user.go @@ -1,6 +1,7 @@ package store import ( + "context" "database/sql" "fmt" "strings" @@ -43,12 +44,22 @@ func (raw *userRaw) toUser() *api.User { } } -func (s *Store) CreateUser(create *api.UserCreate) (*api.User, error) { - userRaw, err := createUser(s.db, create) +func (s *Store) CreateUser(ctx context.Context, create *api.UserCreate) (*api.User, error) { + tx, err := s.db.BeginTx(ctx, nil) + if err != nil { + return nil, FormatError(err) + } + defer tx.Rollback() + + userRaw, err := createUser(ctx, tx, create) if err != nil { return nil, err } + if err := tx.Commit(); err != nil { + return nil, FormatError(err) + } + user := userRaw.toUser() if err := s.cache.UpsertCache(api.UserCache, user.ID, user); err != nil { @@ -58,12 +69,22 @@ func (s *Store) CreateUser(create *api.UserCreate) (*api.User, error) { return user, nil } -func (s *Store) PatchUser(patch *api.UserPatch) (*api.User, error) { - userRaw, err := patchUser(s.db, patch) +func (s *Store) PatchUser(ctx context.Context, patch *api.UserPatch) (*api.User, error) { + tx, err := s.db.BeginTx(ctx, nil) + if err != nil { + return nil, FormatError(err) + } + defer tx.Rollback() + + userRaw, err := patchUser(ctx, tx, patch) if err != nil { return nil, err } + if err := tx.Commit(); err != nil { + return nil, FormatError(err) + } + user := userRaw.toUser() if err := s.cache.UpsertCache(api.UserCache, user.ID, user); err != nil { @@ -73,8 +94,14 @@ func (s *Store) PatchUser(patch *api.UserPatch) (*api.User, error) { return user, nil } -func (s *Store) FindUserList(find *api.UserFind) ([]*api.User, error) { - userRawList, err := findUserList(s.db, find) +func (s *Store) FindUserList(ctx context.Context, find *api.UserFind) ([]*api.User, error) { + tx, err := s.db.BeginTx(ctx, nil) + if err != nil { + return nil, FormatError(err) + } + defer tx.Rollback() + + userRawList, err := findUserList(ctx, tx, find) if err != nil { return nil, err } @@ -87,7 +114,7 @@ func (s *Store) FindUserList(find *api.UserFind) ([]*api.User, error) { return list, nil } -func (s *Store) FindUser(find *api.UserFind) (*api.User, error) { +func (s *Store) FindUser(ctx context.Context, find *api.UserFind) (*api.User, error) { if find.ID != nil { user := &api.User{} has, err := s.cache.FindCache(api.UserCache, *find.ID, user) @@ -99,7 +126,13 @@ func (s *Store) FindUser(find *api.UserFind) (*api.User, error) { } } - list, err := findUserList(s.db, find) + tx, err := s.db.BeginTx(ctx, nil) + if err != nil { + return nil, FormatError(err) + } + defer tx.Rollback() + + list, err := findUserList(ctx, tx, find) if err != nil { return nil, err } @@ -119,19 +152,29 @@ func (s *Store) FindUser(find *api.UserFind) (*api.User, error) { return user, nil } -func (s *Store) DeleteUser(delete *api.UserDelete) error { - err := deleteUser(s.db, delete) +func (s *Store) DeleteUser(ctx context.Context, delete *api.UserDelete) error { + tx, err := s.db.BeginTx(ctx, nil) + if err != nil { + return FormatError(err) + } + defer tx.Rollback() + + err = deleteUser(ctx, tx, delete) if err != nil { return FormatError(err) } + if err := tx.Commit(); err != nil { + return FormatError(err) + } + s.cache.DeleteCache(api.UserCache, delete.ID) return nil } -func createUser(db *sql.DB, create *api.UserCreate) (*userRaw, error) { - row, err := db.Query(` +func createUser(ctx context.Context, tx *sql.Tx, create *api.UserCreate) (*userRaw, error) { + query := ` INSERT INTO user ( email, role, @@ -141,21 +184,15 @@ func createUser(db *sql.DB, create *api.UserCreate) (*userRaw, error) { ) VALUES (?, ?, ?, ?, ?) RETURNING id, email, role, name, password_hash, open_id, created_ts, updated_ts, row_status - `, + ` + var userRaw userRaw + if err := tx.QueryRowContext(ctx, query, create.Email, create.Role, create.Name, create.PasswordHash, create.OpenID, - ) - if err != nil { - return nil, FormatError(err) - } - defer row.Close() - - row.Next() - var userRaw userRaw - if err := row.Scan( + ).Scan( &userRaw.ID, &userRaw.Email, &userRaw.Role, @@ -172,7 +209,7 @@ func createUser(db *sql.DB, create *api.UserCreate) (*userRaw, error) { return &userRaw, nil } -func patchUser(db *sql.DB, patch *api.UserPatch) (*userRaw, error) { +func patchUser(ctx context.Context, tx *sql.Tx, patch *api.UserPatch) (*userRaw, error) { set, args := []string{}, []interface{}{} if v := patch.RowStatus; v != nil { @@ -193,12 +230,13 @@ func patchUser(db *sql.DB, patch *api.UserPatch) (*userRaw, error) { args = append(args, patch.ID) - row, err := db.Query(` + query := ` UPDATE user - SET `+strings.Join(set, ", ")+` + SET ` + strings.Join(set, ", ") + ` WHERE id = ? RETURNING id, email, role, name, password_hash, open_id, created_ts, updated_ts, row_status - `, args...) + ` + row, err := tx.QueryContext(ctx, query, args...) if err != nil { return nil, FormatError(err) } @@ -226,7 +264,7 @@ func patchUser(db *sql.DB, patch *api.UserPatch) (*userRaw, error) { return nil, &common.Error{Code: common.NotFound, Err: fmt.Errorf("user ID not found: %d", patch.ID)} } -func findUserList(db *sql.DB, find *api.UserFind) ([]*userRaw, error) { +func findUserList(ctx context.Context, tx *sql.Tx, find *api.UserFind) ([]*userRaw, error) { where, args := []string{"1 = 1"}, []interface{}{} if v := find.ID; v != nil { @@ -245,7 +283,7 @@ func findUserList(db *sql.DB, find *api.UserFind) ([]*userRaw, error) { where, args = append(where, "open_id = ?"), append(args, *v) } - rows, err := db.Query(` + query := ` SELECT id, email, @@ -257,10 +295,10 @@ func findUserList(db *sql.DB, find *api.UserFind) ([]*userRaw, error) { updated_ts, row_status FROM user - WHERE `+strings.Join(where, " AND ")+` - ORDER BY created_ts DESC, row_status DESC`, - args..., - ) + WHERE ` + strings.Join(where, " AND ") + ` + ORDER BY created_ts DESC, row_status DESC + ` + rows, err := tx.QueryContext(ctx, query, args...) if err != nil { return nil, FormatError(err) } @@ -293,19 +331,13 @@ func findUserList(db *sql.DB, find *api.UserFind) ([]*userRaw, error) { return userRawList, nil } -func deleteUser(db *sql.DB, delete *api.UserDelete) error { - result, err := db.Exec(` +func deleteUser(ctx context.Context, tx *sql.Tx, delete *api.UserDelete) error { + if _, err := tx.ExecContext(ctx, ` PRAGMA foreign_keys = ON; DELETE FROM user WHERE id = ? - `, delete.ID) - if err != nil { + `, delete.ID); err != nil { return FormatError(err) } - rows, _ := result.RowsAffected() - if rows == 0 { - return &common.Error{Code: common.NotFound, Err: fmt.Errorf("user ID not found: %d", delete.ID)} - } - return nil }