Skip to content

Commit

Permalink
Replace our nullability decoding with the new BCL API (#26967)
Browse files Browse the repository at this point in the history
Closes #24744
  • Loading branch information
roji authored Dec 14, 2021
1 parent 0aeab4d commit 0e1e95b
Show file tree
Hide file tree
Showing 4 changed files with 50 additions and 146 deletions.
120 changes: 10 additions & 110 deletions src/EFCore/Metadata/Conventions/NonNullableConventionBase.cs
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

using System.Diagnostics.CodeAnalysis;

namespace Microsoft.EntityFrameworkCore.Metadata.Conventions;

/// <summary>
Expand All @@ -14,12 +12,7 @@ namespace Microsoft.EntityFrameworkCore.Metadata.Conventions;
/// </remarks>
public abstract class NonNullableConventionBase : IModelFinalizingConvention
{
// For the interpretation of nullability metadata, see
// https://github.com/dotnet/roslyn/blob/master/docs/features/nullable-metadata.md

private const string StateAnnotationName = "NonNullableConventionState";
private const string NullableAttributeFullName = "System.Runtime.CompilerServices.NullableAttribute";
private const string NullableContextAttributeFullName = "System.Runtime.CompilerServices.NullableContextAttribute";

/// <summary>
/// Creates a new instance of <see cref="NonNullableConventionBase" />.
Expand Down Expand Up @@ -50,118 +43,25 @@ protected virtual bool IsNonNullableReferenceType(
return false;
}

var state = GetOrInitializeState(modelBuilder);

// First check for [MaybeNull] on the return value. If it exists, the member is nullable.
// Note: avoid using GetCustomAttribute<> below because of https://github.com/mono/mono/issues/17477
var isMaybeNull = memberInfo switch
{
FieldInfo f
=> f.CustomAttributes.Any(a => a.AttributeType == typeof(MaybeNullAttribute)),
PropertyInfo p
=> p.GetMethod?.ReturnParameter?.CustomAttributes?.Any(a => a.AttributeType == typeof(MaybeNullAttribute)) == true,
_ => false
};

if (isMaybeNull)
{
return false;
}

// For C# 8.0 nullable types, the C# compiler currently synthesizes a NullableAttribute that expresses nullability into
// assemblies it produces. If the model is spread across more than one assembly, there will be multiple versions of this
// attribute, so look for it by name, caching to avoid reflection on every check.
// Note that this may change - if https://github.com/dotnet/corefx/issues/36222 is done we can remove all of this.

// First look for NullableAttribute on the member itself
if (Attribute.GetCustomAttributes(memberInfo)
.FirstOrDefault(a => a.GetType().FullName == NullableAttributeFullName) is Attribute attribute)
{
var attributeType = attribute.GetType();

if (attributeType != state.NullableAttrType)
{
state.NullableFlagsFieldInfo = attributeType.GetField("NullableFlags");
state.NullableAttrType = attributeType;
}

if (state.NullableFlagsFieldInfo?.GetValue(attribute) is byte[] flags)
{
return flags.FirstOrDefault() == 1;
}
}

// No attribute on the member, try to find a NullableContextAttribute on the declaring type
var type = memberInfo.DeclaringType;
if (type is not null)
{
// We currently don't calculate support nullability for generic properties, since calculating that is complex
// (depends on the nullability of generic type argument).
// However, we special case Dictionary as it's used for property bags, and specifically don't identify its indexer
// as non-nullable.
if (memberInfo is PropertyInfo property
&& property.IsIndexerProperty()
&& type.IsGenericType
&& type.GetGenericTypeDefinition() == typeof(Dictionary<,>))
{
return false;
}

return DoesTypeHaveNonNullableContext(type, state);
}

return false;
}

private static bool DoesTypeHaveNonNullableContext(Type type, NonNullabilityConventionState state)
{
if (state.TypeCache.TryGetValue(type, out var cachedTypeNonNullable))
{
return cachedTypeNonNullable;
}

if (Attribute.GetCustomAttributes(type)
.FirstOrDefault(a => a.GetType().FullName == NullableContextAttributeFullName) is Attribute contextAttr)
{
var attributeType = contextAttr.GetType();
var annotation =
modelBuilder.Metadata.FindAnnotation(StateAnnotationName)
?? modelBuilder.Metadata.AddAnnotation(StateAnnotationName, new NullabilityInfoContext());

if (attributeType != state.NullableContextAttrType)
{
state.NullableContextFlagFieldInfo = attributeType.GetField("Flag");
state.NullableContextAttrType = attributeType;
}
var nullabilityInfoContext = (NullabilityInfoContext)annotation.Value!;

if (state.NullableContextFlagFieldInfo?.GetValue(contextAttr) is byte flag)
{
return state.TypeCache[type] = flag == 1;
}
}
else if (type.IsNested)
var nullabilityInfo = memberInfo switch
{
return state.TypeCache[type] = DoesTypeHaveNonNullableContext(type.DeclaringType!, state);
}
PropertyInfo propertyInfo => nullabilityInfoContext.Create(propertyInfo),
FieldInfo fieldInfo => nullabilityInfoContext.Create(fieldInfo),
_ => null
};

return state.TypeCache[type] = false;
return nullabilityInfo?.ReadState == NullabilityState.NotNull;
}

