Skip to content

Commit

Permalink
Merge pull request #4923 from ConnectAI-E/refactor-model-table
Browse files Browse the repository at this point in the history
Refactor model table
  • Loading branch information
Dogtiti authored Jul 4, 2024
2 parents 8cb204e + 31d9444 commit c4a6c93
Show file tree
Hide file tree
Showing 4 changed files with 74 additions and 34 deletions.
15 changes: 9 additions & 6 deletions app/api/anthropic/[...path]/route.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,13 @@ import {
Anthropic,
ApiPath,
DEFAULT_MODELS,
ServiceProvider,
ModelProvider,
} from "@/app/constant";
import { prettyObject } from "@/app/utils/format";
import { NextRequest, NextResponse } from "next/server";
import { auth } from "../../auth";
import { collectModelTable } from "@/app/utils/model";
import { isModelAvailableInServer } from "@/app/utils/model";

const ALLOWD_PATH = new Set([Anthropic.ChatPath, Anthropic.ChatPath1]);

Expand Down Expand Up @@ -136,17 +137,19 @@ async function request(req: NextRequest) {
// #1815 try to refuse some request to some models
if (serverConfig.customModels && req.body) {
try {
const modelTable = collectModelTable(
DEFAULT_MODELS,
serverConfig.customModels,
);
const clonedBody = await req.text();
fetchOptions.body = clonedBody;

const jsonBody = JSON.parse(clonedBody) as { model?: string };

// not undefined and is false
if (modelTable[jsonBody?.model ?? ""].available === false) {
if (
isModelAvailableInServer(
serverConfig.customModels,
jsonBody?.model as string,
ServiceProvider.Anthropic as string,
)
) {
return NextResponse.json(
{
error: true,
Expand Down
46 changes: 28 additions & 18 deletions app/api/common.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,12 @@
import { NextRequest, NextResponse } from "next/server";
import { getServerSideConfig } from "../config/server";
import { DEFAULT_MODELS, OPENAI_BASE_URL, GEMINI_BASE_URL } from "../constant";
import { collectModelTable } from "../utils/model";
import {
DEFAULT_MODELS,
OPENAI_BASE_URL,
GEMINI_BASE_URL,
ServiceProvider,
} from "../constant";
import { isModelAvailableInServer } from "../utils/model";
import { makeAzurePath } from "../azure";

const serverConfig = getServerSideConfig();
Expand Down Expand Up @@ -83,17 +88,24 @@ export async function requestOpenai(req: NextRequest) {
// #1815 try to refuse gpt4 request
if (serverConfig.customModels && req.body) {
try {
const modelTable = collectModelTable(
DEFAULT_MODELS,
serverConfig.customModels,
);
const clonedBody = await req.text();
fetchOptions.body = clonedBody;

const jsonBody = JSON.parse(clonedBody) as { model?: string };

// not undefined and is false
if (modelTable[jsonBody?.model ?? ""].available === false) {
if (
isModelAvailableInServer(
serverConfig.customModels,
jsonBody?.model as string,
ServiceProvider.OpenAI as string,
) ||
isModelAvailableInServer(
serverConfig.customModels,
jsonBody?.model as string,
ServiceProvider.Azure as string,
)
) {
return NextResponse.json(
{
error: true,
Expand All @@ -112,24 +124,23 @@ export async function requestOpenai(req: NextRequest) {
try {
const res = await fetch(fetchUrl, fetchOptions);

// Extract the OpenAI-Organization header from the response
const openaiOrganizationHeader = res.headers.get("OpenAI-Organization");
// Extract the OpenAI-Organization header from the response
const openaiOrganizationHeader = res.headers.get("OpenAI-Organization");

// Check if serverConfig.openaiOrgId is defined and not an empty string
if (serverConfig.openaiOrgId && serverConfig.openaiOrgId.trim() !== "") {
// If openaiOrganizationHeader is present, log it; otherwise, log that the header is not present
console.log("[Org ID]", openaiOrganizationHeader);
} else {
console.log("[Org ID] is not set up.");
}
// Check if serverConfig.openaiOrgId is defined and not an empty string
if (serverConfig.openaiOrgId && serverConfig.openaiOrgId.trim() !== "") {
// If openaiOrganizationHeader is present, log it; otherwise, log that the header is not present
console.log("[Org ID]", openaiOrganizationHeader);
} else {
console.log("[Org ID] is not set up.");
}

// to prevent browser prompt for credentials
const newHeaders = new Headers(res.headers);
newHeaders.delete("www-authenticate");
// to disable nginx buffering
newHeaders.set("X-Accel-Buffering", "no");


// Conditionally delete the OpenAI-Organization header from the response if [Org ID] is undefined or empty (not setup in ENV)
// Also, this is to prevent the header from being sent to the client
if (!serverConfig.openaiOrgId || serverConfig.openaiOrgId.trim() === "") {
Expand All @@ -142,7 +153,6 @@ export async function requestOpenai(req: NextRequest) {
// The browser will try to decode the response with brotli and fail
newHeaders.delete("content-encoding");


return new Response(res.body, {
status: res.status,
statusText: res.statusText,
Expand Down
4 changes: 2 additions & 2 deletions app/store/config.ts
Original file line number Diff line number Diff line change
Expand Up @@ -116,12 +116,12 @@ export const useAppConfig = createPersistStore(

for (const model of oldModels) {
model.available = false;
modelMap[model.name] = model;
modelMap[`${model.name}@${model?.provider?.id}`] = model;
}

for (const model of newModels) {
model.available = true;
modelMap[model.name] = model;
modelMap[`${model.name}@${model?.provider?.id}`] = model;
}

set(() => ({
Expand Down
43 changes: 35 additions & 8 deletions app/utils/model.ts
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import { DEFAULT_MODELS } from "../constant";
import { LLMModel } from "../client/api";

const customProvider = (modelName: string) => ({
id: modelName,
providerName: "",
providerName: "Custom",
providerType: "custom",
});

Expand All @@ -23,7 +24,8 @@ export function collectModelTable(

// default models
models.forEach((m) => {
modelTable[m.name] = {
// using <modelName>@<providerId> as fullName
modelTable[`${m.name}@${m?.provider?.id}`] = {
...m,
displayName: m.name, // 'provider' is copied over if it exists
};
Expand All @@ -45,12 +47,27 @@ export function collectModelTable(
(model) => (model.available = available),
);
} else {
modelTable[name] = {
name,
displayName: displayName || name,
available,
provider: modelTable[name]?.provider ?? customProvider(name), // Use optional chaining
};
// 1. find model by name(), and set available value
let count = 0;
for (const fullName in modelTable) {
if (fullName.split("@").shift() == name) {
count += 1;
modelTable[fullName]["available"] = available;
if (displayName) {
modelTable[fullName]["displayName"] = displayName;
}
}
}
// 2. if model not exists, create new model with available value
if (count === 0) {
const provider = customProvider(name);
modelTable[`${name}@${provider?.id}`] = {
name,
displayName: displayName || name,
available,
provider, // Use optional chaining
};
}
}
});

Expand Down Expand Up @@ -100,3 +117,13 @@ export function collectModelsWithDefaultModel(
const allModels = Object.values(modelTable);
return allModels;
}

export function isModelAvailableInServer(
customModels: string,
modelName: string,
providerName: string,
) {
const fullName = `${modelName}@${providerName}`;
const modelTable = collectModelTable(DEFAULT_MODELS, customModels);
return modelTable[fullName]?.available === false;
}

0 comments on commit c4a6c93

Please sign in to comment.