From 6829f50cc3a6e26e2003a48e2a92a4e8974672a7 Mon Sep 17 00:00:00 2001 From: Piyush Jain Date: Mon, 22 Jul 2024 10:37:36 -0700 Subject: [PATCH] Fixes streaming for llama3 models. (#116) --- libs/aws/langchain_aws/chat_models/bedrock.py | 22 ++++++++----------- libs/aws/langchain_aws/llms/bedrock.py | 10 ++++++++- .../chat_models/test_bedrock.py | 17 +++++++------- 3 files changed, 27 insertions(+), 22 deletions(-) diff --git a/libs/aws/langchain_aws/chat_models/bedrock.py b/libs/aws/langchain_aws/chat_models/bedrock.py index 574445ce..7eb2592a 100644 --- a/libs/aws/langchain_aws/chat_models/bedrock.py +++ b/libs/aws/langchain_aws/chat_models/bedrock.py @@ -565,20 +565,16 @@ def _generate( ) else: usage_metadata = None + llm_output["model_id"] = self.model_id - if len(tool_calls) > 0: - msg = AIMessage( - content=completion, - additional_kwargs=llm_output, - tool_calls=cast(List[ToolCall], tool_calls), - usage_metadata=usage_metadata, - ) - else: - msg = AIMessage( - content=completion, - additional_kwargs=llm_output, - usage_metadata=usage_metadata, - ) + + msg = AIMessage( + content=completion, + additional_kwargs=llm_output, + tool_calls=cast(List[ToolCall], tool_calls), + usage_metadata=usage_metadata, + ) + return ChatResult( generations=[ ChatGeneration( diff --git a/libs/aws/langchain_aws/llms/bedrock.py b/libs/aws/langchain_aws/llms/bedrock.py index 4f7da276..2eb53c5e 100644 --- a/libs/aws/langchain_aws/llms/bedrock.py +++ b/libs/aws/langchain_aws/llms/bedrock.py @@ -117,7 +117,11 @@ def _stream_response_to_generation_chunk( return None else: # chunk obj format varies with provider - generation_info = {k: v for k, v in stream_response.items() if k != output_key} + generation_info = { + k: v + for k, v in stream_response.items() + if k not in [output_key, "prompt_token_count", "generation_token_count"] + } return GenerationChunk( text=( stream_response[output_key] @@ -347,6 +351,10 @@ def prepare_output_stream( yield _get_invocation_metrics_chunk(chunk_obj) return + elif provider == "meta" and chunk_obj.get("stop_reason", "") == "stop": + yield _get_invocation_metrics_chunk(chunk_obj) + return + elif messages_api and (chunk_obj.get("type") == "message_stop"): yield _get_invocation_metrics_chunk(chunk_obj) return diff --git a/libs/aws/tests/integration_tests/chat_models/test_bedrock.py b/libs/aws/tests/integration_tests/chat_models/test_bedrock.py index 02d8cb22..260d31f4 100644 --- a/libs/aws/tests/integration_tests/chat_models/test_bedrock.py +++ b/libs/aws/tests/integration_tests/chat_models/test_bedrock.py @@ -84,17 +84,18 @@ def test_chat_bedrock_streaming() -> None: @pytest.mark.scheduled def test_chat_bedrock_streaming_llama3() -> None: """Test that streaming correctly invokes on_llm_new_token callback.""" - callback_handler = FakeCallbackHandler() chat = ChatBedrock( # type: ignore[call-arg] - model_id="meta.llama3-8b-instruct-v1:0", - streaming=True, - callbacks=[callback_handler], - verbose=True, + model_id="meta.llama3-8b-instruct-v1:0" ) message = HumanMessage(content="Hello") - response = chat([message]) - assert callback_handler.llm_streams > 0 - assert isinstance(response, BaseMessage) + + response = AIMessageChunk(content="") + for chunk in chat.stream([message]): + response += chunk # type: ignore[assignment] + + assert response.content + assert response.response_metadata + assert response.usage_metadata @pytest.mark.scheduled