Skip to content

Commit

Permalink
Fixes get provider for model_ids with region prefix. Added unit tests. (
Browse files Browse the repository at this point in the history
#184)

Fixes #178
  • Loading branch information
3coins authored Sep 6, 2024
1 parent 533effb commit 4aebf17
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 1 deletion.
14 changes: 13 additions & 1 deletion libs/aws/langchain_aws/llms/bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -613,15 +613,27 @@ def _identifying_params(self) -> Dict[str, Any]:
}

def _get_provider(self) -> str:
# If provider supplied by user, return as-is
if self.provider:
return self.provider

# If model_id is an arn, can't extract provider from model_id,
# so this requires passing in the provider by user
if self.model_id.startswith("arn"):
raise ValueError(
"Model provider should be supplied when passing a model ARN as "
"model_id"
)

return self.model_id.split(".")[0]
# If model_id has region prefixed to them,
# for example eu.anthropic.claude-3-haiku-20240307-v1:0,
# provider is the second part, otherwise, the first part
parts = self.model_id.split(".", maxsplit=2)
return (
parts[1]
if (len(parts) > 1 and parts[0].lower() in {"eu", "us", "ap", "sa"})
else parts[0]
)

def _get_model(self) -> str:
return self.model_id.split(".", maxsplit=1)[-1]
Expand Down
33 changes: 33 additions & 0 deletions libs/aws/tests/unit_tests/chat_models/test_bedrock.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
# type:ignore

"""Test chat model integration."""

from contextlib import nullcontext
from typing import Any, Callable, Dict, Literal, Type, cast

import pytest
Expand Down Expand Up @@ -419,3 +422,33 @@ def test_standard_tracing_params() -> None:
"ls_model_name": "foo",
"ls_temperature": 0.1,
}


@pytest.mark.parametrize(
"model_id, provider, expected_provider, expectation",
[
(
"eu.anthropic.claude-3-haiku-20240307-v1:0",
None,
"anthropic",
nullcontext(),
),
("meta.llama3-1-405b-instruct-v1:0", None, "meta", nullcontext()),
(
"arn:aws:bedrock:us-east-1::custom-model/cohere.command-r-v1:0/MyCustomModel2",
"cohere",
"cohere",
nullcontext(),
),
(
"arn:aws:bedrock:us-east-1::custom-model/cohere.command-r-v1:0/MyCustomModel2",
None,
"cohere",
pytest.raises(ValueError),
),
],
)
def test__get_provider(model_id, provider, expected_provider, expectation) -> None:
llm = ChatBedrock(model_id=model_id, provider=provider, region_name="us-west-2")
with expectation:
assert llm._get_provider() == expected_provider

0 comments on commit 4aebf17

Please sign in to comment.