Skip to content

Commit

Permalink
implement streaming and update tests
Browse files Browse the repository at this point in the history
  • Loading branch information
lspinheiro committed Jan 3, 2025
1 parent bac6b80 commit 2e2a2a5
Show file tree
Hide file tree
Showing 2 changed files with 202 additions and 21 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,10 @@
from semantic_kernel.kernel import Kernel
from semantic_kernel.functions.kernel_plugin import KernelPlugin
from typing_extensions import AsyncGenerator, Union
from ._kernel_function_from_tool import KernelFunctionFromTool
from autogen_ext.tools.semantic_kernel import KernelFunctionFromTool
from semantic_kernel.contents.function_call_content import FunctionCallContent
from autogen_core import FunctionCall
from semantic_kernel.contents.streaming_chat_message_content import StreamingChatMessageContent


class SKChatCompletionAdapter(ChatCompletionClient):
Expand Down Expand Up @@ -76,7 +79,7 @@ def _build_execution_settings(self, extra_create_args: Mapping[str, Any], tools:
# If tools are available, configure function choice behavior with auto_invoke disabled
function_choice_behavior = None
if tools:
function_choice_behavior = FunctionChoiceBehavior.NoneInvoke()
function_choice_behavior = FunctionChoiceBehavior.Auto(auto_invoke=extra_create_args.get("auto_invoke", False))

# Create settings with remaining args as extension_data
settings = PromptExecutionSettings(
Expand Down Expand Up @@ -111,6 +114,28 @@ def _sync_tools_with_kernel(self, kernel: Kernel, tools: Sequence[Tool | ToolSch
kernel_function = KernelFunctionFromTool(tool, plugin_name="autogen_tools")
self._tools_plugin.functions[tool.name] = kernel_function

def _process_tool_calls(self, result: ChatMessageContent) -> list[FunctionCall]:
"""Process tool calls from SK ChatMessageContent"""
function_calls = []
for item in result.items:
if isinstance(item, FunctionCallContent):
# Extract plugin name and function name
plugin_name = item.plugin_name or ""
function_name = item.function_name or item.name
if plugin_name:
full_name = f"{plugin_name}-{function_name}"
else:
full_name = function_name

function_calls.append(
FunctionCall(
id=item.id,
name=full_name,
arguments=item.arguments or "{}"
)
)
return function_calls

async def create(
self,
messages: Sequence[LLMMessage],
Expand Down Expand Up @@ -150,10 +175,19 @@ async def create(

self._total_prompt_tokens += prompt_tokens
self._total_completion_tokens += completion_tokens

# Process content based on whether there are tool calls
content: Union[str, list[FunctionCall]]
if any(isinstance(item, FunctionCallContent) for item in result[0].items):
content = self._process_tool_calls(result[0])
finish_reason = "function_calls"
else:
content = result[0].content
finish_reason = "stop"

return CreateResult(
content=result[0].content,
finish_reason="stop",
content=content,
finish_reason=finish_reason,
usage=RequestUsage(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens
Expand All @@ -169,12 +203,68 @@ async def create_stream(
extra_create_args: Mapping[str, Any] = {},
cancellation_token: Optional[CancellationToken] = None,
) -> AsyncGenerator[Union[str, CreateResult], None]:
# Very similar to create(), but orchestrates streaming.
# 1. Convert messages -> ChatHistory
# 2. Possibly set function-calling if needed
# 3. Build generator that yields str segments or a final CreateResult
# from SK's get_streaming_chat_message_contents(...)
raise NotImplementedError("create_stream is not implemented")
if "kernel" not in extra_create_args:
raise ValueError("kernel is required in extra_create_args")

kernel = extra_create_args["kernel"]
if not isinstance(kernel, Kernel):
raise ValueError("kernel must be an instance of semantic_kernel.kernel.Kernel")

chat_history = self._convert_to_chat_history(messages)
settings = self._build_execution_settings(extra_create_args, tools)
self._sync_tools_with_kernel(kernel, tools)

prompt_tokens = 0
completion_tokens = 0
accumulated_content = ""

async for streaming_messages in self._sk_client.get_streaming_chat_message_contents(
chat_history,
settings=settings,
kernel=kernel
):
for msg in streaming_messages:
if not isinstance(msg, StreamingChatMessageContent):
continue

# Track token usage
if msg.metadata and 'usage' in msg.metadata:
usage = msg.metadata['usage']
prompt_tokens = getattr(usage, 'prompt_tokens', 0)
completion_tokens = getattr(usage, 'completion_tokens', 0)

# Check for function calls
if any(isinstance(item, FunctionCallContent) for item in msg.items):
function_calls = self._process_tool_calls(msg)
yield CreateResult(
content=function_calls,
finish_reason="function_calls",
usage=RequestUsage(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens
),
cached=False
)
return

# Handle text content
if msg.content:
accumulated_content += msg.content
yield msg.content

# Final yield if there was text content
if accumulated_content:
self._total_prompt_tokens += prompt_tokens
self._total_completion_tokens += completion_tokens
yield CreateResult(
content=accumulated_content,
finish_reason="stop",
usage=RequestUsage(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens
),
cached=False
)

def actual_usage(self) -> RequestUsage:
return RequestUsage(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,20 +28,20 @@ def __init__(self):
async def run(self, args: CalculatorArgs, cancellation_token: CancellationToken) -> CalculatorResult:
return CalculatorResult(result=args.a + args.b)

@pytest.mark.asyncio
async def test_sk_chat_completion_with_tools():
# Set up Azure OpenAI client with token auth
deployment_name = "gpt-4o-mini"
endpoint = "https://<your-endpoint>.openai.azure.com/"
api_version = "2024-07-18"

# Create SK client
sk_client = AzureChatCompletion(
@pytest.fixture
def sk_client():
deployment_name = os.getenv("AZURE_OPENAI_DEPLOYMENT_NAME")
endpoint = os.getenv("AZURE_OPENAI_ENDPOINT")
api_key = os.getenv("AZURE_OPENAI_API_KEY")

return AzureChatCompletion(
deployment_name=deployment_name,
endpoint=endpoint,
api_key=os.getenv("AZURE_OPENAI_API_KEY"),
api_key=api_key,
)


@pytest.mark.asyncio
async def test_sk_chat_completion_with_tools(sk_client):
# Create adapter
adapter = SKChatCompletionAdapter(sk_client)

Expand All @@ -63,11 +63,102 @@ async def test_sk_chat_completion_with_tools():
tools=[tool],
extra_create_args={"kernel": kernel}
)

# Verify response
assert isinstance(result.content, list)
assert result.finish_reason == "function_calls"
assert result.usage.prompt_tokens >= 0
assert result.usage.completion_tokens >= 0
assert not result.cached

@pytest.mark.asyncio
async def test_sk_chat_completion_without_tools(sk_client):
# Create adapter and kernel
adapter = SKChatCompletionAdapter(sk_client)
kernel = Kernel(memory=NullMemory())

# Test messages
messages = [
SystemMessage(content="You are a helpful assistant."),
UserMessage(content="Say hello!", source="user"),
]

# Call create without tools
result = await adapter.create(
messages=messages,
extra_create_args={"kernel": kernel}
)

# Verify response
assert isinstance(result.content, str)
assert result.finish_reason == "stop"
assert result.usage.prompt_tokens >= 0
assert result.usage.completion_tokens >= 0
assert not result.cached

@pytest.mark.asyncio
async def test_sk_chat_completion_stream_with_tools(sk_client):
# Create adapter and kernel
adapter = SKChatCompletionAdapter(sk_client)
kernel = Kernel(memory=NullMemory())

# Create calculator tool
tool = CalculatorTool()

# Test messages
messages = [
SystemMessage(content="You are a helpful assistant."),
UserMessage(content="What is 2 + 2?", source="user"),
]

# Call create_stream with tool
response_chunks = []
async for chunk in adapter.create_stream(
messages=messages,
tools=[tool],
extra_create_args={"kernel": kernel}
):
response_chunks.append(chunk)

# Verify response
assert len(response_chunks) > 0
final_chunk = response_chunks[-1]
assert isinstance(final_chunk.content, list) # Function calls
assert final_chunk.finish_reason == "function_calls"
assert final_chunk.usage.prompt_tokens >= 0
assert final_chunk.usage.completion_tokens >= 0
assert not final_chunk.cached

@pytest.mark.asyncio
async def test_sk_chat_completion_stream_without_tools(sk_client):
# Create adapter and kernel
adapter = SKChatCompletionAdapter(sk_client)
kernel = Kernel(memory=NullMemory())

# Test messages
messages = [
SystemMessage(content="You are a helpful assistant."),
UserMessage(content="Say hello!", source="user"),
]

# Call create_stream without tools
response_chunks = []
async for chunk in adapter.create_stream(
messages=messages,
extra_create_args={"kernel": kernel}
):
response_chunks.append(chunk)

# Verify response
assert len(response_chunks) > 0
# All chunks except last should be strings
for chunk in response_chunks[:-1]:
assert isinstance(chunk, str)

# Final chunk should be CreateResult
final_chunk = response_chunks[-1]
assert isinstance(final_chunk.content, str)
assert final_chunk.finish_reason == "stop"
assert final_chunk.usage.prompt_tokens >= 0
assert final_chunk.usage.completion_tokens >= 0
assert not final_chunk.cached

0 comments on commit 2e2a2a5

Please sign in to comment.