Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,7 @@ private PlanBuilder appendInPredicateApplyNode(PlanBuilder subPlan, InPredicate

subPlan.getTranslations().put(inPredicate, inPredicateSubquerySymbol);

return appendApplyNode(subPlan, inPredicate, subqueryPlan.getRoot(), Assignments.of(inPredicateSubquerySymbol, inPredicateSubqueryExpression), correlationAllowed);
return appendApplyNode(subPlan, inPredicate, subqueryPlan.getRoot(), Assignments.of(inPredicateSubquerySymbol, castToRowExpression(inPredicateSubqueryExpression)), correlationAllowed);
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Usually we do the cast at the very last step when replacing expressions. This will force us to build all the necessary infra for RowExpression then start a one-time migration. It would be cleaner and easier to revert.

}

private PlanBuilder appendScalarSubqueryApplyNodes(PlanBuilder builder, Set<SubqueryExpression> scalarSubqueries, boolean correlationAllowed)
Expand Down Expand Up @@ -295,7 +295,7 @@ private PlanBuilder appendExistSubqueryApplyNode(PlanBuilder subPlan, ExistsPred
subPlan,
existsPredicate.getSubquery(),
subqueryNode,
Assignments.of(exists, rewrittenExistsPredicate),
Assignments.of(exists, castToRowExpression(rewrittenExistsPredicate)),
correlationAllowed);
}

Expand Down Expand Up @@ -393,7 +393,7 @@ private PlanBuilder planQuantifiedApplyNode(PlanBuilder subPlan, QuantifiedCompa
subPlan,
quantifiedComparison.getSubquery(),
subqueryPlan.getRoot(),
Assignments.of(coercedQuantifiedComparisonSymbol, coercedQuantifiedComparison),
Assignments.of(coercedQuantifiedComparisonSymbol, castToRowExpression(coercedQuantifiedComparison)),
correlationAllowed);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
import static com.facebook.presto.sql.analyzer.FeaturesConfig.JoinReorderingStrategy.ELIMINATE_CROSS_JOINS;
import static com.facebook.presto.sql.planner.iterative.rule.Util.restrictOutputs;
import static com.facebook.presto.sql.planner.plan.Patterns.join;
import static com.facebook.presto.sql.relational.ProjectNodeUtils.getAsRowExpression;
import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.base.Preconditions.checkState;
import static com.google.common.collect.ImmutableList.toImmutableList;
Expand Down Expand Up @@ -202,7 +203,7 @@ public static PlanNode buildJoinTree(List<Symbol> expectedOutputSymbols, JoinGra
result = new ProjectNode(
idAllocator.getNextId(),
result,
Assignments.copyOf(graph.getAssignments().get()));
Assignments.copyOf(getAsRowExpression(graph.getAssignments().get())));
}

// If needed, introduce a projection to constrain the outputs to what was originally expected
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
import static com.facebook.presto.sql.planner.plan.Patterns.project;
import static com.facebook.presto.sql.planner.plan.Patterns.source;
import static com.facebook.presto.sql.planner.plan.Patterns.window;
import static com.facebook.presto.sql.relational.ProjectNodeUtils.isIdentity;
import static com.google.common.collect.ImmutableList.toImmutableList;
import static com.google.common.collect.ImmutableSet.toImmutableSet;

Expand Down Expand Up @@ -140,7 +141,7 @@ protected static Optional<WindowNode> pullWindowNodeAboveProjects(
// The target node, when hoisted above the projections, will provide the symbols directly.
Map<Symbol, Expression> assignmentsWithoutTargetOutputIdentities = Maps.filterKeys(
project.getAssignments().getMap(),
output -> !(project.getAssignments().isIdentity(output) && targetOutputs.contains(output)));
output -> !(isIdentity(project, output) && targetOutputs.contains(output)));

