Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -881,7 +881,8 @@ async def create_stream(
tool_calls[current_tool_id] = {
"id": chunk.content_block.id,
"name": chunk.content_block.name,
"input": "", # Will be populated from deltas
"input": json.dumps(chunk.content_block.input),
"partial_json": "", # May be populated from deltas
}

elif chunk.type == "content_block_delta":
Expand All @@ -896,10 +897,15 @@ async def create_stream(
elif hasattr(chunk.delta, "type") and chunk.delta.type == "input_json_delta":
if current_tool_id is not None and hasattr(chunk.delta, "partial_json"):
# Accumulate partial JSON for the current tool
tool_calls[current_tool_id]["input"] += chunk.delta.partial_json
tool_calls[current_tool_id]["partial_json"] += chunk.delta.partial_json

elif chunk.type == "content_block_stop":
# End of a content block (could be text or tool)
if current_tool_id is not None:
# If there was partial JSON accumulated, use it as the input
if len(tool_calls[current_tool_id]["partial_json"]) > 0:
tool_calls[current_tool_id]["input"] = tool_calls[current_tool_id]["partial_json"]
del tool_calls[current_tool_id]["partial_json"]
current_tool_id = None

elif chunk.type == "message_delta":
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import asyncio
import json
import logging
import os
from typing import List, Sequence
Expand All @@ -20,6 +21,7 @@
from autogen_ext.models.anthropic import (
AnthropicBedrockChatCompletionClient,
AnthropicChatCompletionClient,
BaseAnthropicChatCompletionClient,
BedrockInfo,
)

Expand All @@ -34,6 +36,11 @@ def _add_numbers(a: int, b: int) -> int:
return a + b


def _ask_for_input() -> str:
"""Function that asks for user input. Used to test empty input handling, such as in `pass_to_user` tool."""
return "Further input from user"


@pytest.mark.asyncio
async def test_mock_tool_choice_specific_tool() -> None:
"""Test tool_choice parameter with a specific tool using mocks."""
Expand Down Expand Up @@ -999,3 +1006,104 @@ async def test_anthropic_tool_choice_none_value_with_actual_api() -> None:

# Should get a text response, not tool calls
assert isinstance(result.content, str)


def get_client_or_skip(provider: str) -> BaseAnthropicChatCompletionClient:
if provider == "anthropic":
api_key = os.getenv("ANTHROPIC_API_KEY")
if not api_key:
pytest.skip("ANTHROPIC_API_KEY not found in environment variables")

return AnthropicChatCompletionClient(
model="claude-3-haiku-20240307",
api_key=api_key,
)
else:
access_key = os.getenv("AWS_ACCESS_KEY_ID")
secret_key = os.getenv("AWS_SECRET_ACCESS_KEY")
region = os.getenv("AWS_REGION")
if not access_key or not secret_key or not region:
pytest.skip("AWS credentials not found in environment variables")

model = os.getenv("ANTHROPIC_BEDROCK_MODEL", "us.anthropic.claude-3-haiku-20240307-v1:0")
return AnthropicBedrockChatCompletionClient(
model=model,
bedrock_info=BedrockInfo(
aws_access_key=access_key,
aws_secret_key=secret_key,
aws_region=region,
aws_session_token=os.getenv("AWS_SESSION_TOKEN", ""),
),
model_info=ModelInfo(
vision=False, function_calling=True, json_output=False, family="unknown", structured_output=True
),
)


@pytest.mark.asyncio
@pytest.mark.parametrize("provider", ["anthropic", "bedrock"])
async def test_streaming_tool_usage_with_no_arguments(provider: str) -> None:
"""
Test reading streaming tool usage response with no arguments.
In that case `input` in initial `tool_use` chunk is `{}` and subsequent `partial_json` chunks are empty.
"""
client = get_client_or_skip(provider)

# Define tools
ask_for_input_tool = FunctionTool(
_ask_for_input, description="Ask user for more input", name="ask_for_input", strict=True
)

chunks: List[str | CreateResult] = []
async for chunk in client.create_stream(
messages=[
SystemMessage(content="When user intent is unclear, ask for more input"),
UserMessage(content="Erm...", source="user"),
],
tools=[ask_for_input_tool],
tool_choice="required",
):
chunks.append(chunk)

assert len(chunks) > 0
assert isinstance(chunks[-1], CreateResult)
result: CreateResult = chunks[-1]
assert len(result.content) == 1
content = result.content[-1]
assert isinstance(content, FunctionCall)
assert content.name == "ask_for_input"
assert json.loads(content.arguments) is not None


@pytest.mark.parametrize("provider", ["anthropic", "bedrock"])
@pytest.mark.asyncio
async def test_streaming_tool_usage_with_arguments(provider: str) -> None:
"""
Test reading streaming tool usage response with arguments.
In that case `input` in initial `tool_use` chunk is `{}` but subsequent `partial_json` chunks make up the actual
complete input value.
"""
client = get_client_or_skip(provider)

# Define tools
add_numbers = FunctionTool(_add_numbers, description="Add two numbers together", name="add_numbers")

chunks: List[str | CreateResult] = []
async for chunk in client.create_stream(
messages=[
SystemMessage(content="Use the tools to evaluate calculations"),
UserMessage(content="2 + 2", source="user"),
],
tools=[add_numbers],
tool_choice="required",
):
chunks.append(chunk)

assert len(chunks) > 0
assert isinstance(chunks[-1], CreateResult)
result: CreateResult = chunks[-1]
assert len(result.content) == 1
content = result.content[-1]
assert isinstance(content, FunctionCall)
assert content.name == "add_numbers"
assert json.loads(content.arguments) is not None
Loading