Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 3 additions & 6 deletions crates/goose/examples/tetrate_auth.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,10 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
// Create new PKCE auth flow
let mut auth_flow = TetrateAuth::new()?;

// Get the auth URL that would be opened
let auth_url = auth_flow.get_auth_url();
println!("Auth URL: {}", auth_url);
println!("\nStarting authentication flow...");
println!("Starting authentication flow...");
println!("This will:");
println!("1. Open your browser to the auth page");
println!("2. Start a local server on port 3000");
println!("1. Start a local server on a dynamic port");
println!("2. Open your browser to the auth page");
println!("3. Wait for the callback\n");

// Complete the full flow
Expand Down
56 changes: 26 additions & 30 deletions crates/goose/src/config/signup_tetrate/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ pub const TETRATE_DEFAULT_MODEL: &str = "claude-haiku-4-5";
// Auth endpoints are on the main web domain
const TETRATE_AUTH_URL: &str = "https://router.tetrate.ai/auth";
const TETRATE_TOKEN_URL: &str = "https://router.tetrate.ai/api/api-keys/verify";
const CALLBACK_URL: &str = "http://localhost:3000";
const CALLBACK_BASE: &str = "http://localhost";
const AUTH_TIMEOUT: Duration = Duration::from_secs(180); // 3 minutes

#[derive(Debug)]
Expand Down Expand Up @@ -61,38 +61,16 @@ impl PkceAuthFlow {
})
}

pub fn get_auth_url(&self) -> String {
pub fn get_auth_url(&self, port: u16) -> String {
let callback_url = format!("{}:{}", CALLBACK_BASE, port);
format!(
"{}?callback={}&code_challenge={}",
TETRATE_AUTH_URL,
urlencoding::encode(CALLBACK_URL),
urlencoding::encode(&callback_url),
urlencoding::encode(&self.code_challenge)
)
}

/// Start local server and wait for callback
pub async fn start_server(&mut self) -> Result<String> {
let (code_tx, code_rx) = oneshot::channel::<String>();
let (shutdown_tx, shutdown_rx) = oneshot::channel::<()>();

// Store shutdown sender so we can stop the server later
self.server_shutdown_tx = Some(shutdown_tx);

// Start the server in a background task
tokio::spawn(async move {
if let Err(e) = server::run_callback_server(code_tx, shutdown_rx).await {
eprintln!("Server error: {}", e);
}
});

// Wait for the authorization code with timeout
match timeout(AUTH_TIMEOUT, code_rx).await {
Ok(Ok(code)) => Ok(code),
Ok(Err(_)) => Err(anyhow!("Failed to receive authorization code")),
Err(_) => Err(anyhow!("Authentication timeout - please try again")),
}
}

pub async fn exchange_code(&self, code: String) -> Result<String> {
let client = Client::new();

Expand Down Expand Up @@ -131,9 +109,22 @@ impl PkceAuthFlow {
Ok(token_response.key)
}

/// Complete flow: open browser, wait for callback, exchange code
/// Complete flow: start server, open browser, wait for callback, exchange code
pub async fn complete_flow(&mut self) -> Result<String> {
let auth_url = self.get_auth_url();
let listener = tokio::net::TcpListener::bind(("127.0.0.1", 0)).await?;
let port = listener.local_addr()?.port();

let (code_tx, code_rx) = oneshot::channel::<String>();
let (shutdown_tx, shutdown_rx) = oneshot::channel::<()>();
self.server_shutdown_tx = Some(shutdown_tx);

tokio::spawn(async move {
if let Err(e) = server::run_callback_server(listener, code_tx, shutdown_rx).await {
eprintln!("Server error: {}", e);
}
});

let auth_url = self.get_auth_url(port);

println!("Opening browser for Tetrate Agent Router Service authentication...");
eprintln!("Auth URL: {}", auth_url);
Expand All @@ -143,8 +134,13 @@ impl PkceAuthFlow {
println!("Please open this URL manually: {}", auth_url);
}

println!("Waiting for authentication callback...");
let code = self.start_server().await?;
println!("Waiting for authentication callback on port {}...", port);

let code = match timeout(AUTH_TIMEOUT, code_rx).await {
Ok(Ok(code)) => Ok(code),
Ok(Err(_)) => Err(anyhow!("Failed to receive authorization code")),
Err(_) => Err(anyhow!("Authentication timeout - please try again")),
}?;

println!("Authorization code received. Exchanging for API key...");
eprintln!("Received code: {}", code);
Expand Down
6 changes: 2 additions & 4 deletions crates/goose/src/config/signup_tetrate/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ use axum::{
use include_dir::{include_dir, Dir};
use minijinja::{context, Environment};
use serde::Deserialize;
use std::net::SocketAddr;
use tokio::sync::oneshot;

static TEMPLATES_DIR: Dir = include_dir!("$CARGO_MANIFEST_DIR/src/config/signup_tetrate/templates");
Expand All @@ -20,14 +19,13 @@ struct CallbackQuery {
error: Option<String>,
}

/// Run the callback server on localhost:3000
/// Run the callback server using the provided listener.
pub async fn run_callback_server(
listener: tokio::net::TcpListener,
code_tx: oneshot::Sender<String>,
shutdown_rx: oneshot::Receiver<()>,
) -> Result<()> {
let app = Router::new().route("/", get(handle_callback));
let addr = SocketAddr::from(([127, 0, 0, 1], 3000));
let listener = tokio::net::TcpListener::bind(addr).await?;
let state = std::sync::Arc::new(tokio::sync::Mutex::new(Some(code_tx)));

axum::serve(listener, app.with_state(state.clone()).into_make_service())
Expand Down
7 changes: 4 additions & 3 deletions crates/goose/src/config/signup_tetrate/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,15 +34,16 @@ fn test_code_challenge_generation() {
#[test]
fn test_auth_url_generation() {
let flow = PkceAuthFlow::new().unwrap();
let auth_url = flow.get_auth_url();
let auth_url = flow.get_auth_url(12345);

// Verify URL contains required parameters
assert!(auth_url.contains("callback="));
assert!(auth_url.contains("code_challenge="));
assert!(auth_url.starts_with(TETRATE_AUTH_URL));

// Verify callback URL is properly encoded
assert!(auth_url.contains(&*urlencoding::encode(CALLBACK_URL)));
// Verify callback URL contains the dynamic port
let expected_callback = format!("{}:{}", CALLBACK_BASE, 12345);
assert!(auth_url.contains(&*urlencoding::encode(&expected_callback)));
}

#[test]
Expand Down
Loading