Skip to content

Commit

Permalink
Fix to #18555 - Query: when rewriting null semantics for comparisons …
Browse files Browse the repository at this point in the history
…with functions use function specific metadata to get better SQL

When we need to compute whether a function is null, we often can just evaluate nullability of it's constituents (instance & arguments), e.g.
SUBSTRING(stringProperty, 0, 5) == null -> stringProperty == null

Adding metadata to SqlFunctionExpression:
nullResultAllowed - indicates whether function can ever be null,
instancePropagatesNullability - indicates whether function instance can be used to calculate nullability of the entire function
argumentsPropagateNullability - array indicating which (if any) function arguments can be used to calculate nullability of the entire function

If "canBeNull" is set to false we can instantly compute IsNull/IsNotNull of that function.
Otherwise, we look at values of instancePropagatesNullability and argumentsPropagateNullability - if any of them are set to true, we use corresponding argument(s) to compute function nullability.
If all of them are set to false we must fallback to the old method and evaluate nullability of the entire function.
  • Loading branch information
maumar committed Jan 22, 2020
1 parent cbb5928 commit 9ca9ce7
Show file tree
Hide file tree
Showing 68 changed files with 1,865 additions and 651 deletions.
48 changes: 48 additions & 0 deletions src/EFCore.Relational/Query/ISqlExpressionFactory.cs
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,54 @@ SqlFunctionExpression Function(
[NotNull] Type returnType,
[CanBeNull] RelationalTypeMapping typeMapping = null);

SqlFunctionExpression Function(
[NotNull] string name,
[NotNull] IEnumerable<SqlExpression> arguments,
bool nullResultAllowed,
[NotNull] IEnumerable<bool> argumentsPropagateNullability,
[NotNull] Type returnType,
[CanBeNull] RelationalTypeMapping typeMapping = null);

SqlFunctionExpression Function(
[CanBeNull] string schema,
[NotNull] string name,
[NotNull] IEnumerable<SqlExpression> arguments,
bool nullResultAllowed,
[NotNull] IEnumerable<bool> argumentsPropagateNullability,
[NotNull] Type returnType,
[CanBeNull] RelationalTypeMapping typeMapping = null);

SqlFunctionExpression Function(
[CanBeNull] SqlExpression instance,
[NotNull] string name,
[NotNull] IEnumerable<SqlExpression> arguments,
bool nullResultAllowed,
bool instancePropagatesNullability,
[NotNull] IEnumerable<bool> argumentsPropagateNullability,
[NotNull] Type returnType,
[CanBeNull] RelationalTypeMapping typeMapping = null);

SqlFunctionExpression Function(
[NotNull] string name,
bool nullResultAllowed,
[NotNull] Type returnType,
[CanBeNull] RelationalTypeMapping typeMapping = null);

SqlFunctionExpression Function(
[NotNull] string schema,
[NotNull] string name,
bool nullResultAllowed,
[NotNull] Type returnType,
[CanBeNull] RelationalTypeMapping typeMapping = null);

SqlFunctionExpression Function(
[CanBeNull] SqlExpression instance,
[NotNull] string name,
bool nullResultAllowed,
bool instancePropagatesNullability,
[NotNull] Type returnType,
[CanBeNull] RelationalTypeMapping typeMapping = null);

ExistsExpression Exists([NotNull] SelectExpression subquery, bool negated);
InExpression In([NotNull] SqlExpression item, [NotNull] SqlExpression values, bool negated);
InExpression In([NotNull] SqlExpression item, [NotNull] SelectExpression subquery, bool negated);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -898,8 +898,7 @@ private SqlExpression RewriteNullSemantics(
return sqlBinaryExpression.Update(left, right);
}

private SqlExpression SimplifyLogicalSqlBinaryExpression(
SqlBinaryExpression sqlBinaryExpression)
private SqlExpression SimplifyLogicalSqlBinaryExpression(SqlBinaryExpression sqlBinaryExpression)
{
var leftUnary = sqlBinaryExpression.Left as SqlUnaryExpression;
var rightUnary = sqlBinaryExpression.Right as SqlUnaryExpression;
Expand Down Expand Up @@ -1253,37 +1252,96 @@ protected virtual SqlExpression ProcessNullNotNull(
sqlUnaryExpression.TypeMapping));
}

