Skip to content

Commit

Permalink
Merge pull request #5769 from ryanhex53/fix-model-multi@
Browse files Browse the repository at this point in the history
Custom model names can include the `@` symbol by itself.
  • Loading branch information
Dogtiti authored Nov 6, 2024
2 parents 00d6cb2 + 8e2484f commit f3603e5
Show file tree
Hide file tree
Showing 6 changed files with 60 additions and 11 deletions.
4 changes: 2 additions & 2 deletions app/api/common.ts
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import { NextRequest, NextResponse } from "next/server";
import { getServerSideConfig } from "../config/server";
import { OPENAI_BASE_URL, ServiceProvider } from "../constant";
import { isModelAvailableInServer } from "../utils/model";
import { cloudflareAIGatewayUrl } from "../utils/cloudflare";
import { getModelProvider, isModelAvailableInServer } from "../utils/model";

const serverConfig = getServerSideConfig();

Expand Down Expand Up @@ -71,7 +71,7 @@ export async function requestOpenai(req: NextRequest) {
.filter((v) => !!v && !v.startsWith("-") && v.includes(modelName))
.forEach((m) => {
const [fullName, displayName] = m.split("=");
const [_, providerName] = fullName.split("@");
const [_, providerName] = getModelProvider(fullName);
if (providerName === "azure" && !displayName) {
const [_, deployId] = (serverConfig?.azureUrl ?? "").split(
"deployments/",
Expand Down
3 changes: 2 additions & 1 deletion app/components/chat.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@ import { createTTSPlayer } from "../utils/audio";
import { MsEdgeTTS, OUTPUT_FORMAT } from "../utils/ms_edge_tts";

import { isEmpty } from "lodash-es";
import { getModelProvider } from "../utils/model";

const localStorage = safeLocalStorage();

Expand Down Expand Up @@ -645,7 +646,7 @@ export function ChatActions(props: {
onClose={() => setShowModelSelector(false)}
onSelection={(s) => {
if (s.length === 0) return;
const [model, providerName] = s[0].split("@");
const [model, providerName] = getModelProvider(s[0]);
chatStore.updateCurrentSession((session) => {
session.mask.modelConfig.model = model as ModelType;
session.mask.modelConfig.providerName =
Expand Down
9 changes: 7 additions & 2 deletions app/components/model-config.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import { ListItem, Select } from "./ui-lib";
import { useAllModels } from "../utils/hooks";
import { groupBy } from "lodash-es";
import styles from "./model-config.module.scss";
import { getModelProvider } from "../utils/model";

export function ModelConfigList(props: {
modelConfig: ModelConfig;
Expand All @@ -28,7 +29,9 @@ export function ModelConfigList(props: {
value={value}
align="left"
onChange={(e) => {
const [model, providerName] = e.currentTarget.value.split("@");
const [model, providerName] = getModelProvider(
e.currentTarget.value,
);
props.updateConfig((config) => {
config.model = ModalConfigValidator.model(model);
config.providerName = providerName as ServiceProvider;
Expand Down Expand Up @@ -247,7 +250,9 @@ export function ModelConfigList(props: {
aria-label={Locale.Settings.CompressModel.Title}
value={compressModelValue}
onChange={(e) => {
const [model, providerName] = e.currentTarget.value.split("@");
const [model, providerName] = getModelProvider(
e.currentTarget.value,
);
props.updateConfig((config) => {
config.compressModel = ModalConfigValidator.model(model);
config.compressProviderName = providerName as ServiceProvider;
Expand Down
5 changes: 3 additions & 2 deletions app/store/access.ts
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import { getClientConfig } from "../config/client";
import { createPersistStore } from "../utils/store";
import { ensure } from "../utils/clone";
import { DEFAULT_CONFIG } from "./config";
import { getModelProvider } from "../utils/model";

let fetchState = 0; // 0 not fetch, 1 fetching, 2 done

Expand Down Expand Up @@ -226,9 +227,9 @@ export const useAccessStore = createPersistStore(
.then((res) => {
const defaultModel = res.defaultModel ?? "";
if (defaultModel !== "") {
const [model, providerName] = defaultModel.split("@");
const [model, providerName] = getModelProvider(defaultModel);
DEFAULT_CONFIG.modelConfig.model = model;
DEFAULT_CONFIG.modelConfig.providerName = providerName;
DEFAULT_CONFIG.modelConfig.providerName = providerName as any;
}

return res;
Expand Down
19 changes: 15 additions & 4 deletions app/utils/model.ts
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,17 @@ const sortModelTable = (models: ReturnType<typeof collectModels>) =>
}
});

/**
* get model name and provider from a formatted string,
* e.g. `gpt-4@OpenAi` or `claude-3-5-sonnet@20240620@Google`
* @param modelWithProvider model name with provider separated by last `@` char,
* @returns [model, provider] tuple, if no `@` char found, provider is undefined
*/
export function getModelProvider(modelWithProvider: string): [string, string?] {
const [model, provider] = modelWithProvider.split(/@(?!.*@)/);
return [model, provider];
}

export function collectModelTable(
models: readonly LLMModel[],
customModels: string,
Expand Down Expand Up @@ -79,10 +90,10 @@ export function collectModelTable(
);
} else {
// 1. find model by name, and set available value
const [customModelName, customProviderName] = name.split("@");
const [customModelName, customProviderName] = getModelProvider(name);
let count = 0;
for (const fullName in modelTable) {
const [modelName, providerName] = fullName.split("@");
const [modelName, providerName] = getModelProvider(fullName);
if (
customModelName == modelName &&
(customProviderName === undefined ||
Expand All @@ -102,7 +113,7 @@ export function collectModelTable(
}
// 2. if model not exists, create new model with available value
if (count === 0) {
let [customModelName, customProviderName] = name.split("@");
let [customModelName, customProviderName] = getModelProvider(name);
const provider = customProvider(
customProviderName || customModelName,
);
Expand Down Expand Up @@ -139,7 +150,7 @@ export function collectModelTableWithDefaultModel(
for (const key of Object.keys(modelTable)) {
if (
modelTable[key].available &&
key.split("@").shift() == defaultModel
getModelProvider(key)[0] == defaultModel
) {
modelTable[key].isDefault = true;
break;
Expand Down
31 changes: 31 additions & 0 deletions test/model-provider.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
import { getModelProvider } from "../app/utils/model";

describe("getModelProvider", () => {
test("should return model and provider when input contains '@'", () => {
const input = "model@provider";
const [model, provider] = getModelProvider(input);
expect(model).toBe("model");
expect(provider).toBe("provider");
});

test("should return model and undefined provider when input does not contain '@'", () => {
const input = "model";
const [model, provider] = getModelProvider(input);
expect(model).toBe("model");
expect(provider).toBeUndefined();
});

test("should handle multiple '@' characters correctly", () => {
const input = "model@provider@extra";
const [model, provider] = getModelProvider(input);
expect(model).toBe("model@provider");
expect(provider).toBe("extra");
});

test("should return empty strings when input is empty", () => {
const input = "";
const [model, provider] = getModelProvider(input);
expect(model).toBe("");
expect(provider).toBeUndefined();
});
});

0 comments on commit f3603e5

Please sign in to comment.