mirror of https://github.com/usememos/memos
feat: implement openai integration (#1245)
* feat: implement openai integration * chore: updatepull/1246/head
parent
dd625d8edc
commit
df04e852bf
@ -0,0 +1,5 @@
|
||||
package api
|
||||
|
||||
type OpenAICompletionRequest struct {
|
||||
Prompt string `json:"prompt"`
|
||||
}
|
@ -0,0 +1,70 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type ChatCompletionMessage struct {
|
||||
Role string `json:"role"`
|
||||
Content string `json:"content"`
|
||||
}
|
||||
|
||||
type ChatCompletionChoice struct {
|
||||
Message *ChatCompletionMessage `json:"message"`
|
||||
}
|
||||
|
||||
type ChatCompletionResponse struct {
|
||||
Error interface{} `json:"error"`
|
||||
Model string `json:"model"`
|
||||
Choices []ChatCompletionChoice `json:"choices"`
|
||||
}
|
||||
|
||||
func PostChatCompletion(prompt string, apiKey string) (string, error) {
|
||||
requestBody := strings.NewReader(`{
|
||||
"model": "gpt-3.5-turbo",
|
||||
"messages": [{"role": "user", "content": "` + prompt + `"}]
|
||||
}`)
|
||||
req, err := http.NewRequest("POST", "https://api.openai.com/v1/chat/completions", requestBody)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
// Set the API key in the request header
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Authorization", "Bearer "+apiKey)
|
||||
|
||||
// Send the request to OpenAI's API
|
||||
client := http.Client{}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
// Read the response body
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
chatCompletionResponse := ChatCompletionResponse{}
|
||||
err = json.Unmarshal(responseBody, &chatCompletionResponse)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if chatCompletionResponse.Error != nil {
|
||||
errorBytes, err := json.Marshal(chatCompletionResponse.Error)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return "", errors.New(string(errorBytes))
|
||||
}
|
||||
if len(chatCompletionResponse.Choices) == 0 {
|
||||
return "", nil
|
||||
}
|
||||
return chatCompletionResponse.Choices[0].Message.Content, nil
|
||||
}
|
@ -0,0 +1,68 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type TextCompletionChoice struct {
|
||||
Text string `json:"text"`
|
||||
}
|
||||
|
||||
type TextCompletionResponse struct {
|
||||
Error interface{} `json:"error"`
|
||||
Model string `json:"model"`
|
||||
Choices []TextCompletionChoice `json:"choices"`
|
||||
}
|
||||
|
||||
func PostTextCompletion(prompt string, apiKey string) (string, error) {
|
||||
requestBody := strings.NewReader(`{
|
||||
"prompt": "` + prompt + `",
|
||||
"temperature": 0.5,
|
||||
"max_tokens": 100,
|
||||
"n": 1,
|
||||
"stop": "."
|
||||
}`)
|
||||
req, err := http.NewRequest("POST", "https://api.openai.com/v1/completions", requestBody)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
// Set the API key in the request header
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Authorization", "Bearer "+apiKey)
|
||||
|
||||
// Send the request to OpenAI's API
|
||||
client := http.Client{}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
// Read the response body
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
textCompletionResponse := TextCompletionResponse{}
|
||||
err = json.Unmarshal(responseBody, &textCompletionResponse)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if textCompletionResponse.Error != nil {
|
||||
errorBytes, err := json.Marshal(textCompletionResponse.Error)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return "", errors.New(string(errorBytes))
|
||||
}
|
||||
if len(textCompletionResponse.Choices) == 0 {
|
||||
return "", nil
|
||||
}
|
||||
return textCompletionResponse.Choices[0].Text, nil
|
||||
}
|
@ -0,0 +1,105 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/usememos/memos/api"
|
||||
"github.com/usememos/memos/common"
|
||||
"github.com/usememos/memos/plugin/openai"
|
||||
)
|
||||
|
||||
func (s *Server) registerOpenAIRoutes(g *echo.Group) {
|
||||
g.POST("/opanai/chat-completion", func(c echo.Context) error {
|
||||
ctx := c.Request().Context()
|
||||
openAIApiKeySetting, err := s.Store.FindSystemSetting(ctx, &api.SystemSettingFind{
|
||||
Name: api.SystemSettingOpenAIAPIKeyName,
|
||||
})
|
||||
if err != nil && common.ErrorCode(err) != common.NotFound {
|
||||
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to find openai api key").SetInternal(err)
|
||||
}
|
||||
|
||||
openAIApiKey := ""
|
||||
if openAIApiKeySetting != nil {
|
||||
err = json.Unmarshal([]byte(openAIApiKeySetting.Value), &openAIApiKey)
|
||||
if err != nil {
|
||||
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to unmarshal system setting value").SetInternal(err)
|
||||
}
|
||||
}
|
||||
if openAIApiKey == "" {
|
||||
return echo.NewHTTPError(http.StatusBadRequest, "OpenAI API key not set")
|
||||
}
|
||||
|
||||
completionRequest := api.OpenAICompletionRequest{}
|
||||
if err := json.NewDecoder(c.Request().Body).Decode(&completionRequest); err != nil {
|
||||
return echo.NewHTTPError(http.StatusBadRequest, "Malformatted post chat completion request").SetInternal(err)
|
||||
}
|
||||
if completionRequest.Prompt == "" {
|
||||
return echo.NewHTTPError(http.StatusBadRequest, "Prompt is required")
|
||||
}
|
||||
|
||||
result, err := openai.PostChatCompletion(completionRequest.Prompt, openAIApiKey)
|
||||
if err != nil {
|
||||
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to post chat completion").SetInternal(err)
|
||||
}
|
||||
|
||||
return c.JSON(http.StatusOK, composeResponse(result))
|
||||
})
|
||||
|
||||
g.POST("/opanai/text-completion", func(c echo.Context) error {
|
||||
ctx := c.Request().Context()
|
||||
openAIApiKeySetting, err := s.Store.FindSystemSetting(ctx, &api.SystemSettingFind{
|
||||
Name: api.SystemSettingOpenAIAPIKeyName,
|
||||
})
|
||||
if err != nil && common.ErrorCode(err) != common.NotFound {
|
||||
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to find openai api key").SetInternal(err)
|
||||
}
|
||||
|
||||
openAIApiKey := ""
|
||||
if openAIApiKeySetting != nil {
|
||||
err = json.Unmarshal([]byte(openAIApiKeySetting.Value), &openAIApiKey)
|
||||
if err != nil {
|
||||
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to unmarshal system setting value").SetInternal(err)
|
||||
}
|
||||
}
|
||||
if openAIApiKey == "" {
|
||||
return echo.NewHTTPError(http.StatusBadRequest, "OpenAI API key not set")
|
||||
}
|
||||
|
||||
textCompletion := api.OpenAICompletionRequest{}
|
||||
if err := json.NewDecoder(c.Request().Body).Decode(&textCompletion); err != nil {
|
||||
return echo.NewHTTPError(http.StatusBadRequest, "Malformatted post text completion request").SetInternal(err)
|
||||
}
|
||||
if textCompletion.Prompt == "" {
|
||||
return echo.NewHTTPError(http.StatusBadRequest, "Prompt is required")
|
||||
}
|
||||
|
||||
result, err := openai.PostTextCompletion(textCompletion.Prompt, openAIApiKey)
|
||||
if err != nil {
|
||||
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to post text completion").SetInternal(err)
|
||||
}
|
||||
|
||||
return c.JSON(http.StatusOK, composeResponse(result))
|
||||
})
|
||||
|
||||
g.GET("/opanai/enabled", func(c echo.Context) error {
|
||||
ctx := c.Request().Context()
|
||||
openAIApiKeySetting, err := s.Store.FindSystemSetting(ctx, &api.SystemSettingFind{
|
||||
Name: api.SystemSettingOpenAIAPIKeyName,
|
||||
})
|
||||
if err != nil && common.ErrorCode(err) != common.NotFound {
|
||||
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to find openai api key").SetInternal(err)
|
||||
}
|
||||
|
||||
openAIApiKey := ""
|
||||
if openAIApiKeySetting != nil {
|
||||
err = json.Unmarshal([]byte(openAIApiKeySetting.Value), &openAIApiKey)
|
||||
if err != nil {
|
||||
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to unmarshal system setting value").SetInternal(err)
|
||||
}
|
||||
}
|
||||
|
||||
return c.JSON(http.StatusOK, composeResponse(openAIApiKey != ""))
|
||||
})
|
||||
}
|
@ -0,0 +1,130 @@
|
||||
import { Button, Textarea } from "@mui/joy";
|
||||
import { reverse } from "lodash-es";
|
||||
import { useEffect, useState } from "react";
|
||||
import * as api from "../helpers/api";
|
||||
import useLoading from "../hooks/useLoading";
|
||||
import { marked } from "../labs/marked";
|
||||
import Icon from "./Icon";
|
||||
import { generateDialog } from "./Dialog";
|
||||
import toastHelper from "./Toast";
|
||||
import showSettingDialog from "./SettingDialog";
|
||||
|
||||
type Props = DialogProps;
|
||||
|
||||
interface History {
|
||||
question: string;
|
||||
answer: string;
|
||||
}
|
||||
|
||||
const AskAIDialog: React.FC<Props> = (props: Props) => {
|
||||
const { destroy, hide } = props;
|
||||
const fetchingState = useLoading(false);
|
||||
const [historyList, setHistoryList] = useState<History[]>([]);
|
||||
const [isEnabled, setIsEnabled] = useState<boolean>(true);
|
||||
|
||||
useEffect(() => {
|
||||
api.checkOpenAIEnabled().then(({ data }) => {
|
||||
const { data: enabled } = data;
|
||||
setIsEnabled(enabled);
|
||||
});
|
||||
}, []);
|
||||
|
||||
const handleGotoSystemSetting = () => {
|
||||
showSettingDialog("system");
|
||||
destroy();
|
||||
};
|
||||
|
||||
const handleQuestionTextareaKeyDown = async (event: React.KeyboardEvent<HTMLTextAreaElement>) => {
|
||||
if (event.key === "Enter") {
|
||||
event.preventDefault();
|
||||
const question = event.currentTarget.value;
|
||||
event.currentTarget.value = "";
|
||||
|
||||
fetchingState.setLoading();
|
||||
try {
|
||||
await askQuestion(question);
|
||||
} catch (error: any) {
|
||||
console.error(error);
|
||||
toastHelper.error(error.response.data.error);
|
||||
}
|
||||
fetchingState.setFinish();
|
||||
}
|
||||
};
|
||||
|
||||
const askQuestion = async (question: string) => {
|
||||
if (question === "") {
|
||||
return;
|
||||
}
|
||||
|
||||
const {
|
||||
data: { data: answer },
|
||||
} = await api.postChatCompletion(question);
|
||||
setHistoryList([
|
||||
...historyList,
|
||||
{
|
||||
question,
|
||||
answer: answer.replace(/^\n\n/, ""),
|
||||
},
|
||||
]);
|
||||
};
|
||||
|
||||
return (
|
||||
<>
|
||||
<div className="dialog-header-container">
|
||||
<p className="title-text flex flex-row items-center">
|
||||
<Icon.Bot className="mr-1 w-5 h-auto opacity-80" />
|
||||
Ask AI
|
||||
</p>
|
||||
<button className="btn close-btn" onClick={() => hide()}>
|
||||
<Icon.X />
|
||||
</button>
|
||||
</div>
|
||||
<div className="dialog-content-container !w-112">
|
||||
<Textarea className="w-full" placeholder="Ask anything…" onKeyDown={handleQuestionTextareaKeyDown} />
|
||||
{fetchingState.isLoading && (
|
||||
<p className="w-full py-2 mt-4 flex flex-row justify-center items-center">
|
||||
<Icon.Loader className="w-5 h-auto animate-spin" />
|
||||
</p>
|
||||
)}
|
||||
{reverse(historyList).map((history, index) => (
|
||||
<div key={index} className="w-full flex flex-col justify-start items-start mt-4 space-y-2">
|
||||
<div className="w-full flex flex-row justify-start items-start pr-6">
|
||||
<span className="word-break rounded shadow px-3 py-2 opacity-80 bg-gray-100 dark:bg-zinc-700">{history.question}</span>
|
||||
</div>
|
||||
<div className="w-full flex flex-row justify-end items-start pl-8 space-x-2">
|
||||
<div className="memo-content-wrapper !w-auto flex flex-col justify-start items-start rounded shadow px-3 py-2 bg-gray-100 dark:bg-zinc-700">
|
||||
<div className="memo-content-text">{marked(history.answer)}</div>
|
||||
</div>
|
||||
<Icon.Bot className="mt-2 flex-shrink-0 mr-1 w-6 h-auto opacity-80" />
|
||||
</div>
|
||||
</div>
|
||||
))}
|
||||
{!isEnabled && (
|
||||
<div className="w-full flex flex-col justify-center items-center mt-4 space-y-2">
|
||||
<p>You have not set up your OpenAI API key.</p>
|
||||
<Button onClick={() => handleGotoSystemSetting()}>Go to settings</Button>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
</>
|
||||
);
|
||||
};
|
||||
|
||||
function showAskAIDialog() {
|
||||
const dialogname = "ask-ai-dialog";
|
||||
const dialogElement = document.body.querySelector(`div.${dialogname}`);
|
||||
if (dialogElement) {
|
||||
dialogElement.classList.remove("showoff");
|
||||
dialogElement.classList.add("showup");
|
||||
} else {
|
||||
generateDialog(
|
||||
{
|
||||
className: dialogname,
|
||||
dialogName: dialogname,
|
||||
},
|
||||
AskAIDialog
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
export default showAskAIDialog;
|
@ -1,7 +1,6 @@
|
||||
interface DialogProps {
|
||||
destroy: FunctionType;
|
||||
}
|
||||
|
||||
interface DialogCallback {
|
||||
destroy: FunctionType;
|
||||
hide: FunctionType;
|
||||
}
|
||||
|
||||
type DialogProps = DialogCallback;
|
||||
|
Loading…
Reference in New Issue