diff --git a/crates/goose/src/providers/bedrock.rs b/crates/goose/src/providers/bedrock.rs index 84d7474310c2..debfc9098019 100644 --- a/crates/goose/src/providers/bedrock.rs +++ b/crates/goose/src/providers/bedrock.rs @@ -64,8 +64,34 @@ impl BedrockProvider { } }; + let filtered_secrets = config.all_secrets().map(|map| { + map.into_iter() + .filter(|(key, _)| key != "AWS_BEARER_TOKEN_BEDROCK") + .collect() + }); + set_aws_env_vars(config.all_values()); - set_aws_env_vars(config.all_secrets()); + set_aws_env_vars(filtered_secrets); + + // Check for bearer token first to determine if region is required + let bearer_token = match config.get_secret::("AWS_BEARER_TOKEN_BEDROCK") { + Ok(token) => { + let token = token.trim().to_string(); + if token.is_empty() { + None + } else { + Some(token) + } + } + Err(_) => None, + }; + + // Get AWS_REGION from config if explicitly set (optional - SDK can resolve from other sources) + let region = match config.get_param::("AWS_REGION") { + Ok(r) if !r.is_empty() => Some(r), + Ok(_) => None, + Err(_) => None, + }; // Use load_defaults() which supports AWS SSO, profiles, and environment variables let mut loader = aws_config::defaults(aws_config::BehaviorVersion::latest()); @@ -76,24 +102,37 @@ impl BedrockProvider { } } - // Check for AWS_REGION configuration - if let Ok(region) = config.get_param::("AWS_REGION") { - if !region.is_empty() { - loader = loader.region(aws_config::Region::new(region)); - } + // Apply region to loader if explicitly configured + if let Some(ref region) = region { + loader = loader.region(aws_config::Region::new(region.clone())); } let sdk_config = loader.load().await; - // Validate credentials or return error back up - sdk_config - .credentials_provider() - .ok_or_else(|| anyhow::anyhow!("No AWS credentials provider configured"))? - .provide_credentials() - .await - .map_err(|e| anyhow::anyhow!("Failed to load AWS credentials: {}. Make sure to run 'aws sso login --profile ' if using SSO", e))?; + // Validate region requirement for bearer token auth after SDK config is loaded + // This allows region to be resolved from ~/.aws/config, AWS_DEFAULT_REGION, etc. + if bearer_token.is_some() && sdk_config.region().is_none() { + return Err(anyhow::anyhow!( + "AWS region is required when using AWS_BEARER_TOKEN_BEDROCK authentication. \ + Set AWS_REGION, AWS_DEFAULT_REGION, or configure region in your AWS profile." + )); + } - let client = Client::new(&sdk_config); + let client = if let Some(bearer_token) = bearer_token { + // Build from sdk_config to inherit all settings (endpoint overrides, timeouts, etc.) + // then override authentication with bearer token + let bedrock_config = aws_sdk_bedrockruntime::Config::new(&sdk_config) + .to_builder() + .bearer_token(aws_sdk_bedrockruntime::config::Token::new( + bearer_token, + None, + )) + .build(); + + Client::from_conf(bedrock_config) + } else { + Self::create_client_with_credentials(&sdk_config).await? + }; let retry_config = Self::load_retry_config(config); @@ -105,6 +144,22 @@ impl BedrockProvider { }) } + async fn create_client_with_credentials(sdk_config: &aws_config::SdkConfig) -> Result { + sdk_config + .credentials_provider() + .ok_or_else(|| anyhow::anyhow!("No AWS credentials provider configured"))? + .provide_credentials() + .await + .map_err(|e| { + anyhow::anyhow!( + "Failed to load AWS credentials: {}. Make sure to run 'aws sso login --profile ' if using SSO", + e + ) + })?; + + Ok(Client::new(sdk_config)) + } + fn load_retry_config(config: &crate::config::Config) -> RetryConfig { let max_retries = config .get_param::("BEDROCK_MAX_RETRIES") @@ -212,13 +267,14 @@ impl Provider for BedrockProvider { ProviderMetadata::new( "aws_bedrock", "Amazon Bedrock", - "Run models through Amazon Bedrock. Supports AWS SSO profiles - run 'aws sso login --profile ' before using. Configure with AWS_PROFILE and AWS_REGION, or use environment variables/credentials.", + "Run models through Amazon Bedrock. Supports AWS SSO profiles - run 'aws sso login --profile ' before using. Configure with AWS_PROFILE and AWS_REGION, use environment variables/credentials, or use AWS_BEARER_TOKEN_BEDROCK for bearer token authentication. Region is required for bearer token auth (can be set via AWS_REGION, AWS_DEFAULT_REGION, or AWS profile).", BEDROCK_DEFAULT_MODEL, BEDROCK_KNOWN_MODELS.to_vec(), BEDROCK_DOC_LINK, vec![ - ConfigKey::new("AWS_PROFILE", true, false, Some("default")), - ConfigKey::new("AWS_REGION", true, false, None), + ConfigKey::new("AWS_PROFILE", false, false, Some("default")), + ConfigKey::new("AWS_REGION", false, false, None), + ConfigKey::new("AWS_BEARER_TOKEN_BEDROCK", false, true, None), ], ) } @@ -276,3 +332,49 @@ impl Provider for BedrockProvider { Ok((message, provider_usage)) } } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_metadata_config_keys_have_expected_flags() { + let meta = BedrockProvider::metadata(); + + let aws_profile = meta + .config_keys + .iter() + .find(|k| k.name == "AWS_PROFILE") + .expect("AWS_PROFILE config key should exist"); + assert!(!aws_profile.required, "AWS_PROFILE should not be required"); + assert!( + !aws_profile.secret, + "AWS_PROFILE should not be marked as secret" + ); + + let aws_region = meta + .config_keys + .iter() + .find(|k| k.name == "AWS_REGION") + .expect("AWS_REGION config key should exist"); + assert!(!aws_region.required, "AWS_REGION should not be required"); + assert!( + !aws_region.secret, + "AWS_REGION should not be marked as secret" + ); + + let bearer_token = meta + .config_keys + .iter() + .find(|k| k.name == "AWS_BEARER_TOKEN_BEDROCK") + .expect("AWS_BEARER_TOKEN_BEDROCK config key should exist"); + assert!( + !bearer_token.required, + "AWS_BEARER_TOKEN_BEDROCK should not be required" + ); + assert!( + bearer_token.secret, + "AWS_BEARER_TOKEN_BEDROCK should be marked as secret" + ); + } +} diff --git a/crates/goose/tests/providers.rs b/crates/goose/tests/providers.rs index 0ca51c127d1f..68a41bb89041 100644 --- a/crates/goose/tests/providers.rs +++ b/crates/goose/tests/providers.rs @@ -395,6 +395,15 @@ async fn test_provider( load_env(); + // Check required_vars BEFORE applying env_modifications to avoid + // leaving the environment mutated when skipping + let missing_vars = required_vars.iter().any(|var| std::env::var(var).is_err()); + if missing_vars { + println!("Skipping {} tests - credentials not configured", name); + TEST_REPORT.record_skip(name); + return Ok(()); + } + let mut original_env = HashMap::new(); for &var in required_vars { if let Ok(val) = std::env::var(var) { @@ -418,13 +427,6 @@ async fn test_provider( } } - let missing_vars = required_vars.iter().any(|var| std::env::var(var).is_err()); - if missing_vars { - println!("Skipping {} tests - credentials not configured", name); - TEST_REPORT.record_skip(name); - return Ok(()); - } - original_env }; @@ -488,7 +490,7 @@ async fn test_azure_provider() -> Result<()> { #[tokio::test] async fn test_bedrock_provider_long_term_credentials() -> Result<()> { test_provider( - "Bedrock", + "aws_bedrock", BEDROCK_DEFAULT_MODEL, &["AWS_ACCESS_KEY_ID", "AWS_SECRET_ACCESS_KEY"], None, @@ -502,7 +504,7 @@ async fn test_bedrock_provider_aws_profile_credentials() -> Result<()> { HashMap::from_iter([("AWS_ACCESS_KEY_ID", None), ("AWS_SECRET_ACCESS_KEY", None)]); test_provider( - "Bedrock", + "aws_bedrock", BEDROCK_DEFAULT_MODEL, &["AWS_PROFILE"], Some(env_mods), @@ -510,6 +512,24 @@ async fn test_bedrock_provider_aws_profile_credentials() -> Result<()> { .await } +#[tokio::test] +async fn test_bedrock_provider_bearer_token() -> Result<()> { + // Clear standard AWS credentials to ensure bearer token auth is used + let env_mods = HashMap::from_iter([ + ("AWS_ACCESS_KEY_ID", None), + ("AWS_SECRET_ACCESS_KEY", None), + ("AWS_PROFILE", None), + ]); + + test_provider( + "aws_bedrock", + BEDROCK_DEFAULT_MODEL, + &["AWS_BEARER_TOKEN_BEDROCK", "AWS_REGION"], + Some(env_mods), + ) + .await +} + #[tokio::test] async fn test_databricks_provider() -> Result<()> { test_provider(