Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions Google_GenerativeAI.sln
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@ Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "TwoWayAudioCommunicationWpf
EndProject
Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "GenerativeAI.Live.Tests", "tests\GenerativeAI.Live.Tests\GenerativeAI.Live.Tests.csproj", "{157399AE-E8D1-4306-8B59-0BDEB45AED03}"
EndProject
Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "AotTest", "tests\AotTest\AotTest.csproj", "{5BF6737C-D3E4-4C46-ABBF-73ECAEE128AF}"
EndProject
Global
GlobalSection(SolutionConfigurationPlatforms) = preSolution
Debug|Any CPU = Debug|Any CPU
Expand Down Expand Up @@ -126,6 +128,10 @@ Global
{157399AE-E8D1-4306-8B59-0BDEB45AED03}.Debug|Any CPU.Build.0 = Debug|Any CPU
{157399AE-E8D1-4306-8B59-0BDEB45AED03}.Release|Any CPU.ActiveCfg = Release|Any CPU
{157399AE-E8D1-4306-8B59-0BDEB45AED03}.Release|Any CPU.Build.0 = Release|Any CPU
{5BF6737C-D3E4-4C46-ABBF-73ECAEE128AF}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
{5BF6737C-D3E4-4C46-ABBF-73ECAEE128AF}.Debug|Any CPU.Build.0 = Debug|Any CPU
{5BF6737C-D3E4-4C46-ABBF-73ECAEE128AF}.Release|Any CPU.ActiveCfg = Release|Any CPU
{5BF6737C-D3E4-4C46-ABBF-73ECAEE128AF}.Release|Any CPU.Build.0 = Release|Any CPU
EndGlobalSection
GlobalSection(SolutionProperties) = preSolution
HideSolutionNode = FALSE
Expand All @@ -149,6 +155,7 @@ Global
{ACB3E4E1-F967-45E7-81CD-70A4F7785AED} = {AC161F1D-EC76-48D2-86A3-B52584618D49}
{8DC0FF3E-8B46-41B0-B814-6049FD80C8C3} = {61CC49B3-1325-40EB-95DF-89E18A0D041B}
{157399AE-E8D1-4306-8B59-0BDEB45AED03} = {FCCDE15A-B121-4D6C-BD56-D1B043A26F18}
{5BF6737C-D3E4-4C46-ABBF-73ECAEE128AF} = {FCCDE15A-B121-4D6C-BD56-D1B043A26F18}
EndGlobalSection
GlobalSection(ExtensibilityGlobals) = postSolution
SolutionGuid = {FFF3E8BB-BACD-4376-8E33-55D6E8A30BE0}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
<ImplicitUsings>enable</ImplicitUsings>
<UseWPF>true</UseWPF>
<EnableWindowsTargeting>true</EnableWindowsTargeting>

</PropertyGroup>

<ItemGroup>
Expand Down
1 change: 1 addition & 0 deletions src/GenerativeAI.Auth/GenerativeAI.Auth.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
<GeneratePackageOnBuild>True</GeneratePackageOnBuild>
<RootNamespace>GenerativeAI.Authenticators</RootNamespace>
<AssemblyName>GenerativeAI.Authenticators</AssemblyName>

</PropertyGroup>
<ItemGroup>
<ProjectReference Include="..\GenerativeAI\GenerativeAI.csproj"/>
Expand Down
1 change: 1 addition & 0 deletions src/GenerativeAI.Live/GenerativeAI.Live.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
<FileVersion>2.3.1</FileVersion>
<SignAssembly>True</SignAssembly>
<GeneratePackageOnBuild>True</GeneratePackageOnBuild>
<IsAotCompatible>true</IsAotCompatible>
</PropertyGroup>

