Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions crates/goose-server/src/routes/action_required.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use crate::routes::errors::ErrorResponse;
use crate::state::AppState;
use axum::{extract::State, http::StatusCode, routing::post, Json, Router};
use axum::{extract::State, routing::post, Json, Router};
use goose::permission::permission_confirmation::PrincipalType;
use goose::permission::{Permission, PermissionConfirmation};
use serde::{Deserialize, Serialize};
Expand Down Expand Up @@ -34,7 +35,7 @@ fn default_principal_type() -> PrincipalType {
pub async fn confirm_tool_action(
State(state): State<Arc<AppState>>,
Json(request): Json<ConfirmToolActionRequest>,
) -> Result<Json<Value>, StatusCode> {
) -> Result<Json<Value>, ErrorResponse> {
let agent = state.get_agent_for_route(request.session_id).await?;
let permission = match request.action.as_str() {
"always_allow" => Permission::AlwaysAllow,
Expand Down Expand Up @@ -72,6 +73,7 @@ mod tests {
mod integration_tests {
use super::*;
use axum::{body::Body, http::Request};
use http::StatusCode;
use tower::ServiceExt;

#[tokio::test(flavor = "multi_thread")]
Expand Down
152 changes: 91 additions & 61 deletions crates/goose-server/src/routes/audio.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
///
/// This module provides endpoints for audio transcription using OpenAI's Whisper API.
/// The OpenAI API key must be configured in the backend for this to work.
use crate::routes::errors::ErrorResponse;
use crate::state::AppState;
use axum::{
http::StatusCode,
Expand Down Expand Up @@ -44,18 +45,22 @@ struct WhisperResponse {
fn validate_audio_input(
audio: &str,
mime_type: &str,
) -> Result<(Vec<u8>, &'static str), StatusCode> {
) -> Result<(Vec<u8>, &'static str), ErrorResponse> {
// Decode the base64 audio data
let audio_bytes = BASE64.decode(audio).map_err(|_| StatusCode::BAD_REQUEST)?;
let audio_bytes = BASE64
.decode(audio)
.map_err(|_| ErrorResponse::bad_request("Invalid base64 audio data"))?;

// Check file size
if audio_bytes.len() > MAX_AUDIO_SIZE_BYTES {
tracing::warn!(
"Audio file too large: {} bytes (max: {} bytes)",
audio_bytes.len(),
MAX_AUDIO_SIZE_BYTES
);
return Err(StatusCode::PAYLOAD_TOO_LARGE);
return Err(ErrorResponse {
message: format!(
"Audio file too large: {} bytes (max: {} bytes)",
audio_bytes.len(),
MAX_AUDIO_SIZE_BYTES
),
status: StatusCode::PAYLOAD_TOO_LARGE,
});
}

// Determine file extension based on MIME type
Expand All @@ -68,19 +73,28 @@ fn validate_audio_input(
"audio/m4a" => "m4a",
"audio/wav" => "wav",
"audio/x-wav" => "wav",
_ => return Err(StatusCode::UNSUPPORTED_MEDIA_TYPE),
_ => {
return Err(ErrorResponse {
message: format!("Unsupported audio format: {}", mime_type),
status: StatusCode::UNSUPPORTED_MEDIA_TYPE,
})
}
};

Ok((audio_bytes, file_extension))
}

/// Get OpenAI configuration (API key and host)
fn get_openai_config() -> Result<(String, String), StatusCode> {
fn get_openai_config() -> Result<(String, String), ErrorResponse> {
let config = goose::config::Config::global();

let api_key: String = config.get_secret("OPENAI_API_KEY").map_err(|e| {
tracing::error!("Failed to get OpenAI API key: {:?}", e);
StatusCode::PRECONDITION_FAILED
let api_key: String = config.get_secret("OPENAI_API_KEY").map_err(|e| match e {
goose::config::ConfigError::NotFound(_) => ErrorResponse {
message: "OpenAI API key not configured. Please set OPENAI_API_KEY in settings."
.to_string(),
status: StatusCode::PRECONDITION_FAILED,
},
_ => ErrorResponse::internal(format!("Failed to get OpenAI API key: {:?}", e)),
})?;

let openai_host = match config.get("OPENAI_HOST", false) {
Expand All @@ -101,7 +115,7 @@ async fn send_openai_request(
mime_type: &str,
api_key: &str,
openai_host: &str,
) -> Result<WhisperResponse, StatusCode> {
) -> Result<WhisperResponse, ErrorResponse> {
tracing::info!("Using OpenAI host: {}", openai_host);
tracing::info!(
"Audio file size: {} bytes, extension: {}, mime_type: {}",
Expand All @@ -115,8 +129,7 @@ async fn send_openai_request(
.file_name(format!("audio.{}", file_extension))
.mime_str(mime_type)
.map_err(|e| {
tracing::error!("Failed to create multipart part: {:?}", e);
StatusCode::INTERNAL_SERVER_ERROR
ErrorResponse::internal(format!("Failed to create multipart part: {:?}", e))
})?;

let form = reqwest::multipart::Form::new()
Expand All @@ -130,10 +143,7 @@ async fn send_openai_request(
let client = Client::builder()
.timeout(Duration::from_secs(OPENAI_TIMEOUT_SECONDS))
.build()
.map_err(|e| {
tracing::error!("Failed to create HTTP client: {}", e);
StatusCode::INTERNAL_SERVER_ERROR
})?;
.map_err(|e| ErrorResponse::internal(format!("Failed to create HTTP client: {}", e)))?;

tracing::info!(
"Sending request to OpenAI: {}/v1/audio/transcriptions",
Expand All @@ -148,14 +158,18 @@ async fn send_openai_request(
.await
.map_err(|e| {
if e.is_timeout() {
tracing::error!(
"OpenAI API request timed out after {}s",
OPENAI_TIMEOUT_SECONDS
);
StatusCode::GATEWAY_TIMEOUT
ErrorResponse {
message: format!(
"OpenAI API request timed out after {}s",
OPENAI_TIMEOUT_SECONDS
),
status: StatusCode::GATEWAY_TIMEOUT,
}
} else {
tracing::error!("Failed to send request to OpenAI: {}", e);
StatusCode::SERVICE_UNAVAILABLE
ErrorResponse {
message: format!("Failed to send request to OpenAI: {}", e),
status: StatusCode::SERVICE_UNAVAILABLE,
}
}
})?;

Expand All @@ -171,20 +185,27 @@ async fn send_openai_request(

// Check for specific error codes
if status == 401 {
tracing::error!("OpenAI API key appears to be invalid or unauthorized");
return Err(StatusCode::UNAUTHORIZED);
return Err(ErrorResponse {
message: "OpenAI API key appears to be invalid or unauthorized".to_string(),
status: StatusCode::UNAUTHORIZED,
});
} else if status == 429 {
tracing::error!("OpenAI API quota or rate limit exceeded");
return Err(StatusCode::TOO_MANY_REQUESTS);
return Err(ErrorResponse {
message: "OpenAI API quota or rate limit exceeded".to_string(),
status: StatusCode::TOO_MANY_REQUESTS,
});
}

return Err(StatusCode::BAD_GATEWAY);
return Err(ErrorResponse {
message: format!("OpenAI API error: {}", error_text),
status: StatusCode::BAD_GATEWAY,
});
}

let whisper_response: WhisperResponse = response.json().await.map_err(|e| {
tracing::error!("Failed to parse OpenAI response: {}", e);
StatusCode::INTERNAL_SERVER_ERROR
})?;
let whisper_response: WhisperResponse = response
.json()
.await
.map_err(|e| ErrorResponse::internal(format!("Failed to parse OpenAI response: {}", e)))?;

Ok(whisper_response)
}
Expand All @@ -208,7 +229,7 @@ async fn send_openai_request(
/// - 503: Service Unavailable (network error)
async fn transcribe_handler(
Json(request): Json<TranscribeRequest>,
) -> Result<Json<TranscribeResponse>, StatusCode> {
) -> Result<Json<TranscribeResponse>, ErrorResponse> {
let (audio_bytes, file_extension) = validate_audio_input(&request.audio, &request.mime_type)?;
let (api_key, openai_host) = get_openai_config()?;

Expand All @@ -232,7 +253,7 @@ async fn transcribe_handler(
/// Requires an ElevenLabs API key with speech-to-text access.
async fn transcribe_elevenlabs_handler(
Json(request): Json<TranscribeElevenLabsRequest>,
) -> Result<Json<TranscribeResponse>, StatusCode> {
) -> Result<Json<TranscribeResponse>, ErrorResponse> {
let (audio_bytes, file_extension) = validate_audio_input(&request.audio, &request.mime_type)?;

// Get the ElevenLabs API key from config (after input validation)
Expand Down Expand Up @@ -266,17 +287,17 @@ async fn transcribe_elevenlabs_handler(
key
}
None => {
tracing::error!(
return Err(ErrorResponse::bad_request(format!(
"ElevenLabs API key is not a string, found: {:?}",
value
);
return Err(StatusCode::PRECONDITION_FAILED);
)));
Comment on lines +290 to +293
Copy link

Copilot AI Jan 28, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This error message includes the raw value for ELEVENLABS_API_KEY when it isn’t a string, which can leak a secret back to the client; avoid echoing the stored value (return a generic message and keep details in server logs only).

Copilot uses AI. Check for mistakes.
}
}
}
Err(_) => {
tracing::error!("No ElevenLabs API key found in configuration");
return Err(StatusCode::PRECONDITION_FAILED);
return Err(ErrorResponse::bad_request(
"No ElevenLabs API key found in configuration",
));
}
}
}
Expand All @@ -286,7 +307,7 @@ async fn transcribe_elevenlabs_handler(
let part = reqwest::multipart::Part::bytes(audio_bytes)
.file_name(format!("audio.{}", file_extension))
.mime_str(&request.mime_type)
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
.map_err(|_| ErrorResponse::internal("Failed to create multipart part"))?;

let form = reqwest::multipart::Form::new()
.part("file", part) // Changed from "audio" to "file"
Expand All @@ -298,10 +319,7 @@ async fn transcribe_elevenlabs_handler(
let client = Client::builder()
.timeout(Duration::from_secs(OPENAI_TIMEOUT_SECONDS))
.build()
.map_err(|e| {
tracing::error!("Failed to create HTTP client: {}", e);
StatusCode::INTERNAL_SERVER_ERROR
})?;
.map_err(|e| ErrorResponse::internal(format!("Failed to create HTTP client: {}", e)))?;

let response = client
.post("https://api.elevenlabs.io/v1/speech-to-text")
Expand All @@ -311,14 +329,18 @@ async fn transcribe_elevenlabs_handler(
.await
.map_err(|e| {
if e.is_timeout() {
tracing::error!(
"ElevenLabs API request timed out after {}s",
OPENAI_TIMEOUT_SECONDS
);
StatusCode::GATEWAY_TIMEOUT
ErrorResponse {
message: format!(
"ElevenLabs API request timed out after {}s",
OPENAI_TIMEOUT_SECONDS
),
status: StatusCode::GATEWAY_TIMEOUT,
}
} else {
tracing::error!("Failed to send request to ElevenLabs: {}", e);
StatusCode::SERVICE_UNAVAILABLE
ErrorResponse {
message: format!("Failed to send request to ElevenLabs: {}", e),
status: StatusCode::SERVICE_UNAVAILABLE,
}
}
})?;

Expand All @@ -329,12 +351,21 @@ async fn transcribe_elevenlabs_handler(

// Check for specific error codes
if error_text.contains("Unauthorized") || error_text.contains("Invalid API key") {
return Err(StatusCode::UNAUTHORIZED);
return Err(ErrorResponse {
message: "ElevenLabs API key is invalid or unauthorized".to_string(),
status: StatusCode::UNAUTHORIZED,
});
} else if error_text.contains("quota") || error_text.contains("limit") {
return Err(StatusCode::PAYMENT_REQUIRED);
return Err(ErrorResponse {
message: "ElevenLabs API quota or rate limit exceeded".to_string(),
status: StatusCode::PAYMENT_REQUIRED,
});
}

return Err(StatusCode::BAD_GATEWAY);
return Err(ErrorResponse {
message: format!("ElevenLabs API error: {}", error_text),
status: StatusCode::BAD_GATEWAY,
});
}

// Parse ElevenLabs response
Expand All @@ -347,8 +378,7 @@ async fn transcribe_elevenlabs_handler(
}

let elevenlabs_response: ElevenLabsResponse = response.json().await.map_err(|e| {
tracing::error!("Failed to parse ElevenLabs response: {}", e);
StatusCode::INTERNAL_SERVER_ERROR
ErrorResponse::internal(format!("Failed to parse ElevenLabs response: {}", e))
})?;

Ok(Json(TranscribeResponse {
Expand All @@ -359,7 +389,7 @@ async fn transcribe_elevenlabs_handler(
/// Check if dictation providers are configured
///
/// Returns configuration status for dictation providers
async fn check_dictation_config() -> Result<Json<serde_json::Value>, StatusCode> {
async fn check_dictation_config() -> Result<Json<serde_json::Value>, ErrorResponse> {
let config = goose::config::Config::global();

// Check if ElevenLabs API key is configured
Expand Down
Loading
Loading