Skip to content

Commit

Permalink
initial sk model adapter implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
lspinheiro committed Dec 30, 2024
1 parent 501d8bb commit f77cd39
Show file tree
Hide file tree
Showing 6 changed files with 672 additions and 374 deletions.
3 changes: 3 additions & 0 deletions python/packages/autogen-ext/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,9 @@ video-surfer = [
grpc = [
"grpcio~=1.62.0", # TODO: update this once we have a stable version.
]
semantic-kernel = [
"semantic-kernel>=1.17.1",
]

[tool.hatch.build.targets.wheel]
packages = ["src/autogen_ext"]
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from ._sk_chat_completion_adapter import SKChatCompletionAdapter

__all__ = [
"SKChatCompletionAdapter"
]
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
from semantic_kernel.functions.kernel_function import KernelFunction
from semantic_kernel.functions.kernel_function_metadata import KernelFunctionMetadata
from semantic_kernel.functions.kernel_parameter_metadata import KernelParameterMetadata
from semantic_kernel.functions.function_result import FunctionResult
from semantic_kernel.filters.functions.function_invocation_context import FunctionInvocationContext
from semantic_kernel.exceptions import FunctionExecutionException
from autogen_core.tools import BaseTool

class KernelFunctionFromTool(KernelFunction):
def __init__(self, tool: BaseTool, plugin_name: str | None = None):
# Build up KernelFunctionMetadata. You can also parse the tool’s schema for parameters.
parameters = [
KernelParameterMetadata(
name="args",
description="JSON arguments for the tool",
default_value=None,
type_="dict",
type_object=dict,
is_required=True,
)
]
return_param = KernelParameterMetadata(
name="return",
description="Result from the tool",
default_value=None,
type_="str",
type_object=str,
is_required=False,
)

metadata = KernelFunctionMetadata(
name=tool.name,
description=tool.description,
parameters=parameters,
return_parameter=return_param,
is_prompt=False,
is_asynchronous=True,
plugin_name=plugin_name or "",
)
super().__init__(metadata=metadata)
self._tool = tool

async def _invoke_internal(self, context: FunctionInvocationContext) -> None:
# Extract the "args" parameter from the context
if "args" not in context.arguments:
raise FunctionExecutionException("Missing 'args' in FunctionInvocationContext.arguments")
tool_args = context.arguments["args"]

# Call your tool’s run_json
result = await self._tool.run_json(tool_args, cancellation_token=None)

# Wrap in a FunctionResult
context.result = FunctionResult(
function=self.metadata,
value=result,
metadata={"used_arguments": tool_args},
)

async def _invoke_internal_stream(self, context: FunctionInvocationContext) -> None:
# If you don’t have a streaming mechanism in your tool, you can simply reuse _invoke_internal
# or raise NotImplementedError. For example:
await self._invoke_internal(context)
Original file line number Diff line number Diff line change
@@ -0,0 +1,214 @@
from typing import Any, Mapping, Optional, Sequence
from autogen_core._cancellation_token import CancellationToken
from autogen_core.models import RequestUsage, FunctionExecutionResultMessage, ModelCapabilities, AssistantMessage, SystemMessage, UserMessage, FunctionExecutionResult
from autogen_core.models import ChatCompletionClient, CreateResult, LLMMessage
from autogen_core.tools import Tool, ToolSchema
from semantic_kernel.connectors.ai.chat_completion_client_base import ChatCompletionClientBase
from semantic_kernel.connectors.ai.prompt_execution_settings import PromptExecutionSettings
from semantic_kernel.connectors.ai.function_choice_behavior import FunctionChoiceBehavior
from semantic_kernel.contents.chat_history import ChatHistory
from semantic_kernel.contents.chat_message_content import ChatMessageContent
from semantic_kernel.contents.utils.author_role import AuthorRole
from semantic_kernel.kernel import Kernel
from semantic_kernel.functions.kernel_plugin import KernelPlugin
from typing_extensions import AsyncGenerator, Union
from ._kernel_function_from_tool import KernelFunctionFromTool


class SKChatCompletionAdapter(ChatCompletionClient):
def __init__(self, sk_client: ChatCompletionClientBase):
self._sk_client = sk_client
self._total_prompt_tokens = 0
self._total_completion_tokens = 0
self._tools_plugin: Optional[KernelPlugin] = None

def _convert_to_chat_history(self, messages: Sequence[LLMMessage]) -> ChatHistory:
"""Convert Autogen LLMMessages to SK ChatHistory"""
chat_history = ChatHistory()

for msg in messages:
if msg.type == "SystemMessage":
chat_history.add_system_message(msg.content)

elif msg.type == "UserMessage":
if isinstance(msg.content, str):
chat_history.add_user_message(msg.content)
else:
# Handle list of str/Image - would need to convert to SK content types
chat_history.add_user_message(str(msg.content))

elif msg.type == "AssistantMessage":
if isinstance(msg.content, str):
chat_history.add_assistant_message(msg.content)
else:
# Handle function calls - would need to convert to SK function call format
chat_history.add_assistant_message(str(msg.content))

elif msg.type == "FunctionExecutionResultMessage":
for result in msg.content:
chat_history.add_tool_message(result.content)

return chat_history

def _convert_from_chat_message(self, message: ChatMessageContent, source: str = "assistant") -> LLMMessage:
"""Convert SK ChatMessageContent to Autogen LLMMessage"""
if message.role == AuthorRole.SYSTEM:
return SystemMessage(content=message.content)

elif message.role == AuthorRole.USER:
return UserMessage(content=message.content, source=source)

elif message.role == AuthorRole.ASSISTANT:
return AssistantMessage(content=message.content, source=source)

elif message.role == AuthorRole.TOOL:
return FunctionExecutionResultMessage(
content=[FunctionExecutionResult(content=message.content, call_id="")]
)

raise ValueError(f"Unknown role: {message.role}")

def _build_execution_settings(self, extra_create_args: Mapping[str, Any], tools: Sequence[Tool | ToolSchema]) -> PromptExecutionSettings:
"""Build PromptExecutionSettings from extra_create_args"""
# Extract service_id if provided, otherwise use None
service_id = extra_create_args.get("service_id")

# If tools are available, configure function choice behavior with auto_invoke disabled
function_choice_behavior = None
if tools:
function_choice_behavior = FunctionChoiceBehavior.NoneInvoke()

# Create settings with remaining args as extension_data
settings = PromptExecutionSettings(
service_id=service_id,
extension_data=dict(extra_create_args),
function_choice_behavior=function_choice_behavior
)

return settings

def _sync_tools_with_kernel(self, kernel: Kernel, tools: Sequence[Tool | ToolSchema]) -> None:
"""Sync tools with kernel by updating the plugin"""
# Create new plugin if none exists
if not self._tools_plugin:
self._tools_plugin = KernelPlugin(name="autogen_tools")
kernel.add_plugin(self._tools_plugin)

# Get current tool names in plugin
current_tool_names = set(self._tools_plugin.functions.keys())

# Get new tool names
new_tool_names = {tool.schema["name"] if isinstance(tool, Tool) else tool.name for tool in tools}

# Remove tools that are no longer needed
for tool_name in current_tool_names - new_tool_names:
del self._tools_plugin.functions[tool_name]

# Add or update tools
for tool in tools:
if isinstance(tool, Tool):
# Convert Tool to KernelFunction using KernelFunctionFromTool
kernel_function = KernelFunctionFromTool(tool, plugin_name="autogen_tools")
self._tools_plugin.functions[tool.name] = kernel_function

async def create(
self,
messages: Sequence[LLMMessage],
tools: Sequence[Tool | ToolSchema] = [],
json_output: Optional[bool] = None,
extra_create_args: Mapping[str, Any] = {},
cancellation_token: Optional[CancellationToken] = None,
) -> CreateResult:
if "kernel" not in extra_create_args:
raise ValueError("kernel is required in extra_create_args")

kernel = extra_create_args["kernel"]
if not isinstance(kernel, Kernel):
raise ValueError("kernel must be an instance of semantic_kernel.kernel.Kernel")

chat_history = self._convert_to_chat_history(messages)

# Build execution settings from extra args and tools
settings = self._build_execution_settings(extra_create_args, tools)

# Sync tools with kernel
self._sync_tools_with_kernel(kernel, tools)

result = await self._sk_client.get_chat_message_contents(
chat_history,
settings=settings,
kernel=kernel
)
# Track token usage from result metadata
prompt_tokens = 0
completion_tokens = 0

if result[0].metadata and 'usage' in result[0].metadata:
usage = result[0].metadata['usage']
prompt_tokens = getattr(usage, 'prompt_tokens', 0)
completion_tokens = getattr(usage, 'completion_tokens', 0)

self._total_prompt_tokens += prompt_tokens
self._total_completion_tokens += completion_tokens

return CreateResult(
content=result[0].content,
finish_reason="stop",
usage=RequestUsage(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens
),
cached=False
)

async def create_stream(
self,
messages: Sequence[LLMMessage],
tools: Sequence[Tool | ToolSchema] = [],
json_output: Optional[bool] = None,
extra_create_args: Mapping[str, Any] = {},
cancellation_token: Optional[CancellationToken] = None,
) -> AsyncGenerator[Union[str, CreateResult], None]:
# Very similar to create(), but orchestrates streaming.
# 1. Convert messages -> ChatHistory
# 2. Possibly set function-calling if needed
# 3. Build generator that yields str segments or a final CreateResult
# from SK's get_streaming_chat_message_contents(...)
raise NotImplementedError("create_stream is not implemented")

def actual_usage(self) -> RequestUsage:
return RequestUsage(
prompt_tokens=self._total_prompt_tokens,
completion_tokens=self._total_completion_tokens
)

def total_usage(self) -> RequestUsage:
return RequestUsage(
prompt_tokens=self._total_prompt_tokens,
completion_tokens=self._total_completion_tokens
)

def count_tokens(self, messages: Sequence[LLMMessage]) -> int:
chat_history = self._convert_to_chat_history(messages)
total_tokens = 0
for message in chat_history.messages:
if message.metadata and 'usage' in message.metadata:
usage = message.metadata['usage']
total_tokens += getattr(usage, 'total_tokens', 0)
return total_tokens

def remaining_tokens(self, messages: Sequence[LLMMessage]) -> int:
# Get total token count
used_tokens = self.count_tokens(messages)
# Assume max tokens from SK client if available, otherwise use default
max_tokens = getattr(self._sk_client, 'max_tokens', 4096)
return max_tokens - used_tokens

@property
def capabilities(self) -> ModelCapabilities:
# Return something consistent with the underlying SK client
return {
"vision": False,
"function_calling": self._sk_client.SUPPORTS_FUNCTION_CALLING,
"json_output": False,
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
import os
import pytest
from semantic_kernel.connectors.ai.open_ai.services.azure_chat_completion import AzureChatCompletion
from semantic_kernel.kernel import Kernel
from semantic_kernel.memory.null_memory import NullMemory
from autogen_core.models import SystemMessage, UserMessage
from autogen_core.tools import BaseTool
from autogen_ext.models.semantic_kernel import SKChatCompletionAdapter
from pydantic import BaseModel
from autogen_core import CancellationToken

class CalculatorArgs(BaseModel):
a: float
b: float

class CalculatorResult(BaseModel):
result: float

class CalculatorTool(BaseTool[CalculatorArgs, CalculatorResult]):
def __init__(self):
super().__init__(
args_type=CalculatorArgs,
return_type=CalculatorResult,
name="calculator",
description="Add two numbers together"
)

async def run(self, args: CalculatorArgs, cancellation_token: CancellationToken) -> CalculatorResult:
return CalculatorResult(result=args.a + args.b)

@pytest.mark.asyncio
async def test_sk_chat_completion_with_tools():
# Set up Azure OpenAI client with token auth
deployment_name = "gpt-4o-mini"
endpoint = "https://<your-endpoint>.openai.azure.com/"
api_version = "2024-07-18"

# Create SK client
sk_client = AzureChatCompletion(
deployment_name=deployment_name,
endpoint=endpoint,
api_key=os.getenv("AZURE_OPENAI_API_KEY"),
)

# Create adapter
adapter = SKChatCompletionAdapter(sk_client)

# Create kernel
kernel = Kernel(memory=NullMemory())

# Create calculator tool instance
tool = CalculatorTool()

# Test messages
messages = [
SystemMessage(content="You are a helpful assistant."),
UserMessage(content="What is 2 + 2?", source="user"),
]

# Call create with tool
result = await adapter.create(
messages=messages,
tools=[tool],
extra_create_args={"kernel": kernel}
)


# Verify response
assert isinstance(result.content, str)
assert result.finish_reason == "stop"
assert result.usage.prompt_tokens >= 0
assert result.usage.completion_tokens >= 0
assert not result.cached
Loading

0 comments on commit f77cd39

Please sign in to comment.