From b7ffca031ebda555c373783820056541307ceba0 Mon Sep 17 00:00:00 2001 From: Yidadaa Date: Fri, 10 Nov 2023 02:43:30 +0800 Subject: [PATCH] feat: close #935 add azure support --- app/api/auth.ts | 18 ++- app/api/common.ts | 31 +++-- app/azure.ts | 9 ++ app/client/api.ts | 16 ++- app/client/platforms/openai.ts | 48 ++++++-- app/components/auth.tsx | 6 +- app/components/chat.tsx | 4 +- app/components/settings.tsx | 187 ++++++++++++++++++++++++------ app/components/ui-lib.module.scss | 2 +- app/components/ui-lib.tsx | 14 +-- app/config/server.ts | 35 ++++-- app/constant.ts | 11 ++ app/locales/cn.ts | 66 ++++++++--- app/locales/en.ts | 66 ++++++++--- app/store/access.ts | 57 +++++++-- app/utils/clone.ts | 7 ++ app/utils/store.ts | 51 ++++---- 17 files changed, 478 insertions(+), 150 deletions(-) create mode 100644 app/azure.ts diff --git a/app/api/auth.ts b/app/api/auth.ts index e0453b2b47f..c1f6e7fdec2 100644 --- a/app/api/auth.ts +++ b/app/api/auth.ts @@ -28,7 +28,7 @@ export function auth(req: NextRequest) { const authToken = req.headers.get("Authorization") ?? ""; // check if it is openai api key or user token - const { accessCode, apiKey: token } = parseApiKey(authToken); + const { accessCode, apiKey } = parseApiKey(authToken); const hashedCode = md5.hash(accessCode ?? "").trim(); @@ -39,7 +39,7 @@ export function auth(req: NextRequest) { console.log("[User IP] ", getIP(req)); console.log("[Time] ", new Date().toLocaleString()); - if (serverConfig.needCode && !serverConfig.codes.has(hashedCode) && !token) { + if (serverConfig.needCode && !serverConfig.codes.has(hashedCode) && !apiKey) { return { error: true, msg: !accessCode ? "empty access code" : "wrong access code", @@ -47,11 +47,17 @@ export function auth(req: NextRequest) { } // if user does not provide an api key, inject system api key - if (!token) { - const apiKey = serverConfig.apiKey; - if (apiKey) { + if (!apiKey) { + const serverApiKey = serverConfig.isAzure + ? serverConfig.azureApiKey + : serverConfig.apiKey; + + if (serverApiKey) { console.log("[Auth] use system api key"); - req.headers.set("Authorization", `Bearer ${apiKey}`); + req.headers.set( + "Authorization", + `${serverConfig.isAzure ? "" : "Bearer "}${serverApiKey}`, + ); } else { console.log("[Auth] admin did not provide an api key"); } diff --git a/app/api/common.ts b/app/api/common.ts index a1decd42f5b..fc877b02db2 100644 --- a/app/api/common.ts +++ b/app/api/common.ts @@ -1,19 +1,24 @@ import { NextRequest, NextResponse } from "next/server"; import { getServerSideConfig } from "../config/server"; import { DEFAULT_MODELS, OPENAI_BASE_URL } from "../constant"; -import { collectModelTable, collectModels } from "../utils/model"; +import { collectModelTable } from "../utils/model"; +import { makeAzurePath } from "../azure"; const serverConfig = getServerSideConfig(); export async function requestOpenai(req: NextRequest) { const controller = new AbortController(); + const authValue = req.headers.get("Authorization") ?? ""; - const openaiPath = `${req.nextUrl.pathname}${req.nextUrl.search}`.replaceAll( + const authHeaderName = serverConfig.isAzure ? "api-key" : "Authorization"; + + let path = `${req.nextUrl.pathname}${req.nextUrl.search}`.replaceAll( "/api/openai/", "", ); - let baseUrl = serverConfig.baseUrl ?? OPENAI_BASE_URL; + let baseUrl = + serverConfig.azureUrl ?? serverConfig.baseUrl ?? OPENAI_BASE_URL; if (!baseUrl.startsWith("http")) { baseUrl = `https://${baseUrl}`; @@ -23,7 +28,7 @@ export async function requestOpenai(req: NextRequest) { baseUrl = baseUrl.slice(0, -1); } - console.log("[Proxy] ", openaiPath); + console.log("[Proxy] ", path); console.log("[Base Url]", baseUrl); console.log("[Org ID]", serverConfig.openaiOrgId); @@ -34,14 +39,24 @@ export async function requestOpenai(req: NextRequest) { 10 * 60 * 1000, ); - const fetchUrl = `${baseUrl}/${openaiPath}`; + if (serverConfig.isAzure) { + if (!serverConfig.azureApiVersion) { + return NextResponse.json({ + error: true, + message: `missing AZURE_API_VERSION in server env vars`, + }); + } + path = makeAzurePath(path, serverConfig.azureApiVersion); + } + + const fetchUrl = `${baseUrl}/${path}`; const fetchOptions: RequestInit = { headers: { "Content-Type": "application/json", "Cache-Control": "no-store", - Authorization: authValue, - ...(process.env.OPENAI_ORG_ID && { - "OpenAI-Organization": process.env.OPENAI_ORG_ID, + [authHeaderName]: authValue, + ...(serverConfig.openaiOrgId && { + "OpenAI-Organization": serverConfig.openaiOrgId, }), }, method: req.method, diff --git a/app/azure.ts b/app/azure.ts new file mode 100644 index 00000000000..48406c55ba5 --- /dev/null +++ b/app/azure.ts @@ -0,0 +1,9 @@ +export function makeAzurePath(path: string, apiVersion: string) { + // should omit /v1 prefix + path = path.replaceAll("v1/", ""); + + // should add api-key to query string + path += `${path.includes("?") ? "&" : "?"}api-version=${apiVersion}`; + + return path; +} diff --git a/app/client/api.ts b/app/client/api.ts index b04dd88b88c..eedd2c9ab48 100644 --- a/app/client/api.ts +++ b/app/client/api.ts @@ -1,5 +1,5 @@ import { getClientConfig } from "../config/client"; -import { ACCESS_CODE_PREFIX } from "../constant"; +import { ACCESS_CODE_PREFIX, Azure, ServiceProvider } from "../constant"; import { ChatMessage, ModelType, useAccessStore } from "../store"; import { ChatGPTApi } from "./platforms/openai"; @@ -127,22 +127,26 @@ export const api = new ClientApi(); export function getHeaders() { const accessStore = useAccessStore.getState(); - let headers: Record = { + const headers: Record = { "Content-Type": "application/json", "x-requested-with": "XMLHttpRequest", }; - const makeBearer = (token: string) => `Bearer ${token.trim()}`; + const isAzure = accessStore.provider === ServiceProvider.Azure; + const authHeader = isAzure ? "api-key" : "Authorization"; + const apiKey = isAzure ? accessStore.azureApiKey : accessStore.openaiApiKey; + + const makeBearer = (s: string) => `${isAzure ? "" : "Bearer "}${s.trim()}`; const validString = (x: string) => x && x.length > 0; // use user's api key first - if (validString(accessStore.token)) { - headers.Authorization = makeBearer(accessStore.token); + if (validString(apiKey)) { + headers[authHeader] = makeBearer(apiKey); } else if ( accessStore.enabledAccessControl() && validString(accessStore.accessCode) ) { - headers.Authorization = makeBearer( + headers[authHeader] = makeBearer( ACCESS_CODE_PREFIX + accessStore.accessCode, ); } diff --git a/app/client/platforms/openai.ts b/app/client/platforms/openai.ts index 4a5ddce7de6..930d606900a 100644 --- a/app/client/platforms/openai.ts +++ b/app/client/platforms/openai.ts @@ -1,8 +1,10 @@ import { + ApiPath, DEFAULT_API_HOST, DEFAULT_MODELS, OpenaiPath, REQUEST_TIMEOUT_MS, + ServiceProvider, } from "@/app/constant"; import { useAccessStore, useAppConfig, useChatStore } from "@/app/store"; @@ -14,6 +16,7 @@ import { } from "@fortaine/fetch-event-source"; import { prettyObject } from "@/app/utils/format"; import { getClientConfig } from "@/app/config/client"; +import { makeAzurePath } from "@/app/azure"; export interface OpenAIListModelResponse { object: string; @@ -28,20 +31,35 @@ export class ChatGPTApi implements LLMApi { private disableListModels = true; path(path: string): string { - let openaiUrl = useAccessStore.getState().openaiUrl; - const apiPath = "/api/openai"; + const accessStore = useAccessStore.getState(); - if (openaiUrl.length === 0) { + const isAzure = accessStore.provider === ServiceProvider.Azure; + + if (isAzure && !accessStore.isValidAzure()) { + throw Error( + "incomplete azure config, please check it in your settings page", + ); + } + + let baseUrl = isAzure ? accessStore.azureUrl : accessStore.openaiUrl; + + if (baseUrl.length === 0) { const isApp = !!getClientConfig()?.isApp; - openaiUrl = isApp ? DEFAULT_API_HOST : apiPath; + baseUrl = isApp ? DEFAULT_API_HOST : ApiPath.OpenAI; } - if (openaiUrl.endsWith("/")) { - openaiUrl = openaiUrl.slice(0, openaiUrl.length - 1); + + if (baseUrl.endsWith("/")) { + baseUrl = baseUrl.slice(0, baseUrl.length - 1); + } + if (!baseUrl.startsWith("http") && !baseUrl.startsWith(ApiPath.OpenAI)) { + baseUrl = "https://" + baseUrl; } - if (!openaiUrl.startsWith("http") && !openaiUrl.startsWith(apiPath)) { - openaiUrl = "https://" + openaiUrl; + + if (isAzure) { + path = makeAzurePath(path, accessStore.azureApiVersion); } - return [openaiUrl, path].join("/"); + + return [baseUrl, path].join("/"); } extractMessage(res: any) { @@ -156,14 +174,20 @@ export class ChatGPTApi implements LLMApi { } const text = msg.data; try { - const json = JSON.parse(text); - const delta = json.choices[0].delta.content; + const json = JSON.parse(text) as { + choices: Array<{ + delta: { + content: string; + }; + }>; + }; + const delta = json.choices[0]?.delta?.content; if (delta) { responseText += delta; options.onUpdate?.(responseText, delta); } } catch (e) { - console.error("[Request] parse error", text, msg); + console.error("[Request] parse error", text); } }, onclose() { diff --git a/app/components/auth.tsx b/app/components/auth.tsx index 577d7754240..3e1548a1325 100644 --- a/app/components/auth.tsx +++ b/app/components/auth.tsx @@ -18,7 +18,7 @@ export function AuthPage() { const goChat = () => navigate(Path.Chat); const resetAccessCode = () => { accessStore.update((access) => { - access.token = ""; + access.openaiApiKey = ""; access.accessCode = ""; }); }; // Reset access code to empty string @@ -57,10 +57,10 @@ export function AuthPage() { className={styles["auth-input"]} type="password" placeholder={Locale.Settings.Token.Placeholder} - value={accessStore.token} + value={accessStore.openaiApiKey} onChange={(e) => { accessStore.update( - (access) => (access.token = e.currentTarget.value), + (access) => (access.openaiApiKey = e.currentTarget.value), ); }} /> diff --git a/app/components/chat.tsx b/app/components/chat.tsx index 9afb49f7a66..c27c3eee464 100644 --- a/app/components/chat.tsx +++ b/app/components/chat.tsx @@ -998,7 +998,9 @@ function _Chat() { ).then((res) => { if (!res) return; if (payload.key) { - accessStore.update((access) => (access.token = payload.key!)); + accessStore.update( + (access) => (access.openaiApiKey = payload.key!), + ); } if (payload.url) { accessStore.update((access) => (access.openaiUrl = payload.url!)); diff --git a/app/components/settings.tsx b/app/components/settings.tsx index 572c0743a11..178fcec57e9 100644 --- a/app/components/settings.tsx +++ b/app/components/settings.tsx @@ -51,10 +51,13 @@ import Locale, { import { copyToClipboard } from "../utils"; import Link from "next/link"; import { + Azure, OPENAI_BASE_URL, Path, RELEASE_URL, STORAGE_KEY, + ServiceProvider, + SlotID, UPDATE_URL, } from "../constant"; import { Prompt, SearchService, usePromptStore } from "../store/prompt"; @@ -580,8 +583,16 @@ export function Settings() { const accessStore = useAccessStore(); const shouldHideBalanceQuery = useMemo(() => { const isOpenAiUrl = accessStore.openaiUrl.includes(OPENAI_BASE_URL); - return accessStore.hideBalanceQuery || isOpenAiUrl; - }, [accessStore.hideBalanceQuery, accessStore.openaiUrl]); + return ( + accessStore.hideBalanceQuery || + isOpenAiUrl || + accessStore.provider === ServiceProvider.Azure + ); + }, [ + accessStore.hideBalanceQuery, + accessStore.openaiUrl, + accessStore.provider, + ]); const usage = { used: updateStore.used, @@ -877,16 +888,16 @@ export function Settings() { - - {showAccessCode ? ( + + {showAccessCode && ( { accessStore.update( (access) => (access.accessCode = e.currentTarget.value), @@ -894,44 +905,152 @@ export function Settings() { }} /> - ) : ( - <> )} - {!accessStore.hideUserApiKey ? ( + {!accessStore.hideUserApiKey && ( <> accessStore.update( - (access) => (access.openaiUrl = e.currentTarget.value), + (access) => + (access.useCustomConfig = e.currentTarget.checked), ) } > - - { - accessStore.update( - (access) => (access.token = e.currentTarget.value), - ); - }} - /> - + {accessStore.useCustomConfig && ( + <> + + + + + {accessStore.provider === "OpenAI" ? ( + <> + + + accessStore.update( + (access) => + (access.openaiUrl = e.currentTarget.value), + ) + } + > + + + { + accessStore.update( + (access) => + (access.openaiApiKey = e.currentTarget.value), + ); + }} + /> + + + ) : ( + <> + + + accessStore.update( + (access) => + (access.azureUrl = e.currentTarget.value), + ) + } + > + + + { + accessStore.update( + (access) => + (access.azureApiKey = e.currentTarget.value), + ); + }} + /> + + + + accessStore.update( + (access) => + (access.azureApiVersion = + e.currentTarget.value), + ) + } + > + + + )} + + )} - ) : null} + )} {!shouldHideBalanceQuery ? ( - | JSX.Element - | null - | undefined; -}) { - return
{props.children}
; +export function List(props: { children: React.ReactNode; id?: string }) { + return ( +
+ {props.children} +
+ ); } export function Loading() { diff --git a/app/config/server.ts b/app/config/server.ts index 007c3973863..2f2e7d7fd8a 100644 --- a/app/config/server.ts +++ b/app/config/server.ts @@ -4,19 +4,28 @@ import { DEFAULT_MODELS } from "../constant"; declare global { namespace NodeJS { interface ProcessEnv { + PROXY_URL?: string; // docker only + OPENAI_API_KEY?: string; CODE?: string; + BASE_URL?: string; - PROXY_URL?: string; - OPENAI_ORG_ID?: string; + OPENAI_ORG_ID?: string; // openai only + VERCEL?: string; - HIDE_USER_API_KEY?: string; // disable user's api key input - DISABLE_GPT4?: string; // allow user to use gpt-4 or not BUILD_MODE?: "standalone" | "export"; BUILD_APP?: string; // is building desktop app + + HIDE_USER_API_KEY?: string; // disable user's api key input + DISABLE_GPT4?: string; // allow user to use gpt-4 or not ENABLE_BALANCE_QUERY?: string; // allow user to query balance or not DISABLE_FAST_LINK?: string; // disallow parse settings from url or not CUSTOM_MODELS?: string; // to control custom models + + // azure only + AZURE_URL?: string; // https://{azure-url}/openai/deployments/{deploy-name} + AZURE_API_KEY?: string; + AZURE_API_VERSION?: string; } } } @@ -41,7 +50,7 @@ export const getServerSideConfig = () => { ); } - let disableGPT4 = !!process.env.DISABLE_GPT4; + const disableGPT4 = !!process.env.DISABLE_GPT4; let customModels = process.env.CUSTOM_MODELS ?? ""; if (disableGPT4) { @@ -51,15 +60,25 @@ export const getServerSideConfig = () => { .join(","); } + const isAzure = !!process.env.AZURE_URL; + return { + baseUrl: process.env.BASE_URL, apiKey: process.env.OPENAI_API_KEY, + openaiOrgId: process.env.OPENAI_ORG_ID, + + isAzure, + azureUrl: process.env.AZURE_URL, + azureApiKey: process.env.AZURE_API_KEY, + azureApiVersion: process.env.AZURE_API_VERSION, + + needCode: ACCESS_CODES.size > 0, code: process.env.CODE, codes: ACCESS_CODES, - needCode: ACCESS_CODES.size > 0, - baseUrl: process.env.BASE_URL, + proxyUrl: process.env.PROXY_URL, - openaiOrgId: process.env.OPENAI_ORG_ID, isVercel: !!process.env.VERCEL, + hideUserApiKey: !!process.env.HIDE_USER_API_KEY, disableGPT4, hideBalanceQuery: !process.env.ENABLE_BALANCE_QUERY, diff --git a/app/constant.ts b/app/constant.ts index a97b8782292..fbc0c72e378 100644 --- a/app/constant.ts +++ b/app/constant.ts @@ -23,10 +23,12 @@ export enum Path { export enum ApiPath { Cors = "/api/cors", + OpenAI = "/api/openai", } export enum SlotID { AppBody = "app-body", + CustomModel = "custom-model", } export enum FileName { @@ -60,6 +62,11 @@ export const REQUEST_TIMEOUT_MS = 60000; export const EXPORT_MESSAGE_CLASS_NAME = "export-markdown"; +export enum ServiceProvider { + OpenAI = "OpenAI", + Azure = "Azure", +} + export const OpenaiPath = { ChatPath: "v1/chat/completions", UsagePath: "dashboard/billing/usage", @@ -67,6 +74,10 @@ export const OpenaiPath = { ListModelPath: "v1/models", }; +export const Azure = { + ExampleEndpoint: "https://{resource-url}/openai/deployments/{deploy-id}", +}; + export const DEFAULT_INPUT_TEMPLATE = `{{input}}`; // input / time / model / lang export const DEFAULT_SYSTEM_TEMPLATE = ` You are ChatGPT, a large language model trained by OpenAI. diff --git a/app/locales/cn.ts b/app/locales/cn.ts index 4cd963fb8e2..e721adef7ae 100644 --- a/app/locales/cn.ts +++ b/app/locales/cn.ts @@ -258,11 +258,6 @@ const cn = { Title: "历史消息长度压缩阈值", SubTitle: "当未压缩的历史消息超过该值时,将进行压缩", }, - Token: { - Title: "API Key", - SubTitle: "使用自己的 Key 可绕过密码访问限制", - Placeholder: "OpenAI API Key", - }, Usage: { Title: "余额查询", @@ -273,19 +268,56 @@ const cn = { Check: "重新检查", NoAccess: "输入 API Key 或访问密码查看余额", }, - AccessCode: { - Title: "访问密码", - SubTitle: "管理员已开启加密访问", - Placeholder: "请输入访问密码", - }, - Endpoint: { - Title: "接口地址", - SubTitle: "除默认地址外,必须包含 http(s)://", - }, - CustomModel: { - Title: "自定义模型名", - SubTitle: "增加自定义模型可选项,使用英文逗号隔开", + + Access: { + AccessCode: { + Title: "访问密码", + SubTitle: "管理员已开启加密访问", + Placeholder: "请输入访问密码", + }, + CustomEndpoint: { + Title: "自定义接口", + SubTitle: "是否使用自定义 Azure 或 OpenAI 服务", + }, + Provider: { + Title: "模型服务商", + SubTitle: "切换不同的服务商", + }, + OpenAI: { + ApiKey: { + Title: "API Key", + SubTitle: "使用自定义 OpenAI Key 绕过密码访问限制", + Placeholder: "OpenAI API Key", + }, + + Endpoint: { + Title: "接口地址", + SubTitle: "除默认地址外,必须包含 http(s)://", + }, + }, + Azure: { + ApiKey: { + Title: "接口密钥", + SubTitle: "使用自定义 Azure Key 绕过密码访问限制", + Placeholder: "Azure API Key", + }, + + Endpoint: { + Title: "接口地址", + SubTitle: "样例:", + }, + + ApiVerion: { + Title: "接口版本 (azure api version)", + SubTitle: "选择指定的部分版本", + }, + }, + CustomModel: { + Title: "自定义模型名", + SubTitle: "增加自定义模型可选项,使用英文逗号隔开", + }, }, + Model: "模型 (model)", Temperature: { Title: "随机性 (temperature)", diff --git a/app/locales/en.ts b/app/locales/en.ts index 928c4b72d4e..c6e61ecab04 100644 --- a/app/locales/en.ts +++ b/app/locales/en.ts @@ -262,11 +262,7 @@ const en: LocaleType = { SubTitle: "Will compress if uncompressed messages length exceeds the value", }, - Token: { - Title: "API Key", - SubTitle: "Use your key to ignore access code limit", - Placeholder: "OpenAI API Key", - }, + Usage: { Title: "Account Balance", SubTitle(used: any, total: any) { @@ -276,19 +272,55 @@ const en: LocaleType = { Check: "Check", NoAccess: "Enter API Key to check balance", }, - AccessCode: { - Title: "Access Code", - SubTitle: "Access control enabled", - Placeholder: "Need Access Code", - }, - Endpoint: { - Title: "Endpoint", - SubTitle: "Custom endpoint must start with http(s)://", - }, - CustomModel: { - Title: "Custom Models", - SubTitle: "Add extra model options, separate by comma", + Access: { + AccessCode: { + Title: "Access Code", + SubTitle: "Access control Enabled", + Placeholder: "Enter Code", + }, + CustomEndpoint: { + Title: "Custom Endpoint", + SubTitle: "Use custom Azure or OpenAI service", + }, + Provider: { + Title: "Model Provider", + SubTitle: "Select Azure or OpenAI", + }, + OpenAI: { + ApiKey: { + Title: "OpenAI API Key", + SubTitle: "User custom OpenAI Api Key", + Placeholder: "sk-xxx", + }, + + Endpoint: { + Title: "OpenAI Endpoint", + SubTitle: "Must starts with http(s):// or use /api/openai as default", + }, + }, + Azure: { + ApiKey: { + Title: "Azure Api Key", + SubTitle: "Check your api key from Azure console", + Placeholder: "Azure Api Key", + }, + + Endpoint: { + Title: "Azure Endpoint", + SubTitle: "Example: ", + }, + + ApiVerion: { + Title: "Azure Api Version", + SubTitle: "Check your api version from azure console", + }, + }, + CustomModel: { + Title: "Custom Models", + SubTitle: "Custom model options, seperated by comma", + }, }, + Model: "Model", Temperature: { Title: "Temperature", diff --git a/app/store/access.ts b/app/store/access.ts index f87e44a2ac4..2abe1e3cc9f 100644 --- a/app/store/access.ts +++ b/app/store/access.ts @@ -1,25 +1,41 @@ -import { DEFAULT_API_HOST, DEFAULT_MODELS, StoreKey } from "../constant"; +import { + ApiPath, + DEFAULT_API_HOST, + ServiceProvider, + StoreKey, +} from "../constant"; import { getHeaders } from "../client/api"; import { getClientConfig } from "../config/client"; import { createPersistStore } from "../utils/store"; +import { ensure } from "../utils/clone"; let fetchState = 0; // 0 not fetch, 1 fetching, 2 done const DEFAULT_OPENAI_URL = - getClientConfig()?.buildMode === "export" ? DEFAULT_API_HOST : "/api/openai/"; -console.log("[API] default openai url", DEFAULT_OPENAI_URL); + getClientConfig()?.buildMode === "export" ? DEFAULT_API_HOST : ApiPath.OpenAI; const DEFAULT_ACCESS_STATE = { - token: "", accessCode: "", + useCustomConfig: false, + + provider: ServiceProvider.OpenAI, + + // openai + openaiUrl: DEFAULT_OPENAI_URL, + openaiApiKey: "", + + // azure + azureUrl: "", + azureApiKey: "", + azureApiVersion: "2023-08-01-preview", + + // server config needCode: true, hideUserApiKey: false, hideBalanceQuery: false, disableGPT4: false, disableFastLink: false, customModels: "", - - openaiUrl: DEFAULT_OPENAI_URL, }; export const useAccessStore = createPersistStore( @@ -31,12 +47,24 @@ export const useAccessStore = createPersistStore( return get().needCode; }, + + isValidOpenAI() { + return ensure(get(), ["openaiUrl", "openaiApiKey"]); + }, + + isValidAzure() { + return ensure(get(), ["azureUrl", "azureApiKey", "azureApiVersion"]); + }, + isAuthorized() { this.fetch(); // has token or has code or disabled access control return ( - !!get().token || !!get().accessCode || !this.enabledAccessControl() + this.isValidOpenAI() || + this.isValidAzure() || + !this.enabledAccessControl() || + (this.enabledAccessControl() && ensure(get(), ["accessCode"])) ); }, fetch() { @@ -64,6 +92,19 @@ export const useAccessStore = createPersistStore( }), { name: StoreKey.Access, - version: 1, + version: 2, + migrate(persistedState, version) { + if (version < 2) { + const state = persistedState as { + token: string; + openaiApiKey: string; + azureApiVersion: string; + }; + state.openaiApiKey = state.token; + state.azureApiVersion = "2023-08-01-preview"; + } + + return persistedState as any; + }, }, ); diff --git a/app/utils/clone.ts b/app/utils/clone.ts index 2958b6b9c35..c42288f7789 100644 --- a/app/utils/clone.ts +++ b/app/utils/clone.ts @@ -1,3 +1,10 @@ export function deepClone(obj: T) { return JSON.parse(JSON.stringify(obj)); } + +export function ensure( + obj: T, + keys: Array<[keyof T][number]>, +) { + return keys.every((k) => obj[k] !== undefined && obj[k] !== null); +} diff --git a/app/utils/store.ts b/app/utils/store.ts index cd151dc4925..684a1911279 100644 --- a/app/utils/store.ts +++ b/app/utils/store.ts @@ -1,5 +1,5 @@ import { create } from "zustand"; -import { persist } from "zustand/middleware"; +import { combine, persist } from "zustand/middleware"; import { Updater } from "../typing"; import { deepClone } from "./clone"; @@ -23,33 +23,42 @@ type SetStoreState = ( replace?: boolean | undefined, ) => void; -export function createPersistStore( - defaultState: T, +export function createPersistStore( + state: T, methods: ( set: SetStoreState>, get: () => T & MakeUpdater, ) => M, persistOptions: SecondParam>>, ) { - return create>()( - persist((set, get) => { - return { - ...defaultState, - ...methods(set as any, get), - - lastUpdateTime: 0, - markUpdate() { - set({ lastUpdateTime: Date.now() } as Partial< - T & M & MakeUpdater - >); + return create( + persist( + combine( + { + ...state, + lastUpdateTime: 0, }, - update(updater) { - const state = deepClone(get()); - updater(state); - get().markUpdate(); - set(state); + (set, get) => { + return { + ...methods(set, get as any), + + markUpdate() { + set({ lastUpdateTime: Date.now() } as Partial< + T & M & MakeUpdater + >); + }, + update(updater) { + const state = deepClone(get()); + updater(state); + set({ + ...state, + lastUpdateTime: Date.now(), + }); + }, + } as M & MakeUpdater; }, - }; - }, persistOptions), + ), + persistOptions as any, + ), ); }