Skip to content

Commit

Permalink
Fixes streaming for llama3 models. (#116)
Browse files Browse the repository at this point in the history
  • Loading branch information
3coins authored Jul 22, 2024
1 parent 8a17693 commit 6829f50
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 22 deletions.
22 changes: 9 additions & 13 deletions libs/aws/langchain_aws/chat_models/bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -565,20 +565,16 @@ def _generate(
)
else:
usage_metadata = None

llm_output["model_id"] = self.model_id
if len(tool_calls) > 0:
msg = AIMessage(
content=completion,
additional_kwargs=llm_output,
tool_calls=cast(List[ToolCall], tool_calls),
usage_metadata=usage_metadata,
)
else:
msg = AIMessage(
content=completion,
additional_kwargs=llm_output,
usage_metadata=usage_metadata,
)

msg = AIMessage(
content=completion,
additional_kwargs=llm_output,
tool_calls=cast(List[ToolCall], tool_calls),
usage_metadata=usage_metadata,
)

return ChatResult(
generations=[
ChatGeneration(
Expand Down
10 changes: 9 additions & 1 deletion libs/aws/langchain_aws/llms/bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,11 @@ def _stream_response_to_generation_chunk(
return None
else:
# chunk obj format varies with provider
generation_info = {k: v for k, v in stream_response.items() if k != output_key}
generation_info = {
k: v
for k, v in stream_response.items()
if k not in [output_key, "prompt_token_count", "generation_token_count"]
}
return GenerationChunk(
text=(
stream_response[output_key]
Expand Down Expand Up @@ -347,6 +351,10 @@ def prepare_output_stream(
yield _get_invocation_metrics_chunk(chunk_obj)
return

elif provider == "meta" and chunk_obj.get("stop_reason", "") == "stop":
yield _get_invocation_metrics_chunk(chunk_obj)
return

elif messages_api and (chunk_obj.get("type") == "message_stop"):
yield _get_invocation_metrics_chunk(chunk_obj)
return
Expand Down
17 changes: 9 additions & 8 deletions libs/aws/tests/integration_tests/chat_models/test_bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,17 +84,18 @@ def test_chat_bedrock_streaming() -> None:
@pytest.mark.scheduled
def test_chat_bedrock_streaming_llama3() -> None:
"""Test that streaming correctly invokes on_llm_new_token callback."""
callback_handler = FakeCallbackHandler()
chat = ChatBedrock( # type: ignore[call-arg]
model_id="meta.llama3-8b-instruct-v1:0",
streaming=True,
callbacks=[callback_handler],
verbose=True,
model_id="meta.llama3-8b-instruct-v1:0"
)
message = HumanMessage(content="Hello")
response = chat([message])
assert callback_handler.llm_streams > 0
assert isinstance(response, BaseMessage)

response = AIMessageChunk(content="")
for chunk in chat.stream([message]):
response += chunk # type: ignore[assignment]

assert response.content
assert response.response_metadata
assert response.usage_metadata


@pytest.mark.scheduled
Expand Down

0 comments on commit 6829f50

Please sign in to comment.