diff --git a/plugin/trino-base-jdbc/src/test/java/io/trino/plugin/jdbc/BaseJdbcConnectorTest.java b/plugin/trino-base-jdbc/src/test/java/io/trino/plugin/jdbc/BaseJdbcConnectorTest.java index 118ed56aa5f1..030f13d41509 100644 --- a/plugin/trino-base-jdbc/src/test/java/io/trino/plugin/jdbc/BaseJdbcConnectorTest.java +++ b/plugin/trino-base-jdbc/src/test/java/io/trino/plugin/jdbc/BaseJdbcConnectorTest.java @@ -90,6 +90,7 @@ import static io.trino.testing.TestingConnectorBehavior.SUPPORTS_JOIN_PUSHDOWN_WITH_VARCHAR_EQUALITY; import static io.trino.testing.TestingConnectorBehavior.SUPPORTS_JOIN_PUSHDOWN_WITH_VARCHAR_INEQUALITY; import static io.trino.testing.TestingConnectorBehavior.SUPPORTS_LIMIT_PUSHDOWN; +import static io.trino.testing.TestingConnectorBehavior.SUPPORTS_PREDICATE_ARITHMETIC_EXPRESSION_PUSHDOWN; import static io.trino.testing.TestingConnectorBehavior.SUPPORTS_PREDICATE_EXPRESSION_PUSHDOWN_WITH_LIKE; import static io.trino.testing.TestingConnectorBehavior.SUPPORTS_PREDICATE_PUSHDOWN; import static io.trino.testing.TestingConnectorBehavior.SUPPORTS_PREDICATE_PUSHDOWN_WITH_VARCHAR_EQUALITY; @@ -1007,6 +1008,34 @@ public void testNullSensitiveTopNPushdown() } } + @Test + public void testArithmeticPredicatePushdown() + { + if (!hasBehavior(SUPPORTS_PREDICATE_ARITHMETIC_EXPRESSION_PUSHDOWN)) { + assertThat(query("SELECT shippriority FROM orders WHERE shippriority % 4 = 0")).isNotFullyPushedDown(FilterNode.class); + return; + } + assertThat(query("SELECT shippriority FROM orders WHERE shippriority % 4 = 0")).isFullyPushedDown(); + + assertThat(query("SELECT nationkey, name, regionkey FROM nation WHERE nationkey > 0 AND (nationkey - regionkey) % nationkey = 2")) + .isFullyPushedDown() + .matches("VALUES (BIGINT '3', CAST('CANADA' AS varchar(25)), BIGINT '1')"); + + // some databases calculate remainder instead of modulus when one of the values is negative + assertThat(query("SELECT nationkey, name, regionkey FROM nation WHERE nationkey > 0 AND (nationkey - regionkey) % -nationkey = 2")) + .isFullyPushedDown() + .matches("VALUES (BIGINT '3', CAST('CANADA' AS varchar(25)), BIGINT '1')"); + + assertThatThrownBy(() -> query("SELECT nationkey, name, regionkey FROM nation WHERE nationkey > 0 AND (nationkey - regionkey) % 0 = 2")) + .hasMessageContaining("by zero"); + + // Expression that evaluates to 0 for some rows on RHS of modulus + assertThatThrownBy(() -> query("SELECT nationkey, name, regionkey FROM nation WHERE nationkey > 0 AND (nationkey - regionkey) % (regionkey - 1) = 2")) + .hasMessageContaining("by zero"); + + // TODO add coverage for other arithmetic pushdowns https://github.com/trinodb/trino/issues/14808 + } + @Test public void testCaseSensitiveTopNPushdown() { diff --git a/plugin/trino-postgresql/src/test/java/io/trino/plugin/postgresql/TestPostgreSqlConnectorTest.java b/plugin/trino-postgresql/src/test/java/io/trino/plugin/postgresql/TestPostgreSqlConnectorTest.java index 9ecf448fa347..db41ac30f1f6 100644 --- a/plugin/trino-postgresql/src/test/java/io/trino/plugin/postgresql/TestPostgreSqlConnectorTest.java +++ b/plugin/trino-postgresql/src/test/java/io/trino/plugin/postgresql/TestPostgreSqlConnectorTest.java @@ -71,7 +71,6 @@ import static java.util.stream.Collectors.joining; import static java.util.stream.IntStream.range; import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.assertThatThrownBy; import static org.testng.Assert.assertTrue; public class TestPostgreSqlConnectorTest @@ -706,25 +705,6 @@ public void testOrPredicatePushdown() assertThat(query("SELECT * FROM nation WHERE name = NULL OR regionkey = 4")).isNotFullyPushedDown(FilterNode.class); // TODO `name = NULL` should be eliminated by the engine } - @Test - public void testArithmeticPredicatePushdown() - { - assertThat(query("SELECT nationkey, name, regionkey FROM nation WHERE nationkey > 0 AND (nationkey - regionkey) % nationkey = 2")) - .isFullyPushedDown() - .matches("VALUES (BIGINT '3', CAST('CANADA' AS varchar(25)), BIGINT '1')"); - - // some databases calculate remainder instead of modulus when one of the values is negative - assertThat(query("SELECT nationkey, name, regionkey FROM nation WHERE nationkey > 0 AND (nationkey - regionkey) % -nationkey = 2")) - .isFullyPushedDown() - .matches("VALUES (BIGINT '3', CAST('CANADA' AS varchar(25)), BIGINT '1')"); - - assertThatThrownBy(() -> query("SELECT nationkey, name, regionkey FROM nation WHERE nationkey > 0 AND (nationkey - regionkey) % 0 = 2")) - .hasMessageContaining("ERROR: division by zero"); - // Expression that evaluates to 0 for some rows on RHS of modulus - assertThatThrownBy(() -> query("SELECT nationkey, name, regionkey FROM nation WHERE nationkey > 0 AND (nationkey - regionkey) % (regionkey - 1) = 2")) - .hasMessageContaining("ERROR: division by zero"); - } - @Test public void testLikePredicatePushdown() { diff --git a/plugin/trino-sqlserver/src/main/java/io/trino/plugin/sqlserver/RewriteUnicodeVarcharConstant.java b/plugin/trino-sqlserver/src/main/java/io/trino/plugin/sqlserver/RewriteUnicodeVarcharConstant.java new file mode 100644 index 000000000000..71557c36f1ae --- /dev/null +++ b/plugin/trino-sqlserver/src/main/java/io/trino/plugin/sqlserver/RewriteUnicodeVarcharConstant.java @@ -0,0 +1,61 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.sqlserver; + +import com.google.common.base.CharMatcher; +import io.airlift.slice.Slice; +import io.trino.matching.Captures; +import io.trino.matching.Pattern; +import io.trino.plugin.base.expression.ConnectorExpressionRule; +import io.trino.spi.expression.Constant; +import io.trino.spi.type.VarcharType; + +import java.util.Optional; + +import static io.trino.plugin.base.expression.ConnectorExpressionPatterns.constant; +import static io.trino.plugin.base.expression.ConnectorExpressionPatterns.type; + +public class RewriteUnicodeVarcharConstant + implements ConnectorExpressionRule +{ + private static final Pattern PATTERN = constant().with(type().matching(VarcharType.class::isInstance)); + private static final CharMatcher UNICODE_CHARACTER_MATCHER = CharMatcher.ascii().negate().precomputed(); + + @Override + public Pattern getPattern() + { + return PATTERN; + } + + @Override + public Optional rewrite(Constant constant, Captures captures, RewriteContext context) + { + if (constant.getValue() == null) { + return Optional.empty(); + } + Slice slice = (Slice) constant.getValue(); + if (slice == null) { + return Optional.empty(); + } + + String sliceUtf8String = slice.toStringUtf8(); + boolean isUnicodeString = UNICODE_CHARACTER_MATCHER.matchesAnyOf(sliceUtf8String); + + if (isUnicodeString) { + return Optional.of("N'" + sliceUtf8String.replace("'", "''") + "'"); + } + + return Optional.of("'" + sliceUtf8String.replace("'", "''") + "'"); + } +} diff --git a/plugin/trino-sqlserver/src/main/java/io/trino/plugin/sqlserver/SqlServerClient.java b/plugin/trino-sqlserver/src/main/java/io/trino/plugin/sqlserver/SqlServerClient.java index 25a9b0dd3dd1..6d90783949b5 100644 --- a/plugin/trino-sqlserver/src/main/java/io/trino/plugin/sqlserver/SqlServerClient.java +++ b/plugin/trino-sqlserver/src/main/java/io/trino/plugin/sqlserver/SqlServerClient.java @@ -51,6 +51,8 @@ import io.trino.plugin.jdbc.aggregation.ImplementMinMax; import io.trino.plugin.jdbc.aggregation.ImplementSum; import io.trino.plugin.jdbc.expression.JdbcConnectorExpressionRewriterBuilder; +import io.trino.plugin.jdbc.expression.RewriteComparison; +import io.trino.plugin.jdbc.expression.RewriteIn; import io.trino.plugin.jdbc.mapping.IdentifierMapping; import io.trino.spi.TrinoException; import io.trino.spi.connector.AggregateFunction; @@ -60,6 +62,7 @@ import io.trino.spi.connector.JoinCondition; import io.trino.spi.connector.JoinStatistics; import io.trino.spi.connector.JoinType; +import io.trino.spi.expression.ConnectorExpression; import io.trino.spi.statistics.ColumnStatistics; import io.trino.spi.statistics.Estimate; import io.trino.spi.statistics.TableStatistics; @@ -213,7 +216,25 @@ public SqlServerClient( this.statisticsEnabled = statisticsConfig.isEnabled(); this.connectorExpressionRewriter = JdbcConnectorExpressionRewriterBuilder.newBuilder() + // Only SqlServer requires N prefix for unicode characters (SQL-92 standard), + // so we add this rule to support such cases for pushdowns + .add(new RewriteUnicodeVarcharConstant()) .addStandardRules(this::quoted) + .add(new RewriteComparison(ImmutableSet.of(RewriteComparison.ComparisonOperator.EQUAL, RewriteComparison.ComparisonOperator.NOT_EQUAL))) + .add(new RewriteIn()) + .withTypeClass("integer_type", ImmutableSet.of("tinyint", "smallint", "integer", "bigint")) + .map("$add(left: integer_type, right: integer_type)").to("left + right") + .map("$subtract(left: integer_type, right: integer_type)").to("left - right") + .map("$multiply(left: integer_type, right: integer_type)").to("left * right") + .map("$divide(left: integer_type, right: integer_type)").to("left / right") + .map("$modulus(left: integer_type, right: integer_type)").to("left % right") + .map("$negate(value: integer_type)").to("-value") + .map("$like(value: varchar, pattern: varchar): boolean").to("value LIKE pattern") + .map("$like(value: varchar, pattern: varchar, escape: varchar(1)): boolean").to("value LIKE pattern ESCAPE escape") + .map("$not($is_null(value))").to("value IS NOT NULL") + .map("$not(value: boolean)").to("NOT value") + .map("$is_null(value)").to("value IS NULL") + .map("$nullif(first, second)").to("NULLIF(first, second)") .build(); this.aggregateFunctionRewriter = new AggregateFunctionRewriter<>( @@ -338,6 +359,12 @@ protected void renameColumn(ConnectorSession session, Connection connection, Rem } } + @Override + public Optional convertPredicate(ConnectorSession session, ConnectorExpression expression, Map assignments) + { + return connectorExpressionRewriter.rewrite(session, expression, assignments); + } + @Override public void renameSchema(ConnectorSession session, String schemaName, String newSchemaName) { diff --git a/plugin/trino-sqlserver/src/test/java/io/trino/plugin/sqlserver/BaseSqlServerConnectorTest.java b/plugin/trino-sqlserver/src/test/java/io/trino/plugin/sqlserver/BaseSqlServerConnectorTest.java index de357eb8a783..7f623c99ca92 100644 --- a/plugin/trino-sqlserver/src/test/java/io/trino/plugin/sqlserver/BaseSqlServerConnectorTest.java +++ b/plugin/trino-sqlserver/src/test/java/io/trino/plugin/sqlserver/BaseSqlServerConnectorTest.java @@ -47,7 +47,6 @@ import static java.util.stream.Collectors.joining; import static java.util.stream.IntStream.range; import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.assertThatThrownBy; import static org.testng.Assert.assertEquals; import static org.testng.Assert.assertTrue; @@ -63,6 +62,7 @@ protected boolean hasBehavior(TestingConnectorBehavior connectorBehavior) case SUPPORTS_PREDICATE_PUSHDOWN_WITH_VARCHAR_INEQUALITY: return false; + case SUPPORTS_PREDICATE_EXPRESSION_PUSHDOWN: case SUPPORTS_AGGREGATION_PUSHDOWN_STDDEV: case SUPPORTS_AGGREGATION_PUSHDOWN_VARIANCE: return true; @@ -406,13 +406,6 @@ public void testShowCreateTable() ")"); } - @Override - public void testDeleteWithLike() - { - assertThatThrownBy(super::testDeleteWithLike) - .hasStackTraceContaining("TrinoException: Unsupported delete"); - } - @Test(dataProvider = "dataCompression") public void testCreateWithDataCompression(DataCompression dataCompression) { diff --git a/testing/trino-faulttolerant-tests/src/test/java/io/trino/faulttolerant/BaseFailureRecoveryTest.java b/testing/trino-faulttolerant-tests/src/test/java/io/trino/faulttolerant/BaseFailureRecoveryTest.java index ca2e8cb75695..cd052dfcb95f 100644 --- a/testing/trino-faulttolerant-tests/src/test/java/io/trino/faulttolerant/BaseFailureRecoveryTest.java +++ b/testing/trino-faulttolerant-tests/src/test/java/io/trino/faulttolerant/BaseFailureRecoveryTest.java @@ -251,7 +251,13 @@ protected void testSelect(String query, Optional session, Consumer getQueryRunner().execute("SELECT * FROM nation WHERE regionKey / nationKey - 1 = 0")) + // Some connectors have pushdowns enabled for arithmetic operations (like SqlServer), + // so exception will come not from trino, but from datasource itself + Session withoutPushdown = Session.builder(this.getSession()) + .setSystemProperty("allow_pushdown_into_connectors", "false") + .build(); + + assertThatThrownBy(() -> getQueryRunner().execute(withoutPushdown, "SELECT * FROM nation WHERE regionKey / nationKey - 1 = 0")) .hasMessageMatching("(?i).*Division by zero.*"); // some errors come back with different casing. assertThatQuery("SELECT * FROM nation") diff --git a/testing/trino-testing/src/main/java/io/trino/testing/TestingConnectorBehavior.java b/testing/trino-testing/src/main/java/io/trino/testing/TestingConnectorBehavior.java index 13ce0b7eafae..51773b514654 100644 --- a/testing/trino-testing/src/main/java/io/trino/testing/TestingConnectorBehavior.java +++ b/testing/trino-testing/src/main/java/io/trino/testing/TestingConnectorBehavior.java @@ -23,6 +23,7 @@ public enum TestingConnectorBehavior SUPPORTS_PREDICATE_PUSHDOWN_WITH_VARCHAR_EQUALITY(SUPPORTS_PREDICATE_PUSHDOWN), SUPPORTS_PREDICATE_PUSHDOWN_WITH_VARCHAR_INEQUALITY(SUPPORTS_PREDICATE_PUSHDOWN), SUPPORTS_PREDICATE_EXPRESSION_PUSHDOWN(SUPPORTS_PREDICATE_PUSHDOWN), + SUPPORTS_PREDICATE_ARITHMETIC_EXPRESSION_PUSHDOWN(SUPPORTS_PREDICATE_EXPRESSION_PUSHDOWN), SUPPORTS_PREDICATE_EXPRESSION_PUSHDOWN_WITH_LIKE(SUPPORTS_PREDICATE_EXPRESSION_PUSHDOWN), SUPPORTS_DYNAMIC_FILTER_PUSHDOWN(false),