Skip to content

Commit

Permalink
Support for row value comparisons
Browse files Browse the repository at this point in the history
Closes #2349
  • Loading branch information
roji committed May 6, 2022
1 parent 94648a3 commit db16332
Show file tree
Hide file tree
Showing 12 changed files with 511 additions and 12 deletions.
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@


// ReSharper disable once CheckNamespace

using System.Runtime.CompilerServices;

namespace Microsoft.EntityFrameworkCore;

/// <summary>
Expand Down Expand Up @@ -42,4 +45,16 @@ public static bool ILike(this DbFunctions _, string matchExpression, string patt
/// <returns>The reversed string.</returns>
public static string Reverse(this DbFunctions _, string value)
=> throw new InvalidOperationException(CoreStrings.FunctionOnClient(nameof(Reverse)));

public static bool GreaterThan(this DbFunctions _, ITuple a, ITuple b)
=> throw new InvalidOperationException(CoreStrings.FunctionOnClient(nameof(GreaterThan)));

public static bool LessThan(this DbFunctions _, ITuple a, ITuple b)
=> throw new InvalidOperationException(CoreStrings.FunctionOnClient(nameof(LessThan)));

public static bool GreaterThanOrEqual(this DbFunctions _, ITuple a, ITuple b)
=> throw new InvalidOperationException(CoreStrings.FunctionOnClient(nameof(GreaterThanOrEqual)));

public static bool LessThanOrEqual(this DbFunctions _, ITuple a, ITuple b)
=> throw new InvalidOperationException(CoreStrings.FunctionOnClient(nameof(LessThanOrEqual)));
}
8 changes: 8 additions & 0 deletions src/EFCore.PG/Properties/NpgsqlStrings.Designer.cs

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 3 additions & 0 deletions src/EFCore.PG/Properties/NpgsqlStrings.resx
Original file line number Diff line number Diff line change
Expand Up @@ -226,4 +226,7 @@
<data name="IdentityBadType" xml:space="preserve">
<value>Identity value generation cannot be used for the property '{property}' on entity type '{entityType}' because the property type is '{propertyType}'. Identity value generation can only be used with signed integer properties.</value>
</data>
<data name="RowValueMethodRequiresTwoArraysOfSameLength" xml:space="preserve">
<value>'{method}' requires two array parameters of the same length.</value>
</data>
</root>
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ public NpgsqlMethodCallTranslatorProvider(
new NpgsqlRandomTranslator(sqlExpressionFactory),
new NpgsqlRangeTranslator(typeMappingSource, sqlExpressionFactory, model),
new NpgsqlRegexIsMatchTranslator(sqlExpressionFactory),
new NpgsqlRowValueComparisonTranslator(sqlExpressionFactory),
new NpgsqlStringMethodTranslator(typeMappingSource, sqlExpressionFactory, model),
new NpgsqlTrigramsMethodTranslator(typeMappingSource, sqlExpressionFactory, model),
});
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

using System.Runtime.CompilerServices;
using Npgsql.EntityFrameworkCore.PostgreSQL.Internal;
using Npgsql.EntityFrameworkCore.PostgreSQL.Query.Expressions.Internal;

namespace Npgsql.EntityFrameworkCore.PostgreSQL.Query.ExpressionTranslators.Internal;

public class NpgsqlRowValueComparisonTranslator : IMethodCallTranslator
{
private readonly NpgsqlSqlExpressionFactory _sqlExpressionFactory;

private static readonly MethodInfo GreaterThan =
typeof(NpgsqlDbFunctionsExtensions).GetRuntimeMethod(
nameof(NpgsqlDbFunctionsExtensions.GreaterThan),
new[] { typeof(DbFunctions), typeof(ITuple), typeof(ITuple) })!;

private static readonly MethodInfo LessThan =
typeof(NpgsqlDbFunctionsExtensions).GetMethods()
.Single(m => m.Name == nameof(NpgsqlDbFunctionsExtensions.LessThan));

private static readonly MethodInfo GreaterThanOrEqual =
typeof(NpgsqlDbFunctionsExtensions).GetMethods()
.Single(m => m.Name == nameof(NpgsqlDbFunctionsExtensions.GreaterThanOrEqual));

private static readonly MethodInfo LessThanOrEqual =
typeof(NpgsqlDbFunctionsExtensions).GetMethods()
.Single(m => m.Name == nameof(NpgsqlDbFunctionsExtensions.LessThanOrEqual));

private static readonly Dictionary<MethodInfo, ExpressionType> Methods = new()
{
{ GreaterThan, ExpressionType.GreaterThan },
{ LessThan, ExpressionType.LessThan },
{ GreaterThanOrEqual, ExpressionType.GreaterThanOrEqual },
{ LessThanOrEqual, ExpressionType.LessThanOrEqual }
};

/// <summary>
/// Initializes a new instance of the <see cref="NpgsqlRowValueComparisonTranslator"/> class.
/// </summary>
public NpgsqlRowValueComparisonTranslator(NpgsqlSqlExpressionFactory sqlExpressionFactory)
=> _sqlExpressionFactory = sqlExpressionFactory;

/// <inheritdoc />
public virtual SqlExpression? Translate(
SqlExpression? instance,
MethodInfo method,
IReadOnlyList<SqlExpression> arguments,
IDiagnosticsLogger<DbLoggerCategory.Query> logger)
{
if (method.DeclaringType != typeof(NpgsqlDbFunctionsExtensions)
|| !Methods.TryGetValue(method, out var expressionType)
|| arguments[1] is not PostgresRowValueExpression rowValue1
|| arguments[2] is not PostgresRowValueExpression rowValue2)
{
return null;
}

if (rowValue1.Values.Count != rowValue2.Values.Count)
{
throw new ArgumentException(NpgsqlStrings.RowValueMethodRequiresTwoArraysOfSameLength(method.Name));
}

return _sqlExpressionFactory.MakeBinary(expressionType, rowValue1, rowValue2, typeMapping: null);
}
}
141 changes: 141 additions & 0 deletions src/EFCore.PG/Query/Expressions/Internal/PostgresRowValueExpression.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

