Skip to content

Commit

Permalink
Fix to #19883 - Query: null semantics should keep track of columns th…
Browse files Browse the repository at this point in the history
…at are guaranteed to be null and remove redundant IS NULL/ IS NOT NULL checks

Adding a list of columns guaranteed to be null in the given subtree. When processing right side of the || operator, we can convert those to non-nullable columns and therefore improve the generated sql.

Fixes #19883
Fixes #19410
  • Loading branch information
maumar committed Sep 14, 2021
1 parent 9e3b869 commit d638623
Show file tree
Hide file tree
Showing 7 changed files with 387 additions and 37 deletions.
78 changes: 67 additions & 11 deletions src/EFCore.Relational/Query/SqlNullabilityProcessor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ namespace Microsoft.EntityFrameworkCore.Query
public class SqlNullabilityProcessor
{
private readonly List<ColumnExpression> _nonNullableColumns;
private readonly List<ColumnExpression> _nullValueColumns;
private readonly ISqlExpressionFactory _sqlExpressionFactory;
private bool _canCache;

Expand All @@ -46,6 +47,7 @@ public SqlNullabilityProcessor(
_sqlExpressionFactory = dependencies.SqlExpressionFactory;
UseRelationalNulls = useRelationalNulls;
_nonNullableColumns = new List<ColumnExpression>();
_nullValueColumns = new List<ColumnExpression>();
ParameterValues = null!;
}

Expand Down Expand Up @@ -81,6 +83,7 @@ public virtual SelectExpression Process(

_canCache = true;
_nonNullableColumns.Clear();
_nullValueColumns.Clear();
ParameterValues = parameterValues;

var result = Visit(selectExpression);
Expand Down Expand Up @@ -342,13 +345,13 @@ protected virtual SelectExpression Visit(SelectExpression selectExpression)
/// <returns> An optimized sql expression. </returns>
[return: NotNullIfNotNull("sqlExpression")]
protected virtual SqlExpression? Visit(SqlExpression? sqlExpression, bool allowOptimizedExpansion, out bool nullable)
=> Visit(sqlExpression, allowOptimizedExpansion, preserveNonNullableColumns: false, out nullable);
=> Visit(sqlExpression, allowOptimizedExpansion, preserveColumnNullabilityInformation: false, out nullable);

[return: NotNullIfNotNull("sqlExpression")]
private SqlExpression? Visit(
SqlExpression? sqlExpression,
bool allowOptimizedExpansion,
bool preserveNonNullableColumns,
bool preserveColumnNullabilityInformation,
out bool nullable)
{
if (sqlExpression == null)
Expand All @@ -358,6 +361,7 @@ protected virtual SelectExpression Visit(SelectExpression selectExpression)
}

var nonNullableColumnsCount = _nonNullableColumns.Count;
var nullValueColumnsCount = _nullValueColumns.Count;
var result = sqlExpression switch
{
CaseExpression caseExpression
Expand Down Expand Up @@ -393,9 +397,10 @@ SqlUnaryExpression sqlUnaryExpression
_ => VisitCustomSqlExpression(sqlExpression, allowOptimizedExpansion, out nullable)
};

if (!preserveNonNullableColumns)
if (!preserveColumnNullabilityInformation)
{
RestoreNonNullableColumnsList(nonNullableColumnsCount);
RestoreNullValueColumnsList(nullValueColumnsCount);
}

return result;
Expand Down Expand Up @@ -430,6 +435,7 @@ protected virtual SqlExpression VisitCase(CaseExpression caseExpression, bool al
// otherwise the result is nullable if any of the WhenClause results OR ElseResult is nullable
nullable = caseExpression.ElseResult == null;
var currentNonNullableColumnsCount = _nonNullableColumns.Count;
var currentNullValueColumnsCount = _nullValueColumns.Count;

var operand = Visit(caseExpression.Operand, out _);
var whenClauses = new List<CaseWhenClause>();
Expand All @@ -438,8 +444,8 @@ protected virtual SqlExpression VisitCase(CaseExpression caseExpression, bool al
var testEvaluatesToTrue = false;
foreach (var whenClause in caseExpression.WhenClauses)
{
// we can use non-nullable column information we got from visiting Test, in the Result
var test = Visit(whenClause.Test, allowOptimizedExpansion: testIsCondition, preserveNonNullableColumns: true, out _);
// we can use column nullability information we got from visiting Test, in the Result
var test = Visit(whenClause.Test, allowOptimizedExpansion: testIsCondition, preserveColumnNullabilityInformation: true, out _);

if (TryGetBoolConstantValue(test) is bool testConstantBool)
{
Expand All @@ -451,6 +457,7 @@ protected virtual SqlExpression VisitCase(CaseExpression caseExpression, bool al
{
// if test evaluates to 'false' we can remove the WhenClause
RestoreNonNullableColumnsList(currentNonNullableColumnsCount);
RestoreNullValueColumnsList(currentNullValueColumnsCount);

continue;
}
Expand All @@ -461,6 +468,7 @@ protected virtual SqlExpression VisitCase(CaseExpression caseExpression, bool al
nullable |= resultNullable;
whenClauses.Add(new CaseWhenClause(test, newResult));
RestoreNonNullableColumnsList(currentNonNullableColumnsCount);
RestoreNullValueColumnsList(currentNonNullableColumnsCount);

// if test evaluates to 'true' we can remove every condition that comes after, including ElseResult
if (testEvaluatesToTrue)
Expand All @@ -476,6 +484,9 @@ protected virtual SqlExpression VisitCase(CaseExpression caseExpression, bool al
nullable |= elseResultNullable;
}

RestoreNonNullableColumnsList(currentNonNullableColumnsCount);
RestoreNullValueColumnsList(currentNullValueColumnsCount);

// if there are no whenClauses left (e.g. their tests evaluated to false):
// - if there is Else block, return it
// - if there is no Else block, return null
Expand Down Expand Up @@ -830,16 +841,30 @@ protected virtual SqlExpression VisitSqlBinary(
|| sqlBinaryExpression.OperatorType == ExpressionType.OrElse);

var currentNonNullableColumnsCount = _nonNullableColumns.Count;
var currentNullValueColumnsCount = _nullValueColumns.Count;

var left = Visit(sqlBinaryExpression.Left, allowOptimizedExpansion, preserveNonNullableColumns: true, out var leftNullable);
var left = Visit(sqlBinaryExpression.Left, allowOptimizedExpansion, preserveColumnNullabilityInformation: true, out var leftNullable);

var leftNonNullableColumns = _nonNullableColumns.Skip(currentNonNullableColumnsCount).ToList();
var leftNullValueColumns = _nullValueColumns.Skip(currentNullValueColumnsCount).ToList();
if (sqlBinaryExpression.OperatorType != ExpressionType.AndAlso)
{
RestoreNonNullableColumnsList(currentNonNullableColumnsCount);
}

var right = Visit(sqlBinaryExpression.Right, allowOptimizedExpansion, preserveNonNullableColumns: true, out var rightNullable);
if (sqlBinaryExpression.OperatorType == ExpressionType.OrElse)
{
// in case of OrElse, we can assume all null value columns on the left side can be treated as non-nullable on the right
// e.g. (a == null || b == null) || f(a, b)
// f(a, b) will only be executed if a != null and b != null
_nonNullableColumns.AddRange(_nullValueColumns.Skip(currentNullValueColumnsCount).ToList());
}
else
{
RestoreNullValueColumnsList(currentNullValueColumnsCount);
}

var right = Visit(sqlBinaryExpression.Right, allowOptimizedExpansion, preserveColumnNullabilityInformation: true, out var rightNullable);

if (sqlBinaryExpression.OperatorType == ExpressionType.OrElse)
{
Expand All @@ -853,6 +878,17 @@ protected virtual SqlExpression VisitSqlBinary(
RestoreNonNullableColumnsList(currentNonNullableColumnsCount);
}

if (sqlBinaryExpression.OperatorType == ExpressionType.AndAlso)
{
var intersect = leftNullValueColumns.Intersect(_nullValueColumns.Skip(currentNullValueColumnsCount)).ToList();
RestoreNullValueColumnsList(currentNullValueColumnsCount);
_nullValueColumns.AddRange(intersect);
}
else if (sqlBinaryExpression.OperatorType != ExpressionType.OrElse)
{
RestoreNullValueColumnsList(currentNullValueColumnsCount);
}

// nullableStringColumn + a -> COALESCE(nullableStringColumn, "") + a
if (sqlBinaryExpression.OperatorType == ExpressionType.Add
&& sqlBinaryExpression.Type == typeof(string))
Expand Down Expand Up @@ -886,10 +922,16 @@ protected virtual SqlExpression VisitSqlBinary(
out nullable);

if (optimized is SqlUnaryExpression optimizedUnary
&& optimizedUnary.OperatorType == ExpressionType.NotEqual
&& optimizedUnary.Operand is ColumnExpression optimizedUnaryColumnOperand)
{
_nonNullableColumns.Add(optimizedUnaryColumnOperand);
if (optimizedUnary.OperatorType == ExpressionType.NotEqual)
{
_nonNullableColumns.Add(optimizedUnaryColumnOperand);
}
else if (optimizedUnary.OperatorType == ExpressionType.Equal)
{
_nullValueColumns.Add(optimizedUnaryColumnOperand);
}
}

// we assume that NullSemantics rewrite is only needed (on the current level)
Expand Down Expand Up @@ -1069,10 +1111,16 @@ protected virtual SqlExpression VisitSqlUnary(
nullable = false;

if (result is SqlUnaryExpression resultUnary
&& resultUnary.OperatorType == ExpressionType.NotEqual
&& resultUnary.Operand is ColumnExpression resultColumnOperand)
{
_nonNullableColumns.Add(resultColumnOperand);
if (resultUnary.OperatorType == ExpressionType.NotEqual)
{
_nonNullableColumns.Add(resultColumnOperand);
}
else if (resultUnary.OperatorType == ExpressionType.Equal)
{
_nullValueColumns.Add(resultColumnOperand);
}
}

return result;
Expand All @@ -1099,6 +1147,14 @@ private void RestoreNonNullableColumnsList(int counter)
}
}

private void RestoreNullValueColumnsList(int counter)
{
if (counter < _nullValueColumns.Count)
{
_nullValueColumns.RemoveRange(counter, _nullValueColumns.Count - counter);
}
}

private SqlExpression ProcessJoinPredicate(SqlExpression predicate)
{
if (predicate is SqlBinaryExpression sqlBinaryExpression)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1087,25 +1087,33 @@ FROM root c
WHERE (c[""Discriminator""] = ""Customer"")");
}

[ConditionalTheory(Skip = "Issue #17246")]
public override async Task IsNullOrEmpty_negated_in_predicate(bool async)
{
await base.IsNullOrEmpty_negated_in_predicate(async);

AssertSql(@"");
}

[ConditionalTheory(Skip = "Issue #17246")]
public override Task IsNullOrWhiteSpace_in_predicate_on_non_nullable_column(bool async)
{
return base.IsNullOrWhiteSpace_in_predicate_on_non_nullable_column(async);
}

public override void IsNullOrEmpty_in_projection()
public override async Task IsNullOrEmpty_in_projection(bool async)
{
base.IsNullOrEmpty_in_projection();
await base.IsNullOrEmpty_in_projection(async);

AssertSql(
@"SELECT c[""CustomerID""], c[""Region""]
FROM root c
WHERE (c[""Discriminator""] = ""Customer"")");
}

public override void IsNullOrEmpty_negated_in_projection()
public override async Task IsNullOrEmpty_negated_in_projection(bool async)
{
base.IsNullOrEmpty_negated_in_projection();
await base.IsNullOrEmpty_negated_in_projection(async);

AssertSql(
@"SELECT c[""CustomerID""], c[""Region""]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1756,6 +1756,104 @@ public virtual async Task Negated_contains_with_comparison_without_null_get_comb
Assert.Equal(expected.Count, result.Count);
}

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task Is_null_on_column_followed_by_OrElse_optimizes_nullability_simple(bool async)
{
return AssertQuery(
async,
ss => ss.Set<NullSemanticsEntity1>().Where(x => !(x.NullableStringA == null || x.NullableStringA != "Foo")));
}

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task Is_null_on_column_followed_by_OrElse_optimizes_nullability_negative(bool async)
{
return AssertQuery(
async,
ss => ss.Set<NullSemanticsEntity1>().Where(x => !(x.NullableStringA == null && x.NullableStringA != "Foo")));
}

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task Is_null_on_column_followed_by_OrElse_optimizes_nullability_nested(bool async)
{
return AssertQuery(
async,
ss => ss.Set<NullSemanticsEntity1>().Where(x => x.NullableStringA == null
|| x.NullableStringB == null
|| x.NullableStringA != x.NullableStringB));
}

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task Is_null_on_column_followed_by_OrElse_optimizes_nullability_intersection(bool async)
{
return AssertQuery(
async,
ss => ss.Set<NullSemanticsEntity1>().Where(x => (x.NullableStringA == null
&& (x.StringA == "Foo" || x.NullableStringA == null || x.NullableStringB == null))
|| x.NullableStringA != x.NullableStringB));
}

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task Is_null_on_column_followed_by_OrElse_optimizes_nullability_conditional(bool async)
{
return AssertQuery(
async,
ss => ss.Set<NullSemanticsEntity1>().Where(x => x.NullableStringA == null
? x.NullableStringA != x.NullableStringB
: x.NullableStringA != x.NullableStringC));
}

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task Is_null_on_column_followed_by_OrElse_optimizes_nullability_conditional_multiple(bool async)
{
return AssertQuery(
async,
ss => ss.Set<NullSemanticsEntity1>().Where(x => x.NullableStringA == null || x.NullableStringB == null
? x.NullableStringA == x.NullableStringB
: x.NullableStringA != x.NullableStringB));
}

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task Is_null_on_column_followed_by_OrElse_optimizes_nullability_conditional_negative(bool async)
{
return AssertQuery(
async,
ss => ss.Set<NullSemanticsEntity1>().Where(x => (x.NullableStringA == null || x.NullableStringB == null) && x.NullableBoolC == null
? x.NullableStringA == x.NullableStringB
: x.NullableStringA != x.NullableStringB));
}

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task Is_null_on_column_followed_by_OrElse_optimizes_nullability_conditional_with_setup(bool async)
{
return AssertQuery(
async,
ss => ss.Set<NullSemanticsEntity1>().Where(x => x.NullableBoolA == null
|| (x.NullableBoolB == null
? x.NullableBoolB != x.NullableBoolA
: x.NullableBoolA != x.NullableBoolB)));
}

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task Is_null_on_column_followed_by_OrElse_optimizes_nullability_conditional_nested(bool async)
{
return AssertQuery(
async,
ss => ss.Set<NullSemanticsEntity1>().Where(x => x.NullableBoolA == null
? x.BoolA == x.BoolB
: (x.NullableBoolC == null
? x.NullableBoolA != x.NullableBoolC
: x.NullableBoolC != x.NullableBoolA)));
}

private string NormalizeDelimitersInRawString(string sql)
=> Fixture.TestStore.NormalizeDelimitersInRawString(sql);

Expand Down
Loading

0 comments on commit d638623

Please sign in to comment.