diff --git a/plugin/trino-mysql/pom.xml b/plugin/trino-mysql/pom.xml
index 3751817e19a8..1a61d6fdf7b2 100644
--- a/plugin/trino-mysql/pom.xml
+++ b/plugin/trino-mysql/pom.xml
@@ -160,6 +160,12 @@
test
+
+ io.trino
+ trino-parser
+ test
+
+
io.trino
trino-plugin-toolkit
diff --git a/plugin/trino-mysql/src/main/java/io/trino/plugin/mysql/MySqlClient.java b/plugin/trino-mysql/src/main/java/io/trino/plugin/mysql/MySqlClient.java
index 14d86597eaf2..ca4475b9e245 100644
--- a/plugin/trino-mysql/src/main/java/io/trino/plugin/mysql/MySqlClient.java
+++ b/plugin/trino-mysql/src/main/java/io/trino/plugin/mysql/MySqlClient.java
@@ -270,6 +270,7 @@ public MySqlClient(
this.connectorExpressionRewriter = JdbcConnectorExpressionRewriterBuilder.newBuilder()
.addStandardRules(this::quoted)
+ .withTypeClass("integer_type", ImmutableSet.of("tinyint", "smallint", "integer", "bigint"))
// No "real" on the list; pushdown on REAL is disabled also in toColumnMapping
.withTypeClass("numeric_type", ImmutableSet.of("tinyint", "smallint", "integer", "bigint", "decimal", "double"))
.map("$equal(left: numeric_type, right: numeric_type)").to("left = right")
@@ -279,8 +280,17 @@ public MySqlClient(
.map("$less_than_or_equal(left: numeric_type, right: numeric_type)").to("left <= right")
.map("$greater_than(left: numeric_type, right: numeric_type)").to("left > right")
.map("$greater_than_or_equal(left: numeric_type, right: numeric_type)").to("left >= right")
+ .map("$add(left: integer_type, right: integer_type)").to("left + right")
+ .map("$subtract(left: integer_type, right: integer_type)").to("left - right")
+ .map("$multiply(left: integer_type, right: integer_type)").to("left * right")
+ .map("$divide(left: integer_type, right: integer_type)").to("left / right")
+ .map("$modulus(left: integer_type, right: integer_type)").to("left % right")
+ .map("$negate(value: integer_type)").to("-value")
.add(new RewriteLikeWithCaseSensitivity())
.add(new RewriteLikeEscapeWithCaseSensitivity())
+ .map("$not($is_null(value))").to("value IS NOT NULL")
+ .map("$is_null(value)").to("value IS NULL")
+ .map("$nullif(first, second)").to("NULLIF(first, second)")
.build();
JdbcTypeHandle bigintTypeHandle = new JdbcTypeHandle(Types.BIGINT, Optional.of("bigint"), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty());
diff --git a/plugin/trino-mysql/src/test/java/io/trino/plugin/mysql/TestMySqlClient.java b/plugin/trino-mysql/src/test/java/io/trino/plugin/mysql/TestMySqlClient.java
index 690fcf5f97f7..6462f39267be 100644
--- a/plugin/trino-mysql/src/test/java/io/trino/plugin/mysql/TestMySqlClient.java
+++ b/plugin/trino-mysql/src/test/java/io/trino/plugin/mysql/TestMySqlClient.java
@@ -13,6 +13,7 @@
*/
package io.trino.plugin.mysql;
+import com.google.common.collect.ImmutableMap;
import io.trino.plugin.base.mapping.DefaultIdentifierMapping;
import io.trino.plugin.jdbc.BaseJdbcConfig;
import io.trino.plugin.jdbc.ColumnMapping;
@@ -22,11 +23,25 @@
import io.trino.plugin.jdbc.JdbcExpression;
import io.trino.plugin.jdbc.JdbcStatisticsConfig;
import io.trino.plugin.jdbc.JdbcTypeHandle;
+import io.trino.plugin.jdbc.QueryParameter;
+import io.trino.plugin.jdbc.expression.ParameterizedExpression;
import io.trino.plugin.jdbc.logging.RemoteQueryModifier;
import io.trino.spi.connector.AggregateFunction;
import io.trino.spi.connector.ColumnHandle;
import io.trino.spi.expression.ConnectorExpression;
import io.trino.spi.expression.Variable;
+import io.trino.spi.type.Type;
+import io.trino.sql.planner.ConnectorExpressionTranslator;
+import io.trino.sql.planner.LiteralEncoder;
+import io.trino.sql.planner.Symbol;
+import io.trino.sql.planner.TypeProvider;
+import io.trino.sql.tree.ArithmeticBinaryExpression;
+import io.trino.sql.tree.ArithmeticUnaryExpression;
+import io.trino.sql.tree.Expression;
+import io.trino.sql.tree.IsNotNullPredicate;
+import io.trino.sql.tree.IsNullPredicate;
+import io.trino.sql.tree.NullIfExpression;
+import io.trino.sql.tree.SymbolReference;
import org.junit.jupiter.api.Test;
import java.sql.Types;
@@ -34,11 +49,17 @@
import java.util.Map;
import java.util.Optional;
+import static com.google.common.collect.ImmutableMap.toImmutableMap;
+import static io.trino.SessionTestUtils.TEST_SESSION;
import static io.trino.spi.type.BigintType.BIGINT;
import static io.trino.spi.type.BooleanType.BOOLEAN;
import static io.trino.spi.type.DoubleType.DOUBLE;
+import static io.trino.spi.type.VarcharType.createVarcharType;
+import static io.trino.sql.planner.TestingPlannerContext.PLANNER_CONTEXT;
+import static io.trino.sql.planner.TypeAnalyzer.createTestingTypeAnalyzer;
import static io.trino.testing.TestingConnectorSession.SESSION;
import static io.trino.type.InternalTypeManager.TESTING_TYPE_MANAGER;
+import static java.lang.String.format;
import static org.assertj.core.api.Assertions.assertThat;
public class TestMySqlClient
@@ -57,6 +78,13 @@ public class TestMySqlClient
.setJdbcTypeHandle(new JdbcTypeHandle(Types.DOUBLE, Optional.of("double"), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty()))
.build();
+ private static final JdbcColumnHandle VARCHAR_COLUMN =
+ JdbcColumnHandle.builder()
+ .setColumnName("c_varchar")
+ .setColumnType(createVarcharType(10))
+ .setJdbcTypeHandle(new JdbcTypeHandle(Types.VARCHAR, Optional.of("varchar"), Optional.of(10), Optional.empty(), Optional.empty(), Optional.empty()))
+ .build();
+
private static final JdbcClient JDBC_CLIENT = new MySqlClient(
new BaseJdbcConfig(),
new JdbcStatisticsConfig(),
@@ -68,6 +96,8 @@ public class TestMySqlClient
new DefaultIdentifierMapping(),
RemoteQueryModifier.NONE);
+ private static final LiteralEncoder LITERAL_ENCODER = new LiteralEncoder(PLANNER_CONTEXT);
+
@Test
public void testImplementCount()
{
@@ -169,4 +199,99 @@ private static void testImplementAggregation(AggregateFunction aggregateFunction
.isEqualTo(aggregateFunction.getOutputType());
}
}
+
+ @Test
+ public void testConvertArithmeticBinary()
+ {
+ for (ArithmeticBinaryExpression.Operator operator : ArithmeticBinaryExpression.Operator.values()) {
+ ParameterizedExpression converted = JDBC_CLIENT.convertPredicate(
+ SESSION,
+ translateToConnectorExpression(
+ new ArithmeticBinaryExpression(
+ operator,
+ new SymbolReference("c_bigint_symbol"),
+ LITERAL_ENCODER.toExpression(42L, BIGINT)),
+ Map.of("c_bigint_symbol", BIGINT)),
+ Map.of("c_bigint_symbol", BIGINT_COLUMN))
+ .orElseThrow();
+
+ assertThat(converted.expression()).isEqualTo(format("(`c_bigint`) %s (?)", operator.getValue()));
+ assertThat(converted.parameters()).isEqualTo(List.of(new QueryParameter(BIGINT, Optional.of(42L))));
+ }
+ }
+
+ @Test
+ public void testConvertArithmeticUnaryMinus()
+ {
+ ParameterizedExpression converted = JDBC_CLIENT.convertPredicate(
+ SESSION,
+ translateToConnectorExpression(
+ new ArithmeticUnaryExpression(
+ ArithmeticUnaryExpression.Sign.MINUS,
+ new SymbolReference("c_bigint_symbol")),
+ Map.of("c_bigint_symbol", BIGINT)),
+ Map.of("c_bigint_symbol", BIGINT_COLUMN))
+ .orElseThrow();
+
+ assertThat(converted.expression()).isEqualTo("-(`c_bigint`)");
+ assertThat(converted.parameters()).isEqualTo(List.of());
+ }
+
+ @Test
+ public void testConvertIsNull()
+ {
+ // c_varchar IS NULL
+ ParameterizedExpression converted = JDBC_CLIENT.convertPredicate(SESSION,
+ translateToConnectorExpression(
+ new IsNullPredicate(
+ new SymbolReference("c_varchar_symbol")),
+ Map.of("c_varchar_symbol", VARCHAR_COLUMN.getColumnType())),
+ Map.of("c_varchar_symbol", VARCHAR_COLUMN))
+ .orElseThrow();
+ assertThat(converted.expression()).isEqualTo("(`c_varchar`) IS NULL");
+ assertThat(converted.parameters()).isEqualTo(List.of());
+ }
+
+ @Test
+ public void testConvertIsNotNull()
+ {
+ // c_varchar IS NOT NULL
+ ParameterizedExpression converted = JDBC_CLIENT.convertPredicate(SESSION,
+ translateToConnectorExpression(
+ new IsNotNullPredicate(
+ new SymbolReference("c_varchar_symbol")),
+ Map.of("c_varchar_symbol", VARCHAR_COLUMN.getColumnType())),
+ Map.of("c_varchar_symbol", VARCHAR_COLUMN))
+ .orElseThrow();
+ assertThat(converted.expression()).isEqualTo("(`c_varchar`) IS NOT NULL");
+ assertThat(converted.parameters()).isEqualTo(List.of());
+ }
+
+ @Test
+ public void testConvertNullIf()
+ {
+ // nullif(a_varchar, b_varchar)
+ ParameterizedExpression converted = JDBC_CLIENT.convertPredicate(SESSION,
+ translateToConnectorExpression(
+ new NullIfExpression(
+ new SymbolReference("a_varchar_symbol"),
+ new SymbolReference("b_varchar_symbol")),
+ ImmutableMap.of("a_varchar_symbol", VARCHAR_COLUMN.getColumnType(), "b_varchar_symbol", VARCHAR_COLUMN.getColumnType())),
+ ImmutableMap.of("a_varchar_symbol", VARCHAR_COLUMN, "b_varchar_symbol", VARCHAR_COLUMN))
+ .orElseThrow();
+ assertThat(converted.expression()).isEqualTo("NULLIF((`c_varchar`), (`c_varchar`))");
+ assertThat(converted.parameters()).isEqualTo(List.of());
+ }
+
+ private ConnectorExpression translateToConnectorExpression(Expression expression, Map symbolTypes)
+ {
+ return ConnectorExpressionTranslator.translate(
+ TEST_SESSION,
+ expression,
+ TypeProvider.viewOf(symbolTypes.entrySet().stream()
+ .collect(toImmutableMap(entry -> new Symbol(entry.getKey()), Map.Entry::getValue))),
+ PLANNER_CONTEXT,
+ createTestingTypeAnalyzer(PLANNER_CONTEXT))
+ .orElseThrow();
+ }
}
diff --git a/plugin/trino-mysql/src/test/java/io/trino/plugin/mysql/TestMySqlConnectorTest.java b/plugin/trino-mysql/src/test/java/io/trino/plugin/mysql/TestMySqlConnectorTest.java
index b9929673e15c..1b19e5d699cd 100644
--- a/plugin/trino-mysql/src/test/java/io/trino/plugin/mysql/TestMySqlConnectorTest.java
+++ b/plugin/trino-mysql/src/test/java/io/trino/plugin/mysql/TestMySqlConnectorTest.java
@@ -15,7 +15,13 @@
import com.google.common.collect.ImmutableMap;
import io.trino.testing.QueryRunner;
+import io.trino.testing.TestingConnectorBehavior;
+import io.trino.testing.sql.TestTable;
+import org.junit.jupiter.api.Test;
+import java.util.List;
+
+import static com.google.common.base.Verify.verify;
import static io.trino.plugin.mysql.MySqlQueryRunner.createMySqlQueryRunner;
import static org.assertj.core.api.Assertions.assertThat;
@@ -30,9 +36,80 @@ protected QueryRunner createQueryRunner()
return createMySqlQueryRunner(mySqlServer, ImmutableMap.of(), ImmutableMap.of(), REQUIRED_TPCH_TABLES);
}
+ @Override
+ protected boolean hasBehavior(TestingConnectorBehavior connectorBehavior)
+ {
+ return switch (connectorBehavior) {
+ case SUPPORTS_PREDICATE_EXPRESSION_PUSHDOWN -> {
+ // TODO remove once super has this set to true
+ verify(!super.hasBehavior(connectorBehavior));
+ yield true;
+ }
+ default -> super.hasBehavior(connectorBehavior);
+ };
+ }
+
@Override
protected void verifyColumnNameLengthFailurePermissible(Throwable e)
{
assertThat(e).hasMessageMatching("(Incorrect column name '.*'|Identifier name '.*' is too long)");
}
+
+ @Test
+ public void testIsNullPredicatePushdown()
+ {
+ assertThat(query("SELECT nationkey FROM nation WHERE name IS NULL")).isFullyPushedDown();
+ assertThat(query("SELECT nationkey FROM nation WHERE name IS NULL OR regionkey = 4")).isFullyPushedDown();
+
+ try (TestTable table = new TestTable(
+ getQueryRunner()::execute,
+ "test_is_null_predicate_pushdown",
+ "(a_int integer, a_varchar varchar(1))",
+ List.of(
+ "1, 'A'",
+ "2, 'B'",
+ "1, NULL",
+ "2, NULL"))) {
+ assertThat(query("SELECT a_int FROM " + table.getName() + " WHERE a_varchar IS NULL OR a_int = 1")).isFullyPushedDown();
+ }
+ }
+
+ @Test
+ public void testIsNotNullPredicatePushdown()
+ {
+ assertThat(query("SELECT nationkey FROM nation WHERE name IS NOT NULL OR regionkey = 4")).isFullyPushedDown();
+
+ try (TestTable table = new TestTable(
+ getQueryRunner()::execute,
+ "test_is_not_null_predicate_pushdown",
+ "(a_int integer, a_varchar varchar(1))",
+ List.of(
+ "1, 'A'",
+ "2, 'B'",
+ "1, NULL",
+ "2, NULL"))) {
+ assertThat(query("SELECT a_int FROM " + table.getName() + " WHERE a_varchar IS NOT NULL OR a_int = 1")).isFullyPushedDown();
+ }
+ }
+
+ @Test
+ public void testNullIfPredicatePushdown()
+ {
+ assertThat(query("SELECT nationkey FROM nation WHERE NULLIF(name, 'ALGERIA') IS NULL"))
+ .matches("VALUES BIGINT '0'")
+ .isFullyPushedDown();
+
+ assertThat(query("SELECT name FROM nation WHERE NULLIF(nationkey, 0) IS NULL"))
+ .matches("VALUES CAST('ALGERIA' AS varchar(255))")
+ .isFullyPushedDown();
+
+ assertThat(query("SELECT nationkey FROM nation WHERE NULLIF(name, 'Algeria') IS NULL"))
+ .returnsEmptyResult()
+ .isFullyPushedDown();
+
+ // NULLIF returns the first argument because arguments aren't the same
+ assertThat(query("SELECT nationkey FROM nation WHERE NULLIF(name, 'Name not found') = name"))
+ .matches("SELECT nationkey FROM nation")
+ .isFullyPushedDown();
+ }
}