Skip to content
Closed
9 changes: 8 additions & 1 deletion presto-docs/src/main/sphinx/connector/oracle.rst
Original file line number Diff line number Diff line change
Expand Up @@ -360,7 +360,14 @@ include ``SYNONYM``, add the following configuration property:
Pushdown
--------

The connector supports :doc:`pushdown </optimizer/pushdown>` for optimized query processing.
The connector supports :doc:`pushdown </optimizer/pushdown>` for processing the
following aggregate functions:

* :func:`avg`
* :func:`count`
* :func:`max`
* :func:`min`
* :func:`sum`

Limitations
-----------
Expand Down
5 changes: 5 additions & 0 deletions presto-oracle/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,11 @@
<artifactId>presto-base-jdbc</artifactId>
</dependency>

<dependency>
<groupId>io.prestosql</groupId>
<artifactId>presto-matching</artifactId>
</dependency>

<dependency>
<groupId>io.airlift</groupId>
<artifactId>configuration</artifactId>
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.prestosql.plugin.oracle;

import io.prestosql.matching.Capture;
import io.prestosql.matching.Captures;
import io.prestosql.matching.Pattern;
import io.prestosql.plugin.jdbc.JdbcColumnHandle;
import io.prestosql.plugin.jdbc.JdbcExpression;
import io.prestosql.plugin.jdbc.expression.AggregateFunctionRule;
import io.prestosql.spi.connector.AggregateFunction;
import io.prestosql.spi.expression.Variable;
import io.prestosql.spi.type.DecimalType;

import java.util.Optional;

import static com.google.common.base.Verify.verify;
import static io.prestosql.matching.Capture.newCapture;
import static io.prestosql.plugin.jdbc.expression.AggregateFunctionPatterns.basicAggregation;
import static io.prestosql.plugin.jdbc.expression.AggregateFunctionPatterns.expressionType;
import static io.prestosql.plugin.jdbc.expression.AggregateFunctionPatterns.functionName;
import static io.prestosql.plugin.jdbc.expression.AggregateFunctionPatterns.singleInput;
import static io.prestosql.plugin.jdbc.expression.AggregateFunctionPatterns.variable;
import static io.prestosql.spi.type.BigintType.BIGINT;
import static java.lang.String.format;

