diff --git a/crates/goose/src/providers/tetrate.rs b/crates/goose/src/providers/tetrate.rs index 2aacfc9e0b9c..51951bae8160 100644 --- a/crates/goose/src/providers/tetrate.rs +++ b/crates/goose/src/providers/tetrate.rs @@ -1,14 +1,22 @@ use anyhow::Result; +use async_stream::try_stream; use async_trait::async_trait; -use serde_json::Value; +use futures::TryStreamExt; +use serde_json::{json, Value}; +use std::io; +use tokio::pin; +use tokio_stream::StreamExt; +use tokio_util::codec::{FramedRead, LinesCodec}; +use tokio_util::io::StreamReader; use super::api_client::{ApiClient, AuthMethod}; -use super::base::{ConfigKey, Provider, ProviderMetadata, ProviderUsage, Usage}; +use super::base::{ConfigKey, MessageStream, Provider, ProviderMetadata, ProviderUsage, Usage}; use super::errors::ProviderError; +use super::formats::openai::response_to_streaming_message; use super::retry::ProviderRetry; use super::utils::{ emit_debug_trace, get_model, handle_response_google_compat, handle_response_openai_compat, - is_google_model, + handle_status_openai_compat, is_google_model, }; use crate::config::signup_tetrate::TETRATE_DEFAULT_MODEL; use crate::conversation::message::Message; @@ -38,6 +46,7 @@ pub struct TetrateProvider { #[serde(skip)] api_client: ApiClient, model: ModelConfig, + supports_streaming: bool, } impl_provider_default!(TetrateProvider); @@ -56,7 +65,11 @@ impl TetrateProvider { .with_header("HTTP-Referer", "https://block.github.io/goose")? .with_header("X-Title", "Goose")?; - Ok(Self { api_client, model }) + Ok(Self { + api_client, + model, + supports_streaming: true, + }) } async fn post(&self, payload: &Value) -> Result { @@ -178,6 +191,49 @@ impl Provider for TetrateProvider { Ok((message, ProviderUsage::new(model, usage))) } + async fn stream( + &self, + system: &str, + messages: &[Message], + tools: &[Tool], + ) -> Result { + let mut payload = create_request( + &self.model, + system, + messages, + tools, + &super::utils::ImageFormat::OpenAi, + )?; + + // Enable streaming + payload["stream"] = json!(true); + payload["stream_options"] = json!({ + "include_usage": true, + }); + + let response = self + .api_client + .response_post("v1/chat/completions", &payload) + .await?; + + let response = handle_status_openai_compat(response).await?; + let stream = response.bytes_stream().map_err(io::Error::other); + let model_config = self.model.clone(); + + Ok(Box::pin(try_stream! { + let stream_reader = StreamReader::new(stream); + let framed = FramedRead::new(stream_reader, LinesCodec::new()).map_err(anyhow::Error::from); + + let message_stream = response_to_streaming_message(framed); + pin!(message_stream); + while let Some(message) = message_stream.next().await { + let (message, usage) = message.map_err(|e| ProviderError::RequestFailed(format!("Stream decode error: {}", e)))?; + emit_debug_trace(&model_config, &payload, &message, &usage.as_ref().map(|f| f.usage).unwrap_or_default()); + yield (message, usage); + } + })) + } + /// Fetch supported models from Tetrate Agent Router Service API (only models with tool support) async fn fetch_supported_models(&self) -> Result>, ProviderError> { // Use the existing api_client which already has authentication configured @@ -251,4 +307,8 @@ impl Provider for TetrateProvider { models.sort(); Ok(Some(models)) } + + fn supports_streaming(&self) -> bool { + self.supports_streaming + } } diff --git a/crates/goose/tests/tetrate_streaming.rs b/crates/goose/tests/tetrate_streaming.rs new file mode 100644 index 000000000000..27d5ea1f5183 --- /dev/null +++ b/crates/goose/tests/tetrate_streaming.rs @@ -0,0 +1,315 @@ +use anyhow::Result; +use futures::StreamExt; +use goose::conversation::message::{Message, MessageContent}; +use goose::model::ModelConfig; +use goose::providers::base::Provider; +use goose::providers::tetrate::TetrateProvider; +use rmcp::model::Tool; +use rmcp::object; +use serial_test::serial; + +/// Test module for Tetrate Agent Router Service streaming functionality +#[cfg(test)] +mod tetrate_streaming_tests { + use super::*; + + fn create_test_provider() -> Result { + // Create a test provider with the default model + let model_config = ModelConfig::new("claude-3-5-sonnet-latest")?; + TetrateProvider::from_env(model_config) + } + + #[tokio::test] + #[serial] + #[ignore] // Ignore by default, run with --ignored flag when API key is available + async fn test_tetrate_streaming_basic() -> Result<()> { + let provider = create_test_provider()?; + + let messages = vec![Message::user().with_text("Count from 1 to 5, one number at a time.")]; + + let mut stream = provider + .stream( + "You are a helpful assistant that counts numbers.", + &messages, + &[], + ) + .await?; + + let mut chunk_count = 0; + let mut content_chunks = Vec::new(); + + while let Some(result) = stream.next().await { + let (message, usage) = result?; + chunk_count += 1; + + if let Some(msg) = message { + let text = msg.as_concat_text(); + if !text.is_empty() { + content_chunks.push(text); + } + } + + // Check if we have usage information in the final chunk + if usage.is_some() { + println!("Received usage information in chunk {}", chunk_count); + } + } + + assert!(chunk_count > 0, "Should receive at least one chunk"); + assert!(!content_chunks.is_empty(), "Should receive some content"); + + let full_content = content_chunks.join(""); + println!("Full streamed content: {}", full_content); + + // Verify the response contains numbers + assert!( + full_content.contains('1'), + "Response should contain number 1" + ); + assert!( + full_content.contains('5'), + "Response should contain number 5" + ); + + Ok(()) + } + + #[tokio::test] + #[serial] + #[ignore] + async fn test_tetrate_streaming_with_tools() -> Result<()> { + let provider = create_test_provider()?; + + // Define a simple tool + let weather_tool = Tool::new( + "get_weather", + "Get the current weather for a location", + object!({ + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The city and state, e.g. San Francisco, CA" + } + }, + "required": ["location"] + }), + ); + + let messages = vec![Message::user().with_text("What's the weather in San Francisco?")]; + + let mut stream = provider + .stream( + "You are a helpful assistant with access to weather information.", + &messages, + &[weather_tool], + ) + .await?; + + let mut received_tool_call = false; + let mut chunk_count = 0; + + while let Some(result) = stream.next().await { + let (message, _usage) = result?; + chunk_count += 1; + + if let Some(msg) = message { + // Check if message contains tool requests + for content in &msg.content { + if matches!(content, MessageContent::ToolRequest(_)) { + received_tool_call = true; + println!("Received tool call in chunk {}", chunk_count); + } + } + } + } + + assert!(chunk_count > 0, "Should receive at least one chunk"); + // Note: Tool calls might not be supported in streaming for all models + // This is more of a capability test than a requirement + if received_tool_call { + println!("✓ Streaming with tools is supported"); + } else { + println!("⚠ Streaming with tools may not be fully supported"); + } + + Ok(()) + } + + #[tokio::test] + #[serial] + #[ignore] + async fn test_tetrate_streaming_empty_response() -> Result<()> { + let provider = create_test_provider()?; + + // This might result in a very short or empty response + let messages = vec![Message::user().with_text("")]; + + let mut stream = provider + .stream("You are a helpful assistant.", &messages, &[]) + .await?; + + let mut chunk_count = 0; + + while let Some(result) = stream.next().await { + let (_message, _usage) = result?; + chunk_count += 1; + } + + // Even with empty input, we should get at least one chunk (possibly with finish_reason) + assert!( + chunk_count > 0, + "Should receive at least one chunk even with empty input" + ); + + Ok(()) + } + + #[tokio::test] + #[serial] + #[ignore] + async fn test_tetrate_streaming_long_response() -> Result<()> { + let provider = create_test_provider()?; + + let messages = vec![Message::user().with_text( + "Write a detailed 3-paragraph essay about the importance of streaming in modern APIs.", + )]; + + let mut stream = provider + .stream( + "You are a helpful assistant that writes detailed essays.", + &messages, + &[], + ) + .await?; + + let mut chunk_count = 0; + let mut total_content_length = 0; + + while let Some(result) = stream.next().await { + let (message, usage) = result?; + chunk_count += 1; + + if let Some(msg) = message { + let text = msg.as_concat_text(); + total_content_length += text.len(); + } + + // Final chunk should have usage information + if let Some(usage_info) = usage { + println!("Final usage: {:?}", usage_info.usage); + assert!( + usage_info.usage.output_tokens.unwrap_or(0) > 0, + "Should have output tokens" + ); + } + } + + println!( + "Received {} chunks with total content length: {}", + chunk_count, total_content_length + ); + + // For a detailed essay, we expect multiple chunks and substantial content + assert!( + chunk_count > 5, + "Long response should be streamed in multiple chunks" + ); + assert!( + total_content_length > 100, + "Essay should have substantial content" + ); + + Ok(()) + } + + #[tokio::test] + #[serial] + async fn test_tetrate_streaming_error_handling() -> Result<()> { + // Test with invalid API key to ensure error handling works + std::env::set_var("TETRATE_API_KEY", "invalid-key-for-testing"); + + let model_config = ModelConfig::new("claude-3-5-sonnet-latest")?; + let provider = TetrateProvider::from_env(model_config)?; + + let messages = vec![Message::user().with_text("Hello")]; + + let result = provider + .stream("You are a helpful assistant.", &messages, &[]) + .await; + + // We expect this to fail with an authentication error + assert!(result.is_err(), "Should fail with invalid API key"); + + // Clean up + std::env::remove_var("TETRATE_API_KEY"); + + Ok(()) + } + + #[tokio::test] + #[serial] + #[ignore] + async fn test_tetrate_streaming_concurrent_streams() -> Result<()> { + let provider = create_test_provider()?; + + // Create multiple concurrent streams + let messages1 = vec![Message::user().with_text("Say 'Stream 1'")]; + let messages2 = vec![Message::user().with_text("Say 'Stream 2'")]; + + let stream1 = provider + .stream("You are a helpful assistant.", &messages1, &[]) + .await?; + + let stream2 = provider + .stream("You are a helpful assistant.", &messages2, &[]) + .await?; + + // Process both streams concurrently + let (result1, result2) = tokio::join!( + process_stream(stream1, "Stream 1"), + process_stream(stream2, "Stream 2") + ); + + let content1 = result1?; + let content2 = result2?; + + println!("Stream 1 content: {}", content1); + println!("Stream 2 content: {}", content2); + + assert!( + content1.contains("Stream 1") || content1.contains("1"), + "First stream should mention Stream 1" + ); + assert!( + content2.contains("Stream 2") || content2.contains("2"), + "Second stream should mention Stream 2" + ); + + Ok(()) + } + + // Helper function to process a stream and collect content + async fn process_stream( + mut stream: goose::providers::base::MessageStream, + label: &str, + ) -> Result { + let mut content = String::new(); + let mut chunk_count = 0; + + while let Some(result) = stream.next().await { + let (message, _usage) = result?; + chunk_count += 1; + + if let Some(msg) = message { + let text = msg.as_concat_text(); + if !text.is_empty() { + content.push_str(&text); + } + } + } + + println!("{}: Received {} chunks", label, chunk_count); + Ok(content) + } +}