-
Notifications
You must be signed in to change notification settings - Fork 5.3k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
initial sk model adapter implementation
- Loading branch information
1 parent
501d8bb
commit f77cd39
Showing
6 changed files
with
672 additions
and
374 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
5 changes: 5 additions & 0 deletions
5
python/packages/autogen-ext/src/autogen_ext/models/semantic_kernel/__init__.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
from ._sk_chat_completion_adapter import SKChatCompletionAdapter | ||
|
||
__all__ = [ | ||
"SKChatCompletionAdapter" | ||
] |
62 changes: 62 additions & 0 deletions
62
...packages/autogen-ext/src/autogen_ext/models/semantic_kernel/_kernel_function_from_tool.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,62 @@ | ||
from semantic_kernel.functions.kernel_function import KernelFunction | ||
from semantic_kernel.functions.kernel_function_metadata import KernelFunctionMetadata | ||
from semantic_kernel.functions.kernel_parameter_metadata import KernelParameterMetadata | ||
from semantic_kernel.functions.function_result import FunctionResult | ||
from semantic_kernel.filters.functions.function_invocation_context import FunctionInvocationContext | ||
from semantic_kernel.exceptions import FunctionExecutionException | ||
from autogen_core.tools import BaseTool | ||
|
||
class KernelFunctionFromTool(KernelFunction): | ||
def __init__(self, tool: BaseTool, plugin_name: str | None = None): | ||
# Build up KernelFunctionMetadata. You can also parse the tool’s schema for parameters. | ||
parameters = [ | ||
KernelParameterMetadata( | ||
name="args", | ||
description="JSON arguments for the tool", | ||
default_value=None, | ||
type_="dict", | ||
type_object=dict, | ||
is_required=True, | ||
) | ||
] | ||
return_param = KernelParameterMetadata( | ||
name="return", | ||
description="Result from the tool", | ||
default_value=None, | ||
type_="str", | ||
type_object=str, | ||
is_required=False, | ||
) | ||
|
||
metadata = KernelFunctionMetadata( | ||
name=tool.name, | ||
description=tool.description, | ||
parameters=parameters, | ||
return_parameter=return_param, | ||
is_prompt=False, | ||
is_asynchronous=True, | ||
plugin_name=plugin_name or "", | ||
) | ||
super().__init__(metadata=metadata) | ||
self._tool = tool | ||
|
||
async def _invoke_internal(self, context: FunctionInvocationContext) -> None: | ||
# Extract the "args" parameter from the context | ||
if "args" not in context.arguments: | ||
raise FunctionExecutionException("Missing 'args' in FunctionInvocationContext.arguments") | ||
tool_args = context.arguments["args"] | ||
|
||
# Call your tool’s run_json | ||
result = await self._tool.run_json(tool_args, cancellation_token=None) | ||
|
||
# Wrap in a FunctionResult | ||
context.result = FunctionResult( | ||
function=self.metadata, | ||
value=result, | ||
metadata={"used_arguments": tool_args}, | ||
) | ||
|
||
async def _invoke_internal_stream(self, context: FunctionInvocationContext) -> None: | ||
# If you don’t have a streaming mechanism in your tool, you can simply reuse _invoke_internal | ||
# or raise NotImplementedError. For example: | ||
await self._invoke_internal(context) |
214 changes: 214 additions & 0 deletions
214
...ackages/autogen-ext/src/autogen_ext/models/semantic_kernel/_sk_chat_completion_adapter.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,214 @@ | ||
from typing import Any, Mapping, Optional, Sequence | ||
from autogen_core._cancellation_token import CancellationToken | ||
from autogen_core.models import RequestUsage, FunctionExecutionResultMessage, ModelCapabilities, AssistantMessage, SystemMessage, UserMessage, FunctionExecutionResult | ||
from autogen_core.models import ChatCompletionClient, CreateResult, LLMMessage | ||
from autogen_core.tools import Tool, ToolSchema | ||
from semantic_kernel.connectors.ai.chat_completion_client_base import ChatCompletionClientBase | ||
from semantic_kernel.connectors.ai.prompt_execution_settings import PromptExecutionSettings | ||
from semantic_kernel.connectors.ai.function_choice_behavior import FunctionChoiceBehavior | ||
from semantic_kernel.contents.chat_history import ChatHistory | ||
from semantic_kernel.contents.chat_message_content import ChatMessageContent | ||
from semantic_kernel.contents.utils.author_role import AuthorRole | ||
from semantic_kernel.kernel import Kernel | ||
from semantic_kernel.functions.kernel_plugin import KernelPlugin | ||
from typing_extensions import AsyncGenerator, Union | ||
from ._kernel_function_from_tool import KernelFunctionFromTool | ||
|
||
|
||
class SKChatCompletionAdapter(ChatCompletionClient): | ||
def __init__(self, sk_client: ChatCompletionClientBase): | ||
self._sk_client = sk_client | ||
self._total_prompt_tokens = 0 | ||
self._total_completion_tokens = 0 | ||
self._tools_plugin: Optional[KernelPlugin] = None | ||
|
||
def _convert_to_chat_history(self, messages: Sequence[LLMMessage]) -> ChatHistory: | ||
"""Convert Autogen LLMMessages to SK ChatHistory""" | ||
chat_history = ChatHistory() | ||
|
||
for msg in messages: | ||
if msg.type == "SystemMessage": | ||
chat_history.add_system_message(msg.content) | ||
|
||
elif msg.type == "UserMessage": | ||
if isinstance(msg.content, str): | ||
chat_history.add_user_message(msg.content) | ||
else: | ||
# Handle list of str/Image - would need to convert to SK content types | ||
chat_history.add_user_message(str(msg.content)) | ||
|
||
elif msg.type == "AssistantMessage": | ||
if isinstance(msg.content, str): | ||
chat_history.add_assistant_message(msg.content) | ||
else: | ||
# Handle function calls - would need to convert to SK function call format | ||
chat_history.add_assistant_message(str(msg.content)) | ||
|
||
elif msg.type == "FunctionExecutionResultMessage": | ||
for result in msg.content: | ||
chat_history.add_tool_message(result.content) | ||
|
||
return chat_history | ||
|
||
def _convert_from_chat_message(self, message: ChatMessageContent, source: str = "assistant") -> LLMMessage: | ||
"""Convert SK ChatMessageContent to Autogen LLMMessage""" | ||
if message.role == AuthorRole.SYSTEM: | ||
return SystemMessage(content=message.content) | ||
|
||
elif message.role == AuthorRole.USER: | ||
return UserMessage(content=message.content, source=source) | ||
|
||
elif message.role == AuthorRole.ASSISTANT: | ||
return AssistantMessage(content=message.content, source=source) | ||
|
||
elif message.role == AuthorRole.TOOL: | ||
return FunctionExecutionResultMessage( | ||
content=[FunctionExecutionResult(content=message.content, call_id="")] | ||
) | ||
|
||
raise ValueError(f"Unknown role: {message.role}") | ||
|
||
def _build_execution_settings(self, extra_create_args: Mapping[str, Any], tools: Sequence[Tool | ToolSchema]) -> PromptExecutionSettings: | ||
"""Build PromptExecutionSettings from extra_create_args""" | ||
# Extract service_id if provided, otherwise use None | ||
service_id = extra_create_args.get("service_id") | ||
|
||
# If tools are available, configure function choice behavior with auto_invoke disabled | ||
function_choice_behavior = None | ||
if tools: | ||
function_choice_behavior = FunctionChoiceBehavior.NoneInvoke() | ||
|
||
# Create settings with remaining args as extension_data | ||
settings = PromptExecutionSettings( | ||
service_id=service_id, | ||
extension_data=dict(extra_create_args), | ||
function_choice_behavior=function_choice_behavior | ||
) | ||
|
||
return settings | ||
|
||
def _sync_tools_with_kernel(self, kernel: Kernel, tools: Sequence[Tool | ToolSchema]) -> None: | ||
"""Sync tools with kernel by updating the plugin""" | ||
# Create new plugin if none exists | ||
if not self._tools_plugin: | ||
self._tools_plugin = KernelPlugin(name="autogen_tools") | ||
kernel.add_plugin(self._tools_plugin) | ||
|
||
# Get current tool names in plugin | ||
current_tool_names = set(self._tools_plugin.functions.keys()) | ||
|
||
# Get new tool names | ||
new_tool_names = {tool.schema["name"] if isinstance(tool, Tool) else tool.name for tool in tools} | ||
|
||
# Remove tools that are no longer needed | ||
for tool_name in current_tool_names - new_tool_names: | ||
del self._tools_plugin.functions[tool_name] | ||
|
||
# Add or update tools | ||
for tool in tools: | ||
if isinstance(tool, Tool): | ||
# Convert Tool to KernelFunction using KernelFunctionFromTool | ||
kernel_function = KernelFunctionFromTool(tool, plugin_name="autogen_tools") | ||
self._tools_plugin.functions[tool.name] = kernel_function | ||
|
||
async def create( | ||
self, | ||
messages: Sequence[LLMMessage], | ||
tools: Sequence[Tool | ToolSchema] = [], | ||
json_output: Optional[bool] = None, | ||
extra_create_args: Mapping[str, Any] = {}, | ||
cancellation_token: Optional[CancellationToken] = None, | ||
) -> CreateResult: | ||
if "kernel" not in extra_create_args: | ||
raise ValueError("kernel is required in extra_create_args") | ||
|
||
kernel = extra_create_args["kernel"] | ||
if not isinstance(kernel, Kernel): | ||
raise ValueError("kernel must be an instance of semantic_kernel.kernel.Kernel") | ||
|
||
chat_history = self._convert_to_chat_history(messages) | ||
|
||
# Build execution settings from extra args and tools | ||
settings = self._build_execution_settings(extra_create_args, tools) | ||
|
||
# Sync tools with kernel | ||
self._sync_tools_with_kernel(kernel, tools) | ||
|
||
result = await self._sk_client.get_chat_message_contents( | ||
chat_history, | ||
settings=settings, | ||
kernel=kernel | ||
) | ||
# Track token usage from result metadata | ||
prompt_tokens = 0 | ||
completion_tokens = 0 | ||
|
||
if result[0].metadata and 'usage' in result[0].metadata: | ||
usage = result[0].metadata['usage'] | ||
prompt_tokens = getattr(usage, 'prompt_tokens', 0) | ||
completion_tokens = getattr(usage, 'completion_tokens', 0) | ||
|
||
self._total_prompt_tokens += prompt_tokens | ||
self._total_completion_tokens += completion_tokens | ||
|
||
return CreateResult( | ||
content=result[0].content, | ||
finish_reason="stop", | ||
usage=RequestUsage( | ||
prompt_tokens=prompt_tokens, | ||
completion_tokens=completion_tokens | ||
), | ||
cached=False | ||
) | ||
|
||
async def create_stream( | ||
self, | ||
messages: Sequence[LLMMessage], | ||
tools: Sequence[Tool | ToolSchema] = [], | ||
json_output: Optional[bool] = None, | ||
extra_create_args: Mapping[str, Any] = {}, | ||
cancellation_token: Optional[CancellationToken] = None, | ||
) -> AsyncGenerator[Union[str, CreateResult], None]: | ||
# Very similar to create(), but orchestrates streaming. | ||
# 1. Convert messages -> ChatHistory | ||
# 2. Possibly set function-calling if needed | ||
# 3. Build generator that yields str segments or a final CreateResult | ||
# from SK's get_streaming_chat_message_contents(...) | ||
raise NotImplementedError("create_stream is not implemented") | ||
|
||
def actual_usage(self) -> RequestUsage: | ||
return RequestUsage( | ||
prompt_tokens=self._total_prompt_tokens, | ||
completion_tokens=self._total_completion_tokens | ||
) | ||
|
||
def total_usage(self) -> RequestUsage: | ||
return RequestUsage( | ||
prompt_tokens=self._total_prompt_tokens, | ||
completion_tokens=self._total_completion_tokens | ||
) | ||
|
||
def count_tokens(self, messages: Sequence[LLMMessage]) -> int: | ||
chat_history = self._convert_to_chat_history(messages) | ||
total_tokens = 0 | ||
for message in chat_history.messages: | ||
if message.metadata and 'usage' in message.metadata: | ||
usage = message.metadata['usage'] | ||
total_tokens += getattr(usage, 'total_tokens', 0) | ||
return total_tokens | ||
|
||
def remaining_tokens(self, messages: Sequence[LLMMessage]) -> int: | ||
# Get total token count | ||
used_tokens = self.count_tokens(messages) | ||
# Assume max tokens from SK client if available, otherwise use default | ||
max_tokens = getattr(self._sk_client, 'max_tokens', 4096) | ||
return max_tokens - used_tokens | ||
|
||
@property | ||
def capabilities(self) -> ModelCapabilities: | ||
# Return something consistent with the underlying SK client | ||
return { | ||
"vision": False, | ||
"function_calling": self._sk_client.SUPPORTS_FUNCTION_CALLING, | ||
"json_output": False, | ||
} |
73 changes: 73 additions & 0 deletions
73
python/packages/autogen-ext/tests/models/test_sk_chat_completion_adapter.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,73 @@ | ||
import os | ||
import pytest | ||
from semantic_kernel.connectors.ai.open_ai.services.azure_chat_completion import AzureChatCompletion | ||
from semantic_kernel.kernel import Kernel | ||
from semantic_kernel.memory.null_memory import NullMemory | ||
from autogen_core.models import SystemMessage, UserMessage | ||
from autogen_core.tools import BaseTool | ||
from autogen_ext.models.semantic_kernel import SKChatCompletionAdapter | ||
from pydantic import BaseModel | ||
from autogen_core import CancellationToken | ||
|
||
class CalculatorArgs(BaseModel): | ||
a: float | ||
b: float | ||
|
||
class CalculatorResult(BaseModel): | ||
result: float | ||
|
||
class CalculatorTool(BaseTool[CalculatorArgs, CalculatorResult]): | ||
def __init__(self): | ||
super().__init__( | ||
args_type=CalculatorArgs, | ||
return_type=CalculatorResult, | ||
name="calculator", | ||
description="Add two numbers together" | ||
) | ||
|
||
async def run(self, args: CalculatorArgs, cancellation_token: CancellationToken) -> CalculatorResult: | ||
return CalculatorResult(result=args.a + args.b) | ||
|
||
@pytest.mark.asyncio | ||
async def test_sk_chat_completion_with_tools(): | ||
# Set up Azure OpenAI client with token auth | ||
deployment_name = "gpt-4o-mini" | ||
endpoint = "https://<your-endpoint>.openai.azure.com/" | ||
api_version = "2024-07-18" | ||
|
||
# Create SK client | ||
sk_client = AzureChatCompletion( | ||
deployment_name=deployment_name, | ||
endpoint=endpoint, | ||
api_key=os.getenv("AZURE_OPENAI_API_KEY"), | ||
) | ||
|
||
# Create adapter | ||
adapter = SKChatCompletionAdapter(sk_client) | ||
|
||
# Create kernel | ||
kernel = Kernel(memory=NullMemory()) | ||
|
||
# Create calculator tool instance | ||
tool = CalculatorTool() | ||
|
||
# Test messages | ||
messages = [ | ||
SystemMessage(content="You are a helpful assistant."), | ||
UserMessage(content="What is 2 + 2?", source="user"), | ||
] | ||
|
||
# Call create with tool | ||
result = await adapter.create( | ||
messages=messages, | ||
tools=[tool], | ||
extra_create_args={"kernel": kernel} | ||
) | ||
|
||
|
||
# Verify response | ||
assert isinstance(result.content, str) | ||
assert result.finish_reason == "stop" | ||
assert result.usage.prompt_tokens >= 0 | ||
assert result.usage.completion_tokens >= 0 | ||
assert not result.cached |
Oops, something went wrong.