Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow model selection on the frontend #187

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion backend/dataline/api/settings/router.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from fastapi import APIRouter, Depends, HTTPException, UploadFile

from dataline.models.user.schema import AvatarOut, UserOut, UserUpdateIn
from dataline.models.user.schema import AllowedModels, AvatarOut, UserOut, UserUpdateIn
from dataline.old_models import SuccessResponse
from dataline.repositories.base import AsyncSession, get_session
from dataline.services.settings import SettingsService
Expand Down Expand Up @@ -49,3 +49,11 @@ async def get_info(
) -> SuccessResponse[UserOut]:
user_info = await settings_service.get_user_info(session)
return SuccessResponse(data=user_info)


@router.get("/allowed_models")
async def get_allowed_models(
settings_service: SettingsService = Depends(SettingsService), session: AsyncSession = Depends(get_session)
) -> SuccessResponse[AllowedModels]:
allowed_modes = await settings_service.get_user_allowed_models(session)
return SuccessResponse(data=AllowedModels(models=allowed_modes))
6 changes: 5 additions & 1 deletion backend/dataline/models/user/schema.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Optional
from typing import Optional, Sequence

import openai
from pydantic import (
Expand Down Expand Up @@ -65,3 +65,7 @@ class UserWithKeys(BaseModel):

class AvatarOut(BaseModel):
blob: str


class AllowedModels(BaseModel):
models: Sequence[str]
26 changes: 22 additions & 4 deletions backend/dataline/services/settings.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import logging
import mimetypes
from typing import Optional
from typing import Optional, Sequence
from uuid import uuid4

import openai
Expand All @@ -18,11 +18,20 @@

logger = logging.getLogger(__name__)

STANDARD_MODELS = {"gpt-4o": "a", "gpt-4-turbo": "b", "gpt-3.5-turbo": "c"}

def model_exists(openai_api_key: SecretStr | str, model: str) -> bool:

def get_allowed_models(openai_api_key: SecretStr | str) -> Sequence[str]:
api_key = openai_api_key.get_secret_value() if isinstance(openai_api_key, SecretStr) else openai_api_key
models = openai.OpenAI(api_key=api_key).models.list()
return model in {model.id for model in models}
return sorted(
[model.id for model in openai.OpenAI(api_key=api_key).models.list() if model.id.startswith("gpt")],
key=lambda x: STANDARD_MODELS.get(x, x),
)


def model_exists(openai_api_key: SecretStr | str, model: str) -> bool:
models = get_allowed_models(openai_api_key)
return model in models


class SettingsService:
Expand Down Expand Up @@ -121,6 +130,15 @@ async def get_user_info(self, session: AsyncSession) -> UserOut:

return UserOut.model_validate(user_info)

async def get_user_allowed_models(self, session: AsyncSession) -> Sequence[str]:
user_info = await self.user_repo.get_one_or_none(session)
if user_info is None:
raise NotFoundError("No user or multiple users found")
if user_info.openai_api_key is None:
raise NotFoundError("OpenAI key not set")

return get_allowed_models(user_info.openai_api_key)

async def get_model_details(self, session: AsyncSession) -> UserWithKeys:
user_info = await self.user_repo.get_one_or_none(session)
if user_info is None:
Expand Down
21 changes: 20 additions & 1 deletion frontend/src/api.ts
Original file line number Diff line number Diff line change
Expand Up @@ -291,14 +291,22 @@ const updateUserInfo = async (options: {
openai_api_key?: string;
langsmith_api_key?: string;
sentry_enabled?: boolean;
preferred_openai_model?: string;
}) => {
const { name, openai_api_key, langsmith_api_key, sentry_enabled } = options;
const {
name,
openai_api_key,
langsmith_api_key,
sentry_enabled,
preferred_openai_model,
} = options;
// send only the filled in fields
const data = {
...(name && { name }),
...(openai_api_key && { openai_api_key }),
...(langsmith_api_key && { langsmith_api_key }),
...(sentry_enabled != null && { sentry_enabled }),
...(preferred_openai_model && { preferred_openai_model }),
};
const response = await backendApi<UpdateUserInfoResult>({
url: `/settings/info`,
Expand All @@ -313,11 +321,21 @@ export type GetUserInfoResult = ApiResponse<{
openai_api_key: string;
langsmith_api_key?: string;
sentry_enabled: boolean;
preferred_openai_model: string;
}>;
const getUserInfo = async () => {
return (await backendApi<GetUserInfoResult>({ url: `/settings/info` })).data;
};

export type GetAllowedModelsResult = ApiResponse<{ models: string[] }>;
const getAllowedModels = async () => {
return (
await backendApi<GetAllowedModelsResult>({
url: `/settings/allowed_models`,
})
).data;
};

export type RefreshChartResult = ApiResponse<{
created_at: string;
chartjs_json: string;
Expand Down Expand Up @@ -376,5 +394,6 @@ export const api = {
updateAvatar,
updateUserInfo,
getUserInfo,
getAllowedModels,
refreshChart,
};
69 changes: 61 additions & 8 deletions frontend/src/components/Settings/Settings.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,16 @@ import {
useGetUserProfile,
useUpdateUserInfo,
useUpdateUserAvatar,
useGetAllowedModels,
} from "@/hooks";
import {
Listbox,
ListboxLabel,
ListboxOption,
} from "@components/Catalyst/listbox";
import { Switch } from "@components/Catalyst/switch";
import _ from "lodash";
import { Spinner } from "../Spinner/Spinner";

function classNames(...classes: string[]) {
return classes.filter(Boolean).join(" ");
Expand All @@ -17,7 +24,9 @@ function classNames(...classes: string[]) {
export default function Account() {
const { data: profile } = useGetUserProfile();
const { data: avatarUrl } = useGetAvatar();
const { mutate: updateUserInfo } = useUpdateUserInfo();
const { data: allowed_models } = useGetAllowedModels();
const { mutate: updateUserInfo, isPending: isUpdatingUserInfo } =
useUpdateUserInfo();
const { mutate: updateAvatar, isPending } = useUpdateUserAvatar();

const avatarUploadRef = useRef<HTMLInputElement>(null);
Expand Down Expand Up @@ -46,6 +55,10 @@ export default function Account() {
userInfo.langsmith_api_key === "**********"
? undefined
: userInfo.langsmith_api_key,
preferred_openai_model:
userInfo.preferred_openai_model === profile?.preferred_openai_model
? undefined
: userInfo.preferred_openai_model,
};
updateUserInfo(updatedUserInfo);
}
Expand Down Expand Up @@ -209,7 +222,7 @@ export default function Account() {
</div>
</div>

{/* Sentry Preference */}
{/* Preferences */}
<div className="grid max-w-7xl grid-cols-1 gap-x-8 gap-y-10 px-4 py-16 sm:px-6 md:grid-cols-3 lg:px-8">
<div>
<h2 className="text-base font-semibold leading-7 text-white">
Expand All @@ -219,6 +232,46 @@ export default function Account() {

<div className="md:col-span-2">
<div className="grid grid-cols-1 gap-x-6 gap-y-8 sm:max-w-xl sm:grid-cols-6">
{/* Preferred Model */}
<div className="sm:col-span-3">
<label
htmlFor="preferred-model"
className="block text-md font-medium leading-6 text-white"
>
Preferred Model
</label>
<div className="mt-2">
<Listbox
name="preferred-model"
defaultValue={profile?.preferred_openai_model}
onChange={(value) =>
setUserInfo((prevUserInfo) => ({
...prevUserInfo!,
preferred_openai_model: value,
}))
}
>
{allowed_models ? (
allowed_models.models.map((model) => (
<ListboxOption value={model} key={model}>
<ListboxLabel>{model}</ListboxLabel>
</ListboxOption>
))
) : (
<ListboxOption
value={profile?.preferred_openai_model}
key={profile?.preferred_openai_model}
>
<ListboxLabel>
{profile?.preferred_openai_model}
</ListboxLabel>
</ListboxOption>
)}
</Listbox>
</div>
</div>

{/* Sentry settings */}
<div className="col-span-full">
<div className="flex items-center gap-x-6">
<label
Expand All @@ -245,27 +298,27 @@ export default function Account() {
</p>
</div>
</div>
<div className="mt-8 flex">
<div className="mt-8 flex gap-x-3 items-center">
<button
disabled={!settingsChanged}
disabled={!settingsChanged || isUpdatingUserInfo}
type="submit"
className={classNames(
"rounded-md px-3 py-2 text-sm font-semibold shadow-sm",
settingsChanged
"rounded-md px-6 py-2 text-sm font-semibold shadow-sm",
settingsChanged && !isUpdatingUserInfo
? "bg-indigo-500 text-white hover:bg-indigo-400 focus-visible:outline focus-visible:outline-2 focus-visible:outline-offset-2 focus-visible:outline-indigo-500"
: "text-gray-300 bg-indigo-700"
)}
onClick={updateUserInfoWithKeys}
>
Save
</button>
{isUpdatingUserInfo && <Spinner />}
</div>
</div>
</div>

<div className="grid max-w-7xl grid-cols-1 gap-x-8 gap-y-10 px-4 py-16 sm:px-6 md:grid-cols-3 lg:px-8">
<div></div>
<div className="md:col-span-2">
<div className="md:col-span-2 md:col-start-2">
<div className="grid grid-cols-1 gap-x-6 gap-y-8 sm:max-w-xl sm:grid-cols-6">
<div className="col-span-full">
<div className="max-w-2xl text-white">
Expand Down
16 changes: 15 additions & 1 deletion frontend/src/hooks/settings.ts
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import { useEffect, useState } from "react";
const HEALTH_CHECK_QUERY_KEY = ["HEALTH_CHECK"];
const USER_INFO_QUERY_KEY = ["USER_INFO"];
const AVATAR_QUERY_KEY = ["AVATAR"];
const ALLOWED_MODELS_QUERY_KEY = ["ALLOWED_MODELS"];

export function getBackendStatusQuery(options = {}) {
return queryOptions({
Expand Down Expand Up @@ -164,9 +165,22 @@ export function useUpdateUserInfo(options = {}) {
message: "Error updating user info",
});
},
onSettled() {
onSettled(_, error, variables) {
queryClient.invalidateQueries({ queryKey: USER_INFO_QUERY_KEY });
if (error === null && variables.openai_api_key) {
queryClient.invalidateQueries({ queryKey: ALLOWED_MODELS_QUERY_KEY });
}
},
...options,
});
}

export function useGetAllowedModels() {
const result = useQuery({
queryKey: ALLOWED_MODELS_QUERY_KEY,
queryFn: async () => (await api.getAllowedModels()).data,
staleTime: Infinity,
});

return result;
}
Loading