Skip to content

Commit

Permalink
WIP: Azure AI Client
Browse files Browse the repository at this point in the history
* Added: object-level usage data
* Added: doc string
* Added: check existing response_format value
* Added: _validate_config and _create_client
  • Loading branch information
rohanthacker committed Dec 30, 2024
1 parent 06d3f95 commit daf43de
Show file tree
Hide file tree
Showing 5 changed files with 131 additions and 32 deletions.
1 change: 1 addition & 0 deletions python/packages/autogen-core/docs/src/reference/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ python/autogen_ext.agents.video_surfer.tools
python/autogen_ext.teams.magentic_one
python/autogen_ext.models.openai
python/autogen_ext.models.replay
python/autogen_ext.models.azure
python/autogen_ext.tools.langchain
python/autogen_ext.code_executors.local
python/autogen_ext.code_executors.docker
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
autogen\_ext.models.azure
==========================


.. automodule:: autogen_ext.models.azure
:members:
:undoc-members:
:show-inheritance:
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from ._azure_ai_client import AzureAIChatCompletionClient
from .config import AzureAIChatCompletionClientConfig

__all__ = ["AzureAIChatCompletionClient"]
__all__ = ["AzureAIChatCompletionClient", "AzureAIChatCompletionClientConfig"]
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import asyncio
import re
import warnings
from asyncio import Task
from typing import Sequence, Optional, Mapping, Any, List, Unpack, Dict, cast
from inspect import getfullargspec
Expand Down Expand Up @@ -154,25 +155,95 @@ def assert_valid_name(name: str) -> str:


class AzureAIChatCompletionClient(ChatCompletionClient):
"""
Chat completion client for models hosted on Azure AI Foundry or GitHub Models.
See `here <https://learn.microsoft.com/en-us/azure/ai-studio/reference/reference-model-inference-chat-completions>`_ for more info.
Args:
endpoint (str): The endpoint to use. **Required.**
credentials (union, AzureKeyCredential, AsyncTokenCredential): The credentials to use. **Required**
model_capabilities (ModelCapabilities): The capabilities of the model. **Required.**
model (str): The name of the model. **Required if model is hosted on GitHub Models.**
frequency_penalty: (optional,float)
presence_penalty: (optional,float)
temperature: (optional,float)
top_p: (optional,float)
max_tokens: (optional,int)
response_format: (optional,ChatCompletionsResponseFormat)
stop: (optional,List[str])
tools: (optional,List[ChatCompletionsToolDefinition])
tool_choice: (optional,Union[str, ChatCompletionsToolChoicePreset, ChatCompletionsNamedToolChoice]])
seed: (optional,int)
model_extras: (optional,Dict[str, Any])
To use this client, you must install the `azure-ai-inference` extension:
.. code-block:: bash
pip install 'autogen-ext[azure-ai-inference]==0.4.0.dev11'
The following code snippet shows how to use the client:
.. code-block:: python
from azure.core.credentials import AzureKeyCredential
from autogen_ext.models.azure import AzureAIChatCompletionClient
from autogen_core.models import UserMessage
client = AzureAIChatCompletionClient(
endpoint="endpoint",
credential=AzureKeyCredential("api_key"),
model_capabilities={
"json_output": False,
"function_calling": False,
"vision": False,
},
)
result = await client.create([UserMessage(content="What is the capital of France?", source="user")]) # type: ignore
print(result)
"""

def __init__(self, **kwargs: Unpack[AzureAIChatCompletionClientConfig]):
if "endpoint" not in kwargs:
config = self._validate_config(kwargs)
self._model_capabilities = config["model_capabilities"]
self._client = self._create_client(config)
self._create_args = self._prepare_create_args(config)

self._actual_usage = RequestUsage(prompt_tokens=0, completion_tokens=0)
self._total_usage = RequestUsage(prompt_tokens=0, completion_tokens=0)

