Skip to content

Commit

Permalink
Fixes support for cross region inference (#242)
Browse files Browse the repository at this point in the history
Fixes #239 

1. Fixes streaming for cross-region inference models ids. 
2. Adds support for Cohere models in ChatBedrockConverse.
3. Also, adds standard integration tests for Cohere.
  • Loading branch information
3coins authored Oct 22, 2024
1 parent ff10faf commit ee32da0
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 6 deletions.
13 changes: 8 additions & 5 deletions libs/aws/langchain_aws/chat_models/bedrock_converse.py
Original file line number Diff line number Diff line change
Expand Up @@ -400,15 +400,18 @@ class Joke(BaseModel):
@model_validator(mode="before")
@classmethod
def set_disable_streaming(cls, values: Dict) -> Any:
values["provider"] = (
values.get("provider")
or (values.get("model_id", values["model"])).split(".")[0]
model_id = values.get("model_id", values.get("model"))
model_parts = model_id.split(".")
values["provider"] = values.get("provider") or (
model_parts[-2] if len(model_parts) > 1 else model_parts[0]
)

# As of 08/05/24 only Anthropic models support streamed tool calling
# As of 09/15/24 Anthropic and Cohere models support streamed tool calling
if "disable_streaming" not in values:
values["disable_streaming"] = (
False if "anthropic" in values["provider"] else "tool_calling"
False
if values["provider"] in ["anthropic", "cohere"]
else "tool_calling"
)
return values

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,28 @@ def test_tool_message_histories_list_content(self, model: BaseChatModel) -> None
super().test_tool_message_histories_list_content(model)


class TestBedrockCohereStandard(ChatModelIntegrationTests):
@property
def chat_model_class(self) -> Type[BaseChatModel]:
return ChatBedrockConverse

@property
def chat_model_params(self) -> dict:
return {"model": "cohere.command-r-plus-v1:0"}

@property
def standard_chat_model_params(self) -> dict:
return {"temperature": 0, "max_tokens": 100, "stop": []}

@pytest.mark.xfail(reason="Cohere models don't support tool_choice.")
def test_structured_few_shot_examples(self, model: BaseChatModel) -> None:
pass

@pytest.mark.xfail(reason="Cohere models don't support tool_choice.")
def test_tool_calling_with_no_arguments(self, model: BaseChatModel) -> None:
pass


def test_structured_output_snake_case() -> None:
model = ChatBedrockConverse(
model="anthropic.claude-3-sonnet-20240229-v1:0", temperature=0
Expand Down
18 changes: 17 additions & 1 deletion libs/aws/tests/unit_tests/chat_models/test_bedrock_converse.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Test chat model integration."""

import base64
from typing import Dict, List, Tuple, Type, cast
from typing import Dict, List, Tuple, Type, Union, cast

import pytest
from langchain_core.language_models import BaseChatModel
Expand Down Expand Up @@ -399,3 +399,19 @@ def test_standard_tracing_params() -> None:
"ls_temperature": 0.1,
"ls_max_tokens": 10,
}


@pytest.mark.parametrize(
"model_id, disable_streaming",
[
("anthropic.claude-3-5-sonnet-20240620-v1:0", False),
("us.anthropic.claude-3-haiku-20240307-v1:0", False),
("cohere.command-r-v1:0", False),
("meta.llama3-1-405b-instruct-v1:0", "tool_calling"),
],
)
def test_set_disable_streaming(
model_id: str, disable_streaming: Union[bool, str]
) -> None:
llm = ChatBedrockConverse(model=model_id, region_name="us-west-2")
assert llm.disable_streaming == disable_streaming

0 comments on commit ee32da0

Please sign in to comment.