<ItemGroup Condition=" '$(TargetFramework)' == 'net9.0'">
Expand Down
9 changes: 5 additions & 4 deletions src/GenerativeAI.Live/Models/MultiModalLiveClient.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
using System.Runtime.CompilerServices;
using System.Text;
using System.Text.Json;
using System.Text.Json.Serialization.Metadata;
using GenerativeAI.Core;
using GenerativeAI.Live.Helper;
using GenerativeAI.Live.Logging;
Expand Down Expand Up @@ -183,11 +184,11 @@ private void ProcessReceivedMessage(ResponseMessage msg)
BidiResponsePayload? responsePayload = null;
if (msg.MessageType == WebSocketMessageType.Binary)
{
responsePayload = JsonSerializer.Deserialize<BidiResponsePayload>(msg.Binary);
responsePayload = JsonSerializer.Deserialize(msg.Binary,(JsonTypeInfo<BidiResponsePayload>) DefaultSerializerOptions.Options.GetTypeInfo(typeof(BidiResponsePayload)));
}
else
{
responsePayload = JsonSerializer.Deserialize<BidiResponsePayload>(msg.Text);
responsePayload = JsonSerializer.Deserialize(msg.Text,(JsonTypeInfo<BidiResponsePayload>) DefaultSerializerOptions.Options.GetTypeInfo(typeof(BidiResponsePayload)));
}