case SqlFunctionExpression sqlFunctionExpression
when sqlFunctionExpression.IsBuiltIn && string.Equals("COALESCE", sqlFunctionExpression.Name, StringComparison.OrdinalIgnoreCase):
case SqlFunctionExpression sqlFunctionExpression:
{
// for coalesce:
// (a ?? b) == null -> a == null && b == null
// (a ?? b) != null -> a != null || b != null
var left = ProcessNullNotNull(
SqlExpressionFactory.MakeUnary(
sqlUnaryExpression.OperatorType,
sqlFunctionExpression.Arguments[0],
typeof(bool),
sqlUnaryExpression.TypeMapping),
operandNullable: null);
if (sqlFunctionExpression.IsBuiltIn && string.Equals("COALESCE", sqlFunctionExpression.Name, StringComparison.OrdinalIgnoreCase))
{
// for coalesce:
// (a ?? b) == null -> a == null && b == null
// (a ?? b) != null -> a != null || b != null
var left = ProcessNullNotNull(
SqlExpressionFactory.MakeUnary(
sqlUnaryExpression.OperatorType,
sqlFunctionExpression.Arguments[0],
typeof(bool),
sqlUnaryExpression.TypeMapping),
operandNullable: null);

var right = ProcessNullNotNull(
SqlExpressionFactory.MakeUnary(
sqlUnaryExpression.OperatorType,
sqlFunctionExpression.Arguments[1],
typeof(bool),
sqlUnaryExpression.TypeMapping),
operandNullable: null);

var right = ProcessNullNotNull(
SqlExpressionFactory.MakeUnary(
sqlUnaryExpression.OperatorType,
sqlFunctionExpression.Arguments[1],
typeof(bool),
sqlUnaryExpression.TypeMapping),
operandNullable: null);
return SimplifyLogicalSqlBinaryExpression(
SqlExpressionFactory.MakeBinary(
sqlUnaryExpression.OperatorType == ExpressionType.Equal
? ExpressionType.AndAlso
: ExpressionType.OrElse,
left,
right,
sqlUnaryExpression.TypeMapping));
}

return SimplifyLogicalSqlBinaryExpression(
SqlExpressionFactory.MakeBinary(
sqlUnaryExpression.OperatorType == ExpressionType.Equal
? ExpressionType.AndAlso
: ExpressionType.OrElse,
left,
right,
sqlUnaryExpression.TypeMapping));
if (!sqlFunctionExpression.NullResultAllowed)
{
// when we know that function can't be nullable:
// non_nullable_function() is null-> false
// non_nullable_function() is not null -> true
return SqlExpressionFactory.Constant(
sqlUnaryExpression.OperatorType == ExpressionType.NotEqual,
sqlUnaryExpression.TypeMapping);
}

// see if we can derive function nullability from it's instance and/or arguments
// rather than evaluating nullability of the entire function
var nullabilityPropagationElements = new List<SqlExpression>();
if (sqlFunctionExpression.Instance != null
&& sqlFunctionExpression.InstancPropagatesNullability == true)
{
nullabilityPropagationElements.Add(sqlFunctionExpression.Instance);
}

for (var i = 0; i < sqlFunctionExpression.Arguments.Count; i++)
{
if (sqlFunctionExpression.ArgumentsPropagateNullability[i])
{
nullabilityPropagationElements.Add(sqlFunctionExpression.Arguments[i]);
}
}

if (nullabilityPropagationElements.Count > 0)
{
var result = ProcessNullNotNull(
SqlExpressionFactory.MakeUnary(
sqlUnaryExpression.OperatorType,
nullabilityPropagationElements[0],
sqlUnaryExpression.Type,
sqlUnaryExpression.TypeMapping),
operandNullable: null);

foreach (var element in nullabilityPropagationElements.Skip(1))
{
result = SimplifyLogicalSqlBinaryExpression(
sqlUnaryExpression.OperatorType == ExpressionType.Equal
? SqlExpressionFactory.OrElse(
result,
ProcessNullNotNull(
SqlExpressionFactory.IsNull(element),
operandNullable: null))
: SqlExpressionFactory.AndAlso(
result,
ProcessNullNotNull(
SqlExpressionFactory.IsNotNull(element),
operandNullable: null)));
}

return result;
}
}
break;
}

