Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -485,7 +485,6 @@ public PlanOptimizers(
new RemoveRedundantExists(),
new ImplementFilteredAggregations(metadata),
new SingleDistinctAggregationToGroupBy(),
new MultipleDistinctAggregationToMarkDistinct(),
new MergeLimitWithDistinct(),
new PruneCountAggregationOverScalar(metadata),
new PruneOrderByInAggregation(metadata),
Expand Down Expand Up @@ -675,7 +674,8 @@ public PlanOptimizers(
new RemoveEmptyExceptBranches(),
new RemoveRedundantIdentityProjections(),
new PushAggregationThroughOuterJoin(),
new ReplaceRedundantJoinWithSource())), // Run this after PredicatePushDown optimizer as it inlines filter constants
new ReplaceRedundantJoinWithSource(), // Run this after PredicatePushDown optimizer as it inlines filter constants
new MultipleDistinctAggregationToMarkDistinct())), // Run this after aggregation pushdown so that multiple distinct aggregations can be pushed into a connector
inlineProjections,
simplifyOptimizer, // Re-run the SimplifyExpressions to simplify any recomposed expressions from other optimizations
pushProjectionIntoTableScanOptimizer,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,9 @@
import java.util.Map;
import java.util.Optional;
import java.util.OptionalLong;
import java.util.Set;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.Consumer;

import static com.google.common.base.Functions.identity;
import static com.google.common.base.Preconditions.checkArgument;
Expand Down Expand Up @@ -208,8 +210,14 @@ public Optional<ProjectionApplicationResult<ConnectorTableHandle>> applyProjecti
.map(JdbcColumnHandle.class::cast)
.collect(toImmutableList());

