mirror of https://github.com/usememos/memos
feat: support Postgres (#2569)
* skeleton of postgres skeleton * Adding Postgres specific db schema sql * user test passed * memo store test passed * tag is working * update user setting test done * activity test done * idp test passed * inbox test done * memo_organizer, UNTESTED * memo relation test passed * webhook test passed * system setting test passed * passed storage test * pass resource test * migration_history done * fix memo_relation_test * fixing server memo_relation test * passes memo relation server test * paess memo test * final manual testing done * final fixes * final fixes cleanup * sync schema * lint * lint * lint * lint * lintpull/2576/head
parent
484efbbfe2
commit
9c18960f47
@ -0,0 +1,117 @@
|
||||
package postgres
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"github.com/Masterminds/squirrel"
|
||||
"github.com/pkg/errors"
|
||||
"google.golang.org/protobuf/encoding/protojson"
|
||||
|
||||
storepb "github.com/usememos/memos/proto/gen/store"
|
||||
"github.com/usememos/memos/store"
|
||||
)
|
||||
|
||||
func (d *DB) CreateActivity(ctx context.Context, create *store.Activity) (*store.Activity, error) {
|
||||
payloadString := "{}"
|
||||
if create.Payload != nil {
|
||||
bytes, err := protojson.Marshal(create.Payload)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to marshal activity payload")
|
||||
}
|
||||
payloadString = string(bytes)
|
||||
}
|
||||
|
||||
qb := squirrel.Insert("activity").
|
||||
Columns("creator_id", "type", "level", "payload").
|
||||
PlaceholderFormat(squirrel.Dollar)
|
||||
|
||||
values := []any{create.CreatorID, create.Type.String(), create.Level.String(), payloadString}
|
||||
|
||||
if create.ID != 0 {
|
||||
qb = qb.Columns("id")
|
||||
values = append(values, create.ID)
|
||||
}
|
||||
|
||||
if create.CreatedTs != 0 {
|
||||
qb = qb.Columns("created_ts")
|
||||
values = append(values, squirrel.Expr("TO_TIMESTAMP(?)", create.CreatedTs))
|
||||
}
|
||||
|
||||
qb = qb.Values(values...).Suffix("RETURNING id")
|
||||
|
||||
stmt, args, err := qb.ToSql()
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to construct query")
|
||||
}
|
||||
|
||||
var id int32
|
||||
err = d.db.QueryRowContext(ctx, stmt, args...).Scan(&id)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to execute statement and retrieve ID")
|
||||
}
|
||||
|
||||
list, err := d.ListActivities(ctx, &store.FindActivity{ID: &id})
|
||||
if err != nil || len(list) == 0 {
|
||||
return nil, errors.Wrap(err, "failed to find activity")
|
||||
}
|
||||
|
||||
return list[0], nil
|
||||
}
|
||||
|
||||
func (d *DB) ListActivities(ctx context.Context, find *store.FindActivity) ([]*store.Activity, error) {
|
||||
qb := squirrel.Select("id", "creator_id", "type", "level", "payload", "created_ts").
|
||||
From("activity").
|
||||
Where("1 = 1").
|
||||
PlaceholderFormat(squirrel.Dollar)
|
||||
|
||||
if find.ID != nil {
|
||||
qb = qb.Where(squirrel.Eq{"id": *find.ID})
|
||||
}
|
||||
if find.Type != nil {
|
||||
qb = qb.Where(squirrel.Eq{"type": find.Type.String()})
|
||||
}
|
||||
|
||||
query, args, err := qb.ToSql()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
rows, err := d.db.QueryContext(ctx, query, args...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
list := []*store.Activity{}
|
||||
for rows.Next() {
|
||||
activity := &store.Activity{}
|
||||
var payloadBytes []byte
|
||||
createdTsPlaceHolder := time.Time{}
|
||||
if err := rows.Scan(
|
||||
&activity.ID,
|
||||
&activity.CreatorID,
|
||||
&activity.Type,
|
||||
&activity.Level,
|
||||
&payloadBytes,
|
||||
&createdTsPlaceHolder,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
activity.CreatedTs = createdTsPlaceHolder.Unix()
|
||||
|
||||
payload := &storepb.ActivityPayload{}
|
||||
if err := protojson.Unmarshal(payloadBytes, payload); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
activity.Payload = payload
|
||||
list = append(list, activity)
|
||||
}
|
||||
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return list, nil
|
||||
}
|
@ -0,0 +1,9 @@
|
||||
package postgres
|
||||
|
||||
import "google.golang.org/protobuf/encoding/protojson"
|
||||
|
||||
var (
|
||||
protojsonUnmarshaler = protojson.UnmarshalOptions{
|
||||
DiscardUnknown: true,
|
||||
}
|
||||
)
|
@ -0,0 +1,178 @@
|
||||
package postgres
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
|
||||
"github.com/Masterminds/squirrel"
|
||||
"github.com/pkg/errors"
|
||||
|
||||
"github.com/usememos/memos/store"
|
||||
)
|
||||
|
||||
func (d *DB) CreateIdentityProvider(ctx context.Context, create *store.IdentityProvider) (*store.IdentityProvider, error) {
|
||||
var configBytes []byte
|
||||
if create.Type == store.IdentityProviderOAuth2Type {
|
||||
bytes, err := json.Marshal(create.Config.OAuth2Config)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
configBytes = bytes
|
||||
} else {
|
||||
return nil, errors.Errorf("unsupported idp type %s", string(create.Type))
|
||||
}
|
||||
|
||||
qb := squirrel.Insert("idp").Columns("name", "type", "identifier_filter", "config")
|
||||
values := []any{create.Name, create.Type, create.IdentifierFilter, string(configBytes)}
|
||||
|
||||
if create.ID != 0 {
|
||||
qb = qb.Columns("id")
|
||||
values = append(values, create.ID)
|
||||
}
|
||||
|
||||
qb = qb.Values(values...).PlaceholderFormat(squirrel.Dollar)
|
||||
qb = qb.Suffix("RETURNING id")
|
||||
|
||||
stmt, args, err := qb.ToSql()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var id int32
|
||||
err = d.db.QueryRowContext(ctx, stmt, args...).Scan(&id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
create.ID = id
|
||||
return create, nil
|
||||
}
|
||||
func (d *DB) ListIdentityProviders(ctx context.Context, find *store.FindIdentityProvider) ([]*store.IdentityProvider, error) {
|
||||
qb := squirrel.Select("id", "name", "type", "identifier_filter", "config").
|
||||
From("idp").
|
||||
Where("1 = 1").
|
||||
PlaceholderFormat(squirrel.Dollar)
|
||||
|
||||
if v := find.ID; v != nil {
|
||||
qb = qb.Where(squirrel.Eq{"id": *v})
|
||||
}
|
||||
|
||||
query, args, err := qb.ToSql()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
rows, err := d.db.QueryContext(ctx, query, args...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var identityProviders []*store.IdentityProvider
|
||||
for rows.Next() {
|
||||
var identityProvider store.IdentityProvider
|
||||
var identityProviderConfig string
|
||||
if err := rows.Scan(
|
||||
&identityProvider.ID,
|
||||
&identityProvider.Name,
|
||||
&identityProvider.Type,
|
||||
&identityProvider.IdentifierFilter,
|
||||
&identityProviderConfig,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if identityProvider.Type == store.IdentityProviderOAuth2Type {
|
||||
oauth2Config := &store.IdentityProviderOAuth2Config{}
|
||||
if err := json.Unmarshal([]byte(identityProviderConfig), oauth2Config); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
identityProvider.Config = &store.IdentityProviderConfig{
|
||||
OAuth2Config: oauth2Config,
|
||||
}
|
||||
} else {
|
||||
return nil, errors.Errorf("unsupported idp type %s", string(identityProvider.Type))
|
||||
}
|
||||
identityProviders = append(identityProviders, &identityProvider)
|
||||
}
|
||||
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return identityProviders, nil
|
||||
}
|
||||
|
||||
func (d *DB) GetIdentityProvider(ctx context.Context, find *store.FindIdentityProvider) (*store.IdentityProvider, error) {
|
||||
list, err := d.ListIdentityProviders(ctx, find)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(list) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
return list[0], nil
|
||||
}
|
||||
|
||||
func (d *DB) UpdateIdentityProvider(ctx context.Context, update *store.UpdateIdentityProvider) (*store.IdentityProvider, error) {
|
||||
qb := squirrel.Update("idp").
|
||||
PlaceholderFormat(squirrel.Dollar)
|
||||
var err error
|
||||
|
||||
if v := update.Name; v != nil {
|
||||
qb = qb.Set("name", *v)
|
||||
}
|
||||
if v := update.IdentifierFilter; v != nil {
|
||||
qb = qb.Set("identifier_filter", *v)
|
||||
}
|
||||
if v := update.Config; v != nil {
|
||||
var configBytes []byte
|
||||
if update.Type == store.IdentityProviderOAuth2Type {
|
||||
bytes, err := json.Marshal(update.Config.OAuth2Config)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
configBytes = bytes
|
||||
} else {
|
||||
return nil, errors.Errorf("unsupported idp type %s", string(update.Type))
|
||||
}
|
||||
qb = qb.Set("config", string(configBytes))
|
||||
}
|
||||
|
||||
qb = qb.Where(squirrel.Eq{"id": update.ID})
|
||||
|
||||
stmt, args, err := qb.ToSql()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
_, err = d.db.ExecContext(ctx, stmt, args...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return d.GetIdentityProvider(ctx, &store.FindIdentityProvider{ID: &update.ID})
|
||||
}
|
||||
|
||||
func (d *DB) DeleteIdentityProvider(ctx context.Context, delete *store.DeleteIdentityProvider) error {
|
||||
qb := squirrel.Delete("idp").
|
||||
Where(squirrel.Eq{"id": delete.ID}).
|
||||
PlaceholderFormat(squirrel.Dollar)
|
||||
|
||||
stmt, args, err := qb.ToSql()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
result, err := d.db.ExecContext(ctx, stmt, args...)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if _, err = result.RowsAffected(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
@ -0,0 +1,144 @@
|
||||
package postgres
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"github.com/Masterminds/squirrel"
|
||||
"github.com/pkg/errors"
|
||||
"google.golang.org/protobuf/encoding/protojson"
|
||||
|
||||
storepb "github.com/usememos/memos/proto/gen/store"
|
||||
"github.com/usememos/memos/store"
|
||||
)
|
||||
|
||||
func (d *DB) CreateInbox(ctx context.Context, create *store.Inbox) (*store.Inbox, error) {
|
||||
messageString := "{}"
|
||||
if create.Message != nil {
|
||||
bytes, err := protojson.Marshal(create.Message)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to marshal inbox message")
|
||||
}
|
||||
messageString = string(bytes)
|
||||
}
|
||||
|
||||
qb := squirrel.Insert("inbox").
|
||||
Columns("sender_id", "receiver_id", "status", "message").
|
||||
Values(create.SenderID, create.ReceiverID, create.Status, messageString).
|
||||
Suffix("RETURNING id").
|
||||
PlaceholderFormat(squirrel.Dollar)
|
||||
|
||||
stmt, args, err := qb.ToSql()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var id int32
|
||||
err = d.db.QueryRowContext(ctx, stmt, args...).Scan(&id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return d.GetInbox(ctx, &store.FindInbox{ID: &id})
|
||||
}
|
||||
|
||||
func (d *DB) ListInboxes(ctx context.Context, find *store.FindInbox) ([]*store.Inbox, error) {
|
||||
qb := squirrel.Select("id", "created_ts", "sender_id", "receiver_id", "status", "message").
|
||||
From("inbox").
|
||||
Where("1 = 1").
|
||||
PlaceholderFormat(squirrel.Dollar)
|
||||
|
||||
if find.ID != nil {
|
||||
qb = qb.Where(squirrel.Eq{"id": *find.ID})
|
||||
}
|
||||
if find.SenderID != nil {
|
||||
qb = qb.Where(squirrel.Eq{"sender_id": *find.SenderID})
|
||||
}
|
||||
if find.ReceiverID != nil {
|
||||
qb = qb.Where(squirrel.Eq{"receiver_id": *find.ReceiverID})
|
||||
}
|
||||
if find.Status != nil {
|
||||
qb = qb.Where(squirrel.Eq{"status": *find.Status})
|
||||
}
|
||||
|
||||
query, args, err := qb.ToSql()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
rows, err := d.db.QueryContext(ctx, query, args...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var list []*store.Inbox
|
||||
for rows.Next() {
|
||||
inbox := &store.Inbox{}
|
||||
var messageBytes []byte
|
||||
createdTsPlaceHolder := time.Time{}
|
||||
if err := rows.Scan(&inbox.ID, &createdTsPlaceHolder, &inbox.SenderID, &inbox.ReceiverID, &inbox.Status, &messageBytes); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
inbox.CreatedTs = createdTsPlaceHolder.Unix()
|
||||
|
||||
message := &storepb.InboxMessage{}
|
||||
if err := protojsonUnmarshaler.Unmarshal(messageBytes, message); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
inbox.Message = message
|
||||
list = append(list, inbox)
|
||||
}
|
||||
|
||||
return list, rows.Err()
|
||||
}
|
||||
|
||||
func (d *DB) GetInbox(ctx context.Context, find *store.FindInbox) (*store.Inbox, error) {
|
||||
list, err := d.ListInboxes(ctx, find)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to get inbox")
|
||||
}
|
||||
if len(list) != 1 {
|
||||
return nil, errors.Wrapf(nil, "unexpected inbox count: %d", len(list))
|
||||
}
|
||||
return list[0], nil
|
||||
}
|
||||
|
||||
func (d *DB) UpdateInbox(ctx context.Context, update *store.UpdateInbox) (*store.Inbox, error) {
|
||||
qb := squirrel.Update("inbox").
|
||||
Set("status", update.Status.String()).
|
||||
Where(squirrel.Eq{"id": update.ID}).
|
||||
PlaceholderFormat(squirrel.Dollar)
|
||||
|
||||
stmt, args, err := qb.ToSql()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
_, err = d.db.ExecContext(ctx, stmt, args...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return d.GetInbox(ctx, &store.FindInbox{ID: &update.ID})
|
||||
}
|
||||
|
||||
func (d *DB) DeleteInbox(ctx context.Context, delete *store.DeleteInbox) error {
|
||||
qb := squirrel.Delete("inbox").
|
||||
Where(squirrel.Eq{"id": delete.ID}).
|
||||
PlaceholderFormat(squirrel.Dollar)
|
||||
|
||||
stmt, args, err := qb.ToSql()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
result, err := d.db.ExecContext(ctx, stmt, args...)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_, err = result.RowsAffected()
|
||||
return err
|
||||
}
|
@ -0,0 +1,370 @@
|
||||
package postgres
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/Masterminds/squirrel"
|
||||
"github.com/pkg/errors"
|
||||
|
||||
"github.com/usememos/memos/internal/util"
|
||||
"github.com/usememos/memos/store"
|
||||
)
|
||||
|
||||
func (d *DB) CreateMemo(ctx context.Context, create *store.Memo) (*store.Memo, error) {
|
||||
// Initialize a Squirrel statement builder for PostgreSQL
|
||||
builder := squirrel.Insert("memo").
|
||||
PlaceholderFormat(squirrel.Dollar).
|
||||
Columns("creator_id", "content", "visibility")
|
||||
|
||||
// Add initial values for the columns
|
||||
values := []any{create.CreatorID, create.Content, create.Visibility}
|
||||
|
||||
// Conditionally add other fields and values
|
||||
if create.ID != 0 {
|
||||
builder = builder.Columns("id")
|
||||
values = append(values, create.ID)
|
||||
}
|
||||
|
||||
if create.CreatedTs != 0 {
|
||||
builder = builder.Columns("created_ts")
|
||||
values = append(values, squirrel.Expr("to_timestamp(?)", create.CreatedTs))
|
||||
}
|
||||
|
||||
if create.UpdatedTs != 0 {
|
||||
builder = builder.Columns("updated_ts")
|
||||
values = append(values, squirrel.Expr("to_timestamp(?)", create.UpdatedTs))
|
||||
}
|
||||
|
||||
if create.RowStatus != "" {
|
||||
builder = builder.Columns("row_status")
|
||||
values = append(values, create.RowStatus)
|
||||
}
|
||||
|
||||
// Add all the values at once
|
||||
builder = builder.Values(values...)
|
||||
|
||||
// Add the RETURNING clause to get the ID of the inserted row
|
||||
builder = builder.Suffix("RETURNING id")
|
||||
|
||||
// Prepare and execute the query
|
||||
query, args, err := builder.ToSql()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var id int32
|
||||
err = d.db.QueryRowContext(ctx, query, args...).Scan(&id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Retrieve the newly created memo
|
||||
memo, err := d.GetMemo(ctx, &store.FindMemo{ID: &id})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if memo == nil {
|
||||
return nil, errors.Errorf("failed to create memo")
|
||||
}
|
||||
|
||||
return memo, nil
|
||||
}
|
||||
|
||||
func (d *DB) ListMemos(ctx context.Context, find *store.FindMemo) ([]*store.Memo, error) {
|
||||
// Start building the SELECT statement
|
||||
builder := squirrel.Select(
|
||||
"memo.id AS id",
|
||||
"memo.creator_id AS creator_id",
|
||||
"EXTRACT(EPOCH FROM memo.created_ts) AS created_ts",
|
||||
"EXTRACT(EPOCH FROM memo.updated_ts) AS updated_ts",
|
||||
"memo.row_status AS row_status",
|
||||
"memo.content AS content",
|
||||
"memo.visibility AS visibility",
|
||||
"MAX(CASE WHEN memo_organizer.pinned = 1 THEN 1 ELSE 0 END) AS pinned",
|
||||
"string_agg(CAST(resource.id AS TEXT), ',') AS resource_id_list", // Cast to TEXT
|
||||
"(SELECT string_agg(CAST(memo_id AS TEXT) || ':' || CAST(related_memo_id AS TEXT) || ':' || type, ',') FROM memo_relation WHERE memo_relation.memo_id = memo.id OR memo_relation.related_memo_id = memo.id) AS relation_list"). // Cast IDs to TEXT
|
||||
From("memo").
|
||||
LeftJoin("memo_organizer ON memo.id = memo_organizer.memo_id").
|
||||
LeftJoin("resource ON memo.id = resource.memo_id").
|
||||
GroupBy("memo.id").
|
||||
PlaceholderFormat(squirrel.Dollar)
|
||||
|
||||
// Add conditional where clauses
|
||||
if v := find.ID; v != nil {
|
||||
builder = builder.Where("memo.id = ?", *v)
|
||||
}
|
||||
if v := find.CreatorID; v != nil {
|
||||
builder = builder.Where("memo.creator_id = ?", *v)
|
||||
}
|
||||
if v := find.RowStatus; v != nil {
|
||||
builder = builder.Where("memo.row_status = ?", *v)
|
||||
}
|
||||
if v := find.CreatedTsBefore; v != nil {
|
||||
builder = builder.Where("EXTRACT(EPOCH FROM memo.created_ts) < ?", *v)
|
||||
}
|
||||
if v := find.CreatedTsAfter; v != nil {
|
||||
builder = builder.Where("EXTRACT(EPOCH FROM memo.created_ts) > ?", *v)
|
||||
}
|
||||
if v := find.Pinned; v != nil {
|
||||
builder = builder.Where("memo_organizer.pinned = 1")
|
||||
}
|
||||
if v := find.ContentSearch; len(v) != 0 {
|
||||
for _, s := range v {
|
||||
builder = builder.Where("memo.content LIKE ?", "%"+s+"%")
|
||||
}
|
||||
}
|
||||
|
||||
if v := find.VisibilityList; len(v) != 0 {
|
||||
placeholders := make([]string, len(v))
|
||||
args := make([]any, len(v))
|
||||
for i, visibility := range v {
|
||||
placeholders[i] = "?"
|
||||
args[i] = visibility // Assuming visibility can be directly used as an argument
|
||||
}
|
||||
inClause := strings.Join(placeholders, ",")
|
||||
builder = builder.Where("memo.visibility IN ("+inClause+")", args...)
|
||||
}
|
||||
// Add order by clauses
|
||||
if find.OrderByPinned {
|
||||
builder = builder.OrderBy("pinned DESC")
|
||||
}
|
||||
if find.OrderByUpdatedTs {
|
||||
builder = builder.OrderBy("updated_ts DESC")
|
||||
} else {
|
||||
builder = builder.OrderBy("created_ts DESC")
|
||||
}
|
||||
builder = builder.OrderBy("id DESC")
|
||||
|
||||
// Handle pagination
|
||||
if find.Limit != nil {
|
||||
builder = builder.Limit(uint64(*find.Limit))
|
||||
if find.Offset != nil {
|
||||
builder = builder.Offset(uint64(*find.Offset))
|
||||
}
|
||||
}
|
||||
|
||||
// Prepare and execute the query
|
||||
query, args, err := builder.ToSql()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
rows, err := d.db.QueryContext(ctx, query, args...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
// Process the result set
|
||||
list := make([]*store.Memo, 0)
|
||||
updatedTsPlaceHolder, createdTsPlaceHolder := make([]uint8, 8), make([]uint8, 8)
|
||||
for rows.Next() {
|
||||
var memo store.Memo
|
||||
var memoResourceIDList sql.NullString
|
||||
var memoRelationList sql.NullString
|
||||
if err := rows.Scan(
|
||||
&memo.ID,
|
||||
&memo.CreatorID,
|
||||
&createdTsPlaceHolder,
|
||||
&updatedTsPlaceHolder,
|
||||
&memo.RowStatus,
|
||||
&memo.Content,
|
||||
&memo.Visibility,
|
||||
&memo.Pinned,
|
||||
&memoResourceIDList,
|
||||
&memoRelationList,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Convert the timestamps from Postgres to Go
|
||||
memo.CreatedTs = int64(binary.BigEndian.Uint64(createdTsPlaceHolder))
|
||||
memo.UpdatedTs = int64(binary.BigEndian.Uint64(updatedTsPlaceHolder))
|
||||
|
||||
if memoResourceIDList.Valid {
|
||||
idStringList := strings.Split(memoResourceIDList.String, ",")
|
||||
memo.ResourceIDList = make([]int32, 0, len(idStringList))
|
||||
for _, idString := range idStringList {
|
||||
id, err := util.ConvertStringToInt32(idString)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
memo.ResourceIDList = append(memo.ResourceIDList, id)
|
||||
}
|
||||
}
|
||||
if memoRelationList.Valid {
|
||||
memo.RelationList = make([]*store.MemoRelation, 0)
|
||||
relatedMemoTypeList := strings.Split(memoRelationList.String, ",")
|
||||
for _, relatedMemoType := range relatedMemoTypeList {
|
||||
relatedMemoTypeList := strings.Split(relatedMemoType, ":")
|
||||
if len(relatedMemoTypeList) != 3 {
|
||||
return nil, errors.Errorf("invalid relation format")
|
||||
}
|
||||
memoID, err := util.ConvertStringToInt32(relatedMemoTypeList[0])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
relatedMemoID, err := util.ConvertStringToInt32(relatedMemoTypeList[1])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
relationType := store.MemoRelationType(relatedMemoTypeList[2])
|
||||
memo.RelationList = append(memo.RelationList, &store.MemoRelation{
|
||||
MemoID: memoID,
|
||||
RelatedMemoID: relatedMemoID,
|
||||
Type: relationType,
|
||||
})
|
||||
// Set the first parent ID if relation type is comment.
|
||||
if memo.ParentID == nil && memoID == memo.ID && relationType == store.MemoRelationComment {
|
||||
memo.ParentID = &relatedMemoID
|
||||
}
|
||||
}
|
||||
}
|
||||
list = append(list, &memo)
|
||||
}
|
||||
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return list, nil
|
||||
}
|
||||
|
||||
func (d *DB) GetMemo(ctx context.Context, find *store.FindMemo) (*store.Memo, error) {
|
||||
list, err := d.ListMemos(ctx, find)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(list) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
memo := list[0]
|
||||
return memo, nil
|
||||
}
|
||||
|
||||
func (d *DB) UpdateMemo(ctx context.Context, update *store.UpdateMemo) error {
|
||||
// Start building the update statement
|
||||
builder := squirrel.Update("memo").
|
||||
PlaceholderFormat(squirrel.Dollar).
|
||||
Where("id = ?", update.ID)
|
||||
|
||||
// Conditionally add set clauses
|
||||
if v := update.CreatedTs; v != nil {
|
||||
builder = builder.Set("created_ts", squirrel.Expr("to_timestamp(?)", *v))
|
||||
}
|
||||
if v := update.UpdatedTs; v != nil {
|
||||
builder = builder.Set("updated_ts", squirrel.Expr("to_timestamp(?)", *v))
|
||||
}
|
||||
if v := update.RowStatus; v != nil {
|
||||
builder = builder.Set("row_status", *v)
|
||||
}
|
||||
if v := update.Content; v != nil {
|
||||
builder = builder.Set("content", *v)
|
||||
}
|
||||
if v := update.Visibility; v != nil {
|
||||
builder = builder.Set("visibility", *v)
|
||||
}
|
||||
|
||||
// Prepare and execute the query
|
||||
query, args, err := builder.ToSql()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if _, err := d.db.ExecContext(ctx, query, args...); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d *DB) DeleteMemo(ctx context.Context, delete *store.DeleteMemo) error {
|
||||
// Start building the DELETE statement
|
||||
builder := squirrel.Delete("memo").
|
||||
PlaceholderFormat(squirrel.Dollar).
|
||||
Where(squirrel.Eq{"id": delete.ID})
|
||||
|
||||
// Prepare the final query
|
||||
query, args, err := builder.ToSql()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Execute the query with the context
|
||||
result, err := d.db.ExecContext(ctx, query, args...)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if _, err := result.RowsAffected(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Perform any additional cleanup or operations such as vacuuming
|
||||
// irving: wait, why do we need to vacuum here?
|
||||
// I don't know why delete memo needs to vacuum. so I commented out.
|
||||
// REVIEWERS LOOK AT ME: please check this.
|
||||
|
||||
return d.Vacuum(ctx)
|
||||
}
|
||||
|
||||
func (d *DB) FindMemosVisibilityList(ctx context.Context, memoIDs []int32) ([]store.Visibility, error) {
|
||||
// Start building the SELECT statement
|
||||
builder := squirrel.Select("DISTINCT(visibility)").From("memo").
|
||||
PlaceholderFormat(squirrel.Dollar).
|
||||
Where(squirrel.Eq{"id": memoIDs})
|
||||
|
||||
// Prepare the final query
|
||||
query, args, err := builder.ToSql()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Execute the query with the context
|
||||
rows, err := d.db.QueryContext(ctx, query, args...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
visibilityList := make([]store.Visibility, 0)
|
||||
for rows.Next() {
|
||||
var visibility store.Visibility
|
||||
if err := rows.Scan(&visibility); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
visibilityList = append(visibilityList, visibility)
|
||||
}
|
||||
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return visibilityList, nil
|
||||
}
|
||||
|
||||
func vacuumMemo(ctx context.Context, tx *sql.Tx) error {
|
||||
// First, build the subquery
|
||||
subQuery, subArgs, err := squirrel.Select("id").From("user").PlaceholderFormat(squirrel.Dollar).ToSql()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Now, build the main delete query using the subquery
|
||||
query, args, err := squirrel.Delete("memo").
|
||||
Where(fmt.Sprintf("creator_id NOT IN (%s)", subQuery), subArgs...).
|
||||
PlaceholderFormat(squirrel.Dollar).
|
||||
ToSql()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Execute the query
|
||||
_, err = tx.ExecContext(ctx, query, args...)
|
||||
return err
|
||||
}
|
@ -0,0 +1,123 @@
|
||||
package postgres
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
|
||||
"github.com/Masterminds/squirrel"
|
||||
|
||||
"github.com/usememos/memos/store"
|
||||
)
|
||||
|
||||
func (d *DB) UpsertMemoOrganizer(ctx context.Context, upsert *store.MemoOrganizer) (*store.MemoOrganizer, error) {
|
||||
pinnedValue := 0
|
||||
if upsert.Pinned {
|
||||
pinnedValue = 1
|
||||
}
|
||||
qb := squirrel.Insert("memo_organizer").
|
||||
Columns("memo_id", "user_id", "pinned").
|
||||
Values(upsert.MemoID, upsert.UserID, pinnedValue).
|
||||
PlaceholderFormat(squirrel.Dollar)
|
||||
|
||||
stmt, args, err := qb.ToSql()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if _, err = d.db.ExecContext(ctx, stmt, args...); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return upsert, nil
|
||||
}
|
||||
|
||||
func (d *DB) ListMemoOrganizer(ctx context.Context, find *store.FindMemoOrganizer) ([]*store.MemoOrganizer, error) {
|
||||
qb := squirrel.Select("memo_id", "user_id", "pinned").
|
||||
From("memo_organizer").
|
||||
Where("1 = 1").
|
||||
PlaceholderFormat(squirrel.Dollar)
|
||||
|
||||
if find.MemoID != 0 {
|
||||
qb = qb.Where(squirrel.Eq{"memo_id": find.MemoID})
|
||||
}
|
||||
if find.UserID != 0 {
|
||||
qb = qb.Where(squirrel.Eq{"user_id": find.UserID})
|
||||
}
|
||||
|
||||
query, args, err := qb.ToSql()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
rows, err := d.db.QueryContext(ctx, query, args...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var list []*store.MemoOrganizer
|
||||
for rows.Next() {
|
||||
memoOrganizer := &store.MemoOrganizer{}
|
||||
if err := rows.Scan(&memoOrganizer.MemoID, &memoOrganizer.UserID, &memoOrganizer.Pinned); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
list = append(list, memoOrganizer)
|
||||
}
|
||||
|
||||
return list, rows.Err()
|
||||
}
|
||||
|
||||
func (d *DB) DeleteMemoOrganizer(ctx context.Context, delete *store.DeleteMemoOrganizer) error {
|
||||
qb := squirrel.Delete("memo_organizer").
|
||||
PlaceholderFormat(squirrel.Dollar)
|
||||
|
||||
if v := delete.MemoID; v != nil {
|
||||
qb = qb.Where(squirrel.Eq{"memo_id": *v})
|
||||
}
|
||||
if v := delete.UserID; v != nil {
|
||||
qb = qb.Where(squirrel.Eq{"user_id": *v})
|
||||
}
|
||||
|
||||
stmt, args, err := qb.ToSql()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if _, err = d.db.ExecContext(ctx, stmt, args...); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func vacuumMemoOrganizer(ctx context.Context, tx *sql.Tx) error {
|
||||
// First, build the subquery for memo_id
|
||||
subQueryMemo, subArgsMemo, err := squirrel.Select("id").From("memo").PlaceholderFormat(squirrel.Dollar).ToSql()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Build the subquery for user_id
|
||||
subQueryUser, subArgsUser, err := squirrel.Select("id").From("\"user\"").PlaceholderFormat(squirrel.Dollar).ToSql()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Now, build the main delete query using the subqueries
|
||||
query, args, err := squirrel.Delete("memo_organizer").
|
||||
Where(fmt.Sprintf("memo_id NOT IN (%s)", subQueryMemo), subArgsMemo...).
|
||||
Where(fmt.Sprintf("user_id NOT IN (%s)", subQueryUser), subArgsUser...).
|
||||
PlaceholderFormat(squirrel.Dollar).
|
||||
ToSql()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Combine the arguments from both subqueries
|
||||
args = append(args, subArgsUser...)
|
||||
|
||||
// Execute the query
|
||||
_, err = tx.ExecContext(ctx, query, args...)
|
||||
return err
|
||||
}
|
@ -0,0 +1,128 @@
|
||||
package postgres
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
|
||||
"github.com/Masterminds/squirrel"
|
||||
|
||||
"github.com/usememos/memos/store"
|
||||
)
|
||||
|
||||
func (d *DB) UpsertMemoRelation(ctx context.Context, create *store.MemoRelation) (*store.MemoRelation, error) {
|
||||
qb := squirrel.Insert("memo_relation").
|
||||
Columns("memo_id", "related_memo_id", "type").
|
||||
Values(create.MemoID, create.RelatedMemoID, create.Type).
|
||||
PlaceholderFormat(squirrel.Dollar)
|
||||
|
||||
stmt, args, err := qb.ToSql()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
_, err = d.db.ExecContext(ctx, stmt, args...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &store.MemoRelation{
|
||||
MemoID: create.MemoID,
|
||||
RelatedMemoID: create.RelatedMemoID,
|
||||
Type: create.Type,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (d *DB) ListMemoRelations(ctx context.Context, find *store.FindMemoRelation) ([]*store.MemoRelation, error) {
|
||||
qb := squirrel.Select("memo_id", "related_memo_id", "type").
|
||||
From("memo_relation").
|
||||
Where("TRUE").
|
||||
PlaceholderFormat(squirrel.Dollar)
|
||||
|
||||
if find.MemoID != nil {
|
||||
qb = qb.Where(squirrel.Eq{"memo_id": *find.MemoID})
|
||||
}
|
||||
if find.RelatedMemoID != nil {
|
||||
qb = qb.Where(squirrel.Eq{"related_memo_id": *find.RelatedMemoID})
|
||||
}
|
||||
if find.Type != nil {
|
||||
qb = qb.Where(squirrel.Eq{"type": *find.Type})
|
||||
}
|
||||
|
||||
query, args, err := qb.ToSql()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
rows, err := d.db.QueryContext(ctx, query, args...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var list []*store.MemoRelation
|
||||
for rows.Next() {
|
||||
memoRelation := &store.MemoRelation{}
|
||||
if err := rows.Scan(&memoRelation.MemoID, &memoRelation.RelatedMemoID, &memoRelation.Type); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
list = append(list, memoRelation)
|
||||
}
|
||||
|
||||
return list, rows.Err()
|
||||
}
|
||||
|
||||
func (d *DB) DeleteMemoRelation(ctx context.Context, delete *store.DeleteMemoRelation) error {
|
||||
qb := squirrel.Delete("memo_relation").
|
||||
PlaceholderFormat(squirrel.Dollar)
|
||||
|
||||
if delete.MemoID != nil {
|
||||
qb = qb.Where(squirrel.Eq{"memo_id": *delete.MemoID})
|
||||
}
|
||||
if delete.RelatedMemoID != nil {
|
||||
qb = qb.Where(squirrel.Eq{"related_memo_id": *delete.RelatedMemoID})
|
||||
}
|
||||
if delete.Type != nil {
|
||||
qb = qb.Where(squirrel.Eq{"type": *delete.Type})
|
||||
}
|
||||
|
||||
stmt, args, err := qb.ToSql()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
result, err := d.db.ExecContext(ctx, stmt, args...)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_, err = result.RowsAffected()
|
||||
return err
|
||||
}
|
||||
|
||||
func vacuumMemoRelations(ctx context.Context, tx *sql.Tx) error {
|
||||
// First, build the subquery for memo_id
|
||||
subQueryMemo, subArgsMemo, err := squirrel.Select("id").From("memo").PlaceholderFormat(squirrel.Dollar).ToSql()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Note: The same subquery is used for related_memo_id as it's also checking against the "memo" table
|
||||
|
||||
// Now, build the main delete query using the subqueries
|
||||
query, args, err := squirrel.Delete("memo_relation").
|
||||
Where(fmt.Sprintf("memo_id NOT IN (%s)", subQueryMemo), subArgsMemo...).
|
||||
Where(fmt.Sprintf("related_memo_id NOT IN (%s)", subQueryMemo), subArgsMemo...).
|
||||
PlaceholderFormat(squirrel.Dollar).
|
||||
ToSql()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Combine the arguments for both instances of the same subquery
|
||||
args = append(args, subArgsMemo...)
|
||||
|
||||
// Execute the query
|
||||
_, err = tx.ExecContext(ctx, query, args...)
|
||||
return err
|
||||
}
|
@ -0,0 +1,163 @@
|
||||
-- drop all tables first (PostgreSQL style)
|
||||
DROP TABLE IF EXISTS migration_history CASCADE;
|
||||
|
||||
DROP TABLE IF EXISTS system_setting CASCADE;
|
||||
|
||||
DROP TABLE IF EXISTS "user" CASCADE;
|
||||
|
||||
DROP TABLE IF EXISTS user_setting CASCADE;
|
||||
|
||||
DROP TABLE IF EXISTS memo CASCADE;
|
||||
|
||||
DROP TABLE IF EXISTS memo_organizer CASCADE;
|
||||
|
||||
DROP TABLE IF EXISTS memo_relation CASCADE;
|
||||
|
||||
DROP TABLE IF EXISTS resource CASCADE;
|
||||
|
||||
DROP TABLE IF EXISTS tag CASCADE;
|
||||
|
||||
DROP TABLE IF EXISTS activity CASCADE;
|
||||
|
||||
DROP TABLE IF EXISTS storage CASCADE;
|
||||
|
||||
DROP TABLE IF EXISTS idp CASCADE;
|
||||
|
||||
DROP TABLE IF EXISTS inbox CASCADE;
|
||||
|
||||
DROP TABLE IF EXISTS webhook CASCADE;
|
||||
|
||||
-- migration_history
|
||||
CREATE TABLE migration_history (
|
||||
version VARCHAR(255) NOT NULL PRIMARY KEY,
|
||||
created_ts TIMESTAMP WITHOUT TIME ZONE NOT NULL DEFAULT CURRENT_TIMESTAMP
|
||||
);
|
||||
|
||||
-- system_setting
|
||||
CREATE TABLE system_setting (
|
||||
name VARCHAR(255) NOT NULL PRIMARY KEY,
|
||||
value TEXT NOT NULL,
|
||||
description TEXT NOT NULL
|
||||
);
|
||||
|
||||
-- user
|
||||
CREATE TABLE "user" (
|
||||
id SERIAL PRIMARY KEY,
|
||||
created_ts TIMESTAMP WITHOUT TIME ZONE NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
updated_ts TIMESTAMP WITHOUT TIME ZONE NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
row_status VARCHAR(255) NOT NULL DEFAULT 'NORMAL',
|
||||
username VARCHAR(255) NOT NULL UNIQUE,
|
||||
role VARCHAR(255) NOT NULL DEFAULT 'USER',
|
||||
email VARCHAR(255) NOT NULL DEFAULT '',
|
||||
nickname VARCHAR(255) NOT NULL DEFAULT '',
|
||||
password_hash VARCHAR(255) NOT NULL,
|
||||
avatar_url TEXT NOT NULL
|
||||
);
|
||||
|
||||
-- user_setting
|
||||
CREATE TABLE user_setting (
|
||||
user_id INT NOT NULL,
|
||||
key VARCHAR(255) NOT NULL,
|
||||
value TEXT NOT NULL,
|
||||
UNIQUE(user_id, key),
|
||||
FOREIGN KEY (user_id) REFERENCES "user"(id) ON DELETE CASCADE
|
||||
);
|
||||
|
||||
-- memo
|
||||
CREATE TABLE memo (
|
||||
id SERIAL PRIMARY KEY,
|
||||
creator_id INT NOT NULL,
|
||||
created_ts TIMESTAMP WITHOUT TIME ZONE NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
updated_ts TIMESTAMP WITHOUT TIME ZONE NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
row_status VARCHAR(255) NOT NULL DEFAULT 'NORMAL',
|
||||
content TEXT NOT NULL,
|
||||
visibility VARCHAR(255) NOT NULL DEFAULT 'PRIVATE'
|
||||
);
|
||||
|
||||
-- memo_organizer
|
||||
CREATE TABLE memo_organizer (
|
||||
memo_id INT NOT NULL,
|
||||
user_id INT NOT NULL,
|
||||
pinned INT NOT NULL DEFAULT 0,
|
||||
UNIQUE(memo_id, user_id)
|
||||
);
|
||||
|
||||
-- memo_relation
|
||||
CREATE TABLE memo_relation (
|
||||
memo_id INT NOT NULL,
|
||||
related_memo_id INT NOT NULL,
|
||||
type VARCHAR(256) NOT NULL,
|
||||
UNIQUE(memo_id, related_memo_id, type),
|
||||
FOREIGN KEY (memo_id) REFERENCES memo(id) ON DELETE CASCADE,
|
||||
FOREIGN KEY (related_memo_id) REFERENCES memo(id) ON DELETE CASCADE
|
||||
);
|
||||
|
||||
-- resource
|
||||
CREATE TABLE resource (
|
||||
id SERIAL PRIMARY KEY,
|
||||
creator_id INT NOT NULL,
|
||||
created_ts TIMESTAMP WITHOUT TIME ZONE NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
updated_ts TIMESTAMP WITHOUT TIME ZONE NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
filename TEXT NOT NULL,
|
||||
blob BYTEA,
|
||||
external_link TEXT NOT NULL,
|
||||
type VARCHAR(255) NOT NULL DEFAULT '',
|
||||
size INT NOT NULL DEFAULT 0,
|
||||
internal_path VARCHAR(255) NOT NULL DEFAULT '',
|
||||
memo_id INT DEFAULT NULL
|
||||
);
|
||||
|
||||
-- tag
|
||||
CREATE TABLE tag (
|
||||
name VARCHAR(255) NOT NULL,
|
||||
creator_id INT NOT NULL,
|
||||
UNIQUE(name, creator_id)
|
||||
);
|
||||
|
||||
-- activity
|
||||
CREATE TABLE activity (
|
||||
id SERIAL PRIMARY KEY,
|
||||
creator_id INT NOT NULL,
|
||||
created_ts TIMESTAMP WITHOUT TIME ZONE NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
type VARCHAR(255) NOT NULL DEFAULT '',
|
||||
level VARCHAR(255) NOT NULL DEFAULT 'INFO',
|
||||
payload TEXT NOT NULL
|
||||
);
|
||||
|
||||
-- storage
|
||||
CREATE TABLE storage (
|
||||
id SERIAL PRIMARY KEY,
|
||||
name VARCHAR(256) NOT NULL,
|
||||
type VARCHAR(256) NOT NULL,
|
||||
config TEXT NOT NULL
|
||||
);
|
||||
|
||||
-- idp
|
||||
CREATE TABLE idp (
|
||||
id SERIAL PRIMARY KEY,
|
||||
name TEXT NOT NULL,
|
||||
type TEXT NOT NULL,
|
||||
identifier_filter VARCHAR(256) NOT NULL DEFAULT '',
|
||||
config TEXT NOT NULL
|
||||
);
|
||||
|
||||
-- inbox
|
||||
CREATE TABLE inbox (
|
||||
id SERIAL PRIMARY KEY,
|
||||
created_ts TIMESTAMP WITHOUT TIME ZONE NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
sender_id INT NOT NULL,
|
||||
receiver_id INT NOT NULL,
|
||||
status TEXT NOT NULL,
|
||||
message TEXT NOT NULL
|
||||
);
|
||||
|
||||
-- webhook
|
||||
CREATE TABLE webhook (
|
||||
id SERIAL PRIMARY KEY,
|
||||
created_ts TIMESTAMP WITHOUT TIME ZONE NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
updated_ts TIMESTAMP WITHOUT TIME ZONE NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
row_status TEXT NOT NULL DEFAULT 'NORMAL',
|
||||
creator_id INT NOT NULL,
|
||||
name TEXT NOT NULL,
|
||||
url TEXT NOT NULL
|
||||
);
|
@ -0,0 +1,163 @@
|
||||
-- drop all tables first (PostgreSQL style)
|
||||
DROP TABLE IF EXISTS migration_history CASCADE;
|
||||
|
||||
DROP TABLE IF EXISTS system_setting CASCADE;
|
||||
|
||||
DROP TABLE IF EXISTS "user" CASCADE;
|
||||
|
||||
DROP TABLE IF EXISTS user_setting CASCADE;
|
||||
|
||||
DROP TABLE IF EXISTS memo CASCADE;
|
||||
|
||||
DROP TABLE IF EXISTS memo_organizer CASCADE;
|
||||
|
||||
DROP TABLE IF EXISTS memo_relation CASCADE;
|
||||
|
||||
DROP TABLE IF EXISTS resource CASCADE;
|
||||
|
||||
DROP TABLE IF EXISTS tag CASCADE;
|
||||
|
||||
DROP TABLE IF EXISTS activity CASCADE;
|
||||
|
||||
DROP TABLE IF EXISTS storage CASCADE;
|
||||
|
||||
DROP TABLE IF EXISTS idp CASCADE;
|
||||
|
||||
DROP TABLE IF EXISTS inbox CASCADE;
|
||||
|
||||
DROP TABLE IF EXISTS webhook CASCADE;
|
||||
|
||||
-- migration_history
|
||||
CREATE TABLE migration_history (
|
||||
version VARCHAR(255) NOT NULL PRIMARY KEY,
|
||||
created_ts TIMESTAMP WITHOUT TIME ZONE NOT NULL DEFAULT CURRENT_TIMESTAMP
|
||||
);
|
||||
|
||||
-- system_setting
|
||||
CREATE TABLE system_setting (
|
||||
name VARCHAR(255) NOT NULL PRIMARY KEY,
|
||||
value TEXT NOT NULL,
|
||||
description TEXT NOT NULL
|
||||
);
|
||||
|
||||
-- user
|
||||
CREATE TABLE "user" (
|
||||
id SERIAL PRIMARY KEY,
|
||||
created_ts TIMESTAMP WITHOUT TIME ZONE NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
updated_ts TIMESTAMP WITHOUT TIME ZONE NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
row_status VARCHAR(255) NOT NULL DEFAULT 'NORMAL',
|
||||
username VARCHAR(255) NOT NULL UNIQUE,
|
||||
role VARCHAR(255) NOT NULL DEFAULT 'USER',
|
||||
email VARCHAR(255) NOT NULL DEFAULT '',
|
||||
nickname VARCHAR(255) NOT NULL DEFAULT '',
|
||||
password_hash VARCHAR(255) NOT NULL,
|
||||
avatar_url TEXT NOT NULL
|
||||
);
|
||||
|
||||
-- user_setting
|
||||
CREATE TABLE user_setting (
|
||||
user_id INT NOT NULL,
|
||||
key VARCHAR(255) NOT NULL,
|
||||
value TEXT NOT NULL,
|
||||
UNIQUE(user_id, key),
|
||||
FOREIGN KEY (user_id) REFERENCES "user"(id) ON DELETE CASCADE
|
||||
);
|
||||
|
||||
-- memo
|
||||
CREATE TABLE memo (
|
||||
id SERIAL PRIMARY KEY,
|
||||
creator_id INT NOT NULL,
|
||||
created_ts TIMESTAMP WITHOUT TIME ZONE NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
updated_ts TIMESTAMP WITHOUT TIME ZONE NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
row_status VARCHAR(255) NOT NULL DEFAULT 'NORMAL',
|
||||
content TEXT NOT NULL,
|
||||
visibility VARCHAR(255) NOT NULL DEFAULT 'PRIVATE'
|
||||
);
|
||||
|
||||
-- memo_organizer
|
||||
CREATE TABLE memo_organizer (
|
||||
memo_id INT NOT NULL,
|
||||
user_id INT NOT NULL,
|
||||
pinned INT NOT NULL DEFAULT 0,
|
||||
UNIQUE(memo_id, user_id)
|
||||
);
|
||||
|
||||
-- memo_relation
|
||||
CREATE TABLE memo_relation (
|
||||
memo_id INT NOT NULL,
|
||||
related_memo_id INT NOT NULL,
|
||||
type VARCHAR(256) NOT NULL,
|
||||
UNIQUE(memo_id, related_memo_id, type),
|
||||
FOREIGN KEY (memo_id) REFERENCES memo(id) ON DELETE CASCADE,
|
||||
FOREIGN KEY (related_memo_id) REFERENCES memo(id) ON DELETE CASCADE
|
||||
);
|
||||
|
||||
-- resource
|
||||
CREATE TABLE resource (
|
||||
id SERIAL PRIMARY KEY,
|
||||
creator_id INT NOT NULL,
|
||||
created_ts TIMESTAMP WITHOUT TIME ZONE NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
updated_ts TIMESTAMP WITHOUT TIME ZONE NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
filename TEXT NOT NULL,
|
||||
blob BYTEA,
|
||||
external_link TEXT NOT NULL,
|
||||
type VARCHAR(255) NOT NULL DEFAULT '',
|
||||
size INT NOT NULL DEFAULT 0,
|
||||
internal_path VARCHAR(255) NOT NULL DEFAULT '',
|
||||
memo_id INT DEFAULT NULL
|
||||
);
|
||||
|
||||
-- tag
|
||||
CREATE TABLE tag (
|
||||
name VARCHAR(255) NOT NULL,
|
||||
creator_id INT NOT NULL,
|
||||
UNIQUE(name, creator_id)
|
||||
);
|
||||
|
||||
-- activity
|
||||
CREATE TABLE activity (
|
||||
id SERIAL PRIMARY KEY,
|
||||
creator_id INT NOT NULL,
|
||||
created_ts TIMESTAMP WITHOUT TIME ZONE NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
type VARCHAR(255) NOT NULL DEFAULT '',
|
||||
level VARCHAR(255) NOT NULL DEFAULT 'INFO',
|
||||
payload TEXT NOT NULL
|
||||
);
|
||||
|
||||
-- storage
|
||||
CREATE TABLE storage (
|
||||
id SERIAL PRIMARY KEY,
|
||||
name VARCHAR(256) NOT NULL,
|
||||
type VARCHAR(256) NOT NULL,
|
||||
config TEXT NOT NULL
|
||||
);
|
||||
|
||||
-- idp
|
||||
CREATE TABLE idp (
|
||||
id SERIAL PRIMARY KEY,
|
||||
name TEXT NOT NULL,
|
||||
type TEXT NOT NULL,
|
||||
identifier_filter VARCHAR(256) NOT NULL DEFAULT '',
|
||||
config TEXT NOT NULL
|
||||
);
|
||||
|
||||
-- inbox
|
||||
CREATE TABLE inbox (
|
||||
id SERIAL PRIMARY KEY,
|
||||
created_ts TIMESTAMP WITHOUT TIME ZONE NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
sender_id INT NOT NULL,
|
||||
receiver_id INT NOT NULL,
|
||||
status TEXT NOT NULL,
|
||||
message TEXT NOT NULL
|
||||
);
|
||||
|
||||
-- webhook
|
||||
CREATE TABLE webhook (
|
||||
id SERIAL PRIMARY KEY,
|
||||
created_ts TIMESTAMP WITHOUT TIME ZONE NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
updated_ts TIMESTAMP WITHOUT TIME ZONE NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
row_status TEXT NOT NULL DEFAULT 'NORMAL',
|
||||
creator_id INT NOT NULL,
|
||||
name TEXT NOT NULL,
|
||||
url TEXT NOT NULL
|
||||
);
|
@ -0,0 +1,79 @@
|
||||
package postgres
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"github.com/Masterminds/squirrel"
|
||||
|
||||
"github.com/usememos/memos/store"
|
||||
)
|
||||
|
||||
func (d *DB) FindMigrationHistoryList(ctx context.Context, _ *store.FindMigrationHistory) ([]*store.MigrationHistory, error) {
|
||||
qb := squirrel.Select("version", "created_ts").
|
||||
From("migration_history").
|
||||
OrderBy("created_ts DESC")
|
||||
|
||||
query, args, err := qb.PlaceholderFormat(squirrel.Dollar).ToSql()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
rows, err := d.db.QueryContext(ctx, query, args...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
list := make([]*store.MigrationHistory, 0)
|
||||
for rows.Next() {
|
||||
var migrationHistory store.MigrationHistory
|
||||
var createdTs time.Time
|
||||
if err := rows.Scan(&migrationHistory.Version, &createdTs); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
migrationHistory.CreatedTs = createdTs.UnixNano()
|
||||
list = append(list, &migrationHistory)
|
||||
}
|
||||
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return list, nil
|
||||
}
|
||||
|
||||
func (d *DB) UpsertMigrationHistory(ctx context.Context, upsert *store.UpsertMigrationHistory) (*store.MigrationHistory, error) {
|
||||
qb := squirrel.Insert("migration_history").
|
||||
Columns("version").
|
||||
Values(upsert.Version).
|
||||
Suffix("ON CONFLICT (version) DO UPDATE SET version = ?", upsert.Version)
|
||||
|
||||
query, args, err := qb.PlaceholderFormat(squirrel.Dollar).ToSql()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
_, err = d.db.ExecContext(ctx, query, args...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var migrationHistory store.MigrationHistory
|
||||
var createdTs time.Time
|
||||
query, args, err = squirrel.Select("version", "created_ts").
|
||||
From("migration_history").
|
||||
Where(squirrel.Eq{"version": upsert.Version}).
|
||||
PlaceholderFormat(squirrel.Dollar).
|
||||
ToSql()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := d.db.QueryRowContext(ctx, query, args...).Scan(&migrationHistory.Version, &createdTs); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
migrationHistory.CreatedTs = createdTs.UnixNano()
|
||||
|
||||
return &migrationHistory, nil
|
||||
}
|
@ -0,0 +1,207 @@
|
||||
package postgres
|
||||
|
||||
import (
|
||||
"context"
|
||||
"embed"
|
||||
"fmt"
|
||||
"io/fs"
|
||||
"regexp"
|
||||
"sort"
|
||||
"strings"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
|
||||
"github.com/usememos/memos/server/version"
|
||||
"github.com/usememos/memos/store"
|
||||
)
|
||||
|
||||
const (
|
||||
latestSchemaFileName = "LATEST__SCHEMA.sql"
|
||||
)
|
||||
|
||||
//go:embed migration
|
||||
var migrationFS embed.FS
|
||||
|
||||
func (d *DB) Migrate(ctx context.Context) error {
|
||||
if d.profile.IsDev() {
|
||||
return d.nonProdMigrate(ctx)
|
||||
}
|
||||
|
||||
return d.prodMigrate(ctx)
|
||||
}
|
||||
|
||||
func (d *DB) nonProdMigrate(ctx context.Context) error {
|
||||
rows, err := d.db.QueryContext(ctx, "SELECT tablename FROM pg_catalog.pg_tables WHERE schemaname != 'pg_catalog' AND schemaname != 'information_schema';")
|
||||
if err != nil {
|
||||
return errors.Errorf("failed to query database tables: %s", err)
|
||||
}
|
||||
if rows.Err() != nil {
|
||||
return errors.Errorf("failed to query database tables: %s", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var tables []string
|
||||
for rows.Next() {
|
||||
var table string
|
||||
err := rows.Scan(&table)
|
||||
if err != nil {
|
||||
return errors.Errorf("failed to scan table name: %s", err)
|
||||
}
|
||||
tables = append(tables, table)
|
||||
}
|
||||
|
||||
if len(tables) != 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
println("no tables in the database. start migration")
|
||||
|
||||
buf, err := migrationFS.ReadFile("migration/dev/" + latestSchemaFileName)
|
||||
if err != nil {
|
||||
return errors.Errorf("failed to read latest schema file: %s", err)
|
||||
}
|
||||
|
||||
stmt := string(buf)
|
||||
if _, err := d.db.ExecContext(ctx, stmt); err != nil {
|
||||
return errors.Errorf("failed to exec SQL %s: %s", stmt, err)
|
||||
}
|
||||
|
||||
// In demo mode, we should seed the database.
|
||||
if d.profile.Mode == "demo" {
|
||||
if err := d.seed(ctx); err != nil {
|
||||
return errors.Wrap(err, "failed to seed")
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d *DB) prodMigrate(ctx context.Context) error {
|
||||
currentVersion := version.GetCurrentVersion(d.profile.Mode)
|
||||
migrationHistoryList, err := d.FindMigrationHistoryList(ctx, &store.FindMigrationHistory{})
|
||||
// If there is no migration history, we should apply the latest schema.
|
||||
if err != nil || len(migrationHistoryList) == 0 {
|
||||
buf, err := migrationFS.ReadFile("migration/prod/" + latestSchemaFileName)
|
||||
if err != nil {
|
||||
return errors.Errorf("failed to read latest schema file: %s", err)
|
||||
}
|
||||
|
||||
stmt := string(buf)
|
||||
if _, err := d.db.ExecContext(ctx, stmt); err != nil {
|
||||
return errors.Errorf("failed to exec SQL %s: %s", stmt, err)
|
||||
}
|
||||
if _, err := d.UpsertMigrationHistory(ctx, &store.UpsertMigrationHistory{
|
||||
Version: currentVersion,
|
||||
}); err != nil {
|
||||
return errors.Wrap(err, "failed to upsert migration history")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
migrationHistoryVersionList := []string{}
|
||||
for _, migrationHistory := range migrationHistoryList {
|
||||
migrationHistoryVersionList = append(migrationHistoryVersionList, migrationHistory.Version)
|
||||
}
|
||||
sort.Sort(version.SortVersion(migrationHistoryVersionList))
|
||||
latestMigrationHistoryVersion := migrationHistoryVersionList[len(migrationHistoryVersionList)-1]
|
||||
if !version.IsVersionGreaterThan(version.GetSchemaVersion(currentVersion), latestMigrationHistoryVersion) {
|
||||
return nil
|
||||
}
|
||||
|
||||
println("start migrate")
|
||||
for _, minorVersion := range getMinorVersionList() {
|
||||
normalizedVersion := minorVersion + ".0"
|
||||
if version.IsVersionGreaterThan(normalizedVersion, latestMigrationHistoryVersion) && version.IsVersionGreaterOrEqualThan(currentVersion, normalizedVersion) {
|
||||
println("applying migration for", normalizedVersion)
|
||||
if err := d.applyMigrationForMinorVersion(ctx, minorVersion); err != nil {
|
||||
return errors.Wrap(err, "failed to apply minor version migration")
|
||||
}
|
||||
}
|
||||
}
|
||||
println("end migrate")
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d *DB) applyMigrationForMinorVersion(ctx context.Context, minorVersion string) error {
|
||||
filenames, err := fs.Glob(migrationFS, fmt.Sprintf("migration/prod/%s/*.sql", minorVersion))
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "failed to read ddl files")
|
||||
}
|
||||
|
||||
sort.Strings(filenames)
|
||||
// Loop over all migration files and execute them in order.
|
||||
for _, filename := range filenames {
|
||||
buf, err := migrationFS.ReadFile(filename)
|
||||
if err != nil {
|
||||
return errors.Wrapf(err, "failed to read minor version migration file, filename=%s", filename)
|
||||
}
|
||||
for _, stmt := range strings.Split(string(buf), ";") {
|
||||
if strings.TrimSpace(stmt) == "" {
|
||||
continue
|
||||
}
|
||||
if _, err := d.db.ExecContext(ctx, stmt); err != nil {
|
||||
return errors.Wrapf(err, "migrate error: %s", stmt)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Upsert the newest version to migration_history.
|
||||
version := minorVersion + ".0"
|
||||
if _, err = d.UpsertMigrationHistory(ctx, &store.UpsertMigrationHistory{Version: version}); err != nil {
|
||||
return errors.Wrapf(err, "failed to upsert migration history with version: %s", version)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
//go:embed seed
|
||||
var seedFS embed.FS
|
||||
|
||||
func (d *DB) seed(ctx context.Context) error {
|
||||
filenames, err := fs.Glob(seedFS, "seed/*.sql")
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "failed to read seed files")
|
||||
}
|
||||
|
||||
sort.Strings(filenames)
|
||||
// Loop over all seed files and execute them in order.
|
||||
for _, filename := range filenames {
|
||||
buf, err := seedFS.ReadFile(filename)
|
||||
if err != nil {
|
||||
return errors.Wrapf(err, "failed to read seed file, filename=%s", filename)
|
||||
}
|
||||
|
||||
for _, stmt := range strings.Split(string(buf), ";") {
|
||||
if strings.TrimSpace(stmt) == "" {
|
||||
continue
|
||||
}
|
||||
if _, err := d.db.ExecContext(ctx, stmt); err != nil {
|
||||
return errors.Wrapf(err, "seed error: %s", stmt)
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// minorDirRegexp is a regular expression for minor version directory.
|
||||
var minorDirRegexp = regexp.MustCompile(`^migration/prod/[0-9]+\.[0-9]+$`)
|
||||
|
||||
func getMinorVersionList() []string {
|
||||
minorVersionList := []string{}
|
||||
|
||||
if err := fs.WalkDir(migrationFS, "migration", func(path string, file fs.DirEntry, err error) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if file.IsDir() && minorDirRegexp.MatchString(path) {
|
||||
minorVersionList = append(minorVersionList, file.Name())
|
||||
}
|
||||
|
||||
return nil
|
||||
}); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
sort.Sort(version.SortVersion(minorVersionList))
|
||||
|
||||
return minorVersionList
|
||||
}
|
@ -0,0 +1,87 @@
|
||||
package postgres
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"log"
|
||||
|
||||
// Import the PostgreSQL driver.
|
||||
_ "github.com/lib/pq"
|
||||
"github.com/pkg/errors"
|
||||
|
||||
"github.com/usememos/memos/server/profile"
|
||||
"github.com/usememos/memos/store"
|
||||
)
|
||||
|
||||
type DB struct {
|
||||
db *sql.DB
|
||||
profile *profile.Profile
|
||||
// Add any other fields as needed
|
||||
}
|
||||
|
||||
func NewDB(profile *profile.Profile) (store.Driver, error) {
|
||||
if profile == nil {
|
||||
return nil, errors.New("profile is nil")
|
||||
}
|
||||
|
||||
// Open the PostgreSQL connection
|
||||
db, err := sql.Open("postgres", profile.DSN)
|
||||
if err != nil {
|
||||
log.Printf("Failed to open database: %s", err)
|
||||
return nil, errors.Wrapf(err, "failed to open database: %s", profile.DSN)
|
||||
}
|
||||
|
||||
var driver store.Driver = &DB{
|
||||
db: db,
|
||||
profile: profile,
|
||||
}
|
||||
|
||||
// Return the DB struct
|
||||
return driver, nil
|
||||
}
|
||||
|
||||
func (d *DB) GetDB() *sql.DB {
|
||||
return d.db
|
||||
}
|
||||
|
||||
func (d *DB) Vacuum(ctx context.Context) error {
|
||||
tx, err := d.db.BeginTx(ctx, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer tx.Rollback()
|
||||
|
||||
if err := vacuumMemo(ctx, tx); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := vacuumResource(ctx, tx); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := vacuumUserSetting(ctx, tx); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := vacuumMemoOrganizer(ctx, tx); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := vacuumMemoRelations(ctx, tx); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := vacuumTag(ctx, tx); err != nil {
|
||||
// Prevent revive warning.
|
||||
return err
|
||||
}
|
||||
|
||||
return tx.Commit()
|
||||
}
|
||||
|
||||
func (*DB) BackupTo(context.Context, string) error {
|
||||
return errors.New("Please use postgresdump to backup")
|
||||
}
|
||||
|
||||
func (*DB) GetCurrentDBSize(context.Context) (int64, error) {
|
||||
return 0, errors.New("unimplemented")
|
||||
}
|
||||
|
||||
func (d *DB) Close() error {
|
||||
return d.db.Close()
|
||||
}
|
@ -0,0 +1,229 @@
|
||||
package postgres
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/Masterminds/squirrel"
|
||||
"github.com/pkg/errors"
|
||||
|
||||
"github.com/usememos/memos/store"
|
||||
)
|
||||
|
||||
func (d *DB) CreateResource(ctx context.Context, create *store.Resource) (*store.Resource, error) {
|
||||
qb := squirrel.Insert("resource").Columns("filename", "blob", "external_link", "type", "size", "creator_id", "internal_path")
|
||||
values := []any{create.Filename, create.Blob, create.ExternalLink, create.Type, create.Size, create.CreatorID, create.InternalPath}
|
||||
|
||||
if create.ID != 0 {
|
||||
qb = qb.Columns("id")
|
||||
values = append(values, create.ID)
|
||||
}
|
||||
|
||||
if create.CreatedTs != 0 {
|
||||
qb = qb.Columns("created_ts")
|
||||
values = append(values, time.Unix(0, create.CreatedTs))
|
||||
}
|
||||
|
||||
if create.UpdatedTs != 0 {
|
||||
qb = qb.Columns("updated_ts")
|
||||
values = append(values, time.Unix(0, create.UpdatedTs))
|
||||
}
|
||||
|
||||
if create.MemoID != nil {
|
||||
qb = qb.Columns("memo_id")
|
||||
values = append(values, *create.MemoID)
|
||||
}
|
||||
|
||||
qb = qb.Values(values...).Suffix("RETURNING id")
|
||||
query, args, err := qb.PlaceholderFormat(squirrel.Dollar).ToSql()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var id int32
|
||||
err = d.db.QueryRowContext(ctx, query, args...).Scan(&id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
list, err := d.ListResources(ctx, &store.FindResource{ID: &id})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(list) != 1 {
|
||||
return nil, errors.Wrapf(nil, "unexpected resource count: %d", len(list))
|
||||
}
|
||||
|
||||
return list[0], nil
|
||||
}
|
||||
|
||||
func (d *DB) ListResources(ctx context.Context, find *store.FindResource) ([]*store.Resource, error) {
|
||||
qb := squirrel.Select("id", "filename", "external_link", "type", "size", "creator_id", "created_ts", "updated_ts", "internal_path", "memo_id").From("resource")
|
||||
|
||||
if v := find.ID; v != nil {
|
||||
qb = qb.Where(squirrel.Eq{"id": *v})
|
||||
}
|
||||
if v := find.CreatorID; v != nil {
|
||||
qb = qb.Where(squirrel.Eq{"creator_id": *v})
|
||||
}
|
||||
if v := find.Filename; v != nil {
|
||||
qb = qb.Where(squirrel.Eq{"filename": *v})
|
||||
}
|
||||
if v := find.MemoID; v != nil {
|
||||
qb = qb.Where(squirrel.Eq{"memo_id": *v})
|
||||
}
|
||||
if find.HasRelatedMemo {
|
||||
qb = qb.Where("memo_id IS NOT NULL")
|
||||
}
|
||||
if find.GetBlob {
|
||||
qb = qb.Columns("blob")
|
||||
}
|
||||
|
||||
qb = qb.GroupBy("id").OrderBy("created_ts DESC")
|
||||
|
||||
if find.Limit != nil {
|
||||
qb = qb.Limit(uint64(*find.Limit))
|
||||
if find.Offset != nil {
|
||||
qb = qb.Offset(uint64(*find.Offset))
|
||||
}
|
||||
}
|
||||
|
||||
query, args, err := qb.PlaceholderFormat(squirrel.Dollar).ToSql()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
rows, err := d.db.QueryContext(ctx, query, args...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
list := make([]*store.Resource, 0)
|
||||
for rows.Next() {
|
||||
resource := store.Resource{}
|
||||
var memoID sql.NullInt32
|
||||
var createdTs, updatedTs time.Time
|
||||
dests := []any{
|
||||
&resource.ID,
|
||||
&resource.Filename,
|
||||
&resource.ExternalLink,
|
||||
&resource.Type,
|
||||
&resource.Size,
|
||||
&resource.CreatorID,
|
||||
&createdTs,
|
||||
&updatedTs,
|
||||
&resource.InternalPath,
|
||||
&memoID,
|
||||
}
|
||||
if find.GetBlob {
|
||||
dests = append(dests, &resource.Blob)
|
||||
}
|
||||
if err := rows.Scan(dests...); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
resource.CreatedTs = createdTs.UnixNano()
|
||||
resource.UpdatedTs = updatedTs.UnixNano()
|
||||
|
||||
if memoID.Valid {
|
||||
resource.MemoID = &memoID.Int32
|
||||
}
|
||||
list = append(list, &resource)
|
||||
}
|
||||
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return list, nil
|
||||
}
|
||||
|
||||
func (d *DB) UpdateResource(ctx context.Context, update *store.UpdateResource) (*store.Resource, error) {
|
||||
qb := squirrel.Update("resource")
|
||||
|
||||
if v := update.UpdatedTs; v != nil {
|
||||
qb = qb.Set("updated_ts", time.Unix(0, *v))
|
||||
}
|
||||
if v := update.Filename; v != nil {
|
||||
qb = qb.Set("filename", *v)
|
||||
}
|
||||
if v := update.InternalPath; v != nil {
|
||||
qb = qb.Set("internal_path", *v)
|
||||
}
|
||||
if v := update.MemoID; v != nil {
|
||||
qb = qb.Set("memo_id", *v)
|
||||
}
|
||||
if v := update.Blob; v != nil {
|
||||
qb = qb.Set("blob", v)
|
||||
}
|
||||
|
||||
qb = qb.Where(squirrel.Eq{"id": update.ID})
|
||||
|
||||
query, args, err := qb.PlaceholderFormat(squirrel.Dollar).ToSql()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if _, err := d.db.ExecContext(ctx, query, args...); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
list, err := d.ListResources(ctx, &store.FindResource{ID: &update.ID})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(list) != 1 {
|
||||
return nil, errors.Wrapf(nil, "unexpected resource count: %d", len(list))
|
||||
}
|
||||
|
||||
return list[0], nil
|
||||
}
|
||||
|
||||
func (d *DB) DeleteResource(ctx context.Context, delete *store.DeleteResource) error {
|
||||
qb := squirrel.Delete("resource").Where(squirrel.Eq{"id": delete.ID})
|
||||
|
||||
query, args, err := qb.PlaceholderFormat(squirrel.Dollar).ToSql()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
result, err := d.db.ExecContext(ctx, query, args...)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if _, err := result.RowsAffected(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := d.Vacuum(ctx); err != nil {
|
||||
// Prevent linter warning.
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func vacuumResource(ctx context.Context, tx *sql.Tx) error {
|
||||
// First, build the subquery
|
||||
subQuery, subArgs, err := squirrel.Select("id").From("user").PlaceholderFormat(squirrel.Dollar).ToSql()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Now, build the main delete query using the subquery
|
||||
query, args, err := squirrel.Delete("resource").
|
||||
Where(fmt.Sprintf("creator_id NOT IN (%s)", subQuery), subArgs...).
|
||||
PlaceholderFormat(squirrel.Dollar).
|
||||
ToSql()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Execute the query
|
||||
_, err = tx.ExecContext(ctx, query, args...)
|
||||
return err
|
||||
}
|
@ -0,0 +1,4 @@
|
||||
TRUNCATE TABLE memo_organizer;
|
||||
TRUNCATE TABLE resource;
|
||||
TRUNCATE TABLE memo;
|
||||
TRUNCATE TABLE user;
|
@ -0,0 +1,44 @@
|
||||
INSERT INTO "user" (
|
||||
id,
|
||||
username,
|
||||
role,
|
||||
email,
|
||||
nickname,
|
||||
row_status,
|
||||
avatar_url,
|
||||
password_hash
|
||||
)
|
||||
VALUES
|
||||
(
|
||||
101,
|
||||
'memos-demo',
|
||||
'HOST',
|
||||
'demo@usememos.com',
|
||||
'Derobot',
|
||||
'NORMAL',
|
||||
'',
|
||||
-- raw password: secret
|
||||
'$2a$14$ajq8Q7fbtFRQvXpdCq7Jcuy.Rx1h/L4J60Otx.gyNLbAYctGMJ9tK'
|
||||
),
|
||||
(
|
||||
102,
|
||||
'jack',
|
||||
'USER',
|
||||
'jack@usememos.com',
|
||||
'Jack',
|
||||
'NORMAL',
|
||||
'',
|
||||
-- raw password: secret
|
||||
'$2a$14$ajq8Q7fbtFRQvXpdCq7Jcuy.Rx1h/L4J60Otx.gyNLbAYctGMJ9tK'
|
||||
),
|
||||
(
|
||||
103,
|
||||
'bob',
|
||||
'USER',
|
||||
'bob@usememos.com',
|
||||
'Bob',
|
||||
'ARCHIVED',
|
||||
'',
|
||||
-- raw password: secret
|
||||
'$2a$14$ajq8Q7fbtFRQvXpdCq7Jcuy.Rx1h/L4J60Otx.gyNLbAYctGMJ9tK'
|
||||
);
|
@ -0,0 +1,34 @@
|
||||
INSERT INTO memo (id, content, creator_id)
|
||||
VALUES
|
||||
(
|
||||
1,
|
||||
'#Hello 👋 Welcome to memos.',
|
||||
101
|
||||
);
|
||||
|
||||
INSERT INTO memo (id, content, creator_id, visibility)
|
||||
VALUES
|
||||
(
|
||||
2,
|
||||
E'#TODO\n- [x] Take more photos about **🌄 sunset**\n- [x] Clean the room\n- [ ] Read *📖 The Little Prince*\n(👆 click to toggle status)',
|
||||
101,
|
||||
'PROTECTED'
|
||||
),
|
||||
(
|
||||
3,
|
||||
E'**[Slash](https://github.com/yourselfhosted/slash)**: A bookmarking and url shortener, save and share your links very easily.\n**[SQL Chat](https://www.sqlchat.ai)**: Chat-based SQL Client',
|
||||
101,
|
||||
'PUBLIC'
|
||||
),
|
||||
(
|
||||
4,
|
||||
E'#TODO\n- [x] Take more photos about **🌄 sunset**\n- [ ] Clean the classroom\n- [ ] Watch *👦 The Boys*\n(👆 click to toggle status)',
|
||||
102,
|
||||
'PROTECTED'
|
||||
),
|
||||
(
|
||||
5,
|
||||
'三人行,必有我师焉!👨🏫',
|
||||
102,
|
||||
'PUBLIC'
|
||||
);
|
@ -0,0 +1,5 @@
|
||||
INSERT INTO
|
||||
memo_organizer (memo_id, user_id, pinned)
|
||||
VALUES
|
||||
(1, 101, 1),
|
||||
(3, 101, 1);
|
@ -0,0 +1,6 @@
|
||||
INSERT INTO
|
||||
tag (name, creator_id)
|
||||
VALUES
|
||||
('Hello', 101),
|
||||
('TODO', 101),
|
||||
('TODO', 102);
|
@ -0,0 +1,125 @@
|
||||
package postgres
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/Masterminds/squirrel"
|
||||
|
||||
"github.com/usememos/memos/store"
|
||||
)
|
||||
|
||||
func (d *DB) CreateStorage(ctx context.Context, create *store.Storage) (*store.Storage, error) {
|
||||
qb := squirrel.Insert("storage").Columns("name", "type", "config")
|
||||
values := []any{create.Name, create.Type, create.Config}
|
||||
|
||||
if create.ID != 0 {
|
||||
qb = qb.Columns("id")
|
||||
values = append(values, create.ID)
|
||||
}
|
||||
|
||||
qb = qb.Values(values...).Suffix("RETURNING id")
|
||||
query, args, err := qb.PlaceholderFormat(squirrel.Dollar).ToSql()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
err = d.db.QueryRowContext(ctx, query, args...).Scan(&create.ID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return create, nil
|
||||
}
|
||||
|
||||
func (d *DB) ListStorages(ctx context.Context, find *store.FindStorage) ([]*store.Storage, error) {
|
||||
qb := squirrel.Select("id", "name", "type", "config").From("storage").OrderBy("id DESC")
|
||||
|
||||
if find.ID != nil {
|
||||
qb = qb.Where(squirrel.Eq{"id": *find.ID})
|
||||
}
|
||||
|
||||
query, args, err := qb.PlaceholderFormat(squirrel.Dollar).ToSql()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
rows, err := d.db.QueryContext(ctx, query, args...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
list := []*store.Storage{}
|
||||
for rows.Next() {
|
||||
storage := &store.Storage{}
|
||||
if err := rows.Scan(&storage.ID, &storage.Name, &storage.Type, &storage.Config); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
list = append(list, storage)
|
||||
}
|
||||
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return list, nil
|
||||
}
|
||||
|
||||
func (d *DB) UpdateStorage(ctx context.Context, update *store.UpdateStorage) (*store.Storage, error) {
|
||||
qb := squirrel.Update("storage")
|
||||
|
||||
if update.Name != nil {
|
||||
qb = qb.Set("name", *update.Name)
|
||||
}
|
||||
if update.Config != nil {
|
||||
qb = qb.Set("config", *update.Config)
|
||||
}
|
||||
|
||||
qb = qb.Where(squirrel.Eq{"id": update.ID})
|
||||
|
||||
query, args, err := qb.PlaceholderFormat(squirrel.Dollar).ToSql()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
_, err = d.db.ExecContext(ctx, query, args...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
storage := &store.Storage{}
|
||||
query, args, err = squirrel.Select("id", "name", "type", "config").
|
||||
From("storage").
|
||||
Where(squirrel.Eq{"id": update.ID}).
|
||||
PlaceholderFormat(squirrel.Dollar).
|
||||
ToSql()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := d.db.QueryRowContext(ctx, query, args...).Scan(&storage.ID, &storage.Name, &storage.Type, &storage.Config); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return storage, nil
|
||||
}
|
||||
|
||||
func (d *DB) DeleteStorage(ctx context.Context, delete *store.DeleteStorage) error {
|
||||
qb := squirrel.Delete("storage").Where(squirrel.Eq{"id": delete.ID})
|
||||
|
||||
query, args, err := qb.PlaceholderFormat(squirrel.Dollar).ToSql()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
result, err := d.db.ExecContext(ctx, query, args...)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if _, err := result.RowsAffected(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
@ -0,0 +1,61 @@
|
||||
package postgres
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/Masterminds/squirrel"
|
||||
|
||||
"github.com/usememos/memos/store"
|
||||
)
|
||||
|
||||
func (d *DB) UpsertSystemSetting(ctx context.Context, upsert *store.SystemSetting) (*store.SystemSetting, error) {
|
||||
qb := squirrel.Insert("system_setting").
|
||||
Columns("name", "value", "description").
|
||||
Values(upsert.Name, upsert.Value, upsert.Description)
|
||||
|
||||
query, args, err := qb.PlaceholderFormat(squirrel.Dollar).ToSql()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
_, err = d.db.ExecContext(ctx, query, args...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return upsert, nil
|
||||
}
|
||||
|
||||
func (d *DB) ListSystemSettings(ctx context.Context, find *store.FindSystemSetting) ([]*store.SystemSetting, error) {
|
||||
qb := squirrel.Select("name", "value", "description").From("system_setting")
|
||||
|
||||
if find.Name != "" {
|
||||
qb = qb.Where(squirrel.Eq{"name": find.Name})
|
||||
}
|
||||
|
||||
query, args, err := qb.PlaceholderFormat(squirrel.Dollar).ToSql()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
rows, err := d.db.QueryContext(ctx, query, args...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
list := []*store.SystemSetting{}
|
||||
for rows.Next() {
|
||||
systemSetting := &store.SystemSetting{}
|
||||
if err := rows.Scan(&systemSetting.Name, &systemSetting.Value, &systemSetting.Description); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
list = append(list, systemSetting)
|
||||
}
|
||||
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return list, nil
|
||||
}
|
@ -0,0 +1,113 @@
|
||||
package postgres
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
|
||||
"github.com/Masterminds/squirrel"
|
||||
|
||||
"github.com/usememos/memos/store"
|
||||
)
|
||||
|
||||
func (d *DB) UpsertTag(ctx context.Context, upsert *store.Tag) (*store.Tag, error) {
|
||||
builder := squirrel.Insert("tag").
|
||||
Columns("name", "creator_id").
|
||||
Values(upsert.Name, upsert.CreatorID). // on conflict is not necessary, as only the pair of name and creator_id is unique
|
||||
PlaceholderFormat(squirrel.Dollar)
|
||||
|
||||
query, args, err := builder.ToSql()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if _, err := d.db.ExecContext(ctx, query, args...); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return upsert, nil
|
||||
}
|
||||
|
||||
func (d *DB) ListTags(ctx context.Context, find *store.FindTag) ([]*store.Tag, error) {
|
||||
builder := squirrel.Select("name", "creator_id").From("tag").
|
||||
Where("1 = 1").
|
||||
OrderBy("name ASC").
|
||||
PlaceholderFormat(squirrel.Dollar)
|
||||
|
||||
if find.CreatorID != 0 {
|
||||
builder = builder.Where("creator_id = ?", find.CreatorID)
|
||||
}
|
||||
|
||||
query, args, err := builder.ToSql()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
rows, err := d.db.QueryContext(ctx, query, args...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
list := []*store.Tag{}
|
||||
for rows.Next() {
|
||||
tag := &store.Tag{}
|
||||
if err := rows.Scan(
|
||||
&tag.Name,
|
||||
&tag.CreatorID,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
list = append(list, tag)
|
||||
}
|
||||
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return list, nil
|
||||
}
|
||||
|
||||
func (d *DB) DeleteTag(ctx context.Context, delete *store.DeleteTag) error {
|
||||
builder := squirrel.Delete("tag").
|
||||
Where(squirrel.Eq{"name": delete.Name, "creator_id": delete.CreatorID}).
|
||||
PlaceholderFormat(squirrel.Dollar)
|
||||
|
||||
query, args, err := builder.ToSql()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
result, err := d.db.ExecContext(ctx, query, args...)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if _, err = result.RowsAffected(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func vacuumTag(ctx context.Context, tx *sql.Tx) error {
|
||||
// First, build the subquery for creator_id
|
||||
subQuery, subArgs, err := squirrel.Select("id").From("\"user\"").PlaceholderFormat(squirrel.Dollar).ToSql()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Now, build the main delete query using the subquery
|
||||
query, args, err := squirrel.Delete("tag").
|
||||
Where(fmt.Sprintf("creator_id NOT IN (%s)", subQuery), subArgs...).
|
||||
PlaceholderFormat(squirrel.Dollar).
|
||||
ToSql()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Execute the query
|
||||
_, err = tx.ExecContext(ctx, query, args...)
|
||||
return err
|
||||
}
|
@ -0,0 +1,225 @@
|
||||
package postgres
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/Masterminds/squirrel"
|
||||
"github.com/pkg/errors"
|
||||
|
||||
"github.com/usememos/memos/store"
|
||||
)
|
||||
|
||||
func (d *DB) CreateUser(ctx context.Context, create *store.User) (*store.User, error) {
|
||||
// Start building the insert statement
|
||||
builder := squirrel.Insert("\"user\"").PlaceholderFormat(squirrel.Dollar)
|
||||
|
||||
columns := []string{"username", "role", "email", "nickname", "password_hash", "avatar_url"}
|
||||
builder = builder.Columns(columns...)
|
||||
|
||||
values := []any{create.Username, create.Role, create.Email, create.Nickname, create.PasswordHash, create.AvatarURL}
|
||||
|
||||
if create.RowStatus != "" {
|
||||
builder = builder.Columns("row_status")
|
||||
values = append(values, create.RowStatus)
|
||||
}
|
||||
|
||||
if create.CreatedTs != 0 {
|
||||
builder = builder.Columns("created_ts")
|
||||
values = append(values, squirrel.Expr("TO_TIMESTAMP(?)", create.CreatedTs))
|
||||
}
|
||||
|
||||
if create.UpdatedTs != 0 {
|
||||
builder = builder.Columns("updated_ts")
|
||||
values = append(values, squirrel.Expr("TO_TIMESTAMP(?)", create.UpdatedTs))
|
||||
}
|
||||
|
||||
if create.ID != 0 {
|
||||
builder = builder.Columns("id")
|
||||
values = append(values, create.ID)
|
||||
}
|
||||
|
||||
builder = builder.Values(values...)
|
||||
|
||||
builder = builder.Suffix("RETURNING id")
|
||||
|
||||
// Prepare the final query
|
||||
query, args, err := builder.ToSql()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Execute the query and get the returned ID
|
||||
var id int32
|
||||
err = d.db.QueryRowContext(ctx, query, args...).Scan(&id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Use the returned ID to retrieve the full user object
|
||||
user, err := d.GetUser(ctx, &store.FindUser{ID: &id})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return user, nil
|
||||
}
|
||||
|
||||
func (d *DB) UpdateUser(ctx context.Context, update *store.UpdateUser) (*store.User, error) {
|
||||
// Start building the update statement
|
||||
builder := squirrel.Update("\"user\"").PlaceholderFormat(squirrel.Dollar)
|
||||
|
||||
// Conditionally add set clauses
|
||||
if v := update.UpdatedTs; v != nil {
|
||||
builder = builder.Set("updated_ts", squirrel.Expr("to_timestamp(?)", *v))
|
||||
}
|
||||
if v := update.RowStatus; v != nil {
|
||||
builder = builder.Set("row_status", *v)
|
||||
}
|
||||
if v := update.Username; v != nil {
|
||||
builder = builder.Set("username", *v)
|
||||
}
|
||||
if v := update.Email; v != nil {
|
||||
builder = builder.Set("email", *v)
|
||||
}
|
||||
if v := update.Nickname; v != nil {
|
||||
builder = builder.Set("nickname", *v)
|
||||
}
|
||||
if v := update.AvatarURL; v != nil {
|
||||
builder = builder.Set("avatar_url", *v)
|
||||
}
|
||||
if v := update.PasswordHash; v != nil {
|
||||
builder = builder.Set("password_hash", *v)
|
||||
}
|
||||
|
||||
// Add the WHERE clause
|
||||
builder = builder.Where(squirrel.Eq{"id": update.ID})
|
||||
|
||||
// Prepare the final query
|
||||
query, args, err := builder.ToSql()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Execute the query with the context
|
||||
if _, err := d.db.ExecContext(ctx, query, args...); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Retrieve the updated user
|
||||
user, err := d.GetUser(ctx, &store.FindUser{ID: &update.ID})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return user, nil
|
||||
}
|
||||
|
||||
func (d *DB) ListUsers(ctx context.Context, find *store.FindUser) ([]*store.User, error) {
|
||||
// Start building the SELECT statement
|
||||
builder := squirrel.Select("id", "username", "role", "email", "nickname", "password_hash", "avatar_url",
|
||||
"FLOOR(EXTRACT(EPOCH FROM created_ts)) AS created_ts", "FLOOR(EXTRACT(EPOCH FROM updated_ts)) AS updated_ts", "row_status").
|
||||
From("\"user\"").
|
||||
PlaceholderFormat(squirrel.Dollar)
|
||||
|
||||
// 1 = 1 is often used as a no-op in SQL, ensuring there's always a WHERE clause
|
||||
builder = builder.Where("1 = 1")
|
||||
|
||||
// Conditionally add where clauses
|
||||
if v := find.ID; v != nil {
|
||||
builder = builder.Where(squirrel.Eq{"id": *v})
|
||||
}
|
||||
if v := find.Username; v != nil {
|
||||
builder = builder.Where(squirrel.Eq{"username": *v})
|
||||
}
|
||||
if v := find.Role; v != nil {
|
||||
builder = builder.Where(squirrel.Eq{"role": *v})
|
||||
}
|
||||
if v := find.Email; v != nil {
|
||||
builder = builder.Where(squirrel.Eq{"email": *v})
|
||||
}
|
||||
if v := find.Nickname; v != nil {
|
||||
builder = builder.Where(squirrel.Eq{"nickname": *v})
|
||||
}
|
||||
|
||||
// Add ordering
|
||||
builder = builder.OrderBy("created_ts DESC", "row_status DESC")
|
||||
|
||||
// Prepare the final query
|
||||
query, args, err := builder.ToSql()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Execute the query with the context
|
||||
rows, err := d.db.QueryContext(ctx, query, args...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
list := make([]*store.User, 0)
|
||||
for rows.Next() {
|
||||
var user store.User
|
||||
if err := rows.Scan(
|
||||
&user.ID,
|
||||
&user.Username,
|
||||
&user.Role,
|
||||
&user.Email,
|
||||
&user.Nickname,
|
||||
&user.PasswordHash,
|
||||
&user.AvatarURL,
|
||||
&user.CreatedTs,
|
||||
&user.UpdatedTs,
|
||||
&user.RowStatus,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
list = append(list, &user)
|
||||
}
|
||||
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return list, nil
|
||||
}
|
||||
|
||||
func (d *DB) GetUser(ctx context.Context, find *store.FindUser) (*store.User, error) {
|
||||
list, err := d.ListUsers(ctx, find)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(list) != 1 {
|
||||
return nil, errors.Wrapf(nil, "unexpected user count: %d", len(list))
|
||||
}
|
||||
return list[0], nil
|
||||
}
|
||||
|
||||
func (d *DB) DeleteUser(ctx context.Context, delete *store.DeleteUser) error {
|
||||
// Start building the DELETE statement
|
||||
builder := squirrel.Delete("\"user\"").
|
||||
PlaceholderFormat(squirrel.Dollar).
|
||||
Where(squirrel.Eq{"id": delete.ID})
|
||||
|
||||
// Prepare the final query
|
||||
query, args, err := builder.ToSql()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Execute the query with the context
|
||||
result, err := d.db.ExecContext(ctx, query, args...)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if _, err := result.RowsAffected(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := d.Vacuum(ctx); err != nil {
|
||||
// Prevent linter warning.
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
@ -0,0 +1,194 @@
|
||||
package postgres
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
|
||||
"github.com/Masterminds/squirrel"
|
||||
"github.com/pkg/errors"
|
||||
"google.golang.org/protobuf/encoding/protojson"
|
||||
|
||||
storepb "github.com/usememos/memos/proto/gen/store"
|
||||
"github.com/usememos/memos/store"
|
||||
)
|
||||
|
||||
func (d *DB) UpsertUserSetting(ctx context.Context, upsert *store.UserSetting) (*store.UserSetting, error) {
|
||||
// Construct the query using Squirrel
|
||||
query, args, err := squirrel.
|
||||
Insert("user_setting").
|
||||
Columns("user_id", "key", "value").
|
||||
Values(upsert.UserID, upsert.Key, upsert.Value).
|
||||
PlaceholderFormat(squirrel.Dollar).
|
||||
// no need to specify ON CONFLICT clause, as the primary key is (user_id, key)
|
||||
ToSql()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Execute the query
|
||||
if _, err := d.db.ExecContext(ctx, query, args...); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return upsert, nil
|
||||
}
|
||||
|
||||
func (d *DB) ListUserSettings(ctx context.Context, find *store.FindUserSetting) ([]*store.UserSetting, error) {
|
||||
// Start building the query
|
||||
qb := squirrel.Select("user_id", "key", "value").From("user_setting").Where("1 = 1").PlaceholderFormat(squirrel.Dollar)
|
||||
|
||||
// Add conditions based on the provided find parameters
|
||||
if v := find.Key; v != "" {
|
||||
qb = qb.Where(squirrel.Eq{"key": v})
|
||||
}
|
||||
if v := find.UserID; v != nil {
|
||||
qb = qb.Where(squirrel.Eq{"user_id": *v})
|
||||
}
|
||||
|
||||
// Finalize the query
|
||||
query, args, err := qb.ToSql()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Execute the query
|
||||
rows, err := d.db.QueryContext(ctx, query, args...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
// Process the rows
|
||||
userSettingList := make([]*store.UserSetting, 0)
|
||||
for rows.Next() {
|
||||
var userSetting store.UserSetting
|
||||
if err := rows.Scan(
|
||||
&userSetting.UserID,
|
||||
&userSetting.Key,
|
||||
&userSetting.Value,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
userSettingList = append(userSettingList, &userSetting)
|
||||
}
|
||||
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return userSettingList, nil
|
||||
}
|
||||
|
||||
func (d *DB) UpsertUserSettingV1(ctx context.Context, upsert *storepb.UserSetting) (*storepb.UserSetting, error) {
|
||||
var valueString string
|
||||
if upsert.Key == storepb.UserSettingKey_USER_SETTING_ACCESS_TOKENS {
|
||||
valueBytes, err := protojson.Marshal(upsert.GetAccessTokens())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
valueString = string(valueBytes)
|
||||
} else {
|
||||
return nil, errors.New("invalid user setting key")
|
||||
}
|
||||
|
||||
// Construct the query using Squirrel
|
||||
query, args, err := squirrel.
|
||||
Insert("user_setting").
|
||||
Columns("user_id", "key", "value").
|
||||
Values(upsert.UserId, upsert.Key.String(), valueString).
|
||||
Suffix("ON CONFLICT (user_id, key) DO UPDATE SET value = EXCLUDED.value").
|
||||
PlaceholderFormat(squirrel.Dollar).
|
||||
ToSql()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Execute the query
|
||||
if _, err := d.db.ExecContext(ctx, query, args...); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return upsert, nil
|
||||
}
|
||||
|
||||
func (d *DB) ListUserSettingsV1(ctx context.Context, find *store.FindUserSettingV1) ([]*storepb.UserSetting, error) {
|
||||
// Start building the query using Squirrel
|
||||
qb := squirrel.Select("user_id", "key", "value").From("user_setting").PlaceholderFormat(squirrel.Dollar)
|
||||
|
||||
// Add conditions based on the provided find parameters
|
||||
if v := find.Key; v != storepb.UserSettingKey_USER_SETTING_KEY_UNSPECIFIED {
|
||||
qb = qb.Where(squirrel.Eq{"key": v.String()})
|
||||
}
|
||||
if v := find.UserID; v != nil {
|
||||
qb = qb.Where(squirrel.Eq{"user_id": *v})
|
||||
}
|
||||
|
||||
// Finalize the query
|
||||
query, args, err := qb.ToSql()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Execute the query
|
||||
rows, err := d.db.QueryContext(ctx, query, args...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
// Process the rows
|
||||
userSettingList := make([]*storepb.UserSetting, 0)
|
||||
for rows.Next() {
|
||||
userSetting := &storepb.UserSetting{}
|
||||
var keyString, valueString string
|
||||
if err := rows.Scan(
|
||||
&userSetting.UserId,
|
||||
&keyString,
|
||||
&valueString,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
userSetting.Key = storepb.UserSettingKey(storepb.UserSettingKey_value[keyString])
|
||||
if userSetting.Key == storepb.UserSettingKey_USER_SETTING_ACCESS_TOKENS {
|
||||
accessTokensUserSetting := &storepb.AccessTokensUserSetting{}
|
||||
if err := protojson.Unmarshal([]byte(valueString), accessTokensUserSetting); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
userSetting.Value = &storepb.UserSetting_AccessTokens{
|
||||
AccessTokens: accessTokensUserSetting,
|
||||
}
|
||||
} else {
|
||||
// Skip unknown user setting v1 key
|
||||
continue
|
||||
}
|
||||
userSettingList = append(userSettingList, userSetting)
|
||||
}
|
||||
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return userSettingList, nil
|
||||
}
|
||||
|
||||
func vacuumUserSetting(ctx context.Context, tx *sql.Tx) error {
|
||||
// First, build the subquery
|
||||
subQuery, subArgs, err := squirrel.Select("id").From("\"user\"").PlaceholderFormat(squirrel.Dollar).ToSql()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Now, build the main delete query using the subquery
|
||||
query, args, err := squirrel.Delete("user_setting").
|
||||
Where(fmt.Sprintf("user_id NOT IN (%s)", subQuery), subArgs...).
|
||||
PlaceholderFormat(squirrel.Dollar).
|
||||
ToSql()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Execute the query
|
||||
_, err = tx.ExecContext(ctx, query, args...)
|
||||
return err
|
||||
}
|
@ -0,0 +1,148 @@
|
||||
package postgres
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"github.com/Masterminds/squirrel"
|
||||
|
||||
storepb "github.com/usememos/memos/proto/gen/store"
|
||||
"github.com/usememos/memos/store"
|
||||
)
|
||||
|
||||
func (d *DB) CreateWebhook(ctx context.Context, create *storepb.Webhook) (*storepb.Webhook, error) {
|
||||
qb := squirrel.Insert("webhook").Columns("name", "url", "creator_id")
|
||||
values := []any{create.Name, create.Url, create.CreatorId}
|
||||
|
||||
if create.Id != 0 {
|
||||
qb = qb.Columns("id")
|
||||
values = append(values, create.Id)
|
||||
}
|
||||
|
||||
qb = qb.Values(values...).Suffix("RETURNING id")
|
||||
query, args, err := qb.PlaceholderFormat(squirrel.Dollar).ToSql()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
err = d.db.QueryRowContext(ctx, query, args...).Scan(&create.Id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
create, err = d.GetWebhook(ctx, &store.FindWebhook{ID: &create.Id})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return create, nil
|
||||
}
|
||||
|
||||
func (d *DB) ListWebhooks(ctx context.Context, find *store.FindWebhook) ([]*storepb.Webhook, error) {
|
||||
qb := squirrel.Select("id", "created_ts", "updated_ts", "row_status", "creator_id", "name", "url").From("webhook").OrderBy("id DESC")
|
||||
|
||||
if find.ID != nil {
|
||||
qb = qb.Where(squirrel.Eq{"id": *find.ID})
|
||||
}
|
||||
if find.CreatorID != nil {
|
||||
qb = qb.Where(squirrel.Eq{"creator_id": *find.CreatorID})
|
||||
}
|
||||
|
||||
query, args, err := qb.PlaceholderFormat(squirrel.Dollar).ToSql()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
rows, err := d.db.QueryContext(ctx, query, args...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
list := []*storepb.Webhook{}
|
||||
for rows.Next() {
|
||||
webhook := &storepb.Webhook{}
|
||||
var rowStatus string
|
||||
var createdTs, updatedTs time.Time
|
||||
|
||||
if err := rows.Scan(
|
||||
&webhook.Id,
|
||||
&createdTs,
|
||||
&updatedTs,
|
||||
&rowStatus,
|
||||
&webhook.CreatorId,
|
||||
&webhook.Name,
|
||||
&webhook.Url,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
webhook.CreatedTs = createdTs.UnixNano()
|
||||
webhook.UpdatedTs = updatedTs.UnixNano()
|
||||
webhook.RowStatus = storepb.RowStatus(storepb.RowStatus_value[rowStatus])
|
||||
|
||||
list = append(list, webhook)
|
||||
}
|
||||
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return list, nil
|
||||
}
|
||||
|
||||
func (d *DB) GetWebhook(ctx context.Context, find *store.FindWebhook) (*storepb.Webhook, error) {
|
||||
list, err := d.ListWebhooks(ctx, find)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(list) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
return list[0], nil
|
||||
}
|
||||
|
||||
func (d *DB) UpdateWebhook(ctx context.Context, update *store.UpdateWebhook) (*storepb.Webhook, error) {
|
||||
qb := squirrel.Update("webhook")
|
||||
|
||||
if update.RowStatus != nil {
|
||||
qb = qb.Set("row_status", update.RowStatus.String())
|
||||
}
|
||||
if update.Name != nil {
|
||||
qb = qb.Set("name", *update.Name)
|
||||
}
|
||||
if update.URL != nil {
|
||||
qb = qb.Set("url", *update.URL)
|
||||
}
|
||||
|
||||
qb = qb.Where(squirrel.Eq{"id": update.ID})
|
||||
|
||||
query, args, err := qb.PlaceholderFormat(squirrel.Dollar).ToSql()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
_, err = d.db.ExecContext(ctx, query, args...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
webhook, err := d.GetWebhook(ctx, &store.FindWebhook{ID: &update.ID})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return webhook, nil
|
||||
}
|
||||
|
||||
func (d *DB) DeleteWebhook(ctx context.Context, delete *store.DeleteWebhook) error {
|
||||
qb := squirrel.Delete("webhook").Where(squirrel.Eq{"id": delete.ID})
|
||||
|
||||
query, args, err := qb.PlaceholderFormat(squirrel.Dollar).ToSql()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_, err = d.db.ExecContext(ctx, query, args...)
|
||||
return err
|
||||
}
|
Loading…
Reference in New Issue