From db16332e159b5cdde54f4ac31773b2b276975e84 Mon Sep 17 00:00:00 2001 From: Shay Rojansky Date: Wed, 4 May 2022 22:06:56 +0200 Subject: [PATCH] Support for row value comparisons Closes #2349 --- .../NpgsqlDbFunctionsExtensions.cs | 15 ++ .../Properties/NpgsqlStrings.Designer.cs | 8 + src/EFCore.PG/Properties/NpgsqlStrings.resx | 3 + .../NpgsqlMethodCallTranslatorProvider.cs | 1 + .../NpgsqlRowValueComparisonTranslator.cs | 67 +++++++++ .../Internal/PostgresRowValueExpression.cs | 141 ++++++++++++++++++ .../NpgsqlEvaluatableExpressionFilter.cs | 10 ++ .../Query/Internal/NpgsqlQuerySqlGenerator.cs | 22 +++ .../Internal/NpgsqlSqlNullabilityProcessor.cs | 40 ++++- .../NpgsqlSqlTranslatingExpressionVisitor.cs | 50 ++++++- .../Query/NpgsqlSqlExpressionFactory.cs | 37 ++++- .../NorthwindMiscellaneousQueryNpgsqlTest.cs | 129 ++++++++++++++++ 12 files changed, 511 insertions(+), 12 deletions(-) create mode 100644 src/EFCore.PG/Query/ExpressionTranslators/Internal/NpgsqlRowValueComparisonTranslator.cs create mode 100644 src/EFCore.PG/Query/Expressions/Internal/PostgresRowValueExpression.cs diff --git a/src/EFCore.PG/Extensions/DbFunctionsExtensions/NpgsqlDbFunctionsExtensions.cs b/src/EFCore.PG/Extensions/DbFunctionsExtensions/NpgsqlDbFunctionsExtensions.cs index 8d7e916b4c..a1e4a85e74 100644 --- a/src/EFCore.PG/Extensions/DbFunctionsExtensions/NpgsqlDbFunctionsExtensions.cs +++ b/src/EFCore.PG/Extensions/DbFunctionsExtensions/NpgsqlDbFunctionsExtensions.cs @@ -1,6 +1,9 @@  // ReSharper disable once CheckNamespace + +using System.Runtime.CompilerServices; + namespace Microsoft.EntityFrameworkCore; /// @@ -42,4 +45,16 @@ public static bool ILike(this DbFunctions _, string matchExpression, string patt /// The reversed string. public static string Reverse(this DbFunctions _, string value) => throw new InvalidOperationException(CoreStrings.FunctionOnClient(nameof(Reverse))); + + public static bool GreaterThan(this DbFunctions _, ITuple a, ITuple b) + => throw new InvalidOperationException(CoreStrings.FunctionOnClient(nameof(GreaterThan))); + + public static bool LessThan(this DbFunctions _, ITuple a, ITuple b) + => throw new InvalidOperationException(CoreStrings.FunctionOnClient(nameof(LessThan))); + + public static bool GreaterThanOrEqual(this DbFunctions _, ITuple a, ITuple b) + => throw new InvalidOperationException(CoreStrings.FunctionOnClient(nameof(GreaterThanOrEqual))); + + public static bool LessThanOrEqual(this DbFunctions _, ITuple a, ITuple b) + => throw new InvalidOperationException(CoreStrings.FunctionOnClient(nameof(LessThanOrEqual))); } \ No newline at end of file diff --git a/src/EFCore.PG/Properties/NpgsqlStrings.Designer.cs b/src/EFCore.PG/Properties/NpgsqlStrings.Designer.cs index df2906684b..fe5851b5ca 100644 --- a/src/EFCore.PG/Properties/NpgsqlStrings.Designer.cs +++ b/src/EFCore.PG/Properties/NpgsqlStrings.Designer.cs @@ -121,6 +121,14 @@ public static string NonKeyValueGeneration(object? property, object? entityType) GetString("NonKeyValueGeneration", nameof(property), nameof(entityType)), property, entityType); + /// + /// '{method}' requires two array parameters of the same length. + /// + public static string RowValueMethodRequiresTwoArraysOfSameLength(object? method) + => string.Format( + GetString("RowValueMethodRequiresTwoArraysOfSameLength", nameof(method)), + method); + /// /// PostgreSQL sequences cannot be used to generate values for the property '{property}' on entity type '{entityType}' because the property type is '{propertyType}'. Sequences can only be used with integer properties. /// diff --git a/src/EFCore.PG/Properties/NpgsqlStrings.resx b/src/EFCore.PG/Properties/NpgsqlStrings.resx index ce7d441e2b..fce0932bf0 100644 --- a/src/EFCore.PG/Properties/NpgsqlStrings.resx +++ b/src/EFCore.PG/Properties/NpgsqlStrings.resx @@ -226,4 +226,7 @@ Identity value generation cannot be used for the property '{property}' on entity type '{entityType}' because the property type is '{propertyType}'. Identity value generation can only be used with signed integer properties. + + '{method}' requires two array parameters of the same length. + \ No newline at end of file diff --git a/src/EFCore.PG/Query/ExpressionTranslators/Internal/NpgsqlMethodCallTranslatorProvider.cs b/src/EFCore.PG/Query/ExpressionTranslators/Internal/NpgsqlMethodCallTranslatorProvider.cs index 2a33aa4de1..11dbd1555f 100644 --- a/src/EFCore.PG/Query/ExpressionTranslators/Internal/NpgsqlMethodCallTranslatorProvider.cs +++ b/src/EFCore.PG/Query/ExpressionTranslators/Internal/NpgsqlMethodCallTranslatorProvider.cs @@ -37,6 +37,7 @@ public NpgsqlMethodCallTranslatorProvider( new NpgsqlRandomTranslator(sqlExpressionFactory), new NpgsqlRangeTranslator(typeMappingSource, sqlExpressionFactory, model), new NpgsqlRegexIsMatchTranslator(sqlExpressionFactory), + new NpgsqlRowValueComparisonTranslator(sqlExpressionFactory), new NpgsqlStringMethodTranslator(typeMappingSource, sqlExpressionFactory, model), new NpgsqlTrigramsMethodTranslator(typeMappingSource, sqlExpressionFactory, model), }); diff --git a/src/EFCore.PG/Query/ExpressionTranslators/Internal/NpgsqlRowValueComparisonTranslator.cs b/src/EFCore.PG/Query/ExpressionTranslators/Internal/NpgsqlRowValueComparisonTranslator.cs new file mode 100644 index 0000000000..e33da5b279 --- /dev/null +++ b/src/EFCore.PG/Query/ExpressionTranslators/Internal/NpgsqlRowValueComparisonTranslator.cs @@ -0,0 +1,67 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Runtime.CompilerServices; +using Npgsql.EntityFrameworkCore.PostgreSQL.Internal; +using Npgsql.EntityFrameworkCore.PostgreSQL.Query.Expressions.Internal; + +namespace Npgsql.EntityFrameworkCore.PostgreSQL.Query.ExpressionTranslators.Internal; + +public class NpgsqlRowValueComparisonTranslator : IMethodCallTranslator +{ + private readonly NpgsqlSqlExpressionFactory _sqlExpressionFactory; + + private static readonly MethodInfo GreaterThan = + typeof(NpgsqlDbFunctionsExtensions).GetRuntimeMethod( + nameof(NpgsqlDbFunctionsExtensions.GreaterThan), + new[] { typeof(DbFunctions), typeof(ITuple), typeof(ITuple) })!; + + private static readonly MethodInfo LessThan = + typeof(NpgsqlDbFunctionsExtensions).GetMethods() + .Single(m => m.Name == nameof(NpgsqlDbFunctionsExtensions.LessThan)); + + private static readonly MethodInfo GreaterThanOrEqual = + typeof(NpgsqlDbFunctionsExtensions).GetMethods() + .Single(m => m.Name == nameof(NpgsqlDbFunctionsExtensions.GreaterThanOrEqual)); + + private static readonly MethodInfo LessThanOrEqual = + typeof(NpgsqlDbFunctionsExtensions).GetMethods() + .Single(m => m.Name == nameof(NpgsqlDbFunctionsExtensions.LessThanOrEqual)); + + private static readonly Dictionary Methods = new() + { + { GreaterThan, ExpressionType.GreaterThan }, + { LessThan, ExpressionType.LessThan }, + { GreaterThanOrEqual, ExpressionType.GreaterThanOrEqual }, + { LessThanOrEqual, ExpressionType.LessThanOrEqual } + }; + + /// + /// Initializes a new instance of the class. + /// + public NpgsqlRowValueComparisonTranslator(NpgsqlSqlExpressionFactory sqlExpressionFactory) + => _sqlExpressionFactory = sqlExpressionFactory; + + /// + public virtual SqlExpression? Translate( + SqlExpression? instance, + MethodInfo method, + IReadOnlyList arguments, + IDiagnosticsLogger logger) + { + if (method.DeclaringType != typeof(NpgsqlDbFunctionsExtensions) + || !Methods.TryGetValue(method, out var expressionType) + || arguments[1] is not PostgresRowValueExpression rowValue1 + || arguments[2] is not PostgresRowValueExpression rowValue2) + { + return null; + } + + if (rowValue1.Values.Count != rowValue2.Values.Count) + { + throw new ArgumentException(NpgsqlStrings.RowValueMethodRequiresTwoArraysOfSameLength(method.Name)); + } + + return _sqlExpressionFactory.MakeBinary(expressionType, rowValue1, rowValue2, typeMapping: null); + } +} diff --git a/src/EFCore.PG/Query/Expressions/Internal/PostgresRowValueExpression.cs b/src/EFCore.PG/Query/Expressions/Internal/PostgresRowValueExpression.cs new file mode 100644 index 0000000000..d319ba4b4d --- /dev/null +++ b/src/EFCore.PG/Query/Expressions/Internal/PostgresRowValueExpression.cs @@ -0,0 +1,141 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Runtime.CompilerServices; + +namespace Npgsql.EntityFrameworkCore.PostgreSQL.Query.Expressions.Internal; + +/// +/// An expression that represents a PostgreSQL-specific row value expression in a SQL tree. +/// +/// +/// See the PostgreSQL docs +/// for more information. +/// +public class PostgresRowValueExpression : SqlExpression, IEquatable +{ + /// + /// The values of this PostgreSQL row value expression. + /// + public virtual IReadOnlyList Values { get; } + + /// + public PostgresRowValueExpression(IReadOnlyList values) + : base(typeof(ITuple), typeMapping: RowValueTypeMapping.Instance) + { + Check.NotNull(values, nameof(values)); + + Values = values; + } + + /// + protected override Expression VisitChildren(ExpressionVisitor visitor) + { + Check.NotNull(visitor, nameof(visitor)); + + SqlExpression[]? newRowValues = null; + + for (var i = 0; i < Values.Count; i++) + { + var rowValue = Values[i]; + var visited = (SqlExpression)visitor.Visit(rowValue); + if (visited != rowValue && newRowValues is null) + { + newRowValues = new SqlExpression[Values.Count]; + for (var j = 0; j < i; i++) + { + newRowValues[j] = Values[j]; + } + } + + if (newRowValues is not null) + { + newRowValues[i] = visited; + } + } + + return newRowValues is null ? this : new PostgresRowValueExpression(newRowValues); + } + + public virtual PostgresRowValueExpression Update(IReadOnlyList values) + => values.Count == Values.Count && values.Zip(Values, (x, y) => (x, y)).All(tup => tup.x == tup.y) + ? this + : new PostgresRowValueExpression(values); + + /// + protected override void Print(ExpressionPrinter expressionPrinter) + { + expressionPrinter.Append("("); + + var count = Values.Count; + for (var i = 0; i < count; i++) + { + expressionPrinter.Visit(Values[i]); + + if (i < count - 1) + { + expressionPrinter.Append(", "); + } + } + + expressionPrinter.Append(")"); + } + + /// + public override bool Equals(object? obj) + => obj is PostgresRowValueExpression other && Equals(other); + + /// + public virtual bool Equals(PostgresRowValueExpression? other) + { + if (other is null || !base.Equals(other) || other.Values.Count != Values.Count) + { + return false; + } + + if (ReferenceEquals(this, other)) + { + return true; + } + + for (var i = 0; i < Values.Count; i++) + { + if (other.Values[i].Equals(Values[i])) + { + return false; + } + } + + return true; + } + + /// + public override int GetHashCode() + { + var hashCode = new HashCode(); + + foreach (var rowValue in Values) + { + hashCode.Add(rowValue); + } + + return hashCode.ToHashCode(); + } + + /// + /// Every node in the SQL tree must have a type mapping, but row values aren't actual values (in the sense that they can be sent as + /// parameters, or have a literal representation). So we have a dummy type mapping for that. + /// + private sealed class RowValueTypeMapping : RelationalTypeMapping + { + internal static RowValueTypeMapping Instance { get; } = new(); + + private RowValueTypeMapping() + : base(new(new(), storeType: "rowvalue")) + { + } + + protected override RelationalTypeMapping Clone(RelationalTypeMappingParameters parameters) + => this; + } +} diff --git a/src/EFCore.PG/Query/Internal/NpgsqlEvaluatableExpressionFilter.cs b/src/EFCore.PG/Query/Internal/NpgsqlEvaluatableExpressionFilter.cs index ad6433f133..2128253af1 100644 --- a/src/EFCore.PG/Query/Internal/NpgsqlEvaluatableExpressionFilter.cs +++ b/src/EFCore.PG/Query/Internal/NpgsqlEvaluatableExpressionFilter.cs @@ -1,3 +1,5 @@ +using System.Runtime.CompilerServices; + namespace Npgsql.EntityFrameworkCore.PostgreSQL.Query.Internal; public class NpgsqlEvaluatableExpressionFilter : RelationalEvaluatableExpressionFilter @@ -35,6 +37,14 @@ public override bool IsEvaluatableExpression(Expression expression, IModel model } break; + + case NewExpression newExpression when newExpression.Type.IsAssignableTo(typeof(ITuple)): + // We translate new ValueTuple(x, y...) to a SQL row value expression: (x, y) + // (see NpgsqlSqlTranslatingExpressionVisitor.VisitNew). + // We must prevent evaluation when the tuple contains only constants/parameters, since SQL row values cannot be + // parameterized; we need to render them as "literals" instead: + // WHERE (x, y) > (3, $1) + return false; } return base.IsEvaluatableExpression(expression, model); diff --git a/src/EFCore.PG/Query/Internal/NpgsqlQuerySqlGenerator.cs b/src/EFCore.PG/Query/Internal/NpgsqlQuerySqlGenerator.cs index d1f3953445..f577b99e79 100644 --- a/src/EFCore.PG/Query/Internal/NpgsqlQuerySqlGenerator.cs +++ b/src/EFCore.PG/Query/Internal/NpgsqlQuerySqlGenerator.cs @@ -49,6 +49,7 @@ protected override Expression VisitExtension(Expression extensionExpression) PostgresJsonTraversalExpression jsonTraversalExpression => VisitJsonPathTraversal(jsonTraversalExpression), PostgresNewArrayExpression newArrayExpression => VisitPostgresNewArray(newArrayExpression), PostgresRegexMatchExpression regexMatchExpression => VisitRegexMatch(regexMatchExpression), + PostgresRowValueExpression rowValueExpression => VisitRowValue(rowValueExpression), PostgresUnknownBinaryExpression unknownBinaryExpression => VisitUnknownBinary(unknownBinaryExpression), _ => base.VisitExtension(extensionExpression) }; @@ -547,6 +548,27 @@ public virtual Expression VisitRegexMatch(PostgresRegexMatchExpression expressio return expression; } + public virtual Expression VisitRowValue(PostgresRowValueExpression rowValueExpression) + { + Sql.Append("("); + + var values = rowValueExpression.Values; + var count = values.Count; + for (var i = 0; i < count; i++) + { + Visit(values[i]); + + if (i < count - 1) + { + Sql.Append(", "); + } + } + + Sql.Append(")"); + + return rowValueExpression; + } + /// /// Visits the children of an . /// diff --git a/src/EFCore.PG/Query/Internal/NpgsqlSqlNullabilityProcessor.cs b/src/EFCore.PG/Query/Internal/NpgsqlSqlNullabilityProcessor.cs index dc921a3683..a01c698fcd 100644 --- a/src/EFCore.PG/Query/Internal/NpgsqlSqlNullabilityProcessor.cs +++ b/src/EFCore.PG/Query/Internal/NpgsqlSqlNullabilityProcessor.cs @@ -38,12 +38,14 @@ PostgresBinaryExpression binaryExpression => VisitBinary(binaryExpression, allowOptimizedExpansion, out nullable), PostgresILikeExpression ilikeExpression => VisitILike(ilikeExpression, allowOptimizedExpansion, out nullable), + PostgresJsonTraversalExpression postgresJsonTraversalExpression + => VisitJsonTraversal(postgresJsonTraversalExpression, allowOptimizedExpansion, out nullable), PostgresNewArrayExpression newArrayExpression => VisitNewArray(newArrayExpression, allowOptimizedExpansion, out nullable), PostgresRegexMatchExpression regexMatchExpression => VisitRegexMatch(regexMatchExpression, allowOptimizedExpansion, out nullable), - PostgresJsonTraversalExpression postgresJsonTraversalExpression - => VisitJsonTraversal(postgresJsonTraversalExpression, allowOptimizedExpansion, out nullable), + PostgresRowValueExpression postgresRowValueExpression + => VisitRowValueExpression(postgresRowValueExpression, allowOptimizedExpansion, out nullable), PostgresUnknownBinaryExpression postgresUnknownBinaryExpression => VisitUnknownBinary(postgresUnknownBinaryExpression, allowOptimizedExpansion, out nullable), @@ -276,6 +278,40 @@ protected virtual SqlExpression VisitJsonTraversal( return jsonTraversalExpression.Update(expression, newPath?.ToArray() ?? jsonTraversalExpression.Path); } + protected virtual SqlExpression VisitRowValueExpression( + PostgresRowValueExpression rowValueExpression, + bool allowOptimizedExpansion, + out bool nullable) + { + SqlExpression[]? newValues = null; + + for (var i = 0; i < rowValueExpression.Values.Count; i++) + { + var value = rowValueExpression.Values[i]; + + // Note that we disallow optimized expansion, since the null vs. false distinction does matter inside the row's values + var newValue = Visit(value, allowOptimizedExpansion: false, out _); + if (newValue != value && newValues is null) + { + newValues = new SqlExpression[rowValueExpression.Values.Count]; + for (var j = 0; j < i; j++) + { + newValues[j] = newValue; + } + } + + if (newValues is not null) + { + newValues[i] = newValue; + } + } + + // The row value expression itself can never be null + nullable = false; + + return rowValueExpression.Update(newValues ?? rowValueExpression.Values); + } + /// /// Visits a and computes its nullability. /// diff --git a/src/EFCore.PG/Query/Internal/NpgsqlSqlTranslatingExpressionVisitor.cs b/src/EFCore.PG/Query/Internal/NpgsqlSqlTranslatingExpressionVisitor.cs index 8a93353b8c..39ff94b5be 100644 --- a/src/EFCore.PG/Query/Internal/NpgsqlSqlTranslatingExpressionVisitor.cs +++ b/src/EFCore.PG/Query/Internal/NpgsqlSqlTranslatingExpressionVisitor.cs @@ -1,5 +1,6 @@ using System.Collections.ObjectModel; using System.Diagnostics.CodeAnalysis; +using System.Runtime.CompilerServices; using Npgsql.EntityFrameworkCore.PostgreSQL.Internal; using Npgsql.EntityFrameworkCore.PostgreSQL.Query.Expressions.Internal; using Npgsql.EntityFrameworkCore.PostgreSQL.Query.ExpressionTranslators.Internal; @@ -17,7 +18,8 @@ public class NpgsqlSqlTranslatingExpressionVisitor : RelationalSqlTranslatingExp typeof(DateTime).GetConstructor(new[] { typeof(int), typeof(int), typeof(int), typeof(int), typeof(int), typeof(int) })!; private static readonly ConstructorInfo DateTimeCtor3 = - typeof(DateTime).GetConstructor(new[] { typeof(int), typeof(int), typeof(int), typeof(int), typeof(int), typeof(int), typeof(DateTimeKind) })!; + typeof(DateTime).GetConstructor( + new[] { typeof(int), typeof(int), typeof(int), typeof(int), typeof(int), typeof(int), typeof(DateTimeKind) })!; private static readonly ConstructorInfo DateOnlyCtor = typeof(DateOnly).GetConstructor(new[] { typeof(int), typeof(int), typeof(int) })!; @@ -131,8 +133,8 @@ protected override Expression VisitUnary(UnaryExpression unaryExpression) // Translate Length on byte[], but only if the type mapping is for bytea. There's also array of bytes // (mapped to smallint[]), which is handled below with CARDINALITY. - if (sqlOperand!.Type == typeof(byte[]) && - (sqlOperand.TypeMapping is null || sqlOperand.TypeMapping is NpgsqlByteArrayTypeMapping)) + if (sqlOperand!.Type == typeof(byte[]) + && (sqlOperand.TypeMapping is null || sqlOperand.TypeMapping is NpgsqlByteArrayTypeMapping)) { return _sqlExpressionFactory.Function( "length", @@ -142,8 +144,8 @@ protected override Expression VisitUnary(UnaryExpression unaryExpression) typeof(int)); } - return _jsonPocoTranslator.TranslateArrayLength(sqlOperand) ?? - _sqlExpressionFactory.Function( + return _jsonPocoTranslator.TranslateArrayLength(sqlOperand) + ?? _sqlExpressionFactory.Function( "cardinality", new[] { sqlOperand }, nullable: true, @@ -151,7 +153,27 @@ protected override Expression VisitUnary(UnaryExpression unaryExpression) typeof(int)); } - return base.VisitUnary(unaryExpression); + var translated = base.VisitUnary(unaryExpression); + + // Temporary hack around https://github.com/dotnet/efcore/pull/27964 + if (translated == QueryCompilationContext.NotTranslatedExpression) + { + var operand = Visit(unaryExpression.Operand); + + if (TranslationFailed(unaryExpression.Operand, operand, out var sqlOperand)) + { + return QueryCompilationContext.NotTranslatedExpression; + } + + if (unaryExpression.NodeType is ExpressionType.Convert or ExpressionType.ConvertChecked or ExpressionType.TypeAs + && unaryExpression.Type.IsInterface + && operand.Type.IsAssignableTo(unaryExpression.Type)) + { + return sqlOperand!; + } + } + + return translated; } protected override Expression VisitMethodCall(MethodCallExpression methodCall) @@ -413,12 +435,27 @@ protected override Expression VisitBinary(BinaryExpression binaryExpression) protected override Expression VisitNew(NewExpression newExpression) { + // TEMPORARY HACK around https://github.com/dotnet/efcore/pull/27965 + // This row value translation should happen after base.VisitNew, but the base implementation doesn't take evaluatable filters into + // account. Move it down after that's fixed. + + // We translate new ValueTuple(x, y...) to a SQL row value expression: (x, y). + // This is notably done to support row value comparisons: WHERE (x, y) > (3, 4) (see e.g. NpgsqlDbFunctionsExtensions.GreaterThan) + if (newExpression.Type.IsAssignableTo(typeof(ITuple))) + { + return TryTranslateArguments(out var sqlArguments) + ? new PostgresRowValueExpression(sqlArguments) + : QueryCompilationContext.NotTranslatedExpression; + } + var visitedNewExpression = base.VisitNew(newExpression); + if (visitedNewExpression != QueryCompilationContext.NotTranslatedExpression) { return visitedNewExpression; } + // Translate new DateTime(...) -> make_timestamp/make_date if (newExpression.Constructor?.DeclaringType == typeof(DateTime)) { if (newExpression.Constructor == DateTimeCtor1) @@ -471,6 +508,7 @@ protected override Expression VisitNew(NewExpression newExpression) } } + // Translate new DateOnly(...) -> make_date if (newExpression.Constructor == DateOnlyCtor) { return TryTranslateArguments(out var sqlArguments) diff --git a/src/EFCore.PG/Query/NpgsqlSqlExpressionFactory.cs b/src/EFCore.PG/Query/NpgsqlSqlExpressionFactory.cs index 800b318088..21e49604f4 100644 --- a/src/EFCore.PG/Query/NpgsqlSqlExpressionFactory.cs +++ b/src/EFCore.PG/Query/NpgsqlSqlExpressionFactory.cs @@ -1,5 +1,6 @@ using System.Collections; using System.Diagnostics.CodeAnalysis; +using System.Security.Cryptography; using System.Text.RegularExpressions; using Npgsql.EntityFrameworkCore.PostgreSQL.Internal; using Npgsql.EntityFrameworkCore.PostgreSQL.Query.Expressions; @@ -321,14 +322,14 @@ public virtual PostgresBinaryExpression Overlaps(SqlExpression left, SqlExpressi return sqlExpression; } - private SqlExpression ApplyTypeMappingOnSqlBinary(SqlBinaryExpression binary, RelationalTypeMapping? typeMapping) + private SqlBinaryExpression ApplyTypeMappingOnSqlBinary(SqlBinaryExpression binary, RelationalTypeMapping? typeMapping) { // The default SqlExpressionFactory behavior is to assume that the two added operands have the same type, // and so to infer one side's mapping from the other if needed. Here we take care of some heterogeneous // operand cases where this doesn't work: // * Period + Period (???) - if (binary.OperatorType == ExpressionType.Add || binary.OperatorType == ExpressionType.Subtract) + if (binary.OperatorType is ExpressionType.Add or ExpressionType.Subtract) { var (left, right) = (binary.Left, binary.Right); var leftType = left.Type.UnwrapNullableType(); @@ -391,14 +392,42 @@ private SqlExpression ApplyTypeMappingOnSqlBinary(SqlBinaryExpression binary, Re } } - return base.ApplyTypeMapping(binary, typeMapping); + if (binary.OperatorType is ExpressionType.GreaterThan or ExpressionType.LessThan + or ExpressionType.GreaterThanOrEqual or ExpressionType.LessThanOrEqual + && binary.Left is PostgresRowValueExpression leftRowValue + && binary.Right is PostgresRowValueExpression rightRowValue) + { + Check.DebugAssert(leftRowValue.Values.Count == rightRowValue.Values.Count, "Row value count mismatch in comparison"); + + var count = leftRowValue.Values.Count; + var updatedLeftValues = new SqlExpression[count]; + var updatedRightValues = new SqlExpression[count]; + + for (var i = 0; i < count; i++) + { + var updatedElementBinaryExpression = + MakeBinary(binary.OperatorType, leftRowValue.Values[i], rightRowValue.Values[i], typeMapping: null)!; + + updatedLeftValues[i] = updatedElementBinaryExpression.Left; + updatedRightValues[i] = updatedElementBinaryExpression.Right; + } + + binary = new SqlBinaryExpression( + binary.OperatorType, + new PostgresRowValueExpression(updatedLeftValues), + new PostgresRowValueExpression(updatedRightValues), + binary.Type, + binary.TypeMapping); + } + + return (SqlBinaryExpression)base.ApplyTypeMapping(binary, typeMapping); } private SqlExpression ApplyTypeMappingOnRegexMatch(PostgresRegexMatchExpression postgresRegexMatchExpression) { var inferredTypeMapping = ExpressionExtensions.InferTypeMapping( postgresRegexMatchExpression.Match, postgresRegexMatchExpression.Pattern) - ?? (RelationalTypeMapping?)_typeMappingSource.FindMapping(postgresRegexMatchExpression.Match.Type, Dependencies.Model); + ?? _typeMappingSource.FindMapping(postgresRegexMatchExpression.Match.Type, Dependencies.Model); return new PostgresRegexMatchExpression( ApplyTypeMapping(postgresRegexMatchExpression.Match, inferredTypeMapping), diff --git a/test/EFCore.PG.FunctionalTests/Query/NorthwindMiscellaneousQueryNpgsqlTest.cs b/test/EFCore.PG.FunctionalTests/Query/NorthwindMiscellaneousQueryNpgsqlTest.cs index eb50aafbb1..1394f2e9d3 100644 --- a/test/EFCore.PG.FunctionalTests/Query/NorthwindMiscellaneousQueryNpgsqlTest.cs +++ b/test/EFCore.PG.FunctionalTests/Query/NorthwindMiscellaneousQueryNpgsqlTest.cs @@ -1,4 +1,5 @@ using Microsoft.EntityFrameworkCore.TestModels.Northwind; +using Npgsql.EntityFrameworkCore.PostgreSQL.Internal; using Xunit.Sdk; namespace Npgsql.EntityFrameworkCore.PostgreSQL.Query; @@ -17,6 +18,7 @@ public NorthwindMiscellaneousQueryNpgsqlTest( public override async Task Query_expression_with_to_string_and_contains(bool async) { await base.Query_expression_with_to_string_and_contains(async); + AssertContainsSqlFragment(@"strpos(o.""EmployeeID""::text, '10') > 0"); } @@ -346,6 +348,133 @@ public async Task Array_All_ILike(bool async) #endregion Any/All Like + #region Row value comparisons + + [ConditionalFact] + public async Task Row_value_GreaterThan() + { + await using var ctx = CreateContext(); + + _ = await ctx.Customers + .Where(c => EF.Functions.GreaterThan( + new ValueTuple(c.City, c.CustomerID), + new ValueTuple("Buenos Aires", "OCEAN"))) + .CountAsync(); + + AssertSql( + @"SELECT COUNT(*)::INT +FROM ""Customers"" AS c +WHERE (c.""City"", c.""CustomerID"") > ('Buenos Aires', 'OCEAN')"); + } + + [ConditionalFact] + public async Task Row_value_GreaterThan_with_differing_types() + { + await using var ctx = CreateContext(); + + _ = await ctx.Orders + .Where(o => EF.Functions.GreaterThan( + new ValueTuple(o.CustomerID, o.OrderID), + new ValueTuple("ALFKI", 10702))) + .CountAsync(); + + AssertSql( + @"SELECT COUNT(*)::INT +FROM ""Orders"" AS o +WHERE (o.""CustomerID"", o.""OrderID"") > ('ALFKI', 10702)"); + } + + [ConditionalFact] + public async Task Row_value_GreaterThan_with_parameter() + { + await using var ctx = CreateContext(); + + var city1 = "Buenos Aires"; + + _ = await ctx.Customers + .Where(c => EF.Functions.GreaterThan( + new ValueTuple(c.City, c.CustomerID), + new ValueTuple(city1, "OCEAN"))) + .CountAsync(); + + AssertSql( + @"@__city1_1='Buenos Aires' + +SELECT COUNT(*)::INT +FROM ""Customers"" AS c +WHERE (c.""City"", c.""CustomerID"") > (@__city1_1, 'OCEAN')"); + } + + [ConditionalFact] + public async Task Row_value_LessThan() + { + await using var ctx = CreateContext(); + + _ = await ctx.Customers + .Where(c => EF.Functions.LessThan( + new ValueTuple(c.City, c.CustomerID), + new ValueTuple("Buenos Aires", "OCEAN"))) + .CountAsync(); + + AssertSql( + @"SELECT COUNT(*)::INT +FROM ""Customers"" AS c +WHERE (c.""City"", c.""CustomerID"") < ('Buenos Aires', 'OCEAN')"); + } + + [ConditionalFact] + public async Task Row_value_GreaterThanOrEqual() + { + await using var ctx = CreateContext(); + + _ = await ctx.Customers + .Where(c => EF.Functions.GreaterThanOrEqual( + new ValueTuple(c.City, c.CustomerID), + new ValueTuple("Buenos Aires", "OCEAN"))) + .CountAsync(); + + AssertSql( + @"SELECT COUNT(*)::INT +FROM ""Customers"" AS c +WHERE (c.""City"", c.""CustomerID"") >= ('Buenos Aires', 'OCEAN')"); + } + + [ConditionalFact] + public async Task Row_value_LessThanOrEqual() + { + await using var ctx = CreateContext(); + + _ = await ctx.Customers + .Where(c => EF.Functions.LessThanOrEqual( + new ValueTuple(c.City, c.CustomerID), + new ValueTuple("Buenos Aires", "OCEAN"))) + .CountAsync(); + + AssertSql( + @"SELECT COUNT(*)::INT +FROM ""Customers"" AS c +WHERE (c.""City"", c.""CustomerID"") <= ('Buenos Aires', 'OCEAN')"); + } + + [ConditionalFact] + public async Task Row_value_parameter_count_mismatch() + { + await using var ctx = CreateContext(); + + var exception = await Assert.ThrowsAsync( + () => ctx.Customers + .Where(c => EF.Functions.LessThanOrEqual( + new ValueTuple(c.City, c.CustomerID), + new ValueTuple("Buenos Aires", "OCEAN", "foo"))) + .CountAsync()); + + Assert.Equal( + NpgsqlStrings.RowValueMethodRequiresTwoArraysOfSameLength(nameof(NpgsqlDbFunctionsExtensions.LessThanOrEqual)), + exception.Message); + } + + #endregion Row value comparisons + [ConditionalFact] // #1560 public async Task Lateral_join_with_table_is_rewritten_with_subquery() {