Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[.Net] Add a constructor which takes ChatCompletionOptions for OpenAIChatAgent #3170

Merged
merged 2 commits into from
Jul 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
96 changes: 71 additions & 25 deletions dotnet/src/AutoGen.OpenAI/Agent/OpenAIChatAgent.cs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
using System.Collections.Generic;
using System.Linq;
using System.Runtime.CompilerServices;
using System.Text.Json;
using System.Threading;
using System.Threading.Tasks;
using AutoGen.OpenAI.Extension;
Expand Down Expand Up @@ -32,13 +33,8 @@ namespace AutoGen.OpenAI;
public class OpenAIChatAgent : IStreamingAgent
{
private readonly OpenAIClient openAIClient;
private readonly string modelName;
private readonly float _temperature;
private readonly int _maxTokens = 1024;
private readonly IEnumerable<FunctionDefinition>? _functions;
private readonly string _systemMessage;
private readonly ChatCompletionsResponseFormat? _responseFormat;
private readonly int? _seed;
private readonly ChatCompletionsOptions options;
private readonly string systemMessage;

/// <summary>
/// Create a new instance of <see cref="OpenAIChatAgent"/>.
Expand All @@ -62,16 +58,36 @@ public OpenAIChatAgent(
int? seed = null,
ChatCompletionsResponseFormat? responseFormat = null,
IEnumerable<FunctionDefinition>? functions = null)
: this(
openAIClient: openAIClient,
name: name,
options: CreateChatCompletionOptions(modelName, temperature, maxTokens, seed, responseFormat, functions),
systemMessage: systemMessage)
{
}

/// <summary>
/// Create a new instance of <see cref="OpenAIChatAgent"/>.
/// </summary>
/// <param name="openAIClient">openai client</param>
/// <param name="name">agent name</param>
/// <param name="systemMessage">system message</param>
/// <param name="options">chat completion option. The option can't contain messages</param>
public OpenAIChatAgent(
OpenAIClient openAIClient,
string name,
ChatCompletionsOptions options,
string systemMessage = "You are a helpful AI assistant")
{
if (options.Messages is { Count: > 0 })
{
throw new ArgumentException("Messages should not be provided in options");
}

this.openAIClient = openAIClient;
this.modelName = modelName;
this.Name = name;
_temperature = temperature;
_maxTokens = maxTokens;
_functions = functions;
_systemMessage = systemMessage;
_responseFormat = responseFormat;
_seed = seed;
this.options = options;
this.systemMessage = systemMessage;
}

public string Name { get; }
Expand Down Expand Up @@ -116,22 +132,25 @@ private ChatCompletionsOptions CreateChatCompletionsOptions(GenerateReplyOptions
// add system message if there's no system message in messages
if (!oaiMessages.Any(m => m is ChatRequestSystemMessage))
{
oaiMessages = new[] { new ChatRequestSystemMessage(_systemMessage) }.Concat(oaiMessages);
oaiMessages = new[] { new ChatRequestSystemMessage(systemMessage) }.Concat(oaiMessages);
}

var settings = new ChatCompletionsOptions(this.modelName, oaiMessages)
// clone the options by serializing and deserializing
var json = JsonSerializer.Serialize(this.options);
var settings = JsonSerializer.Deserialize<ChatCompletionsOptions>(json) ?? throw new InvalidOperationException("Failed to clone options");

foreach (var m in oaiMessages)
{
MaxTokens = options?.MaxToken ?? _maxTokens,
Temperature = options?.Temperature ?? _temperature,
ResponseFormat = _responseFormat,
Seed = _seed,
};
settings.Messages.Add(m);
}

settings.Temperature = options?.Temperature ?? settings.Temperature;
settings.MaxTokens = options?.MaxToken ?? settings.MaxTokens;

var openAIFunctionDefinitions = options?.Functions?.Select(f => f.ToOpenAIFunctionDefinition());
var functions = openAIFunctionDefinitions ?? _functions;
if (functions is not null && functions.Count() > 0)
var openAIFunctionDefinitions = options?.Functions?.Select(f => f.ToOpenAIFunctionDefinition()).ToList();
if (openAIFunctionDefinitions is { Count: > 0 })
{
foreach (var f in functions)
foreach (var f in openAIFunctionDefinitions)
{
settings.Tools.Add(new ChatCompletionsFunctionToolDefinition(f));
}
Expand All @@ -147,4 +166,31 @@ private ChatCompletionsOptions CreateChatCompletionsOptions(GenerateReplyOptions

return settings;
}

private static ChatCompletionsOptions CreateChatCompletionOptions(
string modelName,
float temperature = 0.7f,
int maxTokens = 1024,
int? seed = null,
ChatCompletionsResponseFormat? responseFormat = null,
IEnumerable<FunctionDefinition>? functions = null)
{
var options = new ChatCompletionsOptions(modelName, [])
{
Temperature = temperature,
MaxTokens = maxTokens,
Seed = seed,
ResponseFormat = responseFormat,
};

if (functions is not null)
{
foreach (var f in functions)
{
options.Tools.Add(new ChatCompletionsFunctionToolDefinition(f));
}
}

return options;
}
}
64 changes: 52 additions & 12 deletions dotnet/test/AutoGen.OpenAI.Tests/OpenAIChatAgentTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,8 @@ public async Task<string> GetWeatherAsync(string location)
[ApiKeyFact("AZURE_OPENAI_API_KEY", "AZURE_OPENAI_ENDPOINT", "AZURE_OPENAI_DEPLOY_NAME")]
public async Task BasicConversationTestAsync()
{
var endpoint = Environment.GetEnvironmentVariable("AZURE_OPENAI_ENDPOINT") ?? throw new Exception("Please set AZURE_OPENAI_ENDPOINT environment variable.");
var key = Environment.GetEnvironmentVariable("AZURE_OPENAI_API_KEY") ?? throw new Exception("Please set AZURE_OPENAI_API_KEY environment variable.");
var deployName = Environment.GetEnvironmentVariable("AZURE_OPENAI_DEPLOY_NAME") ?? throw new Exception("Please set AZURE_OPENAI_DEPLOY_NAME environment variable.");
var openaiClient = new OpenAIClient(new Uri(endpoint), new Azure.AzureKeyCredential(key));
var openaiClient = CreateOpenAIClientFromAzureOpenAI();
var openAIChatAgent = new OpenAIChatAgent(
openAIClient: openaiClient,
name: "assistant",
Expand Down Expand Up @@ -60,10 +58,8 @@ public async Task BasicConversationTestAsync()
[ApiKeyFact("AZURE_OPENAI_API_KEY", "AZURE_OPENAI_ENDPOINT", "AZURE_OPENAI_DEPLOY_NAME")]
public async Task OpenAIChatMessageContentConnectorTestAsync()
{
var endpoint = Environment.GetEnvironmentVariable("AZURE_OPENAI_ENDPOINT") ?? throw new Exception("Please set AZURE_OPENAI_ENDPOINT environment variable.");
var key = Environment.GetEnvironmentVariable("AZURE_OPENAI_API_KEY") ?? throw new Exception("Please set AZURE_OPENAI_API_KEY environment variable.");
var deployName = Environment.GetEnvironmentVariable("AZURE_OPENAI_DEPLOY_NAME") ?? throw new Exception("Please set AZURE_OPENAI_DEPLOY_NAME environment variable.");
var openaiClient = new OpenAIClient(new Uri(endpoint), new Azure.AzureKeyCredential(key));
var openaiClient = CreateOpenAIClientFromAzureOpenAI();
var openAIChatAgent = new OpenAIChatAgent(
openAIClient: openaiClient,
name: "assistant",
Expand Down Expand Up @@ -107,10 +103,8 @@ public async Task OpenAIChatMessageContentConnectorTestAsync()
[ApiKeyFact("AZURE_OPENAI_API_KEY", "AZURE_OPENAI_ENDPOINT", "AZURE_OPENAI_DEPLOY_NAME")]
public async Task OpenAIChatAgentToolCallTestAsync()
{
var endpoint = Environment.GetEnvironmentVariable("AZURE_OPENAI_ENDPOINT") ?? throw new Exception("Please set AZURE_OPENAI_ENDPOINT environment variable.");
var key = Environment.GetEnvironmentVariable("AZURE_OPENAI_API_KEY") ?? throw new Exception("Please set AZURE_OPENAI_API_KEY environment variable.");
var deployName = Environment.GetEnvironmentVariable("AZURE_OPENAI_DEPLOY_NAME") ?? throw new Exception("Please set AZURE_OPENAI_DEPLOY_NAME environment variable.");
var openaiClient = new OpenAIClient(new Uri(endpoint), new Azure.AzureKeyCredential(key));
var openaiClient = CreateOpenAIClientFromAzureOpenAI();
var openAIChatAgent = new OpenAIChatAgent(
openAIClient: openaiClient,
name: "assistant",
Expand Down Expand Up @@ -176,10 +170,8 @@ public async Task OpenAIChatAgentToolCallTestAsync()
[ApiKeyFact("AZURE_OPENAI_API_KEY", "AZURE_OPENAI_ENDPOINT", "AZURE_OPENAI_DEPLOY_NAME")]
public async Task OpenAIChatAgentToolCallInvokingTestAsync()
{
var endpoint = Environment.GetEnvironmentVariable("AZURE_OPENAI_ENDPOINT") ?? throw new Exception("Please set AZURE_OPENAI_ENDPOINT environment variable.");
var key = Environment.GetEnvironmentVariable("AZURE_OPENAI_API_KEY") ?? throw new Exception("Please set AZURE_OPENAI_API_KEY environment variable.");
var deployName = Environment.GetEnvironmentVariable("AZURE_OPENAI_DEPLOY_NAME") ?? throw new Exception("Please set AZURE_OPENAI_DEPLOY_NAME environment variable.");
var openaiClient = new OpenAIClient(new Uri(endpoint), new Azure.AzureKeyCredential(key));
var openaiClient = CreateOpenAIClientFromAzureOpenAI();
var openAIChatAgent = new OpenAIChatAgent(
openAIClient: openaiClient,
name: "assistant",
Expand Down Expand Up @@ -236,4 +228,52 @@ public async Task OpenAIChatAgentToolCallInvokingTestAsync()
}
}
}

[ApiKeyFact("AZURE_OPENAI_API_KEY", "AZURE_OPENAI_ENDPOINT", "AZURE_OPENAI_DEPLOY_NAME")]
public async Task ItCreateOpenAIChatAgentWithChatCompletionOptionAsync()
{
var deployName = Environment.GetEnvironmentVariable("AZURE_OPENAI_DEPLOY_NAME") ?? throw new Exception("Please set AZURE_OPENAI_DEPLOY_NAME environment variable.");
var openaiClient = CreateOpenAIClientFromAzureOpenAI();
var options = new ChatCompletionsOptions(deployName, [])
{
Temperature = 0.7f,
MaxTokens = 1,
};

var openAIChatAgent = new OpenAIChatAgent(
openAIClient: openaiClient,
name: "assistant",
options: options)
.RegisterMessageConnector();

var respond = await openAIChatAgent.SendAsync("hello");
respond.GetContent()?.Should().NotBeNullOrEmpty();
}

[ApiKeyFact("AZURE_OPENAI_API_KEY", "AZURE_OPENAI_ENDPOINT", "AZURE_OPENAI_DEPLOY_NAME")]
public async Task ItThrowExceptionWhenChatCompletionOptionContainsMessages()
{
var deployName = Environment.GetEnvironmentVariable("AZURE_OPENAI_DEPLOY_NAME") ?? throw new Exception("Please set AZURE_OPENAI_DEPLOY_NAME environment variable.");
var openaiClient = CreateOpenAIClientFromAzureOpenAI();
var options = new ChatCompletionsOptions(deployName, [new ChatRequestUserMessage("hi")])
{
Temperature = 0.7f,
MaxTokens = 1,
};

var action = () => new OpenAIChatAgent(
openAIClient: openaiClient,
name: "assistant",
options: options)
.RegisterMessageConnector();

action.Should().ThrowExactly<ArgumentException>().WithMessage("Messages should not be provided in options");
}

private OpenAIClient CreateOpenAIClientFromAzureOpenAI()
{
var endpoint = Environment.GetEnvironmentVariable("AZURE_OPENAI_ENDPOINT") ?? throw new Exception("Please set AZURE_OPENAI_ENDPOINT environment variable.");
var key = Environment.GetEnvironmentVariable("AZURE_OPENAI_API_KEY") ?? throw new Exception("Please set AZURE_OPENAI_API_KEY environment variable.");
return new OpenAIClient(new Uri(endpoint), new Azure.AzureKeyCredential(key));
}
}
Loading