Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding support for Llama3 models in BedrockChat #32

Merged
merged 5 commits into from
May 10, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 42 additions & 4 deletions libs/aws/langchain_aws/chat_models/bedrock.py
Original file line number Diff line number Diff line change
@@ -47,6 +47,41 @@ def convert_messages_to_prompt_llama(messages: List[BaseMessage]) -> str:
)


def _convert_one_message_to_text_llama3(message: BaseMessage) -> str:

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I propose one small change here, plus extra logic to support word in mouthing the assistant turn, for things like agent scratchpads etc living in the assistant turn,

def _convert_one_message_to_text_llama3(message: BaseMessage, lastMessage: bool) -> str:
    if isinstance(message, ChatMessage):
        message_text = f"<|start_header_id|>{message.role.capitalize()}<|end_header_id|>{message}<|eot_id|>"
    elif isinstance(message, HumanMessage):
        message_text = f"<|start_header_id|>user<|end_header_id|>{message.content}<|eot_id|>"
    elif isinstance(message, AIMessage):
        message_text = f"<|start_header_id|>assistant<|end_header_id|>{message.content}"
        if not lastMessage:
            message_text += "<|eot_id|>"
    elif isinstance(message, SystemMessage):
        message_text = f"<|start_header_id|>system<|end_header_id|>{message.content}"
    else:
        raise ValueError(f"Got unknown type {message}")
    return message_text


def convert_messages_to_prompt_llama3(messages: List[BaseMessage]) -> str:
    """Convert a list of messages to a prompt for llama3."""
    return "<|begin_of_text|>" + "".join(
        [_convert_one_message_to_text_llama3(message, i == len(messages) - 1)
         for i, message in enumerate(messages)]
    ) + "<|start_header_id|>assistant<|end_header_id|>"

This would allow stuff along these lines.

messages = [HumanMessage(content="list 5 colors"),AIMessage(content="No, I don't want to!")]

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not totally sure I understand the use case.

Shall we do this in a separate PR with a more illustrative example?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's say we want to write part of our assistant turn to align it to do something. This allows us to preemptively fill the assistant turn, without having added an EOD token such that the generation will start and be disjoint from the partially filled assistant turn.

<|start_header_id|>user<|end_header_id|>Can you generate me some xml data<|eot_id|>
<|start_header_id|>user<|end_header_id|>Can you generate me some xml data<|eot_id|>
<|start_header_id|>assistant<|end_header_id|>Yes here's some xml content, <xml> 

Right now this implementation would do this for this use case,

<|start_header_id|>user<|end_header_id|>Can you generate me some xml data<|eot_id|>
<|start_header_id|>assistant<|end_header_id|>Yes here's some xml content, <xml> <|eot_id|>

Causing it to kick to another assistant turn where it's not already aligned to doing what we've prefilled it to do.


Yes, this can go in another PR, this is an edge use case.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jonathancaevans
Thanks for explaining your use case, would you mind opening another issue for this, or even better if you can open a PR to make this change.

if isinstance(message, ChatMessage):
message_text = (
f"<|start_header_id|>{message.role}"
f"<|end_header_id|>{message.content}<|eot_id|>"
)
elif isinstance(message, HumanMessage):
message_text = (
f"<|start_header_id|>user" f"<|end_header_id|>{message.content}<|eot_id|>"
)
elif isinstance(message, AIMessage):
message_text = (
f"<|start_header_id|>assistant"
f"<|end_header_id|>{message.content}<|eot_id|>"
)
elif isinstance(message, SystemMessage):
message_text = (
f"<|start_header_id|>system" f"<|end_header_id|>{message.content}<|eot_id|>"
)
else:
raise ValueError(f"Got unknown type {message}")

return message_text


def convert_messages_to_prompt_llama3(messages: List[BaseMessage]) -> str:
"""Convert a list of messages to a prompt for llama."""

return "\n".join(
["<|begin_of_text|>"]
+ [_convert_one_message_to_text_llama3(message) for message in messages]
+ ["<|start_header_id|>assistant<|end_header_id|>\n\n"]
)


def _convert_one_message_to_text_anthropic(
message: BaseMessage,
human_prompt: str,
@@ -226,12 +261,15 @@ class ChatPromptAdapter:

@classmethod
def convert_messages_to_prompt(
cls, provider: str, messages: List[BaseMessage]
cls, provider: str, messages: List[BaseMessage], model: str
) -> str:
if provider == "anthropic":
prompt = convert_messages_to_prompt_anthropic(messages=messages)
elif provider == "meta":
prompt = convert_messages_to_prompt_llama(messages=messages)
if "llama3" in model:
prompt = convert_messages_to_prompt_llama3(messages=messages)
else:
prompt = convert_messages_to_prompt_llama(messages=messages)
elif provider == "mistral":
prompt = convert_messages_to_prompt_mistral(messages=messages)
elif provider == "amazon":
@@ -309,7 +347,7 @@ def _stream(
)
else:
prompt = ChatPromptAdapter.convert_messages_to_prompt(
provider=provider, messages=messages
provider=provider, messages=messages, model=self._get_model()
)

for chunk in self._prepare_input_and_invoke_stream(
@@ -347,7 +385,7 @@ def _generate(
)
else:
prompt = ChatPromptAdapter.convert_messages_to_prompt(
provider=provider, messages=messages
provider=provider, messages=messages, model=self._get_model()
)

if stop:
3 changes: 3 additions & 0 deletions libs/aws/langchain_aws/llms/bedrock.py
Original file line number Diff line number Diff line change
@@ -461,6 +461,9 @@ def _get_provider(self) -> str:

return self.model_id.split(".")[0]

def _get_model(self) -> str:
return self.model_id.split(".")[1]

@property
def _model_is_anthropic(self) -> bool:
return self._get_provider() == "anthropic"
17 changes: 17 additions & 0 deletions libs/aws/tests/integration_tests/chat_models/test_bedrock.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Test Bedrock chat model."""

from typing import Any, cast

import pytest
@@ -73,6 +74,22 @@ def test_chat_bedrock_streaming() -> None:
assert isinstance(response, BaseMessage)


@pytest.mark.scheduled
def test_chat_bedrock_streaming_llama3() -> None:
"""Test that streaming correctly invokes on_llm_new_token callback."""
callback_handler = FakeCallbackHandler()
chat = ChatBedrock( # type: ignore[call-arg]
model_id="meta.llama3-8b-instruct-v1:0",
streaming=True,
callbacks=[callback_handler],
verbose=True,
)
message = HumanMessage(content="Hello")
response = chat([message])
assert callback_handler.llm_streams > 0
assert isinstance(response, BaseMessage)


@pytest.mark.scheduled
def test_chat_bedrock_streaming_generation_info() -> None:
"""Test that generation info is preserved when streaming."""