diff --git a/docs/src/main/sphinx/connector/oracle.rst b/docs/src/main/sphinx/connector/oracle.rst index 88b401c1998e..25bd864a79b5 100644 --- a/docs/src/main/sphinx/connector/oracle.rst +++ b/docs/src/main/sphinx/connector/oracle.rst @@ -368,6 +368,13 @@ The connector supports pushdown for a number of operations: * :ref:`limit-pushdown` * :ref:`topn-pushdown` +:ref:`Aggregate pushdown ` for the following functions: +* :func:`avg` +* :func:`count` +* :func:`max` +* :func:`min` +* :func:`sum` + Limitations ----------- diff --git a/plugin/trino-oracle/pom.xml b/plugin/trino-oracle/pom.xml index d7bffa18877b..43f9d7a01693 100644 --- a/plugin/trino-oracle/pom.xml +++ b/plugin/trino-oracle/pom.xml @@ -23,6 +23,11 @@ trino-base-jdbc + + io.trino + trino-matching + + io.trino trino-plugin-toolkit diff --git a/plugin/trino-oracle/src/main/java/io/trino/plugin/oracle/ImplementAvgBigint.java b/plugin/trino-oracle/src/main/java/io/trino/plugin/oracle/ImplementAvgBigint.java new file mode 100644 index 000000000000..f393c28136f3 --- /dev/null +++ b/plugin/trino-oracle/src/main/java/io/trino/plugin/oracle/ImplementAvgBigint.java @@ -0,0 +1,64 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.oracle; + +import io.trino.matching.Capture; +import io.trino.matching.Captures; +import io.trino.matching.Pattern; +import io.trino.plugin.jdbc.JdbcColumnHandle; +import io.trino.plugin.jdbc.JdbcExpression; +import io.trino.plugin.jdbc.expression.AggregateFunctionRule; +import io.trino.spi.connector.AggregateFunction; +import io.trino.spi.expression.Variable; +import io.trino.spi.type.DecimalType; + +import java.util.Optional; + +import static com.google.common.base.Verify.verify; +import static io.trino.matching.Capture.newCapture; +import static io.trino.plugin.jdbc.expression.AggregateFunctionPatterns.basicAggregation; +import static io.trino.plugin.jdbc.expression.AggregateFunctionPatterns.expressionType; +import static io.trino.plugin.jdbc.expression.AggregateFunctionPatterns.functionName; +import static io.trino.plugin.jdbc.expression.AggregateFunctionPatterns.singleInput; +import static io.trino.plugin.jdbc.expression.AggregateFunctionPatterns.variable; +import static io.trino.spi.type.BigintType.BIGINT; +import static io.trino.spi.type.DoubleType.DOUBLE; +import static java.lang.String.format; + +public class ImplementAvgBigint + implements AggregateFunctionRule +{ + private static final Capture INPUT = newCapture(); + + @Override + public Pattern getPattern() + { + return basicAggregation() + .with(functionName().equalTo("avg")) + .with(singleInput().matching(variable().with(expressionType().equalTo(BIGINT)).capturedAs(INPUT))); + } + + @Override + public Optional rewrite(AggregateFunction aggregateFunction, Captures captures, RewriteContext context) + { + Variable input = captures.get(INPUT); + JdbcColumnHandle columnHandle = (JdbcColumnHandle) context.getAssignment(input.getName()); + DecimalType type = (DecimalType) columnHandle.getColumnType(); + verify(aggregateFunction.getOutputType() == DOUBLE); + + return Optional.of(new JdbcExpression( + format("avg(CAST(%s AS DECIMAL(%s, %s)))", context.getIdentifierQuote().apply(columnHandle.getColumnName()), type.getPrecision(), type.getScale()), + columnHandle.getJdbcTypeHandle())); + } +} diff --git a/plugin/trino-oracle/src/main/java/io/trino/plugin/oracle/OracleClient.java b/plugin/trino-oracle/src/main/java/io/trino/plugin/oracle/OracleClient.java index c1216fbcb99d..45a3347814a9 100644 --- a/plugin/trino-oracle/src/main/java/io/trino/plugin/oracle/OracleClient.java +++ b/plugin/trino-oracle/src/main/java/io/trino/plugin/oracle/OracleClient.java @@ -22,14 +22,25 @@ import io.trino.plugin.jdbc.ConnectionFactory; import io.trino.plugin.jdbc.DoubleWriteFunction; import io.trino.plugin.jdbc.JdbcColumnHandle; +import io.trino.plugin.jdbc.JdbcExpression; import io.trino.plugin.jdbc.JdbcJoinCondition; import io.trino.plugin.jdbc.JdbcTableHandle; import io.trino.plugin.jdbc.JdbcTypeHandle; import io.trino.plugin.jdbc.LongWriteFunction; import io.trino.plugin.jdbc.SliceWriteFunction; import io.trino.plugin.jdbc.WriteMapping; +import io.trino.plugin.jdbc.expression.AggregateFunctionRewriter; +import io.trino.plugin.jdbc.expression.AggregateFunctionRule; +import io.trino.plugin.jdbc.expression.ImplementAvgDecimal; +import io.trino.plugin.jdbc.expression.ImplementAvgFloatingPoint; +import io.trino.plugin.jdbc.expression.ImplementCount; +import io.trino.plugin.jdbc.expression.ImplementCountAll; +import io.trino.plugin.jdbc.expression.ImplementMinMax; +import io.trino.plugin.jdbc.expression.ImplementSum; import io.trino.plugin.jdbc.mapping.IdentifierMapping; import io.trino.spi.TrinoException; +import io.trino.spi.connector.AggregateFunction; +import io.trino.spi.connector.ColumnHandle; import io.trino.spi.connector.ConnectorSession; import io.trino.spi.connector.JoinCondition; import io.trino.spi.connector.SchemaTableName; @@ -129,6 +140,8 @@ public class OracleClient private static final int PRECISION_OF_UNSPECIFIED_NUMBER = 127; + private final AggregateFunctionRewriter aggregateFunctionRewriter; + private static final Set INTERNAL_SCHEMAS = ImmutableSet.builder() .add("ctxsys") .add("flows_files") @@ -166,6 +179,25 @@ public OracleClient( requireNonNull(oracleConfig, "oracleConfig is null"); this.synonymsEnabled = oracleConfig.isSynonymsEnabled(); + + JdbcTypeHandle bigintTypeHandle = new JdbcTypeHandle(Types.NUMERIC, Optional.of("decimal"), 40, 0, Optional.empty(), Optional.empty()); + this.aggregateFunctionRewriter = new AggregateFunctionRewriter( + this::quoted, + ImmutableSet.builder() + .add(new ImplementCountAll(bigintTypeHandle)) + .add(new ImplementCount(bigintTypeHandle)) + .add(new ImplementMinMax()) + .add(new ImplementSum(OracleClient::toTypeHandle)) + .add(new ImplementAvgFloatingPoint()) + .add(new ImplementAvgDecimal()) + .add(new ImplementAvgBigint()) + .build()); + } + + @Override + public Optional implementAggregation(ConnectorSession session, AggregateFunction aggregate, Map assignments) + { + return aggregateFunctionRewriter.rewrite(session, aggregate, assignments); } @Override @@ -272,7 +304,14 @@ public Optional toColumnMapping(ConnectorSession session, Connect int scale = max(decimalDigits, 0); Optional numberDefaultScale = getNumberDefaultScale(session); RoundingMode roundingMode = getNumberRoundingMode(session); - if (precision < scale) { + if (precision == 40 && decimalDigits == 0) { + return Optional.of(ColumnMapping.longMapping( + BIGINT, + ResultSet::getLong, + bigintWriteFunction(), + FULL_PUSHDOWN)); + } + else if (precision < scale) { if (roundingMode == RoundingMode.UNNECESSARY) { break; } @@ -427,6 +466,11 @@ private SliceWriteFunction oracleCharWriteFunction() ((OraclePreparedStatement) statement).setFixedCHAR(index, value.toStringUtf8()); } + private static Optional toTypeHandle(DecimalType decimalType) + { + return Optional.of(new JdbcTypeHandle(Types.NUMERIC, Optional.of("decimal"), decimalType.getPrecision(), decimalType.getScale(), Optional.empty(), Optional.empty())); + } + @Override public WriteMapping toWriteMapping(ConnectorSession session, Type type) { diff --git a/plugin/trino-oracle/src/test/java/io/trino/plugin/oracle/BaseOracleConnectorSmokeTest.java b/plugin/trino-oracle/src/test/java/io/trino/plugin/oracle/BaseOracleConnectorSmokeTest.java index 4dda6cbf5e20..16f912af9fbf 100644 --- a/plugin/trino-oracle/src/test/java/io/trino/plugin/oracle/BaseOracleConnectorSmokeTest.java +++ b/plugin/trino-oracle/src/test/java/io/trino/plugin/oracle/BaseOracleConnectorSmokeTest.java @@ -52,4 +52,18 @@ public void testCommentColumn() assertUpdate("DROP TABLE " + tableName); } + + @Test + public void testAggregationPushdown() + throws Exception + { + assertThat(query("SELECT count(*) FROM nation")).isFullyPushedDown(); + assertThat(query("SELECT count(nationkey) FROM nation")).isFullyPushedDown(); + assertThat(query("SELECT count(1) FROM nation")).isFullyPushedDown(); + assertThat(query("SELECT regionkey, min(nationkey) FROM nation GROUP BY regionkey")).isFullyPushedDown(); + assertThat(query("SELECT regionkey, max(nationkey) FROM nation GROUP BY regionkey")).isFullyPushedDown(); + assertThat(query("SELECT regionkey, sum(nationkey) FROM nation GROUP BY regionkey")).isFullyPushedDown(); + assertThat(query("SELECT regionkey, avg(nationkey) FROM nation GROUP BY regionkey")).isFullyPushedDown(); + assertThat(query("SELECT regionkey, sum(nationkey) FROM nation WHERE regionkey < 4 GROUP BY regionkey")).isFullyPushedDown(); + } }