Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,7 @@ public final class SystemSessionProperties
public static final String LEGACY_TIMESTAMP = "legacy_timestamp";
public static final String ENABLE_INTERMEDIATE_AGGREGATIONS = "enable_intermediate_aggregations";
public static final String PUSH_AGGREGATION_THROUGH_JOIN = "push_aggregation_through_join";
public static final String PUSH_SEMI_JOIN_THROUGH_UNION = "push_semi_join_through_union";
public static final String PUSH_PARTIAL_AGGREGATION_THROUGH_JOIN = "push_partial_aggregation_through_join";
public static final String PARSE_DECIMAL_LITERALS_AS_DOUBLE = "parse_decimal_literals_as_double";
public static final String FORCE_SINGLE_NODE_OUTPUT = "force_single_node_output";
Expand Down Expand Up @@ -907,6 +908,11 @@ public SystemSessionProperties(
"Allow pushing aggregations below joins",
featuresConfig.isPushAggregationThroughJoin(),
false),
booleanProperty(
PUSH_SEMI_JOIN_THROUGH_UNION,
"Allow pushing semi joins through union",
featuresConfig.isPushSemiJoinThroughUnion(),
false),
booleanProperty(
PUSH_PARTIAL_AGGREGATION_THROUGH_JOIN,
"Push partial aggregations below joins",
Expand Down Expand Up @@ -2488,6 +2494,11 @@ public static boolean shouldPushAggregationThroughJoin(Session session)
return session.getSystemProperty(PUSH_AGGREGATION_THROUGH_JOIN, Boolean.class);
}

public static boolean isPushSemiJoinThroughUnion(Session session)
{
return session.getSystemProperty(PUSH_SEMI_JOIN_THROUGH_UNION, Boolean.class);
}

public static boolean isNativeExecutionEnabled(Session session)
{
return session.getSystemProperty(NATIVE_EXECUTION_ENABLED, Boolean.class);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,7 @@ public class FeaturesConfig
private double defaultJoinSelectivityCoefficient;
private double defaultWriterReplicationCoefficient = 3;
private boolean pushAggregationThroughJoin = true;
private boolean pushSemiJoinThroughUnion;
private double memoryRevokingTarget = 0.5;
private double memoryRevokingThreshold = 0.9;
private boolean useMarkDistinct = true;
Expand Down Expand Up @@ -1625,6 +1626,19 @@ public FeaturesConfig setPushAggregationThroughJoin(boolean value)
return this;
}

public boolean isPushSemiJoinThroughUnion()
{
return pushSemiJoinThroughUnion;
}

@Config("optimizer.push-semi-join-through-union")
@ConfigDescription("Push semi join through union to allow parallel semi join execution")
public FeaturesConfig setPushSemiJoinThroughUnion(boolean pushSemiJoinThroughUnion)
{
this.pushSemiJoinThroughUnion = pushSemiJoinThroughUnion;
return this;
}