if (targetInputs.stream().anyMatch(assignmentsWithoutTargetOutputIdentities::containsKey)) {
// Redefinition of an input to the target -- can't handle this case.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@
import static com.facebook.presto.sql.planner.ExpressionSymbolInliner.inlineSymbols;
import static com.facebook.presto.sql.planner.plan.Patterns.project;
import static com.facebook.presto.sql.planner.plan.Patterns.source;
import static com.facebook.presto.sql.relational.ProjectNodeUtils.getAsRowExpression;
import static com.facebook.presto.sql.relational.ProjectNodeUtils.isIdentity;
import static java.util.stream.Collectors.toSet;

/**
Expand Down Expand Up @@ -105,7 +107,7 @@ public Result apply(ProjectNode parent, Captures captures, Context context)
child.getId(),
child.getSource(),
childAssignments.build()),
Assignments.copyOf(parentAssignments)));
Assignments.copyOf(getAsRowExpression(parentAssignments))));
}

private Expression inlineReferences(Expression expression, Assignments assignments)
Expand Down Expand Up @@ -155,7 +157,7 @@ private Sets.SetView<Symbol> extractInliningTargets(ProjectNode parent, ProjectN
Set<Symbol> singletons = dependencies.entrySet().stream()
.filter(entry -> entry.getValue() == 1) // reference appears just once across all expressions in parent project node
.filter(entry -> !tryArguments.contains(entry.getKey())) // they are not inputs to TRY. Otherwise, inlining might change semantics
.filter(entry -> !child.getAssignments().isIdentity(entry.getKey())) // skip identities, otherwise, this rule will keep firing forever
.filter(entry -> !isIdentity(child, entry.getKey())) // skip identities, otherwise, this rule will keep firing forever
.map(Map.Entry::getKey)
.collect(toSet());

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
import static com.facebook.presto.sql.planner.plan.LateralJoinNode.Type.INNER;
import static com.facebook.presto.sql.planner.plan.LateralJoinNode.Type.LEFT;
import static com.facebook.presto.sql.planner.plan.Patterns.applyNode;
import static com.facebook.presto.sql.relational.OriginalExpressionUtils.castToRowExpression;
import static com.facebook.presto.sql.tree.BooleanLiteral.TRUE_LITERAL;
import static com.facebook.presto.sql.tree.ComparisonExpression.Operator.GREATER_THAN;
import static com.google.common.base.Preconditions.checkState;
Expand Down Expand Up @@ -130,7 +131,7 @@ private Optional<PlanNode> rewriteToNonDefaultAggregation(ApplyNode applyNode, C
applyNode.getSubquery(),
1L,
false),
Assignments.of(subqueryTrue, TRUE_LITERAL));
Assignments.of(subqueryTrue, castToRowExpression(TRUE_LITERAL)));

PlanNodeDecorrelator decorrelator = new PlanNodeDecorrelator(context.getIdAllocator(), context.getLookup());
if (!decorrelator.decorrelateFilters(subquery, applyNode.getCorrelation()).isPresent()) {
Expand All @@ -153,6 +154,7 @@ private PlanNode rewriteToDefaultAggregation(ApplyNode parent, Context context)
Symbol count = context.getSymbolAllocator().newSymbol(COUNT.toString(), BIGINT);
Symbol exists = getOnlyElement(parent.getSubqueryAssignments().getSymbols());

final ComparisonExpression comparisonExpression = new ComparisonExpression(GREATER_THAN, count.toSymbolReference(), new Cast(new LongLiteral("0"), BIGINT.toString()));
return new LateralJoinNode(
parent.getId(),
parent.getInput(),
Expand All @@ -170,7 +172,7 @@ private PlanNode rewriteToDefaultAggregation(ApplyNode parent, Context context)
AggregationNode.Step.SINGLE,
Optional.empty(),
Optional.empty()),
Assignments.of(exists, new ComparisonExpression(GREATER_THAN, count.toSymbolReference(), new Cast(new LongLiteral("0"), BIGINT.toString())))),
Assignments.of(exists, castToRowExpression(comparisonExpression))),
parent.getCorrelation(),
INNER,
parent.getOriginSubquery());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@
import static com.facebook.presto.sql.analyzer.TypeSignatureProvider.fromTypes;
import static com.facebook.presto.sql.planner.plan.AggregationNode.globalAggregation;
import static com.facebook.presto.sql.planner.plan.SimplePlanRewriter.rewriteWith;
import static com.facebook.presto.sql.relational.OriginalExpressionUtils.castToRowExpression;
import static com.facebook.presto.sql.tree.BooleanLiteral.FALSE_LITERAL;
import static com.facebook.presto.sql.tree.BooleanLiteral.TRUE_LITERAL;
import static com.facebook.presto.sql.tree.ComparisonExpression.Operator.EQUAL;
Expand Down Expand Up @@ -180,7 +181,7 @@ countNonNullValue, new Aggregation(

Symbol quantifiedComparisonSymbol = getOnlyElement(node.getSubqueryAssignments().getSymbols());

return projectExpressions(lateralJoinNode, Assignments.of(quantifiedComparisonSymbol, valueComparedToSubquery));
return projectExpressions(lateralJoinNode, Assignments.of(quantifiedComparisonSymbol, castToRowExpression(valueComparedToSubquery)));
}

public Expression rewriteUsingBounds(QuantifiedComparisonExpression quantifiedComparison, Symbol minValue, Symbol maxValue, Symbol countAllValue, Symbol countNonNullValue)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,11 @@
*/
package com.facebook.presto.sql.planner.plan;

import com.facebook.presto.spi.relation.RowExpression;
import com.facebook.presto.sql.planner.Symbol;
import com.facebook.presto.sql.tree.Expression;
import com.facebook.presto.sql.tree.ExpressionRewriter;
import com.facebook.presto.sql.tree.ExpressionTreeRewriter;
import com.facebook.presto.sql.tree.SymbolReference;
import com.fasterxml.jackson.annotation.JsonCreator;
import com.fasterxml.jackson.annotation.JsonProperty;
import com.google.common.base.Predicate;
Expand All @@ -35,6 +35,8 @@
import java.util.function.Function;
import java.util.stream.Collector;

import static com.facebook.presto.sql.relational.OriginalExpressionUtils.castToExpression;
import static com.facebook.presto.sql.relational.ProjectNodeUtils.getAsExpression;
import static com.google.common.base.Preconditions.checkState;
import static java.util.Arrays.asList;
import static java.util.Objects.requireNonNull;
Expand All @@ -56,10 +58,10 @@ public static Assignments identity(Iterable<Symbol> symbols)
return builder().putIdentities(symbols).build();
}

