Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Replace our nullability decoding with the new BCL API #26967

Merged
merged 2 commits into from
Dec 14, 2021
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
123 changes: 23 additions & 100 deletions src/EFCore/Metadata/Conventions/NonNullableConventionBase.cs
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

using System.Diagnostics.CodeAnalysis;

namespace Microsoft.EntityFrameworkCore.Metadata.Conventions;

/// <summary>
Expand All @@ -14,12 +12,7 @@ namespace Microsoft.EntityFrameworkCore.Metadata.Conventions;
/// </remarks>
public abstract class NonNullableConventionBase : IModelFinalizingConvention
{
// For the interpretation of nullability metadata, see
// https://github.com/dotnet/roslyn/blob/master/docs/features/nullable-metadata.md

private const string StateAnnotationName = "NonNullableConventionState";
private const string NullableAttributeFullName = "System.Runtime.CompilerServices.NullableAttribute";
private const string NullableContextAttributeFullName = "System.Runtime.CompilerServices.NullableContextAttribute";

/// <summary>
/// Creates a new instance of <see cref="NonNullableConventionBase" />.
Expand Down Expand Up @@ -50,118 +43,48 @@ protected virtual bool IsNonNullableReferenceType(
return false;
}

var state = GetOrInitializeState(modelBuilder);

// First check for [MaybeNull] on the return value. If it exists, the member is nullable.
// Note: avoid using GetCustomAttribute<> below because of https://github.com/mono/mono/issues/17477
var isMaybeNull = memberInfo switch
{
FieldInfo f
=> f.CustomAttributes.Any(a => a.AttributeType == typeof(MaybeNullAttribute)),
PropertyInfo p
=> p.GetMethod?.ReturnParameter?.CustomAttributes?.Any(a => a.AttributeType == typeof(MaybeNullAttribute)) == true,
_ => false
};

if (isMaybeNull)
{
return false;
}
var annotation =
modelBuilder.Metadata.FindAnnotation(StateAnnotationName)
?? modelBuilder.Metadata.AddAnnotation(StateAnnotationName, new NullabilityInfoContext());

// For C# 8.0 nullable types, the C# compiler currently synthesizes a NullableAttribute that expresses nullability into
// assemblies it produces. If the model is spread across more than one assembly, there will be multiple versions of this
// attribute, so look for it by name, caching to avoid reflection on every check.
// Note that this may change - if https://github.com/dotnet/corefx/issues/36222 is done we can remove all of this.
var nullabilityInfoContext = (NullabilityInfoContext)annotation.Value!;

// First look for NullableAttribute on the member itself
if (Attribute.GetCustomAttributes(memberInfo)
.FirstOrDefault(a => a.GetType().FullName == NullableAttributeFullName) is Attribute attribute)
if (memberInfo is PropertyInfo propertyInfo)
{
var attributeType = attribute.GetType();
var nullabilityInfo = nullabilityInfoContext.Create(propertyInfo);

if (attributeType != state.NullableAttrType)
if (nullabilityInfo.ReadState == NullabilityState.NotNull)
{
state.NullableFlagsFieldInfo = attributeType.GetField("NullableFlags");
state.NullableAttrType = attributeType;
return true;
}

if (state.NullableFlagsFieldInfo?.GetValue(attribute) is byte[] flags)
{
return flags.FirstOrDefault() == 1;
}
}

// No attribute on the member, try to find a NullableContextAttribute on the declaring type
var type = memberInfo.DeclaringType;
if (type is not null)
{
// We currently don't calculate support nullability for generic properties, since calculating that is complex
// (depends on the nullability of generic type argument).
// However, we special case Dictionary as it's used for property bags, and specifically don't identify its indexer
// as non-nullable.
if (memberInfo is PropertyInfo property
&& property.IsIndexerProperty()
&& type.IsGenericType
&& type.GetGenericTypeDefinition() == typeof(Dictionary<,>))
{
return false;
}

return DoesTypeHaveNonNullableContext(type, state);
}

return false;
}

private static bool DoesTypeHaveNonNullableContext(Type type, NonNullabilityConventionState state)
{
if (state.TypeCache.TryGetValue(type, out var cachedTypeNonNullable))
{
return cachedTypeNonNullable;
// In order for us to configure a property as non-nullable, it must be:
// 1. Non-nullable for both read and write, or
// 2. Non-nullable for read and read-only, or
// 3. Non-nullable for write and write-only
// if (nullabilityInfo.ReadState == NullabilityState.NotNull
// && (nullabilityInfo.WriteState == NullabilityState.NotNull || !propertyInfo.CanWrite)
// || nullabilityInfo.WriteState == NullabilityState.NotNull && !propertyInfo.CanRead)
// {
// return true;
// }
roji marked this conversation as resolved.
Show resolved Hide resolved
}

if (Attribute.GetCustomAttributes(type)
.FirstOrDefault(a => a.GetType().FullName == NullableContextAttributeFullName) is Attribute contextAttr)
else if (memberInfo is FieldInfo fieldInfo)
{
var attributeType = contextAttr.GetType();

if (attributeType != state.NullableContextAttrType)
{
state.NullableContextFlagFieldInfo = attributeType.GetField("Flag");
state.NullableContextAttrType = attributeType;
}
var nullabilityInfo = nullabilityInfoContext.Create(fieldInfo);

if (state.NullableContextFlagFieldInfo?.GetValue(contextAttr) is byte flag)
if (nullabilityInfo.ReadState == NullabilityState.NotNull /* && nullabilityInfo.WriteState == NullabilityState.NotNull */)
roji marked this conversation as resolved.
Show resolved Hide resolved
{
return state.TypeCache[type] = flag == 1;
return true;
}
}
else if (type.IsNested)
{
return state.TypeCache[type] = DoesTypeHaveNonNullableContext(type.DeclaringType!, state);
}

return state.TypeCache[type] = false;
return false;
}

