mirror of https://github.com/usememos/memos
feat: implement openai integration (#1245)
* feat: implement openai integration * chore: updatepull/1246/head
parent
dd625d8edc
commit
df04e852bf
@ -0,0 +1,5 @@
|
|||||||
|
package api
|
||||||
|
|
||||||
|
type OpenAICompletionRequest struct {
|
||||||
|
Prompt string `json:"prompt"`
|
||||||
|
}
|
@ -0,0 +1,70 @@
|
|||||||
|
package openai
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
type ChatCompletionMessage struct {
|
||||||
|
Role string `json:"role"`
|
||||||
|
Content string `json:"content"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type ChatCompletionChoice struct {
|
||||||
|
Message *ChatCompletionMessage `json:"message"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type ChatCompletionResponse struct {
|
||||||
|
Error interface{} `json:"error"`
|
||||||
|
Model string `json:"model"`
|
||||||
|
Choices []ChatCompletionChoice `json:"choices"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func PostChatCompletion(prompt string, apiKey string) (string, error) {
|
||||||
|
requestBody := strings.NewReader(`{
|
||||||
|
"model": "gpt-3.5-turbo",
|
||||||
|
"messages": [{"role": "user", "content": "` + prompt + `"}]
|
||||||
|
}`)
|
||||||
|
req, err := http.NewRequest("POST", "https://api.openai.com/v1/chat/completions", requestBody)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set the API key in the request header
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
req.Header.Set("Authorization", "Bearer "+apiKey)
|
||||||
|
|
||||||
|
// Send the request to OpenAI's API
|
||||||
|
client := http.Client{}
|
||||||
|
resp, err := client.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
// Read the response body
|
||||||
|
responseBody, err := io.ReadAll(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
chatCompletionResponse := ChatCompletionResponse{}
|
||||||
|
err = json.Unmarshal(responseBody, &chatCompletionResponse)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
if chatCompletionResponse.Error != nil {
|
||||||
|
errorBytes, err := json.Marshal(chatCompletionResponse.Error)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
return "", errors.New(string(errorBytes))
|
||||||
|
}
|
||||||
|
if len(chatCompletionResponse.Choices) == 0 {
|
||||||
|
return "", nil
|
||||||
|
}
|
||||||
|
return chatCompletionResponse.Choices[0].Message.Content, nil
|
||||||
|
}
|
@ -0,0 +1,68 @@
|
|||||||
|
package openai
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
type TextCompletionChoice struct {
|
||||||
|
Text string `json:"text"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type TextCompletionResponse struct {
|
||||||
|
Error interface{} `json:"error"`
|
||||||
|
Model string `json:"model"`
|
||||||
|
Choices []TextCompletionChoice `json:"choices"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func PostTextCompletion(prompt string, apiKey string) (string, error) {
|
||||||
|
requestBody := strings.NewReader(`{
|
||||||
|
"prompt": "` + prompt + `",
|
||||||
|
"temperature": 0.5,
|
||||||
|
"max_tokens": 100,
|
||||||
|
"n": 1,
|
||||||
|
"stop": "."
|
||||||
|
}`)
|
||||||
|
req, err := http.NewRequest("POST", "https://api.openai.com/v1/completions", requestBody)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set the API key in the request header
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
req.Header.Set("Authorization", "Bearer "+apiKey)
|
||||||
|
|
||||||
|
// Send the request to OpenAI's API
|
||||||
|
client := http.Client{}
|
||||||
|
resp, err := client.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
// Read the response body
|
||||||
|
responseBody, err := io.ReadAll(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
textCompletionResponse := TextCompletionResponse{}
|
||||||
|
err = json.Unmarshal(responseBody, &textCompletionResponse)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
if textCompletionResponse.Error != nil {
|
||||||
|
errorBytes, err := json.Marshal(textCompletionResponse.Error)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
return "", errors.New(string(errorBytes))
|
||||||
|
}
|
||||||
|
if len(textCompletionResponse.Choices) == 0 {
|
||||||
|
return "", nil
|
||||||
|
}
|
||||||
|
return textCompletionResponse.Choices[0].Text, nil
|
||||||
|
}
|
@ -0,0 +1,105 @@
|
|||||||
|
package server
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"net/http"
|
||||||
|
|
||||||
|
"github.com/labstack/echo/v4"
|
||||||
|
"github.com/usememos/memos/api"
|
||||||
|
"github.com/usememos/memos/common"
|
||||||
|
"github.com/usememos/memos/plugin/openai"
|
||||||
|
)
|
||||||
|
|
||||||
|
func (s *Server) registerOpenAIRoutes(g *echo.Group) {
|
||||||
|
g.POST("/opanai/chat-completion", func(c echo.Context) error {
|
||||||
|
ctx := c.Request().Context()
|
||||||
|
openAIApiKeySetting, err := s.Store.FindSystemSetting(ctx, &api.SystemSettingFind{
|
||||||
|
Name: api.SystemSettingOpenAIAPIKeyName,
|
||||||
|
})
|
||||||
|
if err != nil && common.ErrorCode(err) != common.NotFound {
|
||||||
|
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to find openai api key").SetInternal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
openAIApiKey := ""
|
||||||
|
if openAIApiKeySetting != nil {
|
||||||
|
err = json.Unmarshal([]byte(openAIApiKeySetting.Value), &openAIApiKey)
|
||||||
|
if err != nil {
|
||||||
|
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to unmarshal system setting value").SetInternal(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if openAIApiKey == "" {
|
||||||
|
return echo.NewHTTPError(http.StatusBadRequest, "OpenAI API key not set")
|
||||||
|
}
|
||||||
|
|
||||||
|
completionRequest := api.OpenAICompletionRequest{}
|
||||||
|
if err := json.NewDecoder(c.Request().Body).Decode(&completionRequest); err != nil {
|
||||||
|
return echo.NewHTTPError(http.StatusBadRequest, "Malformatted post chat completion request").SetInternal(err)
|
||||||
|
}
|
||||||
|
if completionRequest.Prompt == "" {
|
||||||
|
return echo.NewHTTPError(http.StatusBadRequest, "Prompt is required")
|
||||||
|
}
|
||||||
|
|
||||||
|
result, err := openai.PostChatCompletion(completionRequest.Prompt, openAIApiKey)
|
||||||
|
if err != nil {
|
||||||
|
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to post chat completion").SetInternal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return c.JSON(http.StatusOK, composeResponse(result))
|
||||||
|
})
|
||||||
|
|
||||||
|
g.POST("/opanai/text-completion", func(c echo.Context) error {
|
||||||
|
ctx := c.Request().Context()
|
||||||
|
openAIApiKeySetting, err := s.Store.FindSystemSetting(ctx, &api.SystemSettingFind{
|
||||||
|
Name: api.SystemSettingOpenAIAPIKeyName,
|
||||||
|
})
|
||||||
|
if err != nil && common.ErrorCode(err) != common.NotFound {
|
||||||
|
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to find openai api key").SetInternal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
openAIApiKey := ""
|
||||||
|
if openAIApiKeySetting != nil {
|
||||||
|
err = json.Unmarshal([]byte(openAIApiKeySetting.Value), &openAIApiKey)
|
||||||
|
if err != nil {
|
||||||
|
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to unmarshal system setting value").SetInternal(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if openAIApiKey == "" {
|
||||||
|
return echo.NewHTTPError(http.StatusBadRequest, "OpenAI API key not set")
|
||||||
|
}
|
||||||
|
|
||||||
|
textCompletion := api.OpenAICompletionRequest{}
|
||||||
|
if err := json.NewDecoder(c.Request().Body).Decode(&textCompletion); err != nil {
|
||||||
|
return echo.NewHTTPError(http.StatusBadRequest, "Malformatted post text completion request").SetInternal(err)
|
||||||
|
}
|
||||||
|
if textCompletion.Prompt == "" {
|
||||||
|
return echo.NewHTTPError(http.StatusBadRequest, "Prompt is required")
|
||||||
|
}
|
||||||
|
|
||||||
|
result, err := openai.PostTextCompletion(textCompletion.Prompt, openAIApiKey)
|
||||||
|
if err != nil {
|
||||||
|
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to post text completion").SetInternal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return c.JSON(http.StatusOK, composeResponse(result))
|
||||||
|
})
|
||||||
|
|
||||||
|
g.GET("/opanai/enabled", func(c echo.Context) error {
|
||||||
|
ctx := c.Request().Context()
|
||||||
|
openAIApiKeySetting, err := s.Store.FindSystemSetting(ctx, &api.SystemSettingFind{
|
||||||
|
Name: api.SystemSettingOpenAIAPIKeyName,
|
||||||
|
})
|
||||||
|
if err != nil && common.ErrorCode(err) != common.NotFound {
|
||||||
|
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to find openai api key").SetInternal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
openAIApiKey := ""
|
||||||
|
if openAIApiKeySetting != nil {
|
||||||
|
err = json.Unmarshal([]byte(openAIApiKeySetting.Value), &openAIApiKey)
|
||||||
|
if err != nil {
|
||||||
|
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to unmarshal system setting value").SetInternal(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return c.JSON(http.StatusOK, composeResponse(openAIApiKey != ""))
|
||||||
|
})
|
||||||
|
}
|
@ -0,0 +1,130 @@
|
|||||||
|
import { Button, Textarea } from "@mui/joy";
|
||||||
|
import { reverse } from "lodash-es";
|
||||||
|
import { useEffect, useState } from "react";
|
||||||
|
import * as api from "../helpers/api";
|
||||||
|
import useLoading from "../hooks/useLoading";
|
||||||
|
import { marked } from "../labs/marked";
|
||||||
|
import Icon from "./Icon";
|
||||||
|
import { generateDialog } from "./Dialog";
|
||||||
|
import toastHelper from "./Toast";
|
||||||
|
import showSettingDialog from "./SettingDialog";
|
||||||
|
|
||||||
|
type Props = DialogProps;
|
||||||
|
|
||||||
|
interface History {
|
||||||
|
question: string;
|
||||||
|
answer: string;
|
||||||
|
}
|
||||||
|
|
||||||
|
const AskAIDialog: React.FC<Props> = (props: Props) => {
|
||||||
|
const { destroy, hide } = props;
|
||||||
|
const fetchingState = useLoading(false);
|
||||||
|
const [historyList, setHistoryList] = useState<History[]>([]);
|
||||||
|
const [isEnabled, setIsEnabled] = useState<boolean>(true);
|
||||||
|
|
||||||
|
useEffect(() => {
|
||||||
|
api.checkOpenAIEnabled().then(({ data }) => {
|
||||||
|
const { data: enabled } = data;
|
||||||
|
setIsEnabled(enabled);
|
||||||
|
});
|
||||||
|
}, []);
|
||||||
|
|
||||||
|
const handleGotoSystemSetting = () => {
|
||||||
|
showSettingDialog("system");
|
||||||
|
destroy();
|
||||||
|
};
|
||||||
|
|
||||||
|
const handleQuestionTextareaKeyDown = async (event: React.KeyboardEvent<HTMLTextAreaElement>) => {
|
||||||
|
if (event.key === "Enter") {
|
||||||
|
event.preventDefault();
|
||||||
|
const question = event.currentTarget.value;
|
||||||
|
event.currentTarget.value = "";
|
||||||
|
|
||||||
|
fetchingState.setLoading();
|
||||||
|
try {
|
||||||
|
await askQuestion(question);
|
||||||
|
} catch (error: any) {
|
||||||
|
console.error(error);
|
||||||
|
toastHelper.error(error.response.data.error);
|
||||||
|
}
|
||||||
|
fetchingState.setFinish();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
const askQuestion = async (question: string) => {
|
||||||
|
if (question === "") {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const {
|
||||||
|
data: { data: answer },
|
||||||
|
} = await api.postChatCompletion(question);
|
||||||
|
setHistoryList([
|
||||||
|
...historyList,
|
||||||
|
{
|
||||||
|
question,
|
||||||
|
answer: answer.replace(/^\n\n/, ""),
|
||||||
|
},
|
||||||
|
]);
|
||||||
|
};
|
||||||
|
|
||||||
|
return (
|
||||||
|
<>
|
||||||
|
<div className="dialog-header-container">
|
||||||
|
<p className="title-text flex flex-row items-center">
|
||||||
|
<Icon.Bot className="mr-1 w-5 h-auto opacity-80" />
|
||||||
|
Ask AI
|
||||||
|
</p>
|
||||||
|
<button className="btn close-btn" onClick={() => hide()}>
|
||||||
|
<Icon.X />
|
||||||
|
</button>
|
||||||
|
</div>
|
||||||
|
<div className="dialog-content-container !w-112">
|
||||||
|
<Textarea className="w-full" placeholder="Ask anything…" onKeyDown={handleQuestionTextareaKeyDown} />
|
||||||
|
{fetchingState.isLoading && (
|
||||||
|
<p className="w-full py-2 mt-4 flex flex-row justify-center items-center">
|
||||||
|
<Icon.Loader className="w-5 h-auto animate-spin" />
|
||||||
|
</p>
|
||||||
|
)}
|
||||||
|
{reverse(historyList).map((history, index) => (
|
||||||
|
<div key={index} className="w-full flex flex-col justify-start items-start mt-4 space-y-2">
|
||||||
|
<div className="w-full flex flex-row justify-start items-start pr-6">
|
||||||
|
<span className="word-break rounded shadow px-3 py-2 opacity-80 bg-gray-100 dark:bg-zinc-700">{history.question}</span>
|
||||||
|
</div>
|
||||||
|
<div className="w-full flex flex-row justify-end items-start pl-8 space-x-2">
|
||||||
|
<div className="memo-content-wrapper !w-auto flex flex-col justify-start items-start rounded shadow px-3 py-2 bg-gray-100 dark:bg-zinc-700">
|
||||||
|
<div className="memo-content-text">{marked(history.answer)}</div>
|
||||||
|
</div>
|
||||||
|
<Icon.Bot className="mt-2 flex-shrink-0 mr-1 w-6 h-auto opacity-80" />
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
))}
|
||||||
|
{!isEnabled && (
|
||||||
|
<div className="w-full flex flex-col justify-center items-center mt-4 space-y-2">
|
||||||
|
<p>You have not set up your OpenAI API key.</p>
|
||||||
|
<Button onClick={() => handleGotoSystemSetting()}>Go to settings</Button>
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
|
</div>
|
||||||
|
</>
|
||||||
|
);
|
||||||
|
};
|
||||||
|
|
||||||
|
function showAskAIDialog() {
|
||||||
|
const dialogname = "ask-ai-dialog";
|
||||||
|
const dialogElement = document.body.querySelector(`div.${dialogname}`);
|
||||||
|
if (dialogElement) {
|
||||||
|
dialogElement.classList.remove("showoff");
|
||||||
|
dialogElement.classList.add("showup");
|
||||||
|
} else {
|
||||||
|
generateDialog(
|
||||||
|
{
|
||||||
|
className: dialogname,
|
||||||
|
dialogName: dialogname,
|
||||||
|
},
|
||||||
|
AskAIDialog
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
export default showAskAIDialog;
|
@ -1,7 +1,6 @@
|
|||||||
interface DialogProps {
|
|
||||||
destroy: FunctionType;
|
|
||||||
}
|
|
||||||
|
|
||||||
interface DialogCallback {
|
interface DialogCallback {
|
||||||
destroy: FunctionType;
|
destroy: FunctionType;
|
||||||
|
hide: FunctionType;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type DialogProps = DialogCallback;
|
||||||
|
Loading…
Reference in New Issue