feat: support Postgres (#2569)

* skeleton of postgres

skeleton

* Adding Postgres specific db schema sql

* user test passed

* memo store test passed

* tag is working

* update user setting test done

* activity test done

* idp test passed

* inbox test done

* memo_organizer, UNTESTED

* memo relation test passed

* webhook test passed

* system setting test passed

* passed storage test

* pass resource test

* migration_history done

* fix memo_relation_test

* fixing server memo_relation test

* passes memo relation server test

* paess memo test

* final manual testing done

* final fixes

* final fixes cleanup

* sync schema

* lint

* lint

* lint

* lint

* lint
pull/2576/head
Irving Ou 1 year ago committed by GitHub
parent 484efbbfe2
commit 9c18960f47
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -3,6 +3,7 @@ module github.com/usememos/memos
go 1.21
require (
github.com/Masterminds/squirrel v1.5.4
github.com/aws/aws-sdk-go-v2 v1.22.1
github.com/aws/aws-sdk-go-v2/config v1.22.1
github.com/aws/aws-sdk-go-v2/credentials v1.15.1
@ -16,6 +17,7 @@ require (
github.com/grpc-ecosystem/grpc-gateway/v2 v2.18.1
github.com/improbable-eng/grpc-web v0.15.0
github.com/labstack/echo/v4 v4.11.2
github.com/lib/pq v1.10.9
github.com/microcosm-cc/bluemonday v1.0.26
github.com/pkg/errors v0.9.1
github.com/spf13/cobra v1.8.0
@ -50,6 +52,8 @@ require (
github.com/gorilla/css v1.0.1 // indirect
github.com/josharian/intern v1.0.0 // indirect
github.com/kballard/go-shellquote v0.0.0-20180428030007-95032a82bc51 // indirect
github.com/lann/builder v0.0.0-20180802200727-47ae307949d0 // indirect
github.com/lann/ps v0.0.0-20150810152359-62de8c46ede0 // indirect
github.com/mailru/easyjson v0.7.7 // indirect
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect
github.com/rs/cors v1.10.1 // indirect

@ -41,6 +41,8 @@ github.com/BurntSushi/xgb v0.0.0-20160522181843-27f122750802/go.mod h1:IVnqGOEym
github.com/Knetic/govaluate v3.0.1-0.20171022003610-9aa49832a739+incompatible/go.mod h1:r7JcOSlj0wfOMncg0iLm8Leh48TZaKVeNIfJntJ2wa0=
github.com/KyleBanks/depth v1.2.1 h1:5h8fQADFrWtarTdtDudMmGsC7GPbOAu6RVB3ffsVFHc=
github.com/KyleBanks/depth v1.2.1/go.mod h1:jzSb9d0L43HxTQfT+oSA1EEp2q+ne2uh6XgeJcm8brE=
github.com/Masterminds/squirrel v1.5.4 h1:uUcX/aBc8O7Fg9kaISIUsHXdKuqehiXAMQTYX8afzqM=
github.com/Masterminds/squirrel v1.5.4/go.mod h1:NNaOrjSoIDfDA40n7sr2tPNZRfjzjA400rg+riTZj10=
github.com/Shopify/sarama v1.19.0/go.mod h1:FVkBWblsNy7DGZRfXLU0O9RCGt5g3g3yEuWXgklEdEo=
github.com/Shopify/toxiproxy v2.1.4+incompatible/go.mod h1:OXgGpZ6Cli1/URJOF1DMxUHB2q5Ap20/P/eIdh4G0pI=
github.com/VividCortex/gohistogram v1.0.0/go.mod h1:Pf5mBqqDxYaXu3hDrrU+w6nw50o/4+TcAqDqk/vUH7g=
@ -368,7 +370,13 @@ github.com/labstack/echo/v4 v4.11.2 h1:T+cTLQxWCDfqDEoydYm5kCobjmHwOwcv4OJAPHilm
github.com/labstack/echo/v4 v4.11.2/go.mod h1:UcGuQ8V6ZNRmSweBIJkPvGfwCMIlFmiqrPqiEBfPYws=
github.com/labstack/gommon v0.4.0 h1:y7cvthEAEbU0yHOf4axH8ZG2NH8knB9iNSoTO8dyIk8=
github.com/labstack/gommon v0.4.0/go.mod h1:uW6kP17uPlLJsD3ijUYn3/M5bAxtlZhMI6m3MFxTMTM=
github.com/lann/builder v0.0.0-20180802200727-47ae307949d0 h1:SOEGU9fKiNWd/HOJuq6+3iTQz8KNCLtVX6idSoTLdUw=
github.com/lann/builder v0.0.0-20180802200727-47ae307949d0/go.mod h1:dXGbAdH5GtBTC4WfIxhKZfyBF/HBFgRZSWwZ9g/He9o=
github.com/lann/ps v0.0.0-20150810152359-62de8c46ede0 h1:P6pPBnrTSX3DEVR4fDembhRWSsG5rVo6hYhAB/ADZrk=
github.com/lann/ps v0.0.0-20150810152359-62de8c46ede0/go.mod h1:vmVJ0l/dxyfGW6FmdpVm2joNMFikkuWg0EoCKLGUMNw=
github.com/leodido/go-urn v1.2.0/go.mod h1:+8+nEpDfqqsY+g338gtMEUOtuK+4dEMhiQEgxpxOKII=
github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw=
github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o=
github.com/lightstep/lightstep-tracer-common/golang/gogo v0.0.0-20190605223551-bc2310a04743/go.mod h1:qklhhLq1aX+mtWk9cPHPzaBjWImj5ULL6C7HFJtXQMM=
github.com/lightstep/lightstep-tracer-go v0.18.1/go.mod h1:jlF1pusYV4pidLvZ+XD0UBX0ZE6WURAspgAczcDHrL4=
github.com/lyft/protoc-gen-validate v0.0.13/go.mod h1:XbGvPuh87YZc5TdIa2/I4pLk0QoUACkjt2znoq26NVQ=

@ -6,6 +6,7 @@ import (
"github.com/usememos/memos/server/profile"
"github.com/usememos/memos/store"
"github.com/usememos/memos/store/db/mysql"
"github.com/usememos/memos/store/db/postgres"
"github.com/usememos/memos/store/db/sqlite"
)
@ -19,6 +20,8 @@ func NewDBDriver(profile *profile.Profile) (store.Driver, error) {
driver, err = sqlite.NewDB(profile)
case "mysql":
driver, err = mysql.NewDB(profile)
case "postgres":
driver, err = postgres.NewDB(profile)
default:
return nil, errors.New("unknown db driver")
}

@ -0,0 +1,117 @@
package postgres
import (
"context"
"time"
"github.com/Masterminds/squirrel"
"github.com/pkg/errors"
"google.golang.org/protobuf/encoding/protojson"
storepb "github.com/usememos/memos/proto/gen/store"
"github.com/usememos/memos/store"
)
func (d *DB) CreateActivity(ctx context.Context, create *store.Activity) (*store.Activity, error) {
payloadString := "{}"
if create.Payload != nil {
bytes, err := protojson.Marshal(create.Payload)
if err != nil {
return nil, errors.Wrap(err, "failed to marshal activity payload")
}
payloadString = string(bytes)
}
qb := squirrel.Insert("activity").
Columns("creator_id", "type", "level", "payload").
PlaceholderFormat(squirrel.Dollar)
values := []any{create.CreatorID, create.Type.String(), create.Level.String(), payloadString}
if create.ID != 0 {
qb = qb.Columns("id")
values = append(values, create.ID)
}
if create.CreatedTs != 0 {
qb = qb.Columns("created_ts")
values = append(values, squirrel.Expr("TO_TIMESTAMP(?)", create.CreatedTs))
}
qb = qb.Values(values...).Suffix("RETURNING id")
stmt, args, err := qb.ToSql()
if err != nil {
return nil, errors.Wrap(err, "failed to construct query")
}
var id int32
err = d.db.QueryRowContext(ctx, stmt, args...).Scan(&id)
if err != nil {
return nil, errors.Wrap(err, "failed to execute statement and retrieve ID")
}
list, err := d.ListActivities(ctx, &store.FindActivity{ID: &id})
if err != nil || len(list) == 0 {
return nil, errors.Wrap(err, "failed to find activity")
}
return list[0], nil
}
func (d *DB) ListActivities(ctx context.Context, find *store.FindActivity) ([]*store.Activity, error) {
qb := squirrel.Select("id", "creator_id", "type", "level", "payload", "created_ts").
From("activity").
Where("1 = 1").
PlaceholderFormat(squirrel.Dollar)
if find.ID != nil {
qb = qb.Where(squirrel.Eq{"id": *find.ID})
}
if find.Type != nil {
qb = qb.Where(squirrel.Eq{"type": find.Type.String()})
}
query, args, err := qb.ToSql()
if err != nil {
return nil, err
}
rows, err := d.db.QueryContext(ctx, query, args...)
if err != nil {
return nil, err
}
defer rows.Close()
list := []*store.Activity{}
for rows.Next() {
activity := &store.Activity{}
var payloadBytes []byte
createdTsPlaceHolder := time.Time{}
if err := rows.Scan(
&activity.ID,
&activity.CreatorID,
&activity.Type,
&activity.Level,
&payloadBytes,
&createdTsPlaceHolder,
); err != nil {
return nil, err
}
activity.CreatedTs = createdTsPlaceHolder.Unix()
payload := &storepb.ActivityPayload{}
if err := protojson.Unmarshal(payloadBytes, payload); err != nil {
return nil, err
}
activity.Payload = payload
list = append(list, activity)
}
if err := rows.Err(); err != nil {
return nil, err
}
return list, nil
}

@ -0,0 +1,9 @@
package postgres
import "google.golang.org/protobuf/encoding/protojson"
var (
protojsonUnmarshaler = protojson.UnmarshalOptions{
DiscardUnknown: true,
}
)

@ -0,0 +1,178 @@
package postgres
import (
"context"
"encoding/json"
"github.com/Masterminds/squirrel"
"github.com/pkg/errors"
"github.com/usememos/memos/store"
)
func (d *DB) CreateIdentityProvider(ctx context.Context, create *store.IdentityProvider) (*store.IdentityProvider, error) {
var configBytes []byte
if create.Type == store.IdentityProviderOAuth2Type {
bytes, err := json.Marshal(create.Config.OAuth2Config)
if err != nil {
return nil, err
}
configBytes = bytes
} else {
return nil, errors.Errorf("unsupported idp type %s", string(create.Type))
}
qb := squirrel.Insert("idp").Columns("name", "type", "identifier_filter", "config")
values := []any{create.Name, create.Type, create.IdentifierFilter, string(configBytes)}
if create.ID != 0 {
qb = qb.Columns("id")
values = append(values, create.ID)
}
qb = qb.Values(values...).PlaceholderFormat(squirrel.Dollar)
qb = qb.Suffix("RETURNING id")
stmt, args, err := qb.ToSql()
if err != nil {
return nil, err
}
var id int32
err = d.db.QueryRowContext(ctx, stmt, args...).Scan(&id)
if err != nil {
return nil, err
}
create.ID = id
return create, nil
}
func (d *DB) ListIdentityProviders(ctx context.Context, find *store.FindIdentityProvider) ([]*store.IdentityProvider, error) {
qb := squirrel.Select("id", "name", "type", "identifier_filter", "config").
From("idp").
Where("1 = 1").
PlaceholderFormat(squirrel.Dollar)
if v := find.ID; v != nil {
qb = qb.Where(squirrel.Eq{"id": *v})
}
query, args, err := qb.ToSql()
if err != nil {
return nil, err
}
rows, err := d.db.QueryContext(ctx, query, args...)
if err != nil {
return nil, err
}
defer rows.Close()
var identityProviders []*store.IdentityProvider
for rows.Next() {
var identityProvider store.IdentityProvider
var identityProviderConfig string
if err := rows.Scan(
&identityProvider.ID,
&identityProvider.Name,
&identityProvider.Type,
&identityProvider.IdentifierFilter,
&identityProviderConfig,
); err != nil {
return nil, err
}
if identityProvider.Type == store.IdentityProviderOAuth2Type {
oauth2Config := &store.IdentityProviderOAuth2Config{}
if err := json.Unmarshal([]byte(identityProviderConfig), oauth2Config); err != nil {
return nil, err
}
identityProvider.Config = &store.IdentityProviderConfig{
OAuth2Config: oauth2Config,
}
} else {
return nil, errors.Errorf("unsupported idp type %s", string(identityProvider.Type))
}
identityProviders = append(identityProviders, &identityProvider)
}
if err := rows.Err(); err != nil {
return nil, err
}
return identityProviders, nil
}
func (d *DB) GetIdentityProvider(ctx context.Context, find *store.FindIdentityProvider) (*store.IdentityProvider, error) {
list, err := d.ListIdentityProviders(ctx, find)
if err != nil {
return nil, err
}
if len(list) == 0 {
return nil, nil
}
return list[0], nil
}
func (d *DB) UpdateIdentityProvider(ctx context.Context, update *store.UpdateIdentityProvider) (*store.IdentityProvider, error) {
qb := squirrel.Update("idp").
PlaceholderFormat(squirrel.Dollar)
var err error
if v := update.Name; v != nil {
qb = qb.Set("name", *v)
}
if v := update.IdentifierFilter; v != nil {
qb = qb.Set("identifier_filter", *v)
}
if v := update.Config; v != nil {
var configBytes []byte
if update.Type == store.IdentityProviderOAuth2Type {
bytes, err := json.Marshal(update.Config.OAuth2Config)
if err != nil {
return nil, err
}
configBytes = bytes
} else {
return nil, errors.Errorf("unsupported idp type %s", string(update.Type))
}
qb = qb.Set("config", string(configBytes))
}
qb = qb.Where(squirrel.Eq{"id": update.ID})
stmt, args, err := qb.ToSql()
if err != nil {
return nil, err
}
_, err = d.db.ExecContext(ctx, stmt, args...)
if err != nil {
return nil, err
}
return d.GetIdentityProvider(ctx, &store.FindIdentityProvider{ID: &update.ID})
}
func (d *DB) DeleteIdentityProvider(ctx context.Context, delete *store.DeleteIdentityProvider) error {
qb := squirrel.Delete("idp").
Where(squirrel.Eq{"id": delete.ID}).
PlaceholderFormat(squirrel.Dollar)
stmt, args, err := qb.ToSql()
if err != nil {
return err
}
result, err := d.db.ExecContext(ctx, stmt, args...)
if err != nil {
return err
}
if _, err = result.RowsAffected(); err != nil {
return err
}
return nil
}

@ -0,0 +1,144 @@
package postgres
import (
"context"
"time"
"github.com/Masterminds/squirrel"
"github.com/pkg/errors"
"google.golang.org/protobuf/encoding/protojson"
storepb "github.com/usememos/memos/proto/gen/store"
"github.com/usememos/memos/store"
)
func (d *DB) CreateInbox(ctx context.Context, create *store.Inbox) (*store.Inbox, error) {
messageString := "{}"
if create.Message != nil {
bytes, err := protojson.Marshal(create.Message)
if err != nil {
return nil, errors.Wrap(err, "failed to marshal inbox message")
}
messageString = string(bytes)
}
qb := squirrel.Insert("inbox").
Columns("sender_id", "receiver_id", "status", "message").
Values(create.SenderID, create.ReceiverID, create.Status, messageString).
Suffix("RETURNING id").
PlaceholderFormat(squirrel.Dollar)
stmt, args, err := qb.ToSql()
if err != nil {
return nil, err
}
var id int32
err = d.db.QueryRowContext(ctx, stmt, args...).Scan(&id)
if err != nil {
return nil, err
}
return d.GetInbox(ctx, &store.FindInbox{ID: &id})
}
func (d *DB) ListInboxes(ctx context.Context, find *store.FindInbox) ([]*store.Inbox, error) {
qb := squirrel.Select("id", "created_ts", "sender_id", "receiver_id", "status", "message").
From("inbox").
Where("1 = 1").
PlaceholderFormat(squirrel.Dollar)
if find.ID != nil {
qb = qb.Where(squirrel.Eq{"id": *find.ID})
}
if find.SenderID != nil {
qb = qb.Where(squirrel.Eq{"sender_id": *find.SenderID})
}
if find.ReceiverID != nil {
qb = qb.Where(squirrel.Eq{"receiver_id": *find.ReceiverID})
}
if find.Status != nil {
qb = qb.Where(squirrel.Eq{"status": *find.Status})
}
query, args, err := qb.ToSql()
if err != nil {
return nil, err
}
rows, err := d.db.QueryContext(ctx, query, args...)
if err != nil {
return nil, err
}
defer rows.Close()
var list []*store.Inbox
for rows.Next() {
inbox := &store.Inbox{}
var messageBytes []byte
createdTsPlaceHolder := time.Time{}
if err := rows.Scan(&inbox.ID, &createdTsPlaceHolder, &inbox.SenderID, &inbox.ReceiverID, &inbox.Status, &messageBytes); err != nil {
return nil, err
}
inbox.CreatedTs = createdTsPlaceHolder.Unix()
message := &storepb.InboxMessage{}
if err := protojsonUnmarshaler.Unmarshal(messageBytes, message); err != nil {
return nil, err
}
inbox.Message = message
list = append(list, inbox)
}
return list, rows.Err()
}
func (d *DB) GetInbox(ctx context.Context, find *store.FindInbox) (*store.Inbox, error) {
list, err := d.ListInboxes(ctx, find)
if err != nil {
return nil, errors.Wrap(err, "failed to get inbox")
}
if len(list) != 1 {
return nil, errors.Wrapf(nil, "unexpected inbox count: %d", len(list))
}
return list[0], nil
}
func (d *DB) UpdateInbox(ctx context.Context, update *store.UpdateInbox) (*store.Inbox, error) {
qb := squirrel.Update("inbox").
Set("status", update.Status.String()).
Where(squirrel.Eq{"id": update.ID}).
PlaceholderFormat(squirrel.Dollar)
stmt, args, err := qb.ToSql()
if err != nil {
return nil, err
}
_, err = d.db.ExecContext(ctx, stmt, args...)
if err != nil {
return nil, err
}
return d.GetInbox(ctx, &store.FindInbox{ID: &update.ID})
}
func (d *DB) DeleteInbox(ctx context.Context, delete *store.DeleteInbox) error {
qb := squirrel.Delete("inbox").
Where(squirrel.Eq{"id": delete.ID}).
PlaceholderFormat(squirrel.Dollar)
stmt, args, err := qb.ToSql()
if err != nil {
return err
}
result, err := d.db.ExecContext(ctx, stmt, args...)
if err != nil {
return err
}
_, err = result.RowsAffected()
return err
}

@ -0,0 +1,370 @@
package postgres
import (
"context"
"database/sql"
"encoding/binary"
"fmt"
"strings"
"github.com/Masterminds/squirrel"
"github.com/pkg/errors"
"github.com/usememos/memos/internal/util"
"github.com/usememos/memos/store"
)
func (d *DB) CreateMemo(ctx context.Context, create *store.Memo) (*store.Memo, error) {
// Initialize a Squirrel statement builder for PostgreSQL
builder := squirrel.Insert("memo").
PlaceholderFormat(squirrel.Dollar).
Columns("creator_id", "content", "visibility")
// Add initial values for the columns
values := []any{create.CreatorID, create.Content, create.Visibility}
// Conditionally add other fields and values
if create.ID != 0 {
builder = builder.Columns("id")
values = append(values, create.ID)
}
if create.CreatedTs != 0 {
builder = builder.Columns("created_ts")
values = append(values, squirrel.Expr("to_timestamp(?)", create.CreatedTs))
}
if create.UpdatedTs != 0 {
builder = builder.Columns("updated_ts")
values = append(values, squirrel.Expr("to_timestamp(?)", create.UpdatedTs))
}
if create.RowStatus != "" {
builder = builder.Columns("row_status")
values = append(values, create.RowStatus)
}
// Add all the values at once
builder = builder.Values(values...)
// Add the RETURNING clause to get the ID of the inserted row
builder = builder.Suffix("RETURNING id")
// Prepare and execute the query
query, args, err := builder.ToSql()
if err != nil {
return nil, err
}
var id int32
err = d.db.QueryRowContext(ctx, query, args...).Scan(&id)
if err != nil {
return nil, err
}
// Retrieve the newly created memo
memo, err := d.GetMemo(ctx, &store.FindMemo{ID: &id})
if err != nil {
return nil, err
}
if memo == nil {
return nil, errors.Errorf("failed to create memo")
}
return memo, nil
}
func (d *DB) ListMemos(ctx context.Context, find *store.FindMemo) ([]*store.Memo, error) {
// Start building the SELECT statement
builder := squirrel.Select(
"memo.id AS id",
"memo.creator_id AS creator_id",
"EXTRACT(EPOCH FROM memo.created_ts) AS created_ts",
"EXTRACT(EPOCH FROM memo.updated_ts) AS updated_ts",
"memo.row_status AS row_status",
"memo.content AS content",
"memo.visibility AS visibility",
"MAX(CASE WHEN memo_organizer.pinned = 1 THEN 1 ELSE 0 END) AS pinned",
"string_agg(CAST(resource.id AS TEXT), ',') AS resource_id_list", // Cast to TEXT
"(SELECT string_agg(CAST(memo_id AS TEXT) || ':' || CAST(related_memo_id AS TEXT) || ':' || type, ',') FROM memo_relation WHERE memo_relation.memo_id = memo.id OR memo_relation.related_memo_id = memo.id) AS relation_list"). // Cast IDs to TEXT
From("memo").
LeftJoin("memo_organizer ON memo.id = memo_organizer.memo_id").
LeftJoin("resource ON memo.id = resource.memo_id").
GroupBy("memo.id").
PlaceholderFormat(squirrel.Dollar)
// Add conditional where clauses
if v := find.ID; v != nil {
builder = builder.Where("memo.id = ?", *v)
}
if v := find.CreatorID; v != nil {
builder = builder.Where("memo.creator_id = ?", *v)
}
if v := find.RowStatus; v != nil {
builder = builder.Where("memo.row_status = ?", *v)
}
if v := find.CreatedTsBefore; v != nil {
builder = builder.Where("EXTRACT(EPOCH FROM memo.created_ts) < ?", *v)
}
if v := find.CreatedTsAfter; v != nil {
builder = builder.Where("EXTRACT(EPOCH FROM memo.created_ts) > ?", *v)
}
if v := find.Pinned; v != nil {
builder = builder.Where("memo_organizer.pinned = 1")
}
if v := find.ContentSearch; len(v) != 0 {
for _, s := range v {
builder = builder.Where("memo.content LIKE ?", "%"+s+"%")
}
}
if v := find.VisibilityList; len(v) != 0 {
placeholders := make([]string, len(v))
args := make([]any, len(v))
for i, visibility := range v {
placeholders[i] = "?"
args[i] = visibility // Assuming visibility can be directly used as an argument
}
inClause := strings.Join(placeholders, ",")
builder = builder.Where("memo.visibility IN ("+inClause+")", args...)
}
// Add order by clauses
if find.OrderByPinned {
builder = builder.OrderBy("pinned DESC")
}
if find.OrderByUpdatedTs {
builder = builder.OrderBy("updated_ts DESC")
} else {
builder = builder.OrderBy("created_ts DESC")
}
builder = builder.OrderBy("id DESC")
// Handle pagination
if find.Limit != nil {
builder = builder.Limit(uint64(*find.Limit))
if find.Offset != nil {
builder = builder.Offset(uint64(*find.Offset))
}
}
// Prepare and execute the query
query, args, err := builder.ToSql()
if err != nil {
return nil, err
}
rows, err := d.db.QueryContext(ctx, query, args...)
if err != nil {
return nil, err
}
defer rows.Close()
// Process the result set
list := make([]*store.Memo, 0)
updatedTsPlaceHolder, createdTsPlaceHolder := make([]uint8, 8), make([]uint8, 8)
for rows.Next() {
var memo store.Memo
var memoResourceIDList sql.NullString
var memoRelationList sql.NullString
if err := rows.Scan(
&memo.ID,
&memo.CreatorID,
&createdTsPlaceHolder,
&updatedTsPlaceHolder,
&memo.RowStatus,
&memo.Content,
&memo.Visibility,
&memo.Pinned,
&memoResourceIDList,
&memoRelationList,
); err != nil {
return nil, err
}
// Convert the timestamps from Postgres to Go
memo.CreatedTs = int64(binary.BigEndian.Uint64(createdTsPlaceHolder))
memo.UpdatedTs = int64(binary.BigEndian.Uint64(updatedTsPlaceHolder))
if memoResourceIDList.Valid {
idStringList := strings.Split(memoResourceIDList.String, ",")
memo.ResourceIDList = make([]int32, 0, len(idStringList))
for _, idString := range idStringList {
id, err := util.ConvertStringToInt32(idString)
if err != nil {
return nil, err
}
memo.ResourceIDList = append(memo.ResourceIDList, id)
}
}
if memoRelationList.Valid {
memo.RelationList = make([]*store.MemoRelation, 0)
relatedMemoTypeList := strings.Split(memoRelationList.String, ",")
for _, relatedMemoType := range relatedMemoTypeList {
relatedMemoTypeList := strings.Split(relatedMemoType, ":")
if len(relatedMemoTypeList) != 3 {
return nil, errors.Errorf("invalid relation format")
}
memoID, err := util.ConvertStringToInt32(relatedMemoTypeList[0])
if err != nil {
return nil, err
}
relatedMemoID, err := util.ConvertStringToInt32(relatedMemoTypeList[1])
if err != nil {
return nil, err
}
relationType := store.MemoRelationType(relatedMemoTypeList[2])
memo.RelationList = append(memo.RelationList, &store.MemoRelation{
MemoID: memoID,
RelatedMemoID: relatedMemoID,
Type: relationType,
})
// Set the first parent ID if relation type is comment.
if memo.ParentID == nil && memoID == memo.ID && relationType == store.MemoRelationComment {
memo.ParentID = &relatedMemoID
}
}
}
list = append(list, &memo)
}
if err := rows.Err(); err != nil {
return nil, err
}
return list, nil
}
func (d *DB) GetMemo(ctx context.Context, find *store.FindMemo) (*store.Memo, error) {
list, err := d.ListMemos(ctx, find)
if err != nil {
return nil, err
}
if len(list) == 0 {
return nil, nil
}
memo := list[0]
return memo, nil
}
func (d *DB) UpdateMemo(ctx context.Context, update *store.UpdateMemo) error {
// Start building the update statement
builder := squirrel.Update("memo").
PlaceholderFormat(squirrel.Dollar).
Where("id = ?", update.ID)
// Conditionally add set clauses
if v := update.CreatedTs; v != nil {
builder = builder.Set("created_ts", squirrel.Expr("to_timestamp(?)", *v))
}
if v := update.UpdatedTs; v != nil {
builder = builder.Set("updated_ts", squirrel.Expr("to_timestamp(?)", *v))
}
if v := update.RowStatus; v != nil {
builder = builder.Set("row_status", *v)
}
if v := update.Content; v != nil {
builder = builder.Set("content", *v)
}
if v := update.Visibility; v != nil {
builder = builder.Set("visibility", *v)
}
// Prepare and execute the query
query, args, err := builder.ToSql()
if err != nil {
return err
}
if _, err := d.db.ExecContext(ctx, query, args...); err != nil {
return err
}
return nil
}
func (d *DB) DeleteMemo(ctx context.Context, delete *store.DeleteMemo) error {
// Start building the DELETE statement
builder := squirrel.Delete("memo").
PlaceholderFormat(squirrel.Dollar).
Where(squirrel.Eq{"id": delete.ID})
// Prepare the final query
query, args, err := builder.ToSql()
if err != nil {
return err
}
// Execute the query with the context
result, err := d.db.ExecContext(ctx, query, args...)
if err != nil {
return err
}
if _, err := result.RowsAffected(); err != nil {
return err
}
// Perform any additional cleanup or operations such as vacuuming
// irving: wait, why do we need to vacuum here?
// I don't know why delete memo needs to vacuum. so I commented out.
// REVIEWERS LOOK AT ME: please check this.
return d.Vacuum(ctx)
}
func (d *DB) FindMemosVisibilityList(ctx context.Context, memoIDs []int32) ([]store.Visibility, error) {
// Start building the SELECT statement
builder := squirrel.Select("DISTINCT(visibility)").From("memo").
PlaceholderFormat(squirrel.Dollar).
Where(squirrel.Eq{"id": memoIDs})
// Prepare the final query
query, args, err := builder.ToSql()
if err != nil {
return nil, err
}
// Execute the query with the context
rows, err := d.db.QueryContext(ctx, query, args...)
if err != nil {
return nil, err
}
defer rows.Close()
visibilityList := make([]store.Visibility, 0)
for rows.Next() {
var visibility store.Visibility
if err := rows.Scan(&visibility); err != nil {
return nil, err
}
visibilityList = append(visibilityList, visibility)
}
if err := rows.Err(); err != nil {
return nil, err
}
return visibilityList, nil
}
func vacuumMemo(ctx context.Context, tx *sql.Tx) error {
// First, build the subquery
subQuery, subArgs, err := squirrel.Select("id").From("user").PlaceholderFormat(squirrel.Dollar).ToSql()
if err != nil {
return err
}
// Now, build the main delete query using the subquery
query, args, err := squirrel.Delete("memo").
Where(fmt.Sprintf("creator_id NOT IN (%s)", subQuery), subArgs...).
PlaceholderFormat(squirrel.Dollar).
ToSql()
if err != nil {
return err
}
// Execute the query
_, err = tx.ExecContext(ctx, query, args...)
return err
}

@ -0,0 +1,123 @@
package postgres
import (
"context"
"database/sql"
"fmt"
"github.com/Masterminds/squirrel"
"github.com/usememos/memos/store"
)
func (d *DB) UpsertMemoOrganizer(ctx context.Context, upsert *store.MemoOrganizer) (*store.MemoOrganizer, error) {
pinnedValue := 0
if upsert.Pinned {
pinnedValue = 1
}
qb := squirrel.Insert("memo_organizer").
Columns("memo_id", "user_id", "pinned").
Values(upsert.MemoID, upsert.UserID, pinnedValue).
PlaceholderFormat(squirrel.Dollar)
stmt, args, err := qb.ToSql()
if err != nil {
return nil, err
}
if _, err = d.db.ExecContext(ctx, stmt, args...); err != nil {
return nil, err
}
return upsert, nil
}
func (d *DB) ListMemoOrganizer(ctx context.Context, find *store.FindMemoOrganizer) ([]*store.MemoOrganizer, error) {
qb := squirrel.Select("memo_id", "user_id", "pinned").
From("memo_organizer").
Where("1 = 1").
PlaceholderFormat(squirrel.Dollar)
if find.MemoID != 0 {
qb = qb.Where(squirrel.Eq{"memo_id": find.MemoID})
}
if find.UserID != 0 {
qb = qb.Where(squirrel.Eq{"user_id": find.UserID})
}
query, args, err := qb.ToSql()
if err != nil {
return nil, err
}
rows, err := d.db.QueryContext(ctx, query, args...)
if err != nil {
return nil, err
}
defer rows.Close()
var list []*store.MemoOrganizer
for rows.Next() {
memoOrganizer := &store.MemoOrganizer{}
if err := rows.Scan(&memoOrganizer.MemoID, &memoOrganizer.UserID, &memoOrganizer.Pinned); err != nil {
return nil, err
}
list = append(list, memoOrganizer)
}
return list, rows.Err()
}
func (d *DB) DeleteMemoOrganizer(ctx context.Context, delete *store.DeleteMemoOrganizer) error {
qb := squirrel.Delete("memo_organizer").
PlaceholderFormat(squirrel.Dollar)
if v := delete.MemoID; v != nil {
qb = qb.Where(squirrel.Eq{"memo_id": *v})
}
if v := delete.UserID; v != nil {
qb = qb.Where(squirrel.Eq{"user_id": *v})
}
stmt, args, err := qb.ToSql()
if err != nil {
return err
}
if _, err = d.db.ExecContext(ctx, stmt, args...); err != nil {
return err
}
return nil
}
func vacuumMemoOrganizer(ctx context.Context, tx *sql.Tx) error {
// First, build the subquery for memo_id
subQueryMemo, subArgsMemo, err := squirrel.Select("id").From("memo").PlaceholderFormat(squirrel.Dollar).ToSql()
if err != nil {
return err
}
// Build the subquery for user_id
subQueryUser, subArgsUser, err := squirrel.Select("id").From("\"user\"").PlaceholderFormat(squirrel.Dollar).ToSql()
if err != nil {
return err
}
// Now, build the main delete query using the subqueries
query, args, err := squirrel.Delete("memo_organizer").
Where(fmt.Sprintf("memo_id NOT IN (%s)", subQueryMemo), subArgsMemo...).
Where(fmt.Sprintf("user_id NOT IN (%s)", subQueryUser), subArgsUser...).
PlaceholderFormat(squirrel.Dollar).
ToSql()
if err != nil {
return err
}
// Combine the arguments from both subqueries
args = append(args, subArgsUser...)
// Execute the query
_, err = tx.ExecContext(ctx, query, args...)
return err
}

@ -0,0 +1,128 @@
package postgres
import (
"context"
"database/sql"
"fmt"
"github.com/Masterminds/squirrel"
"github.com/usememos/memos/store"
)
func (d *DB) UpsertMemoRelation(ctx context.Context, create *store.MemoRelation) (*store.MemoRelation, error) {
qb := squirrel.Insert("memo_relation").
Columns("memo_id", "related_memo_id", "type").
Values(create.MemoID, create.RelatedMemoID, create.Type).
PlaceholderFormat(squirrel.Dollar)
stmt, args, err := qb.ToSql()
if err != nil {
return nil, err
}
_, err = d.db.ExecContext(ctx, stmt, args...)
if err != nil {
return nil, err
}
return &store.MemoRelation{
MemoID: create.MemoID,
RelatedMemoID: create.RelatedMemoID,
Type: create.Type,
}, nil
}
func (d *DB) ListMemoRelations(ctx context.Context, find *store.FindMemoRelation) ([]*store.MemoRelation, error) {
qb := squirrel.Select("memo_id", "related_memo_id", "type").
From("memo_relation").
Where("TRUE").
PlaceholderFormat(squirrel.Dollar)
if find.MemoID != nil {
qb = qb.Where(squirrel.Eq{"memo_id": *find.MemoID})
}
if find.RelatedMemoID != nil {
qb = qb.Where(squirrel.Eq{"related_memo_id": *find.RelatedMemoID})
}
if find.Type != nil {
qb = qb.Where(squirrel.Eq{"type": *find.Type})
}
query, args, err := qb.ToSql()
if err != nil {
return nil, err
}
rows, err := d.db.QueryContext(ctx, query, args...)
if err != nil {
return nil, err
}
defer rows.Close()
var list []*store.MemoRelation
for rows.Next() {
memoRelation := &store.MemoRelation{}
if err := rows.Scan(&memoRelation.MemoID, &memoRelation.RelatedMemoID, &memoRelation.Type); err != nil {
return nil, err
}
list = append(list, memoRelation)
}
return list, rows.Err()
}
func (d *DB) DeleteMemoRelation(ctx context.Context, delete *store.DeleteMemoRelation) error {
qb := squirrel.Delete("memo_relation").
PlaceholderFormat(squirrel.Dollar)
if delete.MemoID != nil {
qb = qb.Where(squirrel.Eq{"memo_id": *delete.MemoID})
}
if delete.RelatedMemoID != nil {
qb = qb.Where(squirrel.Eq{"related_memo_id": *delete.RelatedMemoID})
}
if delete.Type != nil {
qb = qb.Where(squirrel.Eq{"type": *delete.Type})
}
stmt, args, err := qb.ToSql()
if err != nil {
return err
}
result, err := d.db.ExecContext(ctx, stmt, args...)
if err != nil {
return err
}
_, err = result.RowsAffected()
return err
}
func vacuumMemoRelations(ctx context.Context, tx *sql.Tx) error {
// First, build the subquery for memo_id
subQueryMemo, subArgsMemo, err := squirrel.Select("id").From("memo").PlaceholderFormat(squirrel.Dollar).ToSql()
if err != nil {
return err
}
// Note: The same subquery is used for related_memo_id as it's also checking against the "memo" table
// Now, build the main delete query using the subqueries
query, args, err := squirrel.Delete("memo_relation").
Where(fmt.Sprintf("memo_id NOT IN (%s)", subQueryMemo), subArgsMemo...).
Where(fmt.Sprintf("related_memo_id NOT IN (%s)", subQueryMemo), subArgsMemo...).
PlaceholderFormat(squirrel.Dollar).
ToSql()
if err != nil {
return err
}
// Combine the arguments for both instances of the same subquery
args = append(args, subArgsMemo...)
// Execute the query
_, err = tx.ExecContext(ctx, query, args...)
return err
}

@ -0,0 +1,163 @@
-- drop all tables first (PostgreSQL style)
DROP TABLE IF EXISTS migration_history CASCADE;
DROP TABLE IF EXISTS system_setting CASCADE;
DROP TABLE IF EXISTS "user" CASCADE;
DROP TABLE IF EXISTS user_setting CASCADE;
DROP TABLE IF EXISTS memo CASCADE;
DROP TABLE IF EXISTS memo_organizer CASCADE;
DROP TABLE IF EXISTS memo_relation CASCADE;
DROP TABLE IF EXISTS resource CASCADE;
DROP TABLE IF EXISTS tag CASCADE;
DROP TABLE IF EXISTS activity CASCADE;
DROP TABLE IF EXISTS storage CASCADE;
DROP TABLE IF EXISTS idp CASCADE;
DROP TABLE IF EXISTS inbox CASCADE;
DROP TABLE IF EXISTS webhook CASCADE;
-- migration_history
CREATE TABLE migration_history (
version VARCHAR(255) NOT NULL PRIMARY KEY,
created_ts TIMESTAMP WITHOUT TIME ZONE NOT NULL DEFAULT CURRENT_TIMESTAMP
);
-- system_setting
CREATE TABLE system_setting (
name VARCHAR(255) NOT NULL PRIMARY KEY,
value TEXT NOT NULL,
description TEXT NOT NULL
);
-- user
CREATE TABLE "user" (
id SERIAL PRIMARY KEY,
created_ts TIMESTAMP WITHOUT TIME ZONE NOT NULL DEFAULT CURRENT_TIMESTAMP,
updated_ts TIMESTAMP WITHOUT TIME ZONE NOT NULL DEFAULT CURRENT_TIMESTAMP,
row_status VARCHAR(255) NOT NULL DEFAULT 'NORMAL',
username VARCHAR(255) NOT NULL UNIQUE,
role VARCHAR(255) NOT NULL DEFAULT 'USER',
email VARCHAR(255) NOT NULL DEFAULT '',
nickname VARCHAR(255) NOT NULL DEFAULT '',
password_hash VARCHAR(255) NOT NULL,
avatar_url TEXT NOT NULL
);
-- user_setting
CREATE TABLE user_setting (
user_id INT NOT NULL,
key VARCHAR(255) NOT NULL,
value TEXT NOT NULL,
UNIQUE(user_id, key),
FOREIGN KEY (user_id) REFERENCES "user"(id) ON DELETE CASCADE
);
-- memo
CREATE TABLE memo (
id SERIAL PRIMARY KEY,
creator_id INT NOT NULL,
created_ts TIMESTAMP WITHOUT TIME ZONE NOT NULL DEFAULT CURRENT_TIMESTAMP,
updated_ts TIMESTAMP WITHOUT TIME ZONE NOT NULL DEFAULT CURRENT_TIMESTAMP,
row_status VARCHAR(255) NOT NULL DEFAULT 'NORMAL',
content TEXT NOT NULL,
visibility VARCHAR(255) NOT NULL DEFAULT 'PRIVATE'
);
-- memo_organizer
CREATE TABLE memo_organizer (
memo_id INT NOT NULL,
user_id INT NOT NULL,
pinned INT NOT NULL DEFAULT 0,
UNIQUE(memo_id, user_id)
);
-- memo_relation
CREATE TABLE memo_relation (
memo_id INT NOT NULL,
related_memo_id INT NOT NULL,
type VARCHAR(256) NOT NULL,
UNIQUE(memo_id, related_memo_id, type),
FOREIGN KEY (memo_id) REFERENCES memo(id) ON DELETE CASCADE,
FOREIGN KEY (related_memo_id) REFERENCES memo(id) ON DELETE CASCADE
);
-- resource
CREATE TABLE resource (
id SERIAL PRIMARY KEY,
creator_id INT NOT NULL,
created_ts TIMESTAMP WITHOUT TIME ZONE NOT NULL DEFAULT CURRENT_TIMESTAMP,
updated_ts TIMESTAMP WITHOUT TIME ZONE NOT NULL DEFAULT CURRENT_TIMESTAMP,
filename TEXT NOT NULL,
blob BYTEA,
external_link TEXT NOT NULL,
type VARCHAR(255) NOT NULL DEFAULT '',
size INT NOT NULL DEFAULT 0,
internal_path VARCHAR(255) NOT NULL DEFAULT '',
memo_id INT DEFAULT NULL
);
-- tag
CREATE TABLE tag (
name VARCHAR(255) NOT NULL,
creator_id INT NOT NULL,
UNIQUE(name, creator_id)
);
-- activity
CREATE TABLE activity (
id SERIAL PRIMARY KEY,
creator_id INT NOT NULL,
created_ts TIMESTAMP WITHOUT TIME ZONE NOT NULL DEFAULT CURRENT_TIMESTAMP,
type VARCHAR(255) NOT NULL DEFAULT '',
level VARCHAR(255) NOT NULL DEFAULT 'INFO',
payload TEXT NOT NULL
);
-- storage
CREATE TABLE storage (
id SERIAL PRIMARY KEY,
name VARCHAR(256) NOT NULL,
type VARCHAR(256) NOT NULL,
config TEXT NOT NULL
);
-- idp
CREATE TABLE idp (
id SERIAL PRIMARY KEY,
name TEXT NOT NULL,
type TEXT NOT NULL,
identifier_filter VARCHAR(256) NOT NULL DEFAULT '',
config TEXT NOT NULL
);
-- inbox
CREATE TABLE inbox (
id SERIAL PRIMARY KEY,
created_ts TIMESTAMP WITHOUT TIME ZONE NOT NULL DEFAULT CURRENT_TIMESTAMP,
sender_id INT NOT NULL,
receiver_id INT NOT NULL,
status TEXT NOT NULL,
message TEXT NOT NULL
);
-- webhook
CREATE TABLE webhook (
id SERIAL PRIMARY KEY,
created_ts TIMESTAMP WITHOUT TIME ZONE NOT NULL DEFAULT CURRENT_TIMESTAMP,
updated_ts TIMESTAMP WITHOUT TIME ZONE NOT NULL DEFAULT CURRENT_TIMESTAMP,
row_status TEXT NOT NULL DEFAULT 'NORMAL',
creator_id INT NOT NULL,
name TEXT NOT NULL,
url TEXT NOT NULL
);

@ -0,0 +1,163 @@
-- drop all tables first (PostgreSQL style)
DROP TABLE IF EXISTS migration_history CASCADE;
DROP TABLE IF EXISTS system_setting CASCADE;
DROP TABLE IF EXISTS "user" CASCADE;
DROP TABLE IF EXISTS user_setting CASCADE;
DROP TABLE IF EXISTS memo CASCADE;
DROP TABLE IF EXISTS memo_organizer CASCADE;
DROP TABLE IF EXISTS memo_relation CASCADE;
DROP TABLE IF EXISTS resource CASCADE;
DROP TABLE IF EXISTS tag CASCADE;
DROP TABLE IF EXISTS activity CASCADE;
DROP TABLE IF EXISTS storage CASCADE;
DROP TABLE IF EXISTS idp CASCADE;
DROP TABLE IF EXISTS inbox CASCADE;
DROP TABLE IF EXISTS webhook CASCADE;
-- migration_history
CREATE TABLE migration_history (
version VARCHAR(255) NOT NULL PRIMARY KEY,
created_ts TIMESTAMP WITHOUT TIME ZONE NOT NULL DEFAULT CURRENT_TIMESTAMP
);
-- system_setting
CREATE TABLE system_setting (
name VARCHAR(255) NOT NULL PRIMARY KEY,
value TEXT NOT NULL,
description TEXT NOT NULL
);
-- user
CREATE TABLE "user" (
id SERIAL PRIMARY KEY,
created_ts TIMESTAMP WITHOUT TIME ZONE NOT NULL DEFAULT CURRENT_TIMESTAMP,
updated_ts TIMESTAMP WITHOUT TIME ZONE NOT NULL DEFAULT CURRENT_TIMESTAMP,
row_status VARCHAR(255) NOT NULL DEFAULT 'NORMAL',
username VARCHAR(255) NOT NULL UNIQUE,
role VARCHAR(255) NOT NULL DEFAULT 'USER',
email VARCHAR(255) NOT NULL DEFAULT '',
nickname VARCHAR(255) NOT NULL DEFAULT '',
password_hash VARCHAR(255) NOT NULL,
avatar_url TEXT NOT NULL
);
-- user_setting
CREATE TABLE user_setting (
user_id INT NOT NULL,
key VARCHAR(255) NOT NULL,
value TEXT NOT NULL,
UNIQUE(user_id, key),
FOREIGN KEY (user_id) REFERENCES "user"(id) ON DELETE CASCADE
);
-- memo
CREATE TABLE memo (
id SERIAL PRIMARY KEY,
creator_id INT NOT NULL,
created_ts TIMESTAMP WITHOUT TIME ZONE NOT NULL DEFAULT CURRENT_TIMESTAMP,
updated_ts TIMESTAMP WITHOUT TIME ZONE NOT NULL DEFAULT CURRENT_TIMESTAMP,
row_status VARCHAR(255) NOT NULL DEFAULT 'NORMAL',
content TEXT NOT NULL,
visibility VARCHAR(255) NOT NULL DEFAULT 'PRIVATE'
);
-- memo_organizer
CREATE TABLE memo_organizer (
memo_id INT NOT NULL,
user_id INT NOT NULL,
pinned INT NOT NULL DEFAULT 0,
UNIQUE(memo_id, user_id)
);
-- memo_relation
CREATE TABLE memo_relation (
memo_id INT NOT NULL,
related_memo_id INT NOT NULL,
type VARCHAR(256) NOT NULL,
UNIQUE(memo_id, related_memo_id, type),
FOREIGN KEY (memo_id) REFERENCES memo(id) ON DELETE CASCADE,
FOREIGN KEY (related_memo_id) REFERENCES memo(id) ON DELETE CASCADE
);
-- resource
CREATE TABLE resource (
id SERIAL PRIMARY KEY,
creator_id INT NOT NULL,
created_ts TIMESTAMP WITHOUT TIME ZONE NOT NULL DEFAULT CURRENT_TIMESTAMP,
updated_ts TIMESTAMP WITHOUT TIME ZONE NOT NULL DEFAULT CURRENT_TIMESTAMP,
filename TEXT NOT NULL,
blob BYTEA,
external_link TEXT NOT NULL,
type VARCHAR(255) NOT NULL DEFAULT '',
size INT NOT NULL DEFAULT 0,
internal_path VARCHAR(255) NOT NULL DEFAULT '',
memo_id INT DEFAULT NULL
);
-- tag
CREATE TABLE tag (
name VARCHAR(255) NOT NULL,
creator_id INT NOT NULL,
UNIQUE(name, creator_id)
);
-- activity
CREATE TABLE activity (
id SERIAL PRIMARY KEY,
creator_id INT NOT NULL,
created_ts TIMESTAMP WITHOUT TIME ZONE NOT NULL DEFAULT CURRENT_TIMESTAMP,
type VARCHAR(255) NOT NULL DEFAULT '',
level VARCHAR(255) NOT NULL DEFAULT 'INFO',
payload TEXT NOT NULL
);
-- storage
CREATE TABLE storage (
id SERIAL PRIMARY KEY,
name VARCHAR(256) NOT NULL,
type VARCHAR(256) NOT NULL,
config TEXT NOT NULL
);
-- idp
CREATE TABLE idp (
id SERIAL PRIMARY KEY,
name TEXT NOT NULL,
type TEXT NOT NULL,
identifier_filter VARCHAR(256) NOT NULL DEFAULT '',
config TEXT NOT NULL
);
-- inbox
CREATE TABLE inbox (
id SERIAL PRIMARY KEY,
created_ts TIMESTAMP WITHOUT TIME ZONE NOT NULL DEFAULT CURRENT_TIMESTAMP,
sender_id INT NOT NULL,
receiver_id INT NOT NULL,
status TEXT NOT NULL,
message TEXT NOT NULL
);
-- webhook
CREATE TABLE webhook (
id SERIAL PRIMARY KEY,
created_ts TIMESTAMP WITHOUT TIME ZONE NOT NULL DEFAULT CURRENT_TIMESTAMP,
updated_ts TIMESTAMP WITHOUT TIME ZONE NOT NULL DEFAULT CURRENT_TIMESTAMP,
row_status TEXT NOT NULL DEFAULT 'NORMAL',
creator_id INT NOT NULL,
name TEXT NOT NULL,
url TEXT NOT NULL
);

@ -0,0 +1,79 @@
package postgres
import (
"context"
"time"
"github.com/Masterminds/squirrel"
"github.com/usememos/memos/store"
)
func (d *DB) FindMigrationHistoryList(ctx context.Context, _ *store.FindMigrationHistory) ([]*store.MigrationHistory, error) {
qb := squirrel.Select("version", "created_ts").
From("migration_history").
OrderBy("created_ts DESC")
query, args, err := qb.PlaceholderFormat(squirrel.Dollar).ToSql()
if err != nil {
return nil, err
}
rows, err := d.db.QueryContext(ctx, query, args...)
if err != nil {
return nil, err
}
defer rows.Close()
list := make([]*store.MigrationHistory, 0)
for rows.Next() {
var migrationHistory store.MigrationHistory
var createdTs time.Time
if err := rows.Scan(&migrationHistory.Version, &createdTs); err != nil {
return nil, err
}
migrationHistory.CreatedTs = createdTs.UnixNano()
list = append(list, &migrationHistory)
}
if err := rows.Err(); err != nil {
return nil, err
}
return list, nil
}
func (d *DB) UpsertMigrationHistory(ctx context.Context, upsert *store.UpsertMigrationHistory) (*store.MigrationHistory, error) {
qb := squirrel.Insert("migration_history").
Columns("version").
Values(upsert.Version).
Suffix("ON CONFLICT (version) DO UPDATE SET version = ?", upsert.Version)
query, args, err := qb.PlaceholderFormat(squirrel.Dollar).ToSql()
if err != nil {
return nil, err
}
_, err = d.db.ExecContext(ctx, query, args...)
if err != nil {
return nil, err
}
var migrationHistory store.MigrationHistory
var createdTs time.Time
query, args, err = squirrel.Select("version", "created_ts").
From("migration_history").
Where(squirrel.Eq{"version": upsert.Version}).
PlaceholderFormat(squirrel.Dollar).
ToSql()
if err != nil {
return nil, err
}
if err := d.db.QueryRowContext(ctx, query, args...).Scan(&migrationHistory.Version, &createdTs); err != nil {
return nil, err
}
migrationHistory.CreatedTs = createdTs.UnixNano()
return &migrationHistory, nil
}

@ -0,0 +1,207 @@
package postgres
import (
"context"
"embed"
"fmt"
"io/fs"
"regexp"
"sort"
"strings"
"github.com/pkg/errors"
"github.com/usememos/memos/server/version"
"github.com/usememos/memos/store"
)
const (
latestSchemaFileName = "LATEST__SCHEMA.sql"
)
//go:embed migration
var migrationFS embed.FS
func (d *DB) Migrate(ctx context.Context) error {
if d.profile.IsDev() {
return d.nonProdMigrate(ctx)
}
return d.prodMigrate(ctx)
}
func (d *DB) nonProdMigrate(ctx context.Context) error {
rows, err := d.db.QueryContext(ctx, "SELECT tablename FROM pg_catalog.pg_tables WHERE schemaname != 'pg_catalog' AND schemaname != 'information_schema';")
if err != nil {
return errors.Errorf("failed to query database tables: %s", err)
}
if rows.Err() != nil {
return errors.Errorf("failed to query database tables: %s", err)
}
defer rows.Close()
var tables []string
for rows.Next() {
var table string
err := rows.Scan(&table)
if err != nil {
return errors.Errorf("failed to scan table name: %s", err)
}
tables = append(tables, table)
}
if len(tables) != 0 {
return nil
}
println("no tables in the database. start migration")
buf, err := migrationFS.ReadFile("migration/dev/" + latestSchemaFileName)
if err != nil {
return errors.Errorf("failed to read latest schema file: %s", err)
}
stmt := string(buf)
if _, err := d.db.ExecContext(ctx, stmt); err != nil {
return errors.Errorf("failed to exec SQL %s: %s", stmt, err)
}
// In demo mode, we should seed the database.
if d.profile.Mode == "demo" {
if err := d.seed(ctx); err != nil {
return errors.Wrap(err, "failed to seed")
}
}
return nil
}
func (d *DB) prodMigrate(ctx context.Context) error {
currentVersion := version.GetCurrentVersion(d.profile.Mode)
migrationHistoryList, err := d.FindMigrationHistoryList(ctx, &store.FindMigrationHistory{})
// If there is no migration history, we should apply the latest schema.
if err != nil || len(migrationHistoryList) == 0 {
buf, err := migrationFS.ReadFile("migration/prod/" + latestSchemaFileName)
if err != nil {
return errors.Errorf("failed to read latest schema file: %s", err)
}
stmt := string(buf)
if _, err := d.db.ExecContext(ctx, stmt); err != nil {
return errors.Errorf("failed to exec SQL %s: %s", stmt, err)
}
if _, err := d.UpsertMigrationHistory(ctx, &store.UpsertMigrationHistory{
Version: currentVersion,
}); err != nil {
return errors.Wrap(err, "failed to upsert migration history")
}
return nil
}
migrationHistoryVersionList := []string{}
for _, migrationHistory := range migrationHistoryList {
migrationHistoryVersionList = append(migrationHistoryVersionList, migrationHistory.Version)
}
sort.Sort(version.SortVersion(migrationHistoryVersionList))
latestMigrationHistoryVersion := migrationHistoryVersionList[len(migrationHistoryVersionList)-1]
if !version.IsVersionGreaterThan(version.GetSchemaVersion(currentVersion), latestMigrationHistoryVersion) {
return nil
}
println("start migrate")
for _, minorVersion := range getMinorVersionList() {
normalizedVersion := minorVersion + ".0"
if version.IsVersionGreaterThan(normalizedVersion, latestMigrationHistoryVersion) && version.IsVersionGreaterOrEqualThan(currentVersion, normalizedVersion) {
println("applying migration for", normalizedVersion)
if err := d.applyMigrationForMinorVersion(ctx, minorVersion); err != nil {
return errors.Wrap(err, "failed to apply minor version migration")
}
}
}
println("end migrate")
return nil
}
func (d *DB) applyMigrationForMinorVersion(ctx context.Context, minorVersion string) error {
filenames, err := fs.Glob(migrationFS, fmt.Sprintf("migration/prod/%s/*.sql", minorVersion))
if err != nil {
return errors.Wrap(err, "failed to read ddl files")
}
sort.Strings(filenames)
// Loop over all migration files and execute them in order.
for _, filename := range filenames {
buf, err := migrationFS.ReadFile(filename)
if err != nil {
return errors.Wrapf(err, "failed to read minor version migration file, filename=%s", filename)
}
for _, stmt := range strings.Split(string(buf), ";") {
if strings.TrimSpace(stmt) == "" {
continue
}
if _, err := d.db.ExecContext(ctx, stmt); err != nil {
return errors.Wrapf(err, "migrate error: %s", stmt)
}
}
}
// Upsert the newest version to migration_history.
version := minorVersion + ".0"
if _, err = d.UpsertMigrationHistory(ctx, &store.UpsertMigrationHistory{Version: version}); err != nil {
return errors.Wrapf(err, "failed to upsert migration history with version: %s", version)
}
return nil
}
//go:embed seed
var seedFS embed.FS
func (d *DB) seed(ctx context.Context) error {
filenames, err := fs.Glob(seedFS, "seed/*.sql")
if err != nil {
return errors.Wrap(err, "failed to read seed files")
}
sort.Strings(filenames)
// Loop over all seed files and execute them in order.
for _, filename := range filenames {
buf, err := seedFS.ReadFile(filename)
if err != nil {
return errors.Wrapf(err, "failed to read seed file, filename=%s", filename)
}
for _, stmt := range strings.Split(string(buf), ";") {
if strings.TrimSpace(stmt) == "" {
continue
}
if _, err := d.db.ExecContext(ctx, stmt); err != nil {
return errors.Wrapf(err, "seed error: %s", stmt)
}
}
}
return nil
}
// minorDirRegexp is a regular expression for minor version directory.
var minorDirRegexp = regexp.MustCompile(`^migration/prod/[0-9]+\.[0-9]+$`)
func getMinorVersionList() []string {
minorVersionList := []string{}
if err := fs.WalkDir(migrationFS, "migration", func(path string, file fs.DirEntry, err error) error {
if err != nil {
return err
}
if file.IsDir() && minorDirRegexp.MatchString(path) {
minorVersionList = append(minorVersionList, file.Name())
}
return nil
}); err != nil {
panic(err)
}
sort.Sort(version.SortVersion(minorVersionList))
return minorVersionList
}

@ -0,0 +1,87 @@
package postgres
import (
"context"
"database/sql"
"log"
// Import the PostgreSQL driver.
_ "github.com/lib/pq"
"github.com/pkg/errors"
"github.com/usememos/memos/server/profile"
"github.com/usememos/memos/store"
)
type DB struct {
db *sql.DB
profile *profile.Profile
// Add any other fields as needed
}
func NewDB(profile *profile.Profile) (store.Driver, error) {
if profile == nil {
return nil, errors.New("profile is nil")
}
// Open the PostgreSQL connection
db, err := sql.Open("postgres", profile.DSN)
if err != nil {
log.Printf("Failed to open database: %s", err)
return nil, errors.Wrapf(err, "failed to open database: %s", profile.DSN)
}
var driver store.Driver = &DB{
db: db,
profile: profile,
}
// Return the DB struct
return driver, nil
}
func (d *DB) GetDB() *sql.DB {
return d.db
}
func (d *DB) Vacuum(ctx context.Context) error {
tx, err := d.db.BeginTx(ctx, nil)
if err != nil {
return err
}
defer tx.Rollback()
if err := vacuumMemo(ctx, tx); err != nil {
return err
}
if err := vacuumResource(ctx, tx); err != nil {
return err
}
if err := vacuumUserSetting(ctx, tx); err != nil {
return err
}
if err := vacuumMemoOrganizer(ctx, tx); err != nil {
return err
}
if err := vacuumMemoRelations(ctx, tx); err != nil {
return err
}
if err := vacuumTag(ctx, tx); err != nil {
// Prevent revive warning.
return err
}
return tx.Commit()
}
func (*DB) BackupTo(context.Context, string) error {
return errors.New("Please use postgresdump to backup")
}
func (*DB) GetCurrentDBSize(context.Context) (int64, error) {
return 0, errors.New("unimplemented")
}
func (d *DB) Close() error {
return d.db.Close()
}

@ -0,0 +1,229 @@
package postgres
import (
"context"
"database/sql"
"fmt"
"time"
"github.com/Masterminds/squirrel"
"github.com/pkg/errors"
"github.com/usememos/memos/store"
)
func (d *DB) CreateResource(ctx context.Context, create *store.Resource) (*store.Resource, error) {
qb := squirrel.Insert("resource").Columns("filename", "blob", "external_link", "type", "size", "creator_id", "internal_path")
values := []any{create.Filename, create.Blob, create.ExternalLink, create.Type, create.Size, create.CreatorID, create.InternalPath}
if create.ID != 0 {
qb = qb.Columns("id")
values = append(values, create.ID)
}
if create.CreatedTs != 0 {
qb = qb.Columns("created_ts")
values = append(values, time.Unix(0, create.CreatedTs))
}
if create.UpdatedTs != 0 {
qb = qb.Columns("updated_ts")
values = append(values, time.Unix(0, create.UpdatedTs))
}
if create.MemoID != nil {
qb = qb.Columns("memo_id")
values = append(values, *create.MemoID)
}
qb = qb.Values(values...).Suffix("RETURNING id")
query, args, err := qb.PlaceholderFormat(squirrel.Dollar).ToSql()
if err != nil {
return nil, err
}
var id int32
err = d.db.QueryRowContext(ctx, query, args...).Scan(&id)
if err != nil {
return nil, err
}
list, err := d.ListResources(ctx, &store.FindResource{ID: &id})
if err != nil {
return nil, err
}
if len(list) != 1 {
return nil, errors.Wrapf(nil, "unexpected resource count: %d", len(list))
}
return list[0], nil
}
func (d *DB) ListResources(ctx context.Context, find *store.FindResource) ([]*store.Resource, error) {
qb := squirrel.Select("id", "filename", "external_link", "type", "size", "creator_id", "created_ts", "updated_ts", "internal_path", "memo_id").From("resource")
if v := find.ID; v != nil {
qb = qb.Where(squirrel.Eq{"id": *v})
}
if v := find.CreatorID; v != nil {
qb = qb.Where(squirrel.Eq{"creator_id": *v})
}
if v := find.Filename; v != nil {
qb = qb.Where(squirrel.Eq{"filename": *v})
}
if v := find.MemoID; v != nil {
qb = qb.Where(squirrel.Eq{"memo_id": *v})
}
if find.HasRelatedMemo {
qb = qb.Where("memo_id IS NOT NULL")
}
if find.GetBlob {
qb = qb.Columns("blob")
}
qb = qb.GroupBy("id").OrderBy("created_ts DESC")
if find.Limit != nil {
qb = qb.Limit(uint64(*find.Limit))
if find.Offset != nil {
qb = qb.Offset(uint64(*find.Offset))
}
}
query, args, err := qb.PlaceholderFormat(squirrel.Dollar).ToSql()
if err != nil {
return nil, err
}
rows, err := d.db.QueryContext(ctx, query, args...)
if err != nil {
return nil, err
}
defer rows.Close()
list := make([]*store.Resource, 0)
for rows.Next() {
resource := store.Resource{}
var memoID sql.NullInt32
var createdTs, updatedTs time.Time
dests := []any{
&resource.ID,
&resource.Filename,
&resource.ExternalLink,
&resource.Type,
&resource.Size,
&resource.CreatorID,
&createdTs,
&updatedTs,
&resource.InternalPath,
&memoID,
}
if find.GetBlob {
dests = append(dests, &resource.Blob)
}
if err := rows.Scan(dests...); err != nil {
return nil, err
}
resource.CreatedTs = createdTs.UnixNano()
resource.UpdatedTs = updatedTs.UnixNano()
if memoID.Valid {
resource.MemoID = &memoID.Int32
}
list = append(list, &resource)
}
if err := rows.Err(); err != nil {
return nil, err
}
return list, nil
}
func (d *DB) UpdateResource(ctx context.Context, update *store.UpdateResource) (*store.Resource, error) {
qb := squirrel.Update("resource")
if v := update.UpdatedTs; v != nil {
qb = qb.Set("updated_ts", time.Unix(0, *v))
}
if v := update.Filename; v != nil {
qb = qb.Set("filename", *v)
}
if v := update.InternalPath; v != nil {
qb = qb.Set("internal_path", *v)
}
if v := update.MemoID; v != nil {
qb = qb.Set("memo_id", *v)
}
if v := update.Blob; v != nil {
qb = qb.Set("blob", v)
}
qb = qb.Where(squirrel.Eq{"id": update.ID})
query, args, err := qb.PlaceholderFormat(squirrel.Dollar).ToSql()
if err != nil {
return nil, err
}
if _, err := d.db.ExecContext(ctx, query, args...); err != nil {
return nil, err
}
list, err := d.ListResources(ctx, &store.FindResource{ID: &update.ID})
if err != nil {
return nil, err
}
if len(list) != 1 {
return nil, errors.Wrapf(nil, "unexpected resource count: %d", len(list))
}
return list[0], nil
}
func (d *DB) DeleteResource(ctx context.Context, delete *store.DeleteResource) error {
qb := squirrel.Delete("resource").Where(squirrel.Eq{"id": delete.ID})
query, args, err := qb.PlaceholderFormat(squirrel.Dollar).ToSql()
if err != nil {
return err
}
result, err := d.db.ExecContext(ctx, query, args...)
if err != nil {
return err
}
if _, err := result.RowsAffected(); err != nil {
return err
}
if err := d.Vacuum(ctx); err != nil {
// Prevent linter warning.
return err
}
return nil
}
func vacuumResource(ctx context.Context, tx *sql.Tx) error {
// First, build the subquery
subQuery, subArgs, err := squirrel.Select("id").From("user").PlaceholderFormat(squirrel.Dollar).ToSql()
if err != nil {
return err
}
// Now, build the main delete query using the subquery
query, args, err := squirrel.Delete("resource").
Where(fmt.Sprintf("creator_id NOT IN (%s)", subQuery), subArgs...).
PlaceholderFormat(squirrel.Dollar).
ToSql()
if err != nil {
return err
}
// Execute the query
_, err = tx.ExecContext(ctx, query, args...)
return err
}

@ -0,0 +1,4 @@
TRUNCATE TABLE memo_organizer;
TRUNCATE TABLE resource;
TRUNCATE TABLE memo;
TRUNCATE TABLE user;

@ -0,0 +1,44 @@
INSERT INTO "user" (
id,
username,
role,
email,
nickname,
row_status,
avatar_url,
password_hash
)
VALUES
(
101,
'memos-demo',
'HOST',
'demo@usememos.com',
'Derobot',
'NORMAL',
'',
-- raw password: secret
'$2a$14$ajq8Q7fbtFRQvXpdCq7Jcuy.Rx1h/L4J60Otx.gyNLbAYctGMJ9tK'
),
(
102,
'jack',
'USER',
'jack@usememos.com',
'Jack',
'NORMAL',
'',
-- raw password: secret
'$2a$14$ajq8Q7fbtFRQvXpdCq7Jcuy.Rx1h/L4J60Otx.gyNLbAYctGMJ9tK'
),
(
103,
'bob',
'USER',
'bob@usememos.com',
'Bob',
'ARCHIVED',
'',
-- raw password: secret
'$2a$14$ajq8Q7fbtFRQvXpdCq7Jcuy.Rx1h/L4J60Otx.gyNLbAYctGMJ9tK'
);

@ -0,0 +1,34 @@
INSERT INTO memo (id, content, creator_id)
VALUES
(
1,
'#Hello 👋 Welcome to memos.',
101
);
INSERT INTO memo (id, content, creator_id, visibility)
VALUES
(
2,
E'#TODO\n- [x] Take more photos about **🌄 sunset**\n- [x] Clean the room\n- [ ] Read *📖 The Little Prince*\n(👆 click to toggle status)',
101,
'PROTECTED'
),
(
3,
E'**[Slash](https://github.com/yourselfhosted/slash)**: A bookmarking and url shortener, save and share your links very easily.\n**[SQL Chat](https://www.sqlchat.ai)**: Chat-based SQL Client',
101,
'PUBLIC'
),
(
4,
E'#TODO\n- [x] Take more photos about **🌄 sunset**\n- [ ] Clean the classroom\n- [ ] Watch *👦 The Boys*\n(👆 click to toggle status)',
102,
'PROTECTED'
),
(
5,
'三人行,必有我师焉!👨‍🏫',
102,
'PUBLIC'
);

@ -0,0 +1,5 @@
INSERT INTO
memo_organizer (memo_id, user_id, pinned)
VALUES
(1, 101, 1),
(3, 101, 1);

@ -0,0 +1,6 @@
INSERT INTO
tag (name, creator_id)
VALUES
('Hello', 101),
('TODO', 101),
('TODO', 102);

@ -0,0 +1,125 @@
package postgres
import (
"context"
"github.com/Masterminds/squirrel"
"github.com/usememos/memos/store"
)
func (d *DB) CreateStorage(ctx context.Context, create *store.Storage) (*store.Storage, error) {
qb := squirrel.Insert("storage").Columns("name", "type", "config")
values := []any{create.Name, create.Type, create.Config}
if create.ID != 0 {
qb = qb.Columns("id")
values = append(values, create.ID)
}
qb = qb.Values(values...).Suffix("RETURNING id")
query, args, err := qb.PlaceholderFormat(squirrel.Dollar).ToSql()
if err != nil {
return nil, err
}
err = d.db.QueryRowContext(ctx, query, args...).Scan(&create.ID)
if err != nil {
return nil, err
}
return create, nil
}
func (d *DB) ListStorages(ctx context.Context, find *store.FindStorage) ([]*store.Storage, error) {
qb := squirrel.Select("id", "name", "type", "config").From("storage").OrderBy("id DESC")
if find.ID != nil {
qb = qb.Where(squirrel.Eq{"id": *find.ID})
}
query, args, err := qb.PlaceholderFormat(squirrel.Dollar).ToSql()
if err != nil {
return nil, err
}
rows, err := d.db.QueryContext(ctx, query, args...)
if err != nil {
return nil, err
}
defer rows.Close()
list := []*store.Storage{}
for rows.Next() {
storage := &store.Storage{}
if err := rows.Scan(&storage.ID, &storage.Name, &storage.Type, &storage.Config); err != nil {
return nil, err
}
list = append(list, storage)
}
if err := rows.Err(); err != nil {
return nil, err
}
return list, nil
}
func (d *DB) UpdateStorage(ctx context.Context, update *store.UpdateStorage) (*store.Storage, error) {
qb := squirrel.Update("storage")
if update.Name != nil {
qb = qb.Set("name", *update.Name)
}
if update.Config != nil {
qb = qb.Set("config", *update.Config)
}
qb = qb.Where(squirrel.Eq{"id": update.ID})
query, args, err := qb.PlaceholderFormat(squirrel.Dollar).ToSql()
if err != nil {
return nil, err
}
_, err = d.db.ExecContext(ctx, query, args...)
if err != nil {
return nil, err
}
storage := &store.Storage{}
query, args, err = squirrel.Select("id", "name", "type", "config").
From("storage").
Where(squirrel.Eq{"id": update.ID}).
PlaceholderFormat(squirrel.Dollar).
ToSql()
if err != nil {
return nil, err
}
if err := d.db.QueryRowContext(ctx, query, args...).Scan(&storage.ID, &storage.Name, &storage.Type, &storage.Config); err != nil {
return nil, err
}
return storage, nil
}
func (d *DB) DeleteStorage(ctx context.Context, delete *store.DeleteStorage) error {
qb := squirrel.Delete("storage").Where(squirrel.Eq{"id": delete.ID})
query, args, err := qb.PlaceholderFormat(squirrel.Dollar).ToSql()
if err != nil {
return err
}
result, err := d.db.ExecContext(ctx, query, args...)
if err != nil {
return err
}
if _, err := result.RowsAffected(); err != nil {
return err
}
return nil
}

@ -0,0 +1,61 @@
package postgres
import (
"context"
"github.com/Masterminds/squirrel"
"github.com/usememos/memos/store"
)
func (d *DB) UpsertSystemSetting(ctx context.Context, upsert *store.SystemSetting) (*store.SystemSetting, error) {
qb := squirrel.Insert("system_setting").
Columns("name", "value", "description").
Values(upsert.Name, upsert.Value, upsert.Description)
query, args, err := qb.PlaceholderFormat(squirrel.Dollar).ToSql()
if err != nil {
return nil, err
}
_, err = d.db.ExecContext(ctx, query, args...)
if err != nil {
return nil, err
}
return upsert, nil
}
func (d *DB) ListSystemSettings(ctx context.Context, find *store.FindSystemSetting) ([]*store.SystemSetting, error) {
qb := squirrel.Select("name", "value", "description").From("system_setting")
if find.Name != "" {
qb = qb.Where(squirrel.Eq{"name": find.Name})
}
query, args, err := qb.PlaceholderFormat(squirrel.Dollar).ToSql()
if err != nil {
return nil, err
}
rows, err := d.db.QueryContext(ctx, query, args...)
if err != nil {
return nil, err
}
defer rows.Close()
list := []*store.SystemSetting{}
for rows.Next() {
systemSetting := &store.SystemSetting{}
if err := rows.Scan(&systemSetting.Name, &systemSetting.Value, &systemSetting.Description); err != nil {
return nil, err
}
list = append(list, systemSetting)
}
if err := rows.Err(); err != nil {
return nil, err
}
return list, nil
}

@ -0,0 +1,113 @@
package postgres
import (
"context"
"database/sql"
"fmt"
"github.com/Masterminds/squirrel"
"github.com/usememos/memos/store"
)
func (d *DB) UpsertTag(ctx context.Context, upsert *store.Tag) (*store.Tag, error) {
builder := squirrel.Insert("tag").
Columns("name", "creator_id").
Values(upsert.Name, upsert.CreatorID). // on conflict is not necessary, as only the pair of name and creator_id is unique
PlaceholderFormat(squirrel.Dollar)
query, args, err := builder.ToSql()
if err != nil {
return nil, err
}
if _, err := d.db.ExecContext(ctx, query, args...); err != nil {
return nil, err
}
return upsert, nil
}
func (d *DB) ListTags(ctx context.Context, find *store.FindTag) ([]*store.Tag, error) {
builder := squirrel.Select("name", "creator_id").From("tag").
Where("1 = 1").
OrderBy("name ASC").
PlaceholderFormat(squirrel.Dollar)
if find.CreatorID != 0 {
builder = builder.Where("creator_id = ?", find.CreatorID)
}
query, args, err := builder.ToSql()
if err != nil {
return nil, err
}
rows, err := d.db.QueryContext(ctx, query, args...)
if err != nil {
return nil, err
}
defer rows.Close()
list := []*store.Tag{}
for rows.Next() {
tag := &store.Tag{}
if err := rows.Scan(
&tag.Name,
&tag.CreatorID,
); err != nil {
return nil, err
}
list = append(list, tag)
}
if err := rows.Err(); err != nil {
return nil, err
}
return list, nil
}
func (d *DB) DeleteTag(ctx context.Context, delete *store.DeleteTag) error {
builder := squirrel.Delete("tag").
Where(squirrel.Eq{"name": delete.Name, "creator_id": delete.CreatorID}).
PlaceholderFormat(squirrel.Dollar)
query, args, err := builder.ToSql()
if err != nil {
return err
}
result, err := d.db.ExecContext(ctx, query, args...)
if err != nil {
return err
}
if _, err = result.RowsAffected(); err != nil {
return err
}
return nil
}
func vacuumTag(ctx context.Context, tx *sql.Tx) error {
// First, build the subquery for creator_id
subQuery, subArgs, err := squirrel.Select("id").From("\"user\"").PlaceholderFormat(squirrel.Dollar).ToSql()
if err != nil {
return err
}
// Now, build the main delete query using the subquery
query, args, err := squirrel.Delete("tag").
Where(fmt.Sprintf("creator_id NOT IN (%s)", subQuery), subArgs...).
PlaceholderFormat(squirrel.Dollar).
ToSql()
if err != nil {
return err
}
// Execute the query
_, err = tx.ExecContext(ctx, query, args...)
return err
}

@ -0,0 +1,225 @@
package postgres
import (
"context"
"github.com/Masterminds/squirrel"
"github.com/pkg/errors"
"github.com/usememos/memos/store"
)
func (d *DB) CreateUser(ctx context.Context, create *store.User) (*store.User, error) {
// Start building the insert statement
builder := squirrel.Insert("\"user\"").PlaceholderFormat(squirrel.Dollar)
columns := []string{"username", "role", "email", "nickname", "password_hash", "avatar_url"}
builder = builder.Columns(columns...)
values := []any{create.Username, create.Role, create.Email, create.Nickname, create.PasswordHash, create.AvatarURL}
if create.RowStatus != "" {
builder = builder.Columns("row_status")
values = append(values, create.RowStatus)
}
if create.CreatedTs != 0 {
builder = builder.Columns("created_ts")
values = append(values, squirrel.Expr("TO_TIMESTAMP(?)", create.CreatedTs))
}
if create.UpdatedTs != 0 {
builder = builder.Columns("updated_ts")
values = append(values, squirrel.Expr("TO_TIMESTAMP(?)", create.UpdatedTs))
}
if create.ID != 0 {
builder = builder.Columns("id")
values = append(values, create.ID)
}
builder = builder.Values(values...)
builder = builder.Suffix("RETURNING id")
// Prepare the final query
query, args, err := builder.ToSql()
if err != nil {
return nil, err
}
// Execute the query and get the returned ID
var id int32
err = d.db.QueryRowContext(ctx, query, args...).Scan(&id)
if err != nil {
return nil, err
}
// Use the returned ID to retrieve the full user object
user, err := d.GetUser(ctx, &store.FindUser{ID: &id})
if err != nil {
return nil, err
}
return user, nil
}
func (d *DB) UpdateUser(ctx context.Context, update *store.UpdateUser) (*store.User, error) {
// Start building the update statement
builder := squirrel.Update("\"user\"").PlaceholderFormat(squirrel.Dollar)
// Conditionally add set clauses
if v := update.UpdatedTs; v != nil {
builder = builder.Set("updated_ts", squirrel.Expr("to_timestamp(?)", *v))
}
if v := update.RowStatus; v != nil {
builder = builder.Set("row_status", *v)
}
if v := update.Username; v != nil {
builder = builder.Set("username", *v)
}
if v := update.Email; v != nil {
builder = builder.Set("email", *v)
}
if v := update.Nickname; v != nil {
builder = builder.Set("nickname", *v)
}
if v := update.AvatarURL; v != nil {
builder = builder.Set("avatar_url", *v)
}
if v := update.PasswordHash; v != nil {
builder = builder.Set("password_hash", *v)
}
// Add the WHERE clause
builder = builder.Where(squirrel.Eq{"id": update.ID})
// Prepare the final query
query, args, err := builder.ToSql()
if err != nil {
return nil, err
}
// Execute the query with the context
if _, err := d.db.ExecContext(ctx, query, args...); err != nil {
return nil, err
}
// Retrieve the updated user
user, err := d.GetUser(ctx, &store.FindUser{ID: &update.ID})
if err != nil {
return nil, err
}
return user, nil
}
func (d *DB) ListUsers(ctx context.Context, find *store.FindUser) ([]*store.User, error) {
// Start building the SELECT statement
builder := squirrel.Select("id", "username", "role", "email", "nickname", "password_hash", "avatar_url",
"FLOOR(EXTRACT(EPOCH FROM created_ts)) AS created_ts", "FLOOR(EXTRACT(EPOCH FROM updated_ts)) AS updated_ts", "row_status").
From("\"user\"").
PlaceholderFormat(squirrel.Dollar)
// 1 = 1 is often used as a no-op in SQL, ensuring there's always a WHERE clause
builder = builder.Where("1 = 1")
// Conditionally add where clauses
if v := find.ID; v != nil {
builder = builder.Where(squirrel.Eq{"id": *v})
}
if v := find.Username; v != nil {
builder = builder.Where(squirrel.Eq{"username": *v})
}
if v := find.Role; v != nil {
builder = builder.Where(squirrel.Eq{"role": *v})
}
if v := find.Email; v != nil {
builder = builder.Where(squirrel.Eq{"email": *v})
}
if v := find.Nickname; v != nil {
builder = builder.Where(squirrel.Eq{"nickname": *v})
}
// Add ordering
builder = builder.OrderBy("created_ts DESC", "row_status DESC")
// Prepare the final query
query, args, err := builder.ToSql()
if err != nil {
return nil, err
}
// Execute the query with the context
rows, err := d.db.QueryContext(ctx, query, args...)
if err != nil {
return nil, err
}
defer rows.Close()
list := make([]*store.User, 0)
for rows.Next() {
var user store.User
if err := rows.Scan(
&user.ID,
&user.Username,
&user.Role,
&user.Email,
&user.Nickname,
&user.PasswordHash,
&user.AvatarURL,
&user.CreatedTs,
&user.UpdatedTs,
&user.RowStatus,
); err != nil {
return nil, err
}
list = append(list, &user)
}
if err := rows.Err(); err != nil {
return nil, err
}
return list, nil
}
func (d *DB) GetUser(ctx context.Context, find *store.FindUser) (*store.User, error) {
list, err := d.ListUsers(ctx, find)
if err != nil {
return nil, err
}
if len(list) != 1 {
return nil, errors.Wrapf(nil, "unexpected user count: %d", len(list))
}
return list[0], nil
}
func (d *DB) DeleteUser(ctx context.Context, delete *store.DeleteUser) error {
// Start building the DELETE statement
builder := squirrel.Delete("\"user\"").
PlaceholderFormat(squirrel.Dollar).
Where(squirrel.Eq{"id": delete.ID})
// Prepare the final query
query, args, err := builder.ToSql()
if err != nil {
return err
}
// Execute the query with the context
result, err := d.db.ExecContext(ctx, query, args...)
if err != nil {
return err
}
if _, err := result.RowsAffected(); err != nil {
return err
}
if err := d.Vacuum(ctx); err != nil {
// Prevent linter warning.
return err
}
return nil
}

@ -0,0 +1,194 @@
package postgres
import (
"context"
"database/sql"
"fmt"
"github.com/Masterminds/squirrel"
"github.com/pkg/errors"
"google.golang.org/protobuf/encoding/protojson"
storepb "github.com/usememos/memos/proto/gen/store"
"github.com/usememos/memos/store"
)
func (d *DB) UpsertUserSetting(ctx context.Context, upsert *store.UserSetting) (*store.UserSetting, error) {
// Construct the query using Squirrel
query, args, err := squirrel.
Insert("user_setting").
Columns("user_id", "key", "value").
Values(upsert.UserID, upsert.Key, upsert.Value).
PlaceholderFormat(squirrel.Dollar).
// no need to specify ON CONFLICT clause, as the primary key is (user_id, key)
ToSql()
if err != nil {
return nil, err
}
// Execute the query
if _, err := d.db.ExecContext(ctx, query, args...); err != nil {
return nil, err
}
return upsert, nil
}
func (d *DB) ListUserSettings(ctx context.Context, find *store.FindUserSetting) ([]*store.UserSetting, error) {
// Start building the query
qb := squirrel.Select("user_id", "key", "value").From("user_setting").Where("1 = 1").PlaceholderFormat(squirrel.Dollar)
// Add conditions based on the provided find parameters
if v := find.Key; v != "" {
qb = qb.Where(squirrel.Eq{"key": v})
}
if v := find.UserID; v != nil {
qb = qb.Where(squirrel.Eq{"user_id": *v})
}
// Finalize the query
query, args, err := qb.ToSql()
if err != nil {
return nil, err
}
// Execute the query
rows, err := d.db.QueryContext(ctx, query, args...)
if err != nil {
return nil, err
}
defer rows.Close()
// Process the rows
userSettingList := make([]*store.UserSetting, 0)
for rows.Next() {
var userSetting store.UserSetting
if err := rows.Scan(
&userSetting.UserID,
&userSetting.Key,
&userSetting.Value,
); err != nil {
return nil, err
}
userSettingList = append(userSettingList, &userSetting)
}
if err := rows.Err(); err != nil {
return nil, err
}
return userSettingList, nil
}
func (d *DB) UpsertUserSettingV1(ctx context.Context, upsert *storepb.UserSetting) (*storepb.UserSetting, error) {
var valueString string
if upsert.Key == storepb.UserSettingKey_USER_SETTING_ACCESS_TOKENS {
valueBytes, err := protojson.Marshal(upsert.GetAccessTokens())
if err != nil {
return nil, err
}
valueString = string(valueBytes)
} else {
return nil, errors.New("invalid user setting key")
}
// Construct the query using Squirrel
query, args, err := squirrel.
Insert("user_setting").
Columns("user_id", "key", "value").
Values(upsert.UserId, upsert.Key.String(), valueString).
Suffix("ON CONFLICT (user_id, key) DO UPDATE SET value = EXCLUDED.value").
PlaceholderFormat(squirrel.Dollar).
ToSql()
if err != nil {
return nil, err
}
// Execute the query
if _, err := d.db.ExecContext(ctx, query, args...); err != nil {
return nil, err
}
return upsert, nil
}
func (d *DB) ListUserSettingsV1(ctx context.Context, find *store.FindUserSettingV1) ([]*storepb.UserSetting, error) {
// Start building the query using Squirrel
qb := squirrel.Select("user_id", "key", "value").From("user_setting").PlaceholderFormat(squirrel.Dollar)
// Add conditions based on the provided find parameters
if v := find.Key; v != storepb.UserSettingKey_USER_SETTING_KEY_UNSPECIFIED {
qb = qb.Where(squirrel.Eq{"key": v.String()})
}
if v := find.UserID; v != nil {
qb = qb.Where(squirrel.Eq{"user_id": *v})
}
// Finalize the query
query, args, err := qb.ToSql()
if err != nil {
return nil, err
}
// Execute the query
rows, err := d.db.QueryContext(ctx, query, args...)
if err != nil {
return nil, err
}
defer rows.Close()
// Process the rows
userSettingList := make([]*storepb.UserSetting, 0)
for rows.Next() {
userSetting := &storepb.UserSetting{}
var keyString, valueString string
if err := rows.Scan(
&userSetting.UserId,
&keyString,
&valueString,
); err != nil {
return nil, err
}
userSetting.Key = storepb.UserSettingKey(storepb.UserSettingKey_value[keyString])
if userSetting.Key == storepb.UserSettingKey_USER_SETTING_ACCESS_TOKENS {
accessTokensUserSetting := &storepb.AccessTokensUserSetting{}
if err := protojson.Unmarshal([]byte(valueString), accessTokensUserSetting); err != nil {
return nil, err
}
userSetting.Value = &storepb.UserSetting_AccessTokens{
AccessTokens: accessTokensUserSetting,
}
} else {
// Skip unknown user setting v1 key
continue
}
userSettingList = append(userSettingList, userSetting)
}
if err := rows.Err(); err != nil {
return nil, err
}
return userSettingList, nil
}
func vacuumUserSetting(ctx context.Context, tx *sql.Tx) error {
// First, build the subquery
subQuery, subArgs, err := squirrel.Select("id").From("\"user\"").PlaceholderFormat(squirrel.Dollar).ToSql()
if err != nil {
return err
}
// Now, build the main delete query using the subquery
query, args, err := squirrel.Delete("user_setting").
Where(fmt.Sprintf("user_id NOT IN (%s)", subQuery), subArgs...).
PlaceholderFormat(squirrel.Dollar).
ToSql()
if err != nil {
return err
}
// Execute the query
_, err = tx.ExecContext(ctx, query, args...)
return err
}

@ -0,0 +1,148 @@
package postgres
import (
"context"
"time"
"github.com/Masterminds/squirrel"
storepb "github.com/usememos/memos/proto/gen/store"
"github.com/usememos/memos/store"
)
func (d *DB) CreateWebhook(ctx context.Context, create *storepb.Webhook) (*storepb.Webhook, error) {
qb := squirrel.Insert("webhook").Columns("name", "url", "creator_id")
values := []any{create.Name, create.Url, create.CreatorId}
if create.Id != 0 {
qb = qb.Columns("id")
values = append(values, create.Id)
}
qb = qb.Values(values...).Suffix("RETURNING id")
query, args, err := qb.PlaceholderFormat(squirrel.Dollar).ToSql()
if err != nil {
return nil, err
}
err = d.db.QueryRowContext(ctx, query, args...).Scan(&create.Id)
if err != nil {
return nil, err
}
create, err = d.GetWebhook(ctx, &store.FindWebhook{ID: &create.Id})
if err != nil {
return nil, err
}
return create, nil
}
func (d *DB) ListWebhooks(ctx context.Context, find *store.FindWebhook) ([]*storepb.Webhook, error) {
qb := squirrel.Select("id", "created_ts", "updated_ts", "row_status", "creator_id", "name", "url").From("webhook").OrderBy("id DESC")
if find.ID != nil {
qb = qb.Where(squirrel.Eq{"id": *find.ID})
}
if find.CreatorID != nil {
qb = qb.Where(squirrel.Eq{"creator_id": *find.CreatorID})
}
query, args, err := qb.PlaceholderFormat(squirrel.Dollar).ToSql()
if err != nil {
return nil, err
}
rows, err := d.db.QueryContext(ctx, query, args...)
if err != nil {
return nil, err
}
defer rows.Close()
list := []*storepb.Webhook{}
for rows.Next() {
webhook := &storepb.Webhook{}
var rowStatus string
var createdTs, updatedTs time.Time
if err := rows.Scan(
&webhook.Id,
&createdTs,
&updatedTs,
&rowStatus,
&webhook.CreatorId,
&webhook.Name,
&webhook.Url,
); err != nil {
return nil, err
}
webhook.CreatedTs = createdTs.UnixNano()
webhook.UpdatedTs = updatedTs.UnixNano()
webhook.RowStatus = storepb.RowStatus(storepb.RowStatus_value[rowStatus])
list = append(list, webhook)
}
if err := rows.Err(); err != nil {
return nil, err
}
return list, nil
}
func (d *DB) GetWebhook(ctx context.Context, find *store.FindWebhook) (*storepb.Webhook, error) {
list, err := d.ListWebhooks(ctx, find)
if err != nil {
return nil, err
}
if len(list) == 0 {
return nil, nil
}
return list[0], nil
}
func (d *DB) UpdateWebhook(ctx context.Context, update *store.UpdateWebhook) (*storepb.Webhook, error) {
qb := squirrel.Update("webhook")
if update.RowStatus != nil {
qb = qb.Set("row_status", update.RowStatus.String())
}
if update.Name != nil {
qb = qb.Set("name", *update.Name)
}
if update.URL != nil {
qb = qb.Set("url", *update.URL)
}
qb = qb.Where(squirrel.Eq{"id": update.ID})
query, args, err := qb.PlaceholderFormat(squirrel.Dollar).ToSql()
if err != nil {
return nil, err
}
_, err = d.db.ExecContext(ctx, query, args...)
if err != nil {
return nil, err
}
webhook, err := d.GetWebhook(ctx, &store.FindWebhook{ID: &update.ID})
if err != nil {
return nil, err
}
return webhook, nil
}
func (d *DB) DeleteWebhook(ctx context.Context, delete *store.DeleteWebhook) error {
qb := squirrel.Delete("webhook").Where(squirrel.Eq{"id": delete.ID})
query, args, err := qb.PlaceholderFormat(squirrel.Dollar).ToSql()
if err != nil {
return err
}
_, err = d.db.ExecContext(ctx, query, args...)
return err
}

@ -49,4 +49,13 @@ func TestMemoStore(t *testing.T) {
})
require.NoError(t, err)
require.Equal(t, 0, len(memoList))
memoList, err = ts.ListMemos(ctx, &store.FindMemo{
CreatorID: &user.ID,
VisibilityList: []store.Visibility{
store.Public,
},
})
require.NoError(t, err)
require.Equal(t, 0, len(memoList))
}

Loading…
Cancel
Save