@staticmethod
def _validate_config(config: Dict) -> AzureAIChatCompletionClientConfig:
if "endpoint" not in config:
raise ValueError("endpoint is required for AzureAIChatCompletionClient")
if "credential" not in kwargs:
if "credential" not in config:
raise ValueError("credential is required for AzureAIChatCompletionClient")
if "model_capabilities" not in kwargs:
if "model_capabilities" not in config:
raise ValueError("model_capabilities is required for AzureAIChatCompletionClient")
if _is_github_model(kwargs['endpoint']) and "model" not in kwargs:
if _is_github_model(config["endpoint"]) and "model" not in config:
raise ValueError("model is required for when using a Github model with AzureAIChatCompletionClient")

# TODO: Change
_endpoint = kwargs.pop("endpoint")
_credential = kwargs.pop("credential")
self._model_capabilities = kwargs.pop("model_capabilities")
self._create_args = kwargs.copy()

self._client = ChatCompletionsClient(_endpoint, _credential, **self._create_args)
self._actual_usage = RequestUsage(prompt_tokens=0, completion_tokens=0)
self._total_usage = RequestUsage(prompt_tokens=0, completion_tokens=0)
return config

@staticmethod
def _create_client(config: AzureAIChatCompletionClientConfig):
return ChatCompletionsClient(**config)

@staticmethod
def _prepare_create_args(config: Mapping[str, Any]) -> Mapping[str, Any]:
create_args = {k: v for k, v in config.items() if k in create_kwargs}
return create_args
# self._endpoint = config.pop("endpoint")
# self._credential = config.pop("credential")
# self._model_capabilities = config.pop("model_capabilities")
# self._create_args = config.copy()

def add_usage(self, usage: RequestUsage):
self._total_usage = RequestUsage(
self._total_usage.prompt_tokens + usage.prompt_tokens,
self._total_usage.completion_tokens + usage.completion_tokens,
)

