Skip to content

Commit 9e322f1

Browse files
authored
Add support for AWS Bedrock LLM Provider (#238)
- Add support for AWS Bedrock LLM provider. The provider can be accessed by type `aws_bedrock`, and it is compatible with both `langchain` and `llamaindex` - Added unit tests for existing combinations of existing LLM providers and LLM frameworks. The unit tests are skipped by default, since each of them needs credential to interact with LLM models, but it should be updated and checked each time we add/modify the LLM provider/client to make sure they are still working - Fixed OpenAI + `llamaindex` LLM client that is not working Closes [AIQ-1213](https://jirasw.nvidia.com/browse/AIQ-1213) ## By Submitting this PR I confirm: - I am familiar with the [Contributing Guidelines](https://github.com/NVIDIA/AIQToolkit/blob/develop/docs/source/advanced/contributing.md). - We require that all contributors "sign-off" on their commits. This certifies that the contribution is your original work, or you have rights to submit it under the same license, or a compatible license. - Any contribution which contains commits that are not Signed-Off will not be accepted. - When the PR is ready for review, new or existing tests cover these changes. - When the PR is ready for review, the documentation is up to date with these changes. Authors: - Yuchen Zhang (https://github.com/yczhang-nv) - David Gardner (https://github.com/dagardner-nv) - https://github.com/liamy-nv - Matthew Penn (https://github.com/mpenn) - Anuradha Karuppiah (https://github.com/AnuradhaKaruppiah) - Ayush Thakur (https://github.com/ayulockin) - Soumili Nandi (https://github.com/soumilinandi) - Eric Evans II (https://github.com/ericevans-nv) - https://github.com/hsin-c - Zac Wang (https://github.com/zac-wang-nv) - Hritik Raj (https://github.com/Hritik003) - Victor Yudin (https://github.com/VictorYudin) - Dhruv Nandakumar (https://github.com/dnandakumar-nv) - Michael Demoret (https://github.com/mdemoret-nv) Approvers: - Michael Demoret (https://github.com/mdemoret-nv) URL: #238
1 parent dfb0f1c commit 9e322f1

File tree

12 files changed

+2794
-2239
lines changed

12 files changed

+2794
-2239
lines changed

docs/source/extend/adding-an-llm-provider.md

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,36 @@ Similar to the registration function for the provider, the client registration f
112112
In the above example, the `ChatOpenAI` class is imported lazily, allowing for the client to be registered without importing the client class until it is needed. Thus, improving performance and startup times.
113113
:::
114114

115+
## Test the Combination of LLM Provider and Client
116+
117+
After implementing a new LLM provider, it's important to verify that it works correctly with all existing LLM clients. This can be done by writing integration tests. Here's an example of how to test the integration between the NIM LLM provider and the LangChain framework:
118+
119+
```python
120+
@pytest.mark.integration
121+
async def test_nim_langchain_agent():
122+
"""
123+
Test NIM LLM with LangChain agent. Requires NVIDIA_API_KEY to be set.
124+
"""
125+
126+
prompt = ChatPromptTemplate.from_messages([("system", "You are a helpful AI assistant."), ("human", "{input}")])
127+
128+
llm_config = NIMModelConfig(model_name="meta/llama-3.1-70b-instruct", temperature=0.0)
129+
130+
async with WorkflowBuilder() as builder:
131+
await builder.add_llm("nim_llm", llm_config)
132+
llm = await builder.get_llm("nim_llm", wrapper_type=LLMFrameworkEnum.LANGCHAIN)
133+
134+
agent = prompt | llm
135+
136+
response = await agent.ainvoke({"input": "What is 1+2?"})
137+
assert isinstance(response, AIMessage)
138+
assert response.content is not None
139+
assert isinstance(response.content, str)
140+
assert "3" in response.content.lower()
141+
```
142+
143+
Note: Since this test requires an API key, it's marked with `@pytest.mark.integration` to exclude it from CI runs. However, these tests are necessary for maintaining and verifying the functionality of LLM providers and their client integrations.
144+
115145
## Packaging the Provider and Client
116146

117147
The provider and client will need to be bundled into a Python package, which in turn will be registered with AIQ toolkit as a [plugin](../extend/plugins.md). In the `pyproject.toml` file of the package the `project.entry-points.'aiq.components'` section, defines a Python module as the entry point of the plugin. Details on how this is defined are found in the [Entry Point](../extend/plugins.md#entry-point) section of the plugins document. By convention, the entry point module is named `register.py`, but this is not a requirement.
Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
<!--
2+
SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3+
SPDX-License-Identifier: Apache-2.0
4+
5+
Licensed under the Apache License, Version 2.0 (the "License");
6+
you may not use this file except in compliance with the License.
7+
You may obtain a copy of the License at
8+
9+
http://www.apache.org/licenses/LICENSE-2.0
10+
11+
Unless required by applicable law or agreed to in writing, software
12+
distributed under the License is distributed on an "AS IS" BASIS,
13+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
See the License for the specific language governing permissions and
15+
limitations under the License.
16+
-->
17+
18+
# AWS Bedrock Integration
19+
20+
The Agent Intelligence Toolkit supports integration with multiple LLM providers, including AWS Bedrock. This documentation provides a comprehensive guide on how to integrate AWS Bedrock models into your AIQ Toolkit workflow. To view the full list of supported LLM providers, run `aiq info components -t llm_provider`.
21+
22+
23+
## Configuration
24+
25+
### Prerequisites
26+
Before integrating AWS Bedrock, ensure you have:
27+
- Set up AWS credentials by configuring `AWS_ACCESS_KEY_ID` and `AWS_SECRET_ACCESS_KEY`
28+
- For detailed setup instructions, refer to the [AWS Bedrock setup guide](https://docs.aws.amazon.com/bedrock/latest/userguide/setting-up.html)
29+
30+
### Example Configuration
31+
Add the AWS Bedrock LLM configuration to your workflow config file. Make sure the `region_name` matches the region of your AWS account, and the `credentials_profile_name` matches the field in your credential file:
32+
33+
```yaml
34+
llms:
35+
aws_bedrock_llm:
36+
_type: aws_bedrock
37+
model_name: meta.llama3-3-70b-instruct-v1:0
38+
temperature: 0.0
39+
max_tokens: 1024
40+
region_name: us-east-2
41+
credentials_profile_name: default
42+
```
43+
44+
### Configurable Options
45+
* `model_name`: The name of the AWS Bedrock model to use (required)
46+
* `temperature`: Controls randomness in the output (0.0 to 1.0, default: 0.0)
47+
* `max_tokens`: Maximum number of tokens to generate (must be > 0, default: 1024)
48+
* `context_size`: Maximum number of tokens for context (must be > 0, default: 1024, required for LlamaIndex)
49+
* `region_name`: AWS region where your Bedrock service is hosted (default: "None")
50+
* `base_url`: Custom Bedrock endpoint URL (default: None, needed if you don't want to use the default us-east-1 endpoint)
51+
* `credentials_profile_name`: AWS credentials profile name from ~/.aws/credentials or ~/.aws/config files (default: None)
52+
53+
## Usage in Workflow
54+
Reference the AWS Bedrock LLM in your workflow configuration:
55+
56+
```yaml
57+
workflow:
58+
_type: react_agent
59+
llm_name: aws_bedrock_llm
60+
# ... other workflow configurations
61+
```

docs/source/index.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,7 @@ Adding a Custom Evaluator <./extend/custom-evaluator.md>
108108
./extend/adding-a-retriever.md
109109
./extend/memory.md
110110
Adding an LLM Provider <./extend/adding-an-llm-provider.md>
111+
Integrating AWS Bedrock Models <./extend/integrating-aws-bedrock-models.md>
111112
```
112113

113114
```{toctree}

packages/aiqtoolkit_langchain/pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ dependencies = [
2020
# version when adding a new package. If unsure, default to using `~=` instead of `==`. Does not apply to aiq packages.
2121
# Keep sorted!!!
2222
"aiqtoolkit~=1.2",
23+
"langchain-aws~=0.2.1",
2324
"langchain-core~=0.3.7",
2425
"langchain-nvidia-ai-endpoints~=0.3.5",
2526
"langchain-milvus~=0.1.5",

packages/aiqtoolkit_langchain/src/aiq/plugins/langchain/llm.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from aiq.builder.builder import Builder
1717
from aiq.builder.framework_enum import LLMFrameworkEnum
1818
from aiq.cli.register_workflow import register_llm_client
19+
from aiq.llm.aws_bedrock_llm import AWSBedrockModelConfig
1920
from aiq.llm.nim_llm import NIMModelConfig
2021
from aiq.llm.openai_llm import OpenAIModelConfig
2122

@@ -34,3 +35,11 @@ async def openai_langchain(llm_config: OpenAIModelConfig, builder: Builder):
3435
from langchain_openai import ChatOpenAI
3536

3637
yield ChatOpenAI(**llm_config.model_dump(exclude={"type"}, by_alias=True))
38+
39+
40+
@register_llm_client(config_type=AWSBedrockModelConfig, wrapper_type=LLMFrameworkEnum.LANGCHAIN)
41+
async def aws_bedrock_langchain(llm_config: AWSBedrockModelConfig, builder: Builder):
42+
43+
from langchain_aws import ChatBedrockConverse
44+
45+
yield ChatBedrockConverse(**llm_config.model_dump(exclude={"type", "context_size"}, by_alias=True))

packages/aiqtoolkit_llama_index/pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ dependencies = [
2424
# error
2525
"llama-index-core==0.12.21",
2626
"llama-index-embeddings-nvidia==0.3.1",
27+
"llama-index-llms-bedrock==0.3.8",
2728
"llama-index-llms-nvidia==0.3.1",
2829
"llama-index-readers-file==0.4.4",
2930
"llama-index==0.12.21",

packages/aiqtoolkit_llama_index/src/aiq/plugins/llama_index/llm.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from aiq.builder.builder import Builder
1717
from aiq.builder.framework_enum import LLMFrameworkEnum
1818
from aiq.cli.register_workflow import register_llm_client
19+
from aiq.llm.aws_bedrock_llm import AWSBedrockModelConfig
1920
from aiq.llm.nim_llm import NIMModelConfig
2021
from aiq.llm.openai_llm import OpenAIModelConfig
2122

@@ -47,7 +48,16 @@ async def openai_llama_index(llm_config: OpenAIModelConfig, builder: Builder):
4748

4849
llm = OpenAI(**kwargs)
4950

50-
# Disable content blocks
51-
llm.supports_content_blocks = False
51+
yield llm
52+
53+
54+
@register_llm_client(config_type=AWSBedrockModelConfig, wrapper_type=LLMFrameworkEnum.LLAMA_INDEX)
55+
async def aws_bedrock_llama_index(llm_config: AWSBedrockModelConfig, builder: Builder):
56+
57+
from llama_index.llms.bedrock import Bedrock
58+
59+
kwargs = llm_config.model_dump(exclude={"type", "max_tokens"}, by_alias=True)
60+
61+
llm = Bedrock(**kwargs)
5262

5363
yield llm

src/aiq/llm/aws_bedrock_llm.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2024-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
from pydantic import AliasChoices
17+
from pydantic import ConfigDict
18+
from pydantic import Field
19+
20+
from aiq.builder.builder import Builder
21+
from aiq.builder.llm import LLMProviderInfo
22+
from aiq.cli.register_workflow import register_llm_provider
23+
from aiq.data_models.llm import LLMBaseConfig
24+
25+
26+
class AWSBedrockModelConfig(LLMBaseConfig, name="aws_bedrock"):
27+
"""An AWS Bedrock llm provider to be used with an LLM client."""
28+
29+
model_config = ConfigDict(protected_namespaces=())
30+
31+
# Completion parameters
32+
model_name: str = Field(validation_alias=AliasChoices("model_name", "model"),
33+
serialization_alias="model",
34+
description="The model name for the hosted AWS Bedrock.")
35+
temperature: float = Field(default=0.0, ge=0.0, le=1.0, description="Sampling temperature in [0, 1].")
36+
max_tokens: int | None = Field(default=1024,
37+
gt=0,
38+
description="Maximum number of tokens to generate."
39+
"This field is ONLY required when using AWS Bedrock with Langchain.")
40+
context_size: int | None = Field(default=1024,
41+
gt=0,
42+
description="Maximum number of tokens to generate."
43+
"This field is ONLY required when using AWS Bedrock with LlamaIndex.")
44+
45+
# Client parameters
46+
region_name: str | None = Field(default="None", description="AWS region to use.")
47+
base_url: str | None = Field(
48+
default=None, description="Bedrock endpoint to use. Needed if you don't want to default to us-east-1 endpoint.")
49+
credentials_profile_name: str | None = Field(
50+
default=None, description="The name of the profile in the ~/.aws/credentials or ~/.aws/config files.")
51+
52+
53+
@register_llm_provider(config_type=AWSBedrockModelConfig)
54+
async def aws_bedrock_model(llm_config: AWSBedrockModelConfig, builder: Builder):
55+
56+
yield LLMProviderInfo(config=llm_config, description="A AWS Bedrock model for use with an LLM client.")

src/aiq/llm/register.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,5 +18,6 @@
1818
# isort:skip_file
1919

2020
# Import any providers which need to be automatically registered here
21+
from . import aws_bedrock_llm
2122
from . import nim_llm
2223
from . import openai_llm
Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
import pytest
16+
from langchain_core.messages import AIMessage
17+
from langchain_core.prompts import ChatPromptTemplate
18+
19+
from aiq.builder.framework_enum import LLMFrameworkEnum
20+
from aiq.builder.workflow_builder import WorkflowBuilder
21+
from aiq.llm.aws_bedrock_llm import AWSBedrockModelConfig
22+
from aiq.llm.nim_llm import NIMModelConfig
23+
from aiq.llm.openai_llm import OpenAIModelConfig
24+
25+
26+
@pytest.mark.integration
27+
async def test_nim_langchain_agent():
28+
"""
29+
Test NIM LLM with LangChain agent. Requires NVIDIA_API_KEY to be set.
30+
"""
31+
32+
prompt = ChatPromptTemplate.from_messages([("system", "You are a helpful AI assistant."), ("human", "{input}")])
33+
34+
llm_config = NIMModelConfig(model_name="meta/llama-3.1-70b-instruct", temperature=0.0)
35+
36+
async with WorkflowBuilder() as builder:
37+
await builder.add_llm("nim_llm", llm_config)
38+
llm = await builder.get_llm("nim_llm", wrapper_type=LLMFrameworkEnum.LANGCHAIN)
39+
40+
agent = prompt | llm
41+
42+
response = await agent.ainvoke({"input": "What is 1+2?"})
43+
assert isinstance(response, AIMessage)
44+
assert response.content is not None
45+
assert isinstance(response.content, str)
46+
assert "3" in response.content.lower()
47+
48+
49+
@pytest.mark.integration
50+
async def test_openai_langchain_agent():
51+
"""
52+
Test OpenAI LLM with LangChain agent. Requires OPENAI_API_KEY to be set.
53+
"""
54+
prompt = ChatPromptTemplate.from_messages([("system", "You are a helpful AI assistant."), ("human", "{input}")])
55+
56+
llm_config = OpenAIModelConfig(model_name="gpt-3.5-turbo", temperature=0.0)
57+
58+
async with WorkflowBuilder() as builder:
59+
await builder.add_llm("openai_llm", llm_config)
60+
llm = await builder.get_llm("openai_llm", wrapper_type=LLMFrameworkEnum.LANGCHAIN)
61+
62+
agent = prompt | llm
63+
64+
response = await agent.ainvoke({"input": "What is 1+2?"})
65+
assert isinstance(response, AIMessage)
66+
assert response.content is not None
67+
assert isinstance(response.content, str)
68+
assert "3" in response.content.lower()
69+
70+
71+
@pytest.mark.integration
72+
async def test_aws_bedrock_langchain_agent():
73+
"""
74+
Test AWS Bedrock LLM with LangChain agent.
75+
Requires AWS_ACCESS_KEY_ID and AWS_SECRET_ACCESS_KEY to be set.
76+
See https://docs.aws.amazon.com/bedrock/latest/userguide/setting-up.html for more information.
77+
"""
78+
prompt = ChatPromptTemplate.from_messages([("system", "You are a helpful AI assistant."), ("human", "{input}")])
79+
80+
llm_config = AWSBedrockModelConfig(model_name="meta.llama3-3-70b-instruct-v1:0",
81+
temperature=0.0,
82+
region_name="us-east-2",
83+
max_tokens=1024)
84+
85+
async with WorkflowBuilder() as builder:
86+
await builder.add_llm("aws_bedrock_llm", llm_config)
87+
llm = await builder.get_llm("aws_bedrock_llm", wrapper_type=LLMFrameworkEnum.LANGCHAIN)
88+
89+
agent = prompt | llm
90+
91+
response = await agent.ainvoke({"input": "What is 1+2?"})
92+
assert isinstance(response, AIMessage)
93+
assert response.content is not None
94+
assert isinstance(response.content, str)
95+
assert "3" in response.content.lower()

0 commit comments

Comments
 (0)