using System.Runtime.CompilerServices;

namespace Npgsql.EntityFrameworkCore.PostgreSQL.Query.Expressions.Internal;

/// <summary>
/// An expression that represents a PostgreSQL-specific row value expression in a SQL tree.
/// </summary>
/// <remarks>
/// See the <see href="https://www.postgresql.org/docs/current/sql-expressions.html#SQL-SYNTAX-ROW-CONSTRUCTORS">PostgreSQL docs</see>
/// for more information.
/// </remarks>
public class PostgresRowValueExpression : SqlExpression, IEquatable<PostgresRowValueExpression>
{
/// <summary>
/// The values of this PostgreSQL row value expression.
/// </summary>
public virtual IReadOnlyList<SqlExpression> Values { get; }

/// <inheritdoc />
public PostgresRowValueExpression(IReadOnlyList<SqlExpression> values)
: base(typeof(ITuple), typeMapping: RowValueTypeMapping.Instance)
{
Check.NotNull(values, nameof(values));

Values = values;
}

/// <inheritdoc />
protected override Expression VisitChildren(ExpressionVisitor visitor)
{
Check.NotNull(visitor, nameof(visitor));

SqlExpression[]? newRowValues = null;

for (var i = 0; i < Values.Count; i++)
{
var rowValue = Values[i];
var visited = (SqlExpression)visitor.Visit(rowValue);
if (visited != rowValue && newRowValues is null)
{
newRowValues = new SqlExpression[Values.Count];
for (var j = 0; j < i; i++)
{
newRowValues[j] = Values[j];
}
}

if (newRowValues is not null)
{
newRowValues[i] = visited;
}
}

return newRowValues is null ? this : new PostgresRowValueExpression(newRowValues);
}

public virtual PostgresRowValueExpression Update(IReadOnlyList<SqlExpression> values)
=> values.Count == Values.Count && values.Zip(Values, (x, y) => (x, y)).All(tup => tup.x == tup.y)
? this
: new PostgresRowValueExpression(values);

/// <inheritdoc />
protected override void Print(ExpressionPrinter expressionPrinter)
{
expressionPrinter.Append("(");

var count = Values.Count;
for (var i = 0; i < count; i++)
{
expressionPrinter.Visit(Values[i]);

if (i < count - 1)
{
expressionPrinter.Append(", ");
}
}

expressionPrinter.Append(")");
}

/// <inheritdoc />
public override bool Equals(object? obj)
=> obj is PostgresRowValueExpression other && Equals(other);

/// <inheritdoc />
public virtual bool Equals(PostgresRowValueExpression? other)
{
if (other is null || !base.Equals(other) || other.Values.Count != Values.Count)
{
return false;
}

if (ReferenceEquals(this, other))
{
return true;
}

for (var i = 0; i < Values.Count; i++)
{
if (other.Values[i].Equals(Values[i]))
{
return false;
}
}

return true;
}

/// <inheritdoc />
public override int GetHashCode()
{
var hashCode = new HashCode();

foreach (var rowValue in Values)
{
hashCode.Add(rowValue);
}

return hashCode.ToHashCode();
}

/// <summary>
/// Every node in the SQL tree must have a type mapping, but row values aren't actual values (in the sense that they can be sent as
/// parameters, or have a literal representation). So we have a dummy type mapping for that.
/// </summary>
private sealed class RowValueTypeMapping : RelationalTypeMapping
{
internal static RowValueTypeMapping Instance { get; } = new();

private RowValueTypeMapping()
: base(new(new(), storeType: "rowvalue"))
{
}

protected override RelationalTypeMapping Clone(RelationalTypeMappingParameters parameters)
=> this;
}
}
10 changes: 10 additions & 0 deletions src/EFCore.PG/Query/Internal/NpgsqlEvaluatableExpressionFilter.cs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
using System.Runtime.CompilerServices;

namespace Npgsql.EntityFrameworkCore.PostgreSQL.Query.Internal;