async def create(
self,
Expand Down Expand Up @@ -200,7 +271,7 @@ async def create(
if self.capabilities["json_output"] is False and json_output is True:
raise ValueError("Model does not support JSON output")

if json_output is True:
if json_output is True and "response_format" not in create_args:
create_args["response_format"] = ChatCompletionsResponseFormatJSON()

if self.capabilities["json_output"] is False and json_output is True:
Expand Down Expand Up @@ -259,6 +330,9 @@ async def create(
usage=usage,
cached=False,
)

self.add_usage(usage)

return response

async def create_stream(
Expand Down Expand Up @@ -286,7 +360,7 @@ async def create_stream(
if self.capabilities["json_output"] is False and json_output is True:
raise ValueError("Model does not support JSON output")

if json_output is True:
if json_output is True and "response_format" not in create_args:
create_args["response_format"] = ChatCompletionsResponseFormatJSON()

if self.capabilities["json_output"] is False and json_output is True:
Expand Down Expand Up @@ -380,6 +454,9 @@ async def create_stream(
usage=usage,
cached=False,
)

self.add_usage(usage)

yield result

def actual_usage(self) -> RequestUsage:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,20 @@
ChatCompletionsClient,
)


from azure.ai.inference.models import (
ChatChoice,
ChatResponseMessage,
CompletionsUsage,
ChatCompletionsResponseFormatJSON,
)

from azure.ai.inference.models import (
ChatCompletions,
StreamingChatCompletionsUpdate,
StreamingChatChoiceUpdate,
StreamingChatResponseMessageUpdate,
)
from azure.ai.inference.models import (ChatCompletions,
StreamingChatCompletionsUpdate, StreamingChatChoiceUpdate,
StreamingChatResponseMessageUpdate)

from azure.core.credentials import AzureKeyCredential

Expand All @@ -32,7 +37,8 @@ async def _mock_create_stream(*args: Any, **kwargs: Any) -> AsyncGenerator[Strea
index=0,
finish_reason="stop",
delta=StreamingChatResponseMessageUpdate(role="assistant", content=chunk_content),
) for chunk_content in mock_chunks_content
)
for chunk_content in mock_chunks_content
]

for mock_chunk in mock_chunks:
Expand All @@ -46,20 +52,20 @@ async def _mock_create_stream(*args: Any, **kwargs: Any) -> AsyncGenerator[Strea
)


async def _mock_create(*args: Any, **kwargs: Any) -> ChatCompletions | AsyncGenerator[StreamingChatCompletionsUpdate, None]:
async def _mock_create(
*args: Any, **kwargs: Any
) -> ChatCompletions | AsyncGenerator[StreamingChatCompletionsUpdate, None]:
stream = kwargs.get("stream", False)

if not stream:
await asyncio.sleep(0.1)
return ChatCompletions(
id="id",
created=datetime.now(),
model='model',
model="model",
choices=[
ChatChoice(
index=0,
finish_reason="stop",
message=ChatResponseMessage(content="Hello", role="assistant")
index=0, finish_reason="stop", message=ChatResponseMessage(content="Hello", role="assistant")
)
],
usage=CompletionsUsage(prompt_tokens=0, completion_tokens=0, total_tokens=0),
Expand All @@ -68,28 +74,29 @@ async def _mock_create(*args: Any, **kwargs: Any) -> ChatCompletions | AsyncGene
return _mock_create_stream(*args, **kwargs)



@pytest.mark.asyncio
async def test_azure_ai_chat_completion_client() -> None:
client = AzureAIChatCompletionClient(
endpoint="endpoint",
credential=AzureKeyCredential("api_key"),
model_capabilities = {
model_capabilities={
"json_output": False,
"function_calling": False,
"vision": False,
},
model="model",
)
assert client


@pytest.mark.asyncio
async def test_azure_ai_chat_completion_client_create(monkeypatch: pytest.MonkeyPatch) -> None:
# monkeypatch.setattr(AsyncCompletions, "create", _mock_create)
monkeypatch.setattr(ChatCompletionsClient, "complete", _mock_create)
client = AzureAIChatCompletionClient(
endpoint="endpoint",
credential=AzureKeyCredential("api_key"),
model_capabilities = {
model_capabilities={
"json_output": False,
"function_calling": False,
"vision": False,
Expand All @@ -98,14 +105,15 @@ async def test_azure_ai_chat_completion_client_create(monkeypatch: pytest.Monkey
result = await client.create(messages=[UserMessage(content="Hello", source="user")])
assert result.content == "Hello"


@pytest.mark.asyncio
async def test_azure_ai_chat_completion_client_create_stream(monkeypatch:pytest.MonkeyPatch) -> None:
async def test_azure_ai_chat_completion_client_create_stream(monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(ChatCompletionsClient, "complete", _mock_create)
chunks = []
client = AzureAIChatCompletionClient(
endpoint="endpoint",
credential=AzureKeyCredential("api_key"),
model_capabilities = {
model_capabilities={
"json_output": False,
"function_calling": False,
"vision": False,
Expand All @@ -118,6 +126,7 @@ async def test_azure_ai_chat_completion_client_create_stream(monkeypatch:pytest.
assert chunks[1] == " Another Hello"
assert chunks[2] == " Yet Another Hello"


@pytest.mark.asyncio
async def test_azure_ai_chat_completion_client_create_cancel(monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(ChatCompletionsClient, "complete", _mock_create)
Expand All @@ -138,6 +147,7 @@ async def test_azure_ai_chat_completion_client_create_cancel(monkeypatch: pytest
with pytest.raises(asyncio.CancelledError):
await task


@pytest.mark.asyncio
async def test_azure_ai_chat_completion_client_create_stream_cancel(monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(ChatCompletionsClient, "complete", _mock_create)
Expand All @@ -151,7 +161,9 @@ async def test_azure_ai_chat_completion_client_create_stream_cancel(monkeypatch:
"vision": False,
},
)
stream=client.create_stream(messages=[UserMessage(content="Hello", source="user")], cancellation_token=cancellation_token)
stream = client.create_stream(
messages=[UserMessage(content="Hello", source="user")], cancellation_token=cancellation_token
)
cancellation_token.cancel()
with pytest.raises(asyncio.CancelledError):
async for _ in stream:
Expand Down

0 comments on commit daf43de

Please sign in to comment.