diff --git a/server/router/mcp/README.md b/server/router/mcp/README.md index 0af991fc0..cd2fbf182 100644 --- a/server/router/mcp/README.md +++ b/server/router/mcp/README.md @@ -12,6 +12,35 @@ DELETE /mcp (optional session termination) Transport: [Streamable HTTP](https://modelcontextprotocol.io/specification/2025-03-26/basic/transports) (single endpoint, MCP spec 2025-03-26). +### Tool Filtering + +The default `/mcp` endpoint exposes all tools. Clients can opt into a smaller +tool surface with GitHub-style headers or route aliases: + +| Control | Description | +|---|---| +| `X-MCP-Readonly: true` | Hide and block mutating tools | +| `X-MCP-Toolsets: memos,tags,attachments,relations,reactions` | Limit the default tool list to selected toolsets | +| `X-MCP-Tools: list_tags,get_memo` | Add specific tools to the selected toolset list | +| `X-MCP-Exclude-Tools: delete_memo` | Remove specific tools | + +Equivalent aliases: + +```text +/mcp/readonly +/mcp/x/{toolsets} +/mcp/x/{toolsets}/readonly +``` + +Examples: + +```text +/mcp/x/memos,tags/readonly +X-MCP-Toolsets: memos +X-MCP-Tools: list_tags +X-MCP-Exclude-Tools: delete_memo +``` + ## Capabilities The server advertises the following MCP capabilities: @@ -126,7 +155,9 @@ claude mcp add --scope user --transport http memos http://localhost:5230/mcp \ | File | Responsibility | |---|---| -| `mcp.go` | `MCPService` struct, constructor, route registration, auth middleware | +| `mcp.go` | `MCPService` struct, constructor, route registration, auth middleware, tool filtering | +| `tool_metadata.go` | Toolsets, read-only metadata, annotations, structured result helpers | +| `api_helpers.go` | Conversion helpers for calling API service methods from MCP tools | | `tools_memo.go` | Memo CRUD tools + helpers (JSON types, visibility/access checks) | | `tools_tag.go` | Tag listing tool | | `tools_attachment.go` | Attachment listing, metadata, deletion, linking tools | diff --git a/server/router/mcp/api_helpers.go b/server/router/mcp/api_helpers.go new file mode 100644 index 000000000..d561ea970 --- /dev/null +++ b/server/router/mcp/api_helpers.go @@ -0,0 +1,74 @@ +package mcp + +import ( + "context" + + "github.com/pkg/errors" + + v1pb "github.com/usememos/memos/proto/gen/api/v1" + apiv1 "github.com/usememos/memos/server/router/api/v1" + "github.com/usememos/memos/store" +) + +func visibilityToProto(visibility store.Visibility) v1pb.Visibility { + switch visibility { + case store.Protected: + return v1pb.Visibility_PROTECTED + case store.Public: + return v1pb.Visibility_PUBLIC + default: + return v1pb.Visibility_PRIVATE + } +} + +func rowStatusToProto(rowStatus store.RowStatus) v1pb.State { + switch rowStatus { + case store.Archived: + return v1pb.State_ARCHIVED + default: + return v1pb.State_NORMAL + } +} + +func (s *MCPService) loadMemoJSONByName(ctx context.Context, name string) (memoJSON, error) { + uid, err := parseMemoUID(name) + if err != nil { + return memoJSON{}, err + } + memo, err := s.store.GetMemo(ctx, &store.FindMemo{UID: &uid}) + if err != nil { + return memoJSON{}, errors.Wrap(err, "failed to get memo") + } + if memo == nil { + return memoJSON{}, errors.New("memo not found") + } + return storeMemoToJSONWithStore(ctx, s.store, memo) +} + +func (s *MCPService) loadReactionJSONByID(ctx context.Context, reactionID int32) (reactionJSON, error) { + reaction, err := s.store.GetReaction(ctx, &store.FindReaction{ID: &reactionID}) + if err != nil { + return reactionJSON{}, errors.Wrap(err, "failed to get reaction") + } + if reaction == nil { + return reactionJSON{}, errors.New("reaction not found") + } + creator, err := lookupUsername(ctx, s.store, reaction.CreatorID) + if err != nil { + return reactionJSON{}, errors.Wrap(err, "failed to resolve reaction creator") + } + return reactionJSON{ + ID: reaction.ID, + Creator: creator, + ReactionType: reaction.ReactionType, + CreateTime: reaction.CreatedTs, + }, nil +} + +func (s *MCPService) loadReactionJSONByName(ctx context.Context, name string) (reactionJSON, error) { + _, reactionID, err := apiv1.ExtractMemoReactionIDFromName(name) + if err != nil { + return reactionJSON{}, err + } + return s.loadReactionJSONByID(ctx, reactionID) +} diff --git a/server/router/mcp/mcp.go b/server/router/mcp/mcp.go index 93dcb6c82..a5ec3820c 100644 --- a/server/router/mcp/mcp.go +++ b/server/router/mcp/mcp.go @@ -1,26 +1,42 @@ package mcp import ( + "context" + "fmt" "net/http" + "strings" "github.com/labstack/echo/v5" + "github.com/mark3labs/mcp-go/mcp" mcpserver "github.com/mark3labs/mcp-go/server" "github.com/usememos/memos/internal/profile" "github.com/usememos/memos/server/auth" + apiv1 "github.com/usememos/memos/server/router/api/v1" "github.com/usememos/memos/store" ) +const ( + headerMCPReadonly = "X-MCP-Readonly" + headerMCPToolsets = "X-MCP-Toolsets" + headerMCPTools = "X-MCP-Tools" + headerMCPExcludeTools = "X-MCP-Exclude-Tools" +) + +type mcpRequestConfigContextKey struct{} + type MCPService struct { profile *profile.Profile store *store.Store + apiV1Service *apiv1.APIV1Service authenticator *auth.Authenticator } -func NewMCPService(profile *profile.Profile, store *store.Store, secret string) *MCPService { +func NewMCPService(profile *profile.Profile, store *store.Store, secret string, apiV1Service *apiv1.APIV1Service) *MCPService { return &MCPService{ profile: profile, store: store, + apiV1Service: apiV1Service, authenticator: auth.NewAuthenticator(store, secret), } } @@ -31,6 +47,10 @@ func (s *MCPService) RegisterRoutes(echoServer *echo.Echo) { mcpserver.WithResourceCapabilities(true, true), mcpserver.WithPromptCapabilities(true), mcpserver.WithLogging(), + mcpserver.WithToolFilter(s.filterTools), + mcpserver.WithToolHandlerMiddleware(s.enforceToolAccess), + mcpserver.WithRecovery(), + mcpserver.WithResourceRecovery(), ) s.registerMemoTools(mcpSrv) s.registerTagTools(mcpSrv) @@ -40,7 +60,9 @@ func (s *MCPService) RegisterRoutes(echoServer *echo.Echo) { s.registerMemoResources(mcpSrv) s.registerPrompts(mcpSrv) - httpHandler := mcpserver.NewStreamableHTTPServer(mcpSrv) + httpHandler := mcpserver.NewStreamableHTTPServer(mcpSrv, + mcpserver.WithHTTPContextFunc(s.withRequestConfig), + ) mcpGroup := echoServer.Group("") mcpGroup.Use(func(next echo.HandlerFunc) echo.HandlerFunc { @@ -52,7 +74,18 @@ func (s *MCPService) RegisterRoutes(echoServer *echo.Echo) { headers := c.Response().Header() headers.Set("Vary", "Origin") headers.Set("Access-Control-Allow-Origin", origin) - headers.Set("Access-Control-Allow-Headers", "Authorization, Content-Type, Accept, Mcp-Session-Id, MCP-Protocol-Version, Last-Event-ID") + headers.Set("Access-Control-Allow-Headers", strings.Join([]string{ + "Authorization", + "Content-Type", + "Accept", + "Mcp-Session-Id", + "MCP-Protocol-Version", + "Last-Event-ID", + headerMCPReadonly, + headerMCPToolsets, + headerMCPTools, + headerMCPExcludeTools, + }, ", ")) headers.Set("Access-Control-Allow-Methods", "GET, POST, DELETE, OPTIONS") if c.Request().Method == http.MethodOptions { return c.NoContent(http.StatusNoContent) @@ -72,4 +105,147 @@ func (s *MCPService) RegisterRoutes(echoServer *echo.Echo) { } }) mcpGroup.Any("/mcp", echo.WrapHandler(httpHandler)) + mcpGroup.Any("/mcp/readonly", echo.WrapHandler(httpHandler)) + mcpGroup.Any("/mcp/x/:toolsets", echo.WrapHandler(httpHandler)) + mcpGroup.Any("/mcp/x/:toolsets/readonly", echo.WrapHandler(httpHandler)) +} + +func (*MCPService) withRequestConfig(ctx context.Context, r *http.Request) context.Context { + return context.WithValue(ctx, mcpRequestConfigContextKey{}, parseMCPRequestConfig(r)) +} + +func (*MCPService) filterTools(ctx context.Context, tools []mcp.Tool) []mcp.Tool { + cfg := mcpRequestConfigFromContext(ctx) + filtered := make([]mcp.Tool, 0, len(tools)) + for _, tool := range tools { + if cfg.allowsTool(tool.Name) { + filtered = append(filtered, tool) + } + } + return filtered +} + +func (*MCPService) enforceToolAccess(next mcpserver.ToolHandlerFunc) mcpserver.ToolHandlerFunc { + return func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { + cfg := mcpRequestConfigFromContext(ctx) + if !cfg.allowsTool(req.Params.Name) { + return mcp.NewToolResultError(fmt.Sprintf("tool %q is not enabled by MCP configuration", req.Params.Name)), nil + } + return next(ctx, req) + } +} + +type mcpRequestConfig struct { + readOnly bool + toolsets map[string]struct{} + includeTools map[string]struct{} + excludeTools map[string]struct{} +} + +func mcpRequestConfigFromContext(ctx context.Context) mcpRequestConfig { + if cfg, ok := ctx.Value(mcpRequestConfigContextKey{}).(mcpRequestConfig); ok { + return cfg + } + return mcpRequestConfig{} +} + +func parseMCPRequestConfig(r *http.Request) mcpRequestConfig { + cfg := mcpRequestConfig{} + + pathToolsets, pathReadonly := parseMCPPathConfig(r.URL.Path) + cfg.readOnly = pathReadonly || parseBoolHeader(r.Header.Get(headerMCPReadonly)) + cfg.toolsets = mergeStringSets(cfg.toolsets, pathToolsets) + cfg.toolsets = mergeStringSets(cfg.toolsets, parseCommaSet(r.Header.Get(headerMCPToolsets), strings.ToLower)) + cfg.includeTools = parseCommaSet(r.Header.Get(headerMCPTools), keepString) + cfg.excludeTools = parseCommaSet(r.Header.Get(headerMCPExcludeTools), keepString) + return cfg +} + +func parseMCPPathConfig(path string) (map[string]struct{}, bool) { + trimmed := strings.Trim(path, "/") + if trimmed == "mcp/readonly" { + return nil, true + } + const prefix = "mcp/x/" + if !strings.HasPrefix(trimmed, prefix) { + return nil, false + } + + rest := strings.TrimPrefix(trimmed, prefix) + readOnly := false + if strings.HasSuffix(rest, "/readonly") { + readOnly = true + rest = strings.TrimSuffix(rest, "/readonly") + } + return parseCommaSet(rest, strings.ToLower), readOnly +} + +func (cfg mcpRequestConfig) allowsTool(name string) bool { + if _, known := allMCPToolNames[name]; !known { + return false + } + if cfg.readOnly { + if _, mutates := mcpMutationTools[name]; mutates { + return false + } + } + if _, excluded := cfg.excludeTools[name]; excluded { + return false + } + if _, included := cfg.includeTools[name]; included { + return true + } + if len(cfg.toolsets) == 0 { + return true + } + for toolset := range cfg.toolsets { + if _, ok := mcpToolsByToolset[toolset][name]; ok { + return true + } + } + return false +} + +func parseBoolHeader(value string) bool { + switch strings.ToLower(strings.TrimSpace(value)) { + case "1", "t", "true", "y", "yes", "on": + return true + default: + return false + } +} + +func parseCommaSet(value string, normalize func(string) string) map[string]struct{} { + if value == "" { + return nil + } + result := map[string]struct{}{} + for _, item := range strings.Split(value, ",") { + item = strings.TrimSpace(item) + if item == "" { + continue + } + result[normalize(item)] = struct{}{} + } + if len(result) == 0 { + return nil + } + return result +} + +func mergeStringSets(dst map[string]struct{}, src map[string]struct{}) map[string]struct{} { + if len(src) == 0 { + return dst + } + if dst == nil { + dst = map[string]struct{}{} + } + for item := range src { + dst[item] = struct{}{} + } + return dst +} + +func keepString(s string) string { + return s } diff --git a/server/router/mcp/mcp_test.go b/server/router/mcp/mcp_test.go index a4dd1c489..7e1e0dbfb 100644 --- a/server/router/mcp/mcp_test.go +++ b/server/router/mcp/mcp_test.go @@ -1,11 +1,17 @@ package mcp import ( + "bytes" "context" "encoding/json" + "net/http" "net/http/httptest" + "reflect" "testing" + "time" + "unsafe" + "github.com/labstack/echo/v5" "github.com/lithammer/shortuuid/v4" "github.com/mark3labs/mcp-go/mcp" "github.com/stretchr/testify/require" @@ -13,6 +19,7 @@ import ( "github.com/usememos/memos/internal/profile" storepb "github.com/usememos/memos/proto/gen/store" "github.com/usememos/memos/server/auth" + apiv1service "github.com/usememos/memos/server/router/api/v1" "github.com/usememos/memos/store" teststore "github.com/usememos/memos/store/test" ) @@ -31,10 +38,12 @@ func newTestMCPService(t *testing.T) *testMCPService { require.NoError(t, stores.Close()) }) - svc := NewMCPService(&profile.Profile{ + profile := &profile.Profile{ Driver: "sqlite", InstanceURL: "https://notes.example.com", - }, stores, "test-secret") + } + apiV1Service := apiv1service.NewAPIV1Service("test-secret", profile, stores) + svc := NewMCPService(profile, stores, "test-secret", apiV1Service) return &testMCPService{ service: svc, store: stores, @@ -115,6 +124,125 @@ func firstText(t *testing.T, result *mcp.CallToolResult) string { return text.Text } +func initializeMCPHTTP(t *testing.T, e *echo.Echo, path string, headers map[string]string) string { + t.Helper() + payload := map[string]any{ + "jsonrpc": "2.0", + "id": 1, + "method": "initialize", + "params": map[string]any{ + "protocolVersion": "2025-06-18", + "capabilities": map[string]any{}, + "clientInfo": map[string]any{ + "name": "mcp-test", + "version": "1.0.0", + }, + }, + } + resp := postMCPHTTP(t, e, path, "", headers, payload) + require.Equal(t, http.StatusOK, resp.Code, resp.Body.String()) + sessionID := resp.Header().Get("Mcp-Session-Id") + require.NotEmpty(t, sessionID) + return sessionID +} + +func callMCPHTTP(t *testing.T, e *echo.Echo, path string, sessionID string, headers map[string]string, method string, params any) map[string]any { + t.Helper() + payload := map[string]any{ + "jsonrpc": "2.0", + "id": 2, + "method": method, + } + if params != nil { + payload["params"] = params + } + resp := postMCPHTTP(t, e, path, sessionID, headers, payload) + require.Equal(t, http.StatusOK, resp.Code, resp.Body.String()) + var decoded map[string]any + require.NoError(t, json.Unmarshal(resp.Body.Bytes(), &decoded)) + return decoded +} + +func postMCPHTTP(t *testing.T, e *echo.Echo, path string, sessionID string, headers map[string]string, payload map[string]any) *httptest.ResponseRecorder { + t.Helper() + body, err := json.Marshal(payload) + require.NoError(t, err) + req := httptest.NewRequest(http.MethodPost, path, bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + if sessionID != "" { + req.Header.Set("Mcp-Session-Id", sessionID) + } + for key, value := range headers { + req.Header.Set(key, value) + } + resp := httptest.NewRecorder() + e.ServeHTTP(resp, req) + return resp +} + +func toolNamesFromListResponse(t *testing.T, response map[string]any) map[string]struct{} { + t.Helper() + result, ok := response["result"].(map[string]any) + require.True(t, ok, "missing result: %#v", response) + rawTools, ok := result["tools"].([]any) + require.True(t, ok, "missing tools: %#v", result) + names := map[string]struct{}{} + for _, rawTool := range rawTools { + tool, ok := rawTool.(map[string]any) + require.True(t, ok) + name, ok := tool["name"].(string) + require.True(t, ok) + names[name] = struct{}{} + } + return names +} + +func requireToolPresent(t *testing.T, names map[string]struct{}, name string) { + t.Helper() + _, ok := names[name] + require.True(t, ok, "expected tool %q to be present in %#v", name, names) +} + +func requireToolAbsent(t *testing.T, names map[string]struct{}, name string) { + t.Helper() + _, ok := names[name] + require.False(t, ok, "expected tool %q to be absent in %#v", name, names) +} + +func nextSSEEvent(t *testing.T, client *apiv1service.SSEClient) *apiv1service.SSEEvent { + t.Helper() + events := sseClientEvents(t, client) + var data []byte + select { + case eventData, ok := <-events: + require.True(t, ok, "SSE client channel closed") + data = eventData + case <-time.After(time.Second): + t.Fatal("timed out waiting for SSE event") + } + var event apiv1service.SSEEvent + require.NoError(t, json.Unmarshal(data, &event)) + return &event +} + +func requireNoSSEEvent(t *testing.T, client *apiv1service.SSEClient) { + t.Helper() + select { + case eventData, ok := <-sseClientEvents(t, client): + require.True(t, ok, "SSE client channel closed") + t.Fatalf("unexpected SSE event received: %s", string(eventData)) + case <-time.After(150 * time.Millisecond): + } +} + +func sseClientEvents(t *testing.T, client *apiv1service.SSEClient) <-chan []byte { + t.Helper() + field := reflect.ValueOf(client).Elem().FieldByName("events") + events, ok := reflect.NewAt(field.Type(), unsafe.Pointer(field.UnsafeAddr())).Elem().Interface().(chan []byte) + require.True(t, ok) + return events +} + func TestHandleGetMemoAndReadResourceDenyArchivedMemoToNonCreator(t *testing.T) { ts := newTestMCPService(t) owner := ts.createUser(t, "owner") @@ -273,3 +401,205 @@ func TestIsAllowedOrigin(t *testing.T) { require.False(t, ts.service.isAllowedOrigin(req)) }) } + +func TestMCPToolFilteringRoutesAndHeaders(t *testing.T) { + ts := newTestMCPService(t) + e := echo.New() + ts.service.RegisterRoutes(e) + + t.Run("default endpoint lists all tools", func(t *testing.T) { + sessionID := initializeMCPHTTP(t, e, "/mcp", nil) + response := callMCPHTTP(t, e, "/mcp", sessionID, nil, "tools/list", map[string]any{}) + names := toolNamesFromListResponse(t, response) + require.Len(t, names, len(allMCPToolNames)) + requireToolPresent(t, names, "create_memo") + requireToolPresent(t, names, "list_tags") + requireToolPresent(t, names, "upsert_reaction") + }) + + t.Run("readonly header hides and blocks mutation tools", func(t *testing.T) { + headers := map[string]string{headerMCPReadonly: "true"} + sessionID := initializeMCPHTTP(t, e, "/mcp", nil) + response := callMCPHTTP(t, e, "/mcp", sessionID, headers, "tools/list", map[string]any{}) + names := toolNamesFromListResponse(t, response) + requireToolPresent(t, names, "list_memos") + requireToolPresent(t, names, "list_tags") + requireToolAbsent(t, names, "create_memo") + requireToolAbsent(t, names, "delete_memo") + + callResponse := callMCPHTTP(t, e, "/mcp", sessionID, headers, "tools/call", map[string]any{ + "name": "create_memo", + "arguments": map[string]any{"content": "blocked"}, + }) + result, ok := callResponse["result"].(map[string]any) + require.True(t, ok) + require.Equal(t, true, result["isError"]) + rawContent, ok := result["content"].([]any) + require.True(t, ok) + content, ok := rawContent[0].(map[string]any) + require.True(t, ok) + require.Contains(t, content["text"], "not enabled") + }) + + t.Run("readonly alias applies path config", func(t *testing.T) { + sessionID := initializeMCPHTTP(t, e, "/mcp/readonly", nil) + response := callMCPHTTP(t, e, "/mcp/readonly", sessionID, nil, "tools/list", map[string]any{}) + names := toolNamesFromListResponse(t, response) + requireToolPresent(t, names, "get_memo") + requireToolAbsent(t, names, "create_memo") + requireToolAbsent(t, names, "upsert_reaction") + }) + + t.Run("toolsets include and exclude compose", func(t *testing.T) { + headers := map[string]string{ + headerMCPToolsets: "memos", + headerMCPTools: "list_tags", + headerMCPExcludeTools: "get_memo", + } + sessionID := initializeMCPHTTP(t, e, "/mcp", nil) + response := callMCPHTTP(t, e, "/mcp", sessionID, headers, "tools/list", map[string]any{}) + names := toolNamesFromListResponse(t, response) + requireToolPresent(t, names, "list_memos") + requireToolPresent(t, names, "list_tags") + requireToolAbsent(t, names, "get_memo") + requireToolAbsent(t, names, "list_attachments") + }) + + t.Run("path toolsets and readonly compose", func(t *testing.T) { + sessionID := initializeMCPHTTP(t, e, "/mcp/x/memos,tags/readonly", nil) + response := callMCPHTTP(t, e, "/mcp/x/memos,tags/readonly", sessionID, nil, "tools/list", map[string]any{}) + names := toolNamesFromListResponse(t, response) + requireToolPresent(t, names, "list_memos") + requireToolPresent(t, names, "list_tags") + requireToolAbsent(t, names, "create_memo") + requireToolAbsent(t, names, "list_attachments") + }) + + t.Run("unknown toolset returns empty tool list", func(t *testing.T) { + sessionID := initializeMCPHTTP(t, e, "/mcp/x/notreal", nil) + response := callMCPHTTP(t, e, "/mcp/x/notreal", sessionID, nil, "tools/list", map[string]any{}) + names := toolNamesFromListResponse(t, response) + require.Empty(t, names) + }) +} + +func TestMCPMemoAndReactionMutationsEmitSSEEvents(t *testing.T) { + ts := newTestMCPService(t) + user := ts.createUser(t, "author") + ctx := withUser(context.Background(), user.ID) + client := ts.service.apiV1Service.SSEHub.Subscribe(user.ID, store.RoleUser) + defer ts.service.apiV1Service.SSEHub.Unsubscribe(client) + + createResult, err := ts.service.handleCreateMemo(ctx, toolRequest("create_memo", map[string]any{ + "content": "created from MCP", + "visibility": "PRIVATE", + })) + require.NoError(t, err) + require.False(t, createResult.IsError) + createEvent := nextSSEEvent(t, client) + require.Equal(t, apiv1service.SSEEventMemoCreated, createEvent.Type) + + var created memoJSON + require.NoError(t, json.Unmarshal([]byte(firstText(t, createResult)), &created)) + + updateResult, err := ts.service.handleUpdateMemo(ctx, toolRequest("update_memo", map[string]any{ + "name": created.Name, + "content": "updated from MCP", + })) + require.NoError(t, err) + require.False(t, updateResult.IsError) + updateEvent := nextSSEEvent(t, client) + require.Equal(t, apiv1service.SSEEventMemoUpdated, updateEvent.Type) + + commentResult, err := ts.service.handleCreateMemoComment(ctx, toolRequest("create_memo_comment", map[string]any{ + "name": created.Name, + "content": "comment from MCP", + })) + require.NoError(t, err) + require.False(t, commentResult.IsError) + commentEvent := nextSSEEvent(t, client) + require.Equal(t, apiv1service.SSEEventMemoCommentCreated, commentEvent.Type) + require.Equal(t, created.Name, commentEvent.Name) + + upsertReactionResult, err := ts.service.handleUpsertReaction(ctx, toolRequest("upsert_reaction", map[string]any{ + "name": created.Name, + "reaction_type": "👍", + })) + require.NoError(t, err) + require.False(t, upsertReactionResult.IsError) + reactionEvent := nextSSEEvent(t, client) + require.Equal(t, apiv1service.SSEEventReactionUpserted, reactionEvent.Type) + + var reaction reactionJSON + require.NoError(t, json.Unmarshal([]byte(firstText(t, upsertReactionResult)), &reaction)) + deleteReactionResult, err := ts.service.handleDeleteReaction(ctx, toolRequest("delete_reaction", map[string]any{ + "id": float64(reaction.ID), + })) + require.NoError(t, err) + require.False(t, deleteReactionResult.IsError) + deleteReactionEvent := nextSSEEvent(t, client) + require.Equal(t, apiv1service.SSEEventReactionDeleted, deleteReactionEvent.Type) + + deleteResult, err := ts.service.handleDeleteMemo(ctx, toolRequest("delete_memo", map[string]any{ + "name": created.Name, + })) + require.NoError(t, err) + require.False(t, deleteResult.IsError) + deleteEvent := nextSSEEvent(t, client) + require.Equal(t, apiv1service.SSEEventMemoDeleted, deleteEvent.Type) +} + +func TestMCPRelationAndAttachmentMutationsEmitMemoUpdated(t *testing.T) { + ts := newTestMCPService(t) + user := ts.createUser(t, "owner") + ctx := withUser(context.Background(), user.ID) + source := ts.createMemo(t, user.ID, store.Private, "source") + target := ts.createMemo(t, user.ID, store.Private, "target") + attachment := ts.createAttachment(t, user.ID, nil) + client := ts.service.apiV1Service.SSEHub.Subscribe(user.ID, store.RoleUser) + defer ts.service.apiV1Service.SSEHub.Unsubscribe(client) + + relationResult, err := ts.service.handleCreateMemoRelation(ctx, toolRequest("create_memo_relation", map[string]any{ + "name": "memos/" + source.UID, + "related_memo": "memos/" + target.UID, + })) + require.NoError(t, err) + require.False(t, relationResult.IsError) + relationEvent := nextSSEEvent(t, client) + require.Equal(t, apiv1service.SSEEventMemoUpdated, relationEvent.Type) + require.Equal(t, "memos/"+source.UID, relationEvent.Name) + + duplicateRelationResult, err := ts.service.handleCreateMemoRelation(ctx, toolRequest("create_memo_relation", map[string]any{ + "name": "memos/" + source.UID, + "related_memo": "memos/" + target.UID, + })) + require.NoError(t, err) + require.False(t, duplicateRelationResult.IsError) + requireNoSSEEvent(t, client) + + selfRelationResult, err := ts.service.handleCreateMemoRelation(ctx, toolRequest("create_memo_relation", map[string]any{ + "name": "memos/" + source.UID, + "related_memo": "memos/" + source.UID, + })) + require.NoError(t, err) + require.True(t, selfRelationResult.IsError) + require.Contains(t, firstText(t, selfRelationResult), "itself") + + linkResult, err := ts.service.handleLinkAttachmentToMemo(ctx, toolRequest("link_attachment_to_memo", map[string]any{ + "name": "attachments/" + attachment.UID, + "memo": "memos/" + source.UID, + })) + require.NoError(t, err) + require.False(t, linkResult.IsError) + attachmentEvent := nextSSEEvent(t, client) + require.Equal(t, apiv1service.SSEEventMemoUpdated, attachmentEvent.Type) + require.Equal(t, "memos/"+source.UID, attachmentEvent.Name) + + relinkResult, err := ts.service.handleLinkAttachmentToMemo(ctx, toolRequest("link_attachment_to_memo", map[string]any{ + "name": "attachments/" + attachment.UID, + "memo": "memos/" + source.UID, + })) + require.NoError(t, err) + require.False(t, relinkResult.IsError) + requireNoSSEEvent(t, client) +} diff --git a/server/router/mcp/tool_metadata.go b/server/router/mcp/tool_metadata.go new file mode 100644 index 000000000..61c4284af --- /dev/null +++ b/server/router/mcp/tool_metadata.go @@ -0,0 +1,102 @@ +package mcp + +import "github.com/mark3labs/mcp-go/mcp" + +var mcpToolsByToolset = map[string]map[string]struct{}{ + "memos": stringSet( + "list_memos", + "get_memo", + "create_memo", + "update_memo", + "delete_memo", + "search_memos", + "list_memo_comments", + "create_memo_comment", + ), + "tags": stringSet( + "list_tags", + ), + "attachments": stringSet( + "list_attachments", + "get_attachment", + "delete_attachment", + "link_attachment_to_memo", + ), + "relations": stringSet( + "list_memo_relations", + "create_memo_relation", + "delete_memo_relation", + ), + "reactions": stringSet( + "list_reactions", + "upsert_reaction", + "delete_reaction", + ), +} + +var allMCPToolNames = func() map[string]struct{} { + names := map[string]struct{}{} + for _, tools := range mcpToolsByToolset { + for name := range tools { + names[name] = struct{}{} + } + } + return names +}() + +var mcpMutationTools = stringSet( + "create_memo", + "update_memo", + "delete_memo", + "create_memo_comment", + "delete_attachment", + "link_attachment_to_memo", + "create_memo_relation", + "delete_memo_relation", + "upsert_reaction", + "delete_reaction", +) + +type deletedJSON struct { + Deleted bool `json:"deleted"` +} + +func stringSet(values ...string) map[string]struct{} { + result := make(map[string]struct{}, len(values)) + for _, value := range values { + result[value] = struct{}{} + } + return result +} + +func readOnlyToolOptions(title string, description string, opts ...mcp.ToolOption) []mcp.ToolOption { + return annotatedToolOptions(title, description, true, false, true, false, opts...) +} + +func createToolOptions(title string, description string, idempotent bool, opts ...mcp.ToolOption) []mcp.ToolOption { + return annotatedToolOptions(title, description, false, false, idempotent, false, opts...) +} + +func updateToolOptions(title string, description string, opts ...mcp.ToolOption) []mcp.ToolOption { + return annotatedToolOptions(title, description, false, true, false, false, opts...) +} + +func annotatedToolOptions(title string, description string, readOnly bool, destructive bool, idempotent bool, openWorld bool, opts ...mcp.ToolOption) []mcp.ToolOption { + base := []mcp.ToolOption{ + mcp.WithTitleAnnotation(title), + mcp.WithDescription(description), + mcp.WithReadOnlyHintAnnotation(readOnly), + mcp.WithDestructiveHintAnnotation(destructive), + mcp.WithIdempotentHintAnnotation(idempotent), + mcp.WithOpenWorldHintAnnotation(openWorld), + } + return append(base, opts...) +} + +func newToolResultJSON(v any) (*mcp.CallToolResult, error) { + return mcp.NewToolResultJSON(v) +} + +func newDeletedToolResult() (*mcp.CallToolResult, error) { + return newToolResultJSON(deletedJSON{Deleted: true}) +} diff --git a/server/router/mcp/tools_attachment.go b/server/router/mcp/tools_attachment.go index 2e2b3f571..86b14745e 100644 --- a/server/router/mcp/tools_attachment.go +++ b/server/router/mcp/tools_attachment.go @@ -9,6 +9,7 @@ import ( mcpserver "github.com/mark3labs/mcp-go/server" "github.com/pkg/errors" + v1pb "github.com/usememos/memos/proto/gen/api/v1" storepb "github.com/usememos/memos/proto/gen/store" "github.com/usememos/memos/server/auth" "github.com/usememos/memos/store" @@ -26,6 +27,11 @@ type attachmentJSON struct { Memo string `json:"memo,omitempty"` } +type attachmentListJSON struct { + Attachments []attachmentJSON `json:"attachments"` + HasMore bool `json:"has_more"` +} + func storeAttachmentToJSON(ctx context.Context, stores *store.Store, a *store.Attachment) (attachmentJSON, error) { creator, err := lookupUsername(ctx, stores, a.CreatorID) if err != nil { @@ -98,26 +104,34 @@ func parseAttachmentUID(name string) (string, error) { func (s *MCPService) registerAttachmentTools(mcpSrv *mcpserver.MCPServer) { mcpSrv.AddTool(mcp.NewTool("list_attachments", - mcp.WithDescription("List attachments owned by the authenticated user. Supports pagination and optional filtering by linked memo."), - mcp.WithNumber("page_size", mcp.Description("Maximum attachments to return (1–100, default 20)")), - mcp.WithNumber("page", mcp.Description("Zero-based page index (default 0)")), - mcp.WithString("memo", mcp.Description(`Filter by linked memo resource name, e.g. "memos/abc123"`)), + readOnlyToolOptions("List attachments", "List attachments owned by the authenticated user. Supports pagination and optional filtering by linked memo.", + mcp.WithNumber("page_size", mcp.Description("Maximum attachments to return (1–100, default 20)")), + mcp.WithNumber("page", mcp.Description("Zero-based page index (default 0)")), + mcp.WithString("memo", mcp.Description(`Filter by linked memo resource name, e.g. "memos/abc123"`)), + mcp.WithOutputSchema[attachmentListJSON](), + )..., ), s.handleListAttachments) mcpSrv.AddTool(mcp.NewTool("get_attachment", - mcp.WithDescription("Get a single attachment's metadata by resource name. Requires authentication."), - mcp.WithString("name", mcp.Required(), mcp.Description(`Attachment resource name, e.g. "attachments/abc123"`)), + readOnlyToolOptions("Get attachment", "Get a single attachment's metadata by resource name. Requires authentication.", + mcp.WithString("name", mcp.Required(), mcp.Description(`Attachment resource name, e.g. "attachments/abc123"`)), + mcp.WithOutputSchema[attachmentJSON](), + )..., ), s.handleGetAttachment) mcpSrv.AddTool(mcp.NewTool("delete_attachment", - mcp.WithDescription("Permanently delete an attachment and its stored file. Requires authentication and ownership."), - mcp.WithString("name", mcp.Required(), mcp.Description(`Attachment resource name, e.g. "attachments/abc123"`)), + updateToolOptions("Delete attachment", "Permanently delete an attachment and its stored file. Requires authentication and ownership.", + mcp.WithString("name", mcp.Required(), mcp.Description(`Attachment resource name, e.g. "attachments/abc123"`)), + mcp.WithOutputSchema[deletedJSON](), + )..., ), s.handleDeleteAttachment) mcpSrv.AddTool(mcp.NewTool("link_attachment_to_memo", - mcp.WithDescription("Link an existing attachment to a memo. Requires authentication and ownership of the attachment."), - mcp.WithString("name", mcp.Required(), mcp.Description(`Attachment resource name, e.g. "attachments/abc123"`)), - mcp.WithString("memo", mcp.Required(), mcp.Description(`Memo resource name, e.g. "memos/abc123"`)), + createToolOptions("Link attachment to memo", "Link an existing attachment to a memo. Requires authentication and ownership of the attachment.", true, + mcp.WithString("name", mcp.Required(), mcp.Description(`Attachment resource name, e.g. "attachments/abc123"`)), + mcp.WithString("memo", mcp.Required(), mcp.Description(`Memo resource name, e.g. "memos/abc123"`)), + mcp.WithOutputSchema[attachmentJSON](), + )..., ), s.handleLinkAttachmentToMemo) } @@ -189,15 +203,7 @@ func (s *MCPService) handleListAttachments(ctx context.Context, req mcp.CallTool results[i] = result } - type listResponse struct { - Attachments []attachmentJSON `json:"attachments"` - HasMore bool `json:"has_more"` - } - out, err := marshalJSON(listResponse{Attachments: results, HasMore: hasMore}) - if err != nil { - return nil, err - } - return mcp.NewToolResultText(out), nil + return newToolResultJSON(attachmentListJSON{Attachments: results, HasMore: hasMore}) } func (s *MCPService) handleGetAttachment(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { @@ -224,16 +230,11 @@ func (s *MCPService) handleGetAttachment(ctx context.Context, req mcp.CallToolRe if err != nil { return mcp.NewToolResultError(fmt.Sprintf("failed to resolve attachment creator: %v", err)), nil } - out, err := marshalJSON(result) - if err != nil { - return nil, err - } - return mcp.NewToolResultText(out), nil + return newToolResultJSON(result) } func (s *MCPService) handleDeleteAttachment(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { - userID, err := extractUserID(ctx) - if err != nil { + if _, err := extractUserID(ctx); err != nil { return mcp.NewToolResultError(err.Error()), nil } @@ -242,18 +243,10 @@ func (s *MCPService) handleDeleteAttachment(ctx context.Context, req mcp.CallToo return mcp.NewToolResultError(err.Error()), nil } - attachment, err := s.store.GetAttachment(ctx, &store.FindAttachment{UID: &uid, CreatorID: &userID}) - if err != nil { - return mcp.NewToolResultError(fmt.Sprintf("failed to find attachment: %v", err)), nil - } - if attachment == nil { - return mcp.NewToolResultError("attachment not found"), nil - } - - if err := s.store.DeleteAttachment(ctx, &store.DeleteAttachment{ID: attachment.ID}); err != nil { + if _, err := s.apiV1Service.DeleteAttachment(ctx, &v1pb.DeleteAttachmentRequest{Name: "attachments/" + uid}); err != nil { return mcp.NewToolResultError(fmt.Sprintf("failed to delete attachment: %v", err)), nil } - return mcp.NewToolResultText(`{"deleted":true}`), nil + return newDeletedToolResult() } func (s *MCPService) handleLinkAttachmentToMemo(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { @@ -293,9 +286,30 @@ func (s *MCPService) handleLinkAttachmentToMemo(ctx context.Context, req mcp.Cal return mcp.NewToolResultError(err.Error()), nil } - if err := s.store.UpdateAttachment(ctx, &store.UpdateAttachment{ - ID: attachment.ID, - MemoID: &memo.ID, + currentAttachments, err := s.store.ListAttachments(ctx, &store.FindAttachment{MemoID: &memo.ID}) + if err != nil { + return mcp.NewToolResultError(fmt.Sprintf("failed to list memo attachments: %v", err)), nil + } + requestAttachments := make([]*v1pb.Attachment, 0, len(currentAttachments)+1) + var currentTarget *store.Attachment + for _, current := range currentAttachments { + requestAttachments = append(requestAttachments, &v1pb.Attachment{Name: "attachments/" + current.UID}) + if current.ID == attachment.ID { + currentTarget = current + } + } + if currentTarget != nil { + result, err := storeAttachmentToJSON(ctx, s.store, currentTarget) + if err != nil { + return mcp.NewToolResultError(fmt.Sprintf("failed to resolve attachment creator: %v", err)), nil + } + return newToolResultJSON(result) + } + requestAttachments = append(requestAttachments, &v1pb.Attachment{Name: "attachments/" + uid}) + + if _, err := s.apiV1Service.SetMemoAttachments(ctx, &v1pb.SetMemoAttachmentsRequest{ + Name: "memos/" + memoUID, + Attachments: requestAttachments, }); err != nil { return mcp.NewToolResultError(fmt.Sprintf("failed to link attachment: %v", err)), nil } @@ -309,9 +323,5 @@ func (s *MCPService) handleLinkAttachmentToMemo(ctx context.Context, req mcp.Cal if err != nil { return mcp.NewToolResultError(fmt.Sprintf("failed to resolve attachment creator: %v", err)), nil } - out, err := marshalJSON(result) - if err != nil { - return nil, err - } - return mcp.NewToolResultText(out), nil + return newToolResultJSON(result) } diff --git a/server/router/mcp/tools_memo.go b/server/router/mcp/tools_memo.go index 47e8a2298..86539159d 100644 --- a/server/router/mcp/tools_memo.go +++ b/server/router/mcp/tools_memo.go @@ -2,52 +2,19 @@ package mcp import ( "context" - "encoding/json" "fmt" - "regexp" "strings" - "github.com/lithammer/shortuuid/v4" "github.com/mark3labs/mcp-go/mcp" mcpserver "github.com/mark3labs/mcp-go/server" "github.com/pkg/errors" + "google.golang.org/protobuf/types/known/fieldmaskpb" - storepb "github.com/usememos/memos/proto/gen/store" + v1pb "github.com/usememos/memos/proto/gen/api/v1" "github.com/usememos/memos/server/auth" "github.com/usememos/memos/store" ) -// tagRegexp matches #tag patterns in memo content. -// A tag must start with a letter and contain no whitespace or # characters. -var tagRegexp = regexp.MustCompile(`(?:^|\s)#([A-Za-z][^\s#]*)`) - -// extractTags does a best-effort extraction of #tags from raw markdown content. -// It is used when creating or updating memos via MCP to pre-populate Payload.Tags. -// The full markdown service may later rebuild a more accurate payload. -func extractTags(content string) []string { - matches := tagRegexp.FindAllStringSubmatch(content, -1) - seen := make(map[string]struct{}, len(matches)) - tags := make([]string, 0, len(matches)) - for _, m := range matches { - tag := m[1] - if _, ok := seen[tag]; !ok { - seen[tag] = struct{}{} - tags = append(tags, tag) - } - } - return tags -} - -// buildPayload constructs a MemoPayload with tags extracted from content. -// Returns nil when no tags are found so the store omits the payload entirely. -func buildPayload(content string) *storepb.MemoPayload { - tags := extractTags(content) - if len(tags) == 0 { - return nil - } - return &storepb.MemoPayload{Tags: tags} -} - // propertyJSON is the serialisable form of MemoPayload.Property. type propertyJSON struct { HasLink bool `json:"has_link"` @@ -72,6 +39,11 @@ type memoJSON struct { Parent string `json:"parent,omitempty"` } +type memoListJSON struct { + Memos []memoJSON `json:"memos"` + HasMore bool `json:"has_more"` +} + func storeMemoToJSON(m *store.Memo) memoJSON { j := memoJSON{ Name: "memos/" + m.UID, @@ -205,75 +177,81 @@ func extractUserID(ctx context.Context) (int32, error) { return id, nil } -func marshalJSON(v any) (string, error) { - b, err := json.Marshal(v) - if err != nil { - return "", err - } - return string(b), nil -} - func (s *MCPService) registerMemoTools(mcpSrv *mcpserver.MCPServer) { mcpSrv.AddTool(mcp.NewTool("list_memos", - mcp.WithDescription("List memos visible to the caller. Authenticated users see their own memos plus public and protected memos; unauthenticated callers see only public memos."), - mcp.WithNumber("page_size", mcp.Description("Maximum memos to return (1–100, default 20)")), - mcp.WithNumber("page", mcp.Description("Zero-based page index for pagination (default 0)")), - mcp.WithString("state", - mcp.Enum("NORMAL", "ARCHIVED"), - mcp.Description("Filter by state: NORMAL (default) or ARCHIVED"), - ), - mcp.WithBoolean("order_by_pinned", mcp.Description("When true, pinned memos appear first (default false)")), - mcp.WithString("filter", mcp.Description(`Optional CEL filter (supported subset of standard CEL syntax), e.g. content.contains("keyword") or tags.exists(t, t == "work")`)), + readOnlyToolOptions("List memos", "List memos visible to the caller. Authenticated users see their own memos plus public and protected memos; unauthenticated callers see only public memos.", + mcp.WithNumber("page_size", mcp.Description("Maximum memos to return (1–100, default 20)")), + mcp.WithNumber("page", mcp.Description("Zero-based page index for pagination (default 0)")), + mcp.WithString("state", + mcp.Enum("NORMAL", "ARCHIVED"), + mcp.Description("Filter by state: NORMAL (default) or ARCHIVED"), + ), + mcp.WithBoolean("order_by_pinned", mcp.Description("When true, pinned memos appear first (default false)")), + mcp.WithString("filter", mcp.Description(`Optional CEL filter (supported subset of standard CEL syntax), e.g. content.contains("keyword") or tags.exists(t, t == "work")`)), + mcp.WithOutputSchema[memoListJSON](), + )..., ), s.handleListMemos) mcpSrv.AddTool(mcp.NewTool("get_memo", - mcp.WithDescription("Get a single memo by resource name. Public memos are accessible without authentication."), - mcp.WithString("name", mcp.Required(), mcp.Description(`Memo resource name, e.g. "memos/abc123"`)), + readOnlyToolOptions("Get memo", "Get a single memo by resource name. Public memos are accessible without authentication.", + mcp.WithString("name", mcp.Required(), mcp.Description(`Memo resource name, e.g. "memos/abc123"`)), + mcp.WithOutputSchema[memoJSON](), + )..., ), s.handleGetMemo) mcpSrv.AddTool(mcp.NewTool("create_memo", - mcp.WithDescription("Create a new memo. Requires authentication."), - mcp.WithString("content", mcp.Required(), mcp.Description("Memo content in Markdown. Use #tag syntax for tagging.")), - mcp.WithString("visibility", - mcp.Enum("PRIVATE", "PROTECTED", "PUBLIC"), - mcp.Description("Visibility (default: PRIVATE)"), - ), + createToolOptions("Create memo", "Create a new memo. Requires authentication.", false, + mcp.WithString("content", mcp.Required(), mcp.Description("Memo content in Markdown. Use #tag syntax for tagging.")), + mcp.WithString("visibility", + mcp.Enum("PRIVATE", "PROTECTED", "PUBLIC"), + mcp.Description("Visibility (default: PRIVATE)"), + ), + mcp.WithOutputSchema[memoJSON](), + )..., ), s.handleCreateMemo) mcpSrv.AddTool(mcp.NewTool("update_memo", - mcp.WithDescription("Update a memo's content, visibility, pin state, or archive state. Requires authentication and ownership. Omit any field to leave it unchanged."), - mcp.WithString("name", mcp.Required(), mcp.Description(`Memo resource name, e.g. "memos/abc123"`)), - mcp.WithString("content", mcp.Description("New Markdown content")), - mcp.WithString("visibility", - mcp.Enum("PRIVATE", "PROTECTED", "PUBLIC"), - mcp.Description("New visibility"), - ), - mcp.WithBoolean("pinned", mcp.Description("Pin or unpin the memo")), - mcp.WithString("state", - mcp.Enum("NORMAL", "ARCHIVED"), - mcp.Description("Set to ARCHIVED to archive, NORMAL to restore"), - ), + updateToolOptions("Update memo", "Update a memo's content, visibility, pin state, or archive state. Requires authentication and ownership. Omit any field to leave it unchanged.", + mcp.WithString("name", mcp.Required(), mcp.Description(`Memo resource name, e.g. "memos/abc123"`)), + mcp.WithString("content", mcp.Description("New Markdown content")), + mcp.WithString("visibility", + mcp.Enum("PRIVATE", "PROTECTED", "PUBLIC"), + mcp.Description("New visibility"), + ), + mcp.WithBoolean("pinned", mcp.Description("Pin or unpin the memo")), + mcp.WithString("state", + mcp.Enum("NORMAL", "ARCHIVED"), + mcp.Description("Set to ARCHIVED to archive, NORMAL to restore"), + ), + mcp.WithOutputSchema[memoJSON](), + )..., ), s.handleUpdateMemo) mcpSrv.AddTool(mcp.NewTool("delete_memo", - mcp.WithDescription("Permanently delete a memo. Requires authentication and ownership."), - mcp.WithString("name", mcp.Required(), mcp.Description(`Memo resource name, e.g. "memos/abc123"`)), + updateToolOptions("Delete memo", "Permanently delete a memo. Requires authentication and ownership.", + mcp.WithString("name", mcp.Required(), mcp.Description(`Memo resource name, e.g. "memos/abc123"`)), + mcp.WithOutputSchema[deletedJSON](), + )..., ), s.handleDeleteMemo) mcpSrv.AddTool(mcp.NewTool("search_memos", - mcp.WithDescription("Search memo content. Authenticated users search their own and visible memos; unauthenticated callers search public memos only."), - mcp.WithString("query", mcp.Required(), mcp.Description("Text to search for in memo content")), + readOnlyToolOptions("Search memos", "Search memo content. Authenticated users search their own and visible memos; unauthenticated callers search public memos only.", + mcp.WithString("query", mcp.Required(), mcp.Description("Text to search for in memo content")), + )..., ), s.handleSearchMemos) mcpSrv.AddTool(mcp.NewTool("list_memo_comments", - mcp.WithDescription("List comments on a memo. Visibility rules for comments match those of the parent memo."), - mcp.WithString("name", mcp.Required(), mcp.Description(`Memo resource name, e.g. "memos/abc123"`)), + readOnlyToolOptions("List memo comments", "List comments on a memo. Visibility rules for comments match those of the parent memo.", + mcp.WithString("name", mcp.Required(), mcp.Description(`Memo resource name, e.g. "memos/abc123"`)), + )..., ), s.handleListMemoComments) mcpSrv.AddTool(mcp.NewTool("create_memo_comment", - mcp.WithDescription("Add a comment to a memo. The comment inherits the parent memo's visibility. Requires authentication."), - mcp.WithString("name", mcp.Required(), mcp.Description(`Memo resource name to comment on, e.g. "memos/abc123"`)), - mcp.WithString("content", mcp.Required(), mcp.Description("Comment content in Markdown")), + createToolOptions("Create memo comment", "Add a comment to a memo. The comment inherits the parent memo's visibility. Requires authentication.", false, + mcp.WithString("name", mcp.Required(), mcp.Description(`Memo resource name to comment on, e.g. "memos/abc123"`)), + mcp.WithString("content", mcp.Required(), mcp.Description("Comment content in Markdown")), + mcp.WithOutputSchema[memoJSON](), + )..., ), s.handleCreateMemoComment) } @@ -342,15 +320,7 @@ func (s *MCPService) handleListMemos(ctx context.Context, req mcp.CallToolReques results[i] = result } - type listResponse struct { - Memos []memoJSON `json:"memos"` - HasMore bool `json:"has_more"` - } - out, err := marshalJSON(listResponse{Memos: results, HasMore: hasMore}) - if err != nil { - return nil, err - } - return mcp.NewToolResultText(out), nil + return newToolResultJSON(memoListJSON{Memos: results, HasMore: hasMore}) } func (s *MCPService) handleGetMemo(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { @@ -376,16 +346,11 @@ func (s *MCPService) handleGetMemo(ctx context.Context, req mcp.CallToolRequest) if err != nil { return mcp.NewToolResultError(fmt.Sprintf("failed to resolve memo creator: %v", err)), nil } - out, err := marshalJSON(result) - if err != nil { - return nil, err - } - return mcp.NewToolResultText(out), nil + return newToolResultJSON(result) } func (s *MCPService) handleCreateMemo(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { - userID, err := extractUserID(ctx) - if err != nil { + if _, err := extractUserID(ctx); err != nil { return mcp.NewToolResultError(err.Error()), nil } @@ -398,31 +363,25 @@ func (s *MCPService) handleCreateMemo(ctx context.Context, req mcp.CallToolReque return mcp.NewToolResultError(err.Error()), nil } - memo, err := s.store.CreateMemo(ctx, &store.Memo{ - UID: shortuuid.New(), - CreatorID: userID, - Content: content, - Visibility: visibility, - Payload: buildPayload(content), + created, err := s.apiV1Service.CreateMemo(ctx, &v1pb.CreateMemoRequest{ + Memo: &v1pb.Memo{ + Content: content, + Visibility: visibilityToProto(visibility), + }, }) if err != nil { return mcp.NewToolResultError(fmt.Sprintf("failed to create memo: %v", err)), nil } - result, err := storeMemoToJSONWithStore(ctx, s.store, memo) + result, err := s.loadMemoJSONByName(ctx, created.Name) if err != nil { - return mcp.NewToolResultError(fmt.Sprintf("failed to resolve memo creator: %v", err)), nil - } - out, err := marshalJSON(result) - if err != nil { - return nil, err + return mcp.NewToolResultError(err.Error()), nil } - return mcp.NewToolResultText(out), nil + return newToolResultJSON(result) } func (s *MCPService) handleUpdateMemo(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { - userID, err := extractUserID(ctx) - if err != nil { + if _, err := extractUserID(ctx); err != nil { return mcp.NewToolResultError(err.Error()), nil } @@ -431,66 +390,56 @@ func (s *MCPService) handleUpdateMemo(ctx context.Context, req mcp.CallToolReque return mcp.NewToolResultError(err.Error()), nil } - memo, err := s.store.GetMemo(ctx, &store.FindMemo{UID: &uid}) - if err != nil { - return mcp.NewToolResultError(fmt.Sprintf("failed to get memo: %v", err)), nil - } - if memo == nil { - return mcp.NewToolResultError("memo not found"), nil - } - if err := checkMemoOwnership(memo, userID); err != nil { - return mcp.NewToolResultError(err.Error()), nil - } - - update := &store.UpdateMemo{ID: memo.ID} + update := &v1pb.Memo{Name: "memos/" + uid} + updateMask := &fieldmaskpb.FieldMask{} args := req.GetArguments() if v := req.GetString("content", ""); v != "" { - update.Content = &v - update.Payload = buildPayload(v) + update.Content = v + updateMask.Paths = append(updateMask.Paths, "content") } if v := req.GetString("visibility", ""); v != "" { vis, err := parseVisibility(v) if err != nil { return mcp.NewToolResultError(err.Error()), nil } - update.Visibility = &vis + update.Visibility = visibilityToProto(vis) + updateMask.Paths = append(updateMask.Paths, "visibility") } if v := req.GetString("state", ""); v != "" { rs, err := parseRowStatus(v) if err != nil { return mcp.NewToolResultError(err.Error()), nil } - update.RowStatus = &rs + update.State = rowStatusToProto(rs) + updateMask.Paths = append(updateMask.Paths, "state") } if _, ok := args["pinned"]; ok { - pinned := req.GetBool("pinned", false) - update.Pinned = &pinned + update.Pinned = req.GetBool("pinned", false) + updateMask.Paths = append(updateMask.Paths, "pinned") } - if err := s.store.UpdateMemo(ctx, update); err != nil { - return mcp.NewToolResultError(fmt.Sprintf("failed to update memo: %v", err)), nil + if len(updateMask.Paths) == 0 { + return mcp.NewToolResultError("at least one field must be provided to update"), nil } - updated, err := s.store.GetMemo(ctx, &store.FindMemo{ID: &memo.ID}) + updated, err := s.apiV1Service.UpdateMemo(ctx, &v1pb.UpdateMemoRequest{ + Memo: update, + UpdateMask: updateMask, + }) if err != nil { - return mcp.NewToolResultError(fmt.Sprintf("failed to fetch updated memo: %v", err)), nil + return mcp.NewToolResultError(fmt.Sprintf("failed to update memo: %v", err)), nil } - result, err := storeMemoToJSONWithStore(ctx, s.store, updated) + result, err := s.loadMemoJSONByName(ctx, updated.Name) if err != nil { - return mcp.NewToolResultError(fmt.Sprintf("failed to resolve memo creator: %v", err)), nil - } - out, err := marshalJSON(result) - if err != nil { - return nil, err + return mcp.NewToolResultError(err.Error()), nil } - return mcp.NewToolResultText(out), nil + return newToolResultJSON(result) } func (s *MCPService) handleDeleteMemo(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { - userID, err := extractUserID(ctx) - if err != nil { + if _, err := extractUserID(ctx); err != nil { return mcp.NewToolResultError(err.Error()), nil } @@ -499,21 +448,10 @@ func (s *MCPService) handleDeleteMemo(ctx context.Context, req mcp.CallToolReque return mcp.NewToolResultError(err.Error()), nil } - memo, err := s.store.GetMemo(ctx, &store.FindMemo{UID: &uid}) - if err != nil { - return mcp.NewToolResultError(fmt.Sprintf("failed to get memo: %v", err)), nil - } - if memo == nil { - return mcp.NewToolResultError("memo not found"), nil - } - if err := checkMemoOwnership(memo, userID); err != nil { - return mcp.NewToolResultError(err.Error()), nil - } - - if err := s.store.DeleteMemo(ctx, &store.DeleteMemo{ID: memo.ID}); err != nil { + if _, err := s.apiV1Service.DeleteMemo(ctx, &v1pb.DeleteMemoRequest{Name: "memos/" + uid}); err != nil { return mcp.NewToolResultError(fmt.Sprintf("failed to delete memo: %v", err)), nil } - return mcp.NewToolResultText(`{"deleted":true}`), nil + return newDeletedToolResult() } func (s *MCPService) handleSearchMemos(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { @@ -557,11 +495,7 @@ func (s *MCPService) handleSearchMemos(ctx context.Context, req mcp.CallToolRequ } results[i] = result } - out, err := marshalJSON(results) - if err != nil { - return nil, err - } - return mcp.NewToolResultText(out), nil + return newToolResultJSON(results) } func (s *MCPService) handleListMemoComments(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { @@ -592,8 +526,7 @@ func (s *MCPService) handleListMemoComments(ctx context.Context, req mcp.CallToo return mcp.NewToolResultError(fmt.Sprintf("failed to list relations: %v", err)), nil } if len(relations) == 0 { - out, _ := marshalJSON([]memoJSON{}) - return mcp.NewToolResultText(out), nil + return newToolResultJSON([]memoJSON{}) } commentIDs := make([]int32, len(relations)) @@ -626,11 +559,7 @@ func (s *MCPService) handleListMemoComments(ctx context.Context, req mcp.CallToo results = append(results, result) } } - out, err := marshalJSON(results) - if err != nil { - return nil, err - } - return mcp.NewToolResultText(out), nil + return newToolResultJSON(results) } func (s *MCPService) handleCreateMemoComment(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { @@ -659,33 +588,20 @@ func (s *MCPService) handleCreateMemoComment(ctx context.Context, req mcp.CallTo return mcp.NewToolResultError(err.Error()), nil } - comment, err := s.store.CreateMemo(ctx, &store.Memo{ - UID: shortuuid.New(), - CreatorID: userID, - Content: content, - Visibility: parent.Visibility, - Payload: buildPayload(content), - ParentUID: &parent.UID, + comment, err := s.apiV1Service.CreateMemoComment(ctx, &v1pb.CreateMemoCommentRequest{ + Name: "memos/" + uid, + Comment: &v1pb.Memo{ + Content: content, + Visibility: visibilityToProto(parent.Visibility), + }, }) if err != nil { return mcp.NewToolResultError(fmt.Sprintf("failed to create comment: %v", err)), nil } - if _, err = s.store.UpsertMemoRelation(ctx, &store.MemoRelation{ - MemoID: comment.ID, - RelatedMemoID: parent.ID, - Type: store.MemoRelationComment, - }); err != nil { - return mcp.NewToolResultError(fmt.Sprintf("failed to link comment: %v", err)), nil - } - - result, err := storeMemoToJSONWithStore(ctx, s.store, comment) + result, err := s.loadMemoJSONByName(ctx, comment.Name) if err != nil { - return mcp.NewToolResultError(fmt.Sprintf("failed to resolve memo creator: %v", err)), nil - } - out, err := marshalJSON(result) - if err != nil { - return nil, err + return mcp.NewToolResultError(err.Error()), nil } - return mcp.NewToolResultText(out), nil + return newToolResultJSON(result) } diff --git a/server/router/mcp/tools_reaction.go b/server/router/mcp/tools_reaction.go index 46e4c5d44..705c898a4 100644 --- a/server/router/mcp/tools_reaction.go +++ b/server/router/mcp/tools_reaction.go @@ -7,6 +7,7 @@ import ( "github.com/mark3labs/mcp-go/mcp" mcpserver "github.com/mark3labs/mcp-go/server" + v1pb "github.com/usememos/memos/proto/gen/api/v1" "github.com/usememos/memos/server/auth" "github.com/usememos/memos/store" ) @@ -20,19 +21,24 @@ type reactionJSON struct { func (s *MCPService) registerReactionTools(mcpSrv *mcpserver.MCPServer) { mcpSrv.AddTool(mcp.NewTool("list_reactions", - mcp.WithDescription("List all reactions on a memo. Returns reaction type and creator for each reaction."), - mcp.WithString("name", mcp.Required(), mcp.Description(`Memo resource name, e.g. "memos/abc123"`)), + readOnlyToolOptions("List reactions", "List all reactions on a memo. Returns reaction type and creator for each reaction.", + mcp.WithString("name", mcp.Required(), mcp.Description(`Memo resource name, e.g. "memos/abc123"`)), + )..., ), s.handleListReactions) mcpSrv.AddTool(mcp.NewTool("upsert_reaction", - mcp.WithDescription("Add a reaction (emoji) to a memo. If the same reaction already exists from the same user, this is a no-op. Requires authentication."), - mcp.WithString("name", mcp.Required(), mcp.Description(`Memo resource name, e.g. "memos/abc123"`)), - mcp.WithString("reaction_type", mcp.Required(), mcp.Description(`Reaction emoji, e.g. "👍", "❤️", "🎉"`)), + createToolOptions("Upsert reaction", "Add a reaction (emoji) to a memo. If the same reaction already exists from the same user, this is a no-op. Requires authentication.", true, + mcp.WithString("name", mcp.Required(), mcp.Description(`Memo resource name, e.g. "memos/abc123"`)), + mcp.WithString("reaction_type", mcp.Required(), mcp.Description(`Reaction emoji, e.g. "👍", "❤️", "🎉"`)), + mcp.WithOutputSchema[reactionJSON](), + )..., ), s.handleUpsertReaction) mcpSrv.AddTool(mcp.NewTool("delete_reaction", - mcp.WithDescription("Remove a reaction by its ID. Requires authentication and ownership of the reaction."), - mcp.WithNumber("id", mcp.Required(), mcp.Description("Reaction ID to delete")), + updateToolOptions("Delete reaction", "Remove a reaction by its ID. Requires authentication and ownership of the reaction.", + mcp.WithNumber("id", mcp.Required(), mcp.Description("Reaction ID to delete")), + mcp.WithOutputSchema[deletedJSON](), + )..., ), s.handleDeleteReaction) } @@ -83,11 +89,7 @@ func (s *MCPService) handleListReactions(ctx context.Context, req mcp.CallToolRe } } - out, err := marshalJSON(results) - if err != nil { - return nil, err - } - return mcp.NewToolResultText(out), nil + return newToolResultJSON(results) } func (s *MCPService) handleUpsertReaction(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { @@ -133,34 +135,26 @@ func (s *MCPService) handleUpsertReaction(ctx context.Context, req mcp.CallToolR } contentID := "memos/" + uid - reaction, err := s.store.UpsertReaction(ctx, &store.Reaction{ - CreatorID: userID, - ContentID: contentID, - ReactionType: reactionType, + reaction, err := s.apiV1Service.UpsertMemoReaction(ctx, &v1pb.UpsertMemoReactionRequest{ + Name: contentID, + Reaction: &v1pb.Reaction{ + ContentId: contentID, + ReactionType: reactionType, + }, }) if err != nil { return mcp.NewToolResultError(fmt.Sprintf("failed to upsert reaction: %v", err)), nil } - creator, err := lookupUsername(ctx, s.store, reaction.CreatorID) - if err != nil { - return mcp.NewToolResultError(fmt.Sprintf("failed to resolve reaction creator: %v", err)), nil - } - out, err := marshalJSON(reactionJSON{ - ID: reaction.ID, - Creator: creator, - ReactionType: reaction.ReactionType, - CreateTime: reaction.CreatedTs, - }) + result, err := s.loadReactionJSONByName(ctx, reaction.Name) if err != nil { - return nil, err + return mcp.NewToolResultError(err.Error()), nil } - return mcp.NewToolResultText(out), nil + return newToolResultJSON(result) } func (s *MCPService) handleDeleteReaction(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { - userID, err := extractUserID(ctx) - if err != nil { + if _, err := extractUserID(ctx); err != nil { return mcp.NewToolResultError(err.Error()), nil } @@ -176,12 +170,11 @@ func (s *MCPService) handleDeleteReaction(ctx context.Context, req mcp.CallToolR if reaction == nil { return mcp.NewToolResultError("reaction not found"), nil } - if reaction.CreatorID != userID { - return mcp.NewToolResultError("permission denied: can only delete your own reactions"), nil - } - if err := s.store.DeleteReaction(ctx, &store.DeleteReaction{ID: reactionID}); err != nil { + if _, err := s.apiV1Service.DeleteMemoReaction(ctx, &v1pb.DeleteMemoReactionRequest{ + Name: fmt.Sprintf("%s/reactions/%d", reaction.ContentID, reactionID), + }); err != nil { return mcp.NewToolResultError(fmt.Sprintf("failed to delete reaction: %v", err)), nil } - return mcp.NewToolResultText(`{"deleted":true}`), nil + return newDeletedToolResult() } diff --git a/server/router/mcp/tools_relation.go b/server/router/mcp/tools_relation.go index 127bb16fe..e395f45c4 100644 --- a/server/router/mcp/tools_relation.go +++ b/server/router/mcp/tools_relation.go @@ -7,6 +7,7 @@ import ( "github.com/mark3labs/mcp-go/mcp" mcpserver "github.com/mark3labs/mcp-go/server" + v1pb "github.com/usememos/memos/proto/gen/api/v1" "github.com/usememos/memos/server/auth" "github.com/usememos/memos/store" ) @@ -19,24 +20,29 @@ type relationJSON struct { func (s *MCPService) registerRelationTools(mcpSrv *mcpserver.MCPServer) { mcpSrv.AddTool(mcp.NewTool("list_memo_relations", - mcp.WithDescription("List all relations (references and comments) for a memo. Requires read access to the memo."), - mcp.WithString("name", mcp.Required(), mcp.Description(`Memo resource name, e.g. "memos/abc123"`)), - mcp.WithString("type", - mcp.Enum("REFERENCE", "COMMENT"), - mcp.Description("Filter by relation type (optional)"), - ), + readOnlyToolOptions("List memo relations", "List all relations (references and comments) for a memo. Requires read access to the memo.", + mcp.WithString("name", mcp.Required(), mcp.Description(`Memo resource name, e.g. "memos/abc123"`)), + mcp.WithString("type", + mcp.Enum("REFERENCE", "COMMENT"), + mcp.Description("Filter by relation type (optional)"), + ), + )..., ), s.handleListMemoRelations) mcpSrv.AddTool(mcp.NewTool("create_memo_relation", - mcp.WithDescription("Create a reference relation between two memos. Requires authentication. For comments, use create_memo_comment instead."), - mcp.WithString("name", mcp.Required(), mcp.Description(`Source memo resource name, e.g. "memos/abc123"`)), - mcp.WithString("related_memo", mcp.Required(), mcp.Description(`Target memo resource name, e.g. "memos/def456"`)), + createToolOptions("Create memo relation", "Create a reference relation between two memos. Requires authentication. For comments, use create_memo_comment instead.", true, + mcp.WithString("name", mcp.Required(), mcp.Description(`Source memo resource name, e.g. "memos/abc123"`)), + mcp.WithString("related_memo", mcp.Required(), mcp.Description(`Target memo resource name, e.g. "memos/def456"`)), + mcp.WithOutputSchema[relationJSON](), + )..., ), s.handleCreateMemoRelation) mcpSrv.AddTool(mcp.NewTool("delete_memo_relation", - mcp.WithDescription("Delete a reference relation between two memos. Requires authentication and ownership of the source memo."), - mcp.WithString("name", mcp.Required(), mcp.Description(`Source memo resource name, e.g. "memos/abc123"`)), - mcp.WithString("related_memo", mcp.Required(), mcp.Description(`Target memo resource name, e.g. "memos/def456"`)), + updateToolOptions("Delete memo relation", "Delete a reference relation between two memos. Requires authentication and ownership of the source memo.", + mcp.WithString("name", mcp.Required(), mcp.Description(`Source memo resource name, e.g. "memos/abc123"`)), + mcp.WithString("related_memo", mcp.Required(), mcp.Description(`Target memo resource name, e.g. "memos/def456"`)), + mcp.WithOutputSchema[deletedJSON](), + )..., ), s.handleDeleteMemoRelation) } @@ -113,11 +119,7 @@ func (s *MCPService) handleListMemoRelations(ctx context.Context, req mcp.CallTo }) } - out, err := marshalJSON(results) - if err != nil { - return nil, err - } - return mcp.NewToolResultText(out), nil + return newToolResultJSON(results) } func (s *MCPService) handleCreateMemoRelation(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { @@ -134,6 +136,9 @@ func (s *MCPService) handleCreateMemoRelation(ctx context.Context, req mcp.CallT if err != nil { return mcp.NewToolResultError(err.Error()), nil } + if srcUID == dstUID { + return mcp.NewToolResultError("cannot create a relation from a memo to itself"), nil + } srcMemo, err := s.store.GetMemo(ctx, &store.FindMemo{UID: &srcUID}) if err != nil { @@ -157,24 +162,24 @@ func (s *MCPService) handleCreateMemoRelation(ctx context.Context, req mcp.CallT return mcp.NewToolResultError(err.Error()), nil } - relation, err := s.store.UpsertMemoRelation(ctx, &store.MemoRelation{ - MemoID: srcMemo.ID, - RelatedMemoID: dstMemo.ID, - Type: store.MemoRelationReference, - }) + relations, changed, err := s.buildReferenceRelationSet(ctx, srcMemo, &dstMemo.UID, nil) if err != nil { - return mcp.NewToolResultError(fmt.Sprintf("failed to create relation: %v", err)), nil + return mcp.NewToolResultError(fmt.Sprintf("failed to build relation set: %v", err)), nil + } + if changed { + if _, err := s.apiV1Service.SetMemoRelations(ctx, &v1pb.SetMemoRelationsRequest{ + Name: "memos/" + srcUID, + Relations: relations, + }); err != nil { + return mcp.NewToolResultError(fmt.Sprintf("failed to create relation: %v", err)), nil + } } - out, err := marshalJSON(relationJSON{ + return newToolResultJSON(relationJSON{ Memo: "memos/" + srcUID, RelatedMemo: "memos/" + dstUID, - Type: string(relation.Type), + Type: string(store.MemoRelationReference), }) - if err != nil { - return nil, err - } - return mcp.NewToolResultText(out), nil } func (s *MCPService) handleDeleteMemoRelation(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { @@ -214,13 +219,79 @@ func (s *MCPService) handleDeleteMemoRelation(ctx context.Context, req mcp.CallT return mcp.NewToolResultError(err.Error()), nil } - refType := store.MemoRelationReference - if err := s.store.DeleteMemoRelation(ctx, &store.DeleteMemoRelation{ - MemoID: &srcMemo.ID, - RelatedMemoID: &dstMemo.ID, - Type: &refType, - }); err != nil { - return mcp.NewToolResultError(fmt.Sprintf("failed to delete relation: %v", err)), nil + relations, changed, err := s.buildReferenceRelationSet(ctx, srcMemo, nil, &dstMemo.UID) + if err != nil { + return mcp.NewToolResultError(fmt.Sprintf("failed to build relation set: %v", err)), nil + } + if changed { + if _, err := s.apiV1Service.SetMemoRelations(ctx, &v1pb.SetMemoRelationsRequest{ + Name: "memos/" + srcUID, + Relations: relations, + }); err != nil { + return mcp.NewToolResultError(fmt.Sprintf("failed to delete relation: %v", err)), nil + } + } + return newDeletedToolResult() +} + +func (s *MCPService) buildReferenceRelationSet(ctx context.Context, source *store.Memo, includeUID *string, excludeUID *string) ([]*v1pb.MemoRelation, bool, error) { + referenceType := store.MemoRelationReference + relations, err := s.store.ListMemoRelations(ctx, &store.FindMemoRelation{ + MemoIDList: []int32{source.ID}, + Type: &referenceType, + }) + if err != nil { + return nil, false, err + } + + idSet := make(map[int32]struct{}, len(relations)) + for _, relation := range relations { + idSet[relation.RelatedMemoID] = struct{}{} + } + ids := make([]int32, 0, len(idSet)) + for id := range idSet { + ids = append(ids, id) + } + + memosByID := map[int32]*store.Memo{} + if len(ids) > 0 { + memos, err := s.store.ListMemos(ctx, &store.FindMemo{IDList: ids, ExcludeContent: true}) + if err != nil { + return nil, false, err + } + for _, memo := range memos { + memosByID[memo.ID] = memo + } + } + + result := make([]*v1pb.MemoRelation, 0, len(relations)+1) + seenUIDs := map[string]struct{}{} + changed := false + for _, relation := range relations { + relatedMemo := memosByID[relation.RelatedMemoID] + if relatedMemo == nil { + continue + } + if excludeUID != nil && relatedMemo.UID == *excludeUID { + changed = true + continue + } + result = append(result, newReferenceRelation(source.UID, relatedMemo.UID)) + seenUIDs[relatedMemo.UID] = struct{}{} + } + if includeUID != nil { + if _, seen := seenUIDs[*includeUID]; !seen && source.UID != *includeUID { + result = append(result, newReferenceRelation(source.UID, *includeUID)) + changed = true + } + } + return result, changed, nil +} + +func newReferenceRelation(sourceUID string, relatedUID string) *v1pb.MemoRelation { + return &v1pb.MemoRelation{ + Memo: &v1pb.MemoRelation_Memo{Name: "memos/" + sourceUID}, + RelatedMemo: &v1pb.MemoRelation_Memo{Name: "memos/" + relatedUID}, + Type: v1pb.MemoRelation_REFERENCE, } - return mcp.NewToolResultText(`{"deleted":true}`), nil } diff --git a/server/router/mcp/tools_tag.go b/server/router/mcp/tools_tag.go index ded3c5849..ff9199644 100644 --- a/server/router/mcp/tools_tag.go +++ b/server/router/mcp/tools_tag.go @@ -14,7 +14,7 @@ import ( func (s *MCPService) registerTagTools(mcpSrv *mcpserver.MCPServer) { mcpSrv.AddTool(mcp.NewTool("list_tags", - mcp.WithDescription("List all tags with their memo counts. Authenticated users see tags from their own and visible memos; unauthenticated callers see tags from public memos only. Results are sorted by count descending, then alphabetically."), + readOnlyToolOptions("List tags", "List all tags with their memo counts. Authenticated users see tags from their own and visible memos; unauthenticated callers see tags from public memos only. Results are sorted by count descending, then alphabetically.")..., ), s.handleListTags) } @@ -70,9 +70,5 @@ func (s *MCPService) handleListTags(ctx context.Context, _ mcp.CallToolRequest) } }) - out, err := marshalJSON(entries) - if err != nil { - return nil, err - } - return mcp.NewToolResultText(out), nil + return newToolResultJSON(entries) } diff --git a/server/server.go b/server/server.go index 122504114..8c7e47ddf 100644 --- a/server/server.go +++ b/server/server.go @@ -89,7 +89,7 @@ func NewServer(ctx context.Context, profile *profile.Profile, store *store.Store } // Register MCP server. - mcpService := mcprouter.NewMCPService(s.Profile, s.Store, s.Secret) + mcpService := mcprouter.NewMCPService(s.Profile, s.Store, s.Secret, apiV1Service) mcpService.RegisterRoutes(echoServer) return s, nil