mirror of https://github.com/usememos/memos
feat: support mysql as backend storage driver (#2300)
* Rename checkDSN to checkDataDir * Add option to set DSN and db driver * Add mysql driver skeleton * Add mysql container in compose for debug * Add basic function for mysql driver * Cleanup go mod with tidy * Cleanup go.sum with tidy * Add DeleteUser support for mysql driver * Fix UpdateUser of mysql driver * Add DeleteTag support for mysql driver * Add DeleteResource support for mysql driver * Add UpdateMemo and DeleteMemo support for mysql driver * Add MemoRelation support for mysql driver * Add MemoOrganizer support for mysql driver * Add Idp support for mysql driver * Add Storage support for mysql driver * Add FindMemosVisibilityList support for mysql driver * Add Vacuum support for mysql driver * Add Migration support for mysql driver * Add Migration support for mysql driver * Fix ListMemo failed with referece * Change Activity.CreateTs type in MySQL * Change User.CreateTs type in MySQL * Fix by golangci-lint * Change Resource.CreateTs type in MySQL * Change MigrationHistory.CreateTs type in MySQL * Change Memo.CreateTs type in MySQLpull/2323/head
parent
4ca2b551f5
commit
c72f221fc0
@ -0,0 +1,64 @@
|
||||
package mysql
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
|
||||
"github.com/usememos/memos/store"
|
||||
)
|
||||
|
||||
func (d *Driver) CreateActivity(ctx context.Context, create *store.Activity) (*store.Activity, error) {
|
||||
stmt := `
|
||||
INSERT INTO activity (
|
||||
creator_id,
|
||||
type,
|
||||
level,
|
||||
payload
|
||||
)
|
||||
VALUES (?, ?, ?, ?)
|
||||
`
|
||||
result, err := d.db.ExecContext(ctx, stmt,
|
||||
create.CreatorID,
|
||||
create.Type,
|
||||
create.Level,
|
||||
create.Payload,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to db.Exec")
|
||||
}
|
||||
|
||||
id, err := result.LastInsertId()
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to db.LastInsertId")
|
||||
}
|
||||
|
||||
return d.FindActivity(ctx, id)
|
||||
}
|
||||
|
||||
func (d *Driver) FindActivity(ctx context.Context, id int64) (*store.Activity, error) {
|
||||
var activity store.Activity
|
||||
stmt := `
|
||||
SELECT
|
||||
id,
|
||||
creator_id,
|
||||
type,
|
||||
level,
|
||||
payload,
|
||||
UNIX_TIMESTAMP(created_ts)
|
||||
FROM activity
|
||||
WHERE id = ?
|
||||
`
|
||||
if err := d.db.QueryRowContext(ctx, stmt, id).Scan(
|
||||
&activity.ID,
|
||||
&activity.CreatorID,
|
||||
&activity.Type,
|
||||
&activity.Level,
|
||||
&activity.Payload,
|
||||
&activity.CreatedTs,
|
||||
); err != nil {
|
||||
return nil, errors.Wrap(err, "failed to db.QueryRow")
|
||||
}
|
||||
|
||||
return &activity, nil
|
||||
}
|
@ -0,0 +1,199 @@
|
||||
package mysql
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
|
||||
"github.com/usememos/memos/store"
|
||||
)
|
||||
|
||||
func (d *Driver) CreateIdentityProvider(ctx context.Context, create *store.IdentityProvider) (*store.IdentityProvider, error) {
|
||||
var configBytes []byte
|
||||
if create.Type == store.IdentityProviderOAuth2Type {
|
||||
bytes, err := json.Marshal(create.Config.OAuth2Config)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
configBytes = bytes
|
||||
} else {
|
||||
return nil, errors.Errorf("unsupported idp type %s", string(create.Type))
|
||||
}
|
||||
|
||||
stmt := `
|
||||
INSERT INTO idp (
|
||||
name,
|
||||
type,
|
||||
identifier_filter,
|
||||
config
|
||||
)
|
||||
VALUES (?, ?, ?, ?)
|
||||
`
|
||||
result, err := d.db.ExecContext(
|
||||
ctx,
|
||||
stmt,
|
||||
create.Name,
|
||||
create.Type,
|
||||
create.IdentifierFilter,
|
||||
string(configBytes),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
id, err := result.LastInsertId()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
create.ID = int32(id)
|
||||
return create, nil
|
||||
}
|
||||
|
||||
func (d *Driver) ListIdentityProviders(ctx context.Context, find *store.FindIdentityProvider) ([]*store.IdentityProvider, error) {
|
||||
where, args := []string{"1 = 1"}, []any{}
|
||||
if v := find.ID; v != nil {
|
||||
where, args = append(where, fmt.Sprintf("id = $%d", len(args)+1)), append(args, *v)
|
||||
}
|
||||
|
||||
rows, err := d.db.QueryContext(ctx, `
|
||||
SELECT
|
||||
id,
|
||||
name,
|
||||
type,
|
||||
identifier_filter,
|
||||
config
|
||||
FROM idp
|
||||
WHERE `+strings.Join(where, " AND ")+` ORDER BY id ASC`,
|
||||
args...,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var identityProviders []*store.IdentityProvider
|
||||
for rows.Next() {
|
||||
var identityProvider store.IdentityProvider
|
||||
var identityProviderConfig string
|
||||
if err := rows.Scan(
|
||||
&identityProvider.ID,
|
||||
&identityProvider.Name,
|
||||
&identityProvider.Type,
|
||||
&identityProvider.IdentifierFilter,
|
||||
&identityProviderConfig,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if identityProvider.Type == store.IdentityProviderOAuth2Type {
|
||||
oauth2Config := &store.IdentityProviderOAuth2Config{}
|
||||
if err := json.Unmarshal([]byte(identityProviderConfig), oauth2Config); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
identityProvider.Config = &store.IdentityProviderConfig{
|
||||
OAuth2Config: oauth2Config,
|
||||
}
|
||||
} else {
|
||||
return nil, errors.Errorf("unsupported idp type %s", string(identityProvider.Type))
|
||||
}
|
||||
identityProviders = append(identityProviders, &identityProvider)
|
||||
}
|
||||
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return identityProviders, nil
|
||||
}
|
||||
|
||||
func (d *Driver) GetIdentityProvider(ctx context.Context, find *store.FindIdentityProvider) (*store.IdentityProvider, error) {
|
||||
list, err := d.ListIdentityProviders(ctx, find)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(list) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
identityProvider := list[0]
|
||||
return identityProvider, nil
|
||||
}
|
||||
|
||||
func (d *Driver) UpdateIdentityProvider(ctx context.Context, update *store.UpdateIdentityProvider) (*store.IdentityProvider, error) {
|
||||
set, args := []string{}, []any{}
|
||||
if v := update.Name; v != nil {
|
||||
set, args = append(set, "name = ?"), append(args, *v)
|
||||
}
|
||||
if v := update.IdentifierFilter; v != nil {
|
||||
set, args = append(set, "identifier_filter = ?"), append(args, *v)
|
||||
}
|
||||
if v := update.Config; v != nil {
|
||||
var configBytes []byte
|
||||
if update.Type == store.IdentityProviderOAuth2Type {
|
||||
bytes, err := json.Marshal(update.Config.OAuth2Config)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
configBytes = bytes
|
||||
} else {
|
||||
return nil, errors.Errorf("unsupported idp type %s", string(update.Type))
|
||||
}
|
||||
set, args = append(set, "config = ?"), append(args, string(configBytes))
|
||||
}
|
||||
args = append(args, update.ID)
|
||||
|
||||
stmt := `
|
||||
UPDATE idp
|
||||
SET ` + strings.Join(set, ", ") + `
|
||||
WHERE id = ?
|
||||
RETURNING id, name, type, identifier_filter, config
|
||||
`
|
||||
_, err := d.db.ExecContext(ctx, stmt, args...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var identityProvider store.IdentityProvider
|
||||
var identityProviderConfig string
|
||||
stmt = `SELECT id, name, type, identifier_filter, config FROM idp WHERE id = ?`
|
||||
if err := d.db.QueryRowContext(ctx, stmt, update.ID).Scan(
|
||||
&identityProvider.ID,
|
||||
&identityProvider.Name,
|
||||
&identityProvider.Type,
|
||||
&identityProvider.IdentifierFilter,
|
||||
&identityProviderConfig,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if identityProvider.Type == store.IdentityProviderOAuth2Type {
|
||||
oauth2Config := &store.IdentityProviderOAuth2Config{}
|
||||
if err := json.Unmarshal([]byte(identityProviderConfig), oauth2Config); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
identityProvider.Config = &store.IdentityProviderConfig{
|
||||
OAuth2Config: oauth2Config,
|
||||
}
|
||||
} else {
|
||||
return nil, errors.Errorf("unsupported idp type %s", string(identityProvider.Type))
|
||||
}
|
||||
|
||||
return &identityProvider, nil
|
||||
}
|
||||
|
||||
func (d *Driver) DeleteIdentityProvider(ctx context.Context, delete *store.DeleteIdentityProvider) error {
|
||||
where, args := []string{"id = ?"}, []any{delete.ID}
|
||||
stmt := `DELETE FROM idp WHERE ` + strings.Join(where, " AND ")
|
||||
result, err := d.db.ExecContext(ctx, stmt, args...)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err = result.RowsAffected(); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
@ -0,0 +1,311 @@
|
||||
package mysql
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
|
||||
"github.com/usememos/memos/common/util"
|
||||
"github.com/usememos/memos/store"
|
||||
)
|
||||
|
||||
func (d *Driver) CreateMemo(ctx context.Context, create *store.Memo) (*store.Memo, error) {
|
||||
stmt := `
|
||||
INSERT INTO memo (
|
||||
creator_id,
|
||||
content,
|
||||
visibility
|
||||
)
|
||||
VALUES (?, ?, ?)
|
||||
`
|
||||
result, err := d.db.ExecContext(
|
||||
ctx,
|
||||
stmt,
|
||||
create.CreatorID,
|
||||
create.Content,
|
||||
create.Visibility,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
id, err := result.LastInsertId()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var memo store.Memo
|
||||
stmt = `
|
||||
SELECT
|
||||
id,
|
||||
creator_id,
|
||||
content,
|
||||
visibility,
|
||||
UNIX_TIMESTAMP(created_ts),
|
||||
UNIX_TIMESTAMP(updated_ts),
|
||||
row_status
|
||||
FROM memo
|
||||
WHERE id = ?
|
||||
`
|
||||
if err := d.db.QueryRowContext(ctx, stmt, id).Scan(
|
||||
&memo.ID,
|
||||
&memo.CreatorID,
|
||||
&memo.Content,
|
||||
&memo.Visibility,
|
||||
&memo.UpdatedTs,
|
||||
&memo.CreatedTs,
|
||||
&memo.RowStatus,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &memo, nil
|
||||
}
|
||||
|
||||
func (d *Driver) ListMemos(ctx context.Context, find *store.FindMemo) ([]*store.Memo, error) {
|
||||
where, args := []string{"1 = 1"}, []any{}
|
||||
|
||||
if v := find.ID; v != nil {
|
||||
where, args = append(where, "memo.id = ?"), append(args, *v)
|
||||
}
|
||||
if v := find.CreatorID; v != nil {
|
||||
where, args = append(where, "memo.creator_id = ?"), append(args, *v)
|
||||
}
|
||||
if v := find.RowStatus; v != nil {
|
||||
where, args = append(where, "memo.row_status = ?"), append(args, *v)
|
||||
}
|
||||
if v := find.CreatedTsBefore; v != nil {
|
||||
where, args = append(where, "UNIX_TIMESTAMP(memo.created_ts) < ?"), append(args, *v)
|
||||
}
|
||||
if v := find.CreatedTsAfter; v != nil {
|
||||
where, args = append(where, "UNIX_TIMESTAMP(memo.created_ts) > ?"), append(args, *v)
|
||||
}
|
||||
if v := find.Pinned; v != nil {
|
||||
where = append(where, "memo_organizer.pinned = 1")
|
||||
}
|
||||
if v := find.ContentSearch; len(v) != 0 {
|
||||
for _, s := range v {
|
||||
where, args = append(where, "memo.content LIKE ?"), append(args, "%"+s+"%")
|
||||
}
|
||||
}
|
||||
if v := find.VisibilityList; len(v) != 0 {
|
||||
list := []string{}
|
||||
for _, visibility := range v {
|
||||
list = append(list, "?")
|
||||
args = append(args, visibility)
|
||||
}
|
||||
where = append(where, fmt.Sprintf("memo.visibility in (%s)", strings.Join(list, ",")))
|
||||
}
|
||||
orders := []string{"pinned DESC"}
|
||||
if find.OrderByUpdatedTs {
|
||||
orders = append(orders, "updated_ts DESC")
|
||||
} else {
|
||||
orders = append(orders, "created_ts DESC")
|
||||
}
|
||||
orders = append(orders, "id DESC")
|
||||
|
||||
query := `
|
||||
SELECT
|
||||
memo.id AS id,
|
||||
memo.creator_id AS creator_id,
|
||||
UNIX_TIMESTAMP(memo.created_ts) AS created_ts,
|
||||
UNIX_TIMESTAMP(memo.updated_ts) AS updated_ts,
|
||||
memo.row_status AS row_status,
|
||||
memo.content AS content,
|
||||
memo.visibility AS visibility,
|
||||
MAX(CASE WHEN memo_organizer.pinned = 1 THEN 1 ELSE 0 END) AS pinned,
|
||||
GROUP_CONCAT(resource.id) AS resource_id_list,
|
||||
(
|
||||
SELECT
|
||||
GROUP_CONCAT(related_memo_id,':',type)
|
||||
FROM
|
||||
memo_relation
|
||||
WHERE
|
||||
memo_relation.memo_id = memo.id
|
||||
GROUP BY
|
||||
memo_relation.memo_id
|
||||
) AS relation_list
|
||||
FROM
|
||||
memo
|
||||
LEFT JOIN
|
||||
memo_organizer ON memo.id = memo_organizer.memo_id
|
||||
LEFT JOIN
|
||||
resource ON memo.id = resource.memo_id
|
||||
WHERE ` + strings.Join(where, " AND ") + `
|
||||
GROUP BY memo.id
|
||||
ORDER BY ` + strings.Join(orders, ", ") + `
|
||||
`
|
||||
if find.Limit != nil {
|
||||
query = fmt.Sprintf("%s LIMIT %d", query, *find.Limit)
|
||||
if find.Offset != nil {
|
||||
query = fmt.Sprintf("%s OFFSET %d", query, *find.Offset)
|
||||
}
|
||||
}
|
||||
|
||||
rows, err := d.db.QueryContext(ctx, query, args...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
list := make([]*store.Memo, 0)
|
||||
for rows.Next() {
|
||||
var memo store.Memo
|
||||
var memoResourceIDList sql.NullString
|
||||
var memoRelationList sql.NullString
|
||||
if err := rows.Scan(
|
||||
&memo.ID,
|
||||
&memo.CreatorID,
|
||||
&memo.CreatedTs,
|
||||
&memo.UpdatedTs,
|
||||
&memo.RowStatus,
|
||||
&memo.Content,
|
||||
&memo.Visibility,
|
||||
&memo.Pinned,
|
||||
&memoResourceIDList,
|
||||
&memoRelationList,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if memoResourceIDList.Valid {
|
||||
idStringList := strings.Split(memoResourceIDList.String, ",")
|
||||
memo.ResourceIDList = make([]int32, 0, len(idStringList))
|
||||
for _, idString := range idStringList {
|
||||
id, err := util.ConvertStringToInt32(idString)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
memo.ResourceIDList = append(memo.ResourceIDList, id)
|
||||
}
|
||||
}
|
||||
if memoRelationList.Valid {
|
||||
memo.RelationList = make([]*store.MemoRelation, 0)
|
||||
relatedMemoTypeList := strings.Split(memoRelationList.String, ",")
|
||||
for _, relatedMemoType := range relatedMemoTypeList {
|
||||
relatedMemoTypeList := strings.Split(relatedMemoType, ":")
|
||||
if len(relatedMemoTypeList) != 2 {
|
||||
return nil, errors.Errorf("invalid relation format")
|
||||
}
|
||||
relatedMemoID, err := util.ConvertStringToInt32(relatedMemoTypeList[0])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
memo.RelationList = append(memo.RelationList, &store.MemoRelation{
|
||||
MemoID: memo.ID,
|
||||
RelatedMemoID: relatedMemoID,
|
||||
Type: store.MemoRelationType(relatedMemoTypeList[1]),
|
||||
})
|
||||
}
|
||||
}
|
||||
list = append(list, &memo)
|
||||
}
|
||||
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return list, nil
|
||||
}
|
||||
|
||||
func (d *Driver) UpdateMemo(ctx context.Context, update *store.UpdateMemo) error {
|
||||
set, args := []string{}, []any{}
|
||||
if v := update.CreatedTs; v != nil {
|
||||
set, args = append(set, "created_ts = FROM_UNIXTIME(?)"), append(args, *v)
|
||||
}
|
||||
if v := update.UpdatedTs; v != nil {
|
||||
set, args = append(set, "updated_ts = FROM_UNIXTIME(?)"), append(args, *v)
|
||||
}
|
||||
if v := update.RowStatus; v != nil {
|
||||
set, args = append(set, "row_status = ?"), append(args, *v)
|
||||
}
|
||||
if v := update.Content; v != nil {
|
||||
set, args = append(set, "content = ?"), append(args, *v)
|
||||
}
|
||||
if v := update.Visibility; v != nil {
|
||||
set, args = append(set, "visibility = ?"), append(args, *v)
|
||||
}
|
||||
args = append(args, update.ID)
|
||||
|
||||
stmt := `
|
||||
UPDATE memo
|
||||
SET ` + strings.Join(set, ", ") + `
|
||||
WHERE id = ?
|
||||
`
|
||||
if _, err := d.db.ExecContext(ctx, stmt, args...); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d *Driver) DeleteMemo(ctx context.Context, delete *store.DeleteMemo) error {
|
||||
where, args := []string{"id = ?"}, []any{delete.ID}
|
||||
stmt := `DELETE FROM memo WHERE ` + strings.Join(where, " AND ")
|
||||
result, err := d.db.ExecContext(ctx, stmt, args...)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := result.RowsAffected(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := d.Vacuum(ctx); err != nil {
|
||||
// Prevent linter warning.
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d *Driver) FindMemosVisibilityList(ctx context.Context, memoIDs []int32) ([]store.Visibility, error) {
|
||||
args := make([]any, 0, len(memoIDs))
|
||||
list := make([]string, 0, len(memoIDs))
|
||||
for _, memoID := range memoIDs {
|
||||
args = append(args, memoID)
|
||||
list = append(list, "?")
|
||||
}
|
||||
|
||||
where := fmt.Sprintf("id in (%s)", strings.Join(list, ","))
|
||||
query := `SELECT DISTINCT(visibility) FROM memo WHERE ` + where
|
||||
rows, err := d.db.QueryContext(ctx, query, args...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
visibilityList := make([]store.Visibility, 0)
|
||||
for rows.Next() {
|
||||
var visibility store.Visibility
|
||||
if err := rows.Scan(&visibility); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
visibilityList = append(visibilityList, visibility)
|
||||
}
|
||||
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return visibilityList, nil
|
||||
}
|
||||
|
||||
func vacuumMemo(ctx context.Context, tx *sql.Tx) error {
|
||||
stmt := `
|
||||
DELETE FROM
|
||||
memo
|
||||
WHERE
|
||||
creator_id NOT IN (
|
||||
SELECT
|
||||
id
|
||||
FROM
|
||||
user
|
||||
)`
|
||||
_, err := tx.ExecContext(ctx, stmt)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
@ -0,0 +1,106 @@
|
||||
package mysql
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/usememos/memos/store"
|
||||
)
|
||||
|
||||
func (d *Driver) UpsertMemoOrganizer(ctx context.Context, upsert *store.MemoOrganizer) (*store.MemoOrganizer, error) {
|
||||
stmt := `
|
||||
INSERT INTO memo_organizer (
|
||||
memo_id,
|
||||
user_id,
|
||||
pinned
|
||||
)
|
||||
VALUES (?, ?, ?)
|
||||
ON DUPLICATE KEY UPDATE pinned = ?
|
||||
`
|
||||
if _, err := d.db.ExecContext(ctx, stmt, upsert.MemoID, upsert.UserID, upsert.Pinned, upsert.Pinned); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return upsert, nil
|
||||
}
|
||||
|
||||
func (d *Driver) GetMemoOrganizer(ctx context.Context, find *store.FindMemoOrganizer) (*store.MemoOrganizer, error) {
|
||||
where, args := []string{}, []any{}
|
||||
if find.MemoID != 0 {
|
||||
where = append(where, "memo_id = ?")
|
||||
args = append(args, find.MemoID)
|
||||
}
|
||||
if find.UserID != 0 {
|
||||
where = append(where, "user_id = ?")
|
||||
args = append(args, find.UserID)
|
||||
}
|
||||
|
||||
query := fmt.Sprintf(`
|
||||
SELECT
|
||||
memo_id,
|
||||
user_id,
|
||||
pinned
|
||||
FROM memo_organizer
|
||||
WHERE %s
|
||||
`, strings.Join(where, " AND "))
|
||||
row := d.db.QueryRowContext(ctx, query, args...)
|
||||
if err := row.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if row == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
memoOrganizer := &store.MemoOrganizer{}
|
||||
if err := row.Scan(
|
||||
&memoOrganizer.MemoID,
|
||||
&memoOrganizer.UserID,
|
||||
&memoOrganizer.Pinned,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return memoOrganizer, nil
|
||||
}
|
||||
|
||||
func (d *Driver) DeleteMemoOrganizer(ctx context.Context, delete *store.DeleteMemoOrganizer) error {
|
||||
where, args := []string{}, []any{}
|
||||
if v := delete.MemoID; v != nil {
|
||||
where, args = append(where, "memo_id = ?"), append(args, *v)
|
||||
}
|
||||
if v := delete.UserID; v != nil {
|
||||
where, args = append(where, "user_id = ?"), append(args, *v)
|
||||
}
|
||||
stmt := `DELETE FROM memo_organizer WHERE ` + strings.Join(where, " AND ")
|
||||
if _, err := d.db.ExecContext(ctx, stmt, args...); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func vacuumMemoOrganizer(ctx context.Context, tx *sql.Tx) error {
|
||||
stmt := `
|
||||
DELETE FROM
|
||||
memo_organizer
|
||||
WHERE
|
||||
memo_id NOT IN (
|
||||
SELECT
|
||||
id
|
||||
FROM
|
||||
memo
|
||||
)
|
||||
OR user_id NOT IN (
|
||||
SELECT
|
||||
id
|
||||
FROM
|
||||
user
|
||||
)`
|
||||
_, err := tx.ExecContext(ctx, stmt)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
@ -0,0 +1,118 @@
|
||||
package mysql
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"strings"
|
||||
|
||||
"github.com/usememos/memos/store"
|
||||
)
|
||||
|
||||
func (d *Driver) UpsertMemoRelation(ctx context.Context, create *store.MemoRelation) (*store.MemoRelation, error) {
|
||||
stmt := `
|
||||
INSERT INTO memo_relation (
|
||||
memo_id,
|
||||
related_memo_id,
|
||||
type
|
||||
)
|
||||
VALUES (?, ?, ?)
|
||||
ON DUPLICATE KEY UPDATE type = ?
|
||||
`
|
||||
_, err := d.db.ExecContext(
|
||||
ctx,
|
||||
stmt,
|
||||
create.MemoID,
|
||||
create.RelatedMemoID,
|
||||
create.Type,
|
||||
create.Type,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
memoRelation := store.MemoRelation{
|
||||
MemoID: create.MemoID,
|
||||
RelatedMemoID: create.RelatedMemoID,
|
||||
Type: create.Type,
|
||||
}
|
||||
|
||||
return &memoRelation, nil
|
||||
}
|
||||
|
||||
func (d *Driver) ListMemoRelations(ctx context.Context, find *store.FindMemoRelation) ([]*store.MemoRelation, error) {
|
||||
where, args := []string{"TRUE"}, []any{}
|
||||
if find.MemoID != nil {
|
||||
where, args = append(where, "memo_id = ?"), append(args, find.MemoID)
|
||||
}
|
||||
if find.RelatedMemoID != nil {
|
||||
where, args = append(where, "related_memo_id = ?"), append(args, find.RelatedMemoID)
|
||||
}
|
||||
if find.Type != nil {
|
||||
where, args = append(where, "type = ?"), append(args, find.Type)
|
||||
}
|
||||
|
||||
rows, err := d.db.QueryContext(ctx, `
|
||||
SELECT
|
||||
memo_id,
|
||||
related_memo_id,
|
||||
type
|
||||
FROM memo_relation
|
||||
WHERE `+strings.Join(where, " AND "), args...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
list := []*store.MemoRelation{}
|
||||
for rows.Next() {
|
||||
memoRelation := &store.MemoRelation{}
|
||||
if err := rows.Scan(
|
||||
&memoRelation.MemoID,
|
||||
&memoRelation.RelatedMemoID,
|
||||
&memoRelation.Type,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
list = append(list, memoRelation)
|
||||
}
|
||||
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return list, nil
|
||||
}
|
||||
|
||||
func (d *Driver) DeleteMemoRelation(ctx context.Context, delete *store.DeleteMemoRelation) error {
|
||||
where, args := []string{"TRUE"}, []any{}
|
||||
if delete.MemoID != nil {
|
||||
where, args = append(where, "memo_id = ?"), append(args, delete.MemoID)
|
||||
}
|
||||
if delete.RelatedMemoID != nil {
|
||||
where, args = append(where, "related_memo_id = ?"), append(args, delete.RelatedMemoID)
|
||||
}
|
||||
if delete.Type != nil {
|
||||
where, args = append(where, "type = ?"), append(args, delete.Type)
|
||||
}
|
||||
stmt := `
|
||||
DELETE FROM memo_relation
|
||||
WHERE ` + strings.Join(where, " AND ")
|
||||
result, err := d.db.ExecContext(ctx, stmt, args...)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err = result.RowsAffected(); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func vacuumMemoRelations(ctx context.Context, tx *sql.Tx) error {
|
||||
if _, err := tx.ExecContext(ctx, `
|
||||
DELETE FROM memo_relation
|
||||
WHERE memo_id NOT IN (SELECT id FROM memo) OR related_memo_id NOT IN (SELECT id FROM memo)
|
||||
`); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
@ -0,0 +1,182 @@
|
||||
package mysql
|
||||
|
||||
import (
|
||||
"context"
|
||||
"embed"
|
||||
"fmt"
|
||||
"io/fs"
|
||||
"regexp"
|
||||
"sort"
|
||||
"strings"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
|
||||
"github.com/usememos/memos/server/version"
|
||||
)
|
||||
|
||||
const (
|
||||
latestSchemaFileName = "LATEST__SCHEMA.sql"
|
||||
)
|
||||
|
||||
//go:embed migration
|
||||
var migrationFS embed.FS
|
||||
|
||||
func (d *Driver) Migrate(ctx context.Context) error {
|
||||
if d.profile.IsDev() {
|
||||
return d.nonProdMigrate(ctx)
|
||||
}
|
||||
|
||||
return d.prodMigrate(ctx)
|
||||
}
|
||||
|
||||
func (d *Driver) nonProdMigrate(ctx context.Context) error {
|
||||
buf, err := migrationFS.ReadFile("migration/dev/" + latestSchemaFileName)
|
||||
if err != nil {
|
||||
return errors.Errorf("failed to read latest schema file: %s", err)
|
||||
}
|
||||
|
||||
for _, stmt := range strings.Split(string(buf), ";") {
|
||||
stmt = strings.TrimSpace(stmt)
|
||||
if stmt == "" {
|
||||
continue
|
||||
}
|
||||
_, err := d.db.ExecContext(ctx, stmt)
|
||||
if err != nil {
|
||||
return errors.Errorf("failed to exec SQL %s: %s", stmt, err)
|
||||
}
|
||||
}
|
||||
|
||||
// In demo mode, we should seed the database.
|
||||
if d.profile.Mode == "demo" {
|
||||
if err := d.seed(ctx); err != nil {
|
||||
return errors.Wrap(err, "failed to seed")
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
func (d *Driver) prodMigrate(ctx context.Context) error {
|
||||
currentVersion := version.GetCurrentVersion(d.profile.Mode)
|
||||
migrationHistoryList, err := d.FindMigrationHistoryList(ctx, &MigrationHistoryFind{})
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "failed to find migration history")
|
||||
}
|
||||
if len(migrationHistoryList) == 0 {
|
||||
_, err := d.UpsertMigrationHistory(ctx, &MigrationHistoryUpsert{
|
||||
Version: currentVersion,
|
||||
})
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "failed to upsert migration history")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
migrationHistoryVersionList := []string{}
|
||||
for _, migrationHistory := range migrationHistoryList {
|
||||
migrationHistoryVersionList = append(migrationHistoryVersionList, migrationHistory.Version)
|
||||
}
|
||||
sort.Sort(version.SortVersion(migrationHistoryVersionList))
|
||||
latestMigrationHistoryVersion := migrationHistoryVersionList[len(migrationHistoryVersionList)-1]
|
||||
|
||||
if !version.IsVersionGreaterThan(version.GetSchemaVersion(currentVersion), latestMigrationHistoryVersion) {
|
||||
return nil
|
||||
}
|
||||
|
||||
println("start migrate")
|
||||
for _, minorVersion := range getMinorVersionList() {
|
||||
normalizedVersion := minorVersion + ".0"
|
||||
if version.IsVersionGreaterThan(normalizedVersion, latestMigrationHistoryVersion) && version.IsVersionGreaterOrEqualThan(currentVersion, normalizedVersion) {
|
||||
println("applying migration for", normalizedVersion)
|
||||
if err := d.applyMigrationForMinorVersion(ctx, minorVersion); err != nil {
|
||||
return errors.Wrap(err, "failed to apply minor version migration")
|
||||
}
|
||||
}
|
||||
}
|
||||
println("end migrate")
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d *Driver) applyMigrationForMinorVersion(ctx context.Context, minorVersion string) error {
|
||||
filenames, err := fs.Glob(migrationFS, fmt.Sprintf("%s/%s/*.sql", "migration/prod", minorVersion))
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "failed to read ddl files")
|
||||
}
|
||||
|
||||
sort.Strings(filenames)
|
||||
// Loop over all migration files and execute them in order.
|
||||
for _, filename := range filenames {
|
||||
buf, err := migrationFS.ReadFile(filename)
|
||||
if err != nil {
|
||||
return errors.Wrapf(err, "failed to read minor version migration file, filename=%s", filename)
|
||||
}
|
||||
for _, stmt := range strings.Split(string(buf), ";") {
|
||||
if strings.TrimSpace(stmt) == "" {
|
||||
continue
|
||||
}
|
||||
if _, err := d.db.ExecContext(ctx, stmt); err != nil {
|
||||
return errors.Wrapf(err, "migrate error: %s", stmt)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Upsert the newest version to migration_history.
|
||||
version := minorVersion + ".0"
|
||||
if _, err = d.UpsertMigrationHistory(ctx, &MigrationHistoryUpsert{Version: version}); err != nil {
|
||||
return errors.Wrapf(err, "failed to upsert migration history with version: %s", version)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
//go:embed seed
|
||||
var seedFS embed.FS
|
||||
|
||||
func (d *Driver) seed(ctx context.Context) error {
|
||||
filenames, err := fs.Glob(seedFS, fmt.Sprintf("%s/*.sql", "seed"))
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "failed to read seed files")
|
||||
}
|
||||
|
||||
sort.Strings(filenames)
|
||||
|
||||
// Loop over all seed files and execute them in order.
|
||||
for _, filename := range filenames {
|
||||
buf, err := seedFS.ReadFile(filename)
|
||||
if err != nil {
|
||||
return errors.Wrapf(err, "failed to read seed file, filename=%s", filename)
|
||||
}
|
||||
|
||||
for _, stmt := range strings.Split(string(buf), ";") {
|
||||
if strings.TrimSpace(stmt) == "" {
|
||||
continue
|
||||
}
|
||||
if _, err := d.db.ExecContext(ctx, stmt); err != nil {
|
||||
return errors.Wrapf(err, "seed error: %s", stmt)
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// minorDirRegexp is a regular expression for minor version directory.
|
||||
var minorDirRegexp = regexp.MustCompile(`^migration/prod/[0-9]+\.[0-9]+$`)
|
||||
|
||||
func getMinorVersionList() []string {
|
||||
minorVersionList := []string{}
|
||||
|
||||
if err := fs.WalkDir(migrationFS, "migration", func(path string, file fs.DirEntry, err error) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if file.IsDir() && minorDirRegexp.MatchString(path) {
|
||||
minorVersionList = append(minorVersionList, file.Name())
|
||||
}
|
||||
|
||||
return nil
|
||||
}); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
sort.Sort(version.SortVersion(minorVersionList))
|
||||
|
||||
return minorVersionList
|
||||
}
|
@ -0,0 +1,131 @@
|
||||
-- activity
|
||||
CREATE TABLE IF NOT EXISTS `activity` (
|
||||
`id` int NOT NULL AUTO_INCREMENT,
|
||||
`creator_id` int NOT NULL,
|
||||
`created_ts` TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
`type` varchar(255) NOT NULL DEFAULT '',
|
||||
`level` varchar(255) NOT NULL DEFAULT 'INFO',
|
||||
`payload` text NOT NULL,
|
||||
PRIMARY KEY (`id`),
|
||||
CONSTRAINT `activity_chk_1` CHECK ((`level` in (_utf8mb4'INFO',_utf8mb4'WARN',_utf8mb4'ERROR')))
|
||||
);
|
||||
|
||||
-- idp
|
||||
CREATE TABLE IF NOT EXISTS `idp` (
|
||||
`id` int NOT NULL AUTO_INCREMENT,
|
||||
`name` text NOT NULL,
|
||||
`type` text NOT NULL,
|
||||
`identifier_filter` varchar(256) NOT NULL DEFAULT '',
|
||||
`config` text NOT NULL,
|
||||
PRIMARY KEY (`id`)
|
||||
);
|
||||
|
||||
-- memo
|
||||
CREATE TABLE IF NOT EXISTS `memo` (
|
||||
`id` int NOT NULL AUTO_INCREMENT,
|
||||
`creator_id` int NOT NULL,
|
||||
`created_ts` TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
`updated_ts` TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
`row_status` varchar(255) NOT NULL DEFAULT 'NORMAL',
|
||||
`content` text NOT NULL,
|
||||
`visibility` varchar(255) NOT NULL DEFAULT 'PRIVATE',
|
||||
PRIMARY KEY (`id`),
|
||||
KEY `creator_id` (`creator_id`),
|
||||
KEY `visibility` (`visibility`),
|
||||
CONSTRAINT `memo_chk_1` CHECK ((`row_status` in (_utf8mb4'NORMAL',_utf8mb4'ARCHIVED'))),
|
||||
CONSTRAINT `memo_chk_2` CHECK ((`visibility` in (_utf8mb4'PUBLIC',_utf8mb4'PROTECTED',_utf8mb4'PRIVATE')))
|
||||
);
|
||||
|
||||
-- memo_organizer
|
||||
CREATE TABLE IF NOT EXISTS `memo_organizer` (
|
||||
`memo_id` int NOT NULL,
|
||||
`user_id` int NOT NULL,
|
||||
`pinned` int NOT NULL DEFAULT '0',
|
||||
UNIQUE KEY `memo_id` (`memo_id`,`user_id`),
|
||||
CONSTRAINT `memo_organizer_chk_1` CHECK ((`pinned` in (0,1)))
|
||||
);
|
||||
|
||||
-- memo_relation
|
||||
CREATE TABLE IF NOT EXISTS `memo_relation` (
|
||||
`memo_id` int NOT NULL,
|
||||
`related_memo_id` int NOT NULL,
|
||||
`type` varchar(256) NOT NULL,
|
||||
UNIQUE KEY `memo_id` (`memo_id`,`related_memo_id`,`type`)
|
||||
);
|
||||
|
||||
-- migration_history
|
||||
CREATE TABLE IF NOT EXISTS `migration_history` (
|
||||
`version` varchar(255) NOT NULL,
|
||||
`created_ts` TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
PRIMARY KEY (`version`)
|
||||
);
|
||||
|
||||
-- resource
|
||||
CREATE TABLE IF NOT EXISTS `resource` (
|
||||
`id` int NOT NULL AUTO_INCREMENT,
|
||||
`creator_id` int NOT NULL,
|
||||
`created_ts` TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
`updated_ts` TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
`filename` text NOT NULL,
|
||||
`blob` blob,
|
||||
`external_link` text NOT NULL,
|
||||
`type` varchar(255) NOT NULL DEFAULT '',
|
||||
`size` int NOT NULL DEFAULT '0',
|
||||
`internal_path` varchar(255) NOT NULL DEFAULT '',
|
||||
`memo_id` int DEFAULT NULL,
|
||||
PRIMARY KEY (`id`),
|
||||
KEY `creator_id` (`creator_id`),
|
||||
KEY `memo_id` (`memo_id`)
|
||||
);
|
||||
|
||||
-- storage
|
||||
CREATE TABLE IF NOT EXISTS `storage` (
|
||||
`id` int NOT NULL AUTO_INCREMENT,
|
||||
`name` varchar(256) NOT NULL,
|
||||
`type` varchar(256) NOT NULL,
|
||||
`config` text NOT NULL,
|
||||
PRIMARY KEY (`id`)
|
||||
);
|
||||
|
||||
-- system_setting
|
||||
CREATE TABLE IF NOT EXISTS `system_setting` (
|
||||
`name` varchar(255) NOT NULL,
|
||||
`value` text NOT NULL,
|
||||
`description` text NOT NULL,
|
||||
PRIMARY KEY (`name`)
|
||||
);
|
||||
|
||||
-- tag
|
||||
CREATE TABLE IF NOT EXISTS `tag` (
|
||||
`name` varchar(255) NOT NULL,
|
||||
`creator_id` int NOT NULL,
|
||||
UNIQUE KEY `name` (`name`,`creator_id`)
|
||||
);
|
||||
|
||||
-- user
|
||||
CREATE TABLE IF NOT EXISTS `user` (
|
||||
`id` int NOT NULL AUTO_INCREMENT,
|
||||
`created_ts` TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
`updated_ts` TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
`row_status` varchar(255) NOT NULL DEFAULT 'NORMAL',
|
||||
`username` varchar(255) NOT NULL,
|
||||
`role` varchar(255) NOT NULL DEFAULT 'USER',
|
||||
`email` varchar(255) NOT NULL DEFAULT '',
|
||||
`nickname` varchar(255) NOT NULL DEFAULT '',
|
||||
`password_hash` varchar(255) NOT NULL,
|
||||
`avatar_url` text NOT NULL,
|
||||
PRIMARY KEY (`id`),
|
||||
UNIQUE KEY `username` (`username`),
|
||||
CONSTRAINT `user_chk_1` CHECK ((`row_status` in (_utf8mb4'NORMAL',_utf8mb4'ARCHIVED'))),
|
||||
CONSTRAINT `user_chk_2` CHECK ((`role` in (_utf8mb4'HOST',_utf8mb4'ADMIN',_utf8mb4'USER')))
|
||||
);
|
||||
|
||||
-- user_setting
|
||||
CREATE TABLE IF NOT EXISTS `user_setting` (
|
||||
`user_id` int NOT NULL,
|
||||
`key` varchar(255) NOT NULL,
|
||||
`value` text NOT NULL,
|
||||
UNIQUE KEY `user_id` (`user_id`,`key`)
|
||||
);
|
||||
|
||||
|
@ -0,0 +1,84 @@
|
||||
package mysql
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type MigrationHistory struct {
|
||||
Version string
|
||||
CreatedTs int64
|
||||
}
|
||||
|
||||
type MigrationHistoryUpsert struct {
|
||||
Version string
|
||||
}
|
||||
|
||||
type MigrationHistoryFind struct {
|
||||
Version *string
|
||||
}
|
||||
|
||||
func (d *Driver) FindMigrationHistoryList(ctx context.Context, find *MigrationHistoryFind) ([]*MigrationHistory, error) {
|
||||
where, args := []string{"1 = 1"}, []any{}
|
||||
|
||||
if v := find.Version; v != nil {
|
||||
where, args = append(where, "version = ?"), append(args, *v)
|
||||
}
|
||||
|
||||
query := `
|
||||
SELECT version, UNIX_TIMESTAMP(created_ts)
|
||||
FROM migration_history
|
||||
WHERE ` + strings.Join(where, " AND ") + `
|
||||
ORDER BY created_ts DESC
|
||||
`
|
||||
rows, err := d.db.QueryContext(ctx, query, args...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
list := make([]*MigrationHistory, 0)
|
||||
for rows.Next() {
|
||||
var migrationHistory MigrationHistory
|
||||
if err := rows.Scan(
|
||||
&migrationHistory.Version,
|
||||
&migrationHistory.CreatedTs,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
list = append(list, &migrationHistory)
|
||||
}
|
||||
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return list, nil
|
||||
}
|
||||
|
||||
func (d *Driver) UpsertMigrationHistory(ctx context.Context, upsert *MigrationHistoryUpsert) (*MigrationHistory, error) {
|
||||
stmt := `
|
||||
INSERT INTO migration_history (version) VALUES (?)
|
||||
ON DUPLICATE KEY UPDATE version = ?
|
||||
`
|
||||
_, err := d.db.ExecContext(ctx, stmt, upsert.Version, upsert.Version)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var migrationHistory MigrationHistory
|
||||
stmt = `
|
||||
SELECT version, UNIX_TIMESTAMP(created_ts)
|
||||
FROM migration_history
|
||||
WHERE version = ?
|
||||
`
|
||||
if err := d.db.QueryRowContext(ctx, stmt, upsert.Version).Scan(
|
||||
&migrationHistory.Version,
|
||||
&migrationHistory.CreatedTs,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &migrationHistory, nil
|
||||
}
|
@ -0,0 +1,64 @@
|
||||
package mysql
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
|
||||
"github.com/usememos/memos/server/profile"
|
||||
"github.com/usememos/memos/store"
|
||||
)
|
||||
|
||||
type Driver struct {
|
||||
db *sql.DB
|
||||
profile *profile.Profile
|
||||
}
|
||||
|
||||
func NewDriver(profile *profile.Profile) (store.Driver, error) {
|
||||
db, err := sql.Open("mysql", profile.DSN)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
driver := Driver{db: db, profile: profile}
|
||||
return &driver, nil
|
||||
}
|
||||
|
||||
func (d *Driver) Vacuum(ctx context.Context) error {
|
||||
tx, err := d.db.BeginTx(ctx, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer tx.Rollback()
|
||||
|
||||
if err := vacuumMemo(ctx, tx); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := vacuumResource(ctx, tx); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := vacuumUserSetting(ctx, tx); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := vacuumMemoOrganizer(ctx, tx); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := vacuumMemoRelations(ctx, tx); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := vacuumTag(ctx, tx); err != nil {
|
||||
// Prevent revive warning.
|
||||
return err
|
||||
}
|
||||
|
||||
return tx.Commit()
|
||||
}
|
||||
|
||||
func (*Driver) BackupTo(context.Context, string) error {
|
||||
return errors.New("Please use mysqldump to backup")
|
||||
}
|
||||
|
||||
func (d *Driver) Close() error {
|
||||
return d.db.Close()
|
||||
}
|
@ -0,0 +1,217 @@
|
||||
package mysql
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
|
||||
"github.com/usememos/memos/store"
|
||||
)
|
||||
|
||||
func (d *Driver) CreateResource(ctx context.Context, create *store.Resource) (*store.Resource, error) {
|
||||
stmt := `
|
||||
INSERT INTO resource (
|
||||
filename,
|
||||
resource.blob,
|
||||
external_link,
|
||||
type,
|
||||
size,
|
||||
creator_id,
|
||||
internal_path
|
||||
)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?)
|
||||
`
|
||||
result, err := d.db.ExecContext(
|
||||
ctx,
|
||||
stmt,
|
||||
create.Filename,
|
||||
create.Blob,
|
||||
create.ExternalLink,
|
||||
create.Type,
|
||||
create.Size,
|
||||
create.CreatorID,
|
||||
create.InternalPath,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
id, err := result.LastInsertId()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
id32 := int32(id)
|
||||
list, err := d.ListResources(ctx, &store.FindResource{ID: &id32})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(list) != 1 {
|
||||
return nil, errors.Wrapf(nil, "unexpected resource count: %d", len(list))
|
||||
}
|
||||
|
||||
return list[0], nil
|
||||
}
|
||||
|
||||
func (d *Driver) ListResources(ctx context.Context, find *store.FindResource) ([]*store.Resource, error) {
|
||||
where, args := []string{"1 = 1"}, []any{}
|
||||
|
||||
if v := find.ID; v != nil {
|
||||
where, args = append(where, "id = ?"), append(args, *v)
|
||||
}
|
||||
if v := find.CreatorID; v != nil {
|
||||
where, args = append(where, "creator_id = ?"), append(args, *v)
|
||||
}
|
||||
if v := find.Filename; v != nil {
|
||||
where, args = append(where, "filename = ?"), append(args, *v)
|
||||
}
|
||||
if v := find.MemoID; v != nil {
|
||||
where, args = append(where, "memo_id = ?"), append(args, *v)
|
||||
}
|
||||
if find.HasRelatedMemo {
|
||||
where = append(where, "memo_id IS NOT NULL")
|
||||
}
|
||||
|
||||
fields := []string{"id", "filename", "external_link", "type", "size", "creator_id", "UNIX_TIMESTAMP(created_ts)", "UNIX_TIMESTAMP(updated_ts)", "internal_path", "memo_id"}
|
||||
if find.GetBlob {
|
||||
fields = append(fields, "resource.blob")
|
||||
}
|
||||
|
||||
query := fmt.Sprintf(`
|
||||
SELECT
|
||||
%s
|
||||
FROM resource
|
||||
WHERE %s
|
||||
GROUP BY id
|
||||
ORDER BY created_ts DESC
|
||||
`, strings.Join(fields, ", "), strings.Join(where, " AND "))
|
||||
if find.Limit != nil {
|
||||
query = fmt.Sprintf("%s LIMIT %d", query, *find.Limit)
|
||||
if find.Offset != nil {
|
||||
query = fmt.Sprintf("%s OFFSET %d", query, *find.Offset)
|
||||
}
|
||||
}
|
||||
|
||||
rows, err := d.db.QueryContext(ctx, query, args...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
list := make([]*store.Resource, 0)
|
||||
for rows.Next() {
|
||||
resource := store.Resource{}
|
||||
var memoID sql.NullInt32
|
||||
dests := []any{
|
||||
&resource.ID,
|
||||
&resource.Filename,
|
||||
&resource.ExternalLink,
|
||||
&resource.Type,
|
||||
&resource.Size,
|
||||
&resource.CreatorID,
|
||||
&resource.CreatedTs,
|
||||
&resource.UpdatedTs,
|
||||
&resource.InternalPath,
|
||||
&memoID,
|
||||
}
|
||||
if find.GetBlob {
|
||||
dests = append(dests, &resource.Blob)
|
||||
}
|
||||
if err := rows.Scan(dests...); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if memoID.Valid {
|
||||
resource.MemoID = &memoID.Int32
|
||||
}
|
||||
list = append(list, &resource)
|
||||
}
|
||||
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return list, nil
|
||||
}
|
||||
|
||||
func (d *Driver) UpdateResource(ctx context.Context, update *store.UpdateResource) (*store.Resource, error) {
|
||||
set, args := []string{}, []any{}
|
||||
|
||||
if v := update.UpdatedTs; v != nil {
|
||||
set, args = append(set, "updated_ts = ?"), append(args, *v)
|
||||
}
|
||||
if v := update.Filename; v != nil {
|
||||
set, args = append(set, "filename = ?"), append(args, *v)
|
||||
}
|
||||
if v := update.InternalPath; v != nil {
|
||||
set, args = append(set, "internal_path = ?"), append(args, *v)
|
||||
}
|
||||
if v := update.MemoID; v != nil {
|
||||
set, args = append(set, "memo_id = ?"), append(args, *v)
|
||||
}
|
||||
if update.UnbindMemo {
|
||||
set = append(set, "memo_id = NULL")
|
||||
}
|
||||
if v := update.Blob; v != nil {
|
||||
set, args = append(set, "resource.blob = ?"), append(args, v)
|
||||
}
|
||||
|
||||
args = append(args, update.ID)
|
||||
stmt := `
|
||||
UPDATE resource
|
||||
SET ` + strings.Join(set, ", ") + `
|
||||
WHERE id = ?
|
||||
`
|
||||
if _, err := d.db.ExecContext(ctx, stmt, args...); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
list, err := d.ListResources(ctx, &store.FindResource{ID: &update.ID})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(list) != 1 {
|
||||
return nil, errors.Wrapf(nil, "unexpected resource count: %d", len(list))
|
||||
}
|
||||
|
||||
return list[0], nil
|
||||
}
|
||||
|
||||
func (d *Driver) DeleteResource(ctx context.Context, delete *store.DeleteResource) error {
|
||||
stmt := `DELETE FROM resource WHERE id = ?`
|
||||
result, err := d.db.ExecContext(ctx, stmt, delete.ID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := result.RowsAffected(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := d.Vacuum(ctx); err != nil {
|
||||
// Prevent linter warning.
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func vacuumResource(ctx context.Context, tx *sql.Tx) error {
|
||||
stmt := `
|
||||
DELETE FROM
|
||||
resource
|
||||
WHERE
|
||||
creator_id NOT IN (
|
||||
SELECT
|
||||
id
|
||||
FROM
|
||||
user
|
||||
)`
|
||||
_, err := tx.ExecContext(ctx, stmt)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
@ -0,0 +1,4 @@
|
||||
TRUNCATE TABLE memo_organizer;
|
||||
TRUNCATE TABLE resource;
|
||||
TRUNCATE TABLE memo;
|
||||
TRUNCATE TABLE user;
|
@ -0,0 +1,45 @@
|
||||
INSERT INTO
|
||||
user (
|
||||
`id`,
|
||||
`username`,
|
||||
`role`,
|
||||
`email`,
|
||||
`nickname`,
|
||||
`row_status`,
|
||||
`avatar_url`,
|
||||
`password_hash`
|
||||
)
|
||||
VALUES
|
||||
(
|
||||
101,
|
||||
'memos-demo',
|
||||
'HOST',
|
||||
'demo@usememos.com',
|
||||
'Derobot',
|
||||
'NORMAL',
|
||||
'',
|
||||
-- raw password: secret
|
||||
'$2a$14$ajq8Q7fbtFRQvXpdCq7Jcuy.Rx1h/L4J60Otx.gyNLbAYctGMJ9tK'
|
||||
),
|
||||
(
|
||||
102,
|
||||
'jack',
|
||||
'USER',
|
||||
'jack@usememos.com',
|
||||
'Jack',
|
||||
'NORMAL',
|
||||
'',
|
||||
-- raw password: secret
|
||||
'$2a$14$ajq8Q7fbtFRQvXpdCq7Jcuy.Rx1h/L4J60Otx.gyNLbAYctGMJ9tK'
|
||||
),
|
||||
(
|
||||
103,
|
||||
'bob',
|
||||
'USER',
|
||||
'bob@usememos.com',
|
||||
'Bob',
|
||||
'ARCHIVED',
|
||||
'',
|
||||
-- raw password: secret
|
||||
'$2a$14$ajq8Q7fbtFRQvXpdCq7Jcuy.Rx1h/L4J60Otx.gyNLbAYctGMJ9tK'
|
||||
);
|
@ -0,0 +1,54 @@
|
||||
INSERT INTO
|
||||
memo (`id`, `content`, `creator_id`)
|
||||
VALUES
|
||||
(
|
||||
1,
|
||||
"#Hello 👋 Welcome to memos.",
|
||||
101
|
||||
);
|
||||
|
||||
INSERT INTO
|
||||
memo (
|
||||
`id`,
|
||||
`content`,
|
||||
`creator_id`,
|
||||
`visibility`
|
||||
)
|
||||
VALUES
|
||||
(
|
||||
2,
|
||||
'#TODO
|
||||
- [x] Take more photos about **🌄 sunset**
|
||||
- [x] Clean the room
|
||||
- [ ] Read *📖 The Little Prince*
|
||||
(👆 click to toggle status)',
|
||||
101,
|
||||
'PROTECTED'
|
||||
),
|
||||
(
|
||||
3,
|
||||
"**[Slash](https://github.com/boojack/slash)**: A bookmarking and url shortener, save and share your links very easily.
|
||||

|
||||
|
||||
**[SQL Chat](https://www.sqlchat.ai)**: Chat-based SQL Client
|
||||
",
|
||||
101,
|
||||
'PUBLIC'
|
||||
),
|
||||
(
|
||||
4,
|
||||
'#TODO
|
||||
- [x] Take more photos about **🌄 sunset**
|
||||
- [ ] Clean the classroom
|
||||
- [ ] Watch *👦 The Boys*
|
||||
(👆 click to toggle status)
|
||||
',
|
||||
102,
|
||||
'PROTECTED'
|
||||
),
|
||||
(
|
||||
5,
|
||||
'三人行,必有我师焉!👨🏫',
|
||||
102,
|
||||
'PUBLIC'
|
||||
);
|
@ -0,0 +1,5 @@
|
||||
INSERT INTO
|
||||
memo_organizer (`memo_id`, `user_id`, `pinned`)
|
||||
VALUES
|
||||
(1, 101, 1),
|
||||
(3, 101, 1);
|
@ -0,0 +1,6 @@
|
||||
INSERT INTO
|
||||
tag (`name`, `creator_id`)
|
||||
VALUES
|
||||
('Hello', 101),
|
||||
('TODO', 101),
|
||||
('TODO', 102);
|
@ -0,0 +1,137 @@
|
||||
package mysql
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
|
||||
"github.com/usememos/memos/store"
|
||||
)
|
||||
|
||||
func (d *Driver) CreateStorage(ctx context.Context, create *store.Storage) (*store.Storage, error) {
|
||||
stmt := `
|
||||
INSERT INTO storage (
|
||||
name,
|
||||
type,
|
||||
config
|
||||
)
|
||||
VALUES (?, ?, ?)
|
||||
`
|
||||
result, err := d.db.ExecContext(ctx, stmt, create.Name, create.Type, create.Config)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
id, err := result.LastInsertId()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
create.ID = int32(id)
|
||||
return create, nil
|
||||
}
|
||||
|
||||
func (d *Driver) ListStorages(ctx context.Context, find *store.FindStorage) ([]*store.Storage, error) {
|
||||
where, args := []string{"1 = 1"}, []any{}
|
||||
if find.ID != nil {
|
||||
where, args = append(where, "id = ?"), append(args, *find.ID)
|
||||
}
|
||||
|
||||
rows, err := d.db.QueryContext(ctx, `
|
||||
SELECT
|
||||
id,
|
||||
name,
|
||||
type,
|
||||
config
|
||||
FROM storage
|
||||
WHERE `+strings.Join(where, " AND ")+`
|
||||
ORDER BY id DESC`,
|
||||
args...,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
list := []*store.Storage{}
|
||||
for rows.Next() {
|
||||
storage := &store.Storage{}
|
||||
if err := rows.Scan(
|
||||
&storage.ID,
|
||||
&storage.Name,
|
||||
&storage.Type,
|
||||
&storage.Config,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
list = append(list, storage)
|
||||
}
|
||||
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return list, nil
|
||||
}
|
||||
|
||||
func (d *Driver) GetStorage(ctx context.Context, find *store.FindStorage) (*store.Storage, error) {
|
||||
list, err := d.ListStorages(ctx, find)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(list) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
return list[0], nil
|
||||
}
|
||||
|
||||
func (d *Driver) UpdateStorage(ctx context.Context, update *store.UpdateStorage) (*store.Storage, error) {
|
||||
set, args := []string{}, []any{}
|
||||
if update.Name != nil {
|
||||
set = append(set, "name = ?")
|
||||
args = append(args, *update.Name)
|
||||
}
|
||||
if update.Config != nil {
|
||||
set = append(set, "config = ?")
|
||||
args = append(args, *update.Config)
|
||||
}
|
||||
args = append(args, update.ID)
|
||||
|
||||
stmt := `
|
||||
UPDATE storage
|
||||
SET ` + strings.Join(set, ", ") + `
|
||||
WHERE id = ?
|
||||
`
|
||||
_, err := d.db.ExecContext(ctx, stmt, args...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
storage := &store.Storage{}
|
||||
stmt = `SELECT id,name,type,config FROM storage WHERE id = ?`
|
||||
if err := d.db.QueryRowContext(ctx, stmt, update.ID).Scan(
|
||||
&storage.ID,
|
||||
&storage.Name,
|
||||
&storage.Type,
|
||||
&storage.Config,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return storage, nil
|
||||
}
|
||||
|
||||
func (d *Driver) DeleteStorage(ctx context.Context, delete *store.DeleteStorage) error {
|
||||
stmt := `
|
||||
DELETE FROM storage
|
||||
WHERE id = ?
|
||||
`
|
||||
result, err := d.db.ExecContext(ctx, stmt, delete.ID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := result.RowsAffected(); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
@ -0,0 +1,72 @@
|
||||
package mysql
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
|
||||
"github.com/usememos/memos/store"
|
||||
)
|
||||
|
||||
func (d *Driver) UpsertSystemSetting(ctx context.Context, upsert *store.SystemSetting) (*store.SystemSetting, error) {
|
||||
stmt := `
|
||||
INSERT INTO system_setting (
|
||||
name, value, description
|
||||
)
|
||||
VALUES (?, ?, ?)
|
||||
ON DUPLICATE KEY UPDATE value = ?, description = ?
|
||||
`
|
||||
_, err := d.db.ExecContext(
|
||||
ctx,
|
||||
stmt,
|
||||
upsert.Name,
|
||||
upsert.Value,
|
||||
upsert.Description,
|
||||
upsert.Value,
|
||||
upsert.Description,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return upsert, nil
|
||||
}
|
||||
|
||||
func (d *Driver) ListSystemSettings(ctx context.Context, find *store.FindSystemSetting) ([]*store.SystemSetting, error) {
|
||||
where, args := []string{"1 = 1"}, []any{}
|
||||
if find.Name != "" {
|
||||
where, args = append(where, "name = ?"), append(args, find.Name)
|
||||
}
|
||||
|
||||
query := `
|
||||
SELECT
|
||||
name,
|
||||
value,
|
||||
description
|
||||
FROM system_setting
|
||||
WHERE ` + strings.Join(where, " AND ")
|
||||
|
||||
rows, err := d.db.QueryContext(ctx, query, args...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
list := []*store.SystemSetting{}
|
||||
for rows.Next() {
|
||||
systemSettingMessage := &store.SystemSetting{}
|
||||
if err := rows.Scan(
|
||||
&systemSettingMessage.Name,
|
||||
&systemSettingMessage.Value,
|
||||
&systemSettingMessage.Description,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
list = append(list, systemSettingMessage)
|
||||
}
|
||||
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return list, nil
|
||||
}
|
@ -0,0 +1,90 @@
|
||||
package mysql
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"strings"
|
||||
|
||||
"github.com/usememos/memos/store"
|
||||
)
|
||||
|
||||
func (d *Driver) UpsertTag(ctx context.Context, upsert *store.Tag) (*store.Tag, error) {
|
||||
stmt := `
|
||||
INSERT INTO tag (name, creator_id)
|
||||
VALUES (?, ?)
|
||||
ON DUPLICATE KEY UPDATE name = ?
|
||||
`
|
||||
if _, err := d.db.ExecContext(ctx, stmt, upsert.Name, upsert.CreatorID, upsert.Name); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return upsert, nil
|
||||
}
|
||||
|
||||
func (d *Driver) ListTags(ctx context.Context, find *store.FindTag) ([]*store.Tag, error) {
|
||||
where, args := []string{"creator_id = ?"}, []any{find.CreatorID}
|
||||
query := `
|
||||
SELECT
|
||||
name,
|
||||
creator_id
|
||||
FROM tag
|
||||
WHERE ` + strings.Join(where, " AND ") + `
|
||||
ORDER BY name ASC
|
||||
`
|
||||
rows, err := d.db.QueryContext(ctx, query, args...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
list := []*store.Tag{}
|
||||
for rows.Next() {
|
||||
tag := &store.Tag{}
|
||||
if err := rows.Scan(
|
||||
&tag.Name,
|
||||
&tag.CreatorID,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
list = append(list, tag)
|
||||
}
|
||||
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return list, nil
|
||||
}
|
||||
|
||||
func (d *Driver) DeleteTag(ctx context.Context, delete *store.DeleteTag) error {
|
||||
where, args := []string{"name = ?", "creator_id = ?"}, []any{delete.Name, delete.CreatorID}
|
||||
stmt := `DELETE FROM tag WHERE ` + strings.Join(where, " AND ")
|
||||
result, err := d.db.ExecContext(ctx, stmt, args...)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err = result.RowsAffected(); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func vacuumTag(ctx context.Context, tx *sql.Tx) error {
|
||||
stmt := `
|
||||
DELETE FROM
|
||||
tag
|
||||
WHERE
|
||||
creator_id NOT IN (
|
||||
SELECT
|
||||
id
|
||||
FROM
|
||||
user
|
||||
)`
|
||||
_, err := tx.ExecContext(ctx, stmt)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
@ -0,0 +1,205 @@
|
||||
package mysql
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
|
||||
"github.com/usememos/memos/store"
|
||||
)
|
||||
|
||||
func (d *Driver) CreateUser(ctx context.Context, create *store.User) (*store.User, error) {
|
||||
stmt := `
|
||||
INSERT INTO user (
|
||||
username,
|
||||
role,
|
||||
email,
|
||||
nickname,
|
||||
password_hash,
|
||||
avatar_url
|
||||
)
|
||||
VALUES (?, ?, ?, ?, ?, ?)
|
||||
`
|
||||
result, err := d.db.ExecContext(ctx, stmt,
|
||||
create.Username,
|
||||
create.Role,
|
||||
create.Email,
|
||||
create.Nickname,
|
||||
create.PasswordHash,
|
||||
create.AvatarURL,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
id, err := result.LastInsertId()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
id64 := int32(id)
|
||||
list, err := d.ListUsers(ctx, &store.FindUser{ID: &id64})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(list) != 1 {
|
||||
return nil, errors.Wrapf(nil, "unexpected user count: %d", len(list))
|
||||
}
|
||||
|
||||
return list[0], nil
|
||||
}
|
||||
|
||||
func (d *Driver) UpdateUser(ctx context.Context, update *store.UpdateUser) (*store.User, error) {
|
||||
set, args := []string{}, []any{}
|
||||
if v := update.UpdatedTs; v != nil {
|
||||
set, args = append(set, "updated_ts = ?"), append(args, *v)
|
||||
}
|
||||
if v := update.RowStatus; v != nil {
|
||||
set, args = append(set, "row_status = ?"), append(args, *v)
|
||||
}
|
||||
if v := update.Username; v != nil {
|
||||
set, args = append(set, "username = ?"), append(args, *v)
|
||||
}
|
||||
if v := update.Email; v != nil {
|
||||
set, args = append(set, "email = ?"), append(args, *v)
|
||||
}
|
||||
if v := update.Nickname; v != nil {
|
||||
set, args = append(set, "nickname = ?"), append(args, *v)
|
||||
}
|
||||
if v := update.AvatarURL; v != nil {
|
||||
set, args = append(set, "avatar_url = ?"), append(args, *v)
|
||||
}
|
||||
if v := update.PasswordHash; v != nil {
|
||||
set, args = append(set, "password_hash = ?"), append(args, *v)
|
||||
}
|
||||
args = append(args, update.ID)
|
||||
|
||||
query := `
|
||||
UPDATE user
|
||||
SET ` + strings.Join(set, ", ") + `
|
||||
WHERE id = ?
|
||||
`
|
||||
if _, err := d.db.ExecContext(ctx, query, args...); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
user := &store.User{}
|
||||
query = `
|
||||
SELECT
|
||||
id,
|
||||
username,
|
||||
role,
|
||||
email,
|
||||
nickname,
|
||||
password_hash,
|
||||
avatar_url,
|
||||
UNIX_TIMESTAMP(created_ts),
|
||||
UNIX_TIMESTAMP(updated_ts),
|
||||
row_status
|
||||
FROM user WHERE id = ?
|
||||
`
|
||||
if err := d.db.QueryRowContext(ctx, query, update.ID).Scan(
|
||||
&user.ID,
|
||||
&user.Username,
|
||||
&user.Role,
|
||||
&user.Email,
|
||||
&user.Nickname,
|
||||
&user.PasswordHash,
|
||||
&user.AvatarURL,
|
||||
&user.CreatedTs,
|
||||
&user.UpdatedTs,
|
||||
&user.RowStatus,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return user, nil
|
||||
}
|
||||
|
||||
func (d *Driver) ListUsers(ctx context.Context, find *store.FindUser) ([]*store.User, error) {
|
||||
where, args := []string{"1 = 1"}, []any{}
|
||||
|
||||
if v := find.ID; v != nil {
|
||||
where, args = append(where, "id = ?"), append(args, *v)
|
||||
}
|
||||
if v := find.Username; v != nil {
|
||||
where, args = append(where, "username = ?"), append(args, *v)
|
||||
}
|
||||
if v := find.Role; v != nil {
|
||||
where, args = append(where, "role = ?"), append(args, *v)
|
||||
}
|
||||
if v := find.Email; v != nil {
|
||||
where, args = append(where, "email = ?"), append(args, *v)
|
||||
}
|
||||
if v := find.Nickname; v != nil {
|
||||
where, args = append(where, "nickname = ?"), append(args, *v)
|
||||
}
|
||||
|
||||
query := `
|
||||
SELECT
|
||||
id,
|
||||
username,
|
||||
role,
|
||||
email,
|
||||
nickname,
|
||||
password_hash,
|
||||
avatar_url,
|
||||
UNIX_TIMESTAMP(created_ts),
|
||||
UNIX_TIMESTAMP(updated_ts),
|
||||
row_status
|
||||
FROM user
|
||||
WHERE ` + strings.Join(where, " AND ") + `
|
||||
ORDER BY created_ts DESC, row_status DESC
|
||||
`
|
||||
rows, err := d.db.QueryContext(ctx, query, args...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
list := make([]*store.User, 0)
|
||||
for rows.Next() {
|
||||
var user store.User
|
||||
if err := rows.Scan(
|
||||
&user.ID,
|
||||
&user.Username,
|
||||
&user.Role,
|
||||
&user.Email,
|
||||
&user.Nickname,
|
||||
&user.PasswordHash,
|
||||
&user.AvatarURL,
|
||||
&user.CreatedTs,
|
||||
&user.UpdatedTs,
|
||||
&user.RowStatus,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
list = append(list, &user)
|
||||
}
|
||||
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return list, nil
|
||||
}
|
||||
|
||||
func (d *Driver) DeleteUser(ctx context.Context, delete *store.DeleteUser) error {
|
||||
result, err := d.db.ExecContext(ctx, `
|
||||
DELETE FROM user WHERE id = ?
|
||||
`, delete.ID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := result.RowsAffected(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := d.Vacuum(ctx); err != nil {
|
||||
// Prevent linter warning.
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
@ -0,0 +1,169 @@
|
||||
package mysql
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"strings"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"google.golang.org/protobuf/encoding/protojson"
|
||||
|
||||
storepb "github.com/usememos/memos/proto/gen/store"
|
||||
"github.com/usememos/memos/store"
|
||||
)
|
||||
|
||||
func (d *Driver) UpsertUserSetting(ctx context.Context, upsert *store.UserSetting) (*store.UserSetting, error) {
|
||||
stmt := `
|
||||
INSERT INTO user_setting (user_id,user_setting.key,value)
|
||||
VALUES (?, ?, ?)
|
||||
ON DUPLICATE KEY UPDATE value = ?
|
||||
`
|
||||
if _, err := d.db.ExecContext(ctx, stmt, upsert.UserID, upsert.Key, upsert.Value, upsert.Value); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return upsert, nil
|
||||
}
|
||||
|
||||
func (d *Driver) ListUserSettings(ctx context.Context, find *store.FindUserSetting) ([]*store.UserSetting, error) {
|
||||
where, args := []string{"1 = 1"}, []any{}
|
||||
|
||||
if v := find.Key; v != "" {
|
||||
where, args = append(where, "user_setting.key = ?"), append(args, v)
|
||||
}
|
||||
if v := find.UserID; v != nil {
|
||||
where, args = append(where, "user_id = ?"), append(args, *find.UserID)
|
||||
}
|
||||
|
||||
query := `
|
||||
SELECT
|
||||
user_id,
|
||||
user_setting.key,
|
||||
value
|
||||
FROM user_setting
|
||||
WHERE ` + strings.Join(where, " AND ")
|
||||
rows, err := d.db.QueryContext(ctx, query, args...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
userSettingList := make([]*store.UserSetting, 0)
|
||||
for rows.Next() {
|
||||
var userSetting store.UserSetting
|
||||
if err := rows.Scan(
|
||||
&userSetting.UserID,
|
||||
&userSetting.Key,
|
||||
&userSetting.Value,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
userSettingList = append(userSettingList, &userSetting)
|
||||
}
|
||||
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return userSettingList, nil
|
||||
}
|
||||
|
||||
func (d *Driver) UpsertUserSettingV1(ctx context.Context, upsert *storepb.UserSetting) (*storepb.UserSetting, error) {
|
||||
stmt := `
|
||||
INSERT INTO user_setting (user_id, user_setting.key, value)
|
||||
VALUES (?, ?, ?)
|
||||
ON DUPLICATE KEY UPDATE value = ?
|
||||
`
|
||||
var valueString string
|
||||
if upsert.Key == storepb.UserSettingKey_USER_SETTING_ACCESS_TOKENS {
|
||||
valueBytes, err := protojson.Marshal(upsert.GetAccessTokens())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
valueString = string(valueBytes)
|
||||
} else {
|
||||
return nil, errors.New("invalid user setting key")
|
||||
}
|
||||
|
||||
if _, err := d.db.ExecContext(ctx, stmt, upsert.UserId, upsert.Key.String(), valueString, valueString); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return upsert, nil
|
||||
}
|
||||
|
||||
func (d *Driver) ListUserSettingsV1(ctx context.Context, find *store.FindUserSettingV1) ([]*storepb.UserSetting, error) {
|
||||
where, args := []string{"1 = 1"}, []any{}
|
||||
|
||||
if v := find.Key; v != storepb.UserSettingKey_USER_SETTING_KEY_UNSPECIFIED {
|
||||
where, args = append(where, "user_setting.key = ?"), append(args, v.String())
|
||||
}
|
||||
if v := find.UserID; v != nil {
|
||||
where, args = append(where, "user_id = ?"), append(args, *find.UserID)
|
||||
}
|
||||
|
||||
query := `
|
||||
SELECT
|
||||
user_id,
|
||||
user_setting.key,
|
||||
value
|
||||
FROM user_setting
|
||||
WHERE ` + strings.Join(where, " AND ")
|
||||
rows, err := d.db.QueryContext(ctx, query, args...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
userSettingList := make([]*storepb.UserSetting, 0)
|
||||
for rows.Next() {
|
||||
userSetting := &storepb.UserSetting{}
|
||||
var keyString, valueString string
|
||||
if err := rows.Scan(
|
||||
&userSetting.UserId,
|
||||
&keyString,
|
||||
&valueString,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
userSetting.Key = storepb.UserSettingKey(storepb.UserSettingKey_value[keyString])
|
||||
if userSetting.Key == storepb.UserSettingKey_USER_SETTING_ACCESS_TOKENS {
|
||||
accessTokensUserSetting := &storepb.AccessTokensUserSetting{}
|
||||
if err := protojson.Unmarshal([]byte(valueString), accessTokensUserSetting); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
userSetting.Value = &storepb.UserSetting_AccessTokens{
|
||||
AccessTokens: accessTokensUserSetting,
|
||||
}
|
||||
} else {
|
||||
// Skip unknown user setting v1 key.
|
||||
continue
|
||||
}
|
||||
userSettingList = append(userSettingList, userSetting)
|
||||
}
|
||||
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return userSettingList, nil
|
||||
}
|
||||
|
||||
func vacuumUserSetting(ctx context.Context, tx *sql.Tx) error {
|
||||
stmt := `
|
||||
DELETE FROM
|
||||
user_setting
|
||||
WHERE
|
||||
user_id NOT IN (
|
||||
SELECT
|
||||
id
|
||||
FROM
|
||||
user
|
||||
)`
|
||||
_, err := tx.ExecContext(ctx, stmt)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
Loading…
Reference in New Issue