diff --git a/src/api/providers/__tests__/bedrock.spec.ts b/src/api/providers/__tests__/bedrock.spec.ts index d0b5771b5c3..d728fbb91e0 100644 --- a/src/api/providers/__tests__/bedrock.spec.ts +++ b/src/api/providers/__tests__/bedrock.spec.ts @@ -36,7 +36,12 @@ vi.mock("@aws-sdk/client-bedrock-runtime", () => { import { AwsBedrockHandler } from "../bedrock" import { ConverseStreamCommand, BedrockRuntimeClient, ConverseCommand } from "@aws-sdk/client-bedrock-runtime" -import { BEDROCK_1M_CONTEXT_MODEL_IDS, BEDROCK_SERVICE_TIER_MODEL_IDS, bedrockModels, ApiProviderError } from "@roo-code/types" +import { + BEDROCK_1M_CONTEXT_MODEL_IDS, + BEDROCK_SERVICE_TIER_MODEL_IDS, + bedrockModels, + ApiProviderError, +} from "@roo-code/types" import type { Anthropic } from "@anthropic-ai/sdk" @@ -371,6 +376,103 @@ describe("AwsBedrockHandler", () => { expect(result.modelId).toBe("ap.anthropic.claude-3-5-sonnet-20241022-v2:0") // Should be preserved as-is }) }) + + describe("AWS GovCloud and China partition support", () => { + it("should parse AWS GovCloud ARNs (arn:aws-us-gov:bedrock:...)", () => { + const handler = new AwsBedrockHandler({ + apiModelId: "test", + awsAccessKey: "test", + awsSecretKey: "test", + awsRegion: "us-gov-west-1", + }) + + const parseArn = (handler as any).parseArn.bind(handler) + + const result = parseArn( + "arn:aws-us-gov:bedrock:us-gov-west-1:123456789012:inference-profile/us-gov.anthropic.claude-sonnet-4-5-20250929-v1:0", + ) + + expect(result.isValid).toBe(true) + expect(result.region).toBe("us-gov-west-1") + expect(result.modelType).toBe("inference-profile") + }) + + it("should parse AWS China ARNs (arn:aws-cn:bedrock:...)", () => { + const handler = new AwsBedrockHandler({ + apiModelId: "test", + awsAccessKey: "test", + awsSecretKey: "test", + awsRegion: "cn-north-1", + }) + + const parseArn = (handler as any).parseArn.bind(handler) + + const result = parseArn( + "arn:aws-cn:bedrock:cn-north-1:123456789012:inference-profile/anthropic.claude-3-sonnet-20240229-v1:0", + ) + + expect(result.isValid).toBe(true) + expect(result.region).toBe("cn-north-1") + expect(result.modelType).toBe("inference-profile") + }) + + it("should accept GovCloud custom ARN in handler constructor", () => { + const handler = new AwsBedrockHandler({ + apiModelId: "anthropic.claude-3-5-sonnet-20241022-v2:0", + awsAccessKey: "test-access-key", + awsSecretKey: "test-secret-key", + awsRegion: "us-gov-west-1", + awsCustomArn: + "arn:aws-us-gov:bedrock:us-gov-west-1:123456789012:inference-profile/us-gov.anthropic.claude-sonnet-4-5-20250929-v1:0", + }) + + // Should not throw and should return valid model info + const modelInfo = handler.getModel() + expect(modelInfo.id).toBe( + "arn:aws-us-gov:bedrock:us-gov-west-1:123456789012:inference-profile/us-gov.anthropic.claude-sonnet-4-5-20250929-v1:0", + ) + expect(modelInfo.info).toBeDefined() + }) + + it("should accept China region custom ARN in handler constructor", () => { + const handler = new AwsBedrockHandler({ + apiModelId: "anthropic.claude-3-5-sonnet-20241022-v2:0", + awsAccessKey: "test-access-key", + awsSecretKey: "test-secret-key", + awsRegion: "cn-north-1", + awsCustomArn: + "arn:aws-cn:bedrock:cn-north-1:123456789012:inference-profile/anthropic.claude-3-sonnet-20240229-v1:0", + }) + + // Should not throw and should return valid model info + const modelInfo = handler.getModel() + expect(modelInfo.id).toBe( + "arn:aws-cn:bedrock:cn-north-1:123456789012:inference-profile/anthropic.claude-3-sonnet-20240229-v1:0", + ) + expect(modelInfo.info).toBeDefined() + }) + + it("should detect region mismatch in GovCloud ARN", () => { + const handler = new AwsBedrockHandler({ + apiModelId: "test", + awsAccessKey: "test", + awsSecretKey: "test", + awsRegion: "us-east-1", + }) + + const parseArn = (handler as any).parseArn.bind(handler) + + // Region in ARN (us-gov-west-1) doesn't match provided region (us-east-1) + const result = parseArn( + "arn:aws-us-gov:bedrock:us-gov-west-1:123456789012:inference-profile/us-gov.anthropic.claude-sonnet-4-5-20250929-v1:0", + "us-east-1", + ) + + expect(result.isValid).toBe(true) + expect(result.region).toBe("us-gov-west-1") + expect(result.errorMessage).toContain("Region mismatch") + }) + }) }) describe("image handling", () => { diff --git a/src/api/providers/bedrock.ts b/src/api/providers/bedrock.ts index 27dce1bb9fe..8ac4e1ba017 100644 --- a/src/api/providers/bedrock.ts +++ b/src/api/providers/bedrock.ts @@ -928,8 +928,12 @@ export class AwsBedrockHandler extends BaseProvider implements SingleCompletionH * represent literal characters in the AWS ARN format, not filesystem paths. This regex will function consistently across Windows, * macOS, Linux, and any other operating system where JavaScript runs. * + * Supports any AWS partition (aws, aws-us-gov, aws-cn, or future partitions). + * The partition is not captured since we don't need to use it. + * * This matches ARNs like: * - Foundation Model: arn:aws:bedrock:us-west-2::foundation-model/anthropic.claude-v2 + * - GovCloud Inference Profile: arn:aws-us-gov:bedrock:us-gov-west-1:123456789012:inference-profile/us-gov.anthropic.claude-sonnet-4-5-20250929-v1:0 * - Prompt Router: arn:aws:bedrock:us-west-2:123456789012:prompt-router/anthropic-claude * - Inference Profile: arn:aws:bedrock:us-west-2:123456789012:inference-profile/anthropic.claude-v2 * - Cross Region Inference Profile: arn:aws:bedrock:us-west-2:123456789012:inference-profile/us.anthropic.claude-3-5-sonnet-20241022-v2:0 @@ -937,13 +941,13 @@ export class AwsBedrockHandler extends BaseProvider implements SingleCompletionH * - Imported Model: arn:aws:bedrock:us-west-2:123456789012:imported-model/my-imported-model * * match[0] - The entire matched string - * match[1] - The region (e.g., "us-east-1") + * match[1] - The region (e.g., "us-east-1", "us-gov-west-1") * match[2] - The account ID (can be empty string for AWS-managed resources) * match[3] - The resource type (e.g., "foundation-model") * match[4] - The resource ID (e.g., "anthropic.claude-3-sonnet-20240229-v1:0") */ - const arnRegex = /^arn:aws:(?:bedrock|sagemaker):([^:]+):([^:]*):(?:([^\/]+)\/([\w\.\-:]+)|([^\/]+))$/ + const arnRegex = /^arn:[^:]+:(?:bedrock|sagemaker):([^:]+):([^:]*):(?:([^\/]+)\/([\w\.\-:]+)|([^\/]+))$/ let match = arn.match(arnRegex) if (match && match[1] && match[3] && match[4]) {