diff --git a/eng/MSBuild/LegacySupport.props b/eng/MSBuild/LegacySupport.props index 9e83541b0d8..1f0f22c7aa8 100644 --- a/eng/MSBuild/LegacySupport.props +++ b/eng/MSBuild/LegacySupport.props @@ -74,4 +74,8 @@ + + + + diff --git a/src/LegacySupport/NullabilityInfoContext/NullabilityInfo.cs b/src/LegacySupport/NullabilityInfoContext/NullabilityInfo.cs new file mode 100644 index 00000000000..bd9b132cd0f --- /dev/null +++ b/src/LegacySupport/NullabilityInfoContext/NullabilityInfo.cs @@ -0,0 +1,75 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +#if !NET6_0_OR_GREATER +using System.Diagnostics.CodeAnalysis; + +#pragma warning disable SA1623 // Property summary documentation should match accessors + +namespace System.Reflection +{ + /// + /// A class that represents nullability info. + /// + [ExcludeFromCodeCoverage] + internal sealed class NullabilityInfo + { + internal NullabilityInfo(Type type, NullabilityState readState, NullabilityState writeState, + NullabilityInfo? elementType, NullabilityInfo[] typeArguments) + { + Type = type; + ReadState = readState; + WriteState = writeState; + ElementType = elementType; + GenericTypeArguments = typeArguments; + } + + /// + /// The of the member or generic parameter + /// to which this NullabilityInfo belongs. + /// + public Type Type { get; } + + /// + /// The nullability read state of the member. + /// + public NullabilityState ReadState { get; internal set; } + + /// + /// The nullability write state of the member. + /// + public NullabilityState WriteState { get; internal set; } + + /// + /// If the member type is an array, gives the of the elements of the array, null otherwise. + /// + public NullabilityInfo? ElementType { get; } + + /// + /// If the member type is a generic type, gives the array of for each type parameter. + /// + public NullabilityInfo[] GenericTypeArguments { get; } + } + + /// + /// An enum that represents nullability state. + /// + internal enum NullabilityState + { + /// + /// Nullability context not enabled (oblivious). + /// + Unknown, + + /// + /// Non nullable value or reference type. + /// + NotNull, + + /// + /// Nullable value or reference type. + /// + Nullable, + } +} +#endif diff --git a/src/LegacySupport/NullabilityInfoContext/NullabilityInfoContext.cs b/src/LegacySupport/NullabilityInfoContext/NullabilityInfoContext.cs new file mode 100644 index 00000000000..33f13a56164 --- /dev/null +++ b/src/LegacySupport/NullabilityInfoContext/NullabilityInfoContext.cs @@ -0,0 +1,661 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +#if !NET6_0_OR_GREATER +using System.Collections.Generic; +using System.Collections.ObjectModel; +using System.Diagnostics; +using System.Diagnostics.CodeAnalysis; + +#pragma warning disable SA1204 // Static elements should appear before instance elements +#pragma warning disable S109 // Magic numbers should not be used +#pragma warning disable S1067 // Expressions should not be too complex +#pragma warning disable S4136 // Method overloads should be grouped together +#pragma warning disable SA1202 // Elements should be ordered by access +#pragma warning disable IDE1006 // Naming Styles + +namespace System.Reflection +{ + /// + /// Provides APIs for populating nullability information/context from reflection members: + /// , , and . + /// + [ExcludeFromCodeCoverage] + internal sealed class NullabilityInfoContext + { + private const string CompilerServicesNameSpace = "System.Runtime.CompilerServices"; + private readonly Dictionary _publicOnlyModules = new(); + private readonly Dictionary _context = new(); + + [Flags] + private enum NotAnnotatedStatus + { + None = 0x0, // no restriction, all members annotated + Private = 0x1, // private members not annotated + Internal = 0x2, // internal members not annotated + } + + private NullabilityState? GetNullableContext(MemberInfo? memberInfo) + { + while (memberInfo != null) + { + if (_context.TryGetValue(memberInfo, out NullabilityState state)) + { + return state; + } + + foreach (CustomAttributeData attribute in memberInfo.GetCustomAttributesData()) + { + if (attribute.AttributeType.Name == "NullableContextAttribute" && + attribute.AttributeType.Namespace == CompilerServicesNameSpace && + attribute.ConstructorArguments.Count == 1) + { + state = TranslateByte(attribute.ConstructorArguments[0].Value); + _context.Add(memberInfo, state); + return state; + } + } + + memberInfo = memberInfo.DeclaringType; + } + + return null; + } + + /// + /// Populates for the given . + /// If the nullablePublicOnly feature is set for an assembly, like it does in .NET SDK, the private and/or internal member's + /// nullability attributes are omitted, in this case the API will return NullabilityState.Unknown state. + /// + /// The parameter which nullability info gets populated. + /// If the parameterInfo parameter is null. + /// . + public NullabilityInfo Create(ParameterInfo parameterInfo) + { + IList attributes = parameterInfo.GetCustomAttributesData(); + NullableAttributeStateParser parser = parameterInfo.Member is MethodBase method && IsPrivateOrInternalMethodAndAnnotationDisabled(method) + ? NullableAttributeStateParser.Unknown + : CreateParser(attributes); + NullabilityInfo nullability = GetNullabilityInfo(parameterInfo.Member, parameterInfo.ParameterType, parser); + + if (nullability.ReadState != NullabilityState.Unknown) + { + CheckParameterMetadataType(parameterInfo, nullability); + } + + CheckNullabilityAttributes(nullability, attributes); + return nullability; + } + + private void CheckParameterMetadataType(ParameterInfo parameter, NullabilityInfo nullability) + { + ParameterInfo? metaParameter; + MemberInfo metaMember; + + switch (parameter.Member) + { + case ConstructorInfo ctor: + var metaCtor = (ConstructorInfo)GetMemberMetadataDefinition(ctor); + metaMember = metaCtor; + metaParameter = GetMetaParameter(metaCtor, parameter); + break; + + case MethodInfo method: + MethodInfo metaMethod = GetMethodMetadataDefinition(method); + metaMember = metaMethod; + metaParameter = string.IsNullOrEmpty(parameter.Name) ? metaMethod.ReturnParameter : GetMetaParameter(metaMethod, parameter); + break; + + default: + return; + } + + if (metaParameter != null) + { + CheckGenericParameters(nullability, metaMember, metaParameter.ParameterType, parameter.Member.ReflectedType); + } + } + + private static ParameterInfo? GetMetaParameter(MethodBase metaMethod, ParameterInfo parameter) + { + var parameters = metaMethod.GetParameters(); + for (int i = 0; i < parameters.Length; i++) + { + if (parameter.Position == i && + parameter.Name == parameters[i].Name) + { + return parameters[i]; + } + } + + return null; + } + + private static MethodInfo GetMethodMetadataDefinition(MethodInfo method) + { + if (method.IsGenericMethod && !method.IsGenericMethodDefinition) + { + method = method.GetGenericMethodDefinition(); + } + + return (MethodInfo)GetMemberMetadataDefinition(method); + } + + private static void CheckNullabilityAttributes(NullabilityInfo nullability, IList attributes) + { + var codeAnalysisReadState = NullabilityState.Unknown; + var codeAnalysisWriteState = NullabilityState.Unknown; + + foreach (CustomAttributeData attribute in attributes) + { + if (attribute.AttributeType.Namespace == "System.Diagnostics.CodeAnalysis") + { + if (attribute.AttributeType.Name == "NotNullAttribute") + { + codeAnalysisReadState = NullabilityState.NotNull; + } + else if ((attribute.AttributeType.Name == "MaybeNullAttribute" || + attribute.AttributeType.Name == "MaybeNullWhenAttribute") && + codeAnalysisReadState == NullabilityState.Unknown && + !IsValueTypeOrValueTypeByRef(nullability.Type)) + { + codeAnalysisReadState = NullabilityState.Nullable; + } + else if (attribute.AttributeType.Name == "DisallowNullAttribute") + { + codeAnalysisWriteState = NullabilityState.NotNull; + } + else if (attribute.AttributeType.Name == "AllowNullAttribute" && + codeAnalysisWriteState == NullabilityState.Unknown && + !IsValueTypeOrValueTypeByRef(nullability.Type)) + { + codeAnalysisWriteState = NullabilityState.Nullable; + } + } + } + + if (codeAnalysisReadState != NullabilityState.Unknown) + { + nullability.ReadState = codeAnalysisReadState; + } + + if (codeAnalysisWriteState != NullabilityState.Unknown) + { + nullability.WriteState = codeAnalysisWriteState; + } + } + + /// + /// Populates for the given . + /// If the nullablePublicOnly feature is set for an assembly, like it does in .NET SDK, the private and/or internal member's + /// nullability attributes are omitted, in this case the API will return NullabilityState.Unknown state. + /// + /// The parameter which nullability info gets populated. + /// If the propertyInfo parameter is null. + /// . + public NullabilityInfo Create(PropertyInfo propertyInfo) + { + MethodInfo? getter = propertyInfo.GetGetMethod(true); + MethodInfo? setter = propertyInfo.GetSetMethod(true); + bool annotationsDisabled = (getter == null || IsPrivateOrInternalMethodAndAnnotationDisabled(getter)) + && (setter == null || IsPrivateOrInternalMethodAndAnnotationDisabled(setter)); + NullableAttributeStateParser parser = annotationsDisabled ? NullableAttributeStateParser.Unknown : CreateParser(propertyInfo.GetCustomAttributesData()); + NullabilityInfo nullability = GetNullabilityInfo(propertyInfo, propertyInfo.PropertyType, parser); + + if (getter != null) + { + CheckNullabilityAttributes(nullability, getter.ReturnParameter.GetCustomAttributesData()); + } + else + { + nullability.ReadState = NullabilityState.Unknown; + } + + if (setter != null) + { + var setterParams = setter.GetParameters(); + CheckNullabilityAttributes(nullability, setterParams[setterParams.Length - 1].GetCustomAttributesData()); + } + else + { + nullability.WriteState = NullabilityState.Unknown; + } + + return nullability; + } + + private bool IsPrivateOrInternalMethodAndAnnotationDisabled(MethodBase method) + { + if ((method.IsPrivate || method.IsFamilyAndAssembly || method.IsAssembly) && + IsPublicOnly(method.IsPrivate, method.IsFamilyAndAssembly, method.IsAssembly, method.Module)) + { + return true; + } + + return false; + } + + /// + /// Populates for the given . + /// If the nullablePublicOnly feature is set for an assembly, like it does in .NET SDK, the private and/or internal member's + /// nullability attributes are omitted, in this case the API will return NullabilityState.Unknown state. + /// + /// The parameter which nullability info gets populated. + /// If the eventInfo parameter is null. + /// . + public NullabilityInfo Create(EventInfo eventInfo) + { + return GetNullabilityInfo(eventInfo, eventInfo.EventHandlerType!, CreateParser(eventInfo.GetCustomAttributesData())); + } + + /// + /// Populates for the given + /// If the nullablePublicOnly feature is set for an assembly, like it does in .NET SDK, the private and/or internal member's + /// nullability attributes are omitted, in this case the API will return NullabilityState.Unknown state. + /// + /// The parameter which nullability info gets populated. + /// If the fieldInfo parameter is null. + /// . + public NullabilityInfo Create(FieldInfo fieldInfo) + { + IList attributes = fieldInfo.GetCustomAttributesData(); + NullableAttributeStateParser parser = IsPrivateOrInternalFieldAndAnnotationDisabled(fieldInfo) ? NullableAttributeStateParser.Unknown : CreateParser(attributes); + NullabilityInfo nullability = GetNullabilityInfo(fieldInfo, fieldInfo.FieldType, parser); + CheckNullabilityAttributes(nullability, attributes); + return nullability; + } + + private bool IsPrivateOrInternalFieldAndAnnotationDisabled(FieldInfo fieldInfo) + { + if ((fieldInfo.IsPrivate || fieldInfo.IsFamilyAndAssembly || fieldInfo.IsAssembly) && + IsPublicOnly(fieldInfo.IsPrivate, fieldInfo.IsFamilyAndAssembly, fieldInfo.IsAssembly, fieldInfo.Module)) + { + return true; + } + + return false; + } + + private bool IsPublicOnly(bool isPrivate, bool isFamilyAndAssembly, bool isAssembly, Module module) + { + if (!_publicOnlyModules.TryGetValue(module, out NotAnnotatedStatus value)) + { + value = PopulateAnnotationInfo(module.GetCustomAttributesData()); + _publicOnlyModules.Add(module, value); + } + + if (value == NotAnnotatedStatus.None) + { + return false; + } + + if (((isPrivate || isFamilyAndAssembly) && value.HasFlag(NotAnnotatedStatus.Private)) || + (isAssembly && value.HasFlag(NotAnnotatedStatus.Internal))) + { + return true; + } + + return false; + } + + private static NotAnnotatedStatus PopulateAnnotationInfo(IList customAttributes) + { + foreach (CustomAttributeData attribute in customAttributes) + { + if (attribute.AttributeType.Name == "NullablePublicOnlyAttribute" && + attribute.AttributeType.Namespace == CompilerServicesNameSpace && + attribute.ConstructorArguments.Count == 1) + { + if (attribute.ConstructorArguments[0].Value is bool boolValue && boolValue) + { + return NotAnnotatedStatus.Internal | NotAnnotatedStatus.Private; + } + else + { + return NotAnnotatedStatus.Private; + } + } + } + + return NotAnnotatedStatus.None; + } + + private NullabilityInfo GetNullabilityInfo(MemberInfo memberInfo, Type type, NullableAttributeStateParser parser) + { + int index = 0; + NullabilityInfo nullability = GetNullabilityInfo(memberInfo, type, parser, ref index); + + if (nullability.ReadState != NullabilityState.Unknown) + { + TryLoadGenericMetaTypeNullability(memberInfo, nullability); + } + + return nullability; + } + + private NullabilityInfo GetNullabilityInfo(MemberInfo memberInfo, Type type, NullableAttributeStateParser parser, ref int index) + { + NullabilityState state = NullabilityState.Unknown; + NullabilityInfo? elementState = null; + NullabilityInfo[] genericArgumentsState = Array.Empty(); + Type underlyingType = type; + + if (underlyingType.IsByRef || underlyingType.IsPointer) + { + underlyingType = underlyingType.GetElementType()!; + } + + if (underlyingType.IsValueType) + { + if (Nullable.GetUnderlyingType(underlyingType) is { } nullableUnderlyingType) + { + underlyingType = nullableUnderlyingType; + state = NullabilityState.Nullable; + } + else + { + state = NullabilityState.NotNull; + } + + if (underlyingType.IsGenericType) + { + ++index; + } + } + else + { + if (!parser.ParseNullableState(index++, ref state) + && GetNullableContext(memberInfo) is { } contextState) + { + state = contextState; + } + + if (underlyingType.IsArray) + { + elementState = GetNullabilityInfo(memberInfo, underlyingType.GetElementType()!, parser, ref index); + } + } + + if (underlyingType.IsGenericType) + { + Type[] genericArguments = underlyingType.GetGenericArguments(); + genericArgumentsState = new NullabilityInfo[genericArguments.Length]; + + for (int i = 0; i < genericArguments.Length; i++) + { + genericArgumentsState[i] = GetNullabilityInfo(memberInfo, genericArguments[i], parser, ref index); + } + } + + return new NullabilityInfo(type, state, state, elementState, genericArgumentsState); + } + + private static NullableAttributeStateParser CreateParser(IList customAttributes) + { + foreach (CustomAttributeData attribute in customAttributes) + { + if (attribute.AttributeType.Name == "NullableAttribute" && + attribute.AttributeType.Namespace == CompilerServicesNameSpace && + attribute.ConstructorArguments.Count == 1) + { + return new NullableAttributeStateParser(attribute.ConstructorArguments[0].Value); + } + } + + return new NullableAttributeStateParser(null); + } + + private void TryLoadGenericMetaTypeNullability(MemberInfo memberInfo, NullabilityInfo nullability) + { + MemberInfo? metaMember = GetMemberMetadataDefinition(memberInfo); + Type? metaType = null; + if (metaMember is FieldInfo field) + { + metaType = field.FieldType; + } + else if (metaMember is PropertyInfo property) + { + metaType = GetPropertyMetaType(property); + } + + if (metaType != null) + { + CheckGenericParameters(nullability, metaMember!, metaType, memberInfo.ReflectedType); + } + } + + private static MemberInfo GetMemberMetadataDefinition(MemberInfo member) + { + Type? type = member.DeclaringType; + if ((type != null) && type.IsGenericType && !type.IsGenericTypeDefinition) + { + return NullabilityInfoHelpers.GetMemberWithSameMetadataDefinitionAs(type.GetGenericTypeDefinition(), member); + } + + return member; + } + + private static Type GetPropertyMetaType(PropertyInfo property) + { + if (property.GetGetMethod(true) is MethodInfo method) + { + return method.ReturnType; + } + + return property.GetSetMethod(true)!.GetParameters()[0].ParameterType; + } + + private void CheckGenericParameters(NullabilityInfo nullability, MemberInfo metaMember, Type metaType, Type? reflectedType) + { + if (metaType.IsGenericParameter) + { + if (nullability.ReadState == NullabilityState.NotNull) + { + _ = TryUpdateGenericParameterNullability(nullability, metaType, reflectedType); + } + } + else if (metaType.ContainsGenericParameters) + { + if (nullability.GenericTypeArguments.Length > 0) + { + Type[] genericArguments = metaType.GetGenericArguments(); + + for (int i = 0; i < genericArguments.Length; i++) + { + CheckGenericParameters(nullability.GenericTypeArguments[i], metaMember, genericArguments[i], reflectedType); + } + } + else if (nullability.ElementType is { } elementNullability && metaType.IsArray) + { + CheckGenericParameters(elementNullability, metaMember, metaType.GetElementType()!, reflectedType); + } + + // We could also follow this branch for metaType.IsPointer, but since pointers must be unmanaged this + // will be a no-op regardless + else if (metaType.IsByRef) + { + CheckGenericParameters(nullability, metaMember, metaType.GetElementType()!, reflectedType); + } + } + } + + private bool TryUpdateGenericParameterNullability(NullabilityInfo nullability, Type genericParameter, Type? reflectedType) + { + Debug.Assert(genericParameter.IsGenericParameter, "must be generic parameter"); + + if (reflectedType is not null + && !genericParameter.IsGenericMethodParameter() + && TryUpdateGenericTypeParameterNullabilityFromReflectedType(nullability, genericParameter, reflectedType, reflectedType)) + { + return true; + } + + if (IsValueTypeOrValueTypeByRef(nullability.Type)) + { + return true; + } + + var state = NullabilityState.Unknown; + if (CreateParser(genericParameter.GetCustomAttributesData()).ParseNullableState(0, ref state)) + { + nullability.ReadState = state; + nullability.WriteState = state; + return true; + } + + if (GetNullableContext(genericParameter) is { } contextState) + { + nullability.ReadState = contextState; + nullability.WriteState = contextState; + return true; + } + + return false; + } + + private bool TryUpdateGenericTypeParameterNullabilityFromReflectedType(NullabilityInfo nullability, Type genericParameter, Type context, Type reflectedType) + { + Debug.Assert(genericParameter.IsGenericParameter && !genericParameter.IsGenericMethodParameter(), "must be generic parameter"); + + Type contextTypeDefinition = context.IsGenericType && !context.IsGenericTypeDefinition ? context.GetGenericTypeDefinition() : context; + if (genericParameter.DeclaringType == contextTypeDefinition) + { + return false; + } + + Type? baseType = contextTypeDefinition.BaseType; + if (baseType is null) + { + return false; + } + + if (!baseType.IsGenericType + || (baseType.IsGenericTypeDefinition ? baseType : baseType.GetGenericTypeDefinition()) != genericParameter.DeclaringType) + { + return TryUpdateGenericTypeParameterNullabilityFromReflectedType(nullability, genericParameter, baseType, reflectedType); + } + + Type[] genericArguments = baseType.GetGenericArguments(); + Type genericArgument = genericArguments[genericParameter.GenericParameterPosition]; + if (genericArgument.IsGenericParameter) + { + return TryUpdateGenericParameterNullability(nullability, genericArgument, reflectedType); + } + + NullableAttributeStateParser parser = CreateParser(contextTypeDefinition.GetCustomAttributesData()); + int nullabilityStateIndex = 1; // start at 1 since index 0 is the type itself + for (int i = 0; i < genericParameter.GenericParameterPosition; i++) + { + nullabilityStateIndex += CountNullabilityStates(genericArguments[i]); + } + + return TryPopulateNullabilityInfo(nullability, parser, ref nullabilityStateIndex); + + static int CountNullabilityStates(Type type) + { + Type underlyingType = Nullable.GetUnderlyingType(type) ?? type; + if (underlyingType.IsGenericType) + { + int count = 1; + foreach (Type genericArgument in underlyingType.GetGenericArguments()) + { + count += CountNullabilityStates(genericArgument); + } + + return count; + } + + if (underlyingType.HasElementType) + { + return (underlyingType.IsArray ? 1 : 0) + CountNullabilityStates(underlyingType.GetElementType()!); + } + + return type.IsValueType ? 0 : 1; + } + } + +#pragma warning disable SA1204 // Static elements should appear before instance elements + private static bool TryPopulateNullabilityInfo(NullabilityInfo nullability, NullableAttributeStateParser parser, ref int index) +#pragma warning restore SA1204 // Static elements should appear before instance elements + { + bool isValueType = IsValueTypeOrValueTypeByRef(nullability.Type); + if (!isValueType) + { + var state = NullabilityState.Unknown; + if (!parser.ParseNullableState(index, ref state)) + { + return false; + } + + nullability.ReadState = state; + nullability.WriteState = state; + } + + if (!isValueType || (Nullable.GetUnderlyingType(nullability.Type) ?? nullability.Type).IsGenericType) + { + index++; + } + + if (nullability.GenericTypeArguments.Length > 0) + { + foreach (NullabilityInfo genericTypeArgumentNullability in nullability.GenericTypeArguments) + { + _ = TryPopulateNullabilityInfo(genericTypeArgumentNullability, parser, ref index); + } + } + else if (nullability.ElementType is { } elementTypeNullability) + { + _ = TryPopulateNullabilityInfo(elementTypeNullability, parser, ref index); + } + + return true; + } + + private static NullabilityState TranslateByte(object? value) + { + return value is byte b ? TranslateByte(b) : NullabilityState.Unknown; + } + + private static NullabilityState TranslateByte(byte b) => + b switch + { + 1 => NullabilityState.NotNull, + 2 => NullabilityState.Nullable, + _ => NullabilityState.Unknown + }; + + private static bool IsValueTypeOrValueTypeByRef(Type type) => + type.IsValueType || ((type.IsByRef || type.IsPointer) && type.GetElementType()!.IsValueType); + + private readonly struct NullableAttributeStateParser + { + private static readonly object UnknownByte = (byte)0; + + private readonly object? _nullableAttributeArgument; + + public NullableAttributeStateParser(object? nullableAttributeArgument) + { + _nullableAttributeArgument = nullableAttributeArgument; + } + + public static NullableAttributeStateParser Unknown => new(UnknownByte); + + public bool ParseNullableState(int index, ref NullabilityState state) + { + switch (_nullableAttributeArgument) + { + case byte b: + state = TranslateByte(b); + return true; + case ReadOnlyCollection args + when index < args.Count && args[index].Value is byte elementB: + state = TranslateByte(elementB); + return true; + default: + return false; + } + } + } + } +} +#endif diff --git a/src/LegacySupport/NullabilityInfoContext/NullabilityInfoHelpers.cs b/src/LegacySupport/NullabilityInfoContext/NullabilityInfoHelpers.cs new file mode 100644 index 00000000000..1ee573a0020 --- /dev/null +++ b/src/LegacySupport/NullabilityInfoContext/NullabilityInfoHelpers.cs @@ -0,0 +1,47 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +#if !NET6_0_OR_GREATER +using System.Diagnostics.CodeAnalysis; + +#pragma warning disable IDE1006 // Naming Styles +#pragma warning disable S3011 // Reflection should not be used to increase accessibility of classes, methods, or fields + +namespace System.Reflection +{ + /// + /// Polyfills for System.Private.CoreLib internals. + /// + [ExcludeFromCodeCoverage] + internal static class NullabilityInfoHelpers + { + public static MemberInfo GetMemberWithSameMetadataDefinitionAs(Type type, MemberInfo member) + { + const BindingFlags all = BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Static | BindingFlags.Instance; + foreach (var info in type.GetMembers(all)) + { + if (info.HasSameMetadataDefinitionAs(member)) + { + return info; + } + } + + throw new MissingMemberException(type.FullName, member.Name); + } + + // https://github.com/dotnet/runtime/blob/main/src/coreclr/System.Private.CoreLib/src/System/Reflection/MemberInfo.Internal.cs + public static bool HasSameMetadataDefinitionAs(this MemberInfo target, MemberInfo other) + { + return target.MetadataToken == other.MetadataToken && + target.Module.Equals(other.Module); + } + + // https://github.com/dotnet/runtime/issues/23493 + public static bool IsGenericMethodParameter(this Type target) + { + return target.IsGenericParameter && + target.DeclaringMethod != null; + } + } +} +#endif diff --git a/src/Libraries/Microsoft.Extensions.AI.Abstractions/Microsoft.Extensions.AI.Abstractions.csproj b/src/Libraries/Microsoft.Extensions.AI.Abstractions/Microsoft.Extensions.AI.Abstractions.csproj index 72c354ccb97..15cf9aeecc4 100644 --- a/src/Libraries/Microsoft.Extensions.AI.Abstractions/Microsoft.Extensions.AI.Abstractions.csproj +++ b/src/Libraries/Microsoft.Extensions.AI.Abstractions/Microsoft.Extensions.AI.Abstractions.csproj @@ -22,6 +22,7 @@ true + true true true true diff --git a/src/Libraries/Microsoft.Extensions.AI.Abstractions/Utilities/AIJsonUtilities.Schema.Create.cs b/src/Libraries/Microsoft.Extensions.AI.Abstractions/Utilities/AIJsonUtilities.Schema.Create.cs index 200b470d4d8..c2b81817d79 100644 --- a/src/Libraries/Microsoft.Extensions.AI.Abstractions/Utilities/AIJsonUtilities.Schema.Create.cs +++ b/src/Libraries/Microsoft.Extensions.AI.Abstractions/Utilities/AIJsonUtilities.Schema.Create.cs @@ -83,6 +83,7 @@ public static JsonElement CreateFunctionJsonSchema( title ??= method.GetCustomAttribute()?.DisplayName ?? method.Name; description ??= method.GetCustomAttribute()?.Description; + NullabilityInfoContext nullabilityContext = new(); JsonObject parameterSchemas = new(); JsonArray? requiredProperties = null; foreach (ParameterInfo parameter in method.GetParameters()) @@ -118,6 +119,7 @@ public static JsonElement CreateFunctionJsonSchema( JsonNode parameterSchema = CreateJsonSchemaCore( type: parameter.ParameterType, parameter: parameter, + nullabilityContext: nullabilityContext, description: parameterDescription, hasDefaultValue: hasDefaultValue, defaultValue: defaultValue, @@ -182,7 +184,7 @@ public static JsonElement CreateJsonSchema( { serializerOptions ??= DefaultOptions; inferenceOptions ??= AIJsonSchemaCreateOptions.Default; - JsonNode schema = CreateJsonSchemaCore(type, parameter: null, description, hasDefaultValue, defaultValue, serializerOptions, inferenceOptions); + JsonNode schema = CreateJsonSchemaCore(type, parameter: null, nullabilityContext: null, description, hasDefaultValue, defaultValue, serializerOptions, inferenceOptions); // Finally, apply any schema transformations if specified. if (inferenceOptions.TransformOptions is { } options) @@ -208,6 +210,7 @@ internal static void ValidateSchemaDocument(JsonElement document, [CallerArgumen private static JsonNode CreateJsonSchemaCore( Type? type, ParameterInfo? parameter, + NullabilityInfoContext? nullabilityContext, string? description, bool hasDefaultValue, object? defaultValue, @@ -338,6 +341,21 @@ JsonNode TransformSchemaNode(JsonSchemaExporterContext schemaExporterContext, Js objSchema.InsertAtStart(TypePropertyName, new JsonArray { (JsonNode)"string", (JsonNode)"null" }); } } + else if (parameter is not null && + !ctx.TypeInfo.Type.IsValueType && + nullabilityContext?.Create(parameter).WriteState is NullabilityState.Nullable) + { + // Handle nullable reference type parameters (e.g., object?). + if (objSchema.TryGetPropertyValue(TypePropertyName, out JsonNode? typeKeyWord) && + typeKeyWord?.GetValueKind() is JsonValueKind.String) + { + string typeValue = typeKeyWord.GetValue()!; + if (typeValue is not "null") + { + objSchema[TypePropertyName] = new JsonArray { (JsonNode)typeValue, (JsonNode)"null" }; + } + } + } } if (ctx.Path.IsEmpty && hasDefaultValue) diff --git a/test/Libraries/Microsoft.Extensions.AI.Tests/Functions/AIFunctionFactoryTest.cs b/test/Libraries/Microsoft.Extensions.AI.Tests/Functions/AIFunctionFactoryTest.cs index deea4cbcf13..0b40c67d5d9 100644 --- a/test/Libraries/Microsoft.Extensions.AI.Tests/Functions/AIFunctionFactoryTest.cs +++ b/test/Libraries/Microsoft.Extensions.AI.Tests/Functions/AIFunctionFactoryTest.cs @@ -1314,6 +1314,110 @@ public void RegularStaticMethod_NameUnchanged() Assert.Equal("TestStaticMethod", tool.Name); } + [Fact] + public void JsonSchema_NullableValueTypeParameters_AllowNull() + { + // Test that nullable value type parameters (e.g., int?) generate JSON schemas that allow null values. + // This should work on all target frameworks. + AIFunction func = AIFunctionFactory.Create( + (int? nullableInt, int? nullableIntWithDefault = null) => { }); + + JsonElement schema = func.JsonSchema; + JsonElement properties = schema.GetProperty("properties"); + + // nullableInt should have type ["integer", "null"] + JsonElement nullableIntSchema = properties.GetProperty("nullableInt"); + Assert.True( + nullableIntSchema.TryGetProperty("type", out JsonElement nullableIntType), + "nullableInt schema should have a 'type' property"); + Assert.Equal(JsonValueKind.Array, nullableIntType.ValueKind); + Assert.Contains("integer", nullableIntType.EnumerateArray().Select(e => e.GetString())); + Assert.Contains("null", nullableIntType.EnumerateArray().Select(e => e.GetString())); + + // nullableIntWithDefault should have type ["integer", "null"] and default: null + JsonElement nullableIntWithDefaultSchema = properties.GetProperty("nullableIntWithDefault"); + Assert.True( + nullableIntWithDefaultSchema.TryGetProperty("type", out JsonElement nullableIntWithDefaultType), + "nullableIntWithDefault schema should have a 'type' property"); + Assert.Equal(JsonValueKind.Array, nullableIntWithDefaultType.ValueKind); + Assert.Contains("integer", nullableIntWithDefaultType.EnumerateArray().Select(e => e.GetString())); + Assert.Contains("null", nullableIntWithDefaultType.EnumerateArray().Select(e => e.GetString())); + Assert.True( + nullableIntWithDefaultSchema.TryGetProperty("default", out JsonElement nullableIntWithDefaultDefault), + "nullableIntWithDefault schema should have a 'default' property"); + Assert.Equal(JsonValueKind.Null, nullableIntWithDefaultDefault.ValueKind); + + // Required array should contain only parameters without default values + JsonElement required = schema.GetProperty("required"); + List requiredParams = required.EnumerateArray().Select(e => e.GetString()!).ToList(); + Assert.Contains("nullableInt", requiredParams); + Assert.DoesNotContain("nullableIntWithDefault", requiredParams); + } + + [Fact] + public void JsonSchema_NullableReferenceTypeParameters_AllowNull() + { + // Regression test for https://github.com/dotnet/extensions/issues/7182 + // Nullable reference type parameters (e.g., string?) should generate JSON schemas that allow null values. + AIFunction func = AIFunctionFactory.Create( + (string? nullableString, int? nullableInt, string? nullableStringWithDefault = null, int? nullableIntWithDefault = null) => { }); + + JsonElement schema = func.JsonSchema; + JsonElement properties = schema.GetProperty("properties"); + + // nullableString should have type ["string", "null"] + JsonElement nullableStringSchema = properties.GetProperty("nullableString"); + Assert.True( + nullableStringSchema.TryGetProperty("type", out JsonElement nullableStringType), + "nullableString schema should have a 'type' property"); + Assert.Equal(JsonValueKind.Array, nullableStringType.ValueKind); + Assert.Contains("string", nullableStringType.EnumerateArray().Select(e => e.GetString())); + Assert.Contains("null", nullableStringType.EnumerateArray().Select(e => e.GetString())); + + // nullableInt should have type ["integer", "null"] + JsonElement nullableIntSchema = properties.GetProperty("nullableInt"); + Assert.True( + nullableIntSchema.TryGetProperty("type", out JsonElement nullableIntType), + "nullableInt schema should have a 'type' property"); + Assert.Equal(JsonValueKind.Array, nullableIntType.ValueKind); + Assert.Contains("integer", nullableIntType.EnumerateArray().Select(e => e.GetString())); + Assert.Contains("null", nullableIntType.EnumerateArray().Select(e => e.GetString())); + + // nullableStringWithDefault should have type ["string", "null"] and default: null + JsonElement nullableStringWithDefaultSchema = properties.GetProperty("nullableStringWithDefault"); + Assert.True( + nullableStringWithDefaultSchema.TryGetProperty("type", out JsonElement nullableStringWithDefaultType), + "nullableStringWithDefault schema should have a 'type' property"); + Assert.Equal(JsonValueKind.Array, nullableStringWithDefaultType.ValueKind); + Assert.Contains("string", nullableStringWithDefaultType.EnumerateArray().Select(e => e.GetString())); + Assert.Contains("null", nullableStringWithDefaultType.EnumerateArray().Select(e => e.GetString())); + Assert.True( + nullableStringWithDefaultSchema.TryGetProperty("default", out JsonElement nullableStringWithDefaultDefault), + "nullableStringWithDefault schema should have a 'default' property"); + Assert.Equal(JsonValueKind.Null, nullableStringWithDefaultDefault.ValueKind); + + // nullableIntWithDefault should have type ["integer", "null"] and default: null + JsonElement nullableIntWithDefaultSchema = properties.GetProperty("nullableIntWithDefault"); + Assert.True( + nullableIntWithDefaultSchema.TryGetProperty("type", out JsonElement nullableIntWithDefaultType), + "nullableIntWithDefault schema should have a 'type' property"); + Assert.Equal(JsonValueKind.Array, nullableIntWithDefaultType.ValueKind); + Assert.Contains("integer", nullableIntWithDefaultType.EnumerateArray().Select(e => e.GetString())); + Assert.Contains("null", nullableIntWithDefaultType.EnumerateArray().Select(e => e.GetString())); + Assert.True( + nullableIntWithDefaultSchema.TryGetProperty("default", out JsonElement nullableIntWithDefaultDefault), + "nullableIntWithDefault schema should have a 'default' property"); + Assert.Equal(JsonValueKind.Null, nullableIntWithDefaultDefault.ValueKind); + + // Required array should contain only parameters without default values + JsonElement required = schema.GetProperty("required"); + List requiredParams = required.EnumerateArray().Select(e => e.GetString()!).ToList(); + Assert.Contains("nullableString", requiredParams); + Assert.Contains("nullableInt", requiredParams); + Assert.DoesNotContain("nullableStringWithDefault", requiredParams); + Assert.DoesNotContain("nullableIntWithDefault", requiredParams); + } + [JsonSerializable(typeof(IAsyncEnumerable))] [JsonSerializable(typeof(int[]))] [JsonSerializable(typeof(string))]