public static Assignments copyOf(Map<Symbol, Expression> assignments)
public static Assignments copyOf(Map<Symbol, RowExpression> assignments)
{
return builder()
.putAll(assignments)
.putAll(getAsExpression(assignments))
.build();
}

Expand All @@ -68,14 +70,14 @@ public static Assignments of()
return builder().build();
}

public static Assignments of(Symbol symbol, Expression expression)
public static Assignments of(Symbol symbol, RowExpression expression)
{
return builder().put(symbol, expression).build();
return builder().put(symbol, castToExpression(expression)).build();
}

public static Assignments of(Symbol symbol1, Expression expression1, Symbol symbol2, Expression expression2)
public static Assignments of(Symbol symbol1, RowExpression expression1, Symbol symbol2, RowExpression expression2)
{
return builder().put(symbol1, expression1).put(symbol2, expression2).build();
return builder().put(symbol1, castToExpression(expression1)).put(symbol2, castToExpression(expression2)).build();
}

private final Map<Symbol, Expression> assignments;
Expand Down Expand Up @@ -121,13 +123,6 @@ public Assignments filter(Predicate<Symbol> predicate)
.collect(toAssignments());
}

public boolean isIdentity(Symbol output)
{
Expression expression = assignments.get(output);

return expression instanceof SymbolReference && ((SymbolReference) expression).getName().equals(output.getName());
}

