Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 32 additions & 5 deletions crates/goose/src/providers/formats/gcpvertexai.rs
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,9 @@ pub enum GcpVertexAIModel {
Claude(ClaudeVersion),
/// Gemini model family with specific versions
Gemini(GeminiVersion),
/// MaaS (Model as a Service) models from Model Garden
/// Contains (publisher, full_model_name)
MaaS(String, String),
}

/// Represents available versions of the Claude model for goose.
Expand Down Expand Up @@ -126,6 +129,7 @@ impl fmt::Display for GcpVertexAIModel {
GeminiVersion::Pro25 => "gemini-2.5-pro",
GeminiVersion::Generic(name) => name,
},
Self::MaaS(_, model_name) => model_name,
};
write!(f, "{model_id}")
}
Expand All @@ -137,10 +141,12 @@ impl GcpVertexAIModel {
/// Each model family has a well-known location based on availability:
/// - Claude models default to Ohio (us-east5)
/// - Gemini models default to Iowa (us-central1)
/// - MaaS models default to Iowa (us-central1)
pub fn known_location(&self) -> GcpLocation {
match self {
Self::Claude(_) => GcpLocation::Ohio,
Self::Gemini(_) => GcpLocation::Iowa,
Self::MaaS(_, _) => GcpLocation::Iowa,
}
}
}
Expand All @@ -162,6 +168,15 @@ impl TryFrom<&str> for GcpVertexAIModel {
"gemini-2.5-pro-preview-05-06" => Ok(Self::Gemini(GeminiVersion::Pro25Preview)),
"gemini-2.5-flash" => Ok(Self::Gemini(GeminiVersion::Flash25)),
"gemini-2.5-pro" => Ok(Self::Gemini(GeminiVersion::Pro25)),
// MaaS models (Model as a Service from Model Garden)
_ if s.ends_with("-maas") => {
let publisher = s
.split('-')
.next()
.ok_or_else(|| ModelError::UnsupportedModel(s.to_string()))?
.to_string();
Ok(Self::MaaS(publisher, s.to_string()))
}
// Generic models based on prefix matching
_ if s.starts_with("claude-") => {
Ok(Self::Claude(ClaudeVersion::Generic(s.to_string())))
Expand Down Expand Up @@ -202,28 +217,32 @@ impl RequestContext {

/// Returns the provider associated with the model.
pub fn provider(&self) -> ModelProvider {
match self.model {
match &self.model {
GcpVertexAIModel::Claude(_) => ModelProvider::Anthropic,
GcpVertexAIModel::Gemini(_) => ModelProvider::Google,
GcpVertexAIModel::MaaS(publisher, _) => ModelProvider::MaaS(publisher.clone()),
}
}
}

/// Represents available model providers.
#[derive(Debug, Clone, PartialEq, Eq, Copy)]
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum ModelProvider {
/// Anthropic provider (Claude models)
Anthropic,
/// Google provider (Gemini models)
Google,
/// MaaS provider (Model as a Service from Model Garden)
MaaS(String),
}

impl ModelProvider {
/// Returns the string representation of the provider.
pub fn as_str(&self) -> &'static str {
pub fn as_str(&self) -> String {
match self {
Self::Anthropic => "anthropic",
Self::Google => "google",
Self::Anthropic => "anthropic".to_string(),
Self::Google => "google".to_string(),
Self::MaaS(publisher) => publisher.clone(),
}
}
}
Expand Down Expand Up @@ -305,6 +324,12 @@ pub fn create_request(
GcpVertexAIModel::Gemini(_) => {
create_google_request(model_config, system, messages, tools)?
}
GcpVertexAIModel::MaaS(_, _) => {
// TODO: Branch on publisher for format selection once we know which
// MaaS providers use which formats (e.g., OpenAI vs Google format)
// For now, default to Google format since most use generateContent endpoint
create_google_request(model_config, system, messages, tools)?
}
};

Ok((request, context))
Expand All @@ -322,6 +347,7 @@ pub fn response_to_message(response: Value, request_context: RequestContext) ->
match request_context.provider() {
ModelProvider::Anthropic => anthropic::response_to_message(&response),
ModelProvider::Google => google::response_to_message(response),
ModelProvider::MaaS(_) => google::response_to_message(response),
}
}

Expand All @@ -337,6 +363,7 @@ pub fn get_usage(data: &Value, request_context: &RequestContext) -> Result<Usage
match request_context.provider() {
ModelProvider::Anthropic => anthropic::get_usage(data),
ModelProvider::Google => google::get_usage(data),
ModelProvider::MaaS(_) => google::get_usage(data),
}
}

Expand Down
9 changes: 7 additions & 2 deletions crates/goose/src/providers/gcpvertexai.rs
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,7 @@ impl GcpVertexAIProvider {
let endpoint = match provider {
ModelProvider::Anthropic => "streamRawPredict",
ModelProvider::Google => "generateContent",
ModelProvider::MaaS(_) => "generateContent",
};

// Construct path for URL
Expand Down Expand Up @@ -586,8 +587,12 @@ mod tests {

#[test]
fn test_model_provider_conversion() {
assert_eq!(ModelProvider::Anthropic.as_str(), "anthropic");
assert_eq!(ModelProvider::Google.as_str(), "google");
assert_eq!(ModelProvider::Anthropic.as_str(), "anthropic".to_string());
assert_eq!(ModelProvider::Google.as_str(), "google".to_string());
assert_eq!(
ModelProvider::MaaS("qwen".to_string()).as_str(),
"qwen".to_string()
);
}

#[test]
Expand Down