Skip to content

Commit

Permalink
Query: Preseve constant based naked Initialization expressions in exp…
Browse files Browse the repository at this point in the history
…ression tree

This puts some of the processing (evaluating the expression to the corresponding constant) inside translation pipeline.
- When applying EntityEquality, assumption here is that property is always going to be mapped on server side so we can generate constant.
- When translating newExpression. If the generated constant can be mapped, it would work. (like new Datetime()) else translation would null out.

Resolves #15712
Resolves #17048
Resolves #7983
  • Loading branch information
smitpatel committed Aug 14, 2019
1 parent 29b269a commit c1392d6
Show file tree
Hide file tree
Showing 15 changed files with 230 additions and 47 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,14 @@ public virtual SqlExpression Translate(Expression expression)
{
translation = _sqlExpressionFactory.ApplyDefaultTypeMapping(translation);

if ((translation is SqlConstantExpression
|| translation is SqlParameterExpression)
&& translation.TypeMapping == null)
{
// Non-mappable constant/parameter
return null;
}

_sqlVerifyingExpressionVisitor.Visit(translation);

return translation;
Expand Down Expand Up @@ -339,21 +347,52 @@ protected override Expression VisitUnary(UnaryExpression unaryExpression)
return null;
}

private SqlConstantExpression GetConstantOrNull(Expression expression)
{
if (CanEvaluate(expression))
{
var value = Expression.Lambda<Func<object>>(Expression.Convert(expression, typeof(object))).Compile().Invoke();
return new SqlConstantExpression(Expression.Constant(value, expression.Type), null);
}

return null;
}

private static bool CanEvaluate(Expression expression)
{
switch (expression)
{
case ConstantExpression constantExpression:
return true;

case NewExpression newExpression:
return newExpression.Arguments.All(e => CanEvaluate(e));

case MemberInitExpression memberInitExpression:
return CanEvaluate(memberInitExpression.NewExpression)
&& memberInitExpression.Bindings.All(
mb => mb is MemberAssignment memberAssignment && CanEvaluate(memberAssignment.Expression));

default:
return false;
}
}

/// <summary>
/// This is an internal API that supports the Entity Framework Core infrastructure and not subject to
/// the same compatibility standards as public APIs. It may be changed or removed without notice in
/// any release. You should only use it directly in your code with extreme caution and knowing that
/// doing so can result in application failures when updating to a new Entity Framework Core release.
/// </summary>
protected override Expression VisitNew(NewExpression node) => null;
protected override Expression VisitNew(NewExpression node) => GetConstantOrNull(node);

/// <summary>
/// This is an internal API that supports the Entity Framework Core infrastructure and not subject to
/// the same compatibility standards as public APIs. It may be changed or removed without notice in
/// any release. You should only use it directly in your code with extreme caution and knowing that
/// doing so can result in application failures when updating to a new Entity Framework Core release.
/// </summary>
protected override Expression VisitMemberInit(MemberInitExpression node) => null;
protected override Expression VisitMemberInit(MemberInitExpression node) => GetConstantOrNull(node);

/// <summary>
/// This is an internal API that supports the Entity Framework Core infrastructure and not subject to
Expand Down
9 changes: 8 additions & 1 deletion src/EFCore.Cosmos/Query/Internal/QuerySqlGenerator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,14 @@ protected override Expression VisitSelect(SelectExpression selectExpression)
_sqlBuilder.Append("DISTINCT ");
}

GenerateList(selectExpression.Projection, t => Visit(t));
if (selectExpression.Projection.Any())
{
GenerateList(selectExpression.Projection, e => Visit(e));
}
else
{
_sqlBuilder.Append("1");
}
_sqlBuilder.AppendLine();

