Skip to content

Commit

Permalink
[.Net] Add AutoGen.AzureAIInference (#3332)
Browse files Browse the repository at this point in the history
* add AutoGen.AzureAIInference

* add tests

* update readme

* fix format
  • Loading branch information
LittleLittleCloud authored Aug 8, 2024
1 parent 5732b3e commit 4dab28c
Show file tree
Hide file tree
Showing 22 changed files with 1,900 additions and 102 deletions.
21 changes: 21 additions & 0 deletions dotnet/AutoGen.sln
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,12 @@ Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "AutoGen.OpenAI.Sample", "sa
EndProject
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "AutoGen.WebAPI.Sample", "sample\AutoGen.WebAPI.Sample\AutoGen.WebAPI.Sample.csproj", "{12079C18-A519-403F-BBFD-200A36A0C083}"
EndProject
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "AutoGen.AzureAIInference", "src\AutoGen.AzureAIInference\AutoGen.AzureAIInference.csproj", "{5C45981D-1319-4C25-935C-83D411CB28DF}"
EndProject
Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "AutoGen.AzureAIInference.Tests", "test\AutoGen.AzureAIInference.Tests\AutoGen.AzureAIInference.Tests.csproj", "{5970868F-831E-418F-89A9-4EC599563E16}"
EndProject
Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "AutoGen.Tests.Share", "test\AutoGen.Test.Share\AutoGen.Tests.Share.csproj", "{143725E2-206C-4D37-93E4-9EDF699826B2}"
EndProject
Global
GlobalSection(SolutionConfigurationPlatforms) = preSolution
Debug|Any CPU = Debug|Any CPU
Expand Down Expand Up @@ -194,6 +200,18 @@ Global
{12079C18-A519-403F-BBFD-200A36A0C083}.Debug|Any CPU.Build.0 = Debug|Any CPU
{12079C18-A519-403F-BBFD-200A36A0C083}.Release|Any CPU.ActiveCfg = Release|Any CPU
{12079C18-A519-403F-BBFD-200A36A0C083}.Release|Any CPU.Build.0 = Release|Any CPU
{5C45981D-1319-4C25-935C-83D411CB28DF}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
{5C45981D-1319-4C25-935C-83D411CB28DF}.Debug|Any CPU.Build.0 = Debug|Any CPU
{5C45981D-1319-4C25-935C-83D411CB28DF}.Release|Any CPU.ActiveCfg = Release|Any CPU
{5C45981D-1319-4C25-935C-83D411CB28DF}.Release|Any CPU.Build.0 = Release|Any CPU
{5970868F-831E-418F-89A9-4EC599563E16}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
{5970868F-831E-418F-89A9-4EC599563E16}.Debug|Any CPU.Build.0 = Debug|Any CPU
{5970868F-831E-418F-89A9-4EC599563E16}.Release|Any CPU.ActiveCfg = Release|Any CPU
{5970868F-831E-418F-89A9-4EC599563E16}.Release|Any CPU.Build.0 = Release|Any CPU
{143725E2-206C-4D37-93E4-9EDF699826B2}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
{143725E2-206C-4D37-93E4-9EDF699826B2}.Debug|Any CPU.Build.0 = Debug|Any CPU
{143725E2-206C-4D37-93E4-9EDF699826B2}.Release|Any CPU.ActiveCfg = Release|Any CPU
{143725E2-206C-4D37-93E4-9EDF699826B2}.Release|Any CPU.Build.0 = Release|Any CPU
EndGlobalSection
GlobalSection(SolutionProperties) = preSolution
HideSolutionNode = FALSE
Expand Down Expand Up @@ -229,6 +247,9 @@ Global
{6B82F26D-5040-4453-B21B-C8D1F913CE4C} = {F823671B-3ECA-4AE6-86DA-25E920D3FE64}
{0E635268-351C-4A6B-A28D-593D868C2CA4} = {FBFEAD1F-29EB-4D99-A672-0CD8473E10B9}
{12079C18-A519-403F-BBFD-200A36A0C083} = {FBFEAD1F-29EB-4D99-A672-0CD8473E10B9}
{5C45981D-1319-4C25-935C-83D411CB28DF} = {18BF8DD7-0585-48BF-8F97-AD333080CE06}
{5970868F-831E-418F-89A9-4EC599563E16} = {F823671B-3ECA-4AE6-86DA-25E920D3FE64}
{143725E2-206C-4D37-93E4-9EDF699826B2} = {F823671B-3ECA-4AE6-86DA-25E920D3FE64}
EndGlobalSection
GlobalSection(ExtensibilityGlobals) = postSolution
SolutionGuid = {93384647-528D-46C8-922C-8DB36A382F0B}
Expand Down
1 change: 1 addition & 0 deletions dotnet/eng/Version.props
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
<MicrosoftASPNETCoreVersion>8.0.4</MicrosoftASPNETCoreVersion>
<GoogleCloudAPIPlatformVersion>3.0.0</GoogleCloudAPIPlatformVersion>
<JsonSchemaVersion>4.3.0.2</JsonSchemaVersion>
<AzureAIInferenceVersion>1.0.0-beta.1</AzureAIInferenceVersion>
<PowershellSDKVersion>7.4.4</PowershellSDKVersion>
</PropertyGroup>
</Project>
Original file line number Diff line number Diff line change
@@ -0,0 +1,202 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// ChatCompletionsClientAgent.cs

using System;
using System.Collections.Generic;
using System.Linq;
using System.Runtime.CompilerServices;
using System.Text.Json;
using System.Threading;
using System.Threading.Tasks;
using AutoGen.AzureAIInference.Extension;
using AutoGen.Core;
using Azure.AI.Inference;

namespace AutoGen.AzureAIInference;

/// <summary>
/// ChatCompletions client agent. This agent is a thin wrapper around <see cref="ChatCompletionsClient"/> to provide a simple interface for chat completions.
/// <para><see cref="ChatCompletionsClientAgent" /> supports the following message types:</para>
/// <list type="bullet">
/// <item>
/// <see cref="MessageEnvelope{T}"/> where T is <see cref="ChatRequestMessage"/>: chat request message.
/// </item>
/// </list>
/// <para><see cref="ChatCompletionsClientAgent" /> returns the following message types:</para>
/// <list type="bullet">
/// <item>
/// <see cref="MessageEnvelope{T}"/> where T is <see cref="ChatResponseMessage"/>: chat response message.
/// <see cref="MessageEnvelope{T}"/> where T is <see cref="StreamingChatCompletionsUpdate"/>: streaming chat completions update.
/// </item>
/// </list>
/// </summary>
public class ChatCompletionsClientAgent : IStreamingAgent
{
private readonly ChatCompletionsClient chatCompletionsClient;
private readonly ChatCompletionsOptions options;
private readonly string systemMessage;

/// <summary>
/// Create a new instance of <see cref="ChatCompletionsClientAgent"/>.
/// </summary>
/// <param name="chatCompletionsClient">chat completions client</param>
/// <param name="name">agent name</param>
/// <param name="modelName">model name. e.g. gpt-turbo-3.5</param>
/// <param name="systemMessage">system message</param>
/// <param name="temperature">temperature</param>
/// <param name="maxTokens">max tokens to generated</param>
/// <param name="responseFormat">response format, set it to <see cref="ChatCompletionsResponseFormatJSON"/> to enable json mode.</param>
/// <param name="seed">seed to use, set it to enable deterministic output</param>
/// <param name="functions">functions</param>
public ChatCompletionsClientAgent(
ChatCompletionsClient chatCompletionsClient,
string name,
string modelName,
string systemMessage = "You are a helpful AI assistant",
float temperature = 0.7f,
int maxTokens = 1024,
int? seed = null,
ChatCompletionsResponseFormat? responseFormat = null,
IEnumerable<FunctionDefinition>? functions = null)
: this(
chatCompletionsClient: chatCompletionsClient,
name: name,
options: CreateChatCompletionOptions(modelName, temperature, maxTokens, seed, responseFormat, functions),
systemMessage: systemMessage)
{
}

/// <summary>
/// Create a new instance of <see cref="ChatCompletionsClientAgent"/>.
/// </summary>
/// <param name="chatCompletionsClient">chat completions client</param>
/// <param name="name">agent name</param>
/// <param name="systemMessage">system message</param>
/// <param name="options">chat completion option. The option can't contain messages</param>
public ChatCompletionsClientAgent(
ChatCompletionsClient chatCompletionsClient,
string name,
ChatCompletionsOptions options,
string systemMessage = "You are a helpful AI assistant")
{
if (options.Messages is { Count: > 0 })
{
throw new ArgumentException("Messages should not be provided in options");
}

this.chatCompletionsClient = chatCompletionsClient;
this.Name = name;
this.options = options;
this.systemMessage = systemMessage;
}

public string Name { get; }

public async Task<IMessage> GenerateReplyAsync(
IEnumerable<IMessage> messages,
GenerateReplyOptions? options = null,
CancellationToken cancellationToken = default)
{
var settings = this.CreateChatCompletionsOptions(options, messages);
var reply = await this.chatCompletionsClient.CompleteAsync(settings, cancellationToken: cancellationToken);

return new MessageEnvelope<ChatCompletions>(reply, from: this.Name);
}

public async IAsyncEnumerable<IMessage> GenerateStreamingReplyAsync(
IEnumerable<IMessage> messages,
GenerateReplyOptions? options = null,
[EnumeratorCancellation] CancellationToken cancellationToken = default)
{
var settings = this.CreateChatCompletionsOptions(options, messages);
var response = await this.chatCompletionsClient.CompleteStreamingAsync(settings, cancellationToken);
await foreach (var update in response.WithCancellation(cancellationToken))
{
yield return new MessageEnvelope<StreamingChatCompletionsUpdate>(update, from: this.Name);
}
}

private ChatCompletionsOptions CreateChatCompletionsOptions(GenerateReplyOptions? options, IEnumerable<IMessage> messages)
{
var oaiMessages = messages.Select(m => m switch
{
IMessage<ChatRequestMessage> chatRequestMessage => chatRequestMessage.Content,
_ => throw new ArgumentException("Invalid message type")
});

// add system message if there's no system message in messages
if (!oaiMessages.Any(m => m is ChatRequestSystemMessage))
{
oaiMessages = new[] { new ChatRequestSystemMessage(systemMessage) }.Concat(oaiMessages);
}

// clone the options by serializing and deserializing
var json = JsonSerializer.Serialize(this.options);
var settings = JsonSerializer.Deserialize<ChatCompletionsOptions>(json) ?? throw new InvalidOperationException("Failed to clone options");

foreach (var m in oaiMessages)
{
settings.Messages.Add(m);
}

settings.Temperature = options?.Temperature ?? settings.Temperature;
settings.MaxTokens = options?.MaxToken ?? settings.MaxTokens;

foreach (var functions in this.options.Tools)
{
settings.Tools.Add(functions);
}

foreach (var stopSequence in this.options.StopSequences)
{
settings.StopSequences.Add(stopSequence);
}

var openAIFunctionDefinitions = options?.Functions?.Select(f => f.ToAzureAIInferenceFunctionDefinition()).ToList();
if (openAIFunctionDefinitions is { Count: > 0 })
{
foreach (var f in openAIFunctionDefinitions)
{
settings.Tools.Add(new ChatCompletionsFunctionToolDefinition(f));
}
}

if (options?.StopSequence is var sequence && sequence is { Length: > 0 })
{
foreach (var seq in sequence)
{
settings.StopSequences.Add(seq);
}
}

return settings;
}

private static ChatCompletionsOptions CreateChatCompletionOptions(
string modelName,
float temperature = 0.7f,
int maxTokens = 1024,
int? seed = null,
ChatCompletionsResponseFormat? responseFormat = null,
IEnumerable<FunctionDefinition>? functions = null)
{
var options = new ChatCompletionsOptions()
{
Model = modelName,
Temperature = temperature,
MaxTokens = maxTokens,
Seed = seed,
ResponseFormat = responseFormat,
};

if (functions is not null)
{
foreach (var f in functions)
{
options.Tools.Add(new ChatCompletionsFunctionToolDefinition(f));
}
}

return options;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
<Project Sdk="Microsoft.NET.Sdk">
<PropertyGroup>
<TargetFrameworks>$(PackageTargetFrameworks)</TargetFrameworks>
<RootNamespace>AutoGen.AzureAIInference</RootNamespace>
</PropertyGroup>

<Import Project="$(RepoRoot)/nuget/nuget-package.props" />

<PropertyGroup>
<!-- NuGet Package Settings -->
<Title>AutoGen.AzureAIInference</Title>
<Description>
Azure AI Inference Intergration for AutoGen.
</Description>
</PropertyGroup>

<ItemGroup>
<PackageReference Include="Azure.AI.Inference" Version="$(AzureAIInferenceVersion)" />
</ItemGroup>

<ItemGroup>
<ProjectReference Include="..\AutoGen.Core\AutoGen.Core.csproj" />
</ItemGroup>

</Project>
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// ChatComptionClientAgentExtension.cs

using AutoGen.Core;

namespace AutoGen.AzureAIInference.Extension;

public static class ChatComptionClientAgentExtension
{
/// <summary>
/// Register an <see cref="AzureAIInferenceChatRequestMessageConnector"/> to the <see cref="ChatCompletionsClientAgent"/>
/// </summary>
/// <param name="connector">the connector to use. If null, a new instance of <see cref="AzureAIInferenceChatRequestMessageConnector"/> will be created.</param>
public static MiddlewareStreamingAgent<ChatCompletionsClientAgent> RegisterMessageConnector(
this ChatCompletionsClientAgent agent, AzureAIInferenceChatRequestMessageConnector? connector = null)
{
if (connector == null)
{
connector = new AzureAIInferenceChatRequestMessageConnector();
}

return agent.RegisterStreamingMiddleware(connector);
}

/// <summary>
/// Register an <see cref="AzureAIInferenceChatRequestMessageConnector"/> to the <see cref="MiddlewareAgent{T}"/> where T is <see cref="ChatCompletionsClientAgent"/>
/// </summary>
/// <param name="connector">the connector to use. If null, a new instance of <see cref="AzureAIInferenceChatRequestMessageConnector"/> will be created.</param>
public static MiddlewareStreamingAgent<ChatCompletionsClientAgent> RegisterMessageConnector(
this MiddlewareStreamingAgent<ChatCompletionsClientAgent> agent, AzureAIInferenceChatRequestMessageConnector? connector = null)
{
if (connector == null)
{
connector = new AzureAIInferenceChatRequestMessageConnector();
}

return agent.RegisterStreamingMiddleware(connector);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// FunctionContractExtension.cs

using System;
using System.Collections.Generic;
using AutoGen.Core;
using Azure.AI.Inference;
using Json.Schema;
using Json.Schema.Generation;

namespace AutoGen.AzureAIInference.Extension;

public static class FunctionContractExtension
{
/// <summary>
/// Convert a <see cref="FunctionContract"/> to a <see cref="FunctionDefinition"/> that can be used in gpt funciton call.
/// </summary>
/// <param name="functionContract">function contract</param>
/// <returns><see cref="FunctionDefinition"/></returns>
public static FunctionDefinition ToAzureAIInferenceFunctionDefinition(this FunctionContract functionContract)
{
var functionDefinition = new FunctionDefinition
{
Name = functionContract.Name,
Description = functionContract.Description,
};
var requiredParameterNames = new List<string>();
var propertiesSchemas = new Dictionary<string, JsonSchema>();
var propertySchemaBuilder = new JsonSchemaBuilder().Type(SchemaValueType.Object);
foreach (var param in functionContract.Parameters ?? [])
{
if (param.Name is null)
{
throw new InvalidOperationException("Parameter name cannot be null");
}

var schemaBuilder = new JsonSchemaBuilder().FromType(param.ParameterType ?? throw new ArgumentNullException(nameof(param.ParameterType)));
if (param.Description != null)
{
schemaBuilder = schemaBuilder.Description(param.Description);
}

if (param.IsRequired)
{
requiredParameterNames.Add(param.Name);
}

var schema = schemaBuilder.Build();
propertiesSchemas[param.Name] = schema;

}
propertySchemaBuilder = propertySchemaBuilder.Properties(propertiesSchemas);
propertySchemaBuilder = propertySchemaBuilder.Required(requiredParameterNames);

var option = new System.Text.Json.JsonSerializerOptions()
{
PropertyNamingPolicy = System.Text.Json.JsonNamingPolicy.CamelCase
};

functionDefinition.Parameters = BinaryData.FromObjectAsJson(propertySchemaBuilder.Build(), option);

return functionDefinition;
}
}
Loading

0 comments on commit 4dab28c

Please sign in to comment.