Skip to content

Commit

Permalink
Fixes streaming for anthropic tool call. (#119)
Browse files Browse the repository at this point in the history
Co-authored-by: ccurme <[email protected]>
  • Loading branch information
3coins and ccurme authored Aug 2, 2024
1 parent b3196b1 commit 0fe4dd9
Show file tree
Hide file tree
Showing 3 changed files with 142 additions and 141 deletions.
63 changes: 22 additions & 41 deletions libs/aws/langchain_aws/chat_models/bedrock.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import json
import re
from collections import defaultdict
from operator import itemgetter
Expand All @@ -20,6 +19,7 @@
from langchain_core._api.deprecation import deprecated
from langchain_core.callbacks import CallbackManagerForLLMRun
from langchain_core.language_models import BaseChatModel, LanguageModelInput
from langchain_core.language_models.chat_models import generate_from_stream
from langchain_core.messages import (
AIMessage,
AIMessageChunk,
Expand All @@ -29,7 +29,7 @@
SystemMessage,
)
from langchain_core.messages.ai import UsageMetadata
from langchain_core.messages.tool import ToolCall, ToolMessage, tool_call_chunk
from langchain_core.messages.tool import ToolCall, ToolMessage
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
from langchain_core.pydantic_v1 import BaseModel, Extra
from langchain_core.runnables import Runnable, RunnableMap, RunnablePassthrough
Expand All @@ -39,7 +39,6 @@
from langchain_aws.function_calling import (
ToolsOutputParser,
_lc_tool_calls_to_anthropic_tool_use_blocks,
_tools_in_params,
convert_to_anthropic_tool,
get_system_message,
)
Expand Down Expand Up @@ -434,31 +433,6 @@ def _stream(
provider = self._get_provider()
prompt, system, formatted_messages = None, None, None

if "claude-3" in self._get_model():
if _tools_in_params({**kwargs}):
result = self._generate(
messages, stop=stop, run_manager=run_manager, **kwargs
)
message = result.generations[0].message
if isinstance(message, AIMessage) and message.tool_calls is not None:
tool_call_chunks = [
tool_call_chunk(
name=tool_call["name"],
args=json.dumps(tool_call["args"]),
id=tool_call["id"],
index=idx,
)
for idx, tool_call in enumerate(message.tool_calls)
]
message_chunk = AIMessageChunk(
content=message.content,
tool_call_chunks=tool_call_chunks, # type: ignore[arg-type]
usage_metadata=message.usage_metadata,
)
yield ChatGenerationChunk(message=message_chunk)
else:
yield cast(ChatGenerationChunk, result.generations[0])
return
if provider == "anthropic":
system, formatted_messages = ChatPromptAdapter.format_messages(
provider, messages
Expand All @@ -481,20 +455,23 @@ def _stream(
run_manager=run_manager,
**kwargs,
):
delta = chunk.text
if generation_info := chunk.generation_info:
usage_metadata = generation_info.pop("usage_metadata", None)
if isinstance(chunk, AIMessageChunk):
yield ChatGenerationChunk(message=chunk)
else:
usage_metadata = None
yield ChatGenerationChunk(
message=AIMessageChunk(
content=delta,
response_metadata=chunk.generation_info,
usage_metadata=usage_metadata,
delta = chunk.text
if generation_info := chunk.generation_info:
usage_metadata = generation_info.pop("usage_metadata", None)
else:
usage_metadata = None
yield ChatGenerationChunk(
message=AIMessageChunk(
content=delta,
response_metadata=chunk.generation_info,
usage_metadata=usage_metadata,
)
if chunk.generation_info is not None
else AIMessageChunk(content=delta)
)
if chunk.generation_info is not None
else AIMessageChunk(content=delta)
)

def _generate(
self,
Expand All @@ -513,7 +490,12 @@ def _generate(
provider_stop_reason_code = self.provider_stop_reason_key_map.get(
self._get_provider(), "stop_reason"
)
provider = self._get_provider()
if self.streaming:
if provider == "anthropic":
stream_iter = self._stream(messages, stop, run_manager, **kwargs)
return generate_from_stream(stream_iter)

response_metadata: List[Dict[str, Any]] = []
for chunk in self._stream(messages, stop, run_manager, **kwargs):
completion += chunk.text
Expand All @@ -524,7 +506,6 @@ def _generate(
response_metadata, provider_stop_reason_code
)
else:
provider = self._get_provider()
prompt, system, formatted_messages = None, None, None
params: Dict[str, Any] = {**kwargs}

Expand Down
157 changes: 95 additions & 62 deletions libs/aws/langchain_aws/llms/bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,9 @@
CallbackManagerForLLMRun,
)
from langchain_core.language_models import LLM, BaseLanguageModel
from langchain_core.messages import ToolCall
from langchain_core.messages.tool import tool_call
from langchain_core.messages import AIMessageChunk, ToolCall
from langchain_core.messages.ai import UsageMetadata
from langchain_core.messages.tool import tool_call, tool_call_chunk
from langchain_core.outputs import Generation, GenerationChunk, LLMResult
from langchain_core.pydantic_v1 import Extra, Field, root_validator
from langchain_core.utils import get_from_dict_or_env
Expand Down Expand Up @@ -89,61 +90,94 @@ def _stream_response_to_generation_chunk(
provider: str,
output_key: str,
messages_api: bool,
) -> Union[GenerationChunk, None]:
coerce_content_to_string: bool,
) -> Union[GenerationChunk, AIMessageChunk, None]: # type ignore[return]
"""Convert a stream response to a generation chunk."""
if messages_api:
msg_type = stream_response.get("type")
if msg_type == "message_start":
usage_info = stream_response.get("message", {}).get("usage", None)
usage_info = _nest_usage_info_token_counts(usage_info)
generation_info = {"usage": usage_info}
return GenerationChunk(text="", generation_info=generation_info)
input_tokens = stream_response["message"]["usage"]["input_tokens"]
return AIMessageChunk(
content="" if coerce_content_to_string else [],
usage_metadata=UsageMetadata(
input_tokens=input_tokens,
output_tokens=0,
total_tokens=input_tokens,
),
)
elif (
msg_type == "content_block_start"
and stream_response["content_block"] is not None
and stream_response["content_block"]["type"] == "tool_use"
):
content_block = stream_response["content_block"]
content_block["index"] = stream_response["index"]
tc_chunk = tool_call_chunk(
index=stream_response["index"],
id=stream_response["content_block"]["id"],
name=stream_response["content_block"]["name"],
args="",
)
return AIMessageChunk(
content=[content_block],
tool_call_chunks=[tc_chunk], # type: ignore
)
elif msg_type == "content_block_delta":
if not stream_response["delta"]:
return GenerationChunk(text="")
return GenerationChunk(
text=stream_response["delta"]["text"],
generation_info=dict(
stop_reason=stream_response.get("stop_reason", None),
return AIMessageChunk(content="")
if stream_response["delta"]["type"] == "text_delta":
if coerce_content_to_string:
return AIMessageChunk(content=stream_response["delta"]["text"])
else:
content_block = stream_response["delta"]
content_block["index"] = stream_response["index"]
content_block["type"] = "text"
return AIMessageChunk(content=[content_block])
elif stream_response["delta"]["type"] == "input_json_delta":
content_block = stream_response["delta"]
content_block["index"] = stream_response["index"]
content_block["type"] = "tool_use"
tc_chunk = {
"index": stream_response["index"],
"id": None,
"name": None,
"args": stream_response["delta"]["partial_json"],
}
return AIMessageChunk(
content=[content_block],
tool_call_chunks=[tc_chunk], # type: ignore
)
elif msg_type == "message_delta":
output_tokens = stream_response["usage"]["output_tokens"]
return AIMessageChunk(
content="",
usage_metadata=UsageMetadata(
input_tokens=0,
output_tokens=output_tokens,
total_tokens=output_tokens,
),
response_metadata={
"stop_reason": stream_response["delta"]["stop_reason"],
"stop_sequence": stream_response["delta"]["stop_sequence"],
},
)
elif msg_type == "message_delta":
usage_info = stream_response.get("usage", None)
usage_info = _nest_usage_info_token_counts(usage_info)
stop_reason = stream_response.get("delta", {}).get("stop_reason")
generation_info = {"stop_reason": stop_reason, "usage": usage_info}
return GenerationChunk(text="", generation_info=generation_info)
else:
return None
else:
# chunk obj format varies with provider
generation_info = {
k: v
for k, v in stream_response.items()
if k not in [output_key, "prompt_token_count", "generation_token_count"]
}
return GenerationChunk(
text=(
stream_response[output_key]
if provider != "mistral"
else stream_response[output_key][0]["text"]
),
generation_info=generation_info,
)


def _nest_usage_info_token_counts(usage_info: dict) -> dict:
"""
Sticking usage info for token counts into lists to
deal with langchain_core.utils.merge_dicts incompatibility
in which integers must be equal to be merged
as seen here: https://github.com/langchain-ai/langchain-aws/pull/20#issuecomment-2118166376
"""
if "input_tokens" in usage_info:
usage_info["input_tokens"] = [usage_info["input_tokens"]]
if "output_tokens" in usage_info:
usage_info["output_tokens"] = [usage_info["output_tokens"]]
return usage_info
# chunk obj format varies with provider
generation_info = {
k: v
for k, v in stream_response.items()
if k not in [output_key, "prompt_token_count", "generation_token_count"]
}
return GenerationChunk(
text=(
stream_response[output_key]
if provider != "mistral"
else stream_response[output_key][0]["text"]
),
generation_info=generation_info,
)


def _combine_generation_info_for_llm_result(
Expand Down Expand Up @@ -316,7 +350,8 @@ def prepare_output_stream(
response: Any,
stop: Optional[List[str]] = None,
messages_api: bool = False,
) -> Iterator[GenerationChunk]:
coerce_content_to_string: bool = False,
) -> Iterator[Union[GenerationChunk, AIMessageChunk]]:
stream = response.get("body")

if not stream:
Expand Down Expand Up @@ -364,6 +399,7 @@ def prepare_output_stream(
provider=provider,
output_key=output_key,
messages_api=messages_api,
coerce_content_to_string=coerce_content_to_string,
)
if generation_chunk:
yield generation_chunk
Expand All @@ -377,7 +413,8 @@ async def aprepare_output_stream(
response: Any,
stop: Optional[List[str]] = None,
messages_api: bool = False,
) -> AsyncIterator[GenerationChunk]:
coerce_content_to_string: bool = False,
) -> AsyncIterator[Union[GenerationChunk, AIMessageChunk]]:
stream = response.get("body")

if not stream:
Expand Down Expand Up @@ -413,6 +450,7 @@ async def aprepare_output_stream(
provider=provider,
output_key=output_key,
messages_api=messages_api,
coerce_content_to_string=coerce_content_to_string,
)
if generation_chunk:
yield generation_chunk
Expand Down Expand Up @@ -756,7 +794,7 @@ def _prepare_input_and_invoke_stream(
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Iterator[GenerationChunk]:
) -> Iterator[Union[GenerationChunk, AIMessageChunk]]:
_model_kwargs = self.model_kwargs or {}
provider = self._get_provider()

Expand All @@ -783,8 +821,10 @@ def _prepare_input_and_invoke_stream(
messages=messages,
model_kwargs=params,
)
coerce_content_to_string = True
if "claude-3" in self._get_model():
if _tools_in_params(params):
coerce_content_to_string = False
input_body = LLMInputOutputAdapter.prepare_input(
provider=provider,
model_kwargs=params,
Expand Down Expand Up @@ -823,13 +863,12 @@ def _prepare_input_and_invoke_stream(
response,
stop,
True if messages else False,
coerce_content_to_string=coerce_content_to_string,
):
yield chunk
# verify and raise callback error if any middleware intervened
self._get_bedrock_services_signal(chunk.generation_info) # type: ignore[arg-type]

if run_manager is not None:
run_manager.on_llm_new_token(chunk.text, chunk=chunk)
if not isinstance(chunk, AIMessageChunk):
self._get_bedrock_services_signal(chunk.generation_info) # type: ignore[arg-type]

async def _aprepare_input_and_invoke_stream(
self,
Expand All @@ -839,7 +878,7 @@ async def _aprepare_input_and_invoke_stream(
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> AsyncIterator[GenerationChunk]:
) -> AsyncIterator[Union[GenerationChunk, AIMessageChunk]]:
_model_kwargs = self.model_kwargs or {}
provider = self._get_provider()

Expand Down Expand Up @@ -891,12 +930,6 @@ async def _aprepare_input_and_invoke_stream(
True if messages else False,
):
yield chunk
if run_manager is not None and asyncio.iscoroutinefunction(
run_manager.on_llm_new_token
):
await run_manager.on_llm_new_token(chunk.text, chunk=chunk)
elif run_manager is not None:
run_manager.on_llm_new_token(chunk.text, chunk=chunk) # type: ignore[unused-coroutine]


class BedrockLLM(LLM, BedrockBase):
Expand Down Expand Up @@ -990,7 +1023,7 @@ def _stream(
Yields:
Iterator[GenerationChunk]: Responses from the model.
"""
return self._prepare_input_and_invoke_stream(
return self._prepare_input_and_invoke_stream( # type: ignore
prompt=prompt, stop=stop, run_manager=run_manager, **kwargs
)

Expand Down Expand Up @@ -1083,7 +1116,7 @@ async def _astream(
async for chunk in self._aprepare_input_and_invoke_stream(
prompt=prompt, stop=stop, run_manager=run_manager, **kwargs
):
yield chunk
yield chunk # type: ignore

async def _acall(
self,
Expand Down
Loading

0 comments on commit 0fe4dd9

Please sign in to comment.