chore: use `tx` for stores

pull/140/head
boojack 3 years ago
parent 8c28721839
commit d8e10ba399

@ -15,6 +15,7 @@ import (
func (s *Server) registerMemoRoutes(g *echo.Group) { func (s *Server) registerMemoRoutes(g *echo.Group) {
g.POST("/memo", func(c echo.Context) error { g.POST("/memo", func(c echo.Context) error {
ctx := c.Request().Context()
userID, ok := c.Get(getUserIDContextKey()).(int) userID, ok := c.Get(getUserIDContextKey()).(int)
if !ok { if !ok {
return echo.NewHTTPError(http.StatusUnauthorized, "Missing user in session") return echo.NewHTTPError(http.StatusUnauthorized, "Missing user in session")
@ -31,7 +32,7 @@ func (s *Server) registerMemoRoutes(g *echo.Group) {
memoCreate.Visibility = &private memoCreate.Visibility = &private
} }
memo, err := s.Store.CreateMemo(memoCreate) memo, err := s.Store.CreateMemo(ctx, memoCreate)
if err != nil { if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to create memo").SetInternal(err) return echo.NewHTTPError(http.StatusInternalServerError, "Failed to create memo").SetInternal(err)
} }
@ -44,6 +45,7 @@ func (s *Server) registerMemoRoutes(g *echo.Group) {
}) })
g.PATCH("/memo/:memoId", func(c echo.Context) error { g.PATCH("/memo/:memoId", func(c echo.Context) error {
ctx := c.Request().Context()
memoID, err := strconv.Atoi(c.Param("memoId")) memoID, err := strconv.Atoi(c.Param("memoId"))
if err != nil { if err != nil {
return echo.NewHTTPError(http.StatusBadRequest, fmt.Sprintf("ID is not a number: %s", c.Param("memoId"))).SetInternal(err) return echo.NewHTTPError(http.StatusBadRequest, fmt.Sprintf("ID is not a number: %s", c.Param("memoId"))).SetInternal(err)
@ -56,7 +58,7 @@ func (s *Server) registerMemoRoutes(g *echo.Group) {
return echo.NewHTTPError(http.StatusBadRequest, "Malformatted patch memo request").SetInternal(err) return echo.NewHTTPError(http.StatusBadRequest, "Malformatted patch memo request").SetInternal(err)
} }
memo, err := s.Store.PatchMemo(memoPatch) memo, err := s.Store.PatchMemo(ctx, memoPatch)
if err != nil { if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to patch memo").SetInternal(err) return echo.NewHTTPError(http.StatusInternalServerError, "Failed to patch memo").SetInternal(err)
} }
@ -69,6 +71,7 @@ func (s *Server) registerMemoRoutes(g *echo.Group) {
}) })
g.GET("/memo", func(c echo.Context) error { g.GET("/memo", func(c echo.Context) error {
ctx := c.Request().Context()
memoFind := &api.MemoFind{} memoFind := &api.MemoFind{}
if userID, err := strconv.Atoi(c.QueryParam("creatorId")); err == nil { if userID, err := strconv.Atoi(c.QueryParam("creatorId")); err == nil {
@ -118,7 +121,7 @@ func (s *Server) registerMemoRoutes(g *echo.Group) {
memoFind.Offset = offset memoFind.Offset = offset
} }
list, err := s.Store.FindMemoList(memoFind) list, err := s.Store.FindMemoList(ctx, memoFind)
if err != nil { if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to fetch memo list").SetInternal(err) return echo.NewHTTPError(http.StatusInternalServerError, "Failed to fetch memo list").SetInternal(err)
} }
@ -131,6 +134,7 @@ func (s *Server) registerMemoRoutes(g *echo.Group) {
}) })
g.POST("/memo/:memoId/organizer", func(c echo.Context) error { g.POST("/memo/:memoId/organizer", func(c echo.Context) error {
ctx := c.Request().Context()
memoID, err := strconv.Atoi(c.Param("memoId")) memoID, err := strconv.Atoi(c.Param("memoId"))
if err != nil { if err != nil {
return echo.NewHTTPError(http.StatusBadRequest, fmt.Sprintf("ID is not a number: %s", c.Param("memoId"))).SetInternal(err) return echo.NewHTTPError(http.StatusBadRequest, fmt.Sprintf("ID is not a number: %s", c.Param("memoId"))).SetInternal(err)
@ -148,12 +152,12 @@ func (s *Server) registerMemoRoutes(g *echo.Group) {
return echo.NewHTTPError(http.StatusBadRequest, "Malformatted post memo organizer request").SetInternal(err) return echo.NewHTTPError(http.StatusBadRequest, "Malformatted post memo organizer request").SetInternal(err)
} }
err = s.Store.UpsertMemoOrganizer(memoOrganizerUpsert) err = s.Store.UpsertMemoOrganizer(ctx, memoOrganizerUpsert)
if err != nil { if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to upsert memo organizer").SetInternal(err) return echo.NewHTTPError(http.StatusInternalServerError, "Failed to upsert memo organizer").SetInternal(err)
} }
memo, err := s.Store.FindMemo(&api.MemoFind{ memo, err := s.Store.FindMemo(ctx, &api.MemoFind{
ID: &memoID, ID: &memoID,
}) })
if err != nil { if err != nil {
@ -172,6 +176,7 @@ func (s *Server) registerMemoRoutes(g *echo.Group) {
}) })
g.GET("/memo/:memoId", func(c echo.Context) error { g.GET("/memo/:memoId", func(c echo.Context) error {
ctx := c.Request().Context()
memoID, err := strconv.Atoi(c.Param("memoId")) memoID, err := strconv.Atoi(c.Param("memoId"))
if err != nil { if err != nil {
return echo.NewHTTPError(http.StatusBadRequest, fmt.Sprintf("ID is not a number: %s", c.Param("memoId"))).SetInternal(err) return echo.NewHTTPError(http.StatusBadRequest, fmt.Sprintf("ID is not a number: %s", c.Param("memoId"))).SetInternal(err)
@ -180,7 +185,7 @@ func (s *Server) registerMemoRoutes(g *echo.Group) {
memoFind := &api.MemoFind{ memoFind := &api.MemoFind{
ID: &memoID, ID: &memoID,
} }
memo, err := s.Store.FindMemo(memoFind) memo, err := s.Store.FindMemo(ctx, memoFind)
if err != nil { if err != nil {
if common.ErrorCode(err) == common.NotFound { if common.ErrorCode(err) == common.NotFound {
return echo.NewHTTPError(http.StatusNotFound, fmt.Sprintf("Memo ID not found: %d", memoID)).SetInternal(err) return echo.NewHTTPError(http.StatusNotFound, fmt.Sprintf("Memo ID not found: %d", memoID)).SetInternal(err)
@ -197,6 +202,7 @@ func (s *Server) registerMemoRoutes(g *echo.Group) {
}) })
g.DELETE("/memo/:memoId", func(c echo.Context) error { g.DELETE("/memo/:memoId", func(c echo.Context) error {
ctx := c.Request().Context()
memoID, err := strconv.Atoi(c.Param("memoId")) memoID, err := strconv.Atoi(c.Param("memoId"))
if err != nil { if err != nil {
return echo.NewHTTPError(http.StatusBadRequest, fmt.Sprintf("ID is not a number: %s", c.Param("memoId"))).SetInternal(err) return echo.NewHTTPError(http.StatusBadRequest, fmt.Sprintf("ID is not a number: %s", c.Param("memoId"))).SetInternal(err)
@ -205,7 +211,7 @@ func (s *Server) registerMemoRoutes(g *echo.Group) {
memoDelete := &api.MemoDelete{ memoDelete := &api.MemoDelete{
ID: memoID, ID: memoID,
} }
if err := s.Store.DeleteMemo(memoDelete); err != nil { if err := s.Store.DeleteMemo(ctx, memoDelete); err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, fmt.Sprintf("Failed to delete memo ID: %v", memoID)).SetInternal(err) return echo.NewHTTPError(http.StatusInternalServerError, fmt.Sprintf("Failed to delete memo ID: %v", memoID)).SetInternal(err)
} }
@ -213,6 +219,7 @@ func (s *Server) registerMemoRoutes(g *echo.Group) {
}) })
g.GET("/memo/amount", func(c echo.Context) error { g.GET("/memo/amount", func(c echo.Context) error {
ctx := c.Request().Context()
userID, ok := c.Get(getUserIDContextKey()).(int) userID, ok := c.Get(getUserIDContextKey()).(int)
if !ok { if !ok {
return echo.NewHTTPError(http.StatusUnauthorized, "Missing user in session") return echo.NewHTTPError(http.StatusUnauthorized, "Missing user in session")
@ -223,7 +230,7 @@ func (s *Server) registerMemoRoutes(g *echo.Group) {
RowStatus: &normalRowStatus, RowStatus: &normalRowStatus,
} }
memoList, err := s.Store.FindMemoList(memoFind) memoList, err := s.Store.FindMemoList(ctx, memoFind)
if err != nil { if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to find memo list").SetInternal(err) return echo.NewHTTPError(http.StatusInternalServerError, "Failed to find memo list").SetInternal(err)
} }

@ -14,6 +14,7 @@ import (
func (s *Server) registerResourceRoutes(g *echo.Group) { func (s *Server) registerResourceRoutes(g *echo.Group) {
g.POST("/resource", func(c echo.Context) error { g.POST("/resource", func(c echo.Context) error {
ctx := c.Request().Context()
userID, ok := c.Get(getUserIDContextKey()).(int) userID, ok := c.Get(getUserIDContextKey()).(int)
if !ok { if !ok {
return echo.NewHTTPError(http.StatusUnauthorized, "Missing user in session") return echo.NewHTTPError(http.StatusUnauthorized, "Missing user in session")
@ -51,7 +52,7 @@ func (s *Server) registerResourceRoutes(g *echo.Group) {
CreatorID: userID, CreatorID: userID,
} }
resource, err := s.Store.CreateResource(resourceCreate) resource, err := s.Store.CreateResource(ctx, resourceCreate)
if err != nil { if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to create resource").SetInternal(err) return echo.NewHTTPError(http.StatusInternalServerError, "Failed to create resource").SetInternal(err)
} }
@ -64,6 +65,7 @@ func (s *Server) registerResourceRoutes(g *echo.Group) {
}) })
g.GET("/resource", func(c echo.Context) error { g.GET("/resource", func(c echo.Context) error {
ctx := c.Request().Context()
userID, ok := c.Get(getUserIDContextKey()).(int) userID, ok := c.Get(getUserIDContextKey()).(int)
if !ok { if !ok {
return echo.NewHTTPError(http.StatusUnauthorized, "Missing user in session") return echo.NewHTTPError(http.StatusUnauthorized, "Missing user in session")
@ -71,7 +73,7 @@ func (s *Server) registerResourceRoutes(g *echo.Group) {
resourceFind := &api.ResourceFind{ resourceFind := &api.ResourceFind{
CreatorID: &userID, CreatorID: &userID,
} }
list, err := s.Store.FindResourceList(resourceFind) list, err := s.Store.FindResourceList(ctx, resourceFind)
if err != nil { if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to fetch resource list").SetInternal(err) return echo.NewHTTPError(http.StatusInternalServerError, "Failed to fetch resource list").SetInternal(err)
} }
@ -84,6 +86,7 @@ func (s *Server) registerResourceRoutes(g *echo.Group) {
}) })
g.GET("/resource/:resourceId", func(c echo.Context) error { g.GET("/resource/:resourceId", func(c echo.Context) error {
ctx := c.Request().Context()
resourceID, err := strconv.Atoi(c.Param("resourceId")) resourceID, err := strconv.Atoi(c.Param("resourceId"))
if err != nil { if err != nil {
return echo.NewHTTPError(http.StatusBadRequest, fmt.Sprintf("ID is not a number: %s", c.Param("resourceId"))).SetInternal(err) return echo.NewHTTPError(http.StatusBadRequest, fmt.Sprintf("ID is not a number: %s", c.Param("resourceId"))).SetInternal(err)
@ -97,7 +100,7 @@ func (s *Server) registerResourceRoutes(g *echo.Group) {
ID: &resourceID, ID: &resourceID,
CreatorID: &userID, CreatorID: &userID,
} }
resource, err := s.Store.FindResource(resourceFind) resource, err := s.Store.FindResource(ctx, resourceFind)
if err != nil { if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to fetch resource").SetInternal(err) return echo.NewHTTPError(http.StatusInternalServerError, "Failed to fetch resource").SetInternal(err)
} }
@ -110,6 +113,7 @@ func (s *Server) registerResourceRoutes(g *echo.Group) {
}) })
g.GET("/resource/:resourceId/blob", func(c echo.Context) error { g.GET("/resource/:resourceId/blob", func(c echo.Context) error {
ctx := c.Request().Context()
resourceID, err := strconv.Atoi(c.Param("resourceId")) resourceID, err := strconv.Atoi(c.Param("resourceId"))
if err != nil { if err != nil {
return echo.NewHTTPError(http.StatusBadRequest, fmt.Sprintf("ID is not a number: %s", c.Param("resourceId"))).SetInternal(err) return echo.NewHTTPError(http.StatusBadRequest, fmt.Sprintf("ID is not a number: %s", c.Param("resourceId"))).SetInternal(err)
@ -123,7 +127,7 @@ func (s *Server) registerResourceRoutes(g *echo.Group) {
ID: &resourceID, ID: &resourceID,
CreatorID: &userID, CreatorID: &userID,
} }
resource, err := s.Store.FindResource(resourceFind) resource, err := s.Store.FindResource(ctx, resourceFind)
if err != nil { if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to fetch resource").SetInternal(err) return echo.NewHTTPError(http.StatusInternalServerError, "Failed to fetch resource").SetInternal(err)
} }
@ -138,6 +142,7 @@ func (s *Server) registerResourceRoutes(g *echo.Group) {
}) })
g.DELETE("/resource/:resourceId", func(c echo.Context) error { g.DELETE("/resource/:resourceId", func(c echo.Context) error {
ctx := c.Request().Context()
userID, ok := c.Get(getUserIDContextKey()).(int) userID, ok := c.Get(getUserIDContextKey()).(int)
if !ok { if !ok {
return echo.NewHTTPError(http.StatusUnauthorized, "Missing user in session") return echo.NewHTTPError(http.StatusUnauthorized, "Missing user in session")
@ -152,7 +157,7 @@ func (s *Server) registerResourceRoutes(g *echo.Group) {
ID: resourceID, ID: resourceID,
CreatorID: userID, CreatorID: userID,
} }
if err := s.Store.DeleteResource(resourceDelete); err != nil { if err := s.Store.DeleteResource(ctx, resourceDelete); err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to delete resource").SetInternal(err) return echo.NewHTTPError(http.StatusInternalServerError, "Failed to delete resource").SetInternal(err)
} }

@ -13,6 +13,7 @@ import (
func (s *Server) registerShortcutRoutes(g *echo.Group) { func (s *Server) registerShortcutRoutes(g *echo.Group) {
g.POST("/shortcut", func(c echo.Context) error { g.POST("/shortcut", func(c echo.Context) error {
ctx := c.Request().Context()
userID, ok := c.Get(getUserIDContextKey()).(int) userID, ok := c.Get(getUserIDContextKey()).(int)
if !ok { if !ok {
return echo.NewHTTPError(http.StatusUnauthorized, "Missing user in session") return echo.NewHTTPError(http.StatusUnauthorized, "Missing user in session")
@ -24,7 +25,7 @@ func (s *Server) registerShortcutRoutes(g *echo.Group) {
return echo.NewHTTPError(http.StatusBadRequest, "Malformatted post shortcut request").SetInternal(err) return echo.NewHTTPError(http.StatusBadRequest, "Malformatted post shortcut request").SetInternal(err)
} }
shortcut, err := s.Store.CreateShortcut(shortcutCreate) shortcut, err := s.Store.CreateShortcut(ctx, shortcutCreate)
if err != nil { if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to create shortcut").SetInternal(err) return echo.NewHTTPError(http.StatusInternalServerError, "Failed to create shortcut").SetInternal(err)
} }
@ -37,6 +38,7 @@ func (s *Server) registerShortcutRoutes(g *echo.Group) {
}) })
g.PATCH("/shortcut/:shortcutId", func(c echo.Context) error { g.PATCH("/shortcut/:shortcutId", func(c echo.Context) error {
ctx := c.Request().Context()
shortcutID, err := strconv.Atoi(c.Param("shortcutId")) shortcutID, err := strconv.Atoi(c.Param("shortcutId"))
if err != nil { if err != nil {
return echo.NewHTTPError(http.StatusBadRequest, fmt.Sprintf("ID is not a number: %s", c.Param("shortcutId"))).SetInternal(err) return echo.NewHTTPError(http.StatusBadRequest, fmt.Sprintf("ID is not a number: %s", c.Param("shortcutId"))).SetInternal(err)
@ -49,7 +51,7 @@ func (s *Server) registerShortcutRoutes(g *echo.Group) {
return echo.NewHTTPError(http.StatusBadRequest, "Malformatted patch shortcut request").SetInternal(err) return echo.NewHTTPError(http.StatusBadRequest, "Malformatted patch shortcut request").SetInternal(err)
} }
shortcut, err := s.Store.PatchShortcut(shortcutPatch) shortcut, err := s.Store.PatchShortcut(ctx, shortcutPatch)
if err != nil { if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to patch shortcut").SetInternal(err) return echo.NewHTTPError(http.StatusInternalServerError, "Failed to patch shortcut").SetInternal(err)
} }
@ -62,6 +64,7 @@ func (s *Server) registerShortcutRoutes(g *echo.Group) {
}) })
g.GET("/shortcut", func(c echo.Context) error { g.GET("/shortcut", func(c echo.Context) error {
ctx := c.Request().Context()
shortcutFind := &api.ShortcutFind{} shortcutFind := &api.ShortcutFind{}
if userID, err := strconv.Atoi(c.QueryParam("creatorId")); err == nil { if userID, err := strconv.Atoi(c.QueryParam("creatorId")); err == nil {
@ -75,7 +78,7 @@ func (s *Server) registerShortcutRoutes(g *echo.Group) {
shortcutFind.CreatorID = &userID shortcutFind.CreatorID = &userID
} }
list, err := s.Store.FindShortcutList(shortcutFind) list, err := s.Store.FindShortcutList(ctx, shortcutFind)
if err != nil { if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to fetch shortcut list").SetInternal(err) return echo.NewHTTPError(http.StatusInternalServerError, "Failed to fetch shortcut list").SetInternal(err)
} }
@ -88,6 +91,7 @@ func (s *Server) registerShortcutRoutes(g *echo.Group) {
}) })
g.GET("/shortcut/:shortcutId", func(c echo.Context) error { g.GET("/shortcut/:shortcutId", func(c echo.Context) error {
ctx := c.Request().Context()
shortcutID, err := strconv.Atoi(c.Param("shortcutId")) shortcutID, err := strconv.Atoi(c.Param("shortcutId"))
if err != nil { if err != nil {
return echo.NewHTTPError(http.StatusBadRequest, fmt.Sprintf("ID is not a number: %s", c.Param("shortcutId"))).SetInternal(err) return echo.NewHTTPError(http.StatusBadRequest, fmt.Sprintf("ID is not a number: %s", c.Param("shortcutId"))).SetInternal(err)
@ -96,7 +100,7 @@ func (s *Server) registerShortcutRoutes(g *echo.Group) {
shortcutFind := &api.ShortcutFind{ shortcutFind := &api.ShortcutFind{
ID: &shortcutID, ID: &shortcutID,
} }
shortcut, err := s.Store.FindShortcut(shortcutFind) shortcut, err := s.Store.FindShortcut(ctx, shortcutFind)
if err != nil { if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, fmt.Sprintf("Failed to fetch shortcut by ID %d", *shortcutFind.ID)).SetInternal(err) return echo.NewHTTPError(http.StatusInternalServerError, fmt.Sprintf("Failed to fetch shortcut by ID %d", *shortcutFind.ID)).SetInternal(err)
} }
@ -109,6 +113,7 @@ func (s *Server) registerShortcutRoutes(g *echo.Group) {
}) })
g.DELETE("/shortcut/:shortcutId", func(c echo.Context) error { g.DELETE("/shortcut/:shortcutId", func(c echo.Context) error {
ctx := c.Request().Context()
shortcutID, err := strconv.Atoi(c.Param("shortcutId")) shortcutID, err := strconv.Atoi(c.Param("shortcutId"))
if err != nil { if err != nil {
return echo.NewHTTPError(http.StatusBadRequest, fmt.Sprintf("ID is not a number: %s", c.Param("shortcutId"))).SetInternal(err) return echo.NewHTTPError(http.StatusBadRequest, fmt.Sprintf("ID is not a number: %s", c.Param("shortcutId"))).SetInternal(err)
@ -117,7 +122,7 @@ func (s *Server) registerShortcutRoutes(g *echo.Group) {
shortcutDelete := &api.ShortcutDelete{ shortcutDelete := &api.ShortcutDelete{
ID: shortcutID, ID: shortcutID,
} }
if err := s.Store.DeleteShortcut(shortcutDelete); err != nil { if err := s.Store.DeleteShortcut(ctx, shortcutDelete); err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to delete shortcut").SetInternal(err) return echo.NewHTTPError(http.StatusInternalServerError, "Failed to delete shortcut").SetInternal(err)
} }

@ -14,6 +14,7 @@ import (
func (s *Server) registerTagRoutes(g *echo.Group) { func (s *Server) registerTagRoutes(g *echo.Group) {
g.GET("/tag", func(c echo.Context) error { g.GET("/tag", func(c echo.Context) error {
ctx := c.Request().Context()
contentSearch := "#" contentSearch := "#"
normalRowStatus := api.Normal normalRowStatus := api.Normal
memoFind := api.MemoFind{ memoFind := api.MemoFind{
@ -39,7 +40,7 @@ func (s *Server) registerTagRoutes(g *echo.Group) {
} }
} }
memoList, err := s.Store.FindMemoList(&memoFind) memoList, err := s.Store.FindMemoList(ctx, &memoFind)
if err != nil { if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to find memo list").SetInternal(err) return echo.NewHTTPError(http.StatusInternalServerError, "Failed to find memo list").SetInternal(err)
} }

@ -16,6 +16,7 @@ func (s *Server) registerWebhookRoutes(g *echo.Group) {
}) })
g.GET("/r/:resourceId/:filename", func(c echo.Context) error { g.GET("/r/:resourceId/:filename", func(c echo.Context) error {
ctx := c.Request().Context()
resourceID, err := strconv.Atoi(c.Param("resourceId")) resourceID, err := strconv.Atoi(c.Param("resourceId"))
if err != nil { if err != nil {
return echo.NewHTTPError(http.StatusBadRequest, fmt.Sprintf("ID is not a number: %s", c.Param("resourceId"))).SetInternal(err) return echo.NewHTTPError(http.StatusBadRequest, fmt.Sprintf("ID is not a number: %s", c.Param("resourceId"))).SetInternal(err)
@ -26,7 +27,7 @@ func (s *Server) registerWebhookRoutes(g *echo.Group) {
ID: &resourceID, ID: &resourceID,
Filename: &filename, Filename: &filename,
} }
resource, err := s.Store.FindResource(resourceFind) resource, err := s.Store.FindResource(ctx, resourceFind)
if err != nil { if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, fmt.Sprintf("Failed to fetch resource ID: %v", resourceID)).SetInternal(err) return echo.NewHTTPError(http.StatusInternalServerError, fmt.Sprintf("Failed to fetch resource ID: %v", resourceID)).SetInternal(err)
} }

@ -1,6 +1,7 @@
package store package store
import ( import (
"context"
"database/sql" "database/sql"
"fmt" "fmt"
"strings" "strings"
@ -43,13 +44,39 @@ func (raw *memoRaw) toMemo() *api.Memo {
} }
} }
func (s *Store) CreateMemo(create *api.MemoCreate) (*api.Memo, error) { func (s *Store) composeMemo(ctx context.Context, raw *memoRaw) (*api.Memo, error) {
memoRaw, err := createMemoRaw(s.db, create) memo := raw.toMemo()
memoOrganizer, err := s.FindMemoOrganizer(ctx, &api.MemoOrganizerFind{
MemoID: memo.ID,
UserID: memo.CreatorID,
})
if err != nil && common.ErrorCode(err) != common.NotFound {
return nil, err
} else if memoOrganizer != nil {
memo.Pinned = memoOrganizer.Pinned
}
return memo, nil
}
func (s *Store) CreateMemo(ctx context.Context, create *api.MemoCreate) (*api.Memo, error) {
tx, err := s.db.BeginTx(ctx, nil)
if err != nil {
return nil, FormatError(err)
}
defer tx.Rollback()
memoRaw, err := createMemoRaw(ctx, tx, create)
if err != nil { if err != nil {
return nil, err return nil, err
} }
memo, err := s.composeMemo(memoRaw) if err := tx.Commit(); err != nil {
return nil, FormatError(err)
}
memo, err := s.composeMemo(ctx, memoRaw)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -61,13 +88,23 @@ func (s *Store) CreateMemo(create *api.MemoCreate) (*api.Memo, error) {
return memo, nil return memo, nil
} }
func (s *Store) PatchMemo(patch *api.MemoPatch) (*api.Memo, error) { func (s *Store) PatchMemo(ctx context.Context, patch *api.MemoPatch) (*api.Memo, error) {
memoRaw, err := patchMemoRaw(s.db, patch) tx, err := s.db.BeginTx(ctx, nil)
if err != nil {
return nil, FormatError(err)
}
defer tx.Rollback()
memoRaw, err := patchMemoRaw(ctx, tx, patch)
if err != nil { if err != nil {
return nil, err return nil, err
} }
memo, err := s.composeMemo(memoRaw) if err := tx.Commit(); err != nil {
return nil, FormatError(err)
}
memo, err := s.composeMemo(ctx, memoRaw)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -79,15 +116,21 @@ func (s *Store) PatchMemo(patch *api.MemoPatch) (*api.Memo, error) {
return memo, nil return memo, nil
} }
func (s *Store) FindMemoList(find *api.MemoFind) ([]*api.Memo, error) { func (s *Store) FindMemoList(ctx context.Context, find *api.MemoFind) ([]*api.Memo, error) {
memoRawList, err := findMemoRawList(s.db, find) tx, err := s.db.BeginTx(ctx, nil)
if err != nil {
return nil, FormatError(err)
}
defer tx.Rollback()
memoRawList, err := findMemoRawList(ctx, tx, find)
if err != nil { if err != nil {
return nil, err return nil, err
} }
list := []*api.Memo{} list := []*api.Memo{}
for _, raw := range memoRawList { for _, raw := range memoRawList {
memo, err := s.composeMemo(raw) memo, err := s.composeMemo(ctx, raw)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -98,7 +141,7 @@ func (s *Store) FindMemoList(find *api.MemoFind) ([]*api.Memo, error) {
return list, nil return list, nil
} }
func (s *Store) FindMemo(find *api.MemoFind) (*api.Memo, error) { func (s *Store) FindMemo(ctx context.Context, find *api.MemoFind) (*api.Memo, error) {
if find.ID != nil { if find.ID != nil {
memo := &api.Memo{} memo := &api.Memo{}
has, err := s.cache.FindCache(api.MemoCache, *find.ID, memo) has, err := s.cache.FindCache(api.MemoCache, *find.ID, memo)
@ -110,7 +153,13 @@ func (s *Store) FindMemo(find *api.MemoFind) (*api.Memo, error) {
} }
} }
list, err := findMemoRawList(s.db, find) tx, err := s.db.BeginTx(ctx, nil)
if err != nil {
return nil, FormatError(err)
}
defer tx.Rollback()
list, err := findMemoRawList(ctx, tx, find)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -119,7 +168,7 @@ func (s *Store) FindMemo(find *api.MemoFind) (*api.Memo, error) {
return nil, &common.Error{Code: common.NotFound, Err: fmt.Errorf("not found")} return nil, &common.Error{Code: common.NotFound, Err: fmt.Errorf("not found")}
} }
memo, err := s.composeMemo(list[0]) memo, err := s.composeMemo(ctx, list[0])
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -131,18 +180,27 @@ func (s *Store) FindMemo(find *api.MemoFind) (*api.Memo, error) {
return memo, nil return memo, nil
} }
func (s *Store) DeleteMemo(delete *api.MemoDelete) error { func (s *Store) DeleteMemo(ctx context.Context, delete *api.MemoDelete) error {
err := deleteMemo(s.db, delete) tx, err := s.db.BeginTx(ctx, nil)
if err != nil { if err != nil {
return FormatError(err) return FormatError(err)
} }
defer tx.Rollback()
if err := deleteMemo(ctx, tx, delete); err != nil {
return FormatError(err)
}
if err := tx.Commit(); err != nil {
return FormatError(err)
}
s.cache.DeleteCache(api.MemoCache, delete.ID) s.cache.DeleteCache(api.MemoCache, delete.ID)
return nil return nil
} }
func createMemoRaw(db *sql.DB, create *api.MemoCreate) (*memoRaw, error) { func createMemoRaw(ctx context.Context, tx *sql.Tx, create *api.MemoCreate) (*memoRaw, error) {
set := []string{"creator_id", "content"} set := []string{"creator_id", "content"}
placeholder := []string{"?", "?"} placeholder := []string{"?", "?"}
args := []interface{}{create.CreatorID, create.Content} args := []interface{}{create.CreatorID, create.Content}
@ -155,22 +213,14 @@ func createMemoRaw(db *sql.DB, create *api.MemoCreate) (*memoRaw, error) {
} }
query := ` query := `
INSERT INTO memo ( INSERT INTO memo (
` + strings.Join(set, ", ") + ` ` + strings.Join(set, ", ") + `
) )
VALUES (` + strings.Join(placeholder, ",") + `) VALUES (` + strings.Join(placeholder, ",") + `)
RETURNING id, creator_id, created_ts, updated_ts, row_status, content, visibility` RETURNING id, creator_id, created_ts, updated_ts, row_status, content, visibility
row, err := db.Query(query, `
args...,
)
if err != nil {
return nil, FormatError(err)
}
defer row.Close()
row.Next()
var memoRaw memoRaw var memoRaw memoRaw
if err := row.Scan( if err := tx.QueryRowContext(ctx, query, args...).Scan(
&memoRaw.ID, &memoRaw.ID,
&memoRaw.CreatorID, &memoRaw.CreatorID,
&memoRaw.CreatedTs, &memoRaw.CreatedTs,
@ -185,7 +235,7 @@ func createMemoRaw(db *sql.DB, create *api.MemoCreate) (*memoRaw, error) {
return &memoRaw, nil return &memoRaw, nil
} }
func patchMemoRaw(db *sql.DB, patch *api.MemoPatch) (*memoRaw, error) { func patchMemoRaw(ctx context.Context, tx *sql.Tx, patch *api.MemoPatch) (*memoRaw, error) {
set, args := []string{}, []interface{}{} set, args := []string{}, []interface{}{}
if v := patch.Content; v != nil { if v := patch.Content; v != nil {
@ -200,21 +250,14 @@ func patchMemoRaw(db *sql.DB, patch *api.MemoPatch) (*memoRaw, error) {
args = append(args, patch.ID) args = append(args, patch.ID)
row, err := db.Query(` query := `
UPDATE memo UPDATE memo
SET `+strings.Join(set, ", ")+` SET ` + strings.Join(set, ", ") + `
WHERE id = ? WHERE id = ?
RETURNING id, creator_id, created_ts, updated_ts, row_status, content, visibility RETURNING id, creator_id, created_ts, updated_ts, row_status, content, visibility
`, args...) `
if err != nil {
return nil, FormatError(err)
}
defer row.Close()
row.Next()
var memoRaw memoRaw var memoRaw memoRaw
if err := row.Scan( if err := tx.QueryRowContext(ctx, query, args...).Scan(
&memoRaw.ID, &memoRaw.ID,
&memoRaw.CreatorID, &memoRaw.CreatorID,
&memoRaw.CreatedTs, &memoRaw.CreatedTs,
@ -229,7 +272,7 @@ func patchMemoRaw(db *sql.DB, patch *api.MemoPatch) (*memoRaw, error) {
return &memoRaw, nil return &memoRaw, nil
} }
func findMemoRawList(db *sql.DB, find *api.MemoFind) ([]*memoRaw, error) { func findMemoRawList(ctx context.Context, tx *sql.Tx, find *api.MemoFind) ([]*memoRaw, error) {
where, args := []string{"1 = 1"}, []interface{}{} where, args := []string{"1 = 1"}, []interface{}{}
if v := find.ID; v != nil { if v := find.ID; v != nil {
@ -264,7 +307,7 @@ func findMemoRawList(db *sql.DB, find *api.MemoFind) ([]*memoRaw, error) {
} }
} }
rows, err := db.Query(` query := `
SELECT SELECT
id, id,
creator_id, creator_id,
@ -274,10 +317,10 @@ func findMemoRawList(db *sql.DB, find *api.MemoFind) ([]*memoRaw, error) {
content, content,
visibility visibility
FROM memo FROM memo
WHERE `+strings.Join(where, " AND ")+` WHERE ` + strings.Join(where, " AND ") + `
ORDER BY created_ts DESC`+pagination, ORDER BY created_ts DESC
args..., ` + pagination
) rows, err := tx.QueryContext(ctx, query, args...)
if err != nil { if err != nil {
return nil, FormatError(err) return nil, FormatError(err)
} }
@ -308,8 +351,8 @@ func findMemoRawList(db *sql.DB, find *api.MemoFind) ([]*memoRaw, error) {
return memoRawList, nil return memoRawList, nil
} }
func deleteMemo(db *sql.DB, delete *api.MemoDelete) error { func deleteMemo(ctx context.Context, tx *sql.Tx, delete *api.MemoDelete) error {
result, err := db.Exec(` _, err := tx.ExecContext(ctx, `
PRAGMA foreign_keys = ON; PRAGMA foreign_keys = ON;
DELETE FROM memo WHERE id = ? DELETE FROM memo WHERE id = ?
`, delete.ID) `, delete.ID)
@ -317,26 +360,5 @@ func deleteMemo(db *sql.DB, delete *api.MemoDelete) error {
return FormatError(err) return FormatError(err)
} }
rows, _ := result.RowsAffected()
if rows == 0 {
return &common.Error{Code: common.NotFound, Err: fmt.Errorf("memo ID not found: %d", delete.ID)}
}
return nil return nil
} }
func (s *Store) composeMemo(raw *memoRaw) (*api.Memo, error) {
memo := raw.toMemo()
memoOrganizer, err := s.FindMemoOrganizer(&api.MemoOrganizerFind{
MemoID: memo.ID,
UserID: memo.CreatorID,
})
if err != nil && common.ErrorCode(err) != common.NotFound {
return nil, err
} else if memoOrganizer != nil {
memo.Pinned = memoOrganizer.Pinned
}
return memo, nil
}

@ -1,6 +1,7 @@
package store package store
import ( import (
"context"
"database/sql" "database/sql"
"fmt" "fmt"
@ -29,8 +30,14 @@ func (raw *memoOrganizerRaw) toMemoOrganizer() *api.MemoOrganizer {
} }
} }
func (s *Store) FindMemoOrganizer(find *api.MemoOrganizerFind) (*api.MemoOrganizer, error) { func (s *Store) FindMemoOrganizer(ctx context.Context, find *api.MemoOrganizerFind) (*api.MemoOrganizer, error) {
memoOrganizerRaw, err := findMemoOrganizer(s.db, find) tx, err := s.db.BeginTx(ctx, nil)
if err != nil {
return nil, FormatError(err)
}
defer tx.Rollback()
memoOrganizerRaw, err := findMemoOrganizer(ctx, tx, find)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -40,17 +47,26 @@ func (s *Store) FindMemoOrganizer(find *api.MemoOrganizerFind) (*api.MemoOrganiz
return memoOrganizer, nil return memoOrganizer, nil
} }
func (s *Store) UpsertMemoOrganizer(upsert *api.MemoOrganizerUpsert) error { func (s *Store) UpsertMemoOrganizer(ctx context.Context, upsert *api.MemoOrganizerUpsert) error {
err := upsertMemoOrganizer(s.db, upsert) tx, err := s.db.BeginTx(ctx, nil)
if err != nil { if err != nil {
return FormatError(err)
}
defer tx.Rollback()
if err := upsertMemoOrganizer(ctx, tx, upsert); err != nil {
return err return err
} }
if err := tx.Commit(); err != nil {
return FormatError(err)
}
return nil return nil
} }
func findMemoOrganizer(db *sql.DB, find *api.MemoOrganizerFind) (*memoOrganizerRaw, error) { func findMemoOrganizer(ctx context.Context, tx *sql.Tx, find *api.MemoOrganizerFind) (*memoOrganizerRaw, error) {
row, err := db.Query(` query := `
SELECT SELECT
id, id,
memo_id, memo_id,
@ -58,7 +74,8 @@ func findMemoOrganizer(db *sql.DB, find *api.MemoOrganizerFind) (*memoOrganizerR
pinned pinned
FROM memo_organizer FROM memo_organizer
WHERE memo_id = ? AND user_id = ? WHERE memo_id = ? AND user_id = ?
`, find.MemoID, find.UserID) `
row, err := tx.QueryContext(ctx, query, find.MemoID, find.UserID)
if err != nil { if err != nil {
return nil, FormatError(err) return nil, FormatError(err)
} }
@ -81,8 +98,8 @@ func findMemoOrganizer(db *sql.DB, find *api.MemoOrganizerFind) (*memoOrganizerR
return &memoOrganizerRaw, nil return &memoOrganizerRaw, nil
} }
func upsertMemoOrganizer(db *sql.DB, upsert *api.MemoOrganizerUpsert) error { func upsertMemoOrganizer(ctx context.Context, tx *sql.Tx, upsert *api.MemoOrganizerUpsert) error {
row, err := db.Query(` query := `
INSERT INTO memo_organizer ( INSERT INTO memo_organizer (
memo_id, memo_id,
user_id, user_id,
@ -93,20 +110,9 @@ func upsertMemoOrganizer(db *sql.DB, upsert *api.MemoOrganizerUpsert) error {
SET SET
pinned = EXCLUDED.pinned pinned = EXCLUDED.pinned
RETURNING id, memo_id, user_id, pinned RETURNING id, memo_id, user_id, pinned
`, `
upsert.MemoID,
upsert.UserID,
upsert.Pinned,
)
if err != nil {
return FormatError(err)
}
defer row.Close()
row.Next()
var memoOrganizer api.MemoOrganizer var memoOrganizer api.MemoOrganizer
if err := row.Scan( if err := tx.QueryRowContext(ctx, query, upsert.MemoID, upsert.UserID, upsert.Pinned).Scan(
&memoOrganizer.ID, &memoOrganizer.ID,
&memoOrganizer.MemoID, &memoOrganizer.MemoID,
&memoOrganizer.UserID, &memoOrganizer.UserID,

@ -1,6 +1,7 @@
package store package store
import ( import (
"context"
"database/sql" "database/sql"
"fmt" "fmt"
"strings" "strings"
@ -43,12 +44,22 @@ func (raw *resourceRaw) toResource() *api.Resource {
} }
} }
func (s *Store) CreateResource(create *api.ResourceCreate) (*api.Resource, error) { func (s *Store) CreateResource(ctx context.Context, create *api.ResourceCreate) (*api.Resource, error) {
resourceRaw, err := createResource(s.db, create) tx, err := s.db.BeginTx(ctx, nil)
if err != nil {
return nil, FormatError(err)
}
defer tx.Rollback()
resourceRaw, err := createResource(ctx, tx, create)
if err != nil { if err != nil {
return nil, err return nil, err
} }
if err := tx.Commit(); err != nil {
return nil, FormatError(err)
}
resource := resourceRaw.toResource() resource := resourceRaw.toResource()
if err := s.cache.UpsertCache(api.ResourceCache, resource.ID, resource); err != nil { if err := s.cache.UpsertCache(api.ResourceCache, resource.ID, resource); err != nil {
@ -58,8 +69,14 @@ func (s *Store) CreateResource(create *api.ResourceCreate) (*api.Resource, error
return resource, nil return resource, nil
} }
func (s *Store) FindResourceList(find *api.ResourceFind) ([]*api.Resource, error) { func (s *Store) FindResourceList(ctx context.Context, find *api.ResourceFind) ([]*api.Resource, error) {
resourceRawList, err := findResourceList(s.db, find) tx, err := s.db.BeginTx(ctx, nil)
if err != nil {
return nil, FormatError(err)
}
defer tx.Rollback()
resourceRawList, err := findResourceList(ctx, tx, find)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -72,7 +89,7 @@ func (s *Store) FindResourceList(find *api.ResourceFind) ([]*api.Resource, error
return resourceList, nil return resourceList, nil
} }
func (s *Store) FindResource(find *api.ResourceFind) (*api.Resource, error) { func (s *Store) FindResource(ctx context.Context, find *api.ResourceFind) (*api.Resource, error) {
if find.ID != nil { if find.ID != nil {
resource := &api.Resource{} resource := &api.Resource{}
has, err := s.cache.FindCache(api.ResourceCache, *find.ID, resource) has, err := s.cache.FindCache(api.ResourceCache, *find.ID, resource)
@ -84,7 +101,13 @@ func (s *Store) FindResource(find *api.ResourceFind) (*api.Resource, error) {
} }
} }
list, err := findResourceList(s.db, find) tx, err := s.db.BeginTx(ctx, nil)
if err != nil {
return nil, FormatError(err)
}
defer tx.Rollback()
list, err := findResourceList(ctx, tx, find)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -102,19 +125,29 @@ func (s *Store) FindResource(find *api.ResourceFind) (*api.Resource, error) {
return resource, nil return resource, nil
} }
func (s *Store) DeleteResource(delete *api.ResourceDelete) error { func (s *Store) DeleteResource(ctx context.Context, delete *api.ResourceDelete) error {
err := deleteResource(s.db, delete) tx, err := s.db.BeginTx(ctx, nil)
if err != nil {
return FormatError(err)
}
defer tx.Rollback()
err = deleteResource(ctx, tx, delete)
if err != nil { if err != nil {
return err return err
} }
if err := tx.Commit(); err != nil {
return FormatError(err)
}
s.cache.DeleteCache(api.ResourceCache, delete.ID) s.cache.DeleteCache(api.ResourceCache, delete.ID)
return nil return nil
} }
func createResource(db *sql.DB, create *api.ResourceCreate) (*resourceRaw, error) { func createResource(ctx context.Context, tx *sql.Tx, create *api.ResourceCreate) (*resourceRaw, error) {
row, err := db.Query(` query := `
INSERT INTO resource ( INSERT INTO resource (
filename, filename,
blob, blob,
@ -124,21 +157,9 @@ func createResource(db *sql.DB, create *api.ResourceCreate) (*resourceRaw, error
) )
VALUES (?, ?, ?, ?, ?) VALUES (?, ?, ?, ?, ?)
RETURNING id, filename, blob, type, size, creator_id, created_ts, updated_ts RETURNING id, filename, blob, type, size, creator_id, created_ts, updated_ts
`, `
create.Filename,
create.Blob,
create.Type,
create.Size,
create.CreatorID,
)
if err != nil {
return nil, FormatError(err)
}
defer row.Close()
row.Next()
var resourceRaw resourceRaw var resourceRaw resourceRaw
if err := row.Scan( if err := tx.QueryRowContext(ctx, query, create.Filename, create.Blob, create.Type, create.Size, create.CreatorID).Scan(
&resourceRaw.ID, &resourceRaw.ID,
&resourceRaw.Filename, &resourceRaw.Filename,
&resourceRaw.Blob, &resourceRaw.Blob,
@ -154,7 +175,7 @@ func createResource(db *sql.DB, create *api.ResourceCreate) (*resourceRaw, error
return &resourceRaw, nil return &resourceRaw, nil
} }
func findResourceList(db *sql.DB, find *api.ResourceFind) ([]*resourceRaw, error) { func findResourceList(ctx context.Context, tx *sql.Tx, find *api.ResourceFind) ([]*resourceRaw, error) {
where, args := []string{"1 = 1"}, []interface{}{} where, args := []string{"1 = 1"}, []interface{}{}
if v := find.ID; v != nil { if v := find.ID; v != nil {
@ -167,7 +188,7 @@ func findResourceList(db *sql.DB, find *api.ResourceFind) ([]*resourceRaw, error
where, args = append(where, "filename = ?"), append(args, *v) where, args = append(where, "filename = ?"), append(args, *v)
} }
rows, err := db.Query(` query := `
SELECT SELECT
id, id,
filename, filename,
@ -178,10 +199,10 @@ func findResourceList(db *sql.DB, find *api.ResourceFind) ([]*resourceRaw, error
created_ts, created_ts,
updated_ts updated_ts
FROM resource FROM resource
WHERE `+strings.Join(where, " AND ")+` WHERE ` + strings.Join(where, " AND ") + `
ORDER BY created_ts DESC`, ORDER BY created_ts DESC
args..., `
) rows, err := tx.QueryContext(ctx, query, args...)
if err != nil { if err != nil {
return nil, FormatError(err) return nil, FormatError(err)
} }
@ -213,8 +234,8 @@ func findResourceList(db *sql.DB, find *api.ResourceFind) ([]*resourceRaw, error
return resourceRawList, nil return resourceRawList, nil
} }
func deleteResource(db *sql.DB, delete *api.ResourceDelete) error { func deleteResource(ctx context.Context, tx *sql.Tx, delete *api.ResourceDelete) error {
result, err := db.Exec(` _, err := tx.ExecContext(ctx, `
PRAGMA foreign_keys = ON; PRAGMA foreign_keys = ON;
DELETE FROM resource WHERE id = ? AND creator_id = ? DELETE FROM resource WHERE id = ? AND creator_id = ?
`, delete.ID, delete.CreatorID) `, delete.ID, delete.CreatorID)
@ -222,10 +243,5 @@ func deleteResource(db *sql.DB, delete *api.ResourceDelete) error {
return FormatError(err) return FormatError(err)
} }
rows, _ := result.RowsAffected()
if rows == 0 {
return &common.Error{Code: common.NotFound, Err: fmt.Errorf("resource ID not found: %d", delete.ID)}
}
return nil return nil
} }

@ -1,6 +1,7 @@
package store package store
import ( import (
"context"
"database/sql" "database/sql"
"fmt" "fmt"
"strings" "strings"
@ -39,12 +40,22 @@ func (raw *shortcutRaw) toShortcut() *api.Shortcut {
} }
} }
func (s *Store) CreateShortcut(create *api.ShortcutCreate) (*api.Shortcut, error) { func (s *Store) CreateShortcut(ctx context.Context, create *api.ShortcutCreate) (*api.Shortcut, error) {
shortcutRaw, err := createShortcut(s.db, create) tx, err := s.db.BeginTx(ctx, nil)
if err != nil {
return nil, FormatError(err)
}
defer tx.Rollback()
shortcutRaw, err := createShortcut(ctx, tx, create)
if err != nil { if err != nil {
return nil, err return nil, err
} }
if err := tx.Commit(); err != nil {
return nil, FormatError(err)
}
shortcut := shortcutRaw.toShortcut() shortcut := shortcutRaw.toShortcut()
if err := s.cache.UpsertCache(api.ShortcutCache, shortcut.ID, shortcut); err != nil { if err := s.cache.UpsertCache(api.ShortcutCache, shortcut.ID, shortcut); err != nil {
@ -54,12 +65,22 @@ func (s *Store) CreateShortcut(create *api.ShortcutCreate) (*api.Shortcut, error
return shortcut, nil return shortcut, nil
} }
func (s *Store) PatchShortcut(patch *api.ShortcutPatch) (*api.Shortcut, error) { func (s *Store) PatchShortcut(ctx context.Context, patch *api.ShortcutPatch) (*api.Shortcut, error) {
shortcutRaw, err := patchShortcut(s.db, patch) tx, err := s.db.BeginTx(ctx, nil)
if err != nil {
return nil, FormatError(err)
}
defer tx.Rollback()
shortcutRaw, err := patchShortcut(ctx, tx, patch)
if err != nil { if err != nil {
return nil, err return nil, err
} }
if err := tx.Commit(); err != nil {
return nil, FormatError(err)
}
shortcut := shortcutRaw.toShortcut() shortcut := shortcutRaw.toShortcut()
if err := s.cache.UpsertCache(api.ShortcutCache, shortcut.ID, shortcut); err != nil { if err := s.cache.UpsertCache(api.ShortcutCache, shortcut.ID, shortcut); err != nil {
@ -69,8 +90,14 @@ func (s *Store) PatchShortcut(patch *api.ShortcutPatch) (*api.Shortcut, error) {
return shortcut, nil return shortcut, nil
} }
func (s *Store) FindShortcutList(find *api.ShortcutFind) ([]*api.Shortcut, error) { func (s *Store) FindShortcutList(ctx context.Context, find *api.ShortcutFind) ([]*api.Shortcut, error) {
shortcutRawList, err := findShortcutList(s.db, find) tx, err := s.db.BeginTx(ctx, nil)
if err != nil {
return nil, FormatError(err)
}
defer tx.Rollback()
shortcutRawList, err := findShortcutList(ctx, tx, find)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -83,7 +110,7 @@ func (s *Store) FindShortcutList(find *api.ShortcutFind) ([]*api.Shortcut, error
return list, nil return list, nil
} }
func (s *Store) FindShortcut(find *api.ShortcutFind) (*api.Shortcut, error) { func (s *Store) FindShortcut(ctx context.Context, find *api.ShortcutFind) (*api.Shortcut, error) {
if find.ID != nil { if find.ID != nil {
shortcut := &api.Shortcut{} shortcut := &api.Shortcut{}
has, err := s.cache.FindCache(api.ShortcutCache, *find.ID, shortcut) has, err := s.cache.FindCache(api.ShortcutCache, *find.ID, shortcut)
@ -95,7 +122,13 @@ func (s *Store) FindShortcut(find *api.ShortcutFind) (*api.Shortcut, error) {
} }
} }
list, err := findShortcutList(s.db, find) tx, err := s.db.BeginTx(ctx, nil)
if err != nil {
return nil, FormatError(err)
}
defer tx.Rollback()
list, err := findShortcutList(ctx, tx, find)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -113,19 +146,29 @@ func (s *Store) FindShortcut(find *api.ShortcutFind) (*api.Shortcut, error) {
return shortcut, nil return shortcut, nil
} }
func (s *Store) DeleteShortcut(delete *api.ShortcutDelete) error { func (s *Store) DeleteShortcut(ctx context.Context, delete *api.ShortcutDelete) error {
err := deleteShortcut(s.db, delete) tx, err := s.db.BeginTx(ctx, nil)
if err != nil { if err != nil {
return FormatError(err) return FormatError(err)
} }
defer tx.Rollback()
err = deleteShortcut(ctx, tx, delete)
if err != nil {
return FormatError(err)
}
if err := tx.Commit(); err != nil {
return FormatError(err)
}
s.cache.DeleteCache(api.ShortcutCache, delete.ID) s.cache.DeleteCache(api.ShortcutCache, delete.ID)
return nil return nil
} }
func createShortcut(db *sql.DB, create *api.ShortcutCreate) (*shortcutRaw, error) { func createShortcut(ctx context.Context, tx *sql.Tx, create *api.ShortcutCreate) (*shortcutRaw, error) {
row, err := db.Query(` query := `
INSERT INTO shortcut ( INSERT INTO shortcut (
title, title,
payload, payload,
@ -133,19 +176,9 @@ func createShortcut(db *sql.DB, create *api.ShortcutCreate) (*shortcutRaw, error
) )
VALUES (?, ?, ?) VALUES (?, ?, ?)
RETURNING id, title, payload, creator_id, created_ts, updated_ts, row_status RETURNING id, title, payload, creator_id, created_ts, updated_ts, row_status
`, `
create.Title,
create.Payload,
create.CreatorID,
)
if err != nil {
return nil, FormatError(err)
}
defer row.Close()
row.Next()
var shortcutRaw shortcutRaw var shortcutRaw shortcutRaw
if err := row.Scan( if err := tx.QueryRowContext(ctx, query, create.Title, create.Payload, create.CreatorID).Scan(
&shortcutRaw.ID, &shortcutRaw.ID,
&shortcutRaw.Title, &shortcutRaw.Title,
&shortcutRaw.Payload, &shortcutRaw.Payload,
@ -160,7 +193,7 @@ func createShortcut(db *sql.DB, create *api.ShortcutCreate) (*shortcutRaw, error
return &shortcutRaw, nil return &shortcutRaw, nil
} }
func patchShortcut(db *sql.DB, patch *api.ShortcutPatch) (*shortcutRaw, error) { func patchShortcut(ctx context.Context, tx *sql.Tx, patch *api.ShortcutPatch) (*shortcutRaw, error) {
set, args := []string{}, []interface{}{} set, args := []string{}, []interface{}{}
if v := patch.Title; v != nil { if v := patch.Title; v != nil {
@ -175,23 +208,14 @@ func patchShortcut(db *sql.DB, patch *api.ShortcutPatch) (*shortcutRaw, error) {
args = append(args, patch.ID) args = append(args, patch.ID)
row, err := db.Query(` query := `
UPDATE shortcut UPDATE shortcut
SET `+strings.Join(set, ", ")+` SET ` + strings.Join(set, ", ") + `
WHERE id = ? WHERE id = ?
RETURNING id, title, payload, created_ts, updated_ts, row_status RETURNING id, title, payload, created_ts, updated_ts, row_status
`, args...) `
if err != nil {
return nil, FormatError(err)
}
defer row.Close()
if !row.Next() {
return nil, &common.Error{Code: common.NotFound, Err: fmt.Errorf("not found")}
}
var shortcutRaw shortcutRaw var shortcutRaw shortcutRaw
if err := row.Scan( if err := tx.QueryRowContext(ctx, query, args...).Scan(
&shortcutRaw.ID, &shortcutRaw.ID,
&shortcutRaw.Title, &shortcutRaw.Title,
&shortcutRaw.Payload, &shortcutRaw.Payload,
@ -205,7 +229,7 @@ func patchShortcut(db *sql.DB, patch *api.ShortcutPatch) (*shortcutRaw, error) {
return &shortcutRaw, nil return &shortcutRaw, nil
} }
func findShortcutList(db *sql.DB, find *api.ShortcutFind) ([]*shortcutRaw, error) { func findShortcutList(ctx context.Context, tx *sql.Tx, find *api.ShortcutFind) ([]*shortcutRaw, error) {
where, args := []string{"1 = 1"}, []interface{}{} where, args := []string{"1 = 1"}, []interface{}{}
if v := find.ID; v != nil { if v := find.ID; v != nil {
@ -218,7 +242,7 @@ func findShortcutList(db *sql.DB, find *api.ShortcutFind) ([]*shortcutRaw, error
where, args = append(where, "title = ?"), append(args, *v) where, args = append(where, "title = ?"), append(args, *v)
} }
rows, err := db.Query(` rows, err := tx.QueryContext(ctx, `
SELECT SELECT
id, id,
title, title,
@ -262,8 +286,8 @@ func findShortcutList(db *sql.DB, find *api.ShortcutFind) ([]*shortcutRaw, error
return shortcutRawList, nil return shortcutRawList, nil
} }
func deleteShortcut(db *sql.DB, delete *api.ShortcutDelete) error { func deleteShortcut(ctx context.Context, tx *sql.Tx, delete *api.ShortcutDelete) error {
result, err := db.Exec(` _, err := tx.ExecContext(ctx, `
PRAGMA foreign_keys = ON; PRAGMA foreign_keys = ON;
DELETE FROM shortcut WHERE id = ? DELETE FROM shortcut WHERE id = ?
`, delete.ID) `, delete.ID)
@ -271,10 +295,5 @@ func deleteShortcut(db *sql.DB, delete *api.ShortcutDelete) error {
return FormatError(err) return FormatError(err)
} }
rows, _ := result.RowsAffected()
if rows == 0 {
return &common.Error{Code: common.NotFound, Err: fmt.Errorf("shortcut ID not found: %d", delete.ID)}
}
return nil return nil
} }

Loading…
Cancel
Save