diff --git a/api/memo_resource.go b/api/memo_resource.go new file mode 100644 index 00000000..dc45fa5d --- /dev/null +++ b/api/memo_resource.go @@ -0,0 +1,24 @@ +package api + +type MemoResource struct { + MemoID int + ResourceID int + CreatedTs int64 + UpdatedTs int64 +} + +type MemoResourceUpsert struct { + MemoID int + ResourceID int + UpdatedTs *int64 +} + +type MemoResourceFind struct { + MemoID *int + ResourceID *int +} + +type MemoResourceDelete struct { + MemoID int + ResourceID *int +} diff --git a/store/db/migration/dev/LATEST__SCHEMA.sql b/store/db/migration/dev/LATEST__SCHEMA.sql index faa469b7..d7b1ae6e 100644 --- a/store/db/migration/dev/LATEST__SCHEMA.sql +++ b/store/db/migration/dev/LATEST__SCHEMA.sql @@ -164,6 +164,7 @@ CREATE TABLE memo_resource ( memo_id INTEGER NOT NULL, resource_id INTEGER NOT NULL, created_ts BIGINT NOT NULL DEFAULT (strftime('%s', 'now')), + updated_ts BIGINT NOT NULL DEFAULT (strftime('%s', 'now')), FOREIGN KEY(memo_id) REFERENCES memo(id) ON DELETE CASCADE, FOREIGN KEY(resource_id) REFERENCES resource(id) ON DELETE CASCADE ); diff --git a/store/memo_resource.go b/store/memo_resource.go new file mode 100644 index 00000000..fa648818 --- /dev/null +++ b/store/memo_resource.go @@ -0,0 +1,188 @@ +package store + +import ( + "context" + "database/sql" + "fmt" + "strings" + + "github.com/usememos/memos/api" + "github.com/usememos/memos/common" +) + +// memoResourceRaw is the store model for an MemoResource. +// Fields have exactly the same meanings as MemoResource. +type memoResourceRaw struct { + MemoID int + ResourceID int + CreatedTs int64 + UpdatedTs int64 +} + +func (raw *memoResourceRaw) toMemoResource() *api.MemoResource { + return &api.MemoResource{ + MemoID: raw.MemoID, + ResourceID: raw.ResourceID, + CreatedTs: raw.CreatedTs, + UpdatedTs: raw.UpdatedTs, + } +} + +func (s *Store) FindMemoResourceList(ctx context.Context, find *api.MemoResourceFind) ([]*api.MemoResource, error) { + tx, err := s.db.BeginTx(ctx, nil) + if err != nil { + return nil, FormatError(err) + } + defer tx.Rollback() + + memoResourceRawList, err := findMemoResourceList(ctx, tx, find) + if err != nil { + return nil, err + } + + list := []*api.MemoResource{} + for _, raw := range memoResourceRawList { + memoResource := raw.toMemoResource() + list = append(list, memoResource) + } + + return list, nil +} + +func (s *Store) UpsertMemoResource(ctx context.Context, upsert *api.MemoResourceUpsert) (*api.MemoResource, error) { + tx, err := s.db.BeginTx(ctx, nil) + if err != nil { + return nil, FormatError(err) + } + defer tx.Rollback() + + memoResourceRaw, err := upsertMemoResource(ctx, tx, upsert) + if err != nil { + return nil, err + } + + if err := tx.Commit(); err != nil { + return nil, FormatError(err) + } + + return memoResourceRaw.toMemoResource(), nil +} + +func (s *Store) DeleteMemoResource(ctx context.Context, delete *api.MemoResourceDelete) error { + tx, err := s.db.BeginTx(ctx, nil) + if err != nil { + return FormatError(err) + } + defer tx.Rollback() + + if err := deleteMemoResource(ctx, tx, delete); err != nil { + return FormatError(err) + } + + if err := tx.Commit(); err != nil { + return FormatError(err) + } + + return nil +} + +func findMemoResourceList(ctx context.Context, tx *sql.Tx, find *api.MemoResourceFind) ([]*memoResourceRaw, error) { + where, args := []string{"1 = 1"}, []interface{}{} + + if v := find.MemoID; v != nil { + where, args = append(where, "memo_id = ?"), append(args, *v) + } + if v := find.ResourceID; v != nil { + where, args = append(where, "resource_id = ?"), append(args, *v) + } + + query := ` + SELECT + memo_id, + resource_id, + created_ts, + updated_ts + FROM memo_resource + WHERE ` + strings.Join(where, " AND ") + ` + ORDER BY updated_ts DESC + ` + rows, err := tx.QueryContext(ctx, query, args...) + if err != nil { + return nil, FormatError(err) + } + defer rows.Close() + + memoResourceRawList := make([]*memoResourceRaw, 0) + for rows.Next() { + var memoResourceRaw memoResourceRaw + if err := rows.Scan( + &memoResourceRaw.MemoID, + &memoResourceRaw.ResourceID, + &memoResourceRaw.CreatedTs, + &memoResourceRaw.UpdatedTs, + ); err != nil { + return nil, FormatError(err) + } + + memoResourceRawList = append(memoResourceRawList, &memoResourceRaw) + } + + if err := rows.Err(); err != nil { + return nil, err + } + + return memoResourceRawList, nil +} + +func upsertMemoResource(ctx context.Context, tx *sql.Tx, upsert *api.MemoResourceUpsert) (*memoResourceRaw, error) { + set := []string{"memo_id", "resource_id"} + args := []interface{}{upsert.MemoID, upsert.ResourceID} + placeholder := []string{"?", "?"} + + if v := upsert.UpdatedTs; v != nil { + set, args, placeholder = append(set, "updated_ts"), append(args, v), append(placeholder, "?") + } + + query := ` + INSERT INTO memo_resource ( + ` + strings.Join(set, ", ") + ` + ) + VALUES (` + strings.Join(placeholder, ",") + `) + ON CONFLICT(memo_id, resource_id) DO UPDATE + SET + updated_ts = EXCLUDED.updated_ts + RETURNING memo_id, resource_id, created_ts, updated_ts + ` + var memoResourceRaw memoResourceRaw + if err := tx.QueryRowContext(ctx, query, args...).Scan( + &memoResourceRaw.MemoID, + &memoResourceRaw.ResourceID, + &memoResourceRaw.CreatedTs, + &memoResourceRaw.UpdatedTs, + ); err != nil { + return nil, FormatError(err) + } + + return &memoResourceRaw, nil +} + +func deleteMemoResource(ctx context.Context, tx *sql.Tx, delete *api.MemoResourceDelete) error { + where, args := []string{"memo_id = ?"}, []interface{}{delete.MemoID} + + if v := delete.ResourceID; v != nil { + where, args = append(where, "resource_id = ?"), append(args, *v) + } + + result, err := tx.ExecContext(ctx, ` + DELETE FROM memo WHERE `+strings.Join(where, " AND "), args...) + if err != nil { + return FormatError(err) + } + + rows, _ := result.RowsAffected() + if rows == 0 { + return &common.Error{Code: common.NotFound, Err: fmt.Errorf("memo resource not found")} + } + + return nil +}