_sqlBuilder.Append("FROM root ");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -302,7 +302,7 @@ protected override ShapedQueryExpression TranslateGroupBy(
var remappedKeySelector = RemapLambdaBody(source, keySelector);

var translatedKey = TranslateGroupingKey(remappedKeySelector)
?? (remappedKeySelector is ConstantExpression ? remappedKeySelector : null);
?? (remappedKeySelector as ConstantExpression);
if (translatedKey != null)
{
if (elementSelector != null)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ public class RelationalSqlTranslatingExpressionVisitor : ExpressionVisitor
private readonly IModel _model;
private readonly QueryableMethodTranslatingExpressionVisitor _queryableMethodTranslatingExpressionVisitor;
private readonly ISqlExpressionFactory _sqlExpressionFactory;
private readonly SqlTypeMappingVerifyingExpressionVisitor _sqlVerifyingExpressionVisitor;
private readonly SqlTypeMappingVerifyingExpressionVisitor _sqlTypeMappingVerifyingExpressionVisitor;

public RelationalSqlTranslatingExpressionVisitor(
RelationalSqlTranslatingExpressionVisitorDependencies dependencies,
Expand All @@ -31,7 +31,7 @@ public RelationalSqlTranslatingExpressionVisitor(
_model = model;
_queryableMethodTranslatingExpressionVisitor = queryableMethodTranslatingExpressionVisitor;
_sqlExpressionFactory = dependencies.SqlExpressionFactory;
_sqlVerifyingExpressionVisitor = new SqlTypeMappingVerifyingExpressionVisitor();
_sqlTypeMappingVerifyingExpressionVisitor = new SqlTypeMappingVerifyingExpressionVisitor();
}

protected virtual RelationalSqlTranslatingExpressionVisitorDependencies Dependencies { get; }
Expand All @@ -51,14 +51,15 @@ public virtual SqlExpression Translate(Expression expression)

translation = _sqlExpressionFactory.ApplyDefaultTypeMapping(translation);

if (translation is SqlConstantExpression
if ((translation is SqlConstantExpression
|| translation is SqlParameterExpression)
&& translation.TypeMapping == null)
{
// Non-mappable constant
// Non-mappable constant/parameter
return null;
}

_sqlVerifyingExpressionVisitor.Visit(translation);
_sqlTypeMappingVerifyingExpressionVisitor.Visit(translation);

return translation;
}
Expand Down Expand Up @@ -442,9 +443,40 @@ protected override Expression VisitBinary(BinaryExpression binaryExpression)
null);
}

protected override Expression VisitNew(NewExpression node) => null;
private SqlConstantExpression GetConstantOrNull(Expression expression)
{
if (CanEvaluate(expression))
{
var value = Expression.Lambda<Func<object>>(Expression.Convert(expression, typeof(object))).Compile().Invoke();
return new SqlConstantExpression(Expression.Constant(value, expression.Type), null);
}

return null;
}

private static bool CanEvaluate(Expression expression)
{
switch (expression)
{
case ConstantExpression constantExpression:
return true;

case NewExpression newExpression:
return newExpression.Arguments.All(e => CanEvaluate(e));

case MemberInitExpression memberInitExpression:
return CanEvaluate(memberInitExpression.NewExpression)
&& memberInitExpression.Bindings.All(
mb => mb is MemberAssignment memberAssignment && CanEvaluate(memberAssignment.Expression));

default:
return false;
}
}

protected override Expression VisitNew(NewExpression node) => GetConstantOrNull(node);

protected override Expression VisitMemberInit(MemberInitExpression node) => null;
protected override Expression VisitMemberInit(MemberInitExpression node) => GetConstantOrNull(node);

protected override Expression VisitNewArray(NewArrayExpression node) => null;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -914,6 +914,13 @@ private Expression CreatePropertyAccessExpression(Expression target, IProperty p
return Expression.Constant(property.GetGetter().GetClrValue(constantExpression.Value), property.ClrType.MakeNullable());
}

// The target is complex which can be evaluated to Constant.
if (CanEvaluate(target))
{
var value = Expression.Lambda<Func<object>>(Expression.Convert(target, typeof(object))).Compile().Invoke();
return Expression.Constant(property.GetGetter().GetClrValue(value), property.ClrType.MakeNullable());
}

// If the target is a query parameter, we can't simply add a property access over it, but must instead cause a new
// parameter to be added at runtime, with the value of the property on the base parameter.
if (target is ParameterExpression baseParameterExpression
Expand All @@ -936,6 +943,26 @@ private Expression CreatePropertyAccessExpression(Expression target, IProperty p
return target.CreateEFPropertyExpression(property, true);
}

private static bool CanEvaluate(Expression expression)
{
switch (expression)
{
case ConstantExpression constantExpression:
return true;

case NewExpression newExpression:
return newExpression.Arguments.All(e => CanEvaluate(e));

case MemberInitExpression memberInitExpression:
return CanEvaluate(memberInitExpression.NewExpression)
&& memberInitExpression.Bindings.All(
mb => mb is MemberAssignment memberAssignment && CanEvaluate(memberAssignment.Expression));

default:
return false;
}
}

private static object ParameterValueExtractor(QueryContext context, string baseParameterName, IProperty property)
{
var baseParameter = context.ParameterValues[baseParameterName];
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,7 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCallExp
}

if (methodCallExpression.Arguments.Count > 0
&& (methodCallExpression.Arguments[0] is ParameterExpression
|| methodCallExpression.Arguments[0] is ConstantExpression))
&& ClientSource(methodCallExpression.Arguments[0]))
{
// this is methodCall over closure variable or constant
return base.VisitMethodCall(methodCallExpression);
Expand Down Expand Up @@ -137,8 +136,7 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCallExp
&& methodCallExpression.Method.DeclaringType.GetGenericTypeDefinition() == typeof(List<>)
&& string.Equals(nameof(List<int>.Contains), methodCallExpression.Method.Name))
{
if (methodCallExpression.Object is ParameterExpression
|| methodCallExpression.Object is ConstantExpression)
if (ClientSource(methodCallExpression.Object))
{
// this is methodCall over closure variable or constant
return base.VisitMethodCall(methodCallExpression);
Expand All @@ -157,6 +155,12 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCallExp
return base.VisitMethodCall(methodCallExpression);
}

private static bool ClientSource(Expression expression)
=> expression is ConstantExpression
|| expression is MemberInitExpression
|| expression is NewExpression
|| expression is ParameterExpression;

private static bool CanConvertEnumerableToQueryable(Type enumerableType, Type queryableType)
{
if (enumerableType == typeof(IEnumerable)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ public override Expression Visit(Expression expression)
}

if (_evaluatableExpressions.TryGetValue(expression, out var generateParameter)
&& !PreserveInitializationConstant(expression, generateParameter)
&& !PreserveConvertNode(expression))
{
return Evaluate(expression, _parameterize && generateParameter);
Expand All @@ -108,13 +109,10 @@ public override Expression Visit(Expression expression)
return base.Visit(expression);
}

/// <summary>
/// This is an internal API that supports the Entity Framework Core infrastructure and not subject to
/// the same compatibility standards as public APIs. It may be changed or removed without notice in
/// any release. You should only use it directly in your code with extreme caution and knowing that
/// doing so can result in application failures when updating to a new Entity Framework Core release.
/// </summary>
protected virtual bool PreserveConvertNode(Expression expression)
private bool PreserveInitializationConstant(Expression expression, bool generateParameter)
=> !generateParameter && (expression is NewExpression || expression is MemberInitExpression);

private bool PreserveConvertNode(Expression expression)
{
if (expression is UnaryExpression unaryExpression
&& (unaryExpression.NodeType == ExpressionType.Convert
Expand Down
15 changes: 10 additions & 5 deletions src/EFCore/Query/ReplacingExpressionVisitor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -72,15 +72,20 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCallExp
var newEntityExpression = Visit(entityExpression);
if (newEntityExpression is NewExpression newExpression)
{
var index = newExpression.Members.Select(m => m.Name).IndexOf(propertyName);

return newExpression.Arguments[index];
var index = newExpression.Members?.Select(m => m.Name).IndexOf(propertyName);
if (index > 0)
{
return newExpression.Arguments[index.Value];
}
}

if (newEntityExpression is MemberInitExpression memberInitExpression)
{
return ((MemberAssignment)memberInitExpression.Bindings
.Single(mb => mb.Member.Name == propertyName)).Expression;
if (memberInitExpression.Bindings.SingleOrDefault(
mb => mb.Member.Name == propertyName) is MemberAssignment memberAssignment)
{
return memberAssignment.Expression;
}
}

return methodCallExpression.Update(null, new[] { newEntityExpression, methodCallExpression.Arguments[1] });
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -364,6 +364,7 @@ FROM root c
WHERE (c[""Discriminator""] = ""Customer"")");
}

[ConditionalTheory(Skip = "Issue#14935")]
public override async Task New_date_time_in_anonymous_type_works(bool isAsync)
{
await base.New_date_time_in_anonymous_type_works(isAsync);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6259,7 +6259,7 @@ public virtual Task Select_subquery_projecting_single_constant_bool(bool isAsync
}));
}

[ConditionalTheory(Skip = "issue #15712")]
[ConditionalTheory(Skip = "Issue#10001")]
[MemberData(nameof(IsAsyncData))]
public virtual Task Select_subquery_projecting_single_constant_inside_anonymous(bool isAsync)
{
Expand All @@ -6277,7 +6277,7 @@ public virtual Task Select_subquery_projecting_single_constant_inside_anonymous(
}));
}

[ConditionalTheory(Skip = "issue #15712")]
[ConditionalTheory(Skip = "Issue#10001")]
[MemberData(nameof(IsAsyncData))]
public virtual Task Select_subquery_projecting_multiple_constants_inside_anonymous(bool isAsync)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1277,7 +1277,7 @@ public virtual Task GroupBy_empty_key_Aggregate(bool isAsync)
.Select(g => g.Sum(o => o.OrderID)));
}