if (responsePayload == null)
Expand Down Expand Up @@ -228,7 +229,7 @@ private void ProcessTextChunk(BidiResponsePayload responsePayload)
{
if (part.Text != null)
{
this.TextChunkReceived.Invoke(this,
this.TextChunkReceived?.Invoke(this,
new TextChunkReceivedArgs(part.Text, responsePayload.ServerContent.TurnComplete == true));
_logger?.LogInformation("Text chunk received: {Text}", part.Text);
}
Expand Down Expand Up @@ -565,7 +566,7 @@ private async Task SendAsync(BidiClientPayload payload, CancellationToken cancel

try
{
var json = JsonSerializer.Serialize(payload,DefaultSerializerOptions.Options);
var json = JsonSerializer.Serialize(payload,DefaultSerializerOptions.Options.GetTypeInfo(typeof(BidiClientPayload)));
_logger?.LogMessageSent(json);

_client.Send(json);
Expand Down
88 changes: 68 additions & 20 deletions src/GenerativeAI.Microsoft/Extensions/MicrosoftExtensions.cs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
using GenerativeAI.Types;
using Json.More;

using Microsoft.Extensions.AI;
using System.Text.Json;
using System.Text.Json.Nodes;
Expand Down Expand Up @@ -70,8 +70,8 @@ where p is not null
/// <returns>A <see cref="Schema"/> object constructed from the provided JSON schema, or null if deserialization fails.</returns>
public static Schema? ToSchema(this JsonElement schema)
{

var serialized = JsonSerializer.Serialize(schema);
return GoogleSchemaHelper.ConvertToCompatibleSchemaSubset(schema.AsNode().AsObject());
var serialized = JsonSerializer.Serialize(schema, DefaultSerializerOptions.Options.GetTypeInfo(schema.GetType()));
return JsonSerializer.Deserialize(serialized,SchemaSourceGenerationContext.Default.Schema);
}

Expand All @@ -98,27 +98,69 @@ where p is not null
FunctionCall = new FunctionCall()
{
Name = fcc.Name,
Args = fcc.Arguments!,
Args = fcc.Arguments.ToJsonNode(),
}
},
FunctionResultContent frc => new Part
{
FunctionResponse = new FunctionResponse()
{
Name = frc.CallId,
Response = new
{
Name = frc.CallId,
Content = JsonSerializer.SerializeToNode(frc.Result)!,
}
Response = frc.ToJsonNodeResponse()
}
},
_ => null,
};
}


private static JsonNode ToJsonNode(this IDictionary<string, object?>? args)
{
var node = new JsonObject();
foreach (var arg in args!)
{
if (arg.Value is JsonNode nd)
node.Add(arg.Key, nd.DeepClone());
else
{
var p = arg.Value switch
{
string s => s,
int i => i,
float f => f,
double d => d,
bool b => b,
null => null,
JsonElement e => e.AsNode()?.AsObject(),
JsonNode n => n switch
{
JsonObject o => o,
JsonArray a => a,
JsonValue v => v.GetValue<JsonElement>().AsNode()
},
_ => throw new Exception("Unsupported argument type")
};
node.Add(arg.Key, p);
}
}

return node; //JsonSerializer.Deserialize(node.ToJsonString(), TypesSerializerContext.Default.JsonElement)!;
}
public static JsonNode ToJsonNodeResponse(this object? response)
{
if (response is FunctionResultContent content)
{
if (content.Result is JsonObject obj)
return obj;
else if (content.Result is JsonNode arr)
return arr;
}
if(response is JsonNode node)
{
return node;
}
else throw new Exception("Response is not a json node");

}
/// <summary>
/// Maps <see cref="ChatOptions"/> into a <see cref="GenerationConfig"/> object used by GenerativeAI.
/// </summary>
Expand All @@ -137,15 +179,15 @@ where p is not null
config.TopK = options.TopK;
config.MaxOutputTokens = options.MaxOutputTokens;
config.StopSequences = options.StopSequences?.ToList();
config.Seed = (int) options.Seed!;
config.Seed = (int?) options.Seed;
config.ResponseMimeType = options.ResponseFormat is ChatResponseFormatJson ? "application/json" : null;
if (options.ResponseFormat is ChatResponseFormatJson jsonFormat)
{
// see also: https://github.com/dotnet/extensions/blob/f775ed6bd07c0dd94ac422dc6098162eef0b48e5/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/ChatClientStructuredOutputExtensions.cs#L186-L192
if (jsonFormat.Schema is JsonElement je && je.ValueKind == JsonValueKind.Object)
{
// Workaround to convert our real json schema to the format Google's api expects
var forGoogleApi = GoogleSchemaHelper.ConvertToCompatibleSchemaSubset(je.ToJsonDocument());
var forGoogleApi = GoogleSchemaHelper.ConvertToCompatibleSchemaSubset(je.AsNode());
config.ResponseSchema = forGoogleApi;
}
}
Expand Down Expand Up @@ -396,17 +438,23 @@ public static IList<AIContent> ToAiContents(this List<Part>? parts)
/// <returns>A dictionary where the keys represent argument names and values represent their corresponding data, or null if conversion is not possible.</returns>
private static IDictionary<string, object?>? ConvertFunctionCallArg(object? functionCallArgs)
{
if (functionCallArgs != null && functionCallArgs is not JsonElement)
if (functionCallArgs is JsonElement jsonElement)
{
functionCallArgs = JsonSerializer.Deserialize<dynamic>(JsonSerializer.Serialize(functionCallArgs));
var obj = jsonElement.AsNode().AsObject();
return obj?.ToDictionary(s=>s.Key,s=>(object?)s.Value?.DeepClone());

}
if (functionCallArgs is JsonElement jsonElement)
if (functionCallArgs is JsonNode jsonElement2)
{
if (jsonElement.ValueKind == JsonValueKind.Object)
{
var obj = JsonObject.Create(jsonElement);
return obj?.ToDictionary(s=>s.Key,s=>(object?)s.Value);
}
var obj = jsonElement2.AsObject();
return obj?.ToDictionary(s=>s.Key,s=>(object?)s.Value?.DeepClone());
}
else if (functionCallArgs != null && functionCallArgs is not JsonNode)
{
throw new Exception("Unsupported function call argument type");
// #pragma warning disable IL2026, IL3050
// functionCallArgs = JsonSerializer.Deserialize<dynamic>(JsonSerializer.Serialize(functionCallArgs));
// #pragma warning restore IL2026, IL3050
}

return null;
Expand Down
1 change: 1 addition & 0 deletions src/GenerativeAI.Microsoft/GenerativeAI.Microsoft.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
<FileVersion>2.3.1</FileVersion>
<SignAssembly>True</SignAssembly>
<GeneratePackageOnBuild>True</GeneratePackageOnBuild>
<IsAotCompatible>true</IsAotCompatible>
</PropertyGroup>
<ItemGroup>
<ProjectReference Include="..\GenerativeAI\GenerativeAI.csproj" />
Expand Down
12 changes: 7 additions & 5 deletions src/GenerativeAI.Microsoft/GenerativeAIChatClient.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
using System.Runtime.CompilerServices;
using System.Text.Json;
using System.Text.Json.Nodes;
using GenerativeAI.Core;
using GenerativeAI.Exceptions;
using GenerativeAI.Microsoft.Extensions;
Expand Down Expand Up @@ -70,15 +72,15 @@ private async Task<ChatResponse> CallFunctionAsync(GenerateContentRequest reques
var content = response.Candidates?.FirstOrDefault()?.Content;
if (content != null)
contents.Add(content);
var responseObject = new JsonObject();
responseObject["name"] = functionCall.Name;
responseObject["content"] = ((JsonElement)result).AsNode().DeepClone();
//responseObject["content"] = result as JsonNode;
var functionResponse = new FunctionResponse()
{
Name = tool.Name,
Id = functionCall.CallId,
Response = new
{
Name = tool.Name,
Content = result
}
Response = responseObject
};
var funcContent = new Content() { Role = Roles.Function };
funcContent.AddPart(new Part()
Expand Down
3 changes: 2 additions & 1 deletion src/GenerativeAI.Tools/GenerativeAI.Tools.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,13 @@
<FileVersion>2.3.1</FileVersion>
<SignAssembly>True</SignAssembly>
<GeneratePackageOnBuild>True</GeneratePackageOnBuild>
<IsAotCompatible>true</IsAotCompatible>
</PropertyGroup>
<ItemGroup>
<ProjectReference Include="..\GenerativeAI\GenerativeAI.csproj" />
</ItemGroup>
<ItemGroup>
<PackageReference Include="CSharpToJsonSchema" Version="3.10.1" />
<PackageReference Include="CSharpToJsonSchema" Version="3.10.2-dev.25" />
</ItemGroup>

</Project>
48 changes: 36 additions & 12 deletions src/GenerativeAI.Tools/GenericFunctionTool.cs
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
using System.Text.Json.Nodes;
using System.Text.Json;
using System.Text.Json.Nodes;
using CSharpToJsonSchema;
using GenerativeAI.Core;
using GenerativeAI.Types;

using JsonSerializer = System.Text.Json.JsonSerializer;
using Tool = GenerativeAI.Types.Tool;

Expand All @@ -13,7 +15,7 @@ namespace GenerativeAI.Tools;
/// It utilizes the code generation capabilities available in <see href="https://www.nuget.org/packages/CSharpToJsonSchema">CSharpToJsonSchema</see> for transforming
/// tool definitions into executable formats and managing function invocations.
/// </summary>
public class GenericFunctionTool:IFunctionTool
public class GenericFunctionTool:GoogleFunctionTool
{
/// <summary>
/// Represents a generic functional tool that enables interaction with a set of tools and their associated functions,
Expand All @@ -29,7 +31,7 @@ public GenericFunctionTool(IEnumerable<CSharpToJsonSchema.Tool> tools, IReadOnly


/// <inheritdoc/>
public Tool AsTool()
public override Tool AsTool()
{
return new Tool()
{
Expand All @@ -44,30 +46,52 @@ public Tool AsTool()

private Schema? ToSchema(object parameters)
{
var param = JsonSerializer.Serialize(parameters);
var param = JsonSerializer.Serialize(parameters, OpenApiSchemaSourceGenerationContext.Default.OpenApiSchema);
return JsonSerializer.Deserialize(param,SchemaSourceGenerationContext.Default.Schema);
}

/// <inheritdoc/>
public async Task<FunctionResponse?> CallAsync(FunctionCall functionCall, CancellationToken cancellationToken = default)
public override async Task<FunctionResponse?> CallAsync(FunctionCall functionCall, CancellationToken cancellationToken = default)
{
#pragma disable warning IL2026, IL3050
if (this.Calls.TryGetValue(functionCall.Name, out var call))
{
var str = JsonSerializer.Serialize(functionCall.Args);
var response = await call(str, cancellationToken).ConfigureAwait(false);
string? args = null;
if (functionCall.Args !=null)
{
args = functionCall.Args.ToJsonString();
}
// else if (functionCall.Args is JsonNode jsonNode)
// {
// args = jsonNode.ToJsonString();
// }
// else if (functionCall.Args is JsonObject jsonObject)
// {
// args = jsonObject.ToJsonString();
// }
else
{
throw new NotImplementedException();
//args = JsonSerializer.Serialize(functionCall.Args, DefaultSerializerOptions.Options.GetTypeInfo());
}
var response = await call(args, cancellationToken).ConfigureAwait(false);

var node = JsonNode.Parse(response);
var responseNode = new JsonObject();

return new FunctionResponse() { Id = functionCall.Id, Name = functionCall.Name, Response = new {
Name = functionCall.Name,
Content = node,
} };
responseNode["name"] = functionCall.Name;
responseNode["content"] = node;
return new FunctionResponse() { Id = functionCall.Id, Name = functionCall.Name,

Response = responseNode,
};
#pragma restore warning IL2026, IL3050
}
return null;
}

/// <inheritdoc/>
public bool IsContainFunction(string name)
public override bool IsContainFunction(string name)
{
return Tools.Any(s => s.Name == name);
}
Expand Down
Loading