Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions crates/goose-server/src/openapi.rs
Original file line number Diff line number Diff line change
Expand Up @@ -366,6 +366,7 @@ impl<'__s> ToSchema<'__s> for AnnotatedSchema {
super::routes::config_management::get_extensions,
super::routes::config_management::read_all_config,
super::routes::config_management::providers,
super::routes::config_management::get_provider_models,
super::routes::config_management::upsert_permissions,
super::routes::config_management::create_custom_provider,
super::routes::config_management::remove_custom_provider,
Expand Down
115 changes: 112 additions & 3 deletions crates/goose-server/src/routes/config_management.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use super::utils::verify_secret_key;
use crate::routes::utils::check_provider_configured;
use crate::state::AppState;
use axum::{
extract::State,
extract::{Path, State},
routing::{delete, get, post},
Json, Router,
};
Expand Down Expand Up @@ -386,6 +386,66 @@ pub async fn providers(
Ok(Json(providers_response))
}

#[utoipa::path(
get,
path = "/config/providers/{name}/models",
params(
("name" = String, Path, description = "Provider name (e.g., openai)")
),
responses(
(status = 200, description = "Models fetched successfully", body = [String]),
(status = 400, description = "Unknown provider, provider not configured, or authentication error"),
(status = 429, description = "Rate limit exceeded"),
(status = 500, description = "Internal server error")
)
)]
pub async fn get_provider_models(
State(state): State<Arc<AppState>>,
headers: HeaderMap,
Path(name): Path<String>,
) -> Result<Json<Vec<String>>, StatusCode> {
verify_secret_key(&headers, &state)?;

let all = get_providers();
let Some(metadata) = all.into_iter().find(|m| m.name == name) else {
return Err(StatusCode::BAD_REQUEST);
};
if !check_provider_configured(&metadata) {
return Err(StatusCode::BAD_REQUEST);
}

let model_config =
ModelConfig::new(&metadata.default_model).map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
let provider = goose::providers::create(&name, model_config)
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;

match provider.fetch_supported_models().await {
Ok(Some(models)) => Ok(Json(models)),
Ok(None) => Ok(Json(Vec::new())),
Err(provider_error) => {
use goose::providers::errors::ProviderError;
let status_code = match provider_error {
// Permanent misconfigurations - client should fix configuration
ProviderError::Authentication(_) => StatusCode::BAD_REQUEST,
ProviderError::UsageError(_) => StatusCode::BAD_REQUEST,

// Transient errors - client should retry later
ProviderError::RateLimitExceeded(_) => StatusCode::TOO_MANY_REQUESTS,

// All other errors - internal server error
_ => StatusCode::INTERNAL_SERVER_ERROR,
};

tracing::warn!(
"Provider {} failed to fetch models: {}",
name,
provider_error
);
Err(status_code)
}
}
}

