From 916d3c1558ceda76b31f6f628e81a73779746d4f Mon Sep 17 00:00:00 2001 From: Fedor Date: Tue, 30 Apr 2024 12:12:17 +0100 Subject: [PATCH 1/5] Adding support for Llama3 models. --- libs/aws/langchain_aws/chat_models/bedrock.py | 45 +++++++++++++++++-- libs/aws/langchain_aws/llms/bedrock.py | 3 ++ 2 files changed, 45 insertions(+), 3 deletions(-) diff --git a/libs/aws/langchain_aws/chat_models/bedrock.py b/libs/aws/langchain_aws/chat_models/bedrock.py index 5fa7182e..a5e59b4e 100644 --- a/libs/aws/langchain_aws/chat_models/bedrock.py +++ b/libs/aws/langchain_aws/chat_models/bedrock.py @@ -47,6 +47,42 @@ def convert_messages_to_prompt_llama(messages: List[BaseMessage]) -> str: ) +def _convert_one_message_to_text_llama3(message: BaseMessage) -> str: + if isinstance(message, ChatMessage): + message_text = ( + f"<|begin_of_text|><|start_header_id|>{message.role}" + f"<|end_header_id|>{message.content}<|eot_id|>" + ) + elif isinstance(message, HumanMessage): + message_text = ( + f"<|begin_of_text|><|start_header_id|>user" + f"<|end_header_id|>{message.content}<|eot_id|>" + ) + elif isinstance(message, AIMessage): + message_text = ( + f"<|begin_of_text|><|start_header_id|>assistant" + f"<|end_header_id|>{message.content}<|eot_id|>" + ) + elif isinstance(message, SystemMessage): + message_text = ( + f"<|begin_of_text|><|start_header_id|>system" + f"<|end_header_id|>{message.content}<|eot_id|>" + ) + else: + raise ValueError(f"Got unknown type {message}") + + return message_text + + +def convert_messages_to_prompt_llama3(messages: List[BaseMessage]) -> str: + """Convert a list of messages to a prompt for llama.""" + + return "\n".join( + [_convert_one_message_to_text_llama3(message) for message in messages] + + ["<|start_header_id|>assistant<|end_header_id|>\n\n"] + ) + + def _convert_one_message_to_text_anthropic( message: BaseMessage, human_prompt: str, @@ -226,12 +262,15 @@ class ChatPromptAdapter: @classmethod def convert_messages_to_prompt( - cls, provider: str, messages: List[BaseMessage] + cls, provider: str, messages: List[BaseMessage], model: str ) -> str: if provider == "anthropic": prompt = convert_messages_to_prompt_anthropic(messages=messages) elif provider == "meta": - prompt = convert_messages_to_prompt_llama(messages=messages) + if "llama3" in model: + prompt = convert_messages_to_prompt_llama3(messages=messages) + else: + prompt = convert_messages_to_prompt_llama(messages=messages) elif provider == "mistral": prompt = convert_messages_to_prompt_mistral(messages=messages) elif provider == "amazon": @@ -309,7 +348,7 @@ def _stream( ) else: prompt = ChatPromptAdapter.convert_messages_to_prompt( - provider=provider, messages=messages + provider=provider, messages=messages, model=self._get_model() ) for chunk in self._prepare_input_and_invoke_stream( diff --git a/libs/aws/langchain_aws/llms/bedrock.py b/libs/aws/langchain_aws/llms/bedrock.py index c3dc1fc8..f0eedbeb 100644 --- a/libs/aws/langchain_aws/llms/bedrock.py +++ b/libs/aws/langchain_aws/llms/bedrock.py @@ -461,6 +461,9 @@ def _get_provider(self) -> str: return self.model_id.split(".")[0] + def _get_model(self) -> str: + return self.model_id.split(".")[1] + @property def _model_is_anthropic(self) -> bool: return self._get_provider() == "anthropic" From 7dfefbc88a4d6c181d2a21e178739fad58e2f43c Mon Sep 17 00:00:00 2001 From: Fedor Date: Tue, 30 Apr 2024 16:03:52 +0100 Subject: [PATCH 2/5] Remove extraneous <|begin_of_text|> tokens as only one is needed --- libs/aws/langchain_aws/chat_models/bedrock.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/libs/aws/langchain_aws/chat_models/bedrock.py b/libs/aws/langchain_aws/chat_models/bedrock.py index a5e59b4e..922c71c2 100644 --- a/libs/aws/langchain_aws/chat_models/bedrock.py +++ b/libs/aws/langchain_aws/chat_models/bedrock.py @@ -50,22 +50,22 @@ def convert_messages_to_prompt_llama(messages: List[BaseMessage]) -> str: def _convert_one_message_to_text_llama3(message: BaseMessage) -> str: if isinstance(message, ChatMessage): message_text = ( - f"<|begin_of_text|><|start_header_id|>{message.role}" + f"<|start_header_id|>{message.role}" f"<|end_header_id|>{message.content}<|eot_id|>" ) elif isinstance(message, HumanMessage): message_text = ( - f"<|begin_of_text|><|start_header_id|>user" + f"<|start_header_id|>user" f"<|end_header_id|>{message.content}<|eot_id|>" ) elif isinstance(message, AIMessage): message_text = ( - f"<|begin_of_text|><|start_header_id|>assistant" + f"<|start_header_id|>assistant" f"<|end_header_id|>{message.content}<|eot_id|>" ) elif isinstance(message, SystemMessage): message_text = ( - f"<|begin_of_text|><|start_header_id|>system" + f"<|start_header_id|>system" f"<|end_header_id|>{message.content}<|eot_id|>" ) else: @@ -78,7 +78,8 @@ def convert_messages_to_prompt_llama3(messages: List[BaseMessage]) -> str: """Convert a list of messages to a prompt for llama.""" return "\n".join( - [_convert_one_message_to_text_llama3(message) for message in messages] + ["<|begin_of_text|>"] + + [_convert_one_message_to_text_llama3(message) for message in messages] + ["<|start_header_id|>assistant<|end_header_id|>\n\n"] ) From d3a82906c807fbe9486ea3638966c9dd4b0d42c5 Mon Sep 17 00:00:00 2001 From: Fedor Date: Wed, 1 May 2024 09:58:05 +0100 Subject: [PATCH 3/5] linting --- libs/aws/langchain_aws/chat_models/bedrock.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/libs/aws/langchain_aws/chat_models/bedrock.py b/libs/aws/langchain_aws/chat_models/bedrock.py index 922c71c2..cbeb0d6f 100644 --- a/libs/aws/langchain_aws/chat_models/bedrock.py +++ b/libs/aws/langchain_aws/chat_models/bedrock.py @@ -55,8 +55,7 @@ def _convert_one_message_to_text_llama3(message: BaseMessage) -> str: ) elif isinstance(message, HumanMessage): message_text = ( - f"<|start_header_id|>user" - f"<|end_header_id|>{message.content}<|eot_id|>" + f"<|start_header_id|>user" f"<|end_header_id|>{message.content}<|eot_id|>" ) elif isinstance(message, AIMessage): message_text = ( @@ -65,8 +64,7 @@ def _convert_one_message_to_text_llama3(message: BaseMessage) -> str: ) elif isinstance(message, SystemMessage): message_text = ( - f"<|start_header_id|>system" - f"<|end_header_id|>{message.content}<|eot_id|>" + f"<|start_header_id|>system" f"<|end_header_id|>{message.content}<|eot_id|>" ) else: raise ValueError(f"Got unknown type {message}") From 52ce5c598b4f83cbb9f8d568a3544be30eb749f1 Mon Sep 17 00:00:00 2001 From: Fedor Date: Tue, 7 May 2024 13:20:05 +0100 Subject: [PATCH 4/5] lint fix --- libs/aws/langchain_aws/chat_models/bedrock.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/libs/aws/langchain_aws/chat_models/bedrock.py b/libs/aws/langchain_aws/chat_models/bedrock.py index cbeb0d6f..5562ea8b 100644 --- a/libs/aws/langchain_aws/chat_models/bedrock.py +++ b/libs/aws/langchain_aws/chat_models/bedrock.py @@ -385,7 +385,7 @@ def _generate( ) else: prompt = ChatPromptAdapter.convert_messages_to_prompt( - provider=provider, messages=messages + provider=provider, messages=messages, model=self._get_model() ) if stop: From 1f212f4587297c2c81bc690190a1e85241f1da5d Mon Sep 17 00:00:00 2001 From: Fedor Date: Wed, 8 May 2024 08:19:37 +0100 Subject: [PATCH 5/5] add integration test --- .../chat_models/test_bedrock.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/libs/aws/tests/integration_tests/chat_models/test_bedrock.py b/libs/aws/tests/integration_tests/chat_models/test_bedrock.py index 31f00caa..e7ba4ef2 100644 --- a/libs/aws/tests/integration_tests/chat_models/test_bedrock.py +++ b/libs/aws/tests/integration_tests/chat_models/test_bedrock.py @@ -1,4 +1,5 @@ """Test Bedrock chat model.""" + from typing import Any, cast import pytest @@ -73,6 +74,22 @@ def test_chat_bedrock_streaming() -> None: assert isinstance(response, BaseMessage) +@pytest.mark.scheduled +def test_chat_bedrock_streaming_llama3() -> None: + """Test that streaming correctly invokes on_llm_new_token callback.""" + callback_handler = FakeCallbackHandler() + chat = ChatBedrock( # type: ignore[call-arg] + model_id="meta.llama3-8b-instruct-v1:0", + streaming=True, + callbacks=[callback_handler], + verbose=True, + ) + message = HumanMessage(content="Hello") + response = chat([message]) + assert callback_handler.llm_streams > 0 + assert isinstance(response, BaseMessage) + + @pytest.mark.scheduled def test_chat_bedrock_streaming_generation_info() -> None: """Test that generation info is preserved when streaming."""