diff --git a/libs/aws/langchain_aws/chat_models/bedrock.py b/libs/aws/langchain_aws/chat_models/bedrock.py index 5fa7182e..5562ea8b 100644 --- a/libs/aws/langchain_aws/chat_models/bedrock.py +++ b/libs/aws/langchain_aws/chat_models/bedrock.py @@ -47,6 +47,41 @@ def convert_messages_to_prompt_llama(messages: List[BaseMessage]) -> str: ) +def _convert_one_message_to_text_llama3(message: BaseMessage) -> str: + if isinstance(message, ChatMessage): + message_text = ( + f"<|start_header_id|>{message.role}" + f"<|end_header_id|>{message.content}<|eot_id|>" + ) + elif isinstance(message, HumanMessage): + message_text = ( + f"<|start_header_id|>user" f"<|end_header_id|>{message.content}<|eot_id|>" + ) + elif isinstance(message, AIMessage): + message_text = ( + f"<|start_header_id|>assistant" + f"<|end_header_id|>{message.content}<|eot_id|>" + ) + elif isinstance(message, SystemMessage): + message_text = ( + f"<|start_header_id|>system" f"<|end_header_id|>{message.content}<|eot_id|>" + ) + else: + raise ValueError(f"Got unknown type {message}") + + return message_text + + +def convert_messages_to_prompt_llama3(messages: List[BaseMessage]) -> str: + """Convert a list of messages to a prompt for llama.""" + + return "\n".join( + ["<|begin_of_text|>"] + + [_convert_one_message_to_text_llama3(message) for message in messages] + + ["<|start_header_id|>assistant<|end_header_id|>\n\n"] + ) + + def _convert_one_message_to_text_anthropic( message: BaseMessage, human_prompt: str, @@ -226,12 +261,15 @@ class ChatPromptAdapter: @classmethod def convert_messages_to_prompt( - cls, provider: str, messages: List[BaseMessage] + cls, provider: str, messages: List[BaseMessage], model: str ) -> str: if provider == "anthropic": prompt = convert_messages_to_prompt_anthropic(messages=messages) elif provider == "meta": - prompt = convert_messages_to_prompt_llama(messages=messages) + if "llama3" in model: + prompt = convert_messages_to_prompt_llama3(messages=messages) + else: + prompt = convert_messages_to_prompt_llama(messages=messages) elif provider == "mistral": prompt = convert_messages_to_prompt_mistral(messages=messages) elif provider == "amazon": @@ -309,7 +347,7 @@ def _stream( ) else: prompt = ChatPromptAdapter.convert_messages_to_prompt( - provider=provider, messages=messages + provider=provider, messages=messages, model=self._get_model() ) for chunk in self._prepare_input_and_invoke_stream( @@ -347,7 +385,7 @@ def _generate( ) else: prompt = ChatPromptAdapter.convert_messages_to_prompt( - provider=provider, messages=messages + provider=provider, messages=messages, model=self._get_model() ) if stop: diff --git a/libs/aws/langchain_aws/llms/bedrock.py b/libs/aws/langchain_aws/llms/bedrock.py index c3dc1fc8..f0eedbeb 100644 --- a/libs/aws/langchain_aws/llms/bedrock.py +++ b/libs/aws/langchain_aws/llms/bedrock.py @@ -461,6 +461,9 @@ def _get_provider(self) -> str: return self.model_id.split(".")[0] + def _get_model(self) -> str: + return self.model_id.split(".")[1] + @property def _model_is_anthropic(self) -> bool: return self._get_provider() == "anthropic" diff --git a/libs/aws/tests/integration_tests/chat_models/test_bedrock.py b/libs/aws/tests/integration_tests/chat_models/test_bedrock.py index 31f00caa..e7ba4ef2 100644 --- a/libs/aws/tests/integration_tests/chat_models/test_bedrock.py +++ b/libs/aws/tests/integration_tests/chat_models/test_bedrock.py @@ -1,4 +1,5 @@ """Test Bedrock chat model.""" + from typing import Any, cast import pytest @@ -73,6 +74,22 @@ def test_chat_bedrock_streaming() -> None: assert isinstance(response, BaseMessage) +@pytest.mark.scheduled +def test_chat_bedrock_streaming_llama3() -> None: + """Test that streaming correctly invokes on_llm_new_token callback.""" + callback_handler = FakeCallbackHandler() + chat = ChatBedrock( # type: ignore[call-arg] + model_id="meta.llama3-8b-instruct-v1:0", + streaming=True, + callbacks=[callback_handler], + verbose=True, + ) + message = HumanMessage(content="Hello") + response = chat([message]) + assert callback_handler.llm_streams > 0 + assert isinstance(response, BaseMessage) + + @pytest.mark.scheduled def test_chat_bedrock_streaming_generation_info() -> None: """Test that generation info is preserved when streaming."""