Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -92,11 +92,8 @@ public Void visitAggregation(AggregationNode node, ImmutableList.Builder<RowExpr
{
node.getAggregations().values()
.forEach(aggregation -> {
aggregation.getArguments()
.stream()
.map(OriginalExpressionUtils::castToRowExpression)
.forEach(context::add);
aggregation.getFilter().map(OriginalExpressionUtils::castToRowExpression).ifPresent(context::add);
aggregation.getArguments().forEach(context::add);
aggregation.getFilter().ifPresent(context::add);
aggregation.getOrderBy()
.map(OrderingScheme::getOrderBy)
.orElse(ImmutableList.of())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,6 @@
import com.facebook.presto.spi.relation.LambdaDefinitionExpression;
import com.facebook.presto.spi.relation.RowExpression;
import com.facebook.presto.spi.relation.VariableReferenceExpression;
import com.facebook.presto.spi.type.FunctionType;
import com.facebook.presto.spi.type.Type;
import com.facebook.presto.spiller.PartitioningSpillerFactory;
import com.facebook.presto.spiller.SingleStreamSpillerFactory;
Expand Down Expand Up @@ -177,8 +176,6 @@
import com.facebook.presto.sql.relational.SqlToRowExpressionTranslator;
import com.facebook.presto.sql.relational.VariableToChannelTranslator;
import com.facebook.presto.sql.tree.Expression;
import com.facebook.presto.sql.tree.LambdaArgumentDeclaration;
import com.facebook.presto.sql.tree.LambdaExpression;
import com.facebook.presto.sql.tree.NodeRef;
import com.facebook.presto.sql.tree.SymbolReference;
import com.google.common.base.VerifyException;
Expand Down Expand Up @@ -2259,8 +2256,8 @@ public PhysicalOperation visitTableFinish(TableFinishNode node, LocalExecutionPl
ImmutableMap.Builder<VariableReferenceExpression, Integer> outputMapping = ImmutableMap.builder();

OperatorFactory statisticsAggregation = node.getStatisticsAggregation().map(aggregation -> {
List<Symbol> groupingSymbols = aggregation.getGroupingSymbols();
if (groupingSymbols.isEmpty()) {
List<VariableReferenceExpression> groupingVariables = aggregation.getGroupingVariables();
if (groupingVariables.isEmpty()) {
return createAggregationOperatorFactory(
node.getId(),
aggregation.getAggregations(),
Expand Down Expand Up @@ -2528,77 +2525,26 @@ private AccumulatorFactory buildAccumulatorFactory(
InternalAggregationFunction internalAggregationFunction = functionManager.getAggregateFunctionImplementation(aggregation.getFunctionHandle());

List<Integer> valueChannels = new ArrayList<>();
for (int i = 0; i < aggregation.getArguments().size(); i++) {
Expression argument = aggregation.getArguments().get(i);
if (!(argument instanceof LambdaExpression)) {
VariableReferenceExpression argumentVariable = new VariableReferenceExpression(Symbol.from(argument).getName(), types.get(Symbol.from(argument)));
valueChannels.add(source.getLayout().get(argumentVariable));
for (RowExpression argument : aggregation.getArguments()) {
if (!(argument instanceof LambdaDefinitionExpression)) {
checkArgument(argument instanceof VariableReferenceExpression, "argument must be variable reference");
valueChannels.add(source.getLayout().get(argument));
}
}

List<LambdaProvider> lambdaProviders = new ArrayList<>();
List<LambdaExpression> lambdaExpressions = aggregation.getArguments().stream()
.filter(LambdaExpression.class::isInstance)
.map(LambdaExpression.class::cast)
List<LambdaDefinitionExpression> lambdas = aggregation.getArguments().stream()
.filter(LambdaDefinitionExpression.class::isInstance)
.map(LambdaDefinitionExpression.class::cast)
.collect(toImmutableList());
if (!lambdaExpressions.isEmpty()) {
List<FunctionType> functionTypes = functionManager.getFunctionMetadata(aggregation.getFunctionHandle()).getArgumentTypes().stream()
.filter(typeSignature -> typeSignature.getBase().equals(FunctionType.NAME))
.map(typeSignature -> (FunctionType) (metadata.getTypeManager().getType(typeSignature)))
.collect(toImmutableList());
for (int i = 0; i < lambdas.size(); i++) {
List<Class> lambdaInterfaces = internalAggregationFunction.getLambdaInterfaces();
verify(lambdaExpressions.size() == functionTypes.size());
verify(lambdaExpressions.size() == lambdaInterfaces.size());

for (int i = 0; i < lambdaExpressions.size(); i++) {
LambdaExpression lambdaExpression = lambdaExpressions.get(i);
FunctionType functionType = functionTypes.get(i);

// To compile lambda, LambdaDefinitionExpression needs to be generated from LambdaExpression,
// which requires the types of all sub-expressions.
//
// In project and filter expression compilation, ExpressionAnalyzer.getExpressionTypesFromInput
// is used to generate the types of all sub-expressions. (see visitScanFilterAndProject and visitFilter)
//
// This does not work here since the function call representation in final aggregation node
// is currently a hack: it takes intermediate type as input, and may not be a valid
// function call in Presto.
//
// TODO: Once the final aggregation function call representation is fixed,
// the same mechanism in project and filter expression should be used here.
verify(lambdaExpression.getArguments().size() == functionType.getArgumentTypes().size());
Map<NodeRef<Expression>, Type> lambdaArgumentExpressionTypes = new HashMap<>();
Map<Symbol, Type> lambdaArgumentSymbolTypes = new HashMap<>();
for (int j = 0; j < lambdaExpression.getArguments().size(); j++) {
LambdaArgumentDeclaration argument = lambdaExpression.getArguments().get(j);
Type type = functionType.getArgumentTypes().get(j);
lambdaArgumentExpressionTypes.put(NodeRef.of(argument), type);
lambdaArgumentSymbolTypes.put(new Symbol(argument.getName().getValue()), type);
}
Map<NodeRef<Expression>, Type> expressionTypes = ImmutableMap.<NodeRef<Expression>, Type>builder()
// the lambda expression itself
.put(NodeRef.of(lambdaExpression), functionType)
// expressions from lambda arguments
.putAll(lambdaArgumentExpressionTypes)
// expressions from lambda body
.putAll(getExpressionTypes(
session,
metadata,
sqlParser,
TypeProvider.copyOf(lambdaArgumentSymbolTypes),
lambdaExpression.getBody(),
emptyList(),
NOOP))
.build();

LambdaDefinitionExpression lambda = (LambdaDefinitionExpression) toRowExpression(lambdaExpression, expressionTypes, ImmutableMap.of());
Class<? extends LambdaProvider> lambdaProviderClass = compileLambdaProvider(lambda, metadata.getFunctionManager(), lambdaInterfaces.get(i));
try {
lambdaProviders.add((LambdaProvider) constructorMethodHandle(lambdaProviderClass, ConnectorSession.class).invoke(session.toConnectorSession()));
}
catch (Throwable t) {
throw new RuntimeException(t);
}
Class<? extends LambdaProvider> lambdaProviderClass = compileLambdaProvider(lambdas.get(i), metadata.getFunctionManager(), lambdaInterfaces.get(i));
try {
lambdaProviders.add((LambdaProvider) constructorMethodHandle(lambdaProviderClass, ConnectorSession.class).invoke(session.toConnectorSession()));
}
catch (Throwable t) {
throw new RuntimeException(t);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,19 @@

import com.facebook.presto.spi.block.SortOrder;
import com.facebook.presto.spi.relation.VariableReferenceExpression;
import com.facebook.presto.sql.tree.Expression;
import com.facebook.presto.sql.tree.OrderBy;
import com.facebook.presto.sql.tree.SortItem;
import com.facebook.presto.sql.tree.SymbolReference;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.Streams;

import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;

import static com.facebook.presto.sql.relational.Expressions.variable;
import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.collect.ImmutableList.toImmutableList;

public class PlannerUtils
{
Expand All @@ -44,18 +47,26 @@ public static SortOrder toSortOrder(SortItem sortItem)
return SortOrder.DESC_NULLS_LAST;
}

public static OrderingScheme toOrderingScheme(List<SortItem> sortItems, TypeProvider typeProvider)
public static OrderingScheme toOrderingScheme(OrderBy orderBy, TypeProvider types)
{
return toOrderingScheme(
orderBy.getSortItems().stream()
.map(SortItem::getSortKey)
.map(item -> {
checkArgument(item instanceof SymbolReference, "must be symbol reference");
Symbol symbol = Symbol.from(item);
return variable(symbol.getName(), types.get(symbol));
}).collect(toImmutableList()),
orderBy.getSortItems().stream()
.map(PlannerUtils::toSortOrder)
.collect(toImmutableList()));
}

public static OrderingScheme toOrderingScheme(List<VariableReferenceExpression> orderingSymbols, List<SortOrder> sortOrders)
{
// The logic is similar to QueryPlanner::sort
Map<VariableReferenceExpression, SortOrder> orderings = new LinkedHashMap<>();
for (SortItem item : sortItems) {
Expression sortKey = item.getSortKey();
checkArgument(sortKey instanceof SymbolReference, "must be symbol reference");
Symbol symbol = Symbol.from(sortKey);
VariableReferenceExpression variable = new VariableReferenceExpression(symbol.getName(), typeProvider.get(symbol));
// don't override existing keys, i.e. when "ORDER BY a ASC, a DESC" is specified
orderings.putIfAbsent(variable, toSortOrder(item));
}
// don't override existing keys, i.e. when "ORDER BY a ASC, a DESC" is specified
Streams.forEachPair(orderingSymbols.stream(), sortOrders.stream(), orderings::putIfAbsent);
return new OrderingScheme(ImmutableList.copyOf(orderings.keySet()), orderings);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import com.facebook.presto.spi.ColumnHandle;
import com.facebook.presto.spi.block.SortOrder;
import com.facebook.presto.spi.plan.PlanNodeIdAllocator;
import com.facebook.presto.spi.relation.CallExpression;
import com.facebook.presto.spi.relation.VariableReferenceExpression;
import com.facebook.presto.spi.type.Type;
import com.facebook.presto.sql.analyzer.Analysis;
Expand Down Expand Up @@ -68,8 +69,6 @@
import com.google.common.collect.Sets;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
Expand Down Expand Up @@ -562,10 +561,13 @@ private PlanBuilder aggregate(PlanBuilder subPlan, QuerySpecification node)

aggregationsBuilder.put(newVariable,
new Aggregation(
analysis.getFunctionHandle(aggregate),
rewrittenFunction.getArguments(),
rewrittenFunction.getFilter(),
rewrittenFunction.getOrderBy().map(OrderBy::getSortItems).map(sortItems -> toOrderingScheme(sortItems, symbolAllocator.getTypes())),
new CallExpression(
aggregate.getName().getSuffix(),
analysis.getFunctionHandle(aggregate),
analysis.getType(aggregate),
rewrittenFunction.getArguments().stream().map(OriginalExpressionUtils::castToRowExpression).collect(toImmutableList())),
rewrittenFunction.getFilter().map(OriginalExpressionUtils::castToRowExpression),
rewrittenFunction.getOrderBy().map(orderBy -> toOrderingScheme(orderBy, symbolAllocator.getTypes())),
rewrittenFunction.isDistinct(),
Optional.empty()));
}
Expand Down Expand Up @@ -892,23 +894,10 @@ private PlanBuilder sort(PlanBuilder subPlan, Optional<OrderBy> orderBy, Optiona
return subPlan;
}

Iterator<SortItem> sortItems = orderBy.get().getSortItems().iterator();

// This logic is similar to PlannerUtils::toOrderingScheme
ImmutableList.Builder<VariableReferenceExpression> orderByVariables = ImmutableList.builder();
Map<VariableReferenceExpression, SortOrder> orderings = new HashMap<>();
for (Expression fieldOrExpression : orderByExpressions) {
VariableReferenceExpression variable = subPlan.translateToVariable(fieldOrExpression);

SortItem sortItem = sortItems.next();
if (!orderings.containsKey(variable)) {
orderByVariables.add(variable);
orderings.put(variable, toSortOrder(sortItem));
}
}

PlanNode planNode;
OrderingScheme orderingScheme = new OrderingScheme(orderByVariables.build(), orderings);
OrderingScheme orderingScheme = toOrderingScheme(
orderByExpressions.stream().map(subPlan::translate).collect(toImmutableList()),
orderBy.get().getSortItems().stream().map(PlannerUtils::toSortOrder).collect(toImmutableList()));
if (limit.isPresent() && !limit.get().equalsIgnoreCase("all")) {
planNode = new TopNNode(idAllocator.getNextId(), subPlan.getRoot(), Long.parseLong(limit.get()), orderingScheme, TopNNode.Step.SINGLE);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,27 +26,27 @@
import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.base.Preconditions.checkState;

public final class RowExpressionSymbolInliner
public final class RowExpressionVariableInliner
extends RowExpressionRewriter<Void>
{
private final Set<String> excludedNames = new HashSet<>();
private final Map<Symbol, Symbol> mapping;
private final Map<VariableReferenceExpression, RowExpression> mapping;

private RowExpressionSymbolInliner(Map<Symbol, Symbol> mapping)
private RowExpressionVariableInliner(Map<VariableReferenceExpression, RowExpression> mapping)
Comment thread
hellium01 marked this conversation as resolved.
{
this.mapping = mapping;
}

public static RowExpression inlineSymbols(Map<Symbol, Symbol> mapping, RowExpression expression)
public static RowExpression inlineVariables(Map<VariableReferenceExpression, RowExpression> mapping, RowExpression expression)
{
return RowExpressionTreeRewriter.rewriteWith(new RowExpressionSymbolInliner(mapping), expression);
return RowExpressionTreeRewriter.rewriteWith(new RowExpressionVariableInliner(mapping), expression);
}

@Override
public RowExpression rewriteVariableReference(VariableReferenceExpression node, Void context, RowExpressionTreeRewriter<Void> treeRewriter)
{
if (!excludedNames.contains(node.getName())) {
RowExpression result = new VariableReferenceExpression(mapping.get(new Symbol(node.getName())).getName(), node.getType());
RowExpression result = mapping.get(node);
checkState(result != null, "Cannot resolve symbol %s", node.getName());
return result;
}
Expand Down
Loading