feat: implement openai integration (#1245)

* feat: implement openai integration

* chore: update
pull/1246/head
boojack 2 years ago committed by GitHub
parent dd625d8edc
commit df04e852bf
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -0,0 +1,5 @@
package api
type OpenAICompletionRequest struct {
Prompt string `json:"prompt"`
}

@ -27,6 +27,8 @@ const (
SystemSettingCustomizedProfileName SystemSettingName = "customizedProfile"
// SystemSettingStorageServiceIDName is the key type of storage service ID.
SystemSettingStorageServiceIDName SystemSettingName = "storageServiceId"
// SystemSettingOpenAIAPIKeyName is the key type of OpenAI API key.
SystemSettingOpenAIAPIKeyName SystemSettingName = "openAIApiKey"
)
// CustomizedProfile is the struct definition for SystemSettingCustomizedProfileName system setting item.
@ -63,6 +65,8 @@ func (key SystemSettingName) String() string {
return "customizedProfile"
case SystemSettingStorageServiceIDName:
return "storageServiceId"
case SystemSettingOpenAIAPIKeyName:
return "openAIApiKey"
}
return ""
}
@ -161,6 +165,12 @@ func (upsert SystemSettingUpsert) Validate() error {
return fmt.Errorf("failed to unmarshal system setting storage service id value")
}
return nil
} else if upsert.Name == SystemSettingOpenAIAPIKeyName {
value := ""
err := json.Unmarshal([]byte(upsert.Value), &value)
if err != nil {
return fmt.Errorf("failed to unmarshal system setting openai api key value")
}
} else {
return fmt.Errorf("invalid system setting name")
}

@ -0,0 +1,70 @@
package openai
import (
"encoding/json"
"errors"
"io"
"net/http"
"strings"
)
type ChatCompletionMessage struct {
Role string `json:"role"`
Content string `json:"content"`
}
type ChatCompletionChoice struct {
Message *ChatCompletionMessage `json:"message"`
}
type ChatCompletionResponse struct {
Error interface{} `json:"error"`
Model string `json:"model"`
Choices []ChatCompletionChoice `json:"choices"`
}
func PostChatCompletion(prompt string, apiKey string) (string, error) {
requestBody := strings.NewReader(`{
"model": "gpt-3.5-turbo",
"messages": [{"role": "user", "content": "` + prompt + `"}]
}`)
req, err := http.NewRequest("POST", "https://api.openai.com/v1/chat/completions", requestBody)
if err != nil {
return "", err
}
// Set the API key in the request header
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+apiKey)
// Send the request to OpenAI's API
client := http.Client{}
resp, err := client.Do(req)
if err != nil {
return "", err
}
defer resp.Body.Close()
// Read the response body
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return "", err
}
chatCompletionResponse := ChatCompletionResponse{}
err = json.Unmarshal(responseBody, &chatCompletionResponse)
if err != nil {
return "", err
}
if chatCompletionResponse.Error != nil {
errorBytes, err := json.Marshal(chatCompletionResponse.Error)
if err != nil {
return "", err
}
return "", errors.New(string(errorBytes))
}
if len(chatCompletionResponse.Choices) == 0 {
return "", nil
}
return chatCompletionResponse.Choices[0].Message.Content, nil
}

@ -0,0 +1,68 @@
package openai
import (
"encoding/json"
"errors"
"io"
"net/http"
"strings"
)
type TextCompletionChoice struct {
Text string `json:"text"`
}
type TextCompletionResponse struct {
Error interface{} `json:"error"`
Model string `json:"model"`
Choices []TextCompletionChoice `json:"choices"`
}
func PostTextCompletion(prompt string, apiKey string) (string, error) {
requestBody := strings.NewReader(`{
"prompt": "` + prompt + `",
"temperature": 0.5,
"max_tokens": 100,
"n": 1,
"stop": "."
}`)
req, err := http.NewRequest("POST", "https://api.openai.com/v1/completions", requestBody)
if err != nil {
return "", err
}
// Set the API key in the request header
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+apiKey)
// Send the request to OpenAI's API
client := http.Client{}
resp, err := client.Do(req)
if err != nil {
return "", err
}
defer resp.Body.Close()
// Read the response body
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return "", err
}
textCompletionResponse := TextCompletionResponse{}
err = json.Unmarshal(responseBody, &textCompletionResponse)
if err != nil {
return "", err
}
if textCompletionResponse.Error != nil {
errorBytes, err := json.Marshal(textCompletionResponse.Error)
if err != nil {
return "", err
}
return "", errors.New(string(errorBytes))
}
if len(textCompletionResponse.Choices) == 0 {
return "", nil
}
return textCompletionResponse.Choices[0].Text, nil
}

