From df04e852bff47b562c2dd31461f874a89e0989be Mon Sep 17 00:00:00 2001 From: boojack Date: Sat, 4 Mar 2023 18:22:10 +0800 Subject: [PATCH] feat: implement openai integration (#1245) * feat: implement openai integration * chore: update --- api/openai.go | 5 + api/system_setting.go | 10 ++ plugin/openai/chat_completion.go | 70 ++++++++++ plugin/openai/text_completion.go | 68 +++++++++ server/openai.go | 105 ++++++++++++++ server/server.go | 1 + server/system.go | 2 +- web/src/components/AskAIDialog.tsx | 130 ++++++++++++++++++ web/src/components/CreateResourceDialog.tsx | 2 +- web/src/components/Dialog/BaseDialog.tsx | 9 +- web/src/components/Header.tsx | 10 +- web/src/components/Settings/SystemSection.tsx | 84 +++++++---- web/src/helpers/api.ts | 16 +++ web/src/types/view.d.ts | 7 +- 14 files changed, 487 insertions(+), 32 deletions(-) create mode 100644 api/openai.go create mode 100644 plugin/openai/chat_completion.go create mode 100644 plugin/openai/text_completion.go create mode 100644 server/openai.go create mode 100644 web/src/components/AskAIDialog.tsx diff --git a/api/openai.go b/api/openai.go new file mode 100644 index 00000000..ca5c5ec5 --- /dev/null +++ b/api/openai.go @@ -0,0 +1,5 @@ +package api + +type OpenAICompletionRequest struct { + Prompt string `json:"prompt"` +} diff --git a/api/system_setting.go b/api/system_setting.go index 61ff2e11..5069fd55 100644 --- a/api/system_setting.go +++ b/api/system_setting.go @@ -27,6 +27,8 @@ const ( SystemSettingCustomizedProfileName SystemSettingName = "customizedProfile" // SystemSettingStorageServiceIDName is the key type of storage service ID. SystemSettingStorageServiceIDName SystemSettingName = "storageServiceId" + // SystemSettingOpenAIAPIKeyName is the key type of OpenAI API key. + SystemSettingOpenAIAPIKeyName SystemSettingName = "openAIApiKey" ) // CustomizedProfile is the struct definition for SystemSettingCustomizedProfileName system setting item. @@ -63,6 +65,8 @@ func (key SystemSettingName) String() string { return "customizedProfile" case SystemSettingStorageServiceIDName: return "storageServiceId" + case SystemSettingOpenAIAPIKeyName: + return "openAIApiKey" } return "" } @@ -161,6 +165,12 @@ func (upsert SystemSettingUpsert) Validate() error { return fmt.Errorf("failed to unmarshal system setting storage service id value") } return nil + } else if upsert.Name == SystemSettingOpenAIAPIKeyName { + value := "" + err := json.Unmarshal([]byte(upsert.Value), &value) + if err != nil { + return fmt.Errorf("failed to unmarshal system setting openai api key value") + } } else { return fmt.Errorf("invalid system setting name") } diff --git a/plugin/openai/chat_completion.go b/plugin/openai/chat_completion.go new file mode 100644 index 00000000..3b5657e0 --- /dev/null +++ b/plugin/openai/chat_completion.go @@ -0,0 +1,70 @@ +package openai + +import ( + "encoding/json" + "errors" + "io" + "net/http" + "strings" +) + +type ChatCompletionMessage struct { + Role string `json:"role"` + Content string `json:"content"` +} + +type ChatCompletionChoice struct { + Message *ChatCompletionMessage `json:"message"` +} + +type ChatCompletionResponse struct { + Error interface{} `json:"error"` + Model string `json:"model"` + Choices []ChatCompletionChoice `json:"choices"` +} + +func PostChatCompletion(prompt string, apiKey string) (string, error) { + requestBody := strings.NewReader(`{ + "model": "gpt-3.5-turbo", + "messages": [{"role": "user", "content": "` + prompt + `"}] + }`) + req, err := http.NewRequest("POST", "https://api.openai.com/v1/chat/completions", requestBody) + if err != nil { + return "", err + } + + // Set the API key in the request header + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", "Bearer "+apiKey) + + // Send the request to OpenAI's API + client := http.Client{} + resp, err := client.Do(req) + if err != nil { + return "", err + } + defer resp.Body.Close() + + // Read the response body + responseBody, err := io.ReadAll(resp.Body) + if err != nil { + return "", err + } + + chatCompletionResponse := ChatCompletionResponse{} + err = json.Unmarshal(responseBody, &chatCompletionResponse) + if err != nil { + return "", err + } + if chatCompletionResponse.Error != nil { + errorBytes, err := json.Marshal(chatCompletionResponse.Error) + if err != nil { + return "", err + } + return "", errors.New(string(errorBytes)) + } + if len(chatCompletionResponse.Choices) == 0 { + return "", nil + } + return chatCompletionResponse.Choices[0].Message.Content, nil +} diff --git a/plugin/openai/text_completion.go b/plugin/openai/text_completion.go new file mode 100644 index 00000000..26cf2b64 --- /dev/null +++ b/plugin/openai/text_completion.go @@ -0,0 +1,68 @@ +package openai + +import ( + "encoding/json" + "errors" + "io" + "net/http" + "strings" +) + +type TextCompletionChoice struct { + Text string `json:"text"` +} + +type TextCompletionResponse struct { + Error interface{} `json:"error"` + Model string `json:"model"` + Choices []TextCompletionChoice `json:"choices"` +} + +func PostTextCompletion(prompt string, apiKey string) (string, error) { + requestBody := strings.NewReader(`{ + "prompt": "` + prompt + `", + "temperature": 0.5, + "max_tokens": 100, + "n": 1, + "stop": "." + }`) + req, err := http.NewRequest("POST", "https://api.openai.com/v1/completions", requestBody) + if err != nil { + return "", err + } + + // Set the API key in the request header + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", "Bearer "+apiKey) + + // Send the request to OpenAI's API + client := http.Client{} + resp, err := client.Do(req) + if err != nil { + return "", err + } + defer resp.Body.Close() + + // Read the response body + responseBody, err := io.ReadAll(resp.Body) + if err != nil { + return "", err + } + + textCompletionResponse := TextCompletionResponse{} + err = json.Unmarshal(responseBody, &textCompletionResponse) + if err != nil { + return "", err + } + if textCompletionResponse.Error != nil { + errorBytes, err := json.Marshal(textCompletionResponse.Error) + if err != nil { + return "", err + } + return "", errors.New(string(errorBytes)) + } + if len(textCompletionResponse.Choices) == 0 { + return "", nil + } + return textCompletionResponse.Choices[0].Text, nil +} diff --git a/server/openai.go b/server/openai.go new file mode 100644 index 00000000..0761cc34 --- /dev/null +++ b/server/openai.go @@ -0,0 +1,105 @@ +package server + +import ( + "encoding/json" + "net/http" + + "github.com/labstack/echo/v4" + "github.com/usememos/memos/api" + "github.com/usememos/memos/common" + "github.com/usememos/memos/plugin/openai" +) + +func (s *Server) registerOpenAIRoutes(g *echo.Group) { + g.POST("/opanai/chat-completion", func(c echo.Context) error { + ctx := c.Request().Context() + openAIApiKeySetting, err := s.Store.FindSystemSetting(ctx, &api.SystemSettingFind{ + Name: api.SystemSettingOpenAIAPIKeyName, + }) + if err != nil && common.ErrorCode(err) != common.NotFound { + return echo.NewHTTPError(http.StatusInternalServerError, "Failed to find openai api key").SetInternal(err) + } + + openAIApiKey := "" + if openAIApiKeySetting != nil { + err = json.Unmarshal([]byte(openAIApiKeySetting.Value), &openAIApiKey) + if err != nil { + return echo.NewHTTPError(http.StatusInternalServerError, "Failed to unmarshal system setting value").SetInternal(err) + } + } + if openAIApiKey == "" { + return echo.NewHTTPError(http.StatusBadRequest, "OpenAI API key not set") + } + + completionRequest := api.OpenAICompletionRequest{} + if err := json.NewDecoder(c.Request().Body).Decode(&completionRequest); err != nil { + return echo.NewHTTPError(http.StatusBadRequest, "Malformatted post chat completion request").SetInternal(err) + } + if completionRequest.Prompt == "" { + return echo.NewHTTPError(http.StatusBadRequest, "Prompt is required") + } + + result, err := openai.PostChatCompletion(completionRequest.Prompt, openAIApiKey) + if err != nil { + return echo.NewHTTPError(http.StatusInternalServerError, "Failed to post chat completion").SetInternal(err) + } + + return c.JSON(http.StatusOK, composeResponse(result)) + }) + + g.POST("/opanai/text-completion", func(c echo.Context) error { + ctx := c.Request().Context() + openAIApiKeySetting, err := s.Store.FindSystemSetting(ctx, &api.SystemSettingFind{ + Name: api.SystemSettingOpenAIAPIKeyName, + }) + if err != nil && common.ErrorCode(err) != common.NotFound { + return echo.NewHTTPError(http.StatusInternalServerError, "Failed to find openai api key").SetInternal(err) + } + + openAIApiKey := "" + if openAIApiKeySetting != nil { + err = json.Unmarshal([]byte(openAIApiKeySetting.Value), &openAIApiKey) + if err != nil { + return echo.NewHTTPError(http.StatusInternalServerError, "Failed to unmarshal system setting value").SetInternal(err) + } + } + if openAIApiKey == "" { + return echo.NewHTTPError(http.StatusBadRequest, "OpenAI API key not set") + } + + textCompletion := api.OpenAICompletionRequest{} + if err := json.NewDecoder(c.Request().Body).Decode(&textCompletion); err != nil { + return echo.NewHTTPError(http.StatusBadRequest, "Malformatted post text completion request").SetInternal(err) + } + if textCompletion.Prompt == "" { + return echo.NewHTTPError(http.StatusBadRequest, "Prompt is required") + } + + result, err := openai.PostTextCompletion(textCompletion.Prompt, openAIApiKey) + if err != nil { + return echo.NewHTTPError(http.StatusInternalServerError, "Failed to post text completion").SetInternal(err) + } + + return c.JSON(http.StatusOK, composeResponse(result)) + }) + + g.GET("/opanai/enabled", func(c echo.Context) error { + ctx := c.Request().Context() + openAIApiKeySetting, err := s.Store.FindSystemSetting(ctx, &api.SystemSettingFind{ + Name: api.SystemSettingOpenAIAPIKeyName, + }) + if err != nil && common.ErrorCode(err) != common.NotFound { + return echo.NewHTTPError(http.StatusInternalServerError, "Failed to find openai api key").SetInternal(err) + } + + openAIApiKey := "" + if openAIApiKeySetting != nil { + err = json.Unmarshal([]byte(openAIApiKeySetting.Value), &openAIApiKey) + if err != nil { + return echo.NewHTTPError(http.StatusInternalServerError, "Failed to unmarshal system setting value").SetInternal(err) + } + } + + return c.JSON(http.StatusOK, composeResponse(openAIApiKey != "")) + }) +} diff --git a/server/server.go b/server/server.go index 30e4ed42..a114dd7b 100644 --- a/server/server.go +++ b/server/server.go @@ -117,6 +117,7 @@ func NewServer(ctx context.Context, profile *profile.Profile) (*Server, error) { s.registerTagRoutes(apiGroup) s.registerStorageRoutes(apiGroup) s.registerIdentityProviderRoutes(apiGroup) + s.registerOpenAIRoutes(apiGroup) return s, nil } diff --git a/server/system.go b/server/system.go index 5726c91a..a2b2a717 100644 --- a/server/system.go +++ b/server/system.go @@ -59,7 +59,7 @@ func (s *Server) registerSystemRoutes(g *echo.Group) { return echo.NewHTTPError(http.StatusInternalServerError, "Failed to find system setting list").SetInternal(err) } for _, systemSetting := range systemSettingList { - if systemSetting.Name == api.SystemSettingServerID || systemSetting.Name == api.SystemSettingSecretSessionName { + if systemSetting.Name == api.SystemSettingServerID || systemSetting.Name == api.SystemSettingSecretSessionName || systemSetting.Name == api.SystemSettingOpenAIAPIKeyName { continue } diff --git a/web/src/components/AskAIDialog.tsx b/web/src/components/AskAIDialog.tsx new file mode 100644 index 00000000..d6e6a9a1 --- /dev/null +++ b/web/src/components/AskAIDialog.tsx @@ -0,0 +1,130 @@ +import { Button, Textarea } from "@mui/joy"; +import { reverse } from "lodash-es"; +import { useEffect, useState } from "react"; +import * as api from "../helpers/api"; +import useLoading from "../hooks/useLoading"; +import { marked } from "../labs/marked"; +import Icon from "./Icon"; +import { generateDialog } from "./Dialog"; +import toastHelper from "./Toast"; +import showSettingDialog from "./SettingDialog"; + +type Props = DialogProps; + +interface History { + question: string; + answer: string; +} + +const AskAIDialog: React.FC = (props: Props) => { + const { destroy, hide } = props; + const fetchingState = useLoading(false); + const [historyList, setHistoryList] = useState([]); + const [isEnabled, setIsEnabled] = useState(true); + + useEffect(() => { + api.checkOpenAIEnabled().then(({ data }) => { + const { data: enabled } = data; + setIsEnabled(enabled); + }); + }, []); + + const handleGotoSystemSetting = () => { + showSettingDialog("system"); + destroy(); + }; + + const handleQuestionTextareaKeyDown = async (event: React.KeyboardEvent) => { + if (event.key === "Enter") { + event.preventDefault(); + const question = event.currentTarget.value; + event.currentTarget.value = ""; + + fetchingState.setLoading(); + try { + await askQuestion(question); + } catch (error: any) { + console.error(error); + toastHelper.error(error.response.data.error); + } + fetchingState.setFinish(); + } + }; + + const askQuestion = async (question: string) => { + if (question === "") { + return; + } + + const { + data: { data: answer }, + } = await api.postChatCompletion(question); + setHistoryList([ + ...historyList, + { + question, + answer: answer.replace(/^\n\n/, ""), + }, + ]); + }; + + return ( + <> +
+

+ + Ask AI +

+ +
+
+