diff --git a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/expression/ImplementStddevPop.java b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/expression/ImplementStddevPop.java new file mode 100644 index 000000000000..e69c19ad8d96 --- /dev/null +++ b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/expression/ImplementStddevPop.java @@ -0,0 +1,63 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.jdbc.expression; + +import io.trino.matching.Capture; +import io.trino.matching.Captures; +import io.trino.matching.Pattern; +import io.trino.plugin.jdbc.JdbcColumnHandle; +import io.trino.plugin.jdbc.JdbcExpression; +import io.trino.spi.connector.AggregateFunction; +import io.trino.spi.expression.Variable; +import io.trino.spi.type.DoubleType; + +import java.util.Optional; + +import static com.google.common.base.Verify.verify; +import static io.trino.matching.Capture.newCapture; +import static io.trino.plugin.jdbc.expression.AggregateFunctionPatterns.basicAggregation; +import static io.trino.plugin.jdbc.expression.AggregateFunctionPatterns.expressionType; +import static io.trino.plugin.jdbc.expression.AggregateFunctionPatterns.functionName; +import static io.trino.plugin.jdbc.expression.AggregateFunctionPatterns.singleInput; +import static io.trino.plugin.jdbc.expression.AggregateFunctionPatterns.variable; +import static java.lang.String.format; + +public class ImplementStddevPop + implements AggregateFunctionRule +{ + private static final Capture INPUT = newCapture(); + + @Override + public Pattern getPattern() + { + return basicAggregation() + .with(functionName().matching(name -> name.equalsIgnoreCase("stddev_pop"))) + .with(singleInput().matching( + variable() + .with(expressionType().matching(DoubleType.class::isInstance)) + .capturedAs(INPUT))); + } + + @Override + public Optional rewrite(AggregateFunction aggregateFunction, Captures captures, RewriteContext context) + { + Variable input = captures.get(INPUT); + JdbcColumnHandle columnHandle = (JdbcColumnHandle) context.getAssignment(input.getName()); + verify(aggregateFunction.getOutputType() == columnHandle.getColumnType()); + + return Optional.of(new JdbcExpression( + format("stddev_pop(%s)", context.getIdentifierQuote().apply(columnHandle.getColumnName())), + columnHandle.getJdbcTypeHandle())); + } +} diff --git a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/expression/ImplementStddevSamp.java b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/expression/ImplementStddevSamp.java new file mode 100644 index 000000000000..80ba22dabf24 --- /dev/null +++ b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/expression/ImplementStddevSamp.java @@ -0,0 +1,67 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.jdbc.expression; + +import com.google.common.collect.ImmutableList; +import io.trino.matching.Capture; +import io.trino.matching.Captures; +import io.trino.matching.Pattern; +import io.trino.plugin.jdbc.JdbcColumnHandle; +import io.trino.plugin.jdbc.JdbcExpression; +import io.trino.spi.connector.AggregateFunction; +import io.trino.spi.expression.Variable; +import io.trino.spi.type.DoubleType; + +import java.util.List; +import java.util.Optional; + +import static com.google.common.base.Verify.verify; +import static io.trino.matching.Capture.newCapture; +import static io.trino.plugin.jdbc.expression.AggregateFunctionPatterns.basicAggregation; +import static io.trino.plugin.jdbc.expression.AggregateFunctionPatterns.expressionType; +import static io.trino.plugin.jdbc.expression.AggregateFunctionPatterns.functionName; +import static io.trino.plugin.jdbc.expression.AggregateFunctionPatterns.singleInput; +import static io.trino.plugin.jdbc.expression.AggregateFunctionPatterns.variable; +import static java.lang.String.format; + +public class ImplementStddevSamp + implements AggregateFunctionRule +{ + // TODO (https://github.com/trinodb/trino/issues/6189): remove stddev, an alias, from the list & simplify the pattern + private static final List STDDEV_FUNCTION_NAMES = ImmutableList.of("stddev", "stddev_samp"); + private static final Capture INPUT = newCapture(); + + @Override + public Pattern getPattern() + { + return basicAggregation() + .with(functionName().matching(name -> STDDEV_FUNCTION_NAMES.stream().anyMatch(name::equalsIgnoreCase))) + .with(singleInput().matching( + variable() + .with(expressionType().matching(DoubleType.class::isInstance)) + .capturedAs(INPUT))); + } + + @Override + public Optional rewrite(AggregateFunction aggregateFunction, Captures captures, RewriteContext context) + { + Variable input = captures.get(INPUT); + JdbcColumnHandle columnHandle = (JdbcColumnHandle) context.getAssignment(input.getName()); + verify(aggregateFunction.getOutputType() == columnHandle.getColumnType()); + + return Optional.of(new JdbcExpression( + format("stddev_samp(%s)", context.getIdentifierQuote().apply(columnHandle.getColumnName())), + columnHandle.getJdbcTypeHandle())); + } +} diff --git a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/expression/ImplementVariancePop.java b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/expression/ImplementVariancePop.java new file mode 100644 index 000000000000..65870a3beec8 --- /dev/null +++ b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/expression/ImplementVariancePop.java @@ -0,0 +1,63 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.jdbc.expression; + +import io.trino.matching.Capture; +import io.trino.matching.Captures; +import io.trino.matching.Pattern; +import io.trino.plugin.jdbc.JdbcColumnHandle; +import io.trino.plugin.jdbc.JdbcExpression; +import io.trino.spi.connector.AggregateFunction; +import io.trino.spi.expression.Variable; +import io.trino.spi.type.DoubleType; + +import java.util.Optional; + +import static com.google.common.base.Verify.verify; +import static io.trino.matching.Capture.newCapture; +import static io.trino.plugin.jdbc.expression.AggregateFunctionPatterns.basicAggregation; +import static io.trino.plugin.jdbc.expression.AggregateFunctionPatterns.expressionType; +import static io.trino.plugin.jdbc.expression.AggregateFunctionPatterns.functionName; +import static io.trino.plugin.jdbc.expression.AggregateFunctionPatterns.singleInput; +import static io.trino.plugin.jdbc.expression.AggregateFunctionPatterns.variable; +import static java.lang.String.format; + +public class ImplementVariancePop + implements AggregateFunctionRule +{ + private static final Capture INPUT = newCapture(); + + @Override + public Pattern getPattern() + { + return basicAggregation() + .with(functionName().matching(name -> name.equalsIgnoreCase("var_pop"))) + .with(singleInput().matching( + variable() + .with(expressionType().matching(DoubleType.class::isInstance)) + .capturedAs(INPUT))); + } + + @Override + public Optional rewrite(AggregateFunction aggregateFunction, Captures captures, RewriteContext context) + { + Variable input = captures.get(INPUT); + JdbcColumnHandle columnHandle = (JdbcColumnHandle) context.getAssignment(input.getName()); + verify(aggregateFunction.getOutputType() == columnHandle.getColumnType()); + + return Optional.of(new JdbcExpression( + format("var_pop(%s)", context.getIdentifierQuote().apply(columnHandle.getColumnName())), + columnHandle.getJdbcTypeHandle())); + } +} diff --git a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/expression/ImplementVarianceSamp.java b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/expression/ImplementVarianceSamp.java new file mode 100644 index 000000000000..36ae3fa63f26 --- /dev/null +++ b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/expression/ImplementVarianceSamp.java @@ -0,0 +1,67 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.jdbc.expression; + +import com.google.common.collect.ImmutableList; +import io.trino.matching.Capture; +import io.trino.matching.Captures; +import io.trino.matching.Pattern; +import io.trino.plugin.jdbc.JdbcColumnHandle; +import io.trino.plugin.jdbc.JdbcExpression; +import io.trino.spi.connector.AggregateFunction; +import io.trino.spi.expression.Variable; +import io.trino.spi.type.DoubleType; + +import java.util.List; +import java.util.Optional; + +import static com.google.common.base.Verify.verify; +import static io.trino.matching.Capture.newCapture; +import static io.trino.plugin.jdbc.expression.AggregateFunctionPatterns.basicAggregation; +import static io.trino.plugin.jdbc.expression.AggregateFunctionPatterns.expressionType; +import static io.trino.plugin.jdbc.expression.AggregateFunctionPatterns.functionName; +import static io.trino.plugin.jdbc.expression.AggregateFunctionPatterns.singleInput; +import static io.trino.plugin.jdbc.expression.AggregateFunctionPatterns.variable; +import static java.lang.String.format; + +public class ImplementVarianceSamp + implements AggregateFunctionRule +{ + // TODO (https://github.com/trinodb/trino/issues/6189): remove variance, an alias, from the list & simplify the pattern + private static final List VARIANCE_FUNCTION_NAMES = ImmutableList.of("variance", "var_samp"); + private static final Capture INPUT = newCapture(); + + @Override + public Pattern getPattern() + { + return basicAggregation() + .with(functionName().matching(name -> VARIANCE_FUNCTION_NAMES.stream().anyMatch(name::equalsIgnoreCase))) + .with(singleInput().matching( + variable() + .with(expressionType().matching(DoubleType.class::isInstance)) + .capturedAs(INPUT))); + } + + @Override + public Optional rewrite(AggregateFunction aggregateFunction, Captures captures, RewriteContext context) + { + Variable input = captures.get(INPUT); + JdbcColumnHandle columnHandle = (JdbcColumnHandle) context.getAssignment(input.getName()); + verify(aggregateFunction.getOutputType() == columnHandle.getColumnType()); + + return Optional.of(new JdbcExpression( + format("var_samp(%s)", context.getIdentifierQuote().apply(columnHandle.getColumnName())), + columnHandle.getJdbcTypeHandle())); + } +} diff --git a/plugin/trino-mysql/src/main/java/io/trino/plugin/mysql/MySqlClient.java b/plugin/trino-mysql/src/main/java/io/trino/plugin/mysql/MySqlClient.java index 53ea7a226753..4150269db81a 100644 --- a/plugin/trino-mysql/src/main/java/io/trino/plugin/mysql/MySqlClient.java +++ b/plugin/trino-mysql/src/main/java/io/trino/plugin/mysql/MySqlClient.java @@ -32,7 +32,11 @@ import io.trino.plugin.jdbc.expression.ImplementCount; import io.trino.plugin.jdbc.expression.ImplementCountAll; import io.trino.plugin.jdbc.expression.ImplementMinMax; +import io.trino.plugin.jdbc.expression.ImplementStddevPop; +import io.trino.plugin.jdbc.expression.ImplementStddevSamp; import io.trino.plugin.jdbc.expression.ImplementSum; +import io.trino.plugin.jdbc.expression.ImplementVariancePop; +import io.trino.plugin.jdbc.expression.ImplementVarianceSamp; import io.trino.spi.TrinoException; import io.trino.spi.connector.AggregateFunction; import io.trino.spi.connector.ColumnHandle; @@ -144,6 +148,10 @@ public MySqlClient(BaseJdbcConfig config, ConnectionFactory connectionFactory, T .add(new ImplementAvgFloatingPoint()) .add(new ImplementAvgDecimal()) .add(new ImplementAvgBigint()) + .add(new ImplementStddevSamp()) + .add(new ImplementStddevPop()) + .add(new ImplementVarianceSamp()) + .add(new ImplementVariancePop()) .build()); } diff --git a/plugin/trino-mysql/src/test/java/io/trino/plugin/mysql/BaseMySqlIntegrationSmokeTest.java b/plugin/trino-mysql/src/test/java/io/trino/plugin/mysql/BaseMySqlIntegrationSmokeTest.java index ef0b7e22a3a5..e88a6f83e41c 100644 --- a/plugin/trino-mysql/src/test/java/io/trino/plugin/mysql/BaseMySqlIntegrationSmokeTest.java +++ b/plugin/trino-mysql/src/test/java/io/trino/plugin/mysql/BaseMySqlIntegrationSmokeTest.java @@ -13,12 +13,14 @@ */ package io.trino.plugin.mysql; +import com.google.common.collect.ImmutableList; import io.trino.Session; import io.trino.sql.planner.plan.AggregationNode; import io.trino.sql.planner.plan.FilterNode; import io.trino.testing.AbstractTestIntegrationSmokeTest; import io.trino.testing.MaterializedResult; import io.trino.testing.MaterializedRow; +import io.trino.testing.sql.TestTable; import org.intellij.lang.annotations.Language; import org.testng.annotations.AfterClass; import org.testng.annotations.Test; @@ -263,6 +265,72 @@ public void testAggregationPushdown() assertThat(query("SELECT approx_set(nationkey) FROM nation")).isNotFullyPushedDown(AggregationNode.class); } + @Test + public void testStddevPushdown() + { + String schemaName = getSession().getSchema().orElseThrow(); + try (TestTable testTable = new TestTable(mysqlServer::execute, schemaName + ".test_stddev_pushdown", + "(t_double DOUBLE PRECISION)")) { + assertThat(query("SELECT stddev_pop(t_double) FROM " + testTable.getName())).isFullyPushedDown(); + assertThat(query("SELECT stddev(t_double) FROM " + testTable.getName())).isFullyPushedDown(); + assertThat(query("SELECT stddev_samp(t_double) FROM " + testTable.getName())).isFullyPushedDown(); + + mysqlServer.execute("INSERT INTO " + testTable.getName() + " VALUES (1)"); + + assertThat(query("SELECT stddev_pop(t_double) FROM " + testTable.getName())).isFullyPushedDown(); + assertThat(query("SELECT stddev(t_double) FROM " + testTable.getName())).isFullyPushedDown(); + assertThat(query("SELECT stddev_samp(t_double) FROM " + testTable.getName())).isFullyPushedDown(); + + mysqlServer.execute("INSERT INTO " + testTable.getName() + " VALUES (3)"); + assertThat(query("SELECT stddev_pop(t_double) FROM " + testTable.getName())).isFullyPushedDown(); + + mysqlServer.execute("INSERT INTO " + testTable.getName() + " VALUES (5)"); + assertThat(query("SELECT stddev(t_double) FROM " + testTable.getName())).isFullyPushedDown(); + assertThat(query("SELECT stddev_samp(t_double) FROM " + testTable.getName())).isFullyPushedDown(); + } + + try (TestTable testTable = new TestTable(mysqlServer::execute, schemaName + ".test_stddev_pushdown", + "(t_double DOUBLE PRECISION)", ImmutableList.of("1", "2", "4", "5"))) { + // Test non-whole number results + assertThat(query("SELECT stddev_pop(t_double) FROM " + testTable.getName())).isFullyPushedDown(); + assertThat(query("SELECT stddev(t_double) FROM " + testTable.getName())).isFullyPushedDown(); + assertThat(query("SELECT stddev_samp(t_double) FROM " + testTable.getName())).isFullyPushedDown(); + } + } + + @Test + public void testVariancePushdown() + { + String schemaName = getSession().getSchema().orElseThrow(); + try (TestTable testTable = new TestTable(mysqlServer::execute, schemaName + ".test_variance_pushdown", + "(t_double DOUBLE PRECISION)")) { + assertThat(query("SELECT var_pop(t_double) FROM " + testTable.getName())).isFullyPushedDown(); + assertThat(query("SELECT variance(t_double) FROM " + testTable.getName())).isFullyPushedDown(); + assertThat(query("SELECT var_samp(t_double) FROM " + testTable.getName())).isFullyPushedDown(); + + mysqlServer.execute("INSERT INTO " + testTable.getName() + " VALUES (1)"); + + assertThat(query("SELECT var_pop(t_double) FROM " + testTable.getName())).isFullyPushedDown(); + assertThat(query("SELECT variance(t_double) FROM " + testTable.getName())).isFullyPushedDown(); + assertThat(query("SELECT var_samp(t_double) FROM " + testTable.getName())).isFullyPushedDown(); + + mysqlServer.execute("INSERT INTO " + testTable.getName() + " VALUES (3)"); + assertThat(query("SELECT var_pop(t_double) FROM " + testTable.getName())).isFullyPushedDown(); + + mysqlServer.execute("INSERT INTO " + testTable.getName() + " VALUES (5)"); + assertThat(query("SELECT variance(t_double) FROM " + testTable.getName())).isFullyPushedDown(); + assertThat(query("SELECT var_samp(t_double) FROM " + testTable.getName())).isFullyPushedDown(); + } + + try (TestTable testTable = new TestTable(mysqlServer::execute, schemaName + ".test_variance_pushdown", + "(t_double DOUBLE PRECISION)", ImmutableList.of("1", "2", "3", "4", "5"))) { + // Test non-whole number results + assertThat(query("SELECT var_pop(t_double) FROM " + testTable.getName())).isFullyPushedDown(); + assertThat(query("SELECT variance(t_double) FROM " + testTable.getName())).isFullyPushedDown(); + assertThat(query("SELECT var_samp(t_double) FROM " + testTable.getName())).isFullyPushedDown(); + } + } + @Test public void testLimitPushdown() {