diff --git a/presto-main-base/src/main/java/com/facebook/presto/SystemSessionProperties.java b/presto-main-base/src/main/java/com/facebook/presto/SystemSessionProperties.java
index f5adf59cca75b..ce66a176211e8 100644
--- a/presto-main-base/src/main/java/com/facebook/presto/SystemSessionProperties.java
+++ b/presto-main-base/src/main/java/com/facebook/presto/SystemSessionProperties.java
@@ -169,6 +169,7 @@ public final class SystemSessionProperties
public static final String LEGACY_TIMESTAMP = "legacy_timestamp";
public static final String ENABLE_INTERMEDIATE_AGGREGATIONS = "enable_intermediate_aggregations";
public static final String PUSH_AGGREGATION_THROUGH_JOIN = "push_aggregation_through_join";
+ public static final String PUSH_SEMI_JOIN_THROUGH_UNION = "push_semi_join_through_union";
public static final String PUSH_PARTIAL_AGGREGATION_THROUGH_JOIN = "push_partial_aggregation_through_join";
public static final String PARSE_DECIMAL_LITERALS_AS_DOUBLE = "parse_decimal_literals_as_double";
public static final String FORCE_SINGLE_NODE_OUTPUT = "force_single_node_output";
@@ -907,6 +908,11 @@ public SystemSessionProperties(
"Allow pushing aggregations below joins",
featuresConfig.isPushAggregationThroughJoin(),
false),
+ booleanProperty(
+ PUSH_SEMI_JOIN_THROUGH_UNION,
+ "Allow pushing semi joins through union",
+ featuresConfig.isPushSemiJoinThroughUnion(),
+ false),
booleanProperty(
PUSH_PARTIAL_AGGREGATION_THROUGH_JOIN,
"Push partial aggregations below joins",
@@ -2488,6 +2494,11 @@ public static boolean shouldPushAggregationThroughJoin(Session session)
return session.getSystemProperty(PUSH_AGGREGATION_THROUGH_JOIN, Boolean.class);
}
+ public static boolean isPushSemiJoinThroughUnion(Session session)
+ {
+ return session.getSystemProperty(PUSH_SEMI_JOIN_THROUGH_UNION, Boolean.class);
+ }
+
public static boolean isNativeExecutionEnabled(Session session)
{
return session.getSystemProperty(NATIVE_EXECUTION_ENABLED, Boolean.class);
diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/analyzer/FeaturesConfig.java b/presto-main-base/src/main/java/com/facebook/presto/sql/analyzer/FeaturesConfig.java
index 5e3adba503208..ec50b6a2063a3 100644
--- a/presto-main-base/src/main/java/com/facebook/presto/sql/analyzer/FeaturesConfig.java
+++ b/presto-main-base/src/main/java/com/facebook/presto/sql/analyzer/FeaturesConfig.java
@@ -154,6 +154,7 @@ public class FeaturesConfig
private double defaultJoinSelectivityCoefficient;
private double defaultWriterReplicationCoefficient = 3;
private boolean pushAggregationThroughJoin = true;
+ private boolean pushSemiJoinThroughUnion;
private double memoryRevokingTarget = 0.5;
private double memoryRevokingThreshold = 0.9;
private boolean useMarkDistinct = true;
@@ -1625,6 +1626,19 @@ public FeaturesConfig setPushAggregationThroughJoin(boolean value)
return this;
}
+ public boolean isPushSemiJoinThroughUnion()
+ {
+ return pushSemiJoinThroughUnion;
+ }
+
+ @Config("optimizer.push-semi-join-through-union")
+ @ConfigDescription("Push semi join through union to allow parallel semi join execution")
+ public FeaturesConfig setPushSemiJoinThroughUnion(boolean pushSemiJoinThroughUnion)
+ {
+ this.pushSemiJoinThroughUnion = pushSemiJoinThroughUnion;
+ return this;
+ }
+
public boolean isForceSingleNodeOutput()
{
return forceSingleNodeOutput;
diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/PlanOptimizers.java b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/PlanOptimizers.java
index e7bd31adf70ac..b9c2a3c2f6655 100644
--- a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/PlanOptimizers.java
+++ b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/PlanOptimizers.java
@@ -106,6 +106,7 @@
import com.facebook.presto.sql.planner.iterative.rule.PushProjectionThroughUnion;
import com.facebook.presto.sql.planner.iterative.rule.PushRemoteExchangeThroughAssignUniqueId;
import com.facebook.presto.sql.planner.iterative.rule.PushRemoteExchangeThroughGroupId;
+import com.facebook.presto.sql.planner.iterative.rule.PushSemiJoinThroughUnion;
import com.facebook.presto.sql.planner.iterative.rule.PushTableWriteThroughUnion;
import com.facebook.presto.sql.planner.iterative.rule.PushTopNThroughUnion;
import com.facebook.presto.sql.planner.iterative.rule.RandomizeSourceKeyInSemiJoin;
@@ -618,6 +619,12 @@ public PlanOptimizers(
statsCalculator,
estimatedExchangesCostCalculator,
ImmutableSet.of(new LeftJoinNullFilterToSemiJoin(metadata.getFunctionAndTypeManager()))),
+ new IterativeOptimizer(
+ metadata,
+ ruleStats,
+ statsCalculator,
+ estimatedExchangesCostCalculator,
+ ImmutableSet.of(new PushSemiJoinThroughUnion())),
new IterativeOptimizer(
metadata,
ruleStats,
diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PushSemiJoinThroughUnion.java b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PushSemiJoinThroughUnion.java
new file mode 100644
index 0000000000000..fbae8c601415a
--- /dev/null
+++ b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PushSemiJoinThroughUnion.java
@@ -0,0 +1,242 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package com.facebook.presto.sql.planner.iterative.rule;
+
+import com.facebook.presto.Session;
+import com.facebook.presto.matching.Captures;
+import com.facebook.presto.matching.Pattern;
+import com.facebook.presto.spi.plan.Assignments;
+import com.facebook.presto.spi.plan.PlanNode;
+import com.facebook.presto.spi.plan.ProjectNode;
+import com.facebook.presto.spi.plan.SemiJoinNode;
+import com.facebook.presto.spi.plan.UnionNode;
+import com.facebook.presto.spi.relation.RowExpression;
+import com.facebook.presto.spi.relation.VariableReferenceExpression;
+import com.facebook.presto.sql.planner.RowExpressionVariableInliner;
+import com.facebook.presto.sql.planner.iterative.Rule;
+import com.google.common.collect.ImmutableList;
+import com.google.common.collect.ImmutableListMultimap;
+import com.google.common.collect.ImmutableMap;
+import com.google.common.collect.ListMultimap;
+
+import java.util.HashMap;
+import java.util.Map;
+import java.util.Optional;
+
+import static com.facebook.presto.SystemSessionProperties.isPushSemiJoinThroughUnion;
+import static com.facebook.presto.sql.planner.optimizations.SetOperationNodeUtils.fromListMultimap;
+import static com.facebook.presto.sql.planner.plan.Patterns.semiJoin;
+
+/**
+ * Pushes a SemiJoinNode through a UnionNode (on the probe/source side).
+ *
+ * Transforms:
+ *
+ * - SemiJoin (sourceJoinVar=c, output=sjOut)
+ * - Union (output c from [a1, a2])
+ * - source1 (outputs a1)
+ * - source2 (outputs a2)
+ * - filteringSource
+ *
+ * into:
+ *
+ * - Union (output sjOut from [sjOut_0, sjOut_1], c from [a1, a2])
+ * - SemiJoin (sourceJoinVar=a1, output=sjOut_0)
+ * - source1
+ * - filteringSource
+ * - SemiJoin (sourceJoinVar=a2, output=sjOut_1)
+ * - source2
+ * - filteringSource
+ *
+ *
+ * Also handles the case where a ProjectNode sits between the SemiJoin and Union:
+ *
+ * - SemiJoin
+ * - Project
+ * - Union
+ * - filteringSource
+ *
+ * In this case, the project is pushed into each union branch before the semi join.
+ */
+public class PushSemiJoinThroughUnion
+ implements Rule
+{
+ private static final Pattern PATTERN = semiJoin();
+
+ @Override
+ public Pattern getPattern()
+ {
+ return PATTERN;
+ }
+
+ @Override
+ public boolean isEnabled(Session session)
+ {
+ return isPushSemiJoinThroughUnion(session);
+ }
+
+ @Override
+ public Result apply(SemiJoinNode semiJoinNode, Captures captures, Context context)
+ {
+ PlanNode source = context.getLookup().resolve(semiJoinNode.getSource());
+
+ if (source instanceof UnionNode) {
+ return pushThroughUnion(semiJoinNode, (UnionNode) source, Optional.empty(), context);
+ }
+
+ if (source instanceof ProjectNode) {
+ ProjectNode projectNode = (ProjectNode) source;
+ PlanNode projectSource = context.getLookup().resolve(projectNode.getSource());
+ if (projectSource instanceof UnionNode) {
+ return pushThroughUnion(semiJoinNode, (UnionNode) projectSource, Optional.of(projectNode), context);
+ }
+ }
+
+ return Result.empty();
+ }
+
+ private Result pushThroughUnion(
+ SemiJoinNode semiJoinNode,
+ UnionNode unionNode,
+ Optional projectNode,
+ Context context)
+ {
+ ImmutableList.Builder newSources = ImmutableList.builder();
+ ImmutableListMultimap.Builder outputMappings =
+ ImmutableListMultimap.builder();
+
+ for (int i = 0; i < unionNode.getSources().size(); i++) {
+ Map unionVarMap = unionNode.sourceVariableMap(i);
+
+ PlanNode branchSource;
+ VariableReferenceExpression mappedSourceJoinVar;
+ Optional mappedSourceHashVar;
+ Map branchDynamicFilters;
+
+ if (projectNode.isPresent()) {
+ // Push the project into each union branch, translating its assignments
+ ProjectNode project = projectNode.get();
+ Assignments.Builder assignments = Assignments.builder();
+ Map projectVarMapping = new HashMap<>();
+
+ for (Map.Entry entry : project.getAssignments().entrySet()) {
+ RowExpression translatedExpression = RowExpressionVariableInliner.inlineVariables(unionVarMap, entry.getValue());
+ VariableReferenceExpression newVar = context.getVariableAllocator().newVariable(translatedExpression);
+ assignments.put(newVar, translatedExpression);
+ projectVarMapping.put(entry.getKey(), newVar);
+ }
+
+ branchSource = new ProjectNode(
+ project.getSourceLocation(),
+ context.getIdAllocator().getNextId(),
+ unionNode.getSources().get(i),
+ assignments.build(),
+ project.getLocality());
+
+ // Map the semi-join source variables through the project variable mapping
+ mappedSourceJoinVar = projectVarMapping.get(semiJoinNode.getSourceJoinVariable());
+ if (mappedSourceJoinVar == null) {
+ return Result.empty();
+ }
+ mappedSourceHashVar = semiJoinNode.getSourceHashVariable().map(projectVarMapping::get);
+ if (mappedSourceHashVar.isPresent() && mappedSourceHashVar.get() == null) {
+ return Result.empty();
+ }
+
+ // Build output-to-input mappings for original union output variables,
+ // mapped through the project
+ for (VariableReferenceExpression semiJoinOutputVar : semiJoinNode.getOutputVariables()) {
+ if (semiJoinOutputVar.equals(semiJoinNode.getSemiJoinOutput())) {
+ continue; // handled separately below
+ }
+ // This variable comes from the project's output. Map it to the per-branch project output.
+ VariableReferenceExpression branchVar = projectVarMapping.get(semiJoinOutputVar);
+ if (branchVar != null) {
+ outputMappings.put(semiJoinOutputVar, branchVar);
+ }
+ }
+
+ // Remap dynamic filter source variables through the project variable mapping
+ branchDynamicFilters = remapDynamicFilters(semiJoinNode.getDynamicFilters(), projectVarMapping);
+ }
+ else {
+ branchSource = unionNode.getSources().get(i);
+
+ // Map the semi-join source variables through the union variable mapping
+ mappedSourceJoinVar = unionVarMap.get(semiJoinNode.getSourceJoinVariable());
+ if (mappedSourceJoinVar == null) {
+ return Result.empty();
+ }
+ mappedSourceHashVar = semiJoinNode.getSourceHashVariable().map(unionVarMap::get);
+ if (mappedSourceHashVar.isPresent() && mappedSourceHashVar.get() == null) {
+ return Result.empty();
+ }
+
+ // Build output-to-input mappings for original union output variables
+ for (VariableReferenceExpression unionOutputVar : unionNode.getOutputVariables()) {
+ outputMappings.put(unionOutputVar, unionVarMap.get(unionOutputVar));
+ }
+
+ // Remap dynamic filter source variables through the union variable mapping
+ branchDynamicFilters = remapDynamicFilters(semiJoinNode.getDynamicFilters(), unionVarMap);
+ }
+
+ // Allocate new semiJoinOutput variable for each branch
+ VariableReferenceExpression newSemiJoinOutput =
+ context.getVariableAllocator().newVariable(semiJoinNode.getSemiJoinOutput());
+
+ // Build new SemiJoinNode for this branch
+ SemiJoinNode newSemiJoin = new SemiJoinNode(
+ semiJoinNode.getSourceLocation(),
+ context.getIdAllocator().getNextId(),
+ branchSource,
+ semiJoinNode.getFilteringSource(),
+ mappedSourceJoinVar,
+ semiJoinNode.getFilteringSourceJoinVariable(),
+ newSemiJoinOutput,
+ mappedSourceHashVar,
+ semiJoinNode.getFilteringSourceHashVariable(),
+ semiJoinNode.getDistributionType(),
+ branchDynamicFilters);
+
+ newSources.add(newSemiJoin);
+
+ // Add the semiJoinOutput mapping
+ outputMappings.put(semiJoinNode.getSemiJoinOutput(), newSemiJoinOutput);
+ }
+
+ ListMultimap mappings = outputMappings.build();
+
+ return Result.ofPlanNode(new UnionNode(
+ unionNode.getSourceLocation(),
+ context.getIdAllocator().getNextId(),
+ newSources.build(),
+ ImmutableList.copyOf(semiJoinNode.getOutputVariables()),
+ fromListMultimap(mappings)));
+ }
+
+ private static Map remapDynamicFilters(
+ Map dynamicFilters,
+ Map variableMapping)
+ {
+ ImmutableMap.Builder remapped = ImmutableMap.builder();
+ for (Map.Entry entry : dynamicFilters.entrySet()) {
+ VariableReferenceExpression mappedVar = variableMapping.get(entry.getValue());
+ if (mappedVar != null) {
+ remapped.put(entry.getKey(), mappedVar);
+ }
+ }
+ return remapped.build();
+ }
+}
diff --git a/presto-main-base/src/test/java/com/facebook/presto/sql/analyzer/TestFeaturesConfig.java b/presto-main-base/src/test/java/com/facebook/presto/sql/analyzer/TestFeaturesConfig.java
index 2af699b59ed69..2f488fbc2debc 100644
--- a/presto-main-base/src/test/java/com/facebook/presto/sql/analyzer/TestFeaturesConfig.java
+++ b/presto-main-base/src/test/java/com/facebook/presto/sql/analyzer/TestFeaturesConfig.java
@@ -134,6 +134,7 @@ public void testDefaults()
.setExchangeChecksumEnabled(false)
.setEnableIntermediateAggregations(false)
.setPushAggregationThroughJoin(true)
+ .setPushSemiJoinThroughUnion(false)
.setForceSingleNodeOutput(true)
.setPagesIndexEagerCompactionEnabled(false)
.setFilterAndProjectMinOutputPageSize(new DataSize(500, KILOBYTE))
@@ -343,6 +344,7 @@ public void testExplicitPropertyMappings()
.put("optimizer.retry-query-with-history-based-optimization", "true")
.put("optimizer.treat-low-confidence-zero-estimation-as-unknown", "true")
.put("optimizer.push-aggregation-through-join", "false")
+ .put("optimizer.push-semi-join-through-union", "true")
.put("optimizer.aggregation-partition-merging", "top_down")
.put("experimental.spill-enabled", "true")
.put("experimental.join-spill-enabled", "false")
@@ -564,6 +566,7 @@ public void testExplicitPropertyMappings()
.setTreatLowConfidenceZeroEstimationAsUnknownEnabled(true)
.setAggregationPartitioningMergingStrategy(TOP_DOWN)
.setPushAggregationThroughJoin(false)
+ .setPushSemiJoinThroughUnion(true)
.setSpillEnabled(true)
.setJoinSpillingEnabled(false)
.setSpillerSpillPaths("/tmp/custom/spill/path1,/tmp/custom/spill/path2")
diff --git a/presto-main-base/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPushSemiJoinThroughUnion.java b/presto-main-base/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPushSemiJoinThroughUnion.java
new file mode 100644
index 0000000000000..00eee9d8f1f9a
--- /dev/null
+++ b/presto-main-base/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPushSemiJoinThroughUnion.java
@@ -0,0 +1,261 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package com.facebook.presto.sql.planner.iterative.rule;
+
+import com.facebook.presto.spi.relation.VariableReferenceExpression;
+import com.facebook.presto.sql.planner.iterative.rule.test.BaseRuleTest;
+import com.facebook.presto.sql.relational.FunctionResolution;
+import com.google.common.collect.ImmutableList;
+import com.google.common.collect.ImmutableListMultimap;
+import com.google.common.collect.ImmutableMap;
+import org.testng.annotations.Test;
+
+import java.util.Optional;
+
+import static com.facebook.presto.SystemSessionProperties.PUSH_SEMI_JOIN_THROUGH_UNION;
+import static com.facebook.presto.common.function.OperatorType.MULTIPLY;
+import static com.facebook.presto.common.type.BigintType.BIGINT;
+import static com.facebook.presto.common.type.BooleanType.BOOLEAN;
+import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.expression;
+import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.project;
+import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.semiJoin;
+import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.union;
+import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.values;
+import static com.facebook.presto.sql.planner.iterative.rule.test.PlanBuilder.assignment;
+import static com.facebook.presto.sql.relational.Expressions.call;
+import static com.facebook.presto.sql.relational.Expressions.constant;
+
+public class TestPushSemiJoinThroughUnion
+ extends BaseRuleTest
+{
+ @Test
+ public void testDoesNotFireWhenSourceIsNotUnion()
+ {
+ tester().assertThat(new PushSemiJoinThroughUnion())
+ .setSystemProperty(PUSH_SEMI_JOIN_THROUGH_UNION, "true")
+ .on(p -> {
+ VariableReferenceExpression sourceJoinVar = p.variable("sourceJoinVar");
+ VariableReferenceExpression filterJoinVar = p.variable("filterJoinVar");
+ VariableReferenceExpression semiJoinOutput = p.variable("semiJoinOutput", BOOLEAN);
+ return p.semiJoin(
+ sourceJoinVar,
+ filterJoinVar,
+ semiJoinOutput,
+ Optional.empty(),
+ Optional.empty(),
+ p.values(sourceJoinVar),
+ p.values(filterJoinVar));
+ })
+ .doesNotFire();
+ }
+
+ @Test
+ public void testPushThroughTwoBranchUnion()
+ {
+ tester().assertThat(new PushSemiJoinThroughUnion())
+ .setSystemProperty(PUSH_SEMI_JOIN_THROUGH_UNION, "true")
+ .on(p -> {
+ VariableReferenceExpression a = p.variable("a");
+ VariableReferenceExpression b = p.variable("b");
+ VariableReferenceExpression c = p.variable("c");
+ VariableReferenceExpression filterJoinVar = p.variable("filterJoinVar");
+ VariableReferenceExpression semiJoinOutput = p.variable("semiJoinOutput", BOOLEAN);
+ return p.semiJoin(
+ c,
+ filterJoinVar,
+ semiJoinOutput,
+ Optional.empty(),
+ Optional.empty(),
+ p.union(
+ ImmutableListMultimap.builder()
+ .put(c, a)
+ .put(c, b)
+ .build(),
+ ImmutableList.of(
+ p.values(a),
+ p.values(b))),
+ p.values(filterJoinVar));
+ })
+ .matches(
+ union(
+ semiJoin("a", "filterJoinVar", "semiJoinOutput_0",
+ values("a"),
+ values("filterJoinVar")),
+ semiJoin("b", "filterJoinVar", "semiJoinOutput_1",
+ values("b"),
+ values("filterJoinVar"))));
+ }
+
+ @Test
+ public void testPushThroughThreeBranchUnion()
+ {
+ tester().assertThat(new PushSemiJoinThroughUnion())
+ .setSystemProperty(PUSH_SEMI_JOIN_THROUGH_UNION, "true")
+ .on(p -> {
+ VariableReferenceExpression a = p.variable("a");
+ VariableReferenceExpression b = p.variable("b");
+ VariableReferenceExpression d = p.variable("d");
+ VariableReferenceExpression c = p.variable("c");
+ VariableReferenceExpression filterJoinVar = p.variable("filterJoinVar");
+ VariableReferenceExpression semiJoinOutput = p.variable("semiJoinOutput", BOOLEAN);
+ return p.semiJoin(
+ c,
+ filterJoinVar,
+ semiJoinOutput,
+ Optional.empty(),
+ Optional.empty(),
+ p.union(
+ ImmutableListMultimap.builder()
+ .put(c, a)
+ .put(c, b)
+ .put(c, d)
+ .build(),
+ ImmutableList.of(
+ p.values(a),
+ p.values(b),
+ p.values(d))),
+ p.values(filterJoinVar));
+ })
+ .matches(
+ union(
+ semiJoin("a", "filterJoinVar", "semiJoinOutput_0",
+ values("a"),
+ values("filterJoinVar")),
+ semiJoin("b", "filterJoinVar", "semiJoinOutput_1",
+ values("b"),
+ values("filterJoinVar")),
+ semiJoin("d", "filterJoinVar", "semiJoinOutput_2",
+ values("d"),
+ values("filterJoinVar"))));
+ }
+
+ @Test
+ public void testPushThroughProjectOverUnion()
+ {
+ FunctionResolution functionResolution = new FunctionResolution(tester().getMetadata().getFunctionAndTypeManager().getFunctionAndTypeResolver());
+ tester().assertThat(new PushSemiJoinThroughUnion())
+ .setSystemProperty(PUSH_SEMI_JOIN_THROUGH_UNION, "true")
+ .on(p -> {
+ VariableReferenceExpression a = p.variable("a");
+ VariableReferenceExpression b = p.variable("b");
+ VariableReferenceExpression c = p.variable("c");
+ VariableReferenceExpression cTimes3 = p.variable("c_times_3");
+ VariableReferenceExpression filterJoinVar = p.variable("filterJoinVar");
+ VariableReferenceExpression semiJoinOutput = p.variable("semiJoinOutput", BOOLEAN);
+ return p.semiJoin(
+ cTimes3,
+ filterJoinVar,
+ semiJoinOutput,
+ Optional.empty(),
+ Optional.empty(),
+ p.project(
+ assignment(
+ cTimes3,
+ call("c * 3", functionResolution.arithmeticFunction(MULTIPLY, BIGINT, BIGINT), BIGINT, c, constant(3L, BIGINT))),
+ p.union(
+ ImmutableListMultimap.builder()
+ .put(c, a)
+ .put(c, b)
+ .build(),
+ ImmutableList.of(
+ p.values(a),
+ p.values(b)))),
+ p.values(filterJoinVar));
+ })
+ .matches(
+ union(
+ semiJoin(
+ project(
+ ImmutableMap.of("a_times_3", expression("a * 3")),
+ values("a")),
+ values("filterJoinVar")),
+ semiJoin(
+ project(
+ ImmutableMap.of("b_times_3", expression("b * 3")),
+ values("b")),
+ values("filterJoinVar"))));
+ }
+
+ @Test
+ public void testPushThroughUnionWithHashVariables()
+ {
+ tester().assertThat(new PushSemiJoinThroughUnion())
+ .setSystemProperty(PUSH_SEMI_JOIN_THROUGH_UNION, "true")
+ .on(p -> {
+ VariableReferenceExpression a = p.variable("a");
+ VariableReferenceExpression b = p.variable("b");
+ VariableReferenceExpression c = p.variable("c");
+ VariableReferenceExpression aHash = p.variable("aHash");
+ VariableReferenceExpression bHash = p.variable("bHash");
+ VariableReferenceExpression cHash = p.variable("cHash");
+ VariableReferenceExpression filterJoinVar = p.variable("filterJoinVar");
+ VariableReferenceExpression filterHash = p.variable("filterHash");
+ VariableReferenceExpression semiJoinOutput = p.variable("semiJoinOutput", BOOLEAN);
+ return p.semiJoin(
+ c,
+ filterJoinVar,
+ semiJoinOutput,
+ Optional.of(cHash),
+ Optional.of(filterHash),
+ p.union(
+ ImmutableListMultimap.builder()
+ .put(c, a)
+ .put(c, b)
+ .put(cHash, aHash)
+ .put(cHash, bHash)
+ .build(),
+ ImmutableList.of(
+ p.values(a, aHash),
+ p.values(b, bHash))),
+ p.values(filterJoinVar, filterHash));
+ })
+ .matches(
+ union(
+ semiJoin("a", "filterJoinVar", "semiJoinOutput_0",
+ values("a", "aHash"),
+ values("filterJoinVar", "filterHash")),
+ semiJoin("b", "filterJoinVar", "semiJoinOutput_1",
+ values("b", "bHash"),
+ values("filterJoinVar", "filterHash"))));
+ }
+
+ @Test
+ public void testDoesNotFireWhenDisabled()
+ {
+ tester().assertThat(new PushSemiJoinThroughUnion())
+ .on(p -> {
+ VariableReferenceExpression a = p.variable("a");
+ VariableReferenceExpression b = p.variable("b");
+ VariableReferenceExpression c = p.variable("c");
+ VariableReferenceExpression filterJoinVar = p.variable("filterJoinVar");
+ VariableReferenceExpression semiJoinOutput = p.variable("semiJoinOutput", BOOLEAN);
+ return p.semiJoin(
+ c,
+ filterJoinVar,
+ semiJoinOutput,
+ Optional.empty(),
+ Optional.empty(),
+ p.union(
+ ImmutableListMultimap.builder()
+ .put(c, a)
+ .put(c, b)
+ .build(),
+ ImmutableList.of(
+ p.values(a),
+ p.values(b))),
+ p.values(filterJoinVar));
+ })
+ .doesNotFire();
+ }
+}