diff --git a/store/db/sqlite/common.go b/store/db/sqlite/common.go new file mode 100644 index 00000000..edd096c5 --- /dev/null +++ b/store/db/sqlite/common.go @@ -0,0 +1,9 @@ +package sqlite + +import "google.golang.org/protobuf/encoding/protojson" + +var ( + protojsonUnmarshaler = protojson.UnmarshalOptions{ + DiscardUnknown: true, + } +) diff --git a/store/db/sqlite/inbox.go b/store/db/sqlite/inbox.go index 92b7caa5..bb99e91a 100644 --- a/store/db/sqlite/inbox.go +++ b/store/db/sqlite/inbox.go @@ -2,26 +2,124 @@ package sqlite import ( "context" + "strings" + "github.com/pkg/errors" + "google.golang.org/protobuf/encoding/protojson" + + storepb "github.com/usememos/memos/proto/gen/store" "github.com/usememos/memos/store" ) -// nolint func (d *DB) CreateInbox(ctx context.Context, create *store.Inbox) (*store.Inbox, error) { - return nil, nil + messageString := "{}" + if create.Message != nil { + bytes, err := protojson.Marshal(create.Message) + if err != nil { + return nil, errors.Wrap(err, "failed to marshal inbox message") + } + messageString = string(bytes) + } + + fields := []string{"`sender_id`", "`receiver_id`", "`status`", "`message`"} + placeholder := []string{"?", "?", "?", "?"} + args := []any{create.SenderID, create.ReceiverID, create.Status, messageString} + + stmt := "INSERT INTO inbox (" + strings.Join(fields, ", ") + ") VALUES (" + strings.Join(placeholder, ", ") + ") RETURNING `id`, `created_ts`" + if err := d.db.QueryRowContext(ctx, stmt, args...).Scan( + &create.ID, + &create.CreatedTs, + ); err != nil { + return nil, err + } + + return create, nil } -// nolint func (d *DB) ListInboxes(ctx context.Context, find *store.FindInbox) ([]*store.Inbox, error) { - return nil, nil + where, args := []string{"1 = 1"}, []any{} + + if find.ID != nil { + where, args = append(where, "`id` = ?"), append(args, *find.ID) + } + if find.SenderID != nil { + where, args = append(where, "`sender_id` = ?"), append(args, *find.SenderID) + } + if find.ReceiverID != nil { + where, args = append(where, "`receiver_id` = ?"), append(args, *find.ReceiverID) + } + if find.Status != nil { + where, args = append(where, "`status` = ?"), append(args, *find.Status) + } + + query := "SELECT `id`, `created_ts`, `sender_id`, `receiver_id`, `status`, `message` FROM `inbox` WHERE " + strings.Join(where, " AND ") + rows, err := d.db.QueryContext(ctx, query, args...) + if err != nil { + return nil, err + } + defer rows.Close() + + list := []*store.Inbox{} + for rows.Next() { + inbox := &store.Inbox{} + var messageBytes []byte + if err := rows.Scan( + &inbox.ID, + &inbox.CreatedTs, + &inbox.SenderID, + &inbox.ReceiverID, + &inbox.Status, + &messageBytes, + ); err != nil { + return nil, err + } + + message := &storepb.InboxMessage{} + if err := protojsonUnmarshaler.Unmarshal(messageBytes, message); err != nil { + return nil, err + } + inbox.Message = message + list = append(list, inbox) + } + + if err := rows.Err(); err != nil { + return nil, err + } + + return list, nil } -// nolint func (d *DB) UpdateInbox(ctx context.Context, update *store.UpdateInbox) (*store.Inbox, error) { - return nil, nil + set, args := []string{"status"}, []any{update.Status} + args = append(args, update.ID) + query := "UPDATE inbox SET " + strings.Join(set, " = ?, ") + " = ? WHERE id = ? RETURNING `id`, `created_ts`, `sender_id`, `receiver_id`, `status`, `message`" + inbox := &store.Inbox{} + var messageBytes []byte + if err := d.db.QueryRowContext(ctx, query, args...).Scan( + &inbox.ID, + &inbox.CreatedTs, + &inbox.SenderID, + &inbox.ReceiverID, + &inbox.Status, + &messageBytes, + ); err != nil { + return nil, err + } + message := &storepb.InboxMessage{} + if err := protojsonUnmarshaler.Unmarshal(messageBytes, message); err != nil { + return nil, err + } + inbox.Message = message + return inbox, nil } -// nolint func (d *DB) DeleteInbox(ctx context.Context, delete *store.DeleteInbox) error { + result, err := d.db.ExecContext(ctx, "DELETE FROM inbox WHERE id = ?", delete.ID) + if err != nil { + return err + } + if _, err := result.RowsAffected(); err != nil { + return err + } return nil } diff --git a/store/db/sqlite/migration/dev/LATEST__SCHEMA.sql b/store/db/sqlite/migration/dev/LATEST__SCHEMA.sql index c3f365dd..597d6317 100644 --- a/store/db/sqlite/migration/dev/LATEST__SCHEMA.sql +++ b/store/db/sqlite/migration/dev/LATEST__SCHEMA.sql @@ -141,6 +141,6 @@ CREATE TABLE inbox ( created_ts BIGINT NOT NULL DEFAULT (strftime('%s', 'now')), sender_id INTEGER NOT NULL, receiver_id INTEGER NOT NULL, - status TEXT NOT NULL CHECK (status IN ('UNREAD', 'READ', 'ARCHIVED')) DEFAULT 'UNREAD', + status TEXT NOT NULL, message TEXT NOT NULL DEFAULT '{}' ); diff --git a/test/store/inbox_test.go b/test/store/inbox_test.go new file mode 100644 index 00000000..11f07322 --- /dev/null +++ b/test/store/inbox_test.go @@ -0,0 +1,55 @@ +package teststore + +import ( + "context" + "testing" + + "github.com/stretchr/testify/require" + + storepb "github.com/usememos/memos/proto/gen/store" + "github.com/usememos/memos/store" +) + +func TestInboxStore(t *testing.T) { + ctx := context.Background() + ts := NewTestingStore(ctx, t) + user, err := createTestingHostUser(ctx, ts) + require.NoError(t, err) + const systemBotID int32 = 0 + create := &store.Inbox{ + SenderID: systemBotID, + ReceiverID: user.ID, + Status: store.UNREAD, + Message: &storepb.InboxMessage{ + Title: "title", + Content: "content", + Link: "link", + }, + } + inbox, err := ts.CreateInbox(ctx, create) + require.NoError(t, err) + require.NotNil(t, inbox) + require.Equal(t, create.Message, inbox.Message) + inboxes, err := ts.ListInboxes(ctx, &store.FindInbox{ + ReceiverID: &user.ID, + }) + require.NoError(t, err) + require.Equal(t, 1, len(inboxes)) + require.Equal(t, inbox, inboxes[0]) + updatedInbox, err := ts.UpdateInbox(ctx, &store.UpdateInbox{ + ID: inbox.ID, + Status: store.READ, + }) + require.NoError(t, err) + require.NotNil(t, updatedInbox) + require.Equal(t, store.READ, updatedInbox.Status) + err = ts.DeleteInbox(ctx, &store.DeleteInbox{ + ID: inbox.ID, + }) + require.NoError(t, err) + inboxes, err = ts.ListInboxes(ctx, &store.FindInbox{ + ReceiverID: &user.ID, + }) + require.NoError(t, err) + require.Equal(t, 0, len(inboxes)) +}