public boolean isForceSingleNodeOutput()
{
return forceSingleNodeOutput;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@
import com.facebook.presto.sql.planner.iterative.rule.PushProjectionThroughUnion;
import com.facebook.presto.sql.planner.iterative.rule.PushRemoteExchangeThroughAssignUniqueId;
import com.facebook.presto.sql.planner.iterative.rule.PushRemoteExchangeThroughGroupId;
import com.facebook.presto.sql.planner.iterative.rule.PushSemiJoinThroughUnion;
import com.facebook.presto.sql.planner.iterative.rule.PushTableWriteThroughUnion;
import com.facebook.presto.sql.planner.iterative.rule.PushTopNThroughUnion;
import com.facebook.presto.sql.planner.iterative.rule.RandomizeSourceKeyInSemiJoin;
Expand Down Expand Up @@ -618,6 +619,12 @@ public PlanOptimizers(
statsCalculator,
estimatedExchangesCostCalculator,
ImmutableSet.of(new LeftJoinNullFilterToSemiJoin(metadata.getFunctionAndTypeManager()))),
new IterativeOptimizer(
metadata,
ruleStats,
statsCalculator,
estimatedExchangesCostCalculator,
ImmutableSet.of(new PushSemiJoinThroughUnion())),
new IterativeOptimizer(
metadata,
ruleStats,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,242 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.facebook.presto.sql.planner.iterative.rule;

import com.facebook.presto.Session;
import com.facebook.presto.matching.Captures;
import com.facebook.presto.matching.Pattern;
import com.facebook.presto.spi.plan.Assignments;
import com.facebook.presto.spi.plan.PlanNode;
import com.facebook.presto.spi.plan.ProjectNode;
import com.facebook.presto.spi.plan.SemiJoinNode;
import com.facebook.presto.spi.plan.UnionNode;
import com.facebook.presto.spi.relation.RowExpression;
import com.facebook.presto.spi.relation.VariableReferenceExpression;
import com.facebook.presto.sql.planner.RowExpressionVariableInliner;
import com.facebook.presto.sql.planner.iterative.Rule;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableListMultimap;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ListMultimap;

import java.util.HashMap;
import java.util.Map;
import java.util.Optional;

import static com.facebook.presto.SystemSessionProperties.isPushSemiJoinThroughUnion;
import static com.facebook.presto.sql.planner.optimizations.SetOperationNodeUtils.fromListMultimap;
import static com.facebook.presto.sql.planner.plan.Patterns.semiJoin;

/**
* Pushes a SemiJoinNode through a UnionNode (on the probe/source side).
* <p>
* Transforms:
* <pre>
* - SemiJoin (sourceJoinVar=c, output=sjOut)
* - Union (output c from [a1, a2])
* - source1 (outputs a1)
* - source2 (outputs a2)
* - filteringSource
* </pre>
* into:
* <pre>
* - Union (output sjOut from [sjOut_0, sjOut_1], c from [a1, a2])
* - SemiJoin (sourceJoinVar=a1, output=sjOut_0)
* - source1
* - filteringSource
* - SemiJoin (sourceJoinVar=a2, output=sjOut_1)
* - source2
* - filteringSource
* </pre>
* <p>
* Also handles the case where a ProjectNode sits between the SemiJoin and Union:
* <pre>
* - SemiJoin
* - Project
* - Union
* - filteringSource
* </pre>
* In this case, the project is pushed into each union branch before the semi join.
*/
public class PushSemiJoinThroughUnion
implements Rule<SemiJoinNode>
{
private static final Pattern<SemiJoinNode> PATTERN = semiJoin();

@Override
public Pattern<SemiJoinNode> getPattern()
{
return PATTERN;
}

@Override
public boolean isEnabled(Session session)
{
return isPushSemiJoinThroughUnion(session);
}

@Override
public Result apply(SemiJoinNode semiJoinNode, Captures captures, Context context)
{
PlanNode source = context.getLookup().resolve(semiJoinNode.getSource());

if (source instanceof UnionNode) {
return pushThroughUnion(semiJoinNode, (UnionNode) source, Optional.empty(), context);
}

if (source instanceof ProjectNode) {
ProjectNode projectNode = (ProjectNode) source;
PlanNode projectSource = context.getLookup().resolve(projectNode.getSource());
if (projectSource instanceof UnionNode) {
return pushThroughUnion(semiJoinNode, (UnionNode) projectSource, Optional.of(projectNode), context);
}
}

return Result.empty();
}

private Result pushThroughUnion(
SemiJoinNode semiJoinNode,
UnionNode unionNode,
Optional<ProjectNode> projectNode,
Context context)
{
ImmutableList.Builder<PlanNode> newSources = ImmutableList.builder();
ImmutableListMultimap.Builder<VariableReferenceExpression, VariableReferenceExpression> outputMappings =
ImmutableListMultimap.builder();

for (int i = 0; i < unionNode.getSources().size(); i++) {
Map<VariableReferenceExpression, VariableReferenceExpression> unionVarMap = unionNode.sourceVariableMap(i);

PlanNode branchSource;
VariableReferenceExpression mappedSourceJoinVar;
Optional<VariableReferenceExpression> mappedSourceHashVar;
Map<String, VariableReferenceExpression> branchDynamicFilters;

if (projectNode.isPresent()) {
// Push the project into each union branch, translating its assignments
ProjectNode project = projectNode.get();
Assignments.Builder assignments = Assignments.builder();
Map<VariableReferenceExpression, VariableReferenceExpression> projectVarMapping = new HashMap<>();

for (Map.Entry<VariableReferenceExpression, RowExpression> entry : project.getAssignments().entrySet()) {
RowExpression translatedExpression = RowExpressionVariableInliner.inlineVariables(unionVarMap, entry.getValue());
VariableReferenceExpression newVar = context.getVariableAllocator().newVariable(translatedExpression);
assignments.put(newVar, translatedExpression);
projectVarMapping.put(entry.getKey(), newVar);
}

branchSource = new ProjectNode(
project.getSourceLocation(),
context.getIdAllocator().getNextId(),
unionNode.getSources().get(i),
assignments.build(),
project.getLocality());

// Map the semi-join source variables through the project variable mapping
mappedSourceJoinVar = projectVarMapping.get(semiJoinNode.getSourceJoinVariable());
if (mappedSourceJoinVar == null) {
return Result.empty();
}
mappedSourceHashVar = semiJoinNode.getSourceHashVariable().map(projectVarMapping::get);
if (mappedSourceHashVar.isPresent() && mappedSourceHashVar.get() == null) {
return Result.empty();
}

// Build output-to-input mappings for original union output variables,
// mapped through the project
for (VariableReferenceExpression semiJoinOutputVar : semiJoinNode.getOutputVariables()) {
if (semiJoinOutputVar.equals(semiJoinNode.getSemiJoinOutput())) {
continue; // handled separately below
}
// This variable comes from the project's output. Map it to the per-branch project output.
VariableReferenceExpression branchVar = projectVarMapping.get(semiJoinOutputVar);
if (branchVar != null) {
outputMappings.put(semiJoinOutputVar, branchVar);
}
}

// Remap dynamic filter source variables through the project variable mapping
branchDynamicFilters = remapDynamicFilters(semiJoinNode.getDynamicFilters(), projectVarMapping);
}
else {
branchSource = unionNode.getSources().get(i);

// Map the semi-join source variables through the union variable mapping
mappedSourceJoinVar = unionVarMap.get(semiJoinNode.getSourceJoinVariable());
if (mappedSourceJoinVar == null) {
return Result.empty();
}
mappedSourceHashVar = semiJoinNode.getSourceHashVariable().map(unionVarMap::get);
if (mappedSourceHashVar.isPresent() && mappedSourceHashVar.get() == null) {
return Result.empty();
}

// Build output-to-input mappings for original union output variables
for (VariableReferenceExpression unionOutputVar : unionNode.getOutputVariables()) {
outputMappings.put(unionOutputVar, unionVarMap.get(unionOutputVar));
}

// Remap dynamic filter source variables through the union variable mapping
branchDynamicFilters = remapDynamicFilters(semiJoinNode.getDynamicFilters(), unionVarMap);
}

// Allocate new semiJoinOutput variable for each branch
VariableReferenceExpression newSemiJoinOutput =
context.getVariableAllocator().newVariable(semiJoinNode.getSemiJoinOutput());

// Build new SemiJoinNode for this branch
SemiJoinNode newSemiJoin = new SemiJoinNode(
semiJoinNode.getSourceLocation(),
context.getIdAllocator().getNextId(),
branchSource,
semiJoinNode.getFilteringSource(),
mappedSourceJoinVar,
semiJoinNode.getFilteringSourceJoinVariable(),
newSemiJoinOutput,
mappedSourceHashVar,
semiJoinNode.getFilteringSourceHashVariable(),
semiJoinNode.getDistributionType(),
branchDynamicFilters);

newSources.add(newSemiJoin);

// Add the semiJoinOutput mapping
outputMappings.put(semiJoinNode.getSemiJoinOutput(), newSemiJoinOutput);
}

ListMultimap<VariableReferenceExpression, VariableReferenceExpression> mappings = outputMappings.build();

return Result.ofPlanNode(new UnionNode(
unionNode.getSourceLocation(),
context.getIdAllocator().getNextId(),
newSources.build(),
ImmutableList.copyOf(semiJoinNode.getOutputVariables()),
fromListMultimap(mappings)));
}

private static Map<String, VariableReferenceExpression> remapDynamicFilters(
Map<String, VariableReferenceExpression> dynamicFilters,
Map<VariableReferenceExpression, VariableReferenceExpression> variableMapping)
{
ImmutableMap.Builder<String, VariableReferenceExpression> remapped = ImmutableMap.builder();
for (Map.Entry<String, VariableReferenceExpression> entry : dynamicFilters.entrySet()) {
VariableReferenceExpression mappedVar = variableMapping.get(entry.getValue());
if (mappedVar != null) {
remapped.put(entry.getKey(), mappedVar);
}
}
return remapped.build();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,7 @@ public void testDefaults()
.setExchangeChecksumEnabled(false)
.setEnableIntermediateAggregations(false)
.setPushAggregationThroughJoin(true)
.setPushSemiJoinThroughUnion(false)
.setForceSingleNodeOutput(true)
.setPagesIndexEagerCompactionEnabled(false)
.setFilterAndProjectMinOutputPageSize(new DataSize(500, KILOBYTE))
Expand Down Expand Up @@ -343,6 +344,7 @@ public void testExplicitPropertyMappings()
.put("optimizer.retry-query-with-history-based-optimization", "true")
.put("optimizer.treat-low-confidence-zero-estimation-as-unknown", "true")
.put("optimizer.push-aggregation-through-join", "false")
.put("optimizer.push-semi-join-through-union", "true")
.put("optimizer.aggregation-partition-merging", "top_down")
.put("experimental.spill-enabled", "true")
.put("experimental.join-spill-enabled", "false")
Expand Down Expand Up @@ -564,6 +566,7 @@ public void testExplicitPropertyMappings()
.setTreatLowConfidenceZeroEstimationAsUnknownEnabled(true)
.setAggregationPartitioningMergingStrategy(TOP_DOWN)
.setPushAggregationThroughJoin(false)
.setPushSemiJoinThroughUnion(true)
.setSpillEnabled(true)
.setJoinSpillingEnabled(false)
.setSpillerSpillPaths("/tmp/custom/spill/path1,/tmp/custom/spill/path2")
Expand Down
Loading
Loading