mirror of https://github.com/usememos/memos
				
				
				
			feat: implement memo relation store (#1598)
* feat: implement memo relation store * chore: updatepull/1602/head
							parent
							
								
									7776a6b7c6
								
							
						
					
					
						commit
						fab8a71fd2
					
				@ -0,0 +1,189 @@
 | 
			
		||||
package store
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"context"
 | 
			
		||||
	"database/sql"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"strings"
 | 
			
		||||
 | 
			
		||||
	"github.com/usememos/memos/common"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
type MemoRelationType string
 | 
			
		||||
 | 
			
		||||
const (
 | 
			
		||||
	MemoRelationReference  MemoRelationType = "REFERENCE"
 | 
			
		||||
	MemoRelationAdditional MemoRelationType = "ADDITIONAL"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
type MemoRelationMessage struct {
 | 
			
		||||
	MemoID        int
 | 
			
		||||
	RelatedMemoID int
 | 
			
		||||
	Type          MemoRelationType
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type FindMemoRelationMessage struct {
 | 
			
		||||
	MemoID        *int
 | 
			
		||||
	RelatedMemoID *int
 | 
			
		||||
	Type          *MemoRelationType
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type DeleteMemoRelationMessage struct {
 | 
			
		||||
	MemoID        *int
 | 
			
		||||
	RelatedMemoID *int
 | 
			
		||||
	Type          *MemoRelationType
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s *Store) UpsertMemoRelation(ctx context.Context, create *MemoRelationMessage) (*MemoRelationMessage, error) {
 | 
			
		||||
	tx, err := s.db.BeginTx(ctx, nil)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, FormatError(err)
 | 
			
		||||
	}
 | 
			
		||||
	defer tx.Rollback()
 | 
			
		||||
 | 
			
		||||
	query := `
 | 
			
		||||
		INSERT INTO memo_relation (
 | 
			
		||||
			memo_id,
 | 
			
		||||
			related_memo_id,
 | 
			
		||||
			type
 | 
			
		||||
		)
 | 
			
		||||
		VALUES (?, ?, ?)
 | 
			
		||||
		ON CONFLICT (memo_id, related_memo_id, type) DO UPDATE SET
 | 
			
		||||
			type = EXCLUDED.type
 | 
			
		||||
		RETURNING memo_id, related_memo_id, type
 | 
			
		||||
	`
 | 
			
		||||
	memoRelationMessage := &MemoRelationMessage{}
 | 
			
		||||
	if err := tx.QueryRowContext(
 | 
			
		||||
		ctx,
 | 
			
		||||
		query,
 | 
			
		||||
		create.MemoID,
 | 
			
		||||
		create.RelatedMemoID,
 | 
			
		||||
		create.Type,
 | 
			
		||||
	).Scan(
 | 
			
		||||
		&memoRelationMessage.MemoID,
 | 
			
		||||
		&memoRelationMessage.RelatedMemoID,
 | 
			
		||||
		&memoRelationMessage.Type,
 | 
			
		||||
	); err != nil {
 | 
			
		||||
		return nil, FormatError(err)
 | 
			
		||||
	}
 | 
			
		||||
	if err := tx.Commit(); err != nil {
 | 
			
		||||
		return nil, FormatError(err)
 | 
			
		||||
	}
 | 
			
		||||
	return memoRelationMessage, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s *Store) ListMemoRelations(ctx context.Context, find *FindMemoRelationMessage) ([]*MemoRelationMessage, error) {
 | 
			
		||||
	tx, err := s.db.BeginTx(ctx, nil)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, FormatError(err)
 | 
			
		||||
	}
 | 
			
		||||
	defer tx.Rollback()
 | 
			
		||||
 | 
			
		||||
	list, err := listMemoRelations(ctx, tx, find)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return list, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s *Store) GetMemoRelation(ctx context.Context, find *FindMemoRelationMessage) (*MemoRelationMessage, error) {
 | 
			
		||||
	tx, err := s.db.BeginTx(ctx, nil)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, FormatError(err)
 | 
			
		||||
	}
 | 
			
		||||
	defer tx.Rollback()
 | 
			
		||||
 | 
			
		||||
	list, err := listMemoRelations(ctx, tx, find)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if len(list) == 0 {
 | 
			
		||||
		return nil, &common.Error{Code: common.NotFound, Err: fmt.Errorf("not found")}
 | 
			
		||||
	}
 | 
			
		||||
	return list[0], nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s *Store) DeleteMemoRelation(ctx context.Context, delete *DeleteMemoRelationMessage) error {
 | 
			
		||||
	tx, err := s.db.BeginTx(ctx, nil)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return FormatError(err)
 | 
			
		||||
	}
 | 
			
		||||
	defer tx.Rollback()
 | 
			
		||||
 | 
			
		||||
	where, args := []string{"TRUE"}, []any{}
 | 
			
		||||
	if delete.MemoID != nil {
 | 
			
		||||
		where, args = append(where, "memo_id = ?"), append(args, delete.MemoID)
 | 
			
		||||
	}
 | 
			
		||||
	if delete.RelatedMemoID != nil {
 | 
			
		||||
		where, args = append(where, "related_memo_id = ?"), append(args, delete.RelatedMemoID)
 | 
			
		||||
	}
 | 
			
		||||
	if delete.Type != nil {
 | 
			
		||||
		where, args = append(where, "type = ?"), append(args, delete.Type)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	query := `
 | 
			
		||||
		DELETE FROM memo_relation
 | 
			
		||||
		WHERE ` + strings.Join(where, " AND ")
 | 
			
		||||
	if _, err := tx.ExecContext(ctx, query, args...); err != nil {
 | 
			
		||||
		return FormatError(err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if err := tx.Commit(); err != nil {
 | 
			
		||||
		return FormatError(err)
 | 
			
		||||
	}
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func listMemoRelations(ctx context.Context, tx *sql.Tx, find *FindMemoRelationMessage) ([]*MemoRelationMessage, error) {
 | 
			
		||||
	where, args := []string{"TRUE"}, []any{}
 | 
			
		||||
	if find.MemoID != nil {
 | 
			
		||||
		where, args = append(where, "memo_id = ?"), append(args, find.MemoID)
 | 
			
		||||
	}
 | 
			
		||||
	if find.RelatedMemoID != nil {
 | 
			
		||||
		where, args = append(where, "related_memo_id = ?"), append(args, find.RelatedMemoID)
 | 
			
		||||
	}
 | 
			
		||||
	if find.Type != nil {
 | 
			
		||||
		where, args = append(where, "type = ?"), append(args, find.Type)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	rows, err := tx.QueryContext(ctx, `
 | 
			
		||||
		SELECT
 | 
			
		||||
			memo_id,
 | 
			
		||||
			related_memo_id,
 | 
			
		||||
			type
 | 
			
		||||
		FROM memo_relation
 | 
			
		||||
		WHERE `+strings.Join(where, " AND "), args...)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, FormatError(err)
 | 
			
		||||
	}
 | 
			
		||||
	defer rows.Close()
 | 
			
		||||
 | 
			
		||||
	memoRelationMessages := []*MemoRelationMessage{}
 | 
			
		||||
	for rows.Next() {
 | 
			
		||||
		memoRelationMessage := &MemoRelationMessage{}
 | 
			
		||||
		if err := rows.Scan(
 | 
			
		||||
			&memoRelationMessage.MemoID,
 | 
			
		||||
			&memoRelationMessage.RelatedMemoID,
 | 
			
		||||
			&memoRelationMessage.Type,
 | 
			
		||||
		); err != nil {
 | 
			
		||||
			return nil, FormatError(err)
 | 
			
		||||
		}
 | 
			
		||||
		memoRelationMessages = append(memoRelationMessages, memoRelationMessage)
 | 
			
		||||
	}
 | 
			
		||||
	if err := rows.Err(); err != nil {
 | 
			
		||||
		return nil, FormatError(err)
 | 
			
		||||
	}
 | 
			
		||||
	return memoRelationMessages, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func vacuumMemoRelations(ctx context.Context, tx *sql.Tx) error {
 | 
			
		||||
	if _, err := tx.ExecContext(ctx, `
 | 
			
		||||
		DELETE FROM memo_relation
 | 
			
		||||
		WHERE memo_id NOT IN (SELECT id FROM memo) OR related_memo_id NOT IN (SELECT id FROM memo)
 | 
			
		||||
	`); err != nil {
 | 
			
		||||
		return FormatError(err)
 | 
			
		||||
	}
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
@ -0,0 +1,57 @@
 | 
			
		||||
package teststore
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"context"
 | 
			
		||||
	"testing"
 | 
			
		||||
 | 
			
		||||
	"github.com/stretchr/testify/require"
 | 
			
		||||
	"github.com/usememos/memos/api"
 | 
			
		||||
	"github.com/usememos/memos/store"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func TestMemoRelationStore(t *testing.T) {
 | 
			
		||||
	ctx := context.Background()
 | 
			
		||||
	ts := NewTestingStore(ctx, t)
 | 
			
		||||
	user, err := createTestingHostUser(ctx, ts)
 | 
			
		||||
	require.NoError(t, err)
 | 
			
		||||
	memoCreate := &api.MemoCreate{
 | 
			
		||||
		CreatorID:  user.ID,
 | 
			
		||||
		Content:    "test_content",
 | 
			
		||||
		Visibility: api.Public,
 | 
			
		||||
	}
 | 
			
		||||
	memo, err := ts.CreateMemo(ctx, memoCreate)
 | 
			
		||||
	require.NoError(t, err)
 | 
			
		||||
	require.Equal(t, memoCreate.Content, memo.Content)
 | 
			
		||||
	memoCreate = &api.MemoCreate{
 | 
			
		||||
		CreatorID:  user.ID,
 | 
			
		||||
		Content:    "test_content_2",
 | 
			
		||||
		Visibility: api.Public,
 | 
			
		||||
	}
 | 
			
		||||
	memo2, err := ts.CreateMemo(ctx, memoCreate)
 | 
			
		||||
	require.NoError(t, err)
 | 
			
		||||
	require.Equal(t, memoCreate.Content, memo2.Content)
 | 
			
		||||
	memoRelationMessage := &store.MemoRelationMessage{
 | 
			
		||||
		MemoID:        memo.ID,
 | 
			
		||||
		RelatedMemoID: memo2.ID,
 | 
			
		||||
		Type:          store.MemoRelationReference,
 | 
			
		||||
	}
 | 
			
		||||
	_, err = ts.UpsertMemoRelation(ctx, memoRelationMessage)
 | 
			
		||||
	require.NoError(t, err)
 | 
			
		||||
	memoRelation, err := ts.ListMemoRelations(ctx, &store.FindMemoRelationMessage{
 | 
			
		||||
		MemoID: &memo.ID,
 | 
			
		||||
	})
 | 
			
		||||
	require.NoError(t, err)
 | 
			
		||||
	require.Equal(t, 1, len(memoRelation))
 | 
			
		||||
	require.Equal(t, memo2.ID, memoRelation[0].RelatedMemoID)
 | 
			
		||||
	require.Equal(t, memo.ID, memoRelation[0].MemoID)
 | 
			
		||||
	require.Equal(t, store.MemoRelationReference, memoRelation[0].Type)
 | 
			
		||||
	err = ts.DeleteMemo(ctx, &api.MemoDelete{
 | 
			
		||||
		ID: memo2.ID,
 | 
			
		||||
	})
 | 
			
		||||
	require.NoError(t, err)
 | 
			
		||||
	memoRelation, err = ts.ListMemoRelations(ctx, &store.FindMemoRelationMessage{
 | 
			
		||||
		MemoID: &memo.ID,
 | 
			
		||||
	})
 | 
			
		||||
	require.NoError(t, err)
 | 
			
		||||
	require.Equal(t, 0, len(memoRelation))
 | 
			
		||||
}
 | 
			
		||||
					Loading…
					
					
				
		Reference in New Issue