Skip to content

Commit

Permalink
Fix in-memory issues with value comparers (#29745)
Browse files Browse the repository at this point in the history
  • Loading branch information
ajcvickers authored Dec 8, 2022
1 parent 906a2af commit 2e7a4a1
Show file tree
Hide file tree
Showing 6 changed files with 3,292 additions and 104 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -253,80 +253,16 @@ static Expression RemoveConvert(Expression e)
newRight = ConvertToNullable(newRight);
}

if (binaryExpression.NodeType == ExpressionType.Equal
|| binaryExpression.NodeType == ExpressionType.NotEqual)
if ((binaryExpression.NodeType == ExpressionType.Equal
|| binaryExpression.NodeType == ExpressionType.NotEqual)
&& TryUseComparer(newLeft, newRight, out var updatedExpression))
{
var property = FindProperty(newLeft) ?? FindProperty(newRight);
var comparer = property?.GetValueComparer();

if (comparer != null)
if (binaryExpression.NodeType == ExpressionType.NotEqual)
{
MethodInfo? objectEquals = null;
MethodInfo? exactMatch = null;

var converter = property?.GetValueConverter();
foreach (var candidate in comparer
.GetType()
.GetMethods(BindingFlags.Public | BindingFlags.Instance)
.Where(
m => m.Name == "Equals" && m.GetParameters().Length == 2)
.ToList())
{
var parameters = candidate.GetParameters();
var leftType = parameters[0].ParameterType;
var rightType = parameters[1].ParameterType;

if (leftType == typeof(object)
&& rightType == typeof(object))
{
objectEquals = candidate;
continue;
}

var matchingLeft = leftType.IsAssignableFrom(newLeft.Type)
? newLeft
: converter != null && leftType.IsAssignableFrom(converter.ModelClrType)
? ReplacingExpressionVisitor.Replace(
converter.ConvertFromProviderExpression.Parameters.Single(),
newLeft,
converter.ConvertFromProviderExpression.Body)
: null;

var matchingRight = rightType.IsAssignableFrom(newRight.Type)
? newRight
: converter != null && rightType.IsAssignableFrom(converter.ModelClrType)
? ReplacingExpressionVisitor.Replace(
converter.ConvertFromProviderExpression.Parameters.Single(),
newRight,
converter.ConvertFromProviderExpression.Body)
: null;

if (matchingLeft != null && matchingRight != null)
{
exactMatch = candidate;
newLeft = matchingLeft;
newRight = matchingRight;
break;
}
}

var equalsExpression =
exactMatch != null
? Expression.Call(
Expression.Constant(comparer, comparer.GetType()),
exactMatch,
newLeft,
newRight)
: Expression.Call(
Expression.Constant(comparer, comparer.GetType()),
objectEquals!,
Expression.Convert(newLeft, typeof(object)),
Expression.Convert(newRight, typeof(object)));

return binaryExpression.NodeType == ExpressionType.NotEqual
? Expression.IsFalse(equalsExpression)
: equalsExpression;
updatedExpression = Expression.IsFalse(updatedExpression!);
}

return updatedExpression!;
}

return Expression.MakeBinary(
Expand Down Expand Up @@ -408,6 +344,103 @@ static bool IsTypeConstant(Expression expression, out Type? type)
}
}