private Collector<Entry<Symbol, Expression>, Builder, Assignments> toAssignments()
{
return Collector.of(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,26 +13,52 @@
*/
package com.facebook.presto.sql.relational;

import com.facebook.presto.spi.relation.RowExpression;
import com.facebook.presto.sql.planner.Symbol;
import com.facebook.presto.sql.planner.plan.ProjectNode;
import com.facebook.presto.sql.tree.Expression;
import com.facebook.presto.sql.tree.SymbolReference;

import java.util.Map;
import java.util.stream.Collectors;

import static com.facebook.presto.sql.relational.OriginalExpressionUtils.castToExpression;
import static com.facebook.presto.sql.relational.OriginalExpressionUtils.castToRowExpression;

public class ProjectNodeUtils
{
private ProjectNodeUtils() {}

public static boolean isIdentity(ProjectNode projectNode)
{
for (Map.Entry<Symbol, Expression> entry : projectNode.getAssignments().entrySet()) {
Expression expression = entry.getValue();
Symbol symbol = entry.getKey();
if (!(expression instanceof SymbolReference && ((SymbolReference) expression).getName().equals(symbol.getName()))) {
for (Symbol symbol : projectNode.getAssignments().getSymbols()) {
if (!isIdentity(projectNode, symbol)) {
return false;
}
}
return true;
}

public static boolean isIdentity(ProjectNode projectNode, Symbol output)
{
Expression expression = projectNode.getAssignments().get(output);

return expression instanceof SymbolReference && ((SymbolReference) expression).getName().equals(output.getName());
}

public static Map<Symbol, Expression> getAsExpression(Map<Symbol, RowExpression> assignments)
{
return assignments.entrySet().stream()
.collect(Collectors.toMap(
Map.Entry::getKey,
entry -> castToExpression(entry.getValue())));
}

public static Map<Symbol, RowExpression> getAsRowExpression(Map<Symbol, Expression> assignments)
{
return assignments.entrySet().stream()
.collect(Collectors.toMap(
Map.Entry::getKey,
entry -> castToRowExpression(entry.getValue())));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@
import static com.facebook.presto.sql.planner.plan.ExchangeNode.Scope.REMOTE_STREAMING;
import static com.facebook.presto.sql.planner.plan.ExchangeNode.replicatedExchange;
import static com.facebook.presto.sql.planner.plan.ExchangeNode.systemPartitionedExchange;
import static com.facebook.presto.sql.relational.OriginalExpressionUtils.castToRowExpression;
import static com.facebook.presto.testing.TestingSession.createBogusTestingCatalog;
import static com.facebook.presto.testing.TestingSession.testSessionBuilder;
import static com.facebook.presto.tpch.TpchTransactionHandle.INSTANCE;
Expand Down Expand Up @@ -759,7 +760,7 @@ private PlanNode project(String id, PlanNode source, String symbol, Expression e
return new ProjectNode(
new PlanNodeId(id),
source,
Assignments.of(new Symbol(symbol), expression));
Assignments.of(new Symbol(symbol), castToRowExpression(expression)));
}

private AggregationNode aggregation(String id, PlanNode source)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@
import static com.facebook.presto.sql.ExpressionUtils.or;
import static com.facebook.presto.sql.planner.plan.AggregationNode.globalAggregation;
import static com.facebook.presto.sql.planner.plan.AggregationNode.singleGroupingSet;
import static com.facebook.presto.sql.relational.OriginalExpressionUtils.castToRowExpression;
import static com.facebook.presto.sql.tree.BooleanLiteral.FALSE_LITERAL;
import static com.facebook.presto.sql.tree.BooleanLiteral.TRUE_LITERAL;
import static org.testng.Assert.assertEquals;
Expand Down Expand Up @@ -212,7 +213,7 @@ public void testProject()
equals(AE, BE),
equals(BE, CE),
lessThan(CE, bigintLiteral(10)))),
Assignments.of(D, AE, E, CE));
Assignments.of(D, castToRowExpression(AE), E, castToRowExpression(CE)));

Expression effectivePredicate = effectivePredicateExtractor.extract(node);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
import static com.facebook.presto.sql.planner.iterative.rule.EliminateCrossJoins.getJoinOrder;
import static com.facebook.presto.sql.planner.iterative.rule.EliminateCrossJoins.isOriginalOrder;
import static com.facebook.presto.sql.planner.plan.JoinNode.Type.INNER;
import static com.facebook.presto.sql.relational.OriginalExpressionUtils.castToRowExpression;
import static com.facebook.presto.sql.tree.ArithmeticUnaryExpression.Sign.MINUS;
import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.collect.ImmutableList.toImmutableList;
Expand Down Expand Up @@ -250,7 +251,7 @@ private PlanNode projectNode(PlanNode source, String symbol, Expression expressi
return new ProjectNode(
idAllocator.getNextId(),
source,
Assignments.of(new Symbol(symbol), expression));
Assignments.of(new Symbol(symbol), castToRowExpression(expression)));
}

private String symbol(String name)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ public void testProjectionExpressionRewrite()
{
tester().assertThat(zeroRewriter.projectExpressionRewrite())
.on(p -> p.project(
Assignments.of(p.symbol("y"), PlanBuilder.expression("x IS NOT NULL")),
Assignments.of(p.symbol("y"), castToRowExpression(PlanBuilder.expression("x IS NOT NULL"))),
p.values(p.symbol("x"))))
.matches(
project(ImmutableMap.of("y", expression("0")), values("x")));
Expand All @@ -69,7 +69,7 @@ public void testProjectionExpressionNotRewritten()
{
tester().assertThat(zeroRewriter.projectExpressionRewrite())
.on(p -> p.project(
Assignments.of(p.symbol("y"), PlanBuilder.expression("0")),
Assignments.of(p.symbol("y"), castToRowExpression(PlanBuilder.expression("0"))),
p.values(p.symbol("x"))))
.doesNotFire();
}
Expand Down Expand Up @@ -152,11 +152,11 @@ public void testApplyExpressionRewrite()
.on(p -> p.apply(
Assignments.of(
p.symbol("a", BIGINT),
new InPredicate(
castToRowExpression(new InPredicate(
new LongLiteral("1"),
new InListExpression(ImmutableList.of(
new LongLiteral("1"),
new LongLiteral("2"))))),
new LongLiteral("2")))))),
ImmutableList.of(),
p.values(),
p.values()))
Expand All @@ -175,11 +175,11 @@ public void testApplyExpressionNotRewritten()
.on(p -> p.apply(
Assignments.of(
p.symbol("a", BIGINT),
new InPredicate(
castToRowExpression(new InPredicate(
new LongLiteral("0"),
new InListExpression(ImmutableList.of(
new LongLiteral("1"),
new LongLiteral("2"))))),
new LongLiteral("2")))))),
ImmutableList.of(),
p.values(),
p.values()))
Expand Down
Loading