public class ImplementAvgBigint
implements AggregateFunctionRule
{
private static final Capture<Variable> INPUT = newCapture();

@Override
public Pattern<AggregateFunction> getPattern()
{
return basicAggregation()
.with(functionName().equalTo("avg"))
.with(singleInput().matching(
variable()
.with(expressionType().matching(type -> type == BIGINT))
.capturedAs(INPUT)));
}

@Override
public Optional<JdbcExpression> rewrite(AggregateFunction aggregateFunction, Captures captures, RewriteContext context)
{
Variable input = captures.get(INPUT);
JdbcColumnHandle columnHandle = (JdbcColumnHandle) context.getAssignments().get(input.getName());
DecimalType type = (DecimalType) columnHandle.getColumnType();
verify(aggregateFunction.getOutputType().equals(type));

return Optional.of(new JdbcExpression(
format("avg(CAST(%s AS DECIMAL(%s, %s)))", columnHandle.toSqlExpression(context.getIdentifierQuote()), type.getPrecision(), type.getScale()),
columnHandle.getJdbcTypeHandle()));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,23 @@
import io.prestosql.plugin.jdbc.ConnectionFactory;
import io.prestosql.plugin.jdbc.DoubleWriteFunction;
import io.prestosql.plugin.jdbc.JdbcColumnHandle;
import io.prestosql.plugin.jdbc.JdbcExpression;
import io.prestosql.plugin.jdbc.JdbcTableHandle;
import io.prestosql.plugin.jdbc.JdbcTypeHandle;
import io.prestosql.plugin.jdbc.LongWriteFunction;
import io.prestosql.plugin.jdbc.SliceWriteFunction;
import io.prestosql.plugin.jdbc.WriteMapping;
import io.prestosql.plugin.jdbc.expression.AggregateFunctionRewriter;
import io.prestosql.plugin.jdbc.expression.AggregateFunctionRule;
import io.prestosql.plugin.jdbc.expression.ImplementAvgDecimal;
import io.prestosql.plugin.jdbc.expression.ImplementAvgFloatingPoint;
import io.prestosql.plugin.jdbc.expression.ImplementCount;
import io.prestosql.plugin.jdbc.expression.ImplementCountAll;
import io.prestosql.plugin.jdbc.expression.ImplementMinMax;
import io.prestosql.plugin.jdbc.expression.ImplementSum;
import io.prestosql.spi.PrestoException;
import io.prestosql.spi.connector.AggregateFunction;
import io.prestosql.spi.connector.ColumnHandle;
import io.prestosql.spi.connector.ConnectorSession;
import io.prestosql.spi.connector.SchemaTableName;
import io.prestosql.spi.type.CharType;
Expand Down Expand Up @@ -126,6 +137,8 @@ public class OracleClient

private static final int PRECISION_OF_UNSPECIFIED_NUMBER = 127;

private final AggregateFunctionRewriter aggregateFunctionRewriter;

private static final Set<String> INTERNAL_SCHEMAS = ImmutableSet.<String>builder()
.add("ctxsys")
.add("flows_files")
Expand Down Expand Up @@ -162,6 +175,25 @@ public OracleClient(

requireNonNull(oracleConfig, "oracle config is null");
this.synonymsEnabled = oracleConfig.isSynonymsEnabled();

JdbcTypeHandle bigintTypeHandle = new JdbcTypeHandle(Types.NUMERIC, Optional.of("decimal"), 40, 0, Optional.empty(), Optional.empty());
this.aggregateFunctionRewriter = new AggregateFunctionRewriter(
this::quoted,
ImmutableSet.<AggregateFunctionRule>builder()
.add(new ImplementCountAll(bigintTypeHandle))
.add(new ImplementCount(bigintTypeHandle))
.add(new ImplementMinMax())
.add(new ImplementSum(OracleClient::toTypeHandle))
.add(new ImplementAvgFloatingPoint())
.add(new ImplementAvgDecimal())
.add(new ImplementAvgBigint())
.build());
}

@Override
public Optional<JdbcExpression> implementAggregation(ConnectorSession session, AggregateFunction aggregate, Map<String, ColumnHandle> assignments)
{
return aggregateFunctionRewriter.rewrite(session, aggregate, assignments);
}

private String[] getTableTypes()
Expand Down Expand Up @@ -274,7 +306,14 @@ public Optional<ColumnMapping> toPrestoType(ConnectorSession session, Connection
int scale = max(decimalDigits, 0);
Optional<Integer> numberDefaultScale = getNumberDefaultScale(session);
RoundingMode roundingMode = getNumberRoundingMode(session);
if (precision < scale) {
if (precision == 40 && decimalDigits == 0) {
return Optional.of(ColumnMapping.longMapping(
BIGINT,
ResultSet::getLong,
bigintWriteFunction(),
FULL_PUSHDOWN));
}
else if (precision < scale) {
if (roundingMode == RoundingMode.UNNECESSARY) {
break;
}
Expand Down Expand Up @@ -475,4 +514,9 @@ public void setColumnComment(ConnectorSession session, JdbcTableHandle handle, J
comment.orElse(""));
execute(session, sql);
}

private static Optional<JdbcTypeHandle> toTypeHandle(DecimalType decimalType)
{
return Optional.of(new JdbcTypeHandle(Types.NUMERIC, Optional.of("decimal"), decimalType.getPrecision(), decimalType.getScale(), Optional.empty(), Optional.empty()));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,20 @@ public void testTooLargeDomainCompactionThreshold()
"SELECT * from nation", "Domain compaction threshold \\(10000\\) cannot exceed 1000");
}

@Test
public void testAggregationPushdown()
throws Exception
{
assertThat(query("SELECT count(*) FROM nation")).isFullyPushedDown();
assertThat(query("SELECT count(nationkey) FROM nation")).isFullyPushedDown();
assertThat(query("SELECT count(1) FROM nation")).isFullyPushedDown();
assertThat(query("SELECT regionkey, min(nationkey) FROM nation GROUP BY regionkey")).isFullyPushedDown();
assertThat(query("SELECT regionkey, max(nationkey) FROM nation GROUP BY regionkey")).isFullyPushedDown();
assertThat(query("SELECT regionkey, sum(nationkey) FROM nation GROUP BY regionkey")).isFullyPushedDown();
assertThat(query("SELECT regionkey, avg(nationkey) FROM nation GROUP BY regionkey")).isFullyPushedDown();
assertThat(query("SELECT regionkey, sum(nationkey) FROM nation WHERE regionkey < 4 GROUP BY regionkey")).isFullyPushedDown();
}

private void predicatePushdownTest(String oracleType, String oracleLiteral, String operator, String filterLiteral)
{
String tableName = "test_pdown_" + oracleType.replaceAll("[^a-zA-Z0-9]", "");
Expand Down