Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add support for custom types in function parameters #504

Open
wants to merge 5 commits into
base: dev
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
144 changes: 140 additions & 4 deletions OpenAI.Utilities.Tests/FunctionCallingHelperTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ public void VerifyGetFunctionDefinition()
stringParameter.Description.ShouldBe("String Parameter");
stringParameter.Type.ShouldBe("string");

var enumValues = new List<string> {"Value1", "Value2", "Value3"};
var enumValues = new List<string> { "Value1", "Value2", "Value3" };

var enumParameter = functionDefinition.Parameters.Properties["enumParameter"];
enumParameter.Description.ShouldBe("Enum Parameter");
Expand Down Expand Up @@ -124,8 +124,7 @@ public void VerifyCallFunction_Complex()
{
Name = "TestFunction",
// arguments is a json dictionary
Arguments =
"{\"intParameter\": 1, \"floatParameter\": 2.0, \"boolParameter\": true, \"stringParameter\": \"Hello\", \"enumParameter\": \"Value1\", \"enumParameter2\": \"Value2\", \"requiredIntParameter\": 1, \"notRequiredIntParameter\": 2, \"OverriddenName\": 3}"
Arguments = "{\"intParameter\": 1, \"floatParameter\": 2.0, \"boolParameter\": true, \"stringParameter\": \"Hello\", \"enumParameter\": \"Value1\", \"enumParameter2\": \"Value2\", \"requiredIntParameter\": 1, \"notRequiredIntParameter\": 2, \"OverriddenName\": 3}"
};

var result = FunctionCallingHelper.CallFunction<int>(functionCall, obj);
Expand Down Expand Up @@ -227,6 +226,108 @@ public void VerifyCallFunctionWithOverriddenType()
FunctionCallingHelper.CallFunction<object>(functionCall, obj);
obj.OverriddenTypeParameter.ShouldBe("1");
}


[Fact]
public void VerifyGetFunctionDefinition_CustomType()
{
var functionDefinition =
FunctionCallingHelper.GetFunctionDefinition(
typeof(FunctionCallingTestClass).GetMethod("FunctionWithCustomType")!);

functionDefinition.Name.ShouldBe("FunctionWithCustomType");
functionDefinition.Description.ShouldBe("Function with custom type parameter");
functionDefinition.Parameters.ShouldNotBeNull();
functionDefinition.Parameters.Properties!.Count.ShouldBe(1);

var customTypeParameter = functionDefinition.Parameters.Properties["customTypeParameter"];
customTypeParameter.Description.ShouldBe("Custom type parameter");
customTypeParameter.Type.ShouldBe("object");

customTypeParameter.Properties.ShouldNotBeNull();
customTypeParameter.Properties!.Count.ShouldBe(4);

var nameProperty = customTypeParameter.Properties["Name"];
nameProperty.Type.ShouldBe("string");
nameProperty.Description.ShouldBe("The name");

var ageProperty = customTypeParameter.Properties["Age"];
ageProperty.Type.ShouldBe("integer");
ageProperty.Description.ShouldBe("The age");

var scoreProperty = customTypeParameter.Properties["Score"];
scoreProperty.Type.ShouldBe("number");
scoreProperty.Description.ShouldBe("The score");

var isActiveProperty = customTypeParameter.Properties["IsActive"];
isActiveProperty.Type.ShouldBe("boolean");
isActiveProperty.Description.ShouldBe("The status");
}

[Fact]
public void VerifyGetFunctionDefinition_ComplexCustomType()
{
var functionDefinition =
FunctionCallingHelper.GetFunctionDefinition(
typeof(FunctionCallingTestClass).GetMethod("FunctionWithComplexCustomType")!);


functionDefinition.Name.ShouldBe("FunctionWithComplexCustomType");
functionDefinition.Description.ShouldBe("Function with complex custom type parameter");
functionDefinition.Parameters.ShouldNotBeNull();
functionDefinition.Parameters.Properties!.Count.ShouldBe(1);

var complexCustomTypeParameter = functionDefinition.Parameters.Properties["complexCustomTypeParameter"];
complexCustomTypeParameter.Description.ShouldBe("Complex custom type parameter");
complexCustomTypeParameter.Type.ShouldBe("object");

complexCustomTypeParameter.Properties.ShouldNotBeNull();
complexCustomTypeParameter.Properties!.Count.ShouldBe(5);

complexCustomTypeParameter.Properties.ShouldContainKey("Name");
complexCustomTypeParameter.Properties.ShouldContainKey("Age");
complexCustomTypeParameter.Properties.ShouldContainKey("Scores");
complexCustomTypeParameter.Properties.ShouldContainKey("IsActive");
complexCustomTypeParameter.Properties.ShouldContainKey("NestedCustomType");

var nestedCustomTypeProperty = complexCustomTypeParameter.Properties["NestedCustomType"];
nestedCustomTypeProperty.Type.ShouldBe("object");

nestedCustomTypeProperty.Properties.ShouldNotBeNull();
nestedCustomTypeProperty.Properties!.Count.ShouldBe(4);

nestedCustomTypeProperty.Properties.ShouldContainKey("Name");
nestedCustomTypeProperty.Properties.ShouldContainKey("Age");
nestedCustomTypeProperty.Properties.ShouldContainKey("Score");
nestedCustomTypeProperty.Properties.ShouldContainKey("IsActive");
}