return sqlUnaryExpression;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,8 @@ public virtual SqlExpression Translate(
dbFunction.Schema,
dbFunction.Name,
arguments,
nullResultAllowed: true,
argumentsPropagateNullability: arguments.Select(a => true).ToList(),
method.ReturnType);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -99,11 +99,20 @@ public virtual SqlExpression TranslateAverage([NotNull] Expression expression)
return inputType == typeof(float)
? SqlExpressionFactory.Convert(
SqlExpressionFactory.Function(
"AVG", new[] { sqlExpression }, typeof(double)),
"AVG",
new[] { sqlExpression },
nullResultAllowed: true,
argumentsPropagateNullability: new[] { false },
typeof(double)),
sqlExpression.Type,
sqlExpression.TypeMapping)
: (SqlExpression)SqlExpressionFactory.Function(
"AVG", new[] { sqlExpression }, sqlExpression.Type, sqlExpression.TypeMapping);
"AVG",
new[] { sqlExpression },
nullResultAllowed: true,
argumentsPropagateNullability: new[] { false },
sqlExpression.Type,
sqlExpression.TypeMapping);
}

public virtual SqlExpression TranslateCount([CanBeNull] Expression expression = null)
Expand All @@ -115,7 +124,12 @@ public virtual SqlExpression TranslateCount([CanBeNull] Expression expression =
}

return SqlExpressionFactory.ApplyDefaultTypeMapping(
SqlExpressionFactory.Function("COUNT", new[] { SqlExpressionFactory.Fragment("*") }, typeof(int)));
SqlExpressionFactory.Function(
"COUNT",
new[] { SqlExpressionFactory.Fragment("*") },
nullResultAllowed: false,
argumentsPropagateNullability: new[] { false },
typeof(int)));
}

public virtual SqlExpression TranslateLongCount([CanBeNull] Expression expression = null)
Expand All @@ -127,7 +141,12 @@ public virtual SqlExpression TranslateLongCount([CanBeNull] Expression expressio
}

return SqlExpressionFactory.ApplyDefaultTypeMapping(
SqlExpressionFactory.Function("COUNT", new[] { SqlExpressionFactory.Fragment("*") }, typeof(long)));
SqlExpressionFactory.Function(
"COUNT",
new[] { SqlExpressionFactory.Fragment("*") },
nullResultAllowed: false,
argumentsPropagateNullability: new[] { false },
typeof(long)));
}

public virtual SqlExpression TranslateMax([NotNull] Expression expression)
Expand All @@ -140,7 +159,13 @@ public virtual SqlExpression TranslateMax([NotNull] Expression expression)
}

return sqlExpression != null
? SqlExpressionFactory.Function("MAX", new[] { sqlExpression }, sqlExpression.Type, sqlExpression.TypeMapping)
? SqlExpressionFactory.Function(
"MAX",
new[] { sqlExpression },
nullResultAllowed: true,
argumentsPropagateNullability: new[] { false },
sqlExpression.Type,
sqlExpression.TypeMapping)
: null;
}

Expand All @@ -154,7 +179,13 @@ public virtual SqlExpression TranslateMin([NotNull] Expression expression)
}

return sqlExpression != null
? SqlExpressionFactory.Function("MIN", new[] { sqlExpression }, sqlExpression.Type, sqlExpression.TypeMapping)
? SqlExpressionFactory.Function(
"MIN",
new[] { sqlExpression },
nullResultAllowed: true,
argumentsPropagateNullability: new[] { false },
sqlExpression.Type,
sqlExpression.TypeMapping)
: null;
}

Expand All @@ -176,11 +207,21 @@ public virtual SqlExpression TranslateSum([NotNull] Expression expression)

return inputType == typeof(float)
? SqlExpressionFactory.Convert(
SqlExpressionFactory.Function("SUM", new[] { sqlExpression }, typeof(double)),
SqlExpressionFactory.Function(
"SUM",
new[] { sqlExpression },
nullResultAllowed: true,
argumentsPropagateNullability: new[] { false },
typeof(double)),
inputType,
sqlExpression.TypeMapping)
: (SqlExpression)SqlExpressionFactory.Function(
"SUM", new[] { sqlExpression }, inputType, sqlExpression.TypeMapping);
"SUM",
new[] { sqlExpression },
nullResultAllowed: true,
argumentsPropagateNullability: new[] { false },
inputType,
sqlExpression.TypeMapping);
}

private sealed class SqlTypeMappingVerifyingExpressionVisitor : ExpressionVisitor
Expand Down
Loading

0 comments on commit 9ca9ce7

Please sign in to comment.