Skip to content
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ tmp/
# will have compiled files and executables
debug/
target/
.goose/

# These are backup files generated by rustfmt
**/*.rs.bk
Expand Down
3 changes: 3 additions & 0 deletions crates/goose/src/model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,9 @@ static MODEL_SPECIFIC_LIMITS: Lazy<HashMap<&'static str, usize>> = Lazy::new(||
// Meta Llama models, https://github.com/meta-llama/llama-models/tree/main?tab=readme-ov-file#llama-models-1
map.insert("llama3.2", 128_000);
map.insert("llama3.3", 128_000);

// x.ai Grok models, https://docs.x.ai/docs/overview
map.insert("grok", 131_072);
map
});

Expand Down
5 changes: 4 additions & 1 deletion crates/goose/src/providers/factory.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ use super::{
sagemaker_tgi::SageMakerTgiProvider,
snowflake::SnowflakeProvider,
venice::VeniceProvider,
xai::XaiProvider,
};
use crate::model::ModelConfig;
use anyhow::Result;
Expand Down Expand Up @@ -52,6 +53,7 @@ pub fn providers() -> Vec<ProviderMetadata> {
SageMakerTgiProvider::metadata(),
VeniceProvider::metadata(),
SnowflakeProvider::metadata(),
XaiProvider::metadata(),
]
}

Expand Down Expand Up @@ -128,6 +130,7 @@ fn create_provider(name: &str, model: ModelConfig) -> Result<Arc<dyn Provider>>
"venice" => Ok(Arc::new(VeniceProvider::from_env(model)?)),
"snowflake" => Ok(Arc::new(SnowflakeProvider::from_env(model)?)),
"github_copilot" => Ok(Arc::new(GithubCopilotProvider::from_env(model)?)),
"xai" => Ok(Arc::new(XaiProvider::from_env(model)?)),
_ => Err(anyhow::anyhow!("Unknown provider: {}", name)),
}
}
Expand Down Expand Up @@ -259,7 +262,7 @@ mod tests {
}

// Set only the required lead model
env::set_var("GOOSE_LEAD_MODEL", "gpt-4o");
env::set_var("GOOSE_LEAD_MODEL", "grok-3");

// This should use defaults for all other values
let result = create("openai", ModelConfig::new("gpt-4o-mini".to_string()));
Expand Down
1 change: 1 addition & 0 deletions crates/goose/src/providers/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,5 +24,6 @@ pub mod toolshim;
pub mod utils;
pub mod utils_universal_openai_stream;
pub mod venice;
pub mod xai;

pub use factory::{create, providers};
181 changes: 181 additions & 0 deletions crates/goose/src/providers/xai.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,181 @@
use super::errors::ProviderError;
use crate::message::Message;
use crate::model::ModelConfig;
use crate::providers::base::{ConfigKey, Provider, ProviderMetadata, ProviderUsage, Usage};
use crate::providers::formats::openai::{create_request, get_usage, response_to_message};
use crate::providers::utils::get_model;
use anyhow::Result;
use async_trait::async_trait;
use mcp_core::Tool;
use reqwest::{Client, StatusCode};
use serde_json::Value;
use std::time::Duration;
use url::Url;

pub const XAI_API_HOST: &str = "https://api.x.ai/v1";
pub const XAI_DEFAULT_MODEL: &str = "grok-3";
pub const XAI_KNOWN_MODELS: &[&str] = &[
"grok-3",
"grok-3-fast",
"grok-3-mini",
"grok-3-mini-fast",
"grok-2-vision-1212",
"grok-2-image-1212",
"grok-2-1212",
"grok-3-latest",
"grok-3-fast-latest",
"grok-3-mini-latest",
"grok-3-mini-fast-latest",
"grok-2-vision",
"grok-2-vision-latest",
"grok-2-image",
"grok-2-image-latest",
"grok-2",
"grok-2-latest",
];

pub const XAI_DOC_URL: &str = "https://docs.x.ai/docs/overview";

#[derive(serde::Serialize)]
pub struct XaiProvider {
#[serde(skip)]
client: Client,
host: String,
api_key: String,
model: ModelConfig,
}

impl Default for XaiProvider {
fn default() -> Self {
let model = ModelConfig::new(XaiProvider::metadata().default_model);
XaiProvider::from_env(model).expect("Failed to initialize xAI provider")
}
}

impl XaiProvider {
pub fn from_env(model: ModelConfig) -> Result<Self> {
let config = crate::config::Config::global();
let api_key: String = config.get_secret("XAI_API_KEY")?;
let host: String = config
.get_param("XAI_HOST")
.unwrap_or_else(|_| XAI_API_HOST.to_string());

let client = Client::builder()
.timeout(Duration::from_secs(600))
.build()?;

Ok(Self {
client,
host,
api_key,
model,
})
}

async fn post(&self, payload: Value) -> anyhow::Result<Value, ProviderError> {
// Ensure the host ends with a slash for proper URL joining
let host = if self.host.ends_with('/') {
self.host.clone()
} else {
format!("{}/", self.host)
};
let base_url = Url::parse(&host)
.map_err(|e| ProviderError::RequestFailed(format!("Invalid base URL: {e}")))?;
let url = base_url.join("chat/completions").map_err(|e| {
ProviderError::RequestFailed(format!("Failed to construct endpoint URL: {e}"))
})?;

tracing::debug!("xAI API URL: {}", url);
tracing::debug!("xAI request model: {:?}", self.model.model_name);

let response = self
.client
.post(url)
.header("Authorization", format!("Bearer {}", self.api_key))
.json(&payload)
.send()
.await?;

let status = response.status();
let payload: Option<Value> = response.json().await.ok();

match status {
StatusCode::OK => payload.ok_or_else( || ProviderError::RequestFailed("Response body is not valid JSON".to_string()) ),
StatusCode::UNAUTHORIZED | StatusCode::FORBIDDEN => {
Err(ProviderError::Authentication(format!("Authentication failed. Please ensure your API keys are valid and have the required permissions. \
Status: {}. Response: {:?}", status, payload)))
}
StatusCode::PAYLOAD_TOO_LARGE => {
Err(ProviderError::ContextLengthExceeded(format!("{:?}", payload)))
}
StatusCode::TOO_MANY_REQUESTS => {
Err(ProviderError::RateLimitExceeded(format!("{:?}", payload)))
}
StatusCode::INTERNAL_SERVER_ERROR | StatusCode::SERVICE_UNAVAILABLE => {
Err(ProviderError::ServerError(format!("{:?}", payload)))
}
_ => {
tracing::debug!(
"{}", format!("Provider request failed with status: {}. Payload: {:?}", status, payload)
);
Err(ProviderError::RequestFailed(format!("Request failed with status: {}", status)))
}
}
}
}

#[async_trait]
impl Provider for XaiProvider {
fn metadata() -> ProviderMetadata {
ProviderMetadata::new(
"xai",
"xAI",
"Grok models from xAI, including reasoning and multimodal capabilities",
XAI_DEFAULT_MODEL,
XAI_KNOWN_MODELS.to_vec(),
XAI_DOC_URL,
vec![
ConfigKey::new("XAI_API_KEY", true, true, None),
ConfigKey::new("XAI_HOST", false, false, Some(XAI_API_HOST)),
],
)
}

fn get_model_config(&self) -> ModelConfig {
self.model.clone()
}

#[tracing::instrument(
skip(self, system, messages, tools),
fields(model_config, input, output, input_tokens, output_tokens, total_tokens)
)]
async fn complete(
&self,
system: &str,
messages: &[Message],
tools: &[Tool],
) -> anyhow::Result<(Message, ProviderUsage), ProviderError> {
let payload = create_request(
&self.model,
system,
messages,
tools,
&super::utils::ImageFormat::OpenAi,
)?;

let response = self.post(payload.clone()).await?;

let message = response_to_message(response.clone())?;
let usage = match get_usage(&response) {
Ok(usage) => usage,
Err(ProviderError::UsageError(e)) => {
tracing::debug!("Failed to get usage data: {}", e);
Usage::default()
}
Err(e) => return Err(e),
};
let model = get_model(&response);
super::utils::emit_debug_trace(&self.model, &payload, &response, &usage);
Ok((message, ProviderUsage::new(model, usage)))
}
}
1 change: 1 addition & 0 deletions documentation/docs/getting-started/providers.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ Goose relies heavily on tool calling capabilities and currently works best with
| [OpenRouter](https://openrouter.ai/) | API gateway for unified access to various models with features like rate-limiting management. | `OPENROUTER_API_KEY` |
| [Snowflake](https://docs.snowflake.com/user-guide/snowflake-cortex/aisql#choosing-a-model) | Access the latest models using Snowflake Cortex services, including Claude models. **Requires a Snowflake account and programmatic access token (PAT)**. | `SNOWFLAKE_HOST`, `SNOWFLAKE_TOKEN` |
| [Venice AI](https://venice.ai/home) | Provides access to open source models like Llama, Mistral, and Qwen while prioritizing user privacy. **Requires an account and an [API key](https://docs.venice.ai/overview/guides/generating-api-key)**. | `VENICE_API_KEY`, `VENICE_HOST` (optional), `VENICE_BASE_PATH` (optional), `VENICE_MODELS_PATH` (optional) |
| [xAI](https://x.ai/) | Access to xAI's Grok models including grok-3, grok-3-mini, and grok-3-fast with 131,072 token context window. | `XAI_API_KEY`, `XAI_HOST` (optional) |


## Configure Provider
Expand Down
19 changes: 19 additions & 0 deletions ui/desktop/src/components/settings/providers/ProviderRegistry.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,25 @@ export const PROVIDER_REGISTRY: ProviderRegistry[] = [
],
},
},
{
name: 'xAI',
details: {
id: 'xai',
name: 'xAI',
description: 'Access Grok models from xAI, including reasoning and multimodal capabilities',
parameters: [
{
name: 'XAI_API_KEY',
is_secret: true,
},
{
name: 'XAI_HOST',
is_secret: false,
default: 'https://api.x.ai/v1',
},
],
},
},
{
name: 'Google',
details: {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import OllamaLogo from './icons/ollama@3x.png';
import DatabricksLogo from './icons/databricks@3x.png';
import OpenRouterLogo from './icons/openrouter@3x.png';
import SnowflakeLogo from './icons/snowflake@3x.png';
import XaiLogo from './icons/xai@3x.png';
import DefaultLogo from './icons/default@3x.png';

// Map provider names to their logos
Expand All @@ -18,6 +19,7 @@ const providerLogos: Record<string, string> = {
databricks: DatabricksLogo,
openrouter: OpenRouterLogo,
snowflake: SnowflakeLogo,
xai: XaiLogo,
default: DefaultLogo,
};

Expand All @@ -30,10 +32,24 @@ export default function ProviderLogo({ providerName }: ProviderLogoProps) {
const logoKey = providerName.toLowerCase();
const logo = providerLogos[logoKey] || DefaultLogo;

// Special handling for xAI logo
const isXai = logoKey === 'xai';
const imageStyle = isXai ? { filter: 'invert(1)', opacity: 0.9 } : {};

// Use smaller size for xAI logo to fit better in circle
const imageClassName = isXai
? 'w-8 h-8 object-contain' // Smaller size for xAI
: 'w-16 h-16 object-contain'; // Default size for others

return (
<div className="flex justify-center mb-2">
<div className="w-12 h-12 bg-black rounded-full overflow-hidden flex items-center justify-center">
<img src={logo} alt={`${providerName} logo`} className="w-16 h-16 object-contain" />
<img
src={logo}
alt={`${providerName} logo`}
className={imageClassName}
style={imageStyle}
/>
</div>
</div>
);
Expand Down
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading