diff --git a/crates/goose/src/providers/ollama.rs b/crates/goose/src/providers/ollama.rs index 1479a847fe73..95c57a65c414 100644 --- a/crates/goose/src/providers/ollama.rs +++ b/crates/goose/src/providers/ollama.rs @@ -1,21 +1,30 @@ use super::api_client::{ApiClient, AuthMethod}; -use super::base::{ConfigKey, Provider, ProviderMetadata, ProviderUsage, Usage}; +use super::base::{ConfigKey, MessageStream, Provider, ProviderMetadata, ProviderUsage, Usage}; use super::errors::ProviderError; use super::retry::ProviderRetry; -use super::utils::{get_model, handle_response_openai_compat}; +use super::utils::{get_model, handle_response_openai_compat, handle_status_openai_compat}; use crate::config::custom_providers::CustomProviderConfig; use crate::conversation::message::Message; use crate::conversation::Conversation; use crate::impl_provider_default; use crate::model::ModelConfig; -use crate::providers::formats::openai::{create_request, get_usage, response_to_message}; +use crate::providers::formats::openai::{ + create_request, get_usage, response_to_message, response_to_streaming_message, +}; use crate::utils::safe_truncate; use anyhow::Result; +use async_stream::try_stream; use async_trait::async_trait; +use futures::TryStreamExt; use regex::Regex; use rmcp::model::Tool; -use serde_json::Value; +use serde_json::{json, Value}; +use std::io; use std::time::Duration; +use tokio::pin; +use tokio_stream::StreamExt; +use tokio_util::codec::{FramedRead, LinesCodec}; +use tokio_util::io::StreamReader; use url::Url; pub const OLLAMA_HOST: &str = "localhost"; @@ -78,7 +87,7 @@ impl OllamaProvider { Ok(Self { api_client, model, - supports_streaming: false, + supports_streaming: true, }) } @@ -228,6 +237,45 @@ impl Provider for OllamaProvider { fn supports_streaming(&self) -> bool { self.supports_streaming } + + async fn stream( + &self, + system: &str, + messages: &[Message], + tools: &[Tool], + ) -> Result { + let mut payload = create_request( + &self.model, + system, + messages, + tools, + &super::utils::ImageFormat::OpenAi, + )?; + payload["stream"] = json!(true); + payload["stream_options"] = json!({ + "include_usage": true, + }); + + let response = self + .api_client + .response_post("v1/chat/completions", &payload) + .await?; + let response = handle_status_openai_compat(response).await?; + let stream = response.bytes_stream().map_err(io::Error::other); + let model_config = self.model.clone(); + + Ok(Box::pin(try_stream! { + let stream_reader = StreamReader::new(stream); + let framed = FramedRead::new(stream_reader, LinesCodec::new()).map_err(anyhow::Error::from); + let message_stream = response_to_streaming_message(framed); + pin!(message_stream); + while let Some(message) = message_stream.next().await { + let (message, usage) = message.map_err(|e| ProviderError::RequestFailed(format!("Stream decode error: {}", e)))?; + super::utils::emit_debug_trace(&model_config, &payload, &message, &usage.as_ref().map(|f| f.usage).unwrap_or_default()); + yield (message, usage); + } + })) + } } impl OllamaProvider {