private static NonNullabilityConventionState GetOrInitializeState(IConventionModelBuilder modelBuilder)
=> (NonNullabilityConventionState)(
modelBuilder.Metadata.FindAnnotation(StateAnnotationName)
?? modelBuilder.Metadata.AddAnnotation(StateAnnotationName, new NonNullabilityConventionState())
).Value!;

/// <inheritdoc />
public virtual void ProcessModelFinalizing(
IConventionModelBuilder modelBuilder,
IConventionContext<IConventionModelBuilder> context)
=> modelBuilder.Metadata.RemoveAnnotation(StateAnnotationName);

private sealed class NonNullabilityConventionState
{
public Type? NullableAttrType;
public Type? NullableContextAttrType;
public FieldInfo? NullableFlagsFieldInfo;
public FieldInfo? NullableContextFlagFieldInfo;
public Dictionary<Type, bool> TypeCache { get; } = new();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ public class NonNullableReferencePropertyConvention : NonNullableConventionBase,
IPropertyAddedConvention,
IPropertyFieldChangedConvention
{
private readonly NullabilityInfoContext _nullabilityInfoContext = new();
roji marked this conversation as resolved.
Show resolved Hide resolved

/// <summary>
/// Creates a new instance of <see cref="NonNullableReferencePropertyConvention" />.
/// </summary>
Expand All @@ -26,8 +28,6 @@ public NonNullableReferencePropertyConvention(ProviderConventionSetBuilderDepend

private void Process(IConventionPropertyBuilder propertyBuilder)
{
// If the model is spread across multiple assemblies, it may contain different NullableAttribute types as
// the compiler synthesizes them for each assembly.
if (propertyBuilder.Metadata.GetIdentifyingMemberInfo() is MemberInfo memberInfo
&& IsNonNullableReferenceType(propertyBuilder.ModelBuilder, memberInfo))
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,21 +8,23 @@

namespace Microsoft.EntityFrameworkCore.Metadata.Conventions;

#nullable enable

public class NonNullableNavigationConventionTest
{
[ConditionalFact]
public void Non_nullability_does_not_override_configuration_from_explicit_source()
{
var dependentEntityTypeBuilder = CreateInternalEntityTypeBuilder<Post>();
var principalEntityTypeBuilder = dependentEntityTypeBuilder.ModelBuilder.Entity(typeof(Blog), ConfigurationSource.Convention);
var principalEntityTypeBuilder = dependentEntityTypeBuilder.ModelBuilder.Entity(typeof(Blog), ConfigurationSource.Convention)!;

var relationshipBuilder = dependentEntityTypeBuilder.HasRelationship(
principalEntityTypeBuilder.Metadata,
nameof(Post.Blog),
nameof(Blog.Posts),
ConfigurationSource.Convention);
ConfigurationSource.Convention)!;

var navigation = dependentEntityTypeBuilder.Metadata.FindNavigation(nameof(Post.Blog));
var navigation = dependentEntityTypeBuilder.Metadata.FindNavigation(nameof(Post.Blog))!;

relationshipBuilder.IsRequired(false, ConfigurationSource.Explicit);

Expand All @@ -39,15 +41,15 @@ public void Non_nullability_does_not_override_configuration_from_explicit_source
public void Non_nullability_does_not_override_configuration_from_data_annotation()
{
var dependentEntityTypeBuilder = CreateInternalEntityTypeBuilder<Post>();
var principalEntityTypeBuilder = dependentEntityTypeBuilder.ModelBuilder.Entity(typeof(Blog), ConfigurationSource.Convention);
var principalEntityTypeBuilder = dependentEntityTypeBuilder.ModelBuilder.Entity(typeof(Blog), ConfigurationSource.Convention)!;

var relationshipBuilder = dependentEntityTypeBuilder.HasRelationship(
principalEntityTypeBuilder.Metadata,
nameof(Post.Blog),
nameof(Blog.Posts),
ConfigurationSource.Convention);
ConfigurationSource.Convention)!;

var navigation = dependentEntityTypeBuilder.Metadata.FindNavigation(nameof(Post.Blog));
var navigation = dependentEntityTypeBuilder.Metadata.FindNavigation(nameof(Post.Blog))!;

relationshipBuilder.IsRequired(false, ConfigurationSource.DataAnnotation);

Expand All @@ -65,15 +67,15 @@ public void Non_nullability_does_not_set_is_required_for_collection_navigation()
{
var dependentEntityTypeBuilder = CreateInternalEntityTypeBuilder<Dependent>();
var principalEntityTypeBuilder =
dependentEntityTypeBuilder.ModelBuilder.Entity(typeof(Principal), ConfigurationSource.Convention);
dependentEntityTypeBuilder.ModelBuilder.Entity(typeof(Principal), ConfigurationSource.Convention)!;

var relationshipBuilder = principalEntityTypeBuilder.HasRelationship(
dependentEntityTypeBuilder.Metadata,
nameof(Principal.Dependents),
nameof(Dependent.Principal),
ConfigurationSource.Convention);
ConfigurationSource.Convention)!;

var navigation = principalEntityTypeBuilder.Metadata.FindNavigation(nameof(Principal.Dependents));
var navigation = principalEntityTypeBuilder.Metadata.FindNavigation(nameof(Principal.Dependents))!;

Assert.False(relationshipBuilder.Metadata.IsRequired);

Expand All @@ -89,17 +91,17 @@ public void Non_nullability_does_not_set_is_required_for_navigation_to_dependent
{
var dependentEntityTypeBuilder = CreateInternalEntityTypeBuilder<Dependent>();
var principalEntityTypeBuilder =
dependentEntityTypeBuilder.ModelBuilder.Entity(typeof(Principal), ConfigurationSource.Convention);
dependentEntityTypeBuilder.ModelBuilder.Entity(typeof(Principal), ConfigurationSource.Convention)!;

var relationshipBuilder = dependentEntityTypeBuilder.HasRelationship(
principalEntityTypeBuilder.Metadata,
nameof(Dependent.Principal),
nameof(Principal.Dependent),
ConfigurationSource.Convention)
ConfigurationSource.Convention)!
.HasEntityTypes
(principalEntityTypeBuilder.Metadata, dependentEntityTypeBuilder.Metadata, ConfigurationSource.Explicit);
(principalEntityTypeBuilder.Metadata, dependentEntityTypeBuilder.Metadata, ConfigurationSource.Explicit)!;

var navigation = principalEntityTypeBuilder.Metadata.FindNavigation(nameof(Principal.Dependent));
var navigation = principalEntityTypeBuilder.Metadata.FindNavigation(nameof(Principal.Dependent))!;

Assert.False(relationshipBuilder.Metadata.IsRequired);

Expand All @@ -116,7 +118,7 @@ public void Non_nullability_sets_is_required_with_conventional_builder()
modelBuilder.Entity<BlogDetails>();

Assert.True(
model.FindEntityType(typeof(BlogDetails)).GetForeignKeys().Single(fk => fk.PrincipalEntityType?.ClrType == typeof(Blog))
model.FindEntityType(typeof(BlogDetails))!.GetForeignKeys().Single(fk => fk.PrincipalEntityType?.ClrType == typeof(Blog))
.IsRequired);
}

Expand All @@ -125,7 +127,7 @@ private Navigation RunConvention(InternalForeignKeyBuilder relationshipBuilder,
var context = new ConventionContext<IConventionNavigationBuilder>(
relationshipBuilder.Metadata.DeclaringEntityType.Model.ConventionDispatcher);
CreateNotNullNavigationConvention().ProcessNavigationAdded(navigation.Builder, context);
return context.ShouldStopProcessing() ? (Navigation)context.Result?.Metadata : navigation;
return context.ShouldStopProcessing() ? (Navigation)context.Result?.Metadata! : navigation;
}

private NonNullableNavigationConvention CreateNotNullNavigationConvention()
Expand All @@ -146,15 +148,15 @@ private InternalEntityTypeBuilder CreateInternalEntityTypeBuilder<T>()

var modelBuilder = new Model(conventionSet).Builder;

return modelBuilder.Entity(typeof(T), ConfigurationSource.Explicit);
return modelBuilder.Entity(typeof(T), ConfigurationSource.Explicit)!;
}

private ModelBuilder CreateModelBuilder()
{
var serviceProvider = CreateServiceProvider();
return new ModelBuilder(
serviceProvider.GetService<IConventionSetBuilder>().CreateConventionSet(),
serviceProvider.GetService<ModelDependencies>());
serviceProvider.GetRequiredService<IConventionSetBuilder>().CreateConventionSet(),
serviceProvider.GetRequiredService<ModelDependencies>());
}

private ProviderConventionSetBuilderDependencies CreateDependencies()
Expand All @@ -179,9 +181,9 @@ protected IServiceProvider CreateServiceProvider()
return modelLogger;
}

#nullable enable
#pragma warning disable CS8618

// ReSharper disable UnusedMember.Local
// ReSharper disable ClassNeverInstantiated.Local
private class Blog
{
public int Id { get; set; }
Expand Down Expand Up @@ -245,6 +247,7 @@ private class Dependent
[ForeignKey("PrincipalId, PrincipalFk")]
public Principal? CompositePrincipal { get; set; }
}
// ReSharper restore ClassNeverInstantiated.Local
// ReSharper restore UnusedMember.Local
#pragma warning restore CS8618
#nullable disable
}
Loading