public class NpgsqlEvaluatableExpressionFilter : RelationalEvaluatableExpressionFilter
Expand Down Expand Up @@ -35,6 +37,14 @@ public override bool IsEvaluatableExpression(Expression expression, IModel model
}

break;

case NewExpression newExpression when newExpression.Type.IsAssignableTo(typeof(ITuple)):
// We translate new ValueTuple<T1, T2...>(x, y...) to a SQL row value expression: (x, y)
// (see NpgsqlSqlTranslatingExpressionVisitor.VisitNew).
// We must prevent evaluation when the tuple contains only constants/parameters, since SQL row values cannot be
// parameterized; we need to render them as "literals" instead:
// WHERE (x, y) > (3, $1)
return false;
}

return base.IsEvaluatableExpression(expression, model);
Expand Down
22 changes: 22 additions & 0 deletions src/EFCore.PG/Query/Internal/NpgsqlQuerySqlGenerator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ protected override Expression VisitExtension(Expression extensionExpression)
PostgresJsonTraversalExpression jsonTraversalExpression => VisitJsonPathTraversal(jsonTraversalExpression),
PostgresNewArrayExpression newArrayExpression => VisitPostgresNewArray(newArrayExpression),
PostgresRegexMatchExpression regexMatchExpression => VisitRegexMatch(regexMatchExpression),
PostgresRowValueExpression rowValueExpression => VisitRowValue(rowValueExpression),
PostgresUnknownBinaryExpression unknownBinaryExpression => VisitUnknownBinary(unknownBinaryExpression),
_ => base.VisitExtension(extensionExpression)
};
Expand Down Expand Up @@ -547,6 +548,27 @@ public virtual Expression VisitRegexMatch(PostgresRegexMatchExpression expressio
return expression;
}

public virtual Expression VisitRowValue(PostgresRowValueExpression rowValueExpression)
{
Sql.Append("(");

var values = rowValueExpression.Values;
var count = values.Count;
for (var i = 0; i < count; i++)
{
Visit(values[i]);

if (i < count - 1)
{
Sql.Append(", ");
}
}

Sql.Append(")");

return rowValueExpression;
}

/// <summary>
/// Visits the children of an <see cref="PostgresILikeExpression"/>.
/// </summary>
Expand Down
40 changes: 38 additions & 2 deletions src/EFCore.PG/Query/Internal/NpgsqlSqlNullabilityProcessor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -38,12 +38,14 @@ PostgresBinaryExpression binaryExpression
=> VisitBinary(binaryExpression, allowOptimizedExpansion, out nullable),
PostgresILikeExpression ilikeExpression
=> VisitILike(ilikeExpression, allowOptimizedExpansion, out nullable),
PostgresJsonTraversalExpression postgresJsonTraversalExpression
=> VisitJsonTraversal(postgresJsonTraversalExpression, allowOptimizedExpansion, out nullable),
PostgresNewArrayExpression newArrayExpression
=> VisitNewArray(newArrayExpression, allowOptimizedExpansion, out nullable),
PostgresRegexMatchExpression regexMatchExpression
=> VisitRegexMatch(regexMatchExpression, allowOptimizedExpansion, out nullable),
PostgresJsonTraversalExpression postgresJsonTraversalExpression
=> VisitJsonTraversal(postgresJsonTraversalExpression, allowOptimizedExpansion, out nullable),
PostgresRowValueExpression postgresRowValueExpression
=> VisitRowValueExpression(postgresRowValueExpression, allowOptimizedExpansion, out nullable),
PostgresUnknownBinaryExpression postgresUnknownBinaryExpression
=> VisitUnknownBinary(postgresUnknownBinaryExpression, allowOptimizedExpansion, out nullable),

Expand Down Expand Up @@ -276,6 +278,40 @@ protected virtual SqlExpression VisitJsonTraversal(
return jsonTraversalExpression.Update(expression, newPath?.ToArray() ?? jsonTraversalExpression.Path);
}

protected virtual SqlExpression VisitRowValueExpression(
PostgresRowValueExpression rowValueExpression,
bool allowOptimizedExpansion,
out bool nullable)
{
SqlExpression[]? newValues = null;

for (var i = 0; i < rowValueExpression.Values.Count; i++)
{
var value = rowValueExpression.Values[i];

// Note that we disallow optimized expansion, since the null vs. false distinction does matter inside the row's values
var newValue = Visit(value, allowOptimizedExpansion: false, out _);
if (newValue != value && newValues is null)
{
newValues = new SqlExpression[rowValueExpression.Values.Count];
for (var j = 0; j < i; j++)
{
newValues[j] = newValue;
}
}

if (newValues is not null)
{
newValues[i] = newValue;
}
}

// The row value expression itself can never be null
nullable = false;

return rowValueExpression.Update(newValues ?? rowValueExpression.Values);
}

/// <summary>
/// Visits a <see cref="PostgresUnknownBinaryExpression" /> and computes its nullability.
/// </summary>
Expand Down
Loading

0 comments on commit db16332

Please sign in to comment.