private static bool TryUseComparer(
Expression? newLeft,
Expression? newRight,
out Expression? updatedExpression)
{
updatedExpression = null;

if (newLeft == null
|| newRight == null)
{
return false;
}

var property = FindProperty(newLeft) ?? FindProperty(newRight);
var comparer = property?.GetValueComparer();

if (comparer == null)
{
return false;
}

MethodInfo? objectEquals = null;
MethodInfo? exactMatch = null;

var converter = property?.GetValueConverter();
foreach (var candidate in comparer
.GetType()
.GetMethods(BindingFlags.Public | BindingFlags.Instance)
.Where(
m => m.Name == "Equals" && m.GetParameters().Length == 2)
.ToList())
{
var parameters = candidate.GetParameters();
var leftType = parameters[0].ParameterType;
var rightType = parameters[1].ParameterType;

if (leftType == typeof(object)
&& rightType == typeof(object))
{
objectEquals = candidate;
continue;
}

var matchingLeft = leftType.IsAssignableFrom(newLeft.Type)
? newLeft
: converter != null
&& leftType.IsAssignableFrom(converter.ModelClrType)
&& converter.ProviderClrType.IsAssignableFrom(newLeft.Type)
? ReplacingExpressionVisitor.Replace(
converter.ConvertFromProviderExpression.Parameters.Single(),
newLeft,
converter.ConvertFromProviderExpression.Body)
: null;

var matchingRight = rightType.IsAssignableFrom(newRight.Type)
? newRight
: converter != null
&& rightType.IsAssignableFrom(converter.ModelClrType)
&& converter.ProviderClrType.IsAssignableFrom(newRight.Type)
? ReplacingExpressionVisitor.Replace(
converter.ConvertFromProviderExpression.Parameters.Single(),
newRight,
converter.ConvertFromProviderExpression.Body)
: null;

if (matchingLeft != null && matchingRight != null)
{
exactMatch = candidate;
newLeft = matchingLeft;
newRight = matchingRight;
break;
}
}

if (exactMatch == null
&& (!property!.ClrType.IsAssignableFrom(newLeft.Type))
|| !property!.ClrType.IsAssignableFrom(newRight.Type))
{
return false;
}

updatedExpression =
exactMatch != null
? Expression.Call(
Expression.Constant(comparer, comparer.GetType()),
exactMatch,
newLeft,
newRight)
: Expression.Call(
Expression.Constant(comparer, comparer.GetType()),
objectEquals!,
Expression.Convert(newLeft, typeof(object)),
Expression.Convert(newRight, typeof(object)));

return true;
}

/// <summary>
/// This is an internal API that supports the Entity Framework Core infrastructure and not subject to
/// the same compatibility standards as public APIs. It may be changed or removed without notice in
Expand Down Expand Up @@ -750,6 +783,11 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCallExp
var left = Visit(methodCallExpression.Arguments[0]);
var right = Visit(methodCallExpression.Arguments[1]);

if (TryUseComparer(left, right, out var updatedExpression))
{
return updatedExpression!;
}

