Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,18 @@

import com.facebook.presto.Session;
import com.facebook.presto.matching.Pattern;
import com.facebook.presto.spi.relation.RowExpression;
import com.facebook.presto.sql.planner.Symbol;
import com.facebook.presto.sql.planner.TypeProvider;
import com.facebook.presto.sql.planner.iterative.Lookup;
import com.facebook.presto.sql.planner.plan.ProjectNode;
import com.facebook.presto.sql.tree.Expression;

import java.util.Map;
import java.util.Optional;

import static com.facebook.presto.sql.planner.plan.Patterns.project;
import static com.facebook.presto.sql.relational.OriginalExpressionUtils.castToExpression;
import static com.facebook.presto.sql.relational.OriginalExpressionUtils.isExpression;
import static java.util.Objects.requireNonNull;

public class ProjectStatsRule
Expand Down Expand Up @@ -53,8 +55,14 @@ protected Optional<PlanNodeStatsEstimate> doCalculate(ProjectNode node, StatsPro
PlanNodeStatsEstimate.Builder calculatedStats = PlanNodeStatsEstimate.builder()
.setOutputRowCount(sourceStats.getOutputRowCount());

for (Map.Entry<Symbol, Expression> entry : node.getAssignments().entrySet()) {
calculatedStats.addSymbolStatistics(entry.getKey(), scalarStatsCalculator.calculate(entry.getValue(), sourceStats, session, types));
for (Map.Entry<Symbol, RowExpression> entry : node.getAssignments().entrySet()) {
RowExpression expression = entry.getValue();
if (isExpression(expression)) {
calculatedStats.addSymbolStatistics(entry.getKey(), scalarStatsCalculator.calculate(castToExpression(expression), sourceStats, session, types));
}
else {
calculatedStats.addSymbolStatistics(entry.getKey(), scalarStatsCalculator.calculate(expression, sourceStats, session));
}
}
return Optional.of(calculatedStats.build());
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,198 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.facebook.presto.sql.planner;

import com.facebook.presto.spi.relation.RowExpression;
import com.facebook.presto.sql.planner.plan.Assignments;
import com.facebook.presto.sql.tree.Expression;
import com.facebook.presto.sql.tree.ExpressionRewriter;
import com.facebook.presto.sql.tree.ExpressionTreeRewriter;
import com.facebook.presto.sql.tree.SymbolReference;
import com.google.common.collect.Maps;

import java.util.Collection;
import java.util.LinkedHashMap;
import java.util.Map;
import java.util.function.Function;
import java.util.function.Predicate;
import java.util.stream.Collector;

import static com.facebook.presto.sql.relational.OriginalExpressionUtils.castToExpression;
import static com.facebook.presto.sql.relational.OriginalExpressionUtils.castToRowExpression;
import static com.google.common.base.Preconditions.checkState;
import static java.util.Arrays.asList;

/**
* Everything in this should be moved back to Assignments
*/
public class AssignmentsUtils
{
private AssignmentsUtils() {}

// Originally, the following functions are also static
public static Builder builder()
{
return new Builder();
}

public static Assignments identity(Symbol... symbols)
{
return identity(asList(symbols));
}

public static Assignments identity(Iterable<Symbol> symbols)
{
return builder().putIdentities(symbols).build();
}

public static Assignments copyOf(Map<Symbol, RowExpression> assignments)
{
return builder()
.putAll(assignments)
.build();
}

public static Assignments of()
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These name should be changed if they are no longer in the Assignments class. It's ok to call the static creator of, but that's really strange from another class. I don't know what to call them though.

{
return builder().build();
}

public static Assignments of(Symbol symbol, RowExpression expression)
{
return builder().put(symbol, expression).build();
}

public static Assignments of(Symbol symbol, Expression expression)
{
return builder().put(symbol, castToRowExpression(expression)).build();
}

public static Assignments of(Symbol symbol1, RowExpression expression1, Symbol symbol2, RowExpression expression2)
{
return builder().put(symbol1, expression1).put(symbol2, expression2).build();
}

public static Assignments of(Symbol symbol1, Expression expression1, Symbol symbol2, Expression expression2)
{
return builder().put(symbol1, castToRowExpression(expression1)).put(symbol2, castToRowExpression(expression2)).build();
}

// Originally, the following functions are not static move assignments as member variables
public static <C> Assignments rewrite(Assignments assignments, ExpressionRewriter<C> rewriter)
{
return rewrite(assignments, expression -> ExpressionTreeRewriter.rewriteWith(rewriter, expression));
}

public static Assignments rewrite(Assignments assignments, Function<Expression, Expression> rewrite)
{
return assignments.entrySet().stream()
.map(entry -> Maps.immutableEntry(entry.getKey(), castToRowExpression(rewrite.apply(castToExpression(entry.getValue())))))
.collect(toAssignments());
}

public static Assignments filter(Assignments assignments, Collection<Symbol> symbols)
{
return filter(assignments, symbols::contains);
}

public static Assignments filter(Assignments assignments, Predicate<Symbol> predicate)
{
return assignments.entrySet().stream()
.filter(entry -> predicate.test(entry.getKey()))
.collect(toAssignments());
}

public static boolean isIdentity(Assignments assignments, Symbol output)
{
Expression expression = castToExpression(assignments.get(output));

return expression instanceof SymbolReference && ((SymbolReference) expression).getName().equals(output.getName());
}

private static Collector<Map.Entry<Symbol, RowExpression>, Builder, Assignments> toAssignments()
{
return Collector.of(
AssignmentsUtils::builder,
(builder, entry) -> builder.put(entry.getKey(), entry.getValue()),
(left, right) -> {
left.putAll(right.build());
return left;
},
Builder::build);
}

// Originally, the following class is also static
public static class Builder
{
private final Map<Symbol, RowExpression> assignments = new LinkedHashMap<>();

public Builder putAll(Assignments assignments)
{
return putAll(assignments.getMap());
}

public Builder putAll(Map<Symbol, RowExpression> assignments)
{
for (Map.Entry<Symbol, RowExpression> assignment : assignments.entrySet()) {
put(assignment.getKey(), assignment.getValue());
}
return this;
}

public Builder put(Symbol symbol, RowExpression expression)
{
if (assignments.containsKey(symbol)) {
RowExpression assignment = assignments.get(symbol);
checkState(
assignment.equals(expression),
"Symbol %s already has assignment %s, while adding %s",
symbol,
assignment,
expression);
}
assignments.put(symbol, expression);
return this;
}

public Builder put(Symbol symbol, Expression expression)
{
return put(symbol, castToRowExpression(expression));
}

public Builder put(Map.Entry<Symbol, RowExpression> assignment)
{
put(assignment.getKey(), assignment.getValue());
return this;
}

public Builder putIdentities(Iterable<Symbol> symbols)
{
for (Symbol symbol : symbols) {
putIdentity(symbol);
}
return this;
}

public Builder putIdentity(Symbol symbol)
{
put(symbol, castToRowExpression(symbol.toSymbolReference()));
return this;
}

public Assignments build()
{
return new Assignments(assignments);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
import com.facebook.presto.sql.planner.plan.TopNNode;
import com.facebook.presto.sql.planner.plan.UnionNode;
import com.facebook.presto.sql.planner.plan.WindowNode;
import com.facebook.presto.sql.relational.OriginalExpressionUtils;
import com.facebook.presto.sql.tree.ComparisonExpression;
import com.facebook.presto.sql.tree.Expression;
import com.facebook.presto.sql.tree.SymbolReference;
Expand Down Expand Up @@ -59,6 +60,7 @@
import static com.facebook.presto.sql.tree.BooleanLiteral.TRUE_LITERAL;
import static com.google.common.base.Predicates.in;
import static com.google.common.collect.ImmutableList.toImmutableList;
import static com.google.common.collect.Maps.transformValues;
import static java.util.Objects.requireNonNull;

/**
Expand Down Expand Up @@ -159,7 +161,7 @@ public Expression visitProject(ProjectNode node, Void context)

Expression underlyingPredicate = node.getSource().accept(this, context);

List<Expression> projectionEqualities = node.getAssignments().entrySet().stream()
List<Expression> projectionEqualities = transformValues(node.getAssignments().getMap(), OriginalExpressionUtils::castToExpression).entrySet().stream()
.filter(SYMBOL_MATCHES_EXPRESSION.negate())
.map(ENTRY_TO_EQUALITY)
.collect(toImmutableList());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ public Void visitFilter(FilterNode node, ImmutableList.Builder<RowExpression> co
@Override
public Void visitProject(ProjectNode node, ImmutableList.Builder<RowExpression> context)
{
context.addAll(node.getAssignments().getExpressions().stream().map(OriginalExpressionUtils::castToRowExpression).collect(toImmutableList()));
context.addAll(node.getAssignments().getExpressions().stream().collect(toImmutableList()));
return super.visitProject(node, context);
}

Expand All @@ -130,7 +130,6 @@ public Void visitApply(ApplyNode node, ImmutableList.Builder<RowExpression> cont
context.addAll(node.getSubqueryAssignments()
.getExpressions()
.stream()
.map(OriginalExpressionUtils::castToRowExpression)
.collect(toImmutableList()));
return super.visitApply(node, context);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@
import com.facebook.presto.spi.predicate.NullableValue;
import com.facebook.presto.spi.relation.LambdaDefinitionExpression;
import com.facebook.presto.spi.relation.RowExpression;
import com.facebook.presto.spi.relation.VariableReferenceExpression;
import com.facebook.presto.spi.type.FunctionType;
import com.facebook.presto.spi.type.Type;
import com.facebook.presto.spiller.PartitioningSpillerFactory;
Expand Down Expand Up @@ -273,7 +274,6 @@
import static com.google.common.collect.DiscreteDomain.integers;
import static com.google.common.collect.ImmutableList.toImmutableList;
import static com.google.common.collect.ImmutableSet.toImmutableSet;
import static com.google.common.collect.Iterables.concat;
import static com.google.common.collect.Iterables.getOnlyElement;
import static com.google.common.collect.Range.closedOpen;
import static io.airlift.units.DataSize.Unit.BYTE;
Expand Down Expand Up @@ -1143,7 +1143,12 @@ public PhysicalOperation visitFilter(FilterNode node, LocalExecutionPlanContext
RowExpression filterExpression = node.getPredicate();
List<Symbol> outputSymbols = node.getOutputSymbols();

return visitScanFilterAndProject(context, node.getId(), sourceNode, Optional.of(filterExpression), Assignments.identity(outputSymbols), outputSymbols);
AssignmentsUtils.Builder identities = AssignmentsUtils.builder();
for (Symbol symbol : outputSymbols) {
Type type = requireNonNull(context.getTypes().get(symbol), format("No type for symbol %s", symbol));
identities.put(symbol, new VariableReferenceExpression(symbol.getName(), type));
}
return visitScanFilterAndProject(context, node.getId(), sourceNode, Optional.of(filterExpression), identities.build(), outputSymbols);
}

@Override
Expand Down Expand Up @@ -1218,30 +1223,15 @@ private PhysicalOperation visitScanFilterAndProject(
Map<Symbol, Integer> outputMappings = outputMappingsBuilder.build();

// compiler uses inputs instead of symbols, so rewrite the expressions first

List<Expression> projections = new ArrayList<>();
for (Symbol symbol : outputSymbols) {
projections.add(assignments.get(symbol));
}

Map<NodeRef<Expression>, Type> expressionTypes = getExpressionTypes(
context.getSession(),
metadata,
sqlParser,
context.getTypes(),
concat(assignments.getExpressions()),
emptyList(),
NOOP,
false);

List<RowExpression> translatedProjections = projections.stream()
.map(expression -> toRowExpression(expression, expressionTypes, sourceLayout))
List<RowExpression> projections = outputSymbols.stream()
.map(assignments::get)
.map(expression -> bindChannels(expression, sourceLayout))
.collect(toImmutableList());

try {
if (columns != null) {
Supplier<CursorProcessor> cursorProcessor = expressionCompiler.compileCursorProcessor(filterExpression, translatedProjections, sourceNode.getId());
Supplier<PageProcessor> pageProcessor = expressionCompiler.compilePageProcessor(filterExpression, translatedProjections, Optional.of(context.getStageId() + "_" + planNodeId));
Supplier<CursorProcessor> cursorProcessor = expressionCompiler.compileCursorProcessor(filterExpression, projections, sourceNode.getId());
Supplier<PageProcessor> pageProcessor = expressionCompiler.compilePageProcessor(filterExpression, projections, Optional.of(context.getStageId() + "_" + planNodeId));

SourceOperatorFactory operatorFactory = new ScanFilterAndProjectOperatorFactory(
context.getNextOperatorId(),
Expand All @@ -1251,20 +1241,20 @@ private PhysicalOperation visitScanFilterAndProject(
cursorProcessor,
pageProcessor,
columns,
getTypes(projections, expressionTypes),
projections.stream().map(RowExpression::getType).collect(toImmutableList()),
getFilterAndProjectMinOutputPageSize(session),
getFilterAndProjectMinOutputPageRowCount(session));

return new PhysicalOperation(operatorFactory, outputMappings, context, stageExecutionDescriptor.isScanGroupedExecution(sourceNode.getId()) ? GROUPED_EXECUTION : UNGROUPED_EXECUTION);
}
else {
Supplier<PageProcessor> pageProcessor = expressionCompiler.compilePageProcessor(filterExpression, translatedProjections, Optional.of(context.getStageId() + "_" + planNodeId));
Supplier<PageProcessor> pageProcessor = expressionCompiler.compilePageProcessor(filterExpression, projections, Optional.of(context.getStageId() + "_" + planNodeId));

OperatorFactory operatorFactory = new FilterAndProjectOperator.FilterAndProjectOperatorFactory(
context.getNextOperatorId(),
planNodeId,
pageProcessor,
getTypes(projections, expressionTypes),
projections.stream().map(RowExpression::getType).collect(toImmutableList()),
getFilterAndProjectMinOutputPageSize(session),
getFilterAndProjectMinOutputPageRowCount(session));

Expand Down Expand Up @@ -2796,14 +2786,6 @@ private OperatorFactory createHashAggregationOperatorFactory(
}
}

private static List<Type> getTypes(List<Expression> expressions, Map<NodeRef<Expression>, Type> expressionTypes)
{
return expressions.stream()
.map(NodeRef::of)
.map(expressionTypes::get)
.collect(toImmutableList());
}

private static TableFinisher createTableFinisher(Session session, TableFinishNode node, Metadata metadata)
{
WriterTarget target = node.getTarget();
Expand Down
Loading