[Fact]
public void VerifyCallFunction_ComplexCustomType()
{
var obj = new FunctionCallingTestClass();

var functionCall = new FunctionCall
{
Name = "FunctionWithComplexCustomType",
Arguments =
"{\"complexCustomTypeParameter\": {\"Name\": \"John\", \"Age\": 30, \"Scores\": [85.5, 92.0, 78.5], \"IsActive\": true, \"NestedCustomType\": {\"Name\": \"Nested\", \"Age\": 20, \"Score\": 95.0, \"IsActive\": false}}}"
};

FunctionCallingHelper.CallFunction<object>(functionCall, obj);

obj.ComplexCustomTypeParameter.ShouldNotBeNull();
obj.ComplexCustomTypeParameter.Name.ShouldBe("John");
obj.ComplexCustomTypeParameter.Age.ShouldBe(30);
obj.ComplexCustomTypeParameter.Scores.ShouldBe(new List<float> { 85.5f, 92.0f, 78.5f });
obj.ComplexCustomTypeParameter.IsActive.ShouldBe(true);

obj.ComplexCustomTypeParameter.NestedCustomType.ShouldNotBeNull();
obj.ComplexCustomTypeParameter.NestedCustomType.Name.ShouldBe("Nested");
obj.ComplexCustomTypeParameter.NestedCustomType.Age.ShouldBe(20);
obj.ComplexCustomTypeParameter.NestedCustomType.Score.ShouldBe(95.0f);
obj.ComplexCustomTypeParameter.NestedCustomType.IsActive.ShouldBe(false);
}
}

internal class FunctionCallingTestClass
Expand All @@ -241,6 +342,7 @@ internal class FunctionCallingTestClass
public string OverriddenTypeParameter = null!;
public int RequiredIntParameter;
public string StringParameter = null!;
public ComplexCustomType ComplexCustomTypeParameter { get; set; } = new();

[FunctionDescription("Test Function")]
public int TestFunction(
Expand Down Expand Up @@ -284,7 +386,9 @@ public string SecondFunction()
}

[FunctionDescription("Third Function")]
public void ThirdFunction([ParameterDescription(Type = "string", Description = "Overridden type parameter")] int overriddenTypeParameter)
public void ThirdFunction(
[ParameterDescription(Type = "string", Description = "Overridden type parameter")]
int overriddenTypeParameter)
{
OverriddenTypeParameter = overriddenTypeParameter.ToString();
}
Expand All @@ -294,6 +398,38 @@ public string FourthFunction()
{
return "Ciallo~(∠・ω< )⌒★";
}

[FunctionDescription("Function with complex custom type parameter")]
public void FunctionWithComplexCustomType(
[ParameterDescription("Complex custom type parameter")]
ComplexCustomType complexCustomTypeParameter)
{
ComplexCustomTypeParameter = complexCustomTypeParameter;
}
}

public class ComplexCustomType
{
public string Name { get; set; } = string.Empty;

public int Age { get; set; }

public List<float> Scores { get; set; } = new();

public bool IsActive { get; set; }

public CustomType NestedCustomType { get; set; } = new();
}

public class CustomType
{
public string Name { get; set; } = string.Empty;

public int Age { get; set; }

public float Score { get; set; }

public bool IsActive { get; set; }
}

public enum TestEnum
Expand Down
98 changes: 85 additions & 13 deletions OpenAI.Utilities/FunctionCalling/FunctionCallingHelper.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
using System.Reflection;
using System.ComponentModel;
using System.ComponentModel.DataAnnotations;
using System.Reflection;
using System.Text.Json;
using OpenAI.Builders;
using OpenAI.ObjectModels.RequestModels;
Expand Down Expand Up @@ -41,7 +43,6 @@ public static FunctionDefinition GetFunctionDefinition(MethodInfo methodInfo)
Type = parameterDescriptionAttribute!.Type!,
Description = description
};

break;
case ({ } t, _) when t.IsAssignableFrom(typeof(int)):
definition = PropertyDefinition.DefineInteger(description);
Expand All @@ -55,18 +56,31 @@ public static FunctionDefinition GetFunctionDefinition(MethodInfo methodInfo)
case ({ } t, _) when t.IsAssignableFrom(typeof(string)):
definition = PropertyDefinition.DefineString(description);
break;
case ({IsEnum: true}, _):

case ({ IsEnum: true }, _):
var enumValues = string.IsNullOrEmpty(parameterDescriptionAttribute?.Enum)
? Enum.GetNames(parameter.ParameterType).ToList()
: parameterDescriptionAttribute.Enum.Split(",").Select(x => x.Trim()).ToList();
definition = PropertyDefinition.DefineEnum(enumValues, description);
break;
default:
// Handling custom types
var properties = new Dictionary<string, PropertyDefinition>();
var requiredProperties = new List<string>();