if (TryRewriteEntityEquality(
ExpressionType.Equal,
left == QueryCompilationContext.NotTranslatedExpression ? methodCallExpression.Arguments[0] : left,
Expand Down Expand Up @@ -1192,12 +1230,19 @@ private static Expression ConvertToNonNullable(Expression expression)
? Expression.Convert(expression, expression.Type.UnwrapNullableType())
: expression;

private static IProperty? FindProperty(Expression expression)
private static IProperty? FindProperty(Expression? expression)
{
if (expression.NodeType == ExpressionType.Convert
if (expression?.NodeType == ExpressionType.Convert
&& expression.Type == typeof(object))
{
expression = ((UnaryExpression)expression).Operand;
}

if (expression?.NodeType == ExpressionType.Convert
&& expression.Type.IsNullableType()
&& expression is UnaryExpression unaryExpression
&& expression.Type.UnwrapNullableType() == unaryExpression.Type)
&& (expression.Type.UnwrapNullableType() == unaryExpression.Type
|| expression.Type == unaryExpression.Type))
{
expression = unaryExpression.Operand;
}
Expand Down
57 changes: 39 additions & 18 deletions src/EFCore.InMemory/Query/Internal/InMemoryQueryExpression.cs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ private static readonly PropertyInfo ValueBufferCountMemberInfo
= typeof(ValueBuffer).GetTypeInfo().GetProperty(nameof(ValueBuffer.Count))!;

private static readonly MethodInfo LeftJoinMethodInfo = typeof(InMemoryQueryExpression).GetTypeInfo()
.GetDeclaredMethods(nameof(LeftJoin)).Single(mi => mi.GetParameters().Length == 6);
.GetDeclaredMethods(nameof(LeftJoin)).Single(mi => mi.GetParameters().Length == 7);

private static readonly ConstructorInfo ResultEnumerableConstructor
= typeof(ResultEnumerable).GetConstructors().Single();
Expand Down Expand Up @@ -706,9 +706,8 @@ public virtual EntityShaperExpression AddNavigationToWeakEntityType(
outerKeySelector,
innerKeySelector,
resultSelector,
Constant(
new ValueBuffer(
Enumerable.Repeat((object?)null, selectorExpressions.Count - outerIndex).ToArray())));
Constant(new ValueBuffer(Enumerable.Repeat((object?)null, selectorExpressions.Count - outerIndex).ToArray())),
Constant(null, typeof(IEqualityComparer<>).MakeGenericType(outerKeySelector.ReturnType)));

var entityShaper = new EntityShaperExpression(innerEntityProjection.EntityType, innerEntityProjection, nullable: true);
entityProjectionExpression.AddNavigationBinding(navigation, entityShaper);
Expand Down Expand Up @@ -872,7 +871,8 @@ private static Expression GetGroupingKey(Expression key, List<Expression> groupi

case EntityShaperExpression entityShaperExpression
when entityShaperExpression.ValueBufferExpression is ProjectionBindingExpression projectionBindingExpression:
var entityProjectionExpression = (EntityProjectionExpression)((InMemoryQueryExpression)projectionBindingExpression.QueryExpression)
var entityProjectionExpression =
(EntityProjectionExpression)((InMemoryQueryExpression)projectionBindingExpression.QueryExpression)
.GetProjection(projectionBindingExpression);
var readExpressions = new Dictionary<IProperty, MethodCallExpression>();
foreach (var property in GetAllPropertiesInHierarchy(entityProjectionExpression.EntityType))
Expand Down Expand Up @@ -1054,6 +1054,15 @@ private Expression AddJoin(
if (outerKeySelector != null
&& innerKeySelector != null)
{
var comparer = ((InferPropertyFromInner(outerKeySelector.Body)
?? InferPropertyFromInner(outerKeySelector.Body))
as IProperty)?.GetValueComparer();

if (comparer?.Type != outerKeySelector.ReturnType)
{
comparer = null;
}

if (innerNullable)
{
ServerQueryExpression = Call(
Expand All @@ -1064,20 +1073,29 @@ private Expression AddJoin(
outerKeySelector,
innerKeySelector,
resultSelector,
Constant(
new ValueBuffer(
Enumerable.Repeat((object?)null, resultSelectorExpressions.Count - outerIndex).ToArray())));
Constant(new ValueBuffer(Enumerable.Repeat((object?)null, resultSelectorExpressions.Count - outerIndex).ToArray())),
Constant(comparer, typeof(IEqualityComparer<>).MakeGenericType(outerKeySelector.ReturnType)));
}
else
{
ServerQueryExpression = Call(
EnumerableMethods.Join.MakeGenericMethod(
typeof(ValueBuffer), typeof(ValueBuffer), outerKeySelector.ReturnType, typeof(ValueBuffer)),
ServerQueryExpression,
innerQueryExpression.ServerQueryExpression,
outerKeySelector,
innerKeySelector,
resultSelector);
ServerQueryExpression = comparer == null
? Call(
EnumerableMethods.Join.MakeGenericMethod(
typeof(ValueBuffer), typeof(ValueBuffer), outerKeySelector.ReturnType, typeof(ValueBuffer)),
ServerQueryExpression,
innerQueryExpression.ServerQueryExpression,
outerKeySelector,
innerKeySelector,
resultSelector)
: Call(
EnumerableMethods.JoinWithComparer.MakeGenericMethod(
typeof(ValueBuffer), typeof(ValueBuffer), outerKeySelector.ReturnType, typeof(ValueBuffer)),
ServerQueryExpression,
innerQueryExpression.ServerQueryExpression,
outerKeySelector,
innerKeySelector,
resultSelector,
Constant(comparer, typeof(IEqualityComparer<>).MakeGenericType(outerKeySelector.ReturnType)));
}
}
else
Expand Down Expand Up @@ -1235,8 +1253,11 @@ private static IEnumerable<TResult> LeftJoin<TOuter, TInner, TKey, TResult>(
Func<TOuter, TKey> outerKeySelector,
Func<TInner, TKey> innerKeySelector,
Func<TOuter, TInner, TResult> resultSelector,
TInner defaultValue)
=> outer.GroupJoin(inner, outerKeySelector, innerKeySelector, (oe, ies) => new { oe, ies })
TInner defaultValue,
IEqualityComparer<TKey>? comparer)
=> (comparer == null
? outer.GroupJoin(inner, outerKeySelector, innerKeySelector, (oe, ies) => new { oe, ies })
: outer.GroupJoin(inner, outerKeySelector, innerKeySelector, (oe, ies) => new { oe, ies }, comparer))
.SelectMany(t => t.ies.DefaultIfEmpty(defaultValue), (t, i) => resultSelector(t.oe, i));

private static MethodCallExpression MakeReadValueNullable(Expression expression)
Expand Down
Loading

0 comments on commit 2e7a4a1

Please sign in to comment.