@ -0,0 +1,105 @@
package server
import (
"encoding/json"
"net/http"
"github.com/labstack/echo/v4"
"github.com/usememos/memos/api"
"github.com/usememos/memos/common"
"github.com/usememos/memos/plugin/openai"
)
func (s *Server) registerOpenAIRoutes(g *echo.Group) {
g.POST("/opanai/chat-completion", func(c echo.Context) error {
ctx := c.Request().Context()
openAIApiKeySetting, err := s.Store.FindSystemSetting(ctx, &api.SystemSettingFind{
Name: api.SystemSettingOpenAIAPIKeyName,
})
if err != nil && common.ErrorCode(err) != common.NotFound {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to find openai api key").SetInternal(err)
}
openAIApiKey := ""
if openAIApiKeySetting != nil {
err = json.Unmarshal([]byte(openAIApiKeySetting.Value), &openAIApiKey)
if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to unmarshal system setting value").SetInternal(err)
}
}
if openAIApiKey == "" {
return echo.NewHTTPError(http.StatusBadRequest, "OpenAI API key not set")
}
completionRequest := api.OpenAICompletionRequest{}
if err := json.NewDecoder(c.Request().Body).Decode(&completionRequest); err != nil {
return echo.NewHTTPError(http.StatusBadRequest, "Malformatted post chat completion request").SetInternal(err)
}
if completionRequest.Prompt == "" {
return echo.NewHTTPError(http.StatusBadRequest, "Prompt is required")
}
result, err := openai.PostChatCompletion(completionRequest.Prompt, openAIApiKey)
if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to post chat completion").SetInternal(err)
}
return c.JSON(http.StatusOK, composeResponse(result))
})
g.POST("/opanai/text-completion", func(c echo.Context) error {
ctx := c.Request().Context()
openAIApiKeySetting, err := s.Store.FindSystemSetting(ctx, &api.SystemSettingFind{
Name: api.SystemSettingOpenAIAPIKeyName,
})
if err != nil && common.ErrorCode(err) != common.NotFound {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to find openai api key").SetInternal(err)
}
openAIApiKey := ""
if openAIApiKeySetting != nil {
err = json.Unmarshal([]byte(openAIApiKeySetting.Value), &openAIApiKey)
if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to unmarshal system setting value").SetInternal(err)
}
}
if openAIApiKey == "" {
return echo.NewHTTPError(http.StatusBadRequest, "OpenAI API key not set")
}
textCompletion := api.OpenAICompletionRequest{}
if err := json.NewDecoder(c.Request().Body).Decode(&textCompletion); err != nil {
return echo.NewHTTPError(http.StatusBadRequest, "Malformatted post text completion request").SetInternal(err)
}
if textCompletion.Prompt == "" {
return echo.NewHTTPError(http.StatusBadRequest, "Prompt is required")
}
result, err := openai.PostTextCompletion(textCompletion.Prompt, openAIApiKey)
if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to post text completion").SetInternal(err)
}
return c.JSON(http.StatusOK, composeResponse(result))
})
g.GET("/opanai/enabled", func(c echo.Context) error {
ctx := c.Request().Context()
openAIApiKeySetting, err := s.Store.FindSystemSetting(ctx, &api.SystemSettingFind{
Name: api.SystemSettingOpenAIAPIKeyName,
})
if err != nil && common.ErrorCode(err) != common.NotFound {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to find openai api key").SetInternal(err)
}
openAIApiKey := ""
if openAIApiKeySetting != nil {
err = json.Unmarshal([]byte(openAIApiKeySetting.Value), &openAIApiKey)
if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to unmarshal system setting value").SetInternal(err)
}
}
return c.JSON(http.StatusOK, composeResponse(openAIApiKey != ""))
})
}

@ -117,6 +117,7 @@ func NewServer(ctx context.Context, profile *profile.Profile) (*Server, error) {
s.registerTagRoutes(apiGroup)
s.registerStorageRoutes(apiGroup)
s.registerIdentityProviderRoutes(apiGroup)
s.registerOpenAIRoutes(apiGroup)
return s, nil
}