#[derive(Serialize, ToSchema)]
pub struct PricingData {
pub provider: String,
Expand Down Expand Up @@ -771,6 +831,7 @@ pub fn routes(state: Arc<AppState>) -> Router {
.route("/config/extensions", post(add_extension))
.route("/config/extensions/{name}", delete(remove_extension))
.route("/config/providers", get(providers))
.route("/config/providers/{name}/models", get(get_provider_models))
.route("/config/pricing", post(get_pricing))
.route("/config/init", post(init_config))
.route("/config/backup", post(backup_config))
Expand All @@ -790,8 +851,7 @@ pub fn routes(state: Arc<AppState>) -> Router {
mod tests {
use super::*;

#[tokio::test]
async fn test_read_model_limits() {
async fn create_test_state() -> Arc<AppState> {
let test_state = AppState::new(
Arc::new(goose::agents::Agent::default()),
"test".to_string(),
Expand All @@ -805,6 +865,12 @@ mod tests {
.await
.unwrap();
test_state.set_scheduler(sched).await;
test_state
}

#[tokio::test]
async fn test_read_model_limits() {
let test_state = create_test_state().await;
let mut headers = HeaderMap::new();
headers.insert("X-Secret-Key", "test".parse().unwrap());

Expand All @@ -829,4 +895,47 @@ mod tests {
assert!(gpt4_limit.is_some());
assert_eq!(gpt4_limit.unwrap().context_limit, 128_000);
}

#[tokio::test]
async fn test_get_provider_models_unknown_provider() {
let test_state = create_test_state().await;
let mut headers = HeaderMap::new();
headers.insert("X-Secret-Key", "test".parse().unwrap());

let result = get_provider_models(
State(test_state),
headers,
Path("unknown_provider".to_string()),
)
.await;

assert!(result.is_err());
assert_eq!(result.unwrap_err(), StatusCode::BAD_REQUEST);
}

#[tokio::test]
async fn test_get_provider_models_openai_configured() {
std::env::set_var("OPENAI_API_KEY", "test-key");

let test_state = create_test_state().await;
let mut headers = HeaderMap::new();
headers.insert("X-Secret-Key", "test".parse().unwrap());

let result =
get_provider_models(State(test_state), headers, Path("openai".to_string())).await;

// The response should be BAD_REQUEST since the API key is invalid (authentication error)
assert!(
result.is_err(),
"Expected error response from OpenAI provider with invalid key"
);
let status_code = result.unwrap_err();

assert!(status_code == StatusCode::BAD_REQUEST,
"Expected BAD_REQUEST (authentication error) or INTERNAL_SERVER_ERROR (other errors), got: {}",
status_code
);

std::env::remove_var("OPENAI_API_KEY");
}
}
43 changes: 43 additions & 0 deletions ui/desktop/openapi.json
Original file line number Diff line number Diff line change
Expand Up @@ -522,6 +522,49 @@
}
}
},
"/config/providers/{name}/models": {
"get": {
"tags": [
"super::routes::config_management"
],
"operationId": "get_provider_models",
"parameters": [
{
"name": "name",
"in": "path",
"description": "Provider name (e.g., openai)",
"required": true,
"schema": {
"type": "string"
}
}
],
"responses": {
"200": {
"description": "Models fetched successfully",
"content": {
"application/json": {
"schema": {
"type": "array",
"items": {
"type": "string"
}
}
}
}
},
"400": {
"description": "Unknown provider, provider not configured, or authentication error"
},
"429": {
"description": "Rate limit exceeded"
},
"500": {
"description": "Internal server error"
}
}
}
},
"/config/read": {
"post": {
"tags": [
Expand Down
9 changes: 8 additions & 1 deletion ui/desktop/src/api/sdk.gen.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
// This file is auto-generated by @hey-api/openapi-ts

import type { Options as ClientOptions, TDataShape, Client } from './client';
import type { AddSubRecipesData, AddSubRecipesResponses, AddSubRecipesErrors, ExtendPromptData, ExtendPromptResponses, ExtendPromptErrors, UpdateSessionConfigData, UpdateSessionConfigResponses, UpdateSessionConfigErrors, GetToolsData, GetToolsResponses, GetToolsErrors, UpdateAgentProviderData, UpdateAgentProviderResponses, UpdateAgentProviderErrors, UpdateRouterToolSelectorData, UpdateRouterToolSelectorResponses, UpdateRouterToolSelectorErrors, ReadAllConfigData, ReadAllConfigResponses, BackupConfigData, BackupConfigResponses, BackupConfigErrors, CreateCustomProviderData, CreateCustomProviderResponses, CreateCustomProviderErrors, RemoveCustomProviderData, RemoveCustomProviderResponses, RemoveCustomProviderErrors, GetExtensionsData, GetExtensionsResponses, GetExtensionsErrors, AddExtensionData, AddExtensionResponses, AddExtensionErrors, RemoveExtensionData, RemoveExtensionResponses, RemoveExtensionErrors, InitConfigData, InitConfigResponses, InitConfigErrors, UpsertPermissionsData, UpsertPermissionsResponses, UpsertPermissionsErrors, ProvidersData, ProvidersResponses, ReadConfigData, ReadConfigResponses, ReadConfigErrors, RecoverConfigData, RecoverConfigResponses, RecoverConfigErrors, RemoveConfigData, RemoveConfigResponses, RemoveConfigErrors, UpsertConfigData, UpsertConfigResponses, UpsertConfigErrors, ValidateConfigData, ValidateConfigResponses, ValidateConfigErrors, ConfirmPermissionData, ConfirmPermissionResponses, ConfirmPermissionErrors, ManageContextData, ManageContextResponses, ManageContextErrors, CreateRecipeData, CreateRecipeResponses, CreateRecipeErrors, DecodeRecipeData, DecodeRecipeResponses, DecodeRecipeErrors, EncodeRecipeData, EncodeRecipeResponses, EncodeRecipeErrors, ScanRecipeData, ScanRecipeResponses, CreateScheduleData, CreateScheduleResponses, CreateScheduleErrors, DeleteScheduleData, DeleteScheduleResponses, DeleteScheduleErrors, ListSchedulesData, ListSchedulesResponses, ListSchedulesErrors, UpdateScheduleData, UpdateScheduleResponses, UpdateScheduleErrors, InspectRunningJobData, InspectRunningJobResponses, InspectRunningJobErrors, KillRunningJobData, KillRunningJobResponses, PauseScheduleData, PauseScheduleResponses, PauseScheduleErrors, RunNowHandlerData, RunNowHandlerResponses, RunNowHandlerErrors, SessionsHandlerData, SessionsHandlerResponses, SessionsHandlerErrors, UnpauseScheduleData, UnpauseScheduleResponses, UnpauseScheduleErrors, ListSessionsData, ListSessionsResponses, ListSessionsErrors, GetSessionHistoryData, GetSessionHistoryResponses, GetSessionHistoryErrors } from './types.gen';
import type { AddSubRecipesData, AddSubRecipesResponses, AddSubRecipesErrors, ExtendPromptData, ExtendPromptResponses, ExtendPromptErrors, UpdateSessionConfigData, UpdateSessionConfigResponses, UpdateSessionConfigErrors, GetToolsData, GetToolsResponses, GetToolsErrors, UpdateAgentProviderData, UpdateAgentProviderResponses, UpdateAgentProviderErrors, UpdateRouterToolSelectorData, UpdateRouterToolSelectorResponses, UpdateRouterToolSelectorErrors, ReadAllConfigData, ReadAllConfigResponses, BackupConfigData, BackupConfigResponses, BackupConfigErrors, CreateCustomProviderData, CreateCustomProviderResponses, CreateCustomProviderErrors, RemoveCustomProviderData, RemoveCustomProviderResponses, RemoveCustomProviderErrors, GetExtensionsData, GetExtensionsResponses, GetExtensionsErrors, AddExtensionData, AddExtensionResponses, AddExtensionErrors, RemoveExtensionData, RemoveExtensionResponses, RemoveExtensionErrors, InitConfigData, InitConfigResponses, InitConfigErrors, UpsertPermissionsData, UpsertPermissionsResponses, UpsertPermissionsErrors, ProvidersData, ProvidersResponses, GetProviderModelsData, GetProviderModelsResponses, GetProviderModelsErrors, ReadConfigData, ReadConfigResponses, ReadConfigErrors, RecoverConfigData, RecoverConfigResponses, RecoverConfigErrors, RemoveConfigData, RemoveConfigResponses, RemoveConfigErrors, UpsertConfigData, UpsertConfigResponses, UpsertConfigErrors, ValidateConfigData, ValidateConfigResponses, ValidateConfigErrors, ConfirmPermissionData, ConfirmPermissionResponses, ConfirmPermissionErrors, ManageContextData, ManageContextResponses, ManageContextErrors, CreateRecipeData, CreateRecipeResponses, CreateRecipeErrors, DecodeRecipeData, DecodeRecipeResponses, DecodeRecipeErrors, EncodeRecipeData, EncodeRecipeResponses, EncodeRecipeErrors, ScanRecipeData, ScanRecipeResponses, CreateScheduleData, CreateScheduleResponses, CreateScheduleErrors, DeleteScheduleData, DeleteScheduleResponses, DeleteScheduleErrors, ListSchedulesData, ListSchedulesResponses, ListSchedulesErrors, UpdateScheduleData, UpdateScheduleResponses, UpdateScheduleErrors, InspectRunningJobData, InspectRunningJobResponses, InspectRunningJobErrors, KillRunningJobData, KillRunningJobResponses, PauseScheduleData, PauseScheduleResponses, PauseScheduleErrors, RunNowHandlerData, RunNowHandlerResponses, RunNowHandlerErrors, SessionsHandlerData, SessionsHandlerResponses, SessionsHandlerErrors, UnpauseScheduleData, UnpauseScheduleResponses, UnpauseScheduleErrors, ListSessionsData, ListSessionsResponses, ListSessionsErrors, GetSessionHistoryData, GetSessionHistoryResponses, GetSessionHistoryErrors } from './types.gen';
import { client as _heyApiClient } from './client.gen';

export type Options<TData extends TDataShape = TDataShape, ThrowOnError extends boolean = boolean> = ClientOptions<TData, ThrowOnError> & {
Expand Down Expand Up @@ -158,6 +158,13 @@ export const providers = <ThrowOnError extends boolean = false>(options?: Option
});
};

export const getProviderModels = <ThrowOnError extends boolean = false>(options: Options<GetProviderModelsData, ThrowOnError>) => {
return (options.client ?? _heyApiClient).get<GetProviderModelsResponses, GetProviderModelsErrors, ThrowOnError>({
url: '/config/providers/{name}/models',
...options
});
};

export const readConfig = <ThrowOnError extends boolean = false>(options: Options<ReadConfigData, ThrowOnError>) => {
return (options.client ?? _heyApiClient).post<ReadConfigResponses, ReadConfigErrors, ThrowOnError>({
url: '/config/read',
Expand Down
36 changes: 36 additions & 0 deletions ui/desktop/src/api/types.gen.ts
Original file line number Diff line number Diff line change
Expand Up @@ -1261,6 +1261,42 @@ export type ProvidersResponses = {

export type ProvidersResponse2 = ProvidersResponses[keyof ProvidersResponses];

export type GetProviderModelsData = {
body?: never;
path: {
/**
* Provider name (e.g., openai)
*/
name: string;
};
query?: never;
url: '/config/providers/{name}/models';
};

export type GetProviderModelsErrors = {
/**
* Unknown provider, provider not configured, or authentication error
*/
400: unknown;
/**
* Rate limit exceeded
*/
429: unknown;
/**
* Internal server error
*/
500: unknown;
};

export type GetProviderModelsResponses = {
/**
* Models fetched successfully
*/
200: Array<string>;
};

export type GetProviderModelsResponse = GetProviderModelsResponses[keyof GetProviderModelsResponses];

export type ReadConfigData = {
body: ConfigKeyQuery;
path?: never;
Expand Down
19 changes: 19 additions & 0 deletions ui/desktop/src/components/ConfigContext.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import {
addExtension as apiAddExtension,
removeExtension as apiRemoveExtension,
providers,
getProviderModels as apiGetProviderModels,
} from '../api';
import type {
ConfigResponse,
Expand Down Expand Up @@ -39,6 +40,7 @@ interface ConfigContextType {
removeExtension: (name: string) => Promise<void>;
getProviders: (b: boolean) => Promise<ProviderDetails[]>;
getExtensions: (b: boolean) => Promise<FixedExtensionEntry[]>;
getProviderModels: (providerName: string) => Promise<string[]>;
disableAllExtensions: () => Promise<void>;
enableBotExtensions: (extensions: ExtensionConfig[]) => Promise<void>;
}
Expand Down Expand Up @@ -185,6 +187,21 @@ export const ConfigProvider: React.FC<ConfigProviderProps> = ({ children }) => {
[providersList]
);

const getProviderModels = useCallback(async (providerName: string): Promise<string[]> => {
try {
const response = await apiGetProviderModels({
path: { name: providerName },
headers: {
'X-Secret-Key': await window.electron.getSecretKey(),
},
});
return response.data || [];
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this should use the generated API. you should never have to call getSecretKey

} catch (error) {
console.error(`Failed to fetch models for provider ${providerName}:`, error);
return [];
}
}, []);

useEffect(() => {
// Load all configuration data and providers on mount
(async () => {
Expand Down Expand Up @@ -242,6 +259,7 @@ export const ConfigProvider: React.FC<ConfigProviderProps> = ({ children }) => {
toggleExtension,
getProviders,
getExtensions,
getProviderModels,
disableAllExtensions,
enableBotExtensions,
};
Expand All @@ -257,6 +275,7 @@ export const ConfigProvider: React.FC<ConfigProviderProps> = ({ children }) => {
toggleExtension,
getProviders,
getExtensions,
getProviderModels,
reloadConfig,
]);

Expand Down
Loading
Loading