private static NonNullabilityConventionState GetOrInitializeState(IConventionModelBuilder modelBuilder)
=> (NonNullabilityConventionState)(
modelBuilder.Metadata.FindAnnotation(StateAnnotationName)
?? modelBuilder.Metadata.AddAnnotation(StateAnnotationName, new NonNullabilityConventionState())
).Value!;

/// <inheritdoc />
public virtual void ProcessModelFinalizing(
IConventionModelBuilder modelBuilder,
IConventionContext<IConventionModelBuilder> context)
=> modelBuilder.Metadata.RemoveAnnotation(StateAnnotationName);

private sealed class NonNullabilityConventionState
{
public Type? NullableAttrType;
public Type? NullableContextAttrType;
public FieldInfo? NullableFlagsFieldInfo;
public FieldInfo? NullableContextFlagFieldInfo;
public Dictionary<Type, bool> TypeCache { get; } = new();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,6 @@ public NonNullableReferencePropertyConvention(ProviderConventionSetBuilderDepend

private void Process(IConventionPropertyBuilder propertyBuilder)
{
// If the model is spread across multiple assemblies, it may contain different NullableAttribute types as
// the compiler synthesizes them for each assembly.
if (propertyBuilder.Metadata.GetIdentifyingMemberInfo() is MemberInfo memberInfo
&& IsNonNullableReferenceType(propertyBuilder.ModelBuilder, memberInfo))
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,21 +8,23 @@

namespace Microsoft.EntityFrameworkCore.Metadata.Conventions;

#nullable enable

public class NonNullableNavigationConventionTest
{
[ConditionalFact]
public void Non_nullability_does_not_override_configuration_from_explicit_source()
{
var dependentEntityTypeBuilder = CreateInternalEntityTypeBuilder<Post>();
var principalEntityTypeBuilder = dependentEntityTypeBuilder.ModelBuilder.Entity(typeof(Blog), ConfigurationSource.Convention);
var principalEntityTypeBuilder = dependentEntityTypeBuilder.ModelBuilder.Entity(typeof(Blog), ConfigurationSource.Convention)!;

var relationshipBuilder = dependentEntityTypeBuilder.HasRelationship(
principalEntityTypeBuilder.Metadata,
nameof(Post.Blog),
nameof(Blog.Posts),
ConfigurationSource.Convention);
ConfigurationSource.Convention)!;

var navigation = dependentEntityTypeBuilder.Metadata.FindNavigation(nameof(Post.Blog));
var navigation = dependentEntityTypeBuilder.Metadata.FindNavigation(nameof(Post.Blog))!;

relationshipBuilder.IsRequired(false, ConfigurationSource.Explicit);

Expand All @@ -39,15 +41,15 @@ public void Non_nullability_does_not_override_configuration_from_explicit_source
public void Non_nullability_does_not_override_configuration_from_data_annotation()
{
var dependentEntityTypeBuilder = CreateInternalEntityTypeBuilder<Post>();
var principalEntityTypeBuilder = dependentEntityTypeBuilder.ModelBuilder.Entity(typeof(Blog), ConfigurationSource.Convention);
var principalEntityTypeBuilder = dependentEntityTypeBuilder.ModelBuilder.Entity(typeof(Blog), ConfigurationSource.Convention)!;

var relationshipBuilder = dependentEntityTypeBuilder.HasRelationship(
principalEntityTypeBuilder.Metadata,
nameof(Post.Blog),
nameof(Blog.Posts),
ConfigurationSource.Convention);
ConfigurationSource.Convention)!;

var navigation = dependentEntityTypeBuilder.Metadata.FindNavigation(nameof(Post.Blog));
var navigation = dependentEntityTypeBuilder.Metadata.FindNavigation(nameof(Post.Blog))!;

relationshipBuilder.IsRequired(false, ConfigurationSource.DataAnnotation);

Expand All @@ -65,15 +67,15 @@ public void Non_nullability_does_not_set_is_required_for_collection_navigation()
{
var dependentEntityTypeBuilder = CreateInternalEntityTypeBuilder<Dependent>();
var principalEntityTypeBuilder =
dependentEntityTypeBuilder.ModelBuilder.Entity(typeof(Principal), ConfigurationSource.Convention);
dependentEntityTypeBuilder.ModelBuilder.Entity(typeof(Principal), ConfigurationSource.Convention)!;

var relationshipBuilder = principalEntityTypeBuilder.HasRelationship(
dependentEntityTypeBuilder.Metadata,
nameof(Principal.Dependents),
nameof(Dependent.Principal),
ConfigurationSource.Convention);
ConfigurationSource.Convention)!;

var navigation = principalEntityTypeBuilder.Metadata.FindNavigation(nameof(Principal.Dependents));
var navigation = principalEntityTypeBuilder.Metadata.FindNavigation(nameof(Principal.Dependents))!;

Assert.False(relationshipBuilder.Metadata.IsRequired);

Expand All @@ -89,17 +91,17 @@ public void Non_nullability_does_not_set_is_required_for_navigation_to_dependent
{
var dependentEntityTypeBuilder = CreateInternalEntityTypeBuilder<Dependent>();
var principalEntityTypeBuilder =
dependentEntityTypeBuilder.ModelBuilder.Entity(typeof(Principal), ConfigurationSource.Convention);
dependentEntityTypeBuilder.ModelBuilder.Entity(typeof(Principal), ConfigurationSource.Convention)!;

var relationshipBuilder = dependentEntityTypeBuilder.HasRelationship(
principalEntityTypeBuilder.Metadata,
nameof(Dependent.Principal),
nameof(Principal.Dependent),
ConfigurationSource.Convention)
ConfigurationSource.Convention)!
.HasEntityTypes
(principalEntityTypeBuilder.Metadata, dependentEntityTypeBuilder.Metadata, ConfigurationSource.Explicit);
(principalEntityTypeBuilder.Metadata, dependentEntityTypeBuilder.Metadata, ConfigurationSource.Explicit)!;

var navigation = principalEntityTypeBuilder.Metadata.FindNavigation(nameof(Principal.Dependent));
var navigation = principalEntityTypeBuilder.Metadata.FindNavigation(nameof(Principal.Dependent))!;

Assert.False(relationshipBuilder.Metadata.IsRequired);

Expand All @@ -116,7 +118,7 @@ public void Non_nullability_sets_is_required_with_conventional_builder()
modelBuilder.Entity<BlogDetails>();

Assert.True(
model.FindEntityType(typeof(BlogDetails)).GetForeignKeys().Single(fk => fk.PrincipalEntityType?.ClrType == typeof(Blog))
model.FindEntityType(typeof(BlogDetails))!.GetForeignKeys().Single(fk => fk.PrincipalEntityType?.ClrType == typeof(Blog))
.IsRequired);
}

Expand All @@ -125,7 +127,7 @@ private Navigation RunConvention(InternalForeignKeyBuilder relationshipBuilder,
var context = new ConventionContext<IConventionNavigationBuilder>(
relationshipBuilder.Metadata.DeclaringEntityType.Model.ConventionDispatcher);
CreateNotNullNavigationConvention().ProcessNavigationAdded(navigation.Builder, context);
return context.ShouldStopProcessing() ? (Navigation)context.Result?.Metadata : navigation;
return context.ShouldStopProcessing() ? (Navigation)context.Result?.Metadata! : navigation;
}

private NonNullableNavigationConvention CreateNotNullNavigationConvention()
Expand All @@ -146,15 +148,15 @@ private InternalEntityTypeBuilder CreateInternalEntityTypeBuilder<T>()

var modelBuilder = new Model(conventionSet).Builder;

return modelBuilder.Entity(typeof(T), ConfigurationSource.Explicit);
return modelBuilder.Entity(typeof(T), ConfigurationSource.Explicit)!;
}

private ModelBuilder CreateModelBuilder()
{
var serviceProvider = CreateServiceProvider();
return new ModelBuilder(
serviceProvider.GetService<IConventionSetBuilder>().CreateConventionSet(),
serviceProvider.GetService<ModelDependencies>());
serviceProvider.GetRequiredService<IConventionSetBuilder>().CreateConventionSet(),
serviceProvider.GetRequiredService<ModelDependencies>());
}

private ProviderConventionSetBuilderDependencies CreateDependencies()
Expand All @@ -179,9 +181,9 @@ protected IServiceProvider CreateServiceProvider()
return modelLogger;
}

#nullable enable
#pragma warning disable CS8618

// ReSharper disable UnusedMember.Local
// ReSharper disable ClassNeverInstantiated.Local
private class Blog
{
public int Id { get; set; }
Expand Down Expand Up @@ -245,6 +247,7 @@ private class Dependent
[ForeignKey("PrincipalId, PrincipalFk")]
public Principal? CompositePrincipal { get; set; }
}
// ReSharper restore ClassNeverInstantiated.Local
// ReSharper restore UnusedMember.Local
#pragma warning restore CS8618
#nullable disable
}
Loading

0 comments on commit 0e1e95b

Please sign in to comment.