Skip to content

Commit

Permalink
Query: Make keyless entity type materialization checks part of Discri…
Browse files Browse the repository at this point in the history
…minatorCondition

If DiscriminatorCondition returns null for IEntityType then return null instance.

Part of #18923
  • Loading branch information
smitpatel committed Mar 18, 2020
1 parent e72297c commit 20481dd
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 50 deletions.
30 changes: 21 additions & 9 deletions src/EFCore/Query/EntityShaperExpression.cs
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,7 @@ protected EntityShaperExpression(

if (discriminatorCondition == null)
{
// Generate condition to discriminator if TPH
discriminatorCondition = GenerateDiscriminatorCondition(entityType);

discriminatorCondition = GenerateDiscriminatorCondition(entityType, nullable);
}
else if (discriminatorCondition.Parameters.Count != 1
|| discriminatorCondition.Parameters[0].Type != typeof(ValueBuffer)
Expand All @@ -63,11 +61,11 @@ protected EntityShaperExpression(
DiscriminatorCondition = discriminatorCondition;
}

private LambdaExpression GenerateDiscriminatorCondition(IEntityType entityType)
private LambdaExpression GenerateDiscriminatorCondition(IEntityType entityType, bool nullable)
{
var valueBufferParameter = Parameter(typeof(ValueBuffer));
Expression body;
var concreteEntityTypes = entityType.GetConcreteDerivedTypesInclusive().ToList();
var concreteEntityTypes = entityType.GetConcreteDerivedTypesInclusive().ToArray();
var discriminatorProperty = entityType.GetDiscriminatorProperty();
if (discriminatorProperty != null)
{
Expand All @@ -80,8 +78,8 @@ private LambdaExpression GenerateDiscriminatorCondition(IEntityType entityType)
discriminatorProperty.ClrType, discriminatorProperty.GetIndex(), discriminatorProperty))
};

var switchCases = new SwitchCase[concreteEntityTypes.Count];
for (var i = 0; i < concreteEntityTypes.Count; i++)
var switchCases = new SwitchCase[concreteEntityTypes.Length];
for (var i = 0; i < concreteEntityTypes.Length; i++)
{
var discriminatorValue = Constant(concreteEntityTypes[i].GetDiscriminatorValue(), discriminatorProperty.ClrType);
switchCases[i] = SwitchCase(Constant(concreteEntityTypes[i], typeof(IEntityType)), discriminatorValue);
Expand All @@ -97,7 +95,20 @@ private LambdaExpression GenerateDiscriminatorCondition(IEntityType entityType)
}
else
{
body = Constant(concreteEntityTypes.Count == 1 ? concreteEntityTypes[0] : entityType, typeof(IEntityType));
body = Constant(concreteEntityTypes.Length == 1 ? concreteEntityTypes[0] : entityType, typeof(IEntityType));
}

if (entityType.FindPrimaryKey() == null
&& nullable)
{
body = Condition(
entityType.GetProperties()
.Select(p => NotEqual(
valueBufferParameter.CreateValueBufferReadValueExpression(typeof(object), p.GetIndex(), p),
Constant(null)))
.Aggregate((a, b) => OrElse(a, b)),
body,
Default(typeof(IEntityType)));
}

return Lambda(body, valueBufferParameter);
Expand Down Expand Up @@ -128,7 +139,8 @@ public virtual EntityShaperExpression WithEntityType([NotNull] IEntityType entit

public virtual EntityShaperExpression MarkAsNullable()
=> !IsNullable
? new EntityShaperExpression(EntityType, ValueBufferExpression, true, DiscriminatorCondition)
// Marking nullable requires recomputation of Discriminator condition
? new EntityShaperExpression(EntityType, ValueBufferExpression, true)
: this;

public virtual EntityShaperExpression Update([NotNull] Expression valueBufferExpression)
Expand Down
59 changes: 18 additions & 41 deletions src/EFCore/Query/ShapedQueryCompilingExpressionVisitor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -414,29 +414,9 @@ private Expression ProcessEntityShaper(EntityShaperExpression entityShaperExpres
}
else
{
if (entityShaperExpression.IsNullable)
{
expressions.Add(
Expression.IfThen(
entityType.GetProperties()
.Select(
p =>
Expression.NotEqual(
valueBufferExpression.CreateValueBufferReadValueExpression(
typeof(object),
p.GetIndex(),
p),
Expression.Constant(null)))
.Aggregate((a, b) => Expression.OrElse(a, b)),
MaterializeEntity(
entityShaperExpression, materializationContextVariable, concreteEntityTypeVariable, instanceVariable, null)));
}
else
{
expressions.Add(
MaterializeEntity(
entityShaperExpression, materializationContextVariable, concreteEntityTypeVariable, instanceVariable, null));
}
expressions.Add(
MaterializeEntity(
entityShaperExpression, materializationContextVariable, concreteEntityTypeVariable, instanceVariable, null));
}
}

Expand Down Expand Up @@ -476,30 +456,27 @@ private Expression MaterializeEntity(
valueBufferExpression,
entityShaperExpression.DiscriminatorCondition.Body)));

var concreteEntityTypes = entityType.GetConcreteDerivedTypesInclusive().ToList();
var concreteEntityTypes = entityType.GetConcreteDerivedTypesInclusive().ToArray();
var discriminatorProperty = entityType.GetDiscriminatorProperty();
if (discriminatorProperty != null)
if (discriminatorProperty == null
&& concreteEntityTypes.Length > 1)
{
var switchCases = new SwitchCase[concreteEntityTypes.Count];
for (var i = 0; i < concreteEntityTypes.Count; i++)
{
switchCases[i] = Expression.SwitchCase(
CreateFullMaterializeExpression(concreteEntityTypes[i], expressionContext),
Expression.Constant(concreteEntityTypes[i], typeof(IEntityType)));
}

materializationExpression = Expression.Switch(
concreteEntityTypeVariable,
Expression.Constant(null, returnType),
switchCases);
concreteEntityTypes = new [] { entityType };
}
else

var switchCases = new SwitchCase[concreteEntityTypes.Length];
for (var i = 0; i < concreteEntityTypes.Length; i++)
{
materializationExpression = CreateFullMaterializeExpression(
concreteEntityTypes.Count == 1 ? concreteEntityTypes[0] : entityType,
expressionContext);
switchCases[i] = Expression.SwitchCase(
CreateFullMaterializeExpression(concreteEntityTypes[i], expressionContext),
Expression.Constant(concreteEntityTypes[i], typeof(IEntityType)));
}

materializationExpression = Expression.Switch(
concreteEntityTypeVariable,
Expression.Constant(null, returnType),
switchCases);

expressions.Add(Expression.Assign(instanceVariable, materializationExpression));

if (_trackQueryResults
Expand Down

0 comments on commit 20481dd

Please sign in to comment.