Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.facebook.presto.sql.planner;

import com.facebook.presto.Session;
import com.facebook.presto.execution.warnings.WarningCollector;
import com.facebook.presto.metadata.Metadata;
import com.facebook.presto.spi.type.Type;
import com.facebook.presto.sql.parser.SqlParser;
import com.facebook.presto.sql.tree.Expression;
import com.facebook.presto.sql.tree.ExpressionRewriter;
import com.facebook.presto.sql.tree.ExpressionTreeRewriter;
import com.facebook.presto.sql.tree.LambdaExpression;
import com.facebook.presto.sql.tree.NodeRef;

import java.util.Map;

import static com.facebook.presto.sql.analyzer.ExpressionAnalyzer.getExpressionTypes;
import static java.util.Collections.emptyList;

@Deprecated
public class AnalyzedExpressionRewriter
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm... AnalyzedExpressionRewriter sounds a bit vague... let's talk in person quickly to figure out a better name (or decided that's the best name and we just need some comments.. 😃 )

Copy link
Copy Markdown
Contributor Author

@hellium01 hellium01 May 14, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As discussed offline, will add a Deprecated Mark over there and keep the name as is.

{
private final Session session;
private final Metadata metadata;
private final SqlParser sqlParser;
private final TypeProvider typeProvider;

public AnalyzedExpressionRewriter(Session session, Metadata metadata, SqlParser sqlParser, TypeProvider typeProvider)
{
this.session = session;
this.metadata = metadata;
this.sqlParser = sqlParser;
this.typeProvider = typeProvider;
}

public Expression rewriteWith(RewriterProvider<Void> rewriterProvider, Expression expression)
{
return rewriteWith(rewriterProvider, expression, null);
}

public <C> Expression rewriteWith(RewriterProvider<C> rewriterProvider, Expression expression, C context)
{
// Lambda cannot be analyzed outside the context of function, but its body can be rewritten.
if (expression instanceof LambdaExpression) {
LambdaExpression lambdaExpression = (LambdaExpression) expression;
return new LambdaExpression(lambdaExpression.getArguments(), rewriteWith(rewriterProvider, lambdaExpression.getBody(), context));
}
Map<NodeRef<Expression>, Type> expressionTypes = getExpressionTypes(
session,
metadata,
sqlParser,
typeProvider,
expression,
emptyList(),
WarningCollector.NOOP);
return ExpressionTreeRewriter.rewriteWith(rewriterProvider.get(expressionTypes), expression, context);
}

interface RewriterProvider<C>
{
ExpressionRewriter<C> get(Map<NodeRef<Expression>, Type> expressionTypes);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
package com.facebook.presto.sql.planner;

import com.facebook.presto.Session;
import com.facebook.presto.execution.warnings.WarningCollector;
import com.facebook.presto.metadata.Metadata;
import com.facebook.presto.spi.type.Type;
import com.facebook.presto.sql.parser.SqlParser;
Expand All @@ -28,23 +27,20 @@
import com.facebook.presto.sql.tree.QualifiedName;
import com.facebook.presto.sql.tree.SymbolReference;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;

import java.util.Map;

import static com.facebook.presto.spi.type.TimeType.TIME;
import static com.facebook.presto.spi.type.TimeWithTimeZoneType.TIME_WITH_TIME_ZONE;
import static com.facebook.presto.spi.type.TimestampType.TIMESTAMP;
import static com.facebook.presto.spi.type.TimestampWithTimeZoneType.TIMESTAMP_WITH_TIME_ZONE;
import static com.facebook.presto.sql.analyzer.ExpressionAnalyzer.getExpressionTypes;
import static java.util.Collections.emptyList;
import static java.util.Objects.requireNonNull;

public class DesugarAtTimeZoneRewriter
{
public static Expression rewrite(Expression expression, Map<NodeRef<Expression>, Type> expressionTypes)
{
return ExpressionTreeRewriter.rewriteWith(new Visitor(expressionTypes), expression);
return ExpressionTreeRewriter.rewriteWith(new Visitor(expressionTypes), expression, null);
}

private DesugarAtTimeZoneRewriter() {}
Expand All @@ -57,9 +53,7 @@ public static Expression rewrite(Expression expression, Session session, Metadat
if (expression instanceof SymbolReference) {
return expression;
}
Map<NodeRef<Expression>, Type> expressionTypes = getExpressionTypes(session, metadata, sqlParser, symbolAllocator.getTypes(), expression, emptyList(), WarningCollector.NOOP);

return rewrite(expression, expressionTypes);
return new AnalyzedExpressionRewriter(session, metadata, sqlParser, symbolAllocator.getTypes()).rewriteWith(Visitor::new, expression);
}

private static class Visitor
Expand All @@ -69,14 +63,14 @@ private static class Visitor

public Visitor(Map<NodeRef<Expression>, Type> expressionTypes)
{
this.expressionTypes = ImmutableMap.copyOf(requireNonNull(expressionTypes, "expressionTypes is null"));
this.expressionTypes = requireNonNull(expressionTypes, "expressionTypes is null");
}

@Override
public Expression rewriteAtTimeZone(AtTimeZone node, Void context, ExpressionTreeRewriter<Void> treeRewriter)
{
Expression value = treeRewriter.rewrite(node.getValue(), context);
Type type = getType(node.getValue());
Type type = expressionTypes.get(NodeRef.of(node.getValue()));
if (type.equals(TIME)) {
value = new Cast(value, TIME_WITH_TIME_ZONE.getDisplayName());
}
Expand All @@ -86,10 +80,5 @@ else if (type.equals(TIMESTAMP)) {

return new FunctionCall(QualifiedName.of("at_timezone"), ImmutableList.of(value, treeRewriter.rewrite(node.getTimeZone(), context)));
}

private Type getType(Expression expression)
{
return expressionTypes.get(NodeRef.of(expression));
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@
import java.util.List;

import static com.facebook.presto.sql.planner.iterative.Lookup.noLookup;
import static com.facebook.presto.sql.relational.OriginalExpressionUtils.castToRowExpression;
import static com.google.common.collect.ImmutableList.toImmutableList;
import static java.util.Objects.requireNonNull;

Expand Down Expand Up @@ -92,7 +91,20 @@ public Void visitGroupReference(GroupReference node, ImmutableList.Builder<RowEx
public Void visitAggregation(AggregationNode node, ImmutableList.Builder<RowExpression> context)
{
node.getAggregations().values()
.forEach(aggregation -> context.add(castToRowExpression(aggregation.getCall())));
.forEach(aggregation -> {
aggregation.getArguments()
.stream()
.map(OriginalExpressionUtils::castToRowExpression)
.forEach(context::add);
aggregation.getFilter().map(OriginalExpressionUtils::castToRowExpression).ifPresent(context::add);
aggregation.getOrderBy()
.map(OrderingScheme::getOrderBy)
.orElse(ImmutableList.of())
.stream()
.map(Symbol::toSymbolReference)
.map(OriginalExpressionUtils::castToRowExpression)
.forEach(context::add);
});
return super.visitAggregation(node, context);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -178,8 +178,6 @@
import com.facebook.presto.sql.tree.LambdaArgumentDeclaration;
import com.facebook.presto.sql.tree.LambdaExpression;
import com.facebook.presto.sql.tree.NodeRef;
import com.facebook.presto.sql.tree.OrderBy;
import com.facebook.presto.sql.tree.SortItem;
import com.facebook.presto.sql.tree.SymbolReference;
import com.google.common.base.VerifyException;
import com.google.common.collect.ContiguousSet;
Expand Down Expand Up @@ -2525,15 +2523,15 @@ private AccumulatorFactory buildAccumulatorFactory(
InternalAggregationFunction internalAggregationFunction = functionManager.getAggregateFunctionImplementation(aggregation.getFunctionHandle());

List<Integer> valueChannels = new ArrayList<>();
for (Expression argument : aggregation.getCall().getArguments()) {
for (Expression argument : aggregation.getArguments()) {
if (!(argument instanceof LambdaExpression)) {
Symbol argumentSymbol = Symbol.from(argument);
valueChannels.add(source.getLayout().get(argumentSymbol));
}
}

List<LambdaProvider> lambdaProviders = new ArrayList<>();
List<LambdaExpression> lambdaExpressions = aggregation.getCall().getArguments().stream()
List<LambdaExpression> lambdaExpressions = aggregation.getArguments().stream()
.filter(LambdaExpression.class::isInstance)
.map(LambdaExpression.class::cast)
.collect(toImmutableList());
Expand Down Expand Up @@ -2601,17 +2599,10 @@ private AccumulatorFactory buildAccumulatorFactory(
Optional<Integer> maskChannel = aggregation.getMask().map(value -> source.getLayout().get(value));
List<SortOrder> sortOrders = ImmutableList.of();
List<Symbol> sortKeys = ImmutableList.of();
if (aggregation.getCall().getOrderBy().isPresent()) {
OrderBy orderBy = aggregation.getCall().getOrderBy().get();

sortKeys = orderBy.getSortItems().stream()
.map(SortItem::getSortKey)
.map(Symbol::from)
.collect(toImmutableList());

sortOrders = orderBy.getSortItems().stream()
.map(QueryPlanner::toSortOrder)
.collect(toImmutableList());
if (aggregation.getOrderBy().isPresent()) {
OrderingScheme orderBy = aggregation.getOrderBy().get();
sortKeys = orderBy.getOrderBy();
sortOrders = orderBy.getOrderingList();
}

return internalAggregationFunction.bind(
Expand All @@ -2621,7 +2612,7 @@ private AccumulatorFactory buildAccumulatorFactory(
getChannelsForSymbols(sortKeys, source.getLayout()),
sortOrders,
pagesIndexFactory,
aggregation.getCall().isDistinct(),
aggregation.isDistinct(),
joinCompiler,
lambdaProviders,
session);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -288,7 +288,7 @@ public PlanOptimizers(
new MultipleDistinctAggregationToMarkDistinct(),
new ImplementBernoulliSampleAsFilter(),
new MergeLimitWithDistinct(),
new PruneCountAggregationOverScalar(),
new PruneCountAggregationOverScalar(metadata.getFunctionManager()),
new PruneOrderByInAggregation(metadata.getFunctionManager()),
new RewriteSpatialPartitioningAggregation(metadata)))
.build()),
Expand Down Expand Up @@ -357,7 +357,7 @@ public PlanOptimizers(
estimatedExchangesCostCalculator,
ImmutableSet.of(
new RemoveRedundantIdentityProjections(),
new PushAggregationThroughOuterJoin())),
new PushAggregationThroughOuterJoin(metadata.getFunctionManager()))),
inlineProjections,
simplifyOptimizer, // Re-run the SimplifyExpressions to simplify any recomposed expressions from other optimizations
projectionPushDown,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.facebook.presto.sql.planner;

import com.facebook.presto.spi.block.SortOrder;
import com.facebook.presto.sql.tree.Expression;
import com.facebook.presto.sql.tree.OrderBy;
import com.facebook.presto.sql.tree.SortItem;
import com.facebook.presto.sql.tree.SymbolReference;

import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.function.Function;

import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.collect.ImmutableList.toImmutableList;

public class PlannerUtils
{
private PlannerUtils() {}

public static SortOrder toSortOrder(SortItem sortItem)
{
if (sortItem.getOrdering() == SortItem.Ordering.ASCENDING) {
if (sortItem.getNullOrdering() == SortItem.NullOrdering.FIRST) {
return SortOrder.ASC_NULLS_FIRST;
}
return SortOrder.ASC_NULLS_LAST;
}
if (sortItem.getNullOrdering() == SortItem.NullOrdering.FIRST) {
return SortOrder.DESC_NULLS_FIRST;
}
return SortOrder.DESC_NULLS_LAST;
}

public static OrderingScheme toOrderingScheme(List<SortItem> sortItems)
{
return toOrderingScheme(sortItems, item -> {
checkArgument(item instanceof SymbolReference, "must be symbol reference");
return new Symbol(((SymbolReference) item).getName());
});
}

public static OrderingScheme toOrderingScheme(List<SortItem> sortItems, Function<Expression, Symbol> translator)
{
// The logic is similar to QueryPlanner::sort
Map<Symbol, SortOrder> orderings = new LinkedHashMap<>();
for (SortItem item : sortItems) {
Symbol symbol = translator.apply(item.getSortKey());
// don't override existing keys, i.e. when "ORDER BY a ASC, a DESC" is specified
Comment thread
hellium01 marked this conversation as resolved.
Outdated
orderings.putIfAbsent(symbol, toSortOrder(item));
}
return new OrderingScheme(orderings.keySet().stream().collect(toImmutableList()), orderings);
}

public static OrderingScheme toOrderingScheme(OrderBy orderBy)
{
return toOrderingScheme(orderBy.getSortItems());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,6 @@
import com.facebook.presto.sql.tree.Query;
import com.facebook.presto.sql.tree.QuerySpecification;
import com.facebook.presto.sql.tree.SortItem;
import com.facebook.presto.sql.tree.SortItem.NullOrdering;
import com.facebook.presto.sql.tree.SortItem.Ordering;
import com.facebook.presto.sql.tree.SymbolReference;
import com.facebook.presto.sql.tree.Window;
import com.facebook.presto.sql.tree.WindowFrame;
Expand All @@ -80,6 +78,7 @@
import static com.facebook.presto.spi.type.BigintType.BIGINT;
import static com.facebook.presto.spi.type.VarbinaryType.VARBINARY;
import static com.facebook.presto.sql.NodeUtils.getSortItemsFromOrderBy;
import static com.facebook.presto.sql.planner.PlannerUtils.toSortOrder;
import static com.facebook.presto.sql.planner.optimizations.WindowNodeUtil.toBoundType;
import static com.facebook.presto.sql.planner.optimizations.WindowNodeUtil.toWindowType;
import static com.facebook.presto.sql.planner.plan.AggregationNode.groupingSets;
Expand Down Expand Up @@ -559,8 +558,16 @@ private PlanBuilder aggregate(PlanBuilder subPlan, QuerySpecification node)
needPostProjectionCoercion = true;
}
aggregationTranslations.put(aggregate, newSymbol);

aggregationsBuilder.put(newSymbol, new Aggregation((FunctionCall) rewritten, analysis.getFunctionHandle(aggregate), Optional.empty()));
FunctionCall rewrittenFunction = (FunctionCall) rewritten;

aggregationsBuilder.put(newSymbol,
new Aggregation(
analysis.getFunctionHandle(aggregate),
rewrittenFunction.getArguments(),
rewrittenFunction.getFilter(),
rewrittenFunction.getOrderBy().map(OrderBy::getSortItems).map(PlannerUtils::toOrderingScheme),
rewrittenFunction.isDistinct(),
Optional.empty()));
}
Map<Symbol, Aggregation> aggregations = aggregationsBuilder.build();

Expand Down Expand Up @@ -888,6 +895,7 @@ private PlanBuilder sort(PlanBuilder subPlan, Optional<OrderBy> orderBy, Optiona

Iterator<SortItem> sortItems = orderBy.get().getSortItems().iterator();

// This logic is similar to PlannerUtils::toOrderingScheme
ImmutableList.Builder<Symbol> orderBySymbols = ImmutableList.builder();
Map<Symbol, SortOrder> orderings = new HashMap<>();
for (Expression fieldOrExpression : orderByExpressions) {
Expand Down Expand Up @@ -947,24 +955,4 @@ private static Map<Expression, Symbol> symbolsForExpressions(PlanBuilder builder
.distinct()
.collect(toImmutableMap(expression -> expression, builder::translate));
}

public static SortOrder toSortOrder(SortItem sortItem)
{
if (sortItem.getOrdering() == Ordering.ASCENDING) {
if (sortItem.getNullOrdering() == NullOrdering.FIRST) {
return SortOrder.ASC_NULLS_FIRST;
}
else {
return SortOrder.ASC_NULLS_LAST;
}
}
else {
if (sortItem.getNullOrdering() == NullOrdering.FIRST) {
return SortOrder.DESC_NULLS_FIRST;
}
else {
return SortOrder.DESC_NULLS_LAST;
}
}
}
}
Loading