[ConditionalTheory(Skip = "Issue#17048")]
[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task GroupBy_empty_key_Aggregate_Key(bool isAsync)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ private static void AssertArrays<T>(object e, object a, int count)
}
}

[ConditionalTheory(Skip = "Issue #15712")]
[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual async Task Select_bool_closure(bool isAsync)
{
Expand Down Expand Up @@ -309,7 +309,7 @@ public virtual Task Select_anonymous_nested(bool isAsync)
e => e.City);
}

[ConditionalTheory(Skip = "Issue#15712")]
[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task Select_anonymous_empty(bool isAsync)
{
Expand All @@ -322,7 +322,7 @@ public virtual Task Select_anonymous_empty(bool isAsync)
e => 1);
}

[ConditionalTheory(Skip = "Issue#15712")]
[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task Select_anonymous_literal(bool isAsync)
{
Expand Down Expand Up @@ -620,7 +620,7 @@ orderby o2.OrderID
});
}

[ConditionalTheory(Skip = "Issue#15712")]
[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task New_date_time_in_anonymous_type_works(bool isAsync)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4308,7 +4308,7 @@ public virtual Task Select_expression_int_to_string(bool isAsync)
e => e.ShipName);
}

[ConditionalTheory(Skip = "Issue#17048")]
[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual async Task ToString_with_formatter_is_evaluated_on_the_client(bool isAsync)
{
Expand Down
Loading

0 comments on commit c1392d6

Please sign in to comment.