diff --git a/plugin/trino-mysql/pom.xml b/plugin/trino-mysql/pom.xml index 3751817e19a8..1a61d6fdf7b2 100644 --- a/plugin/trino-mysql/pom.xml +++ b/plugin/trino-mysql/pom.xml @@ -160,6 +160,12 @@ test + + io.trino + trino-parser + test + + io.trino trino-plugin-toolkit diff --git a/plugin/trino-mysql/src/main/java/io/trino/plugin/mysql/MySqlClient.java b/plugin/trino-mysql/src/main/java/io/trino/plugin/mysql/MySqlClient.java index 14d86597eaf2..ca4475b9e245 100644 --- a/plugin/trino-mysql/src/main/java/io/trino/plugin/mysql/MySqlClient.java +++ b/plugin/trino-mysql/src/main/java/io/trino/plugin/mysql/MySqlClient.java @@ -270,6 +270,7 @@ public MySqlClient( this.connectorExpressionRewriter = JdbcConnectorExpressionRewriterBuilder.newBuilder() .addStandardRules(this::quoted) + .withTypeClass("integer_type", ImmutableSet.of("tinyint", "smallint", "integer", "bigint")) // No "real" on the list; pushdown on REAL is disabled also in toColumnMapping .withTypeClass("numeric_type", ImmutableSet.of("tinyint", "smallint", "integer", "bigint", "decimal", "double")) .map("$equal(left: numeric_type, right: numeric_type)").to("left = right") @@ -279,8 +280,17 @@ public MySqlClient( .map("$less_than_or_equal(left: numeric_type, right: numeric_type)").to("left <= right") .map("$greater_than(left: numeric_type, right: numeric_type)").to("left > right") .map("$greater_than_or_equal(left: numeric_type, right: numeric_type)").to("left >= right") + .map("$add(left: integer_type, right: integer_type)").to("left + right") + .map("$subtract(left: integer_type, right: integer_type)").to("left - right") + .map("$multiply(left: integer_type, right: integer_type)").to("left * right") + .map("$divide(left: integer_type, right: integer_type)").to("left / right") + .map("$modulus(left: integer_type, right: integer_type)").to("left % right") + .map("$negate(value: integer_type)").to("-value") .add(new RewriteLikeWithCaseSensitivity()) .add(new RewriteLikeEscapeWithCaseSensitivity()) + .map("$not($is_null(value))").to("value IS NOT NULL") + .map("$is_null(value)").to("value IS NULL") + .map("$nullif(first, second)").to("NULLIF(first, second)") .build(); JdbcTypeHandle bigintTypeHandle = new JdbcTypeHandle(Types.BIGINT, Optional.of("bigint"), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty()); diff --git a/plugin/trino-mysql/src/test/java/io/trino/plugin/mysql/TestMySqlClient.java b/plugin/trino-mysql/src/test/java/io/trino/plugin/mysql/TestMySqlClient.java index 690fcf5f97f7..6462f39267be 100644 --- a/plugin/trino-mysql/src/test/java/io/trino/plugin/mysql/TestMySqlClient.java +++ b/plugin/trino-mysql/src/test/java/io/trino/plugin/mysql/TestMySqlClient.java @@ -13,6 +13,7 @@ */ package io.trino.plugin.mysql; +import com.google.common.collect.ImmutableMap; import io.trino.plugin.base.mapping.DefaultIdentifierMapping; import io.trino.plugin.jdbc.BaseJdbcConfig; import io.trino.plugin.jdbc.ColumnMapping; @@ -22,11 +23,25 @@ import io.trino.plugin.jdbc.JdbcExpression; import io.trino.plugin.jdbc.JdbcStatisticsConfig; import io.trino.plugin.jdbc.JdbcTypeHandle; +import io.trino.plugin.jdbc.QueryParameter; +import io.trino.plugin.jdbc.expression.ParameterizedExpression; import io.trino.plugin.jdbc.logging.RemoteQueryModifier; import io.trino.spi.connector.AggregateFunction; import io.trino.spi.connector.ColumnHandle; import io.trino.spi.expression.ConnectorExpression; import io.trino.spi.expression.Variable; +import io.trino.spi.type.Type; +import io.trino.sql.planner.ConnectorExpressionTranslator; +import io.trino.sql.planner.LiteralEncoder; +import io.trino.sql.planner.Symbol; +import io.trino.sql.planner.TypeProvider; +import io.trino.sql.tree.ArithmeticBinaryExpression; +import io.trino.sql.tree.ArithmeticUnaryExpression; +import io.trino.sql.tree.Expression; +import io.trino.sql.tree.IsNotNullPredicate; +import io.trino.sql.tree.IsNullPredicate; +import io.trino.sql.tree.NullIfExpression; +import io.trino.sql.tree.SymbolReference; import org.junit.jupiter.api.Test; import java.sql.Types; @@ -34,11 +49,17 @@ import java.util.Map; import java.util.Optional; +import static com.google.common.collect.ImmutableMap.toImmutableMap; +import static io.trino.SessionTestUtils.TEST_SESSION; import static io.trino.spi.type.BigintType.BIGINT; import static io.trino.spi.type.BooleanType.BOOLEAN; import static io.trino.spi.type.DoubleType.DOUBLE; +import static io.trino.spi.type.VarcharType.createVarcharType; +import static io.trino.sql.planner.TestingPlannerContext.PLANNER_CONTEXT; +import static io.trino.sql.planner.TypeAnalyzer.createTestingTypeAnalyzer; import static io.trino.testing.TestingConnectorSession.SESSION; import static io.trino.type.InternalTypeManager.TESTING_TYPE_MANAGER; +import static java.lang.String.format; import static org.assertj.core.api.Assertions.assertThat; public class TestMySqlClient @@ -57,6 +78,13 @@ public class TestMySqlClient .setJdbcTypeHandle(new JdbcTypeHandle(Types.DOUBLE, Optional.of("double"), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty())) .build(); + private static final JdbcColumnHandle VARCHAR_COLUMN = + JdbcColumnHandle.builder() + .setColumnName("c_varchar") + .setColumnType(createVarcharType(10)) + .setJdbcTypeHandle(new JdbcTypeHandle(Types.VARCHAR, Optional.of("varchar"), Optional.of(10), Optional.empty(), Optional.empty(), Optional.empty())) + .build(); + private static final JdbcClient JDBC_CLIENT = new MySqlClient( new BaseJdbcConfig(), new JdbcStatisticsConfig(), @@ -68,6 +96,8 @@ public class TestMySqlClient new DefaultIdentifierMapping(), RemoteQueryModifier.NONE); + private static final LiteralEncoder LITERAL_ENCODER = new LiteralEncoder(PLANNER_CONTEXT); + @Test public void testImplementCount() { @@ -169,4 +199,99 @@ private static void testImplementAggregation(AggregateFunction aggregateFunction .isEqualTo(aggregateFunction.getOutputType()); } } + + @Test + public void testConvertArithmeticBinary() + { + for (ArithmeticBinaryExpression.Operator operator : ArithmeticBinaryExpression.Operator.values()) { + ParameterizedExpression converted = JDBC_CLIENT.convertPredicate( + SESSION, + translateToConnectorExpression( + new ArithmeticBinaryExpression( + operator, + new SymbolReference("c_bigint_symbol"), + LITERAL_ENCODER.toExpression(42L, BIGINT)), + Map.of("c_bigint_symbol", BIGINT)), + Map.of("c_bigint_symbol", BIGINT_COLUMN)) + .orElseThrow(); + + assertThat(converted.expression()).isEqualTo(format("(`c_bigint`) %s (?)", operator.getValue())); + assertThat(converted.parameters()).isEqualTo(List.of(new QueryParameter(BIGINT, Optional.of(42L)))); + } + } + + @Test + public void testConvertArithmeticUnaryMinus() + { + ParameterizedExpression converted = JDBC_CLIENT.convertPredicate( + SESSION, + translateToConnectorExpression( + new ArithmeticUnaryExpression( + ArithmeticUnaryExpression.Sign.MINUS, + new SymbolReference("c_bigint_symbol")), + Map.of("c_bigint_symbol", BIGINT)), + Map.of("c_bigint_symbol", BIGINT_COLUMN)) + .orElseThrow(); + + assertThat(converted.expression()).isEqualTo("-(`c_bigint`)"); + assertThat(converted.parameters()).isEqualTo(List.of()); + } + + @Test + public void testConvertIsNull() + { + // c_varchar IS NULL + ParameterizedExpression converted = JDBC_CLIENT.convertPredicate(SESSION, + translateToConnectorExpression( + new IsNullPredicate( + new SymbolReference("c_varchar_symbol")), + Map.of("c_varchar_symbol", VARCHAR_COLUMN.getColumnType())), + Map.of("c_varchar_symbol", VARCHAR_COLUMN)) + .orElseThrow(); + assertThat(converted.expression()).isEqualTo("(`c_varchar`) IS NULL"); + assertThat(converted.parameters()).isEqualTo(List.of()); + } + + @Test + public void testConvertIsNotNull() + { + // c_varchar IS NOT NULL + ParameterizedExpression converted = JDBC_CLIENT.convertPredicate(SESSION, + translateToConnectorExpression( + new IsNotNullPredicate( + new SymbolReference("c_varchar_symbol")), + Map.of("c_varchar_symbol", VARCHAR_COLUMN.getColumnType())), + Map.of("c_varchar_symbol", VARCHAR_COLUMN)) + .orElseThrow(); + assertThat(converted.expression()).isEqualTo("(`c_varchar`) IS NOT NULL"); + assertThat(converted.parameters()).isEqualTo(List.of()); + } + + @Test + public void testConvertNullIf() + { + // nullif(a_varchar, b_varchar) + ParameterizedExpression converted = JDBC_CLIENT.convertPredicate(SESSION, + translateToConnectorExpression( + new NullIfExpression( + new SymbolReference("a_varchar_symbol"), + new SymbolReference("b_varchar_symbol")), + ImmutableMap.of("a_varchar_symbol", VARCHAR_COLUMN.getColumnType(), "b_varchar_symbol", VARCHAR_COLUMN.getColumnType())), + ImmutableMap.of("a_varchar_symbol", VARCHAR_COLUMN, "b_varchar_symbol", VARCHAR_COLUMN)) + .orElseThrow(); + assertThat(converted.expression()).isEqualTo("NULLIF((`c_varchar`), (`c_varchar`))"); + assertThat(converted.parameters()).isEqualTo(List.of()); + } + + private ConnectorExpression translateToConnectorExpression(Expression expression, Map symbolTypes) + { + return ConnectorExpressionTranslator.translate( + TEST_SESSION, + expression, + TypeProvider.viewOf(symbolTypes.entrySet().stream() + .collect(toImmutableMap(entry -> new Symbol(entry.getKey()), Map.Entry::getValue))), + PLANNER_CONTEXT, + createTestingTypeAnalyzer(PLANNER_CONTEXT)) + .orElseThrow(); + } } diff --git a/plugin/trino-mysql/src/test/java/io/trino/plugin/mysql/TestMySqlConnectorTest.java b/plugin/trino-mysql/src/test/java/io/trino/plugin/mysql/TestMySqlConnectorTest.java index b9929673e15c..1b19e5d699cd 100644 --- a/plugin/trino-mysql/src/test/java/io/trino/plugin/mysql/TestMySqlConnectorTest.java +++ b/plugin/trino-mysql/src/test/java/io/trino/plugin/mysql/TestMySqlConnectorTest.java @@ -15,7 +15,13 @@ import com.google.common.collect.ImmutableMap; import io.trino.testing.QueryRunner; +import io.trino.testing.TestingConnectorBehavior; +import io.trino.testing.sql.TestTable; +import org.junit.jupiter.api.Test; +import java.util.List; + +import static com.google.common.base.Verify.verify; import static io.trino.plugin.mysql.MySqlQueryRunner.createMySqlQueryRunner; import static org.assertj.core.api.Assertions.assertThat; @@ -30,9 +36,80 @@ protected QueryRunner createQueryRunner() return createMySqlQueryRunner(mySqlServer, ImmutableMap.of(), ImmutableMap.of(), REQUIRED_TPCH_TABLES); } + @Override + protected boolean hasBehavior(TestingConnectorBehavior connectorBehavior) + { + return switch (connectorBehavior) { + case SUPPORTS_PREDICATE_EXPRESSION_PUSHDOWN -> { + // TODO remove once super has this set to true + verify(!super.hasBehavior(connectorBehavior)); + yield true; + } + default -> super.hasBehavior(connectorBehavior); + }; + } + @Override protected void verifyColumnNameLengthFailurePermissible(Throwable e) { assertThat(e).hasMessageMatching("(Incorrect column name '.*'|Identifier name '.*' is too long)"); } + + @Test + public void testIsNullPredicatePushdown() + { + assertThat(query("SELECT nationkey FROM nation WHERE name IS NULL")).isFullyPushedDown(); + assertThat(query("SELECT nationkey FROM nation WHERE name IS NULL OR regionkey = 4")).isFullyPushedDown(); + + try (TestTable table = new TestTable( + getQueryRunner()::execute, + "test_is_null_predicate_pushdown", + "(a_int integer, a_varchar varchar(1))", + List.of( + "1, 'A'", + "2, 'B'", + "1, NULL", + "2, NULL"))) { + assertThat(query("SELECT a_int FROM " + table.getName() + " WHERE a_varchar IS NULL OR a_int = 1")).isFullyPushedDown(); + } + } + + @Test + public void testIsNotNullPredicatePushdown() + { + assertThat(query("SELECT nationkey FROM nation WHERE name IS NOT NULL OR regionkey = 4")).isFullyPushedDown(); + + try (TestTable table = new TestTable( + getQueryRunner()::execute, + "test_is_not_null_predicate_pushdown", + "(a_int integer, a_varchar varchar(1))", + List.of( + "1, 'A'", + "2, 'B'", + "1, NULL", + "2, NULL"))) { + assertThat(query("SELECT a_int FROM " + table.getName() + " WHERE a_varchar IS NOT NULL OR a_int = 1")).isFullyPushedDown(); + } + } + + @Test + public void testNullIfPredicatePushdown() + { + assertThat(query("SELECT nationkey FROM nation WHERE NULLIF(name, 'ALGERIA') IS NULL")) + .matches("VALUES BIGINT '0'") + .isFullyPushedDown(); + + assertThat(query("SELECT name FROM nation WHERE NULLIF(nationkey, 0) IS NULL")) + .matches("VALUES CAST('ALGERIA' AS varchar(255))") + .isFullyPushedDown(); + + assertThat(query("SELECT nationkey FROM nation WHERE NULLIF(name, 'Algeria') IS NULL")) + .returnsEmptyResult() + .isFullyPushedDown(); + + // NULLIF returns the first argument because arguments aren't the same + assertThat(query("SELECT nationkey FROM nation WHERE NULLIF(name, 'Name not found') = name")) + .matches("SELECT nationkey FROM nation") + .isFullyPushedDown(); + } }