Skip to content
Merged
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,18 @@

import com.facebook.presto.Session;
import com.facebook.presto.matching.Pattern;
import com.facebook.presto.spi.relation.RowExpression;
import com.facebook.presto.spi.relation.VariableReferenceExpression;
import com.facebook.presto.sql.planner.TypeProvider;
import com.facebook.presto.sql.planner.iterative.Lookup;
import com.facebook.presto.sql.planner.plan.ProjectNode;
import com.facebook.presto.sql.tree.Expression;

import java.util.Map;
import java.util.Optional;

import static com.facebook.presto.sql.planner.plan.Patterns.project;
import static com.facebook.presto.sql.relational.OriginalExpressionUtils.castToExpression;
import static com.facebook.presto.sql.relational.OriginalExpressionUtils.isExpression;
import static java.util.Objects.requireNonNull;

public class ProjectStatsRule
Expand Down Expand Up @@ -53,8 +55,14 @@ protected Optional<PlanNodeStatsEstimate> doCalculate(ProjectNode node, StatsPro
PlanNodeStatsEstimate.Builder calculatedStats = PlanNodeStatsEstimate.builder()
.setOutputRowCount(sourceStats.getOutputRowCount());

for (Map.Entry<VariableReferenceExpression, Expression> entry : node.getAssignments().entrySet()) {
calculatedStats.addVariableStatistics(entry.getKey(), scalarStatsCalculator.calculate(entry.getValue(), sourceStats, session, types));
for (Map.Entry<VariableReferenceExpression, RowExpression> entry : node.getAssignments().entrySet()) {
RowExpression expression = entry.getValue();
if (isExpression(expression)) {
calculatedStats.addVariableStatistics(entry.getKey(), scalarStatsCalculator.calculate(castToExpression(expression), sourceStats, session, types));
}
else {
calculatedStats.addVariableStatistics(entry.getKey(), scalarStatsCalculator.calculate(expression, sourceStats, session));
}
}
return Optional.of(calculatedStats.build());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@
import static com.facebook.presto.sql.tree.BooleanLiteral.TRUE_LITERAL;
import static com.google.common.base.Predicates.in;
import static com.google.common.collect.ImmutableList.toImmutableList;
import static com.google.common.collect.Maps.transformValues;
import static java.util.Objects.requireNonNull;

/**
Expand Down Expand Up @@ -164,7 +165,7 @@ public Expression visitProject(ProjectNode node, Void context)

Expression underlyingPredicate = node.getSource().accept(this, context);

List<Expression> projectionEqualities = node.getAssignments().entrySet().stream()
List<Expression> projectionEqualities = transformValues(node.getAssignments().getMap(), OriginalExpressionUtils::castToExpression).entrySet().stream()
.filter(VARIABLE_MATCHES_EXPRESSION.negate())
.map(VARIABLE_ENTRY_TO_EQUALITY)
.collect(toImmutableList());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
import com.facebook.presto.sql.planner.plan.JoinNode;
import com.facebook.presto.sql.planner.plan.ProjectNode;
import com.facebook.presto.sql.planner.plan.ValuesNode;
import com.facebook.presto.sql.relational.OriginalExpressionUtils;
import com.google.common.collect.ImmutableList;

import java.util.List;
Expand Down Expand Up @@ -112,7 +111,7 @@ public Void visitFilter(FilterNode node, ImmutableList.Builder<RowExpression> co
@Override
public Void visitProject(ProjectNode node, ImmutableList.Builder<RowExpression> context)
{
context.addAll(node.getAssignments().getExpressions().stream().map(OriginalExpressionUtils::castToRowExpression).collect(toImmutableList()));
context.addAll(node.getAssignments().getExpressions().stream().collect(toImmutableList()));
return super.visitProject(node, context);
}

Expand All @@ -136,7 +135,6 @@ public Void visitApply(ApplyNode node, ImmutableList.Builder<RowExpression> cont
context.addAll(node.getSubqueryAssignments()
.getExpressions()
.stream()
.map(OriginalExpressionUtils::castToRowExpression)
.collect(toImmutableList()));
return super.visitApply(node, context);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,6 @@
import static com.facebook.presto.SystemSessionProperties.getTaskWriterCount;
import static com.facebook.presto.SystemSessionProperties.isExchangeCompressionEnabled;
import static com.facebook.presto.SystemSessionProperties.isSpillEnabled;
import static com.facebook.presto.execution.warnings.WarningCollector.NOOP;
import static com.facebook.presto.operator.DistinctLimitOperator.DistinctLimitOperatorFactory;
import static com.facebook.presto.operator.NestedLoopBuildOperator.NestedLoopBuildOperatorFactory;
import static com.facebook.presto.operator.NestedLoopJoinOperator.NestedLoopJoinOperatorFactory;
Expand All @@ -240,7 +239,6 @@
import static com.facebook.presto.spi.relation.LogicalRowExpressions.TRUE_CONSTANT;
import static com.facebook.presto.spi.type.BigintType.BIGINT;
import static com.facebook.presto.spi.type.TypeUtils.writeNativeValue;
import static com.facebook.presto.sql.analyzer.ExpressionAnalyzer.getExpressionTypes;
import static com.facebook.presto.sql.gen.LambdaBytecodeGenerator.compileLambdaProvider;
import static com.facebook.presto.sql.planner.RowExpressionInterpreter.rowExpressionInterpreter;
import static com.facebook.presto.sql.planner.SystemPartitioningHandle.COORDINATOR_DISTRIBUTION;
Expand All @@ -250,6 +248,7 @@
import static com.facebook.presto.sql.planner.SystemPartitioningHandle.SINGLE_DISTRIBUTION;
import static com.facebook.presto.sql.planner.plan.AggregationNode.Step.FINAL;
import static com.facebook.presto.sql.planner.plan.AggregationNode.Step.PARTIAL;
import static com.facebook.presto.sql.planner.plan.AssignmentUtils.identityAssignments;
import static com.facebook.presto.sql.planner.plan.JoinNode.Type.FULL;
import static com.facebook.presto.sql.planner.plan.JoinNode.Type.INNER;
import static com.facebook.presto.sql.planner.plan.JoinNode.Type.RIGHT;
Expand All @@ -271,11 +270,9 @@
import static com.google.common.collect.DiscreteDomain.integers;
import static com.google.common.collect.ImmutableList.toImmutableList;
import static com.google.common.collect.ImmutableSet.toImmutableSet;
import static com.google.common.collect.Iterables.concat;
import static com.google.common.collect.Iterables.getOnlyElement;
import static com.google.common.collect.Range.closedOpen;
import static io.airlift.units.DataSize.Unit.BYTE;
import static java.util.Collections.emptyList;
import static java.util.Objects.requireNonNull;
import static java.util.stream.IntStream.range;

Expand Down Expand Up @@ -1140,7 +1137,7 @@ public PhysicalOperation visitFilter(FilterNode node, LocalExecutionPlanContext
RowExpression filterExpression = node.getPredicate();
List<VariableReferenceExpression> outputVariables = node.getOutputVariables();

return visitScanFilterAndProject(context, node.getId(), sourceNode, Optional.of(filterExpression), Assignments.identity(outputVariables), outputVariables);
return visitScanFilterAndProject(context, node.getId(), sourceNode, Optional.of(filterExpression), identityAssignments(outputVariables), outputVariables);
}

@Override
Expand Down Expand Up @@ -1211,30 +1208,15 @@ private PhysicalOperation visitScanFilterAndProject(
Map<VariableReferenceExpression, Integer> outputMappings = outputMappingsBuilder.build();

// compiler uses inputs instead of symbols, so rewrite the expressions first

List<Expression> projections = new ArrayList<>();
for (VariableReferenceExpression variable : outputVariables) {
projections.add(assignments.get(variable));
}

Map<NodeRef<Expression>, Type> expressionTypes = getExpressionTypes(
context.getSession(),
metadata,
sqlParser,
context.getTypes(),
concat(assignments.getExpressions()),
emptyList(),
NOOP,
false);

List<RowExpression> translatedProjections = projections.stream()
.map(expression -> toRowExpression(expression, expressionTypes, sourceLayout))
List<RowExpression> projections = outputVariables.stream()
.map(assignments::get)
.map(expression -> bindChannels(expression, sourceLayout))
.collect(toImmutableList());

try {
if (columns != null) {
Supplier<CursorProcessor> cursorProcessor = expressionCompiler.compileCursorProcessor(filterExpression, translatedProjections, sourceNode.getId());
Supplier<PageProcessor> pageProcessor = expressionCompiler.compilePageProcessor(filterExpression, translatedProjections, Optional.of(context.getStageId() + "_" + planNodeId));
Supplier<CursorProcessor> cursorProcessor = expressionCompiler.compileCursorProcessor(filterExpression, projections, sourceNode.getId());
Supplier<PageProcessor> pageProcessor = expressionCompiler.compilePageProcessor(filterExpression, projections, Optional.of(context.getStageId() + "_" + planNodeId));

SourceOperatorFactory operatorFactory = new ScanFilterAndProjectOperatorFactory(
context.getNextOperatorId(),
Expand All @@ -1244,20 +1226,20 @@ private PhysicalOperation visitScanFilterAndProject(
cursorProcessor,
pageProcessor,
columns,
getTypes(projections, expressionTypes),
projections.stream().map(RowExpression::getType).collect(toImmutableList()),
getFilterAndProjectMinOutputPageSize(session),
getFilterAndProjectMinOutputPageRowCount(session));

return new PhysicalOperation(operatorFactory, outputMappings, context, stageExecutionDescriptor.isScanGroupedExecution(sourceNode.getId()) ? GROUPED_EXECUTION : UNGROUPED_EXECUTION);
}
else {
Supplier<PageProcessor> pageProcessor = expressionCompiler.compilePageProcessor(filterExpression, translatedProjections, Optional.of(context.getStageId() + "_" + planNodeId));
Supplier<PageProcessor> pageProcessor = expressionCompiler.compilePageProcessor(filterExpression, projections, Optional.of(context.getStageId() + "_" + planNodeId));

OperatorFactory operatorFactory = new FilterAndProjectOperator.FilterAndProjectOperatorFactory(
context.getNextOperatorId(),
planNodeId,
pageProcessor,
getTypes(projections, expressionTypes),
projections.stream().map(RowExpression::getType).collect(toImmutableList()),
getFilterAndProjectMinOutputPageSize(session),
getFilterAndProjectMinOutputPageRowCount(session));

Expand Down Expand Up @@ -2729,14 +2711,6 @@ private OperatorFactory createHashAggregationOperatorFactory(
}
}

private static List<Type> getTypes(List<Expression> expressions, Map<NodeRef<Expression>, Type> expressionTypes)
{
return expressions.stream()
.map(NodeRef::of)
.map(expressionTypes::get)
.collect(toImmutableList());
}

private static TableFinisher createTableFinisher(Session session, TableFinishNode node, Metadata metadata)
{
WriterTarget target = node.getTarget();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@
import static com.facebook.presto.sql.planner.plan.TableWriterNode.WriterTarget;
import static com.facebook.presto.sql.planner.sanity.PlanSanityChecker.DISTRIBUTED_PLAN_SANITY_CHECKER;
import static com.facebook.presto.sql.relational.Expressions.constant;
import static com.facebook.presto.sql.relational.OriginalExpressionUtils.castToRowExpression;
import static com.google.common.base.Preconditions.checkState;
import static com.google.common.base.Verify.verify;
import static com.google.common.collect.ImmutableList.toImmutableList;
Expand Down Expand Up @@ -358,19 +359,19 @@ private RelationPlan createInsertPlan(Analysis analysis, Insert insertStatement)
int index = insert.getColumns().indexOf(columns.get(column.getName()));
if (index < 0) {
Expression cast = new Cast(new NullLiteral(), column.getType().getTypeSignature().toString());
assignments.put(output, cast);
assignments.put(output, castToRowExpression(cast));
}
else {
Symbol input = plan.getSymbol(index);
Type tableType = column.getType();
Type queryType = symbolAllocator.getTypes().get(input);

if (queryType.equals(tableType) || metadata.getTypeManager().isTypeOnlyCoercion(queryType, tableType)) {
assignments.put(output, input.toSymbolReference());
assignments.put(output, castToRowExpression(input.toSymbolReference()));
}
else {
Expression cast = new Cast(input.toSymbolReference(), tableType.getTypeSignature().toString());
assignments.put(output, cast);
assignments.put(output, castToRowExpression(cast));
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import java.util.List;
import java.util.Map;

import static com.facebook.presto.sql.relational.OriginalExpressionUtils.castToRowExpression;
import static java.util.Objects.requireNonNull;

class PlanBuilder
Expand Down Expand Up @@ -105,13 +106,13 @@ public PlanBuilder appendProjections(Iterable<Expression> expressions, SymbolAll

// add an identity projection for underlying plan
for (VariableReferenceExpression variable : getRoot().getOutputVariables()) {
projections.put(variable, new SymbolReference(variable.getName()));
projections.put(variable, castToRowExpression(new SymbolReference(variable.getName())));
}

ImmutableMap.Builder<VariableReferenceExpression, Expression> newTranslations = ImmutableMap.builder();
for (Expression expression : expressions) {
VariableReferenceExpression variable = symbolAllocator.newVariable(expression, getAnalysis().getTypeWithCoercions(expression));
projections.put(variable, translations.rewrite(expression));
projections.put(variable, castToRowExpression(translations.rewrite(expression)));
newTranslations.put(variable, expression);
}
// Now append the new translations into the TranslationMap
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
import com.facebook.presto.spi.plan.TableScanNode;
import com.facebook.presto.spi.predicate.TupleDomain;
import com.facebook.presto.spi.relation.ConstantExpression;
import com.facebook.presto.spi.relation.RowExpression;
import com.facebook.presto.spi.relation.VariableReferenceExpression;
import com.facebook.presto.sql.parser.SqlParser;
import com.facebook.presto.sql.planner.Partitioning.ArgumentBinding;
Expand All @@ -69,7 +70,6 @@
import com.facebook.presto.sql.planner.plan.ValuesNode;
import com.facebook.presto.sql.planner.plan.WindowNode;
import com.facebook.presto.sql.planner.sanity.PlanSanityChecker;
import com.facebook.presto.sql.tree.Expression;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
Expand Down Expand Up @@ -109,6 +109,7 @@
import static com.facebook.presto.sql.planner.plan.ExchangeNode.Type.REPLICATE;
import static com.facebook.presto.sql.planner.plan.ExchangeNode.gatheringExchange;
import static com.facebook.presto.sql.planner.planPrinter.PlanPrinter.jsonFragmentPlan;
import static com.facebook.presto.sql.relational.Expressions.constant;
import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.base.Preconditions.checkState;
import static com.google.common.base.Verify.verify;
Expand Down Expand Up @@ -560,13 +561,13 @@ private PlanNode createRemoteMaterializedExchange(ExchangeNode exchange, Rewrite
private PartitioningVariableAssignments assignPartitioningVariables(Partitioning partitioning)
{
ImmutableList.Builder<VariableReferenceExpression> variables = ImmutableList.builder();
ImmutableMap.Builder<VariableReferenceExpression, Expression> constants = ImmutableMap.builder();
ImmutableMap.Builder<VariableReferenceExpression, RowExpression> constants = ImmutableMap.builder();
for (ArgumentBinding argumentBinding : partitioning.getArguments()) {
VariableReferenceExpression variable;
if (argumentBinding.isConstant()) {
ConstantExpression constant = argumentBinding.getConstant();
Expression expression = literalEncoder.toExpression(constant.getValue(), constant.getType());
Comment thread
hellium01 marked this conversation as resolved.
Outdated
variable = symbolAllocator.newVariable(expression, constant.getType());
RowExpression expression = constant(constant.getValue(), constant.getType());
Comment thread
hellium01 marked this conversation as resolved.
variable = symbolAllocator.newVariable("constant_partition", constant.getType());
constants.put(variable, expression);
}
else {
Expand Down Expand Up @@ -632,7 +633,7 @@ private TableFinishNode createTemporaryTableWrite(
List<VariableReferenceExpression> outputs,
List<List<VariableReferenceExpression>> inputs,
List<PlanNode> sources,
Map<VariableReferenceExpression, Expression> constantExpressions,
Map<VariableReferenceExpression, RowExpression> constantExpressions,
PartitioningMetadata partitioningMetadata)
{
if (!constantExpressions.isEmpty()) {
Expand All @@ -656,8 +657,8 @@ private TableFinishNode createTemporaryTableWrite(
sources = sources.stream()
.map(source -> {
Assignments.Builder assignments = Assignments.builder();
assignments.putIdentities(source.getOutputVariables());
constantVariables.forEach(variable -> assignments.put(variable, constantExpressions.get(variable)));
source.getOutputVariables().forEach(variable -> assignments.put(variable, new VariableReferenceExpression(variable.getName(), variable.getType())));
constantVariables.forEach(symbol -> assignments.put(symbol, constantExpressions.get(symbol)));
Comment thread
hellium01 marked this conversation as resolved.
return new ProjectNode(idAllocator.getNextId(), source, assignments.build());
})
.collect(toImmutableList());
Expand Down Expand Up @@ -1216,9 +1217,9 @@ public PlanNode visitTableScan(TableScanNode node, RewriteContext<Void> context)
private static class PartitioningVariableAssignments
{
private final List<VariableReferenceExpression> variables;
private final Map<VariableReferenceExpression, Expression> constants;
private final Map<VariableReferenceExpression, RowExpression> constants;

private PartitioningVariableAssignments(List<VariableReferenceExpression> variables, Map<VariableReferenceExpression, Expression> constants)
private PartitioningVariableAssignments(List<VariableReferenceExpression> variables, Map<VariableReferenceExpression, RowExpression> constants)
{
this.variables = ImmutableList.copyOf(requireNonNull(variables, "variables is null"));
this.constants = ImmutableMap.copyOf(requireNonNull(constants, "constants is null"));
Expand All @@ -1232,7 +1233,7 @@ public List<VariableReferenceExpression> getVariables()
return variables;
}

public Map<VariableReferenceExpression, Expression> getConstants()
public Map<VariableReferenceExpression, RowExpression> getConstants()
{
return constants;
}
Expand Down
Loading