Skip to content

Commit 535d0a4

Browse files
Unit Tests for llama-index-llms-bedrock-converse (run-llama#16379)
* Removed unused `llama-index-llms-anthropic` dependency. Incremented to `0.3.0`. * Expanded range for `pytest` and `pytest-mock` to support `pytest-asyncio` * Unit tests for main functions * remove lock * make coverage checks more narrow * rename test * update makefile * wrong arg names * even better workflow * improve check * try again? * ok, i think this works * Streamlined unit tests. * Consolidated mock exception --------- Co-authored-by: Logan Markewich <[email protected]>
1 parent af6ea71 commit 535d0a4

File tree

5 files changed

+234
-14
lines changed

5 files changed

+234
-14
lines changed

Diff for: .github/workflows/coverage.yml

+71
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
name: Check Coverage
2+
3+
on:
4+
push:
5+
branches:
6+
- main
7+
pull_request:
8+
9+
env:
10+
POETRY_VERSION: "1.8.3"
11+
12+
jobs:
13+
test:
14+
runs-on: ubuntu-latest-unit-tester
15+
strategy:
16+
fail-fast: false
17+
matrix:
18+
python-version: ["3.9"]
19+
steps:
20+
- name: clear space
21+
env:
22+
CI: true
23+
shell: bash
24+
run: rm -rf /opt/hostedtoolcache/*
25+
- uses: actions/checkout@v4
26+
with:
27+
fetch-depth: 0
28+
- name: update rustc
29+
run: rustup update stable
30+
- name: Install Poetry
31+
run: pipx install poetry==${{ env.POETRY_VERSION }}
32+
- name: Set up python ${{ matrix.python-version }}
33+
uses: actions/setup-python@v5
34+
with:
35+
python-version: ${{ matrix.python-version }}
36+
cache: "poetry"
37+
cache-dependency-path: "**/poetry.lock"
38+
- uses: pantsbuild/actions/init-pants@v5-scie-pants
39+
with:
40+
# v0 makes it easy to bust the cache if needed
41+
# just increase the integer to start with a fresh cache
42+
gha-cache-key: v1-py${{ matrix.python_version }}
43+
named-caches-hash: v1-py${{ matrix.python_version }}
44+
pants-python-version: ${{ matrix.python-version }}
45+
pants-ci-config: pants.toml
46+
- name: Check BUILD files
47+
run: |
48+
pants tailor --check :: -docs/::
49+
- name: Run coverage checks on changed packages
50+
run: |
51+
# Get the changed files
52+
CHANGED_FILES=$(pants list --changed-since=origin/main)
53+
54+
# Find which roots contain changed files
55+
FILTER_PATTERNS="["
56+
for file in $CHANGED_FILES; do
57+
root=$(echo "$file" | cut -d'/' -f1,2,3)
58+
if [[ ! "$FILTER_PATTERNS" =~ "$root" ]]; then
59+
FILTER_PATTERNS="${FILTER_PATTERNS}'${root}',"
60+
fi
61+
done
62+
63+
# remove the last comma and close the bracket
64+
FILTER_PATTERNS="${FILTER_PATTERNS%,}]"
65+
66+
echo "Coverage filter patterns: $FILTER_PATTERNS"
67+
68+
pants --level=error --no-local-cache test \
69+
--test-use-coverage \
70+
--changed-since=origin/main \
71+
--coverage-py-filter="$FILTER_PATTERNS"

Diff for: Makefile

+1-1
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ lint: ## Run linters: pre-commit (black, ruff, codespell) and mypy
1111
pre-commit install && git ls-files | xargs pre-commit run --show-diff-on-failure --files
1212

1313
test: ## Run tests via pants
14-
pants --level=error --no-local-cache --changed-since=origin/main --changed-dependents=transitive test
14+
pants --level=error --no-local-cache --changed-since=origin/main --changed-dependents=transitive --no-test-use-coverage test
1515

1616
test-core: ## Run tests via pants
1717
pants --no-local-cache test llama-index-core/::

Diff for: llama-index-integrations/llms/llama-index-llms-bedrock-converse/pyproject.toml

+3-2
Original file line numberDiff line numberDiff line change
@@ -41,8 +41,9 @@ jupyter = "^1.0.0"
4141
mypy = "0.991"
4242
pre-commit = "3.2.0"
4343
pylint = "2.15.10"
44-
pytest = "7.2.1"
45-
pytest-mock = "3.11.1"
44+
pytest = ">=7.2.1"
45+
pytest-asyncio = "^0.24.0"
46+
pytest-mock = ">=3.11.1"
4647
ruff = "0.0.292"
4748
tree-sitter-languages = "^1.8.0"
4849
types-Deprecated = ">=0.1.0"
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,162 @@
1-
from llama_index.core.base.llms.base import BaseLLM
1+
import pytest
22
from llama_index.llms.bedrock_converse import BedrockConverse
3+
from llama_index.core.base.llms.types import (
4+
ChatMessage,
5+
ChatResponse,
6+
MessageRole,
7+
CompletionResponse,
8+
)
9+
from llama_index.core.callbacks import CallbackManager
310

11+
# Expected values
12+
EXP_RESPONSE = "Test"
13+
EXP_STREAM_RESPONSE = ["Test ", "value"]
14+
EXP_MAX_TOKENS = 100
15+
EXP_TEMPERATURE = 0.7
16+
EXP_MODEL = "anthropic.claude-v2"
417

5-
def test_text_inference_embedding_class():
6-
names_of_base_classes = [b.__name__ for b in BedrockConverse.__mro__]
7-
assert BaseLLM.__name__ in names_of_base_classes
18+
# Reused chat message and prompt
19+
messages = [ChatMessage(role=MessageRole.USER, content="Test")]
20+
prompt = "Test"
21+
22+
23+
class MockExceptions:
24+
class ThrottlingException(Exception):
25+
pass
26+
27+
28+
class AsyncMockClient:
29+
def __init__(self) -> "AsyncMockClient":
30+
self.exceptions = MockExceptions()
31+
32+
async def __aenter__(self) -> "AsyncMockClient":
33+
return self
34+
35+
async def __aexit__(self, exc_type, exc_val, exc_tb) -> None:
36+
pass
37+
38+
async def converse(self, *args, **kwargs):
39+
return {"output": {"message": {"content": [{"text": EXP_RESPONSE}]}}}
40+
41+
async def converse_stream(self, *args, **kwargs):
42+
async def stream_generator():
43+
for element in EXP_STREAM_RESPONSE:
44+
yield {"contentBlockDelta": {"delta": {"text": element}}}
45+
46+
return {"stream": stream_generator()}
47+
48+
49+
class MockClient:
50+
def __init__(self) -> "MockClient":
51+
self.exceptions = MockExceptions()
52+
53+
def converse(self, *args, **kwargs):
54+
return {"output": {"message": {"content": [{"text": EXP_RESPONSE}]}}}
55+
56+
def converse_stream(self, *args, **kwargs):
57+
def stream_generator():
58+
for element in EXP_STREAM_RESPONSE:
59+
yield {"contentBlockDelta": {"delta": {"text": element}}}
60+
61+
return {"stream": stream_generator()}
62+
63+
64+
class MockAsyncSession:
65+
def __init__(self, *args, **kwargs) -> "MockAsyncSession":
66+
pass
67+
68+
def client(self, *args, **kwargs):
69+
return AsyncMockClient()
70+
71+
72+
@pytest.fixture()
73+
def mock_boto3_session(monkeypatch):
74+
def mock_client(*args, **kwargs):
75+
return MockClient()
76+
77+
monkeypatch.setattr("boto3.Session.client", mock_client)
78+
79+
80+
@pytest.fixture()
81+
def mock_aioboto3_session(monkeypatch):
82+
monkeypatch.setattr("aioboto3.Session", MockAsyncSession)
83+
84+
85+
@pytest.fixture()
86+
def bedrock_converse(mock_boto3_session, mock_aioboto3_session):
87+
return BedrockConverse(
88+
model=EXP_MODEL,
89+
max_tokens=EXP_MAX_TOKENS,
90+
temperature=EXP_TEMPERATURE,
91+
callback_manager=CallbackManager(),
92+
)
93+
94+
95+
def test_init(bedrock_converse):
96+
assert bedrock_converse.model == EXP_MODEL
97+
assert bedrock_converse.max_tokens == EXP_MAX_TOKENS
98+
assert bedrock_converse.temperature == EXP_TEMPERATURE
99+
assert bedrock_converse._client is not None
100+
101+
102+
def test_chat(bedrock_converse):
103+
response = bedrock_converse.chat(messages)
104+
105+
assert response.message.role == MessageRole.ASSISTANT
106+
assert response.message.content == EXP_RESPONSE
107+
108+
109+
def test_complete(bedrock_converse):
110+
response = bedrock_converse.complete(prompt)
111+
112+
assert isinstance(response, CompletionResponse)
113+
assert response.text == EXP_RESPONSE
114+
assert response.additional_kwargs["status"] == []
115+
assert response.additional_kwargs["tool_call_id"] == []
116+
assert response.additional_kwargs["tool_calls"] == []
117+
118+
119+
def test_stream_chat(bedrock_converse):
120+
response_stream = bedrock_converse.stream_chat(messages)
121+
122+
for response in response_stream:
123+
assert response.message.role == MessageRole.ASSISTANT
124+
assert response.delta in EXP_STREAM_RESPONSE
125+
126+
127+
@pytest.mark.asyncio()
128+
async def test_achat(bedrock_converse):
129+
response = await bedrock_converse.achat(messages)
130+
131+
assert isinstance(response, ChatResponse)
132+
assert response.message.role == MessageRole.ASSISTANT
133+
assert response.message.content == EXP_RESPONSE
134+
135+
136+
@pytest.mark.asyncio()
137+
async def test_astream_chat(bedrock_converse):
138+
response_stream = await bedrock_converse.astream_chat(messages)
139+
140+
responses = []
141+
async for response in response_stream:
142+
assert response.message.role == MessageRole.ASSISTANT
143+
assert response.delta in EXP_STREAM_RESPONSE
144+
145+
146+
@pytest.mark.asyncio()
147+
async def test_acomplete(bedrock_converse):
148+
response = await bedrock_converse.acomplete(prompt)
149+
150+
assert isinstance(response, CompletionResponse)
151+
assert response.text == EXP_RESPONSE
152+
assert response.additional_kwargs["status"] == []
153+
assert response.additional_kwargs["tool_call_id"] == []
154+
assert response.additional_kwargs["tool_calls"] == []
155+
156+
157+
@pytest.mark.asyncio()
158+
async def test_astream_complete(bedrock_converse):
159+
response_stream = await bedrock_converse.astream_complete(prompt)
160+
161+
async for response in response_stream:
162+
assert response.delta in EXP_STREAM_RESPONSE

Diff for: pants.toml

-7
Original file line numberDiff line numberDiff line change
@@ -20,13 +20,6 @@ config = "./pyproject.toml"
2020

2121
[coverage-py]
2222
fail_under = 50
23-
filter = [
24-
'llama-index-core/',
25-
'llama-index-experimental/',
26-
'llama-index-finetuning/',
27-
'llama-index-integrations/',
28-
'llama-index-utils/',
29-
]
3023
global_report = false
3124
report = ["console", "html", "xml"]
3225

0 commit comments

Comments
 (0)