-
Notifications
You must be signed in to change notification settings - Fork 5.3k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[.Net] Add AutoGen.AzureAIInference (#3332)
* add AutoGen.AzureAIInference * add tests * update readme * fix format
- Loading branch information
1 parent
369a75d
commit 06d6b82
Showing
22 changed files
with
1,900 additions
and
102 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
202 changes: 202 additions & 0 deletions
202
dotnet/src/AutoGen.AzureAIInference/Agent/ChatCompletionsClientAgent.cs
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,202 @@ | ||
// Copyright (c) Microsoft Corporation. All rights reserved. | ||
// ChatCompletionsClientAgent.cs | ||
|
||
using System; | ||
using System.Collections.Generic; | ||
using System.Linq; | ||
using System.Runtime.CompilerServices; | ||
using System.Text.Json; | ||
using System.Threading; | ||
using System.Threading.Tasks; | ||
using AutoGen.AzureAIInference.Extension; | ||
using AutoGen.Core; | ||
using Azure.AI.Inference; | ||
|
||
namespace AutoGen.AzureAIInference; | ||
|
||
/// <summary> | ||
/// ChatCompletions client agent. This agent is a thin wrapper around <see cref="ChatCompletionsClient"/> to provide a simple interface for chat completions. | ||
/// <para><see cref="ChatCompletionsClientAgent" /> supports the following message types:</para> | ||
/// <list type="bullet"> | ||
/// <item> | ||
/// <see cref="MessageEnvelope{T}"/> where T is <see cref="ChatRequestMessage"/>: chat request message. | ||
/// </item> | ||
/// </list> | ||
/// <para><see cref="ChatCompletionsClientAgent" /> returns the following message types:</para> | ||
/// <list type="bullet"> | ||
/// <item> | ||
/// <see cref="MessageEnvelope{T}"/> where T is <see cref="ChatResponseMessage"/>: chat response message. | ||
/// <see cref="MessageEnvelope{T}"/> where T is <see cref="StreamingChatCompletionsUpdate"/>: streaming chat completions update. | ||
/// </item> | ||
/// </list> | ||
/// </summary> | ||
public class ChatCompletionsClientAgent : IStreamingAgent | ||
{ | ||
private readonly ChatCompletionsClient chatCompletionsClient; | ||
private readonly ChatCompletionsOptions options; | ||
private readonly string systemMessage; | ||
|
||
/// <summary> | ||
/// Create a new instance of <see cref="ChatCompletionsClientAgent"/>. | ||
/// </summary> | ||
/// <param name="chatCompletionsClient">chat completions client</param> | ||
/// <param name="name">agent name</param> | ||
/// <param name="modelName">model name. e.g. gpt-turbo-3.5</param> | ||
/// <param name="systemMessage">system message</param> | ||
/// <param name="temperature">temperature</param> | ||
/// <param name="maxTokens">max tokens to generated</param> | ||
/// <param name="responseFormat">response format, set it to <see cref="ChatCompletionsResponseFormatJSON"/> to enable json mode.</param> | ||
/// <param name="seed">seed to use, set it to enable deterministic output</param> | ||
/// <param name="functions">functions</param> | ||
public ChatCompletionsClientAgent( | ||
ChatCompletionsClient chatCompletionsClient, | ||
string name, | ||
string modelName, | ||
string systemMessage = "You are a helpful AI assistant", | ||
float temperature = 0.7f, | ||
int maxTokens = 1024, | ||
int? seed = null, | ||
ChatCompletionsResponseFormat? responseFormat = null, | ||
IEnumerable<FunctionDefinition>? functions = null) | ||
: this( | ||
chatCompletionsClient: chatCompletionsClient, | ||
name: name, | ||
options: CreateChatCompletionOptions(modelName, temperature, maxTokens, seed, responseFormat, functions), | ||
systemMessage: systemMessage) | ||
{ | ||
} | ||
|
||
/// <summary> | ||
/// Create a new instance of <see cref="ChatCompletionsClientAgent"/>. | ||
/// </summary> | ||
/// <param name="chatCompletionsClient">chat completions client</param> | ||
/// <param name="name">agent name</param> | ||
/// <param name="systemMessage">system message</param> | ||
/// <param name="options">chat completion option. The option can't contain messages</param> | ||
public ChatCompletionsClientAgent( | ||
ChatCompletionsClient chatCompletionsClient, | ||
string name, | ||
ChatCompletionsOptions options, | ||
string systemMessage = "You are a helpful AI assistant") | ||
{ | ||
if (options.Messages is { Count: > 0 }) | ||
{ | ||
throw new ArgumentException("Messages should not be provided in options"); | ||
} | ||
|
||
this.chatCompletionsClient = chatCompletionsClient; | ||
this.Name = name; | ||
this.options = options; | ||
this.systemMessage = systemMessage; | ||
} | ||
|
||
public string Name { get; } | ||
|
||
public async Task<IMessage> GenerateReplyAsync( | ||
IEnumerable<IMessage> messages, | ||
GenerateReplyOptions? options = null, | ||
CancellationToken cancellationToken = default) | ||
{ | ||
var settings = this.CreateChatCompletionsOptions(options, messages); | ||
var reply = await this.chatCompletionsClient.CompleteAsync(settings, cancellationToken: cancellationToken); | ||
|
||
return new MessageEnvelope<ChatCompletions>(reply, from: this.Name); | ||
} | ||
|
||
public async IAsyncEnumerable<IMessage> GenerateStreamingReplyAsync( | ||
IEnumerable<IMessage> messages, | ||
GenerateReplyOptions? options = null, | ||
[EnumeratorCancellation] CancellationToken cancellationToken = default) | ||
{ | ||
var settings = this.CreateChatCompletionsOptions(options, messages); | ||
var response = await this.chatCompletionsClient.CompleteStreamingAsync(settings, cancellationToken); | ||
await foreach (var update in response.WithCancellation(cancellationToken)) | ||
{ | ||
yield return new MessageEnvelope<StreamingChatCompletionsUpdate>(update, from: this.Name); | ||
} | ||
} | ||
|
||
private ChatCompletionsOptions CreateChatCompletionsOptions(GenerateReplyOptions? options, IEnumerable<IMessage> messages) | ||
{ | ||
var oaiMessages = messages.Select(m => m switch | ||
{ | ||
IMessage<ChatRequestMessage> chatRequestMessage => chatRequestMessage.Content, | ||
_ => throw new ArgumentException("Invalid message type") | ||
}); | ||
|
||
// add system message if there's no system message in messages | ||
if (!oaiMessages.Any(m => m is ChatRequestSystemMessage)) | ||
{ | ||
oaiMessages = new[] { new ChatRequestSystemMessage(systemMessage) }.Concat(oaiMessages); | ||
} | ||
|
||
// clone the options by serializing and deserializing | ||
var json = JsonSerializer.Serialize(this.options); | ||
var settings = JsonSerializer.Deserialize<ChatCompletionsOptions>(json) ?? throw new InvalidOperationException("Failed to clone options"); | ||
|
||
foreach (var m in oaiMessages) | ||
{ | ||
settings.Messages.Add(m); | ||
} | ||
|
||
settings.Temperature = options?.Temperature ?? settings.Temperature; | ||
settings.MaxTokens = options?.MaxToken ?? settings.MaxTokens; | ||
|
||
foreach (var functions in this.options.Tools) | ||
{ | ||
settings.Tools.Add(functions); | ||
} | ||
|
||
foreach (var stopSequence in this.options.StopSequences) | ||
{ | ||
settings.StopSequences.Add(stopSequence); | ||
} | ||
|
||
var openAIFunctionDefinitions = options?.Functions?.Select(f => f.ToAzureAIInferenceFunctionDefinition()).ToList(); | ||
if (openAIFunctionDefinitions is { Count: > 0 }) | ||
{ | ||
foreach (var f in openAIFunctionDefinitions) | ||
{ | ||
settings.Tools.Add(new ChatCompletionsFunctionToolDefinition(f)); | ||
} | ||
} | ||
|
||
if (options?.StopSequence is var sequence && sequence is { Length: > 0 }) | ||
{ | ||
foreach (var seq in sequence) | ||
{ | ||
settings.StopSequences.Add(seq); | ||
} | ||
} | ||
|
||
return settings; | ||
} | ||
|
||
private static ChatCompletionsOptions CreateChatCompletionOptions( | ||
string modelName, | ||
float temperature = 0.7f, | ||
int maxTokens = 1024, | ||
int? seed = null, | ||
ChatCompletionsResponseFormat? responseFormat = null, | ||
IEnumerable<FunctionDefinition>? functions = null) | ||
{ | ||
var options = new ChatCompletionsOptions() | ||
{ | ||
Model = modelName, | ||
Temperature = temperature, | ||
MaxTokens = maxTokens, | ||
Seed = seed, | ||
ResponseFormat = responseFormat, | ||
}; | ||
|
||
if (functions is not null) | ||
{ | ||
foreach (var f in functions) | ||
{ | ||
options.Tools.Add(new ChatCompletionsFunctionToolDefinition(f)); | ||
} | ||
} | ||
|
||
return options; | ||
} | ||
} |
25 changes: 25 additions & 0 deletions
25
dotnet/src/AutoGen.AzureAIInference/AutoGen.AzureAIInference.csproj
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
<Project Sdk="Microsoft.NET.Sdk"> | ||
<PropertyGroup> | ||
<TargetFrameworks>$(PackageTargetFrameworks)</TargetFrameworks> | ||
<RootNamespace>AutoGen.AzureAIInference</RootNamespace> | ||
</PropertyGroup> | ||
|
||
<Import Project="$(RepoRoot)/nuget/nuget-package.props" /> | ||
|
||
<PropertyGroup> | ||
<!-- NuGet Package Settings --> | ||
<Title>AutoGen.AzureAIInference</Title> | ||
<Description> | ||
Azure AI Inference Intergration for AutoGen. | ||
</Description> | ||
</PropertyGroup> | ||
|
||
<ItemGroup> | ||
<PackageReference Include="Azure.AI.Inference" Version="$(AzureAIInferenceVersion)" /> | ||
</ItemGroup> | ||
|
||
<ItemGroup> | ||
<ProjectReference Include="..\AutoGen.Core\AutoGen.Core.csproj" /> | ||
</ItemGroup> | ||
|
||
</Project> |
39 changes: 39 additions & 0 deletions
39
dotnet/src/AutoGen.AzureAIInference/Extension/ChatComptionClientAgentExtension.cs
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
// Copyright (c) Microsoft Corporation. All rights reserved. | ||
// ChatComptionClientAgentExtension.cs | ||
|
||
using AutoGen.Core; | ||
|
||
namespace AutoGen.AzureAIInference.Extension; | ||
|
||
public static class ChatComptionClientAgentExtension | ||
{ | ||
/// <summary> | ||
/// Register an <see cref="AzureAIInferenceChatRequestMessageConnector"/> to the <see cref="ChatCompletionsClientAgent"/> | ||
/// </summary> | ||
/// <param name="connector">the connector to use. If null, a new instance of <see cref="AzureAIInferenceChatRequestMessageConnector"/> will be created.</param> | ||
public static MiddlewareStreamingAgent<ChatCompletionsClientAgent> RegisterMessageConnector( | ||
this ChatCompletionsClientAgent agent, AzureAIInferenceChatRequestMessageConnector? connector = null) | ||
{ | ||
if (connector == null) | ||
{ | ||
connector = new AzureAIInferenceChatRequestMessageConnector(); | ||
} | ||
|
||
return agent.RegisterStreamingMiddleware(connector); | ||
} | ||
|
||
/// <summary> | ||
/// Register an <see cref="AzureAIInferenceChatRequestMessageConnector"/> to the <see cref="MiddlewareAgent{T}"/> where T is <see cref="ChatCompletionsClientAgent"/> | ||
/// </summary> | ||
/// <param name="connector">the connector to use. If null, a new instance of <see cref="AzureAIInferenceChatRequestMessageConnector"/> will be created.</param> | ||
public static MiddlewareStreamingAgent<ChatCompletionsClientAgent> RegisterMessageConnector( | ||
this MiddlewareStreamingAgent<ChatCompletionsClientAgent> agent, AzureAIInferenceChatRequestMessageConnector? connector = null) | ||
{ | ||
if (connector == null) | ||
{ | ||
connector = new AzureAIInferenceChatRequestMessageConnector(); | ||
} | ||
|
||
return agent.RegisterStreamingMiddleware(connector); | ||
} | ||
} |
64 changes: 64 additions & 0 deletions
64
dotnet/src/AutoGen.AzureAIInference/Extension/FunctionContractExtension.cs
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,64 @@ | ||
// Copyright (c) Microsoft Corporation. All rights reserved. | ||
// FunctionContractExtension.cs | ||
|
||
using System; | ||
using System.Collections.Generic; | ||
using AutoGen.Core; | ||
using Azure.AI.Inference; | ||
using Json.Schema; | ||
using Json.Schema.Generation; | ||
|
||
namespace AutoGen.AzureAIInference.Extension; | ||
|
||
public static class FunctionContractExtension | ||
{ | ||
/// <summary> | ||
/// Convert a <see cref="FunctionContract"/> to a <see cref="FunctionDefinition"/> that can be used in gpt funciton call. | ||
/// </summary> | ||
/// <param name="functionContract">function contract</param> | ||
/// <returns><see cref="FunctionDefinition"/></returns> | ||
public static FunctionDefinition ToAzureAIInferenceFunctionDefinition(this FunctionContract functionContract) | ||
{ | ||
var functionDefinition = new FunctionDefinition | ||
{ | ||
Name = functionContract.Name, | ||
Description = functionContract.Description, | ||
}; | ||
var requiredParameterNames = new List<string>(); | ||
var propertiesSchemas = new Dictionary<string, JsonSchema>(); | ||
var propertySchemaBuilder = new JsonSchemaBuilder().Type(SchemaValueType.Object); | ||
foreach (var param in functionContract.Parameters ?? []) | ||
{ | ||
if (param.Name is null) | ||
{ | ||
throw new InvalidOperationException("Parameter name cannot be null"); | ||
} | ||
|
||
var schemaBuilder = new JsonSchemaBuilder().FromType(param.ParameterType ?? throw new ArgumentNullException(nameof(param.ParameterType))); | ||
if (param.Description != null) | ||
{ | ||
schemaBuilder = schemaBuilder.Description(param.Description); | ||
} | ||
|
||
if (param.IsRequired) | ||
{ | ||
requiredParameterNames.Add(param.Name); | ||
} | ||
|
||
var schema = schemaBuilder.Build(); | ||
propertiesSchemas[param.Name] = schema; | ||
|
||
} | ||
propertySchemaBuilder = propertySchemaBuilder.Properties(propertiesSchemas); | ||
propertySchemaBuilder = propertySchemaBuilder.Required(requiredParameterNames); | ||
|
||
var option = new System.Text.Json.JsonSerializerOptions() | ||
{ | ||
PropertyNamingPolicy = System.Text.Json.JsonNamingPolicy.CamelCase | ||
}; | ||
|
||
functionDefinition.Parameters = BinaryData.FromObjectAsJson(propertySchemaBuilder.Build(), option); | ||
|
||
return functionDefinition; | ||
} | ||
} |
Oops, something went wrong.