definition =
PropertyDefinition.DefineEnum(enumValues, description);
foreach (var prop in parameter.ParameterType.GetProperties())
{
var propDefinition = GetPropertyDefinition(prop);
properties[prop.Name] = propDefinition;

if (prop.GetCustomAttribute<RequiredAttribute>() != null)
{
requiredProperties.Add(prop.Name);
}
}

definition =
PropertyDefinition.DefineObject(properties, requiredProperties, false, description, null);
break;
default:
throw new Exception($"Parameter type '{parameter.ParameterType}' not supported");
}

result.AddParameter(
Expand All @@ -78,6 +92,59 @@ public static FunctionDefinition GetFunctionDefinition(MethodInfo methodInfo)
return result.Build();
}

/// <summary>
/// Gets the definition of a property.
/// </summary>
/// <param name="propertyInfo">The reflection information of the property.</param>
/// <returns>The definition of the property.</returns>
/// <remarks>
/// This method creates the appropriate property definition based on the property type. The following types are
/// supported:
/// - int: Defined as an integer type.
/// - float: Defined as a number type.
/// - bool: Defined as a boolean type.
/// - string: Defined as a string type.
/// - enum: Defined as an enum type, with enum values obtained from the property's enum type.
/// - Custom types: Recursively processes the properties of custom types and creates an object type property
/// definition.
/// </remarks>
private static PropertyDefinition GetPropertyDefinition(PropertyInfo propertyInfo)
{
var description = propertyInfo.GetCustomAttribute<DescriptionAttribute>()?.Description;

switch (propertyInfo.PropertyType)
{
case { } t when t == typeof(int):
return PropertyDefinition.DefineInteger(description);
case { } t when t == typeof(float):
return PropertyDefinition.DefineNumber(description);
case { } t when t == typeof(bool):
return PropertyDefinition.DefineBoolean(description);
case { } t when t == typeof(string):
return PropertyDefinition.DefineString(description);
case { IsEnum: true }:
var enumValues = Enum.GetNames(propertyInfo.PropertyType).ToList();
return PropertyDefinition.DefineEnum(enumValues, description);
default:
// Recursive processing if the property type is a custom class
var properties = new Dictionary<string, PropertyDefinition>();
var requiredProperties = new List<string>();

foreach (var prop in propertyInfo.PropertyType.GetProperties())
{
var propDefinition = GetPropertyDefinition(prop);
properties[prop.Name] = propDefinition;

if (prop.GetCustomAttribute<RequiredAttribute>() != null)
{
requiredProperties.Add(prop.Name);
}
}

return PropertyDefinition.DefineObject(properties, requiredProperties, false, description, null);
}
}

public static ToolDefinition GetToolDefinition(MethodInfo methodInfo)
{
return new ToolDefinition()
Expand All @@ -86,6 +153,7 @@ public static ToolDefinition GetToolDefinition(MethodInfo methodInfo)
Function = GetFunctionDefinition(methodInfo)
};
}

/// <summary>
/// Enumerates the methods in the provided object, and a returns a <see cref="List{FunctionDefinition}" /> of
/// <see cref="FunctionDefinition" /> for all methods
Expand Down Expand Up @@ -129,7 +197,7 @@ public static List<ToolDefinition> GetToolDefinitions(Type type)

return result;
}


/// <summary>
/// Calls the function on the provided object, using the provided <see cref="FunctionCall" /> and returns the result of
Expand Down Expand Up @@ -193,13 +261,15 @@ public static List<ToolDefinition> GetToolDefinitions(Type type)
}
else
{
value = parameter.ParameterType.IsEnum ? Enum.Parse(parameter.ParameterType, argument.Value.ToString()!) : ((JsonElement)argument.Value).Deserialize(parameter.ParameterType);
value = parameter.ParameterType.IsEnum
? Enum.Parse(parameter.ParameterType, argument.Value.ToString()!)
: ((JsonElement)argument.Value).Deserialize(parameter.ParameterType);
}

args.Add(value);
}

var result = (T?) methodInfo.Invoke(obj, args.ToArray());
var result = (T?)methodInfo.Invoke(obj, args.ToArray());
return result;
}

Expand All @@ -220,9 +290,11 @@ public static List<ToolDefinition> GetToolDefinitions(Type type)
// If not found, then look for methods with the custom attribute
var methodsWithAttributes = type
.GetMethods()
.FirstOrDefault(m => m.GetCustomAttributes(typeof(FunctionDescriptionAttribute), false).FirstOrDefault() is FunctionDescriptionAttribute attr && attr.Name == functionCall.Name);
.FirstOrDefault(m =>
m.GetCustomAttributes(typeof(FunctionDescriptionAttribute), false)
.FirstOrDefault() is FunctionDescriptionAttribute attr &&
attr.Name == functionCall.Name);

return methodsWithAttributes;
}

}