if (handle.getColumns().isPresent() && containSameElements(newColumns, handle.getColumns().get())) {
return Optional.empty();
if (handle.getColumns().isPresent()) {
Set<JdbcColumnHandle> newColumnSet = ImmutableSet.copyOf(newColumns);
Comment thread
losipiuk marked this conversation as resolved.
Outdated
Set<JdbcColumnHandle> tableColumnSet = ImmutableSet.copyOf(handle.getColumns().get());
if (newColumnSet.equals(tableColumnSet)) {
return Optional.empty();
}

verify(tableColumnSet.containsAll(newColumnSet), "applyProjection called with columns %s and some are not available in existing query: %s", newColumnSet, tableColumnSet);
}

return Optional.of(new ProjectionApplicationResult<>(
Expand Down Expand Up @@ -263,6 +271,24 @@ public Optional<AggregationApplicationResult<ConnectorTableHandle>> applyAggrega
ImmutableList.Builder<ConnectorExpression> projections = ImmutableList.builder();
ImmutableList.Builder<Assignment> resultAssignments = ImmutableList.builder();
ImmutableMap.Builder<String, String> expressions = ImmutableMap.builder();

List<List<JdbcColumnHandle>> groupingSetsAsJdbcColumnHandles = groupingSets.stream()
.map(groupingSet -> groupingSet.stream()
.map(JdbcColumnHandle.class::cast)
.collect(toImmutableList()))
.collect(toImmutableList());
Optional<List<JdbcColumnHandle>> tableColumns = handle.getColumns();
groupingSetsAsJdbcColumnHandles.stream()
.flatMap(List::stream)
.distinct()
.peek(handle.getColumns().<Consumer<JdbcColumnHandle>>map(
columns -> groupKey -> verify(columns.contains(groupKey),
"applyAggregation called with a grouping column %s which was not included in the table columns: %s",
groupKey,
tableColumns))
.orElse(groupKey -> {}))
.forEach(newColumns::add);
Comment thread
losipiuk marked this conversation as resolved.
Outdated

for (AggregateFunction aggregate : aggregates) {
Optional<JdbcExpression> expression = jdbcClient.implementAggregation(session, aggregate, assignments);
if (expression.isEmpty()) {
Expand All @@ -284,25 +310,16 @@ public Optional<AggregationApplicationResult<ConnectorTableHandle>> applyAggrega
expressions.put(columnName, expression.get().getExpression());
}

List<List<JdbcColumnHandle>> groupingSetsAsJdbcColumnHandles = groupingSets.stream()
.map(groupingSet -> groupingSet.stream()
.map(JdbcColumnHandle.class::cast)
.collect(toImmutableList()))
.collect(toImmutableList());

List<JdbcColumnHandle> newColumnsList = newColumns.build();

// TODO(https://github.com/trinodb/trino/issues/9021) We are reading all grouping columns from remote database as at this point we are not able to tell if they are needed up in the query.
// As a reason of that we need to also have matching column handles in JdbcTableHandle constructed below, as columns read via JDBC must match column handles list.
// For more context see assertion in JdbcRecordSetProvider.getRecordSet
PreparedQuery preparedQuery = jdbcClient.prepareQuery(
session,
handle,
Optional.of(groupingSetsAsJdbcColumnHandles),
ImmutableList.<JdbcColumnHandle>builder()
.addAll(groupingSetsAsJdbcColumnHandles.stream()
.flatMap(List::stream)
.distinct()
.iterator())
.addAll(newColumnsList)
.build(),
newColumnsList,
expressions.build());
handle = new JdbcTableHandle(
new JdbcQueryRelationHandle(preparedQuery),
Expand Down Expand Up @@ -759,9 +776,4 @@ public void dropSchema(ConnectorSession session, String schemaName)
{
jdbcClient.dropSchema(session, schemaName);
}

private static boolean containSameElements(Iterable<? extends ColumnHandle> first, Iterable<? extends ColumnHandle> second)
{
return ImmutableSet.copyOf(first).equals(ImmutableSet.copyOf(second));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -262,9 +262,7 @@ public static SliceReadFunction charReadFunction(CharType charType)

public static SliceWriteFunction charWriteFunction()
{
return (statement, index, value) -> {
statement.setString(index, value.toStringUtf8());
};
return (statement, index, value) -> statement.setString(index, value.toStringUtf8());
}

public static ColumnMapping defaultVarcharColumnMapping(int columnSize, boolean isRemoteCaseSensitive)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
/*
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Allows pushdown of multiple distinct aggregations as long as
only one non-count aggregation is distinct.

i may have asked about that but don't remember answer -- why "only one non-count" ?

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure what exactly @alexjo2144 had in mind here.
There are many constraints here. And definitely not all queries which satisfy the condition from commit message will be be fully pushed down

E.g. this one will not:

select max(distinct regionkey), count(distinct regionkey), count(distinct nationkey) from nation;

I suggest to just leave title line in commit message and drop the rest.

* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.trino.plugin.jdbc.expression;

import io.trino.matching.Capture;
import io.trino.matching.Captures;
import io.trino.matching.Pattern;
import io.trino.plugin.base.expression.AggregateFunctionRule;
import io.trino.plugin.jdbc.JdbcClient;
import io.trino.plugin.jdbc.JdbcColumnHandle;
import io.trino.plugin.jdbc.JdbcExpression;
import io.trino.plugin.jdbc.JdbcTypeHandle;
import io.trino.spi.connector.AggregateFunction;
import io.trino.spi.expression.Variable;
import io.trino.spi.type.BigintType;
import io.trino.spi.type.CharType;
import io.trino.spi.type.VarcharType;

import java.util.Optional;

import static com.google.common.base.Verify.verify;
import static io.trino.matching.Capture.newCapture;
import static io.trino.plugin.base.expression.AggregateFunctionPatterns.distinct;
import static io.trino.plugin.base.expression.AggregateFunctionPatterns.functionName;
import static io.trino.plugin.base.expression.AggregateFunctionPatterns.hasFilter;
import static io.trino.plugin.base.expression.AggregateFunctionPatterns.singleInput;
import static io.trino.plugin.base.expression.AggregateFunctionPatterns.variable;
import static io.trino.spi.type.BigintType.BIGINT;
import static java.lang.String.format;
import static java.util.Objects.requireNonNull;

/**
* Implements {@code count(DISTINCT x)}.
*/
public class ImplementCountDistinct
implements AggregateFunctionRule
{
private static final Capture<Variable> INPUT = newCapture();

private final JdbcTypeHandle bigintTypeHandle;
private final boolean isRemoteCollationSensitive;

/**
* @param bigintTypeHandle A {@link JdbcTypeHandle} that will be mapped to {@link BigintType} by {@link JdbcClient#toColumnMapping}.
*/
public ImplementCountDistinct(JdbcTypeHandle bigintTypeHandle, boolean isRemoteCollationSensitive)
{
this.bigintTypeHandle = requireNonNull(bigintTypeHandle, "bigintTypeHandle is null");
this.isRemoteCollationSensitive = isRemoteCollationSensitive;
}

@Override
public Pattern<AggregateFunction> getPattern()
{
return Pattern.typeOf(AggregateFunction.class)
.with(distinct().equalTo(true))
.with(hasFilter().equalTo(false))
.with(functionName().equalTo("count"))
.with(singleInput().matching(variable().capturedAs(INPUT)));
}

@Override
public Optional<JdbcExpression> rewrite(AggregateFunction aggregateFunction, Captures captures, RewriteContext context)
{
Variable input = captures.get(INPUT);
JdbcColumnHandle columnHandle = (JdbcColumnHandle) context.getAssignment(input.getName());
verify(aggregateFunction.getOutputType() == BIGINT);

boolean isCaseSensitiveType = columnHandle.getColumnType() instanceof CharType || columnHandle.getColumnType() instanceof VarcharType;
if (aggregateFunction.isDistinct() && !isRemoteCollationSensitive && isCaseSensitiveType) {
// Remote database is case insensitive or compares values differently from Trino
return Optional.empty();
}

return Optional.of(new JdbcExpression(
format("count(DISTINCT %s)", context.getIdentifierQuote().apply(columnHandle.getColumnName())),
bigintTypeHandle));
}
}
Loading