@@ -454,12 +454,20 @@ def _configure_azure_openai(self, config: Dict[str, Any], openai_config: Dict[st
454
454
azure .identity .DefaultAzureCredential (), "https://cognitiveservices.azure.com/.default"
455
455
)
456
456
457
+ def _configure_openai_config_for_bedrock (self , config : Dict [str , Any ], openai_config : Dict [str , Any ]) -> None :
458
+ """Update openai_config with AWS credentials from config."""
459
+ required_keys = ["aws_access_key" , "aws_secret_key" , "aws_session_token" , "aws_region" ]
460
+
461
+ for key in required_keys :
462
+ if key in config :
463
+ openai_config [key ] = config [key ]
464
+
457
465
def _register_default_client (self , config : Dict [str , Any ], openai_config : Dict [str , Any ]) -> None :
458
466
"""Create a client with the given config to override openai_config,
459
467
after removing extra kwargs.
460
468
461
469
For Azure models/deployment names there's a convenience modification of model removing dots in
462
- the it's value (Azure deploment names can't have dots). I.e. if you have Azure deployment name
470
+ the it's value (Azure deployment names can't have dots). I.e. if you have Azure deployment name
463
471
"gpt-35-turbo" and define model "gpt-3.5-turbo" in the config the function will remove the dot
464
472
from the name and create a client that connects to "gpt-35-turbo" Azure deployment.
465
473
"""
@@ -485,6 +493,8 @@ def _register_default_client(self, config: Dict[str, Any], openai_config: Dict[s
485
493
client = GeminiClient (** openai_config )
486
494
self ._clients .append (client )
487
495
elif api_type is not None and api_type .startswith ("anthropic" ):
496
+ if "api_key" not in config :
497
+ self ._configure_openai_config_for_bedrock (config , openai_config )
488
498
if anthropic_import_exception :
489
499
raise ImportError ("Please install `anthropic` to use Anthropic API." )
490
500
client = AnthropicClient (** openai_config )
0 commit comments