diff --git a/crates/goose-acp/src/server.rs b/crates/goose-acp/src/server.rs index a939097d968d..87932bf9faee 100644 --- a/crates/goose-acp/src/server.rs +++ b/crates/goose-acp/src/server.rs @@ -298,6 +298,7 @@ impl GooseAcpAgent { toolshim: false, toolshim_model: None, fast_model: None, + request_params: None, }; let provider = create(&provider_name, model_config).await?; let goose_mode = config diff --git a/crates/goose-server/src/routes/agent.rs b/crates/goose-server/src/routes/agent.rs index 51a647685850..176a42f78a0b 100644 --- a/crates/goose-server/src/routes/agent.rs +++ b/crates/goose-server/src/routes/agent.rs @@ -48,6 +48,8 @@ pub struct UpdateProviderRequest { provider: String, model: Option, session_id: String, + context_limit: Option, + request_params: Option>, } #[derive(Deserialize, utoipa::ToSchema)] @@ -528,12 +530,15 @@ async fn update_agent_provider( } }; - let model_config = ModelConfig::new(&model).map_err(|e| { - ( - StatusCode::BAD_REQUEST, - format!("Invalid model config: {}", e), - ) - })?; + let model_config = ModelConfig::new(&model) + .map_err(|e| { + ( + StatusCode::BAD_REQUEST, + format!("Invalid model config: {}", e), + ) + })? + .with_context_limit(payload.context_limit) + .with_request_params(payload.request_params); let new_provider = create(&payload.provider, model_config).await.map_err(|e| { ( diff --git a/crates/goose/src/context_mgmt/mod.rs b/crates/goose/src/context_mgmt/mod.rs index 313bef217ac0..b251a72d0109 100644 --- a/crates/goose/src/context_mgmt/mod.rs +++ b/crates/goose/src/context_mgmt/mod.rs @@ -444,6 +444,7 @@ mod tests { toolshim: false, toolshim_model: None, fast_model: None, + request_params: None, }, max_tool_responses: None, } diff --git a/crates/goose/src/model.rs b/crates/goose/src/model.rs index c360e209bde4..bc9b73c9b1ab 100644 --- a/crates/goose/src/model.rs +++ b/crates/goose/src/model.rs @@ -1,10 +1,39 @@ use once_cell::sync::Lazy; use serde::{Deserialize, Serialize}; +use serde_json::Value; +use std::collections::HashMap; use thiserror::Error; use utoipa::ToSchema; const DEFAULT_CONTEXT_LIMIT: usize = 128_000; +#[derive(Debug, Clone, Deserialize)] +struct PredefinedModel { + name: String, + #[serde(default)] + context_limit: Option, + #[serde(default)] + request_params: Option>, +} + +fn get_predefined_models() -> Vec { + static PREDEFINED_MODELS: Lazy> = + Lazy::new(|| match std::env::var("GOOSE_PREDEFINED_MODELS") { + Ok(json_str) => serde_json::from_str(&json_str).unwrap_or_else(|e| { + tracing::warn!("Failed to parse GOOSE_PREDEFINED_MODELS: {}", e); + Vec::new() + }), + Err(_) => Vec::new(), + }); + PREDEFINED_MODELS.clone() +} + +fn find_predefined_model(model_name: &str) -> Option { + get_predefined_models() + .into_iter() + .find(|m| m.name == model_name) +} + #[derive(Error, Debug)] pub enum ConfigError { #[error("Environment variable '{0}' not found")] @@ -80,6 +109,9 @@ pub struct ModelConfig { pub toolshim: bool, pub toolshim_model: Option, pub fast_model: Option, + /// Provider-specific request parameters (e.g., anthropic_beta headers) + #[serde(default, skip_serializing_if = "Option::is_none")] + pub request_params: Option>, } #[derive(Debug, Clone, Serialize, Deserialize)] @@ -97,7 +129,26 @@ impl ModelConfig { model_name: String, context_env_var: Option<&str>, ) -> Result { - let context_limit = Self::parse_context_limit(&model_name, None, context_env_var)?; + let predefined = find_predefined_model(&model_name); + + let context_limit = if let Some(ref pm) = predefined { + if let Some(env_var) = context_env_var { + if let Ok(val) = std::env::var(env_var) { + Some(Self::validate_context_limit(&val, env_var)?) + } else { + pm.context_limit + } + } else if let Ok(val) = std::env::var("GOOSE_CONTEXT_LIMIT") { + Some(Self::validate_context_limit(&val, "GOOSE_CONTEXT_LIMIT")?) + } else { + pm.context_limit + } + } else { + Self::parse_context_limit(&model_name, None, context_env_var)? + }; + + let request_params = predefined.and_then(|pm| pm.request_params); + let temperature = Self::parse_temperature()?; let max_tokens = Self::parse_max_tokens()?; let toolshim = Self::parse_toolshim()?; @@ -111,6 +162,7 @@ impl ModelConfig { toolshim, toolshim_model, fast_model: None, + request_params, }) } @@ -285,6 +337,11 @@ impl ModelConfig { self } + pub fn with_request_params(mut self, params: Option>) -> Self { + self.request_params = params; + self + } + pub fn use_fast_model(&self) -> Self { if let Some(fast_model) = &self.fast_model { let mut config = self.clone(); diff --git a/crates/goose/src/providers/formats/databricks.rs b/crates/goose/src/providers/formats/databricks.rs index c9438aebf3b8..e723cca2f0aa 100644 --- a/crates/goose/src/providers/formats/databricks.rs +++ b/crates/goose/src/providers/formats/databricks.rs @@ -642,6 +642,15 @@ pub fn create_request( apply_cache_control_for_claude(&mut payload); } + // Add request_params to the payload (e.g., anthropic_beta for extended context) + if let Some(params) = &model_config.request_params { + if let Some(obj) = payload.as_object_mut() { + for (key, value) in params { + obj.insert(key.clone(), value.clone()); + } + } + } + Ok(payload) } @@ -1013,6 +1022,7 @@ mod tests { toolshim: false, toolshim_model: None, fast_model: None, + request_params: None, }; let request = create_request(&model_config, "system", &[], &[], &ImageFormat::OpenAi)?; let obj = request.as_object().unwrap(); @@ -1044,6 +1054,7 @@ mod tests { toolshim: false, toolshim_model: None, fast_model: None, + request_params: None, }; let request = create_request(&model_config, "system", &[], &[], &ImageFormat::OpenAi)?; assert_eq!(request["reasoning_effort"], "high"); @@ -1360,6 +1371,7 @@ mod tests { toolshim: false, toolshim_model: None, fast_model: None, + request_params: None, }; let messages = vec![ @@ -1411,6 +1423,7 @@ mod tests { toolshim: false, toolshim_model: None, fast_model: None, + request_params: None, }; let messages = vec![Message::user().with_text("Hello")]; diff --git a/crates/goose/src/providers/formats/openai.rs b/crates/goose/src/providers/formats/openai.rs index 4ca843bfa5dd..5cd7f4edb7e7 100644 --- a/crates/goose/src/providers/formats/openai.rs +++ b/crates/goose/src/providers/formats/openai.rs @@ -1295,6 +1295,7 @@ mod tests { toolshim: false, toolshim_model: None, fast_model: None, + request_params: None, }; let request = create_request( &model_config, @@ -1334,6 +1335,7 @@ mod tests { toolshim: false, toolshim_model: None, fast_model: None, + request_params: None, }; let request = create_request( &model_config, @@ -1374,6 +1376,7 @@ mod tests { toolshim: false, toolshim_model: None, fast_model: None, + request_params: None, }; let request = create_request( &model_config, diff --git a/ui/desktop/openapi.json b/ui/desktop/openapi.json index 429bf91e0af1..6961c28044d7 100644 --- a/ui/desktop/openapi.json +++ b/ui/desktop/openapi.json @@ -4524,6 +4524,12 @@ "model_name": { "type": "string" }, + "request_params": { + "type": "object", + "description": "Provider-specific request parameters (e.g., anthropic_beta headers)", + "additionalProperties": {}, + "nullable": true + }, "temperature": { "type": "number", "format": "float", @@ -6327,6 +6333,11 @@ "session_id" ], "properties": { + "context_limit": { + "type": "integer", + "nullable": true, + "minimum": 0 + }, "model": { "type": "string", "nullable": true @@ -6334,6 +6345,11 @@ "provider": { "type": "string" }, + "request_params": { + "type": "object", + "additionalProperties": {}, + "nullable": true + }, "session_id": { "type": "string" } diff --git a/ui/desktop/src/api/types.gen.ts b/ui/desktop/src/api/types.gen.ts index 6a2e2e9fc64d..a7a12e946869 100644 --- a/ui/desktop/src/api/types.gen.ts +++ b/ui/desktop/src/api/types.gen.ts @@ -542,6 +542,12 @@ export type ModelConfig = { fast_model?: string | null; max_tokens?: number | null; model_name: string; + /** + * Provider-specific request parameters (e.g., anthropic_beta headers) + */ + request_params?: { + [key: string]: unknown; + } | null; temperature?: number | null; toolshim: boolean; toolshim_model?: string | null; @@ -1168,8 +1174,12 @@ export type UpdateFromSessionRequest = { }; export type UpdateProviderRequest = { + context_limit?: number | null; model?: string | null; provider: string; + request_params?: { + [key: string]: unknown; + } | null; session_id: string; }; diff --git a/ui/desktop/src/components/ChatInput.tsx b/ui/desktop/src/components/ChatInput.tsx index 20e3cd87db5b..c5eadc70843d 100644 --- a/ui/desktop/src/components/ChatInput.tsx +++ b/ui/desktop/src/components/ChatInput.tsx @@ -31,6 +31,7 @@ import { getSession, Message } from '../api'; import CreateRecipeFromSessionModal from './recipes/CreateRecipeFromSessionModal'; import CreateEditRecipeModal from './recipes/CreateEditRecipeModal'; import { getInitialWorkingDir } from '../utils/workingDir'; +import { getPredefinedModelsFromEnv } from './settings/models/predefinedModelsUtils'; import { trackFileAttached, trackVoiceDictation, @@ -448,6 +449,15 @@ export default function ChatInput({ return; } + // First, check predefined models from environment (highest priority) + const predefinedModels = getPredefinedModelsFromEnv(); + const predefinedModel = predefinedModels.find((m) => m.name === model); + if (predefinedModel?.context_limit) { + setTokenLimit(predefinedModel.context_limit); + setIsTokenLimitLoaded(true); + return; + } + const providers = await getProviders(true); // Find the provider details for the current provider diff --git a/ui/desktop/src/components/ModelAndProviderContext.tsx b/ui/desktop/src/components/ModelAndProviderContext.tsx index c35f5517d3e9..e8e370897bc1 100644 --- a/ui/desktop/src/components/ModelAndProviderContext.tsx +++ b/ui/desktop/src/components/ModelAndProviderContext.tsx @@ -53,6 +53,8 @@ export const ModelAndProviderProvider: React.FC = session_id: sessionId, provider: providerName, model: modelName, + context_limit: model.context_limit, + request_params: model.request_params, }, }); } diff --git a/ui/desktop/src/components/alerts/AlertBox.tsx b/ui/desktop/src/components/alerts/AlertBox.tsx index d3c5bde81e0a..e28ac2cea014 100644 --- a/ui/desktop/src/components/alerts/AlertBox.tsx +++ b/ui/desktop/src/components/alerts/AlertBox.tsx @@ -24,6 +24,17 @@ const alertStyles: Record = { [AlertType.Info]: 'dark:bg-white dark:text-black bg-black text-white', }; +const formatTokenCount = (count: number): string => { + if (count >= 1000000) { + const millions = count / 1000000; + return millions % 1 === 0 ? `${millions.toFixed(0)}M` : `${millions.toFixed(1)}M`; + } else if (count >= 1000) { + const thousands = count / 1000; + return thousands % 1 === 0 ? `${thousands.toFixed(0)}k` : `${thousands.toFixed(1)}k`; + } + return count.toString(); +}; + export const AlertBox = ({ alert, className }: AlertBoxProps) => { const { read } = useConfig(); const [isEditingThreshold, setIsEditingThreshold] = useState(false); @@ -242,18 +253,14 @@ export const AlertBox = ({ alert, className }: AlertBoxProps) => {
- {alert.progress!.current >= 1000 - ? (alert.progress!.current / 1000).toFixed(1) + 'k' - : alert.progress!.current} + {formatTokenCount(alert.progress!.current)} {Math.round((alert.progress!.current / alert.progress!.total) * 100)}%
- {alert.progress!.total >= 1000 - ? (alert.progress!.total / 1000).toFixed(0) + 'k' - : alert.progress!.total} + {formatTokenCount(alert.progress!.total)}
{alert.showCompactButton && alert.onCompact && ( diff --git a/ui/desktop/src/components/settings/models/modelInterface.ts b/ui/desktop/src/components/settings/models/modelInterface.ts index e37729dfbb68..9ff34dff4084 100644 --- a/ui/desktop/src/components/settings/models/modelInterface.ts +++ b/ui/desktop/src/components/settings/models/modelInterface.ts @@ -7,6 +7,8 @@ export default interface Model { lastUsed?: string; alias?: string; // optional model display name subtext?: string; // goes below model name if not the provider + context_limit?: number; // optional context limit override + request_params?: Record; // provider-specific request parameters } export function createModelStruct(