diff --git a/presto-docs/src/main/sphinx/connector/oracle.rst b/presto-docs/src/main/sphinx/connector/oracle.rst
index 0d5d6a3abe9..03480664909 100644
--- a/presto-docs/src/main/sphinx/connector/oracle.rst
+++ b/presto-docs/src/main/sphinx/connector/oracle.rst
@@ -49,7 +49,7 @@ sales.
Querying Oracle
---------------
-The Oracle connector provides a schema for every Oracle database.
+The Oracle connector provides a schema for every Oracle database.
Run ``SHOW SCHEMAS`` to see the available Oracle databases::
@@ -159,9 +159,9 @@ For example:
- If ``unsupported-type.handling`` is set to ``FAIL``, then the
querying of an unsupported table fails.
-- If ``unsupported-type.handling`` is set to ``IGNORE``,
+- If ``unsupported-type.handling`` is set to ``IGNORE``,
then you can't see the unsupported types in Presto.
-- If ``unsupported-type.handling`` is set to ``CONVERT_TO_VARCHAR``,
+- If ``unsupported-type.handling`` is set to ``CONVERT_TO_VARCHAR``,
then the column is exposed as unbounded ``VARCHAR``.
Presto to Oracle type mapping
@@ -234,13 +234,13 @@ An Oracle ``NUMBER(p, s)`` maps to Presto's ``DECIMAL(p, s)`` except in these
conditions:
- No precision is specified for the column (example: ``NUMBER`` or
- ``NUMBER(*)``), unless ``oracle.number.default-scale`` is set.
-- Scale (``s`` ) is greater than precision.
-- Precision (``p`` ) is greater than 38.
+ ``NUMBER(*)``), unless ``oracle.number.default-scale`` is set.
+- Scale (``s`` ) is greater than precision.
+- Precision (``p`` ) is greater than 38.
- Scale is negative and the difference between ``p`` and ``s`` is greater than
38, unless ``oracle.number.rounding-mode`` is set to a different value than
- ``UNNECESSARY``.
-
+ ``UNNECESSARY``.
+
If ``s`` is negative, ``NUMBER(p, s)`` maps to ``DECIMAL(p + s, 0)``.
For Oracle ``NUMBER`` (without precision and scale), you can change
@@ -308,19 +308,19 @@ Type mapping configuration properties
- ``IGNORE``
* - ``oracle.number.default-scale``
- - ``number_default_scale``
- - Default Presto ``DECIMAL`` scale for Oracle ``NUMBER`` (without precision
- and scale) date type. When not set then such column is treated as not
+ - ``number_default_scale``
+ - Default Presto ``DECIMAL`` scale for Oracle ``NUMBER`` (without precision
+ and scale) date type. When not set then such column is treated as not
supported.
- not set
* - ``oracle.number.rounding-mode``
- ``number_rounding_mode``
- - Rounding mode for the Oracle ``NUMBER`` data type. This is useful when
- Oracle ``NUMBER`` data type specifies higher scale than is supported in
+ - Rounding mode for the Oracle ``NUMBER`` data type. This is useful when
+ Oracle ``NUMBER`` data type specifies higher scale than is supported in
Presto. Possible values are:
- - ``UNNECESSARY`` - Rounding mode to assert that the
- requested operation has an exact result,
+ - ``UNNECESSARY`` - Rounding mode to assert that the
+ requested operation has an exact result,
hence no rounding is necessary.
- ``CEILING`` - Rounding mode to round towards
positive infinity.
@@ -345,7 +345,7 @@ Synonyms
--------
Based on performance reasons, Presto disables support for Oracle ``SYNONYM``. To
-include ``SYNONYM``, add the following configuration property:
+include ``SYNONYM``, add the following configuration property:
.. code-block:: none
@@ -354,7 +354,14 @@ include ``SYNONYM``, add the following configuration property:
Pushdown
--------
-The connector supports :doc:`pushdown ` for optimized query processing.
+The connector supports :doc:`pushdown ` for processing the
+following aggregate functions:
+
+* :func:`avg`
+* :func:`count`
+* :func:`max`
+* :func:`min`
+* :func:`sum`
Limitations
-----------
diff --git a/presto-oracle/pom.xml b/presto-oracle/pom.xml
index 5291994d2a6..ca02476fe9b 100644
--- a/presto-oracle/pom.xml
+++ b/presto-oracle/pom.xml
@@ -22,6 +22,11 @@
presto-base-jdbc
+
+ io.prestosql
+ presto-matching
+
+
io.airlift
configuration
diff --git a/presto-oracle/src/main/java/io/prestosql/plugin/oracle/ImplementAvgBigint.java b/presto-oracle/src/main/java/io/prestosql/plugin/oracle/ImplementAvgBigint.java
new file mode 100644
index 00000000000..31af5232d22
--- /dev/null
+++ b/presto-oracle/src/main/java/io/prestosql/plugin/oracle/ImplementAvgBigint.java
@@ -0,0 +1,69 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package io.prestosql.plugin.oracle;
+
+import io.prestosql.matching.Capture;
+import io.prestosql.matching.Captures;
+import io.prestosql.matching.Pattern;
+import io.prestosql.plugin.jdbc.JdbcColumnHandle;
+import io.prestosql.plugin.jdbc.JdbcExpression;
+import io.prestosql.plugin.jdbc.JdbcTypeHandle;
+import io.prestosql.plugin.jdbc.expression.AggregateFunctionRule;
+import io.prestosql.spi.connector.AggregateFunction;
+import io.prestosql.spi.expression.Variable;
+
+import java.sql.Types;
+import java.util.Optional;
+
+import static com.google.common.base.Verify.verify;
+import static com.google.common.base.Verify.verifyNotNull;
+import static io.prestosql.matching.Capture.newCapture;
+import static io.prestosql.plugin.jdbc.expression.AggregateFunctionPatterns.basicAggregation;
+import static io.prestosql.plugin.jdbc.expression.AggregateFunctionPatterns.expressionType;
+import static io.prestosql.plugin.jdbc.expression.AggregateFunctionPatterns.functionName;
+import static io.prestosql.plugin.jdbc.expression.AggregateFunctionPatterns.singleInput;
+import static io.prestosql.plugin.jdbc.expression.AggregateFunctionPatterns.variable;
+import static io.prestosql.spi.type.BigintType.BIGINT;
+import static io.prestosql.spi.type.DoubleType.DOUBLE;
+import static java.lang.String.format;
+
+public class ImplementAvgBigint
+ implements AggregateFunctionRule
+{
+ private static final Capture INPUT = newCapture();
+
+ @Override
+ public Pattern getPattern()
+ {
+ return basicAggregation()
+ .with(functionName().equalTo("avg"))
+ .with(singleInput().matching(
+ variable()
+ .with(expressionType().matching(type -> type == BIGINT))
+ .capturedAs(INPUT)));
+ }
+
+ @Override
+ public Optional rewrite(AggregateFunction aggregateFunction, Captures captures, RewriteContext context)
+ {
+ Variable input = captures.get(INPUT);
+ JdbcColumnHandle columnHandle = (JdbcColumnHandle) context.getAssignments().get(input.getName());
+ verifyNotNull(columnHandle, "Unbound variable: %s", input);
+ verify(aggregateFunction.getOutputType() == DOUBLE);
+
+ return Optional.of(new JdbcExpression(
+ format("avg(CAST(%s AS DECIMAL))", columnHandle.toSqlExpression(context.getIdentifierQuote())),
+ new JdbcTypeHandle(Types.DOUBLE, Optional.of("double"), 0, Optional.empty(), Optional.empty(), Optional.empty())));
+ }
+}
diff --git a/presto-oracle/src/main/java/io/prestosql/plugin/oracle/OracleClient.java b/presto-oracle/src/main/java/io/prestosql/plugin/oracle/OracleClient.java
index 1f72459b8e8..135c020f1b7 100644
--- a/presto-oracle/src/main/java/io/prestosql/plugin/oracle/OracleClient.java
+++ b/presto-oracle/src/main/java/io/prestosql/plugin/oracle/OracleClient.java
@@ -20,13 +20,24 @@
import io.prestosql.plugin.jdbc.ColumnMapping;
import io.prestosql.plugin.jdbc.ConnectionFactory;
import io.prestosql.plugin.jdbc.DoubleWriteFunction;
+import io.prestosql.plugin.jdbc.JdbcExpression;
import io.prestosql.plugin.jdbc.JdbcIdentity;
import io.prestosql.plugin.jdbc.JdbcTypeHandle;
import io.prestosql.plugin.jdbc.LongWriteFunction;
import io.prestosql.plugin.jdbc.PredicatePushdownController;
import io.prestosql.plugin.jdbc.SliceWriteFunction;
import io.prestosql.plugin.jdbc.WriteMapping;
+import io.prestosql.plugin.jdbc.expression.AggregateFunctionRewriter;
+import io.prestosql.plugin.jdbc.expression.AggregateFunctionRule;
+import io.prestosql.plugin.jdbc.expression.ImplementAvgDecimal;
+import io.prestosql.plugin.jdbc.expression.ImplementAvgFloatingPoint;
+import io.prestosql.plugin.jdbc.expression.ImplementCount;
+import io.prestosql.plugin.jdbc.expression.ImplementCountAll;
+import io.prestosql.plugin.jdbc.expression.ImplementMinMax;
+import io.prestosql.plugin.jdbc.expression.ImplementSum;
import io.prestosql.spi.PrestoException;
+import io.prestosql.spi.connector.AggregateFunction;
+import io.prestosql.spi.connector.ColumnHandle;
import io.prestosql.spi.connector.ConnectorSession;
import io.prestosql.spi.connector.SchemaTableName;
import io.prestosql.spi.predicate.Domain;
@@ -114,6 +125,7 @@
public class OracleClient
extends BaseJdbcClient
{
+ private final AggregateFunctionRewriter aggregateFunctionRewriter;
private static final int MAX_BYTES_PER_CHAR = 4;
private static final int ORACLE_VARCHAR2_MAX_BYTES = 4000;
@@ -162,6 +174,31 @@ public OracleClient(
requireNonNull(oracleConfig, "oracle config is null");
this.synonymsEnabled = oracleConfig.isSynonymsEnabled();
+
+ JdbcTypeHandle bigintTypeHandle = new JdbcTypeHandle(Types.NUMERIC, Optional.of("decimal"), 40, 0, Optional.empty(), Optional.empty());
+ this.aggregateFunctionRewriter = new AggregateFunctionRewriter(
+ this::quoted,
+ ImmutableSet.builder()
+ .add(new ImplementCountAll(bigintTypeHandle))
+ .add(new ImplementCount(bigintTypeHandle))
+ .add(new ImplementMinMax())
+ .add(new ImplementSum(OracleClient::toTypeHandle))
+ .add(new ImplementAvgFloatingPoint())
+ .add(new ImplementAvgDecimal())
+ .add(new ImplementAvgBigint())
+ .build());
+ }
+
+ @Override
+ public Optional implementAggregation(ConnectorSession session, AggregateFunction aggregate, Map assignments)
+ {
+ // TODO support complex ConnectorExpressions
+ return aggregateFunctionRewriter.rewrite(session, aggregate, assignments);
+ }
+
+ private static Optional toTypeHandle(DecimalType decimalType)
+ {
+ return Optional.of(new JdbcTypeHandle(Types.NUMERIC, Optional.of("decimal"), decimalType.getPrecision(), decimalType.getScale(), Optional.empty(), Optional.empty()));
}
private String[] getTableTypes()
@@ -275,7 +312,14 @@ public Optional toPrestoType(ConnectorSession session, Connection
int scale = max(decimalDigits, 0);
Optional numberDefaultScale = getNumberDefaultScale(session);
RoundingMode roundingMode = getNumberRoundingMode(session);
- if (precision < scale) {
+ if (precision == 40 && decimalDigits == 0) {
+ return Optional.of(ColumnMapping.longMapping(
+ BIGINT,
+ ResultSet::getLong,
+ bigintWriteFunction(),
+ OracleClient::fullPushdownIfSupported));
+ }
+ else if (precision < scale) {
if (roundingMode == RoundingMode.UNNECESSARY) {
break;
}
diff --git a/presto-oracle/src/test/java/io/prestosql/plugin/oracle/BaseOracleIntegrationSmokeTest.java b/presto-oracle/src/test/java/io/prestosql/plugin/oracle/BaseOracleIntegrationSmokeTest.java
index f4f8c13d5ff..0546ad7261f 100644
--- a/presto-oracle/src/test/java/io/prestosql/plugin/oracle/BaseOracleIntegrationSmokeTest.java
+++ b/presto-oracle/src/test/java/io/prestosql/plugin/oracle/BaseOracleIntegrationSmokeTest.java
@@ -172,6 +172,23 @@ private void predicatePushdownTest(String oracleType, String oracleLiteral, Stri
}
}
+ @Test
+ public void testAggregationPushdown()
+ throws Exception
+ {
+ // TODO support aggregation pushdown with GROUPING SETS
+ // TODO support aggregation over expressions
+
+ assertThat(query("SELECT count(*) FROM nation")).isFullyPushedDown();
+ assertThat(query("SELECT count(nationkey) FROM nation")).isFullyPushedDown();
+ assertThat(query("SELECT count(1) FROM nation")).isFullyPushedDown();
+ assertThat(query("SELECT regionkey, min(nationkey) FROM nation GROUP BY regionkey")).isFullyPushedDown();
+ assertThat(query("SELECT regionkey, max(nationkey) FROM nation GROUP BY regionkey")).isFullyPushedDown();
+ assertThat(query("SELECT regionkey, sum(nationkey) FROM nation GROUP BY regionkey")).isFullyPushedDown();
+ assertThat(query("SELECT regionkey, avg(nationkey) FROM nation GROUP BY regionkey")).isFullyPushedDown();
+ assertThat(query("SELECT regionkey, sum(nationkey) FROM nation WHERE regionkey < 4 GROUP BY regionkey")).isFullyPushedDown();
+ }
+
protected String getUser()
{
return TEST_USER;