Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions crates/goose-acp/src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -298,6 +298,7 @@ impl GooseAcpAgent {
toolshim: false,
toolshim_model: None,
fast_model: None,
request_params: None,
};
let provider = create(&provider_name, model_config).await?;
let goose_mode = config
Expand Down
17 changes: 11 additions & 6 deletions crates/goose-server/src/routes/agent.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@ pub struct UpdateProviderRequest {
provider: String,
model: Option<String>,
session_id: String,
context_limit: Option<usize>,
request_params: Option<std::collections::HashMap<String, serde_json::Value>>,
}

#[derive(Deserialize, utoipa::ToSchema)]
Expand Down Expand Up @@ -528,12 +530,15 @@ async fn update_agent_provider(
}
};

let model_config = ModelConfig::new(&model).map_err(|e| {
(
StatusCode::BAD_REQUEST,
format!("Invalid model config: {}", e),
)
})?;
let model_config = ModelConfig::new(&model)
.map_err(|e| {
(
StatusCode::BAD_REQUEST,
format!("Invalid model config: {}", e),
)
})?
.with_context_limit(payload.context_limit)
.with_request_params(payload.request_params);

let new_provider = create(&payload.provider, model_config).await.map_err(|e| {
(
Expand Down
1 change: 1 addition & 0 deletions crates/goose/src/context_mgmt/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -444,6 +444,7 @@ mod tests {
toolshim: false,
toolshim_model: None,
fast_model: None,
request_params: None,
},
max_tool_responses: None,
}
Expand Down
59 changes: 58 additions & 1 deletion crates/goose/src/model.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,39 @@
use once_cell::sync::Lazy;
use serde::{Deserialize, Serialize};
use serde_json::Value;
use std::collections::HashMap;
use thiserror::Error;
use utoipa::ToSchema;

const DEFAULT_CONTEXT_LIMIT: usize = 128_000;

#[derive(Debug, Clone, Deserialize)]
struct PredefinedModel {
name: String,
#[serde(default)]
context_limit: Option<usize>,
#[serde(default)]
request_params: Option<HashMap<String, Value>>,
}

fn get_predefined_models() -> Vec<PredefinedModel> {
static PREDEFINED_MODELS: Lazy<Vec<PredefinedModel>> =
Lazy::new(|| match std::env::var("GOOSE_PREDEFINED_MODELS") {
Ok(json_str) => serde_json::from_str(&json_str).unwrap_or_else(|e| {
tracing::warn!("Failed to parse GOOSE_PREDEFINED_MODELS: {}", e);
Vec::new()
}),
Err(_) => Vec::new(),
});
PREDEFINED_MODELS.clone()
}

fn find_predefined_model(model_name: &str) -> Option<PredefinedModel> {
get_predefined_models()
.into_iter()
.find(|m| m.name == model_name)
}

#[derive(Error, Debug)]
pub enum ConfigError {
#[error("Environment variable '{0}' not found")]
Expand Down Expand Up @@ -80,6 +109,9 @@ pub struct ModelConfig {
pub toolshim: bool,
pub toolshim_model: Option<String>,
pub fast_model: Option<String>,
/// Provider-specific request parameters (e.g., anthropic_beta headers)
#[serde(default, skip_serializing_if = "Option::is_none")]
pub request_params: Option<HashMap<String, Value>>,
}

#[derive(Debug, Clone, Serialize, Deserialize)]
Expand All @@ -97,7 +129,26 @@ impl ModelConfig {
model_name: String,
context_env_var: Option<&str>,
) -> Result<Self, ConfigError> {
let context_limit = Self::parse_context_limit(&model_name, None, context_env_var)?;
let predefined = find_predefined_model(&model_name);

let context_limit = if let Some(ref pm) = predefined {
if let Some(env_var) = context_env_var {
if let Ok(val) = std::env::var(env_var) {
Some(Self::validate_context_limit(&val, env_var)?)
} else {
pm.context_limit
}
} else if let Ok(val) = std::env::var("GOOSE_CONTEXT_LIMIT") {
Some(Self::validate_context_limit(&val, "GOOSE_CONTEXT_LIMIT")?)
} else {
pm.context_limit
}
} else {
Self::parse_context_limit(&model_name, None, context_env_var)?
};

let request_params = predefined.and_then(|pm| pm.request_params);

let temperature = Self::parse_temperature()?;
let max_tokens = Self::parse_max_tokens()?;
let toolshim = Self::parse_toolshim()?;
Expand All @@ -111,6 +162,7 @@ impl ModelConfig {
toolshim,
toolshim_model,
fast_model: None,
request_params,
})
}

Expand Down Expand Up @@ -285,6 +337,11 @@ impl ModelConfig {
self
}

pub fn with_request_params(mut self, params: Option<HashMap<String, Value>>) -> Self {
self.request_params = params;
self
}

pub fn use_fast_model(&self) -> Self {
if let Some(fast_model) = &self.fast_model {
let mut config = self.clone();
Expand Down
13 changes: 13 additions & 0 deletions crates/goose/src/providers/formats/databricks.rs
Original file line number Diff line number Diff line change
Expand Up @@ -642,6 +642,15 @@ pub fn create_request(
apply_cache_control_for_claude(&mut payload);
}

// Add request_params to the payload (e.g., anthropic_beta for extended context)
if let Some(params) = &model_config.request_params {
if let Some(obj) = payload.as_object_mut() {
for (key, value) in params {
obj.insert(key.clone(), value.clone());
}
}
}

Ok(payload)
}

Expand Down Expand Up @@ -1013,6 +1022,7 @@ mod tests {
toolshim: false,
toolshim_model: None,
fast_model: None,
request_params: None,
};
let request = create_request(&model_config, "system", &[], &[], &ImageFormat::OpenAi)?;
let obj = request.as_object().unwrap();
Expand Down Expand Up @@ -1044,6 +1054,7 @@ mod tests {
toolshim: false,
toolshim_model: None,
fast_model: None,
request_params: None,
};
let request = create_request(&model_config, "system", &[], &[], &ImageFormat::OpenAi)?;
assert_eq!(request["reasoning_effort"], "high");
Expand Down Expand Up @@ -1360,6 +1371,7 @@ mod tests {
toolshim: false,
toolshim_model: None,
fast_model: None,
request_params: None,
};

let messages = vec![
Expand Down Expand Up @@ -1411,6 +1423,7 @@ mod tests {
toolshim: false,
toolshim_model: None,
fast_model: None,
request_params: None,
};

let messages = vec![Message::user().with_text("Hello")];
Expand Down
3 changes: 3 additions & 0 deletions crates/goose/src/providers/formats/openai.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1295,6 +1295,7 @@ mod tests {
toolshim: false,
toolshim_model: None,
fast_model: None,
request_params: None,
};
let request = create_request(
&model_config,
Expand Down Expand Up @@ -1334,6 +1335,7 @@ mod tests {
toolshim: false,
toolshim_model: None,
fast_model: None,
request_params: None,
};
let request = create_request(
&model_config,
Expand Down Expand Up @@ -1374,6 +1376,7 @@ mod tests {
toolshim: false,
toolshim_model: None,
fast_model: None,
request_params: None,
};
let request = create_request(
&model_config,
Expand Down
16 changes: 16 additions & 0 deletions ui/desktop/openapi.json
Original file line number Diff line number Diff line change
Expand Up @@ -4524,6 +4524,12 @@
"model_name": {
"type": "string"
},
"request_params": {
"type": "object",
"description": "Provider-specific request parameters (e.g., anthropic_beta headers)",
"additionalProperties": {},
"nullable": true
},
"temperature": {
"type": "number",
"format": "float",
Expand Down Expand Up @@ -6327,13 +6333,23 @@
"session_id"
],
"properties": {
"context_limit": {
"type": "integer",
"nullable": true,
"minimum": 0
},
"model": {
"type": "string",
"nullable": true
},
"provider": {
"type": "string"
},
"request_params": {
"type": "object",
"additionalProperties": {},
"nullable": true
},
"session_id": {
"type": "string"
}
Expand Down
10 changes: 10 additions & 0 deletions ui/desktop/src/api/types.gen.ts
Original file line number Diff line number Diff line change
Expand Up @@ -542,6 +542,12 @@ export type ModelConfig = {
fast_model?: string | null;
max_tokens?: number | null;
model_name: string;
/**
* Provider-specific request parameters (e.g., anthropic_beta headers)
*/
request_params?: {
[key: string]: unknown;
} | null;
temperature?: number | null;
toolshim: boolean;
toolshim_model?: string | null;
Expand Down Expand Up @@ -1168,8 +1174,12 @@ export type UpdateFromSessionRequest = {
};

export type UpdateProviderRequest = {
context_limit?: number | null;
model?: string | null;
provider: string;
request_params?: {
[key: string]: unknown;
} | null;
session_id: string;
};

Expand Down
10 changes: 10 additions & 0 deletions ui/desktop/src/components/ChatInput.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ import { getSession, Message } from '../api';
import CreateRecipeFromSessionModal from './recipes/CreateRecipeFromSessionModal';
import CreateEditRecipeModal from './recipes/CreateEditRecipeModal';
import { getInitialWorkingDir } from '../utils/workingDir';
import { getPredefinedModelsFromEnv } from './settings/models/predefinedModelsUtils';
import {
trackFileAttached,
trackVoiceDictation,
Expand Down Expand Up @@ -448,6 +449,15 @@ export default function ChatInput({
return;
}

// First, check predefined models from environment (highest priority)
const predefinedModels = getPredefinedModelsFromEnv();
const predefinedModel = predefinedModels.find((m) => m.name === model);
if (predefinedModel?.context_limit) {
setTokenLimit(predefinedModel.context_limit);
setIsTokenLimitLoaded(true);
return;
}

const providers = await getProviders(true);

// Find the provider details for the current provider
Expand Down
2 changes: 2 additions & 0 deletions ui/desktop/src/components/ModelAndProviderContext.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,8 @@ export const ModelAndProviderProvider: React.FC<ModelAndProviderProviderProps> =
session_id: sessionId,
provider: providerName,
model: modelName,
context_limit: model.context_limit,
request_params: model.request_params,
},
});
}
Expand Down
19 changes: 13 additions & 6 deletions ui/desktop/src/components/alerts/AlertBox.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,17 @@ const alertStyles: Record<AlertType, string> = {
[AlertType.Info]: 'dark:bg-white dark:text-black bg-black text-white',
};

const formatTokenCount = (count: number): string => {
if (count >= 1000000) {
const millions = count / 1000000;
return millions % 1 === 0 ? `${millions.toFixed(0)}M` : `${millions.toFixed(1)}M`;
} else if (count >= 1000) {
const thousands = count / 1000;
return thousands % 1 === 0 ? `${thousands.toFixed(0)}k` : `${thousands.toFixed(1)}k`;
}
return count.toString();
};

export const AlertBox = ({ alert, className }: AlertBoxProps) => {
const { read } = useConfig();
const [isEditingThreshold, setIsEditingThreshold] = useState(false);
Expand Down Expand Up @@ -242,18 +253,14 @@ export const AlertBox = ({ alert, className }: AlertBoxProps) => {
<div className="flex justify-between items-baseline text-[11px]">
<div className="flex gap-1 items-baseline">
<span className={'dark:text-black/60 text-white/60'}>
{alert.progress!.current >= 1000
? (alert.progress!.current / 1000).toFixed(1) + 'k'
: alert.progress!.current}
{formatTokenCount(alert.progress!.current)}
</span>
<span className={'dark:text-black/40 text-white/40'}>
{Math.round((alert.progress!.current / alert.progress!.total) * 100)}%
</span>
</div>
<span className={'dark:text-black/60 text-white/60'}>
{alert.progress!.total >= 1000
? (alert.progress!.total / 1000).toFixed(0) + 'k'
: alert.progress!.total}
{formatTokenCount(alert.progress!.total)}
</span>
</div>
{alert.showCompactButton && alert.onCompact && (
Expand Down
2 changes: 2 additions & 0 deletions ui/desktop/src/components/settings/models/modelInterface.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ export default interface Model {
lastUsed?: string;
alias?: string; // optional model display name
subtext?: string; // goes below model name if not the provider
context_limit?: number; // optional context limit override
request_params?: Record<string, unknown>; // provider-specific request parameters
}

export function createModelStruct(
Expand Down
Loading