@ -59,7 +59,7 @@ func (s *Server) registerSystemRoutes(g *echo.Group) {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to find system setting list").SetInternal(err)
}
for _, systemSetting := range systemSettingList {
if systemSetting.Name == api.SystemSettingServerID || systemSetting.Name == api.SystemSettingSecretSessionName {
if systemSetting.Name == api.SystemSettingServerID || systemSetting.Name == api.SystemSettingSecretSessionName || systemSetting.Name == api.SystemSettingOpenAIAPIKeyName {
continue
}

@ -0,0 +1,130 @@
import { Button, Textarea } from "@mui/joy";
import { reverse } from "lodash-es";
import { useEffect, useState } from "react";
import * as api from "../helpers/api";
import useLoading from "../hooks/useLoading";
import { marked } from "../labs/marked";
import Icon from "./Icon";
import { generateDialog } from "./Dialog";
import toastHelper from "./Toast";
import showSettingDialog from "./SettingDialog";
type Props = DialogProps;
interface History {
question: string;
answer: string;
}
const AskAIDialog: React.FC<Props> = (props: Props) => {
const { destroy, hide } = props;
const fetchingState = useLoading(false);
const [historyList, setHistoryList] = useState<History[]>([]);
const [isEnabled, setIsEnabled] = useState<boolean>(true);
useEffect(() => {
api.checkOpenAIEnabled().then(({ data }) => {
const { data: enabled } = data;
setIsEnabled(enabled);
});
}, []);
const handleGotoSystemSetting = () => {
showSettingDialog("system");
destroy();
};
const handleQuestionTextareaKeyDown = async (event: React.KeyboardEvent<HTMLTextAreaElement>) => {
if (event.key === "Enter") {
event.preventDefault();
const question = event.currentTarget.value;
event.currentTarget.value = "";
fetchingState.setLoading();
try {
await askQuestion(question);
} catch (error: any) {
console.error(error);
toastHelper.error(error.response.data.error);
}
fetchingState.setFinish();
}
};
const askQuestion = async (question: string) => {
if (question === "") {
return;
}
const {
data: { data: answer },
} = await api.postChatCompletion(question);
setHistoryList([
...historyList,
{
question,
answer: answer.replace(/^\n\n/, ""),
},
]);
};
return (
<>
<div className="dialog-header-container">
<p className="title-text flex flex-row items-center">
<Icon.Bot className="mr-1 w-5 h-auto opacity-80" />
Ask AI
</p>
<button className="btn close-btn" onClick={() => hide()}>
<Icon.X />
</button>
</div>
<div className="dialog-content-container !w-112">
<Textarea className="w-full" placeholder="Ask anything…" onKeyDown={handleQuestionTextareaKeyDown} />
{fetchingState.isLoading && (
<p className="w-full py-2 mt-4 flex flex-row justify-center items-center">
<Icon.Loader className="w-5 h-auto animate-spin" />
</p>
)}
{reverse(historyList).map((history, index) => (
<div key={index} className="w-full flex flex-col justify-start items-start mt-4 space-y-2">
<div className="w-full flex flex-row justify-start items-start pr-6">
<span className="word-break rounded shadow px-3 py-2 opacity-80 bg-gray-100 dark:bg-zinc-700">{history.question}</span>
</div>
<div className="w-full flex flex-row justify-end items-start pl-8 space-x-2">
<div className="memo-content-wrapper !w-auto flex flex-col justify-start items-start rounded shadow px-3 py-2 bg-gray-100 dark:bg-zinc-700">
<div className="memo-content-text">{marked(history.answer)}</div>
</div>
<Icon.Bot className="mt-2 flex-shrink-0 mr-1 w-6 h-auto opacity-80" />
</div>
</div>
))}
{!isEnabled && (
<div className="w-full flex flex-col justify-center items-center mt-4 space-y-2">
<p>You have not set up your OpenAI API key.</p>
<Button onClick={() => handleGotoSystemSetting()}>Go to settings</Button>
</div>
)}
</div>
</>
);
};
function showAskAIDialog() {
const dialogname = "ask-ai-dialog";
const dialogElement = document.body.querySelector(`div.${dialogname}`);
if (dialogElement) {
dialogElement.classList.remove("showoff");
dialogElement.classList.add("showup");
} else {
generateDialog(
{
className: dialogname,
dialogName: dialogname,
},
AskAIDialog
);
}
}
export default showAskAIDialog;

@ -234,7 +234,7 @@ const CreateResourceDialog: React.FC<Props> = (props: Props) => {
);
};
function showCreateResourceDialog(props: Omit<Props, "destroy">) {
function showCreateResourceDialog(props: Omit<Props, "destroy" | "hide">) {
generateDialog<Props>(
{
dialogName: "create-resource-dialog",

@ -66,7 +66,7 @@ const BaseDialog: React.FC<Props> = (props: Props) => {
export function generateDialog<T extends DialogProps>(
config: DialogConfig,
DialogComponent: React.FC<T>,
props?: Omit<T, "destroy">
props?: Omit<T, "destroy" | "hide">
): DialogCallback {
const tempDiv = document.createElement("div");
const dialog = createRoot(tempDiv);
@ -85,17 +85,22 @@ export function generateDialog<T extends DialogProps>(
tempDiv.remove();
}, ANIMATION_DURATION);
},
hide: () => {
tempDiv.firstElementChild?.classList.remove("showup");
tempDiv.firstElementChild?.classList.add("showoff");
},
};
const dialogProps = {
...props,
destroy: cbs.destroy,
hide: cbs.hide,
} as T;
const Fragment = (
<Provider store={store}>
<CssVarsProvider theme={theme}>
<BaseDialog destroy={cbs.destroy} clickSpaceDestroy={true} {...config}>
<BaseDialog destroy={cbs.destroy} hide={cbs.hide} clickSpaceDestroy={true} {...config}>
<DialogComponent {...dialogProps} />
</BaseDialog>
</CssVarsProvider>

@ -7,6 +7,8 @@ import Icon from "./Icon";
import showDailyReviewDialog from "./DailyReviewDialog";
import showResourcesDialog from "./ResourcesDialog";
import showSettingDialog from "./SettingDialog";
import showAskAIDialog from "./AskAIDialog";
import showArchivedMemoDialog from "./ArchivedMemoDialog";
import UserBanner from "./UserBanner";
import "../less/header.less";
@ -46,6 +48,12 @@ const Header = () => {
</Link>
{!userStore.isVisitorMode() && (
<>
<button
className="px-4 pr-5 py-2 rounded-lg flex flex-row items-center text-lg dark:text-gray-200 hover:bg-white hover:shadow dark:hover:bg-zinc-700"
onClick={() => showAskAIDialog()}
>
<Icon.Bot className="mr-4 w-6 h-auto opacity-80" /> Ask AI
</button>
<button
className="px-4 pr-5 py-2 rounded-lg flex flex-row items-center text-lg dark:text-gray-200 hover:bg-white hover:shadow dark:hover:bg-zinc-700"
onClick={() => showResourcesDialog()}
@ -54,7 +62,7 @@ const Header = () => {
</button>
<button
className="px-4 pr-5 py-2 rounded-lg flex flex-row items-center text-lg dark:text-gray-200 hover:bg-white hover:shadow dark:hover:bg-zinc-700"
onClick={() => showDailyReviewDialog()}
onClick={() => showArchivedMemoDialog()}
>
<Icon.Archive className="mr-4 w-6 h-auto opacity-80" /> {t("common.archive")}
</button>

@ -1,6 +1,6 @@
import { useEffect, useState } from "react";
import { useTranslation } from "react-i18next";
import { Button, Divider, Switch, Textarea } from "@mui/joy";
import { Button, Divider, Input, Switch, Textarea } from "@mui/joy";
import { useGlobalStore } from "../../store/module";
import * as api from "../../helpers/api";
import toastHelper from "../Toast";
@ -13,6 +13,7 @@ interface State {
dbSize: number;
allowSignUp: boolean;
disablePublicMemos: boolean;
openAIApiKey: string;
additionalStyle: string;
additionalScript: string;
}
@ -34,6 +35,7 @@ const SystemSection = () => {
dbSize: systemStatus.dbSize,
allowSignUp: systemStatus.allowSignUp,
additionalStyle: systemStatus.additionalStyle,
openAIApiKey: "",
additionalScript: systemStatus.additionalScript,
disablePublicMemos: systemStatus.disablePublicMemos,
});
@ -49,6 +51,7 @@ const SystemSection = () => {
dbSize: systemStatus.dbSize,
allowSignUp: systemStatus.allowSignUp,
additionalStyle: systemStatus.additionalStyle,
openAIApiKey: "",
additionalScript: systemStatus.additionalScript,
disablePublicMemos: systemStatus.disablePublicMemos,
});
@ -65,13 +68,6 @@ const SystemSection = () => {
});
};
const handleAdditionalStyleChanged = (value: string) => {
setState({
...state,
additionalStyle: value,
});
};
const handleUpdateCustomizedProfileButtonClick = () => {
showUpdateCustomizedProfileDialog();
};
@ -87,6 +83,33 @@ const SystemSection = () => {
toastHelper.success(t("message.succeed-vacuum-database"));
};
const handleOpenAIApiKeyChanged = (value: string) => {
setState({
...state,
openAIApiKey: value,
});
};
const handleSaveOpenAIApiKey = async () => {
try {
await api.upsertSystemSetting({
name: "openAIApiKey",
value: JSON.stringify(state.openAIApiKey),
});
} catch (error) {
console.error(error);
return;
}
toastHelper.success("OpenAI Api Key updated");
};
const handleAdditionalStyleChanged = (value: string) => {
setState({
...state,
additionalStyle: value,
});
};
const handleSaveAdditionalStyle = async () => {
try {
await api.upsertSystemSetting({
@ -107,19 +130,6 @@ const SystemSection = () => {
});
};
const handleDisablePublicMemosChanged = async (value: boolean) => {
setState({
...state,
disablePublicMemos: value,
});
// Update global store immediately as MemoEditor/Selector is dependent on this value.
dispatch(setGlobalState({ systemStatus: { ...systemStatus, disablePublicMemos: value } }));
await api.upsertSystemSetting({
name: "disablePublicMemos",
value: JSON.stringify(value),
});
};
const handleSaveAdditionalScript = async () => {
try {
await api.upsertSystemSetting({
@ -133,6 +143,19 @@ const SystemSection = () => {
toastHelper.success(t("message.succeed-update-additional-script"));
};
const handleDisablePublicMemosChanged = async (value: boolean) => {
setState({
...state,
disablePublicMemos: value,
});
// Update global store immediately as MemoEditor/Selector is dependent on this value.
dispatch(setGlobalState({ systemStatus: { ...systemStatus, disablePublicMemos: value } }));
await api.upsertSystemSetting({
name: "disablePublicMemos",
value: JSON.stringify(value),
});
};
return (
<div className="section-container system-section-container">
<p className="title-text">{t("common.basic")}</p>
@ -158,6 +181,21 @@ const SystemSection = () => {
<Switch checked={state.disablePublicMemos} onChange={(event) => handleDisablePublicMemosChanged(event.target.checked)} />
</div>
<Divider className="!mt-3 !my-4" />
<div className="form-label">
<span className="normal-text">OpenAI API Key</span>
<Button onClick={handleSaveOpenAIApiKey}>{t("common.save")}</Button>
</div>
<Input
className="w-full"
sx={{
fontFamily: "monospace",
fontSize: "14px",
}}
placeholder="Write only"
value={state.openAIApiKey}
onChange={(event) => handleOpenAIApiKeyChanged(event.target.value)}
/>
<Divider className="!mt-3 !my-4" />
<div className="form-label">
<span className="normal-text">{t("setting.system-section.additional-style")}</span>
<Button onClick={handleSaveAdditionalStyle}>{t("common.save")}</Button>
@ -168,7 +206,7 @@ const SystemSection = () => {
fontFamily: "monospace",
fontSize: "14px",
}}
minRows={4}
minRows={2}
maxRows={4}
placeholder={t("setting.system-section.additional-style-placeholder")}
value={state.additionalStyle}
@ -185,7 +223,7 @@ const SystemSection = () => {
fontFamily: "monospace",
fontSize: "14px",
}}
minRows={4}
minRows={2}
maxRows={4}
placeholder={t("setting.system-section.additional-script-placeholder")}
value={state.additionalScript}

@ -246,6 +246,22 @@ export function deleteIdentityProvider(id: IdentityProviderId) {
return axios.delete(`/api/idp/${id}`);
}
export function postChatCompletion(prompt: string) {
return axios.post<ResponseObject<string>>(`/api/opanai/chat-completion`, {
prompt,
});
}
export function postTextCompletion(prompt: string) {
return axios.post<ResponseObject<string>>(`/api/opanai/text-completion`, {
prompt,
});
}
export function checkOpenAIEnabled() {
return axios.get<ResponseObject<boolean>>(`/api/opanai/enabled`);
}
export async function getRepoStarCount() {
const { data } = await axios.get(`https://api.github.com/repos/usememos/memos`, {
headers: {

@ -1,7 +1,6 @@
interface DialogProps {
destroy: FunctionType;
}
interface DialogCallback {
destroy: FunctionType;
hide: FunctionType;
}
type DialogProps = DialogCallback;

Loading…
Cancel
Save