Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
136 changes: 119 additions & 17 deletions crates/goose/src/providers/bedrock.rs
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,34 @@ impl BedrockProvider {
}
};

let filtered_secrets = config.all_secrets().map(|map| {
map.into_iter()
.filter(|(key, _)| key != "AWS_BEARER_TOKEN_BEDROCK")
.collect()
});

set_aws_env_vars(config.all_values());
set_aws_env_vars(config.all_secrets());
set_aws_env_vars(filtered_secrets);

// Check for bearer token first to determine if region is required
let bearer_token = match config.get_secret::<String>("AWS_BEARER_TOKEN_BEDROCK") {
Ok(token) => {
let token = token.trim().to_string();
if token.is_empty() {
None
} else {
Some(token)
}
}
Err(_) => None,
};

// Get AWS_REGION from config if explicitly set (optional - SDK can resolve from other sources)
let region = match config.get_param::<String>("AWS_REGION") {
Ok(r) if !r.is_empty() => Some(r),
Ok(_) => None,
Err(_) => None,
};

// Use load_defaults() which supports AWS SSO, profiles, and environment variables
let mut loader = aws_config::defaults(aws_config::BehaviorVersion::latest());
Expand All @@ -76,24 +102,37 @@ impl BedrockProvider {
}
}

// Check for AWS_REGION configuration
if let Ok(region) = config.get_param::<String>("AWS_REGION") {
if !region.is_empty() {
loader = loader.region(aws_config::Region::new(region));
}
// Apply region to loader if explicitly configured
if let Some(ref region) = region {
loader = loader.region(aws_config::Region::new(region.clone()));
}

let sdk_config = loader.load().await;

// Validate credentials or return error back up
sdk_config
.credentials_provider()
.ok_or_else(|| anyhow::anyhow!("No AWS credentials provider configured"))?
.provide_credentials()
.await
.map_err(|e| anyhow::anyhow!("Failed to load AWS credentials: {}. Make sure to run 'aws sso login --profile <your-profile>' if using SSO", e))?;
// Validate region requirement for bearer token auth after SDK config is loaded
// This allows region to be resolved from ~/.aws/config, AWS_DEFAULT_REGION, etc.
if bearer_token.is_some() && sdk_config.region().is_none() {
return Err(anyhow::anyhow!(
"AWS region is required when using AWS_BEARER_TOKEN_BEDROCK authentication. \
Set AWS_REGION, AWS_DEFAULT_REGION, or configure region in your AWS profile."
));
}

let client = Client::new(&sdk_config);
let client = if let Some(bearer_token) = bearer_token {
// Build from sdk_config to inherit all settings (endpoint overrides, timeouts, etc.)
// then override authentication with bearer token
let bedrock_config = aws_sdk_bedrockruntime::Config::new(&sdk_config)
.to_builder()
.bearer_token(aws_sdk_bedrockruntime::config::Token::new(
bearer_token,
None,
))
.build();

Client::from_conf(bedrock_config)
} else {
Self::create_client_with_credentials(&sdk_config).await?
};

let retry_config = Self::load_retry_config(config);

Expand All @@ -105,6 +144,22 @@ impl BedrockProvider {
})
}

async fn create_client_with_credentials(sdk_config: &aws_config::SdkConfig) -> Result<Client> {
sdk_config
.credentials_provider()
.ok_or_else(|| anyhow::anyhow!("No AWS credentials provider configured"))?
.provide_credentials()
.await
.map_err(|e| {
anyhow::anyhow!(
"Failed to load AWS credentials: {}. Make sure to run 'aws sso login --profile <your-profile>' if using SSO",
e
)
})?;

Ok(Client::new(sdk_config))
}

fn load_retry_config(config: &crate::config::Config) -> RetryConfig {
let max_retries = config
.get_param::<usize>("BEDROCK_MAX_RETRIES")
Expand Down Expand Up @@ -212,13 +267,14 @@ impl Provider for BedrockProvider {
ProviderMetadata::new(
"aws_bedrock",
"Amazon Bedrock",
"Run models through Amazon Bedrock. Supports AWS SSO profiles - run 'aws sso login --profile <profile-name>' before using. Configure with AWS_PROFILE and AWS_REGION, or use environment variables/credentials.",
"Run models through Amazon Bedrock. Supports AWS SSO profiles - run 'aws sso login --profile <profile-name>' before using. Configure with AWS_PROFILE and AWS_REGION, use environment variables/credentials, or use AWS_BEARER_TOKEN_BEDROCK for bearer token authentication. Region is required for bearer token auth (can be set via AWS_REGION, AWS_DEFAULT_REGION, or AWS profile).",
BEDROCK_DEFAULT_MODEL,
BEDROCK_KNOWN_MODELS.to_vec(),
BEDROCK_DOC_LINK,
vec![
ConfigKey::new("AWS_PROFILE", true, false, Some("default")),
ConfigKey::new("AWS_REGION", true, false, None),
ConfigKey::new("AWS_PROFILE", false, false, Some("default")),
ConfigKey::new("AWS_REGION", false, false, None),
ConfigKey::new("AWS_BEARER_TOKEN_BEDROCK", false, true, None),
],
)
}
Expand Down Expand Up @@ -276,3 +332,49 @@ impl Provider for BedrockProvider {
Ok((message, provider_usage))
}
}

#[cfg(test)]
mod tests {
use super::*;

#[test]
fn test_metadata_config_keys_have_expected_flags() {
let meta = BedrockProvider::metadata();

let aws_profile = meta
.config_keys
.iter()
.find(|k| k.name == "AWS_PROFILE")
.expect("AWS_PROFILE config key should exist");
assert!(!aws_profile.required, "AWS_PROFILE should not be required");
assert!(
!aws_profile.secret,
"AWS_PROFILE should not be marked as secret"
);

let aws_region = meta
.config_keys
.iter()
.find(|k| k.name == "AWS_REGION")
.expect("AWS_REGION config key should exist");
assert!(!aws_region.required, "AWS_REGION should not be required");
assert!(
!aws_region.secret,
"AWS_REGION should not be marked as secret"
);

let bearer_token = meta
.config_keys
.iter()
.find(|k| k.name == "AWS_BEARER_TOKEN_BEDROCK")
.expect("AWS_BEARER_TOKEN_BEDROCK config key should exist");
assert!(
!bearer_token.required,
"AWS_BEARER_TOKEN_BEDROCK should not be required"
);
assert!(
bearer_token.secret,
"AWS_BEARER_TOKEN_BEDROCK should be marked as secret"
);
}
}
38 changes: 29 additions & 9 deletions crates/goose/tests/providers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -395,6 +395,15 @@ async fn test_provider(

load_env();

// Check required_vars BEFORE applying env_modifications to avoid
// leaving the environment mutated when skipping
let missing_vars = required_vars.iter().any(|var| std::env::var(var).is_err());
if missing_vars {
println!("Skipping {} tests - credentials not configured", name);
TEST_REPORT.record_skip(name);
return Ok(());
}

let mut original_env = HashMap::new();
for &var in required_vars {
if let Ok(val) = std::env::var(var) {
Expand All @@ -418,13 +427,6 @@ async fn test_provider(
}
}

let missing_vars = required_vars.iter().any(|var| std::env::var(var).is_err());
if missing_vars {
println!("Skipping {} tests - credentials not configured", name);
TEST_REPORT.record_skip(name);
return Ok(());
}

original_env
};

Expand Down Expand Up @@ -488,7 +490,7 @@ async fn test_azure_provider() -> Result<()> {
#[tokio::test]
async fn test_bedrock_provider_long_term_credentials() -> Result<()> {
test_provider(
"Bedrock",
"aws_bedrock",
BEDROCK_DEFAULT_MODEL,
&["AWS_ACCESS_KEY_ID", "AWS_SECRET_ACCESS_KEY"],
None,
Expand All @@ -502,14 +504,32 @@ async fn test_bedrock_provider_aws_profile_credentials() -> Result<()> {
HashMap::from_iter([("AWS_ACCESS_KEY_ID", None), ("AWS_SECRET_ACCESS_KEY", None)]);

test_provider(
"Bedrock",
"aws_bedrock",
BEDROCK_DEFAULT_MODEL,
&["AWS_PROFILE"],
Some(env_mods),
)
.await
}

#[tokio::test]
async fn test_bedrock_provider_bearer_token() -> Result<()> {
// Clear standard AWS credentials to ensure bearer token auth is used
let env_mods = HashMap::from_iter([
("AWS_ACCESS_KEY_ID", None),
("AWS_SECRET_ACCESS_KEY", None),
("AWS_PROFILE", None),
]);

test_provider(
"aws_bedrock",
BEDROCK_DEFAULT_MODEL,
&["AWS_BEARER_TOKEN_BEDROCK", "AWS_REGION"],
Some(env_mods),
)
.await
}

#[tokio::test]
async fn test_databricks_provider() -> Result<()> {
test_provider(
Expand Down