diff --git a/presto-main-base/src/main/java/com/facebook/presto/SystemSessionProperties.java b/presto-main-base/src/main/java/com/facebook/presto/SystemSessionProperties.java index f5adf59cca75b..ce66a176211e8 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/SystemSessionProperties.java +++ b/presto-main-base/src/main/java/com/facebook/presto/SystemSessionProperties.java @@ -169,6 +169,7 @@ public final class SystemSessionProperties public static final String LEGACY_TIMESTAMP = "legacy_timestamp"; public static final String ENABLE_INTERMEDIATE_AGGREGATIONS = "enable_intermediate_aggregations"; public static final String PUSH_AGGREGATION_THROUGH_JOIN = "push_aggregation_through_join"; + public static final String PUSH_SEMI_JOIN_THROUGH_UNION = "push_semi_join_through_union"; public static final String PUSH_PARTIAL_AGGREGATION_THROUGH_JOIN = "push_partial_aggregation_through_join"; public static final String PARSE_DECIMAL_LITERALS_AS_DOUBLE = "parse_decimal_literals_as_double"; public static final String FORCE_SINGLE_NODE_OUTPUT = "force_single_node_output"; @@ -907,6 +908,11 @@ public SystemSessionProperties( "Allow pushing aggregations below joins", featuresConfig.isPushAggregationThroughJoin(), false), + booleanProperty( + PUSH_SEMI_JOIN_THROUGH_UNION, + "Allow pushing semi joins through union", + featuresConfig.isPushSemiJoinThroughUnion(), + false), booleanProperty( PUSH_PARTIAL_AGGREGATION_THROUGH_JOIN, "Push partial aggregations below joins", @@ -2488,6 +2494,11 @@ public static boolean shouldPushAggregationThroughJoin(Session session) return session.getSystemProperty(PUSH_AGGREGATION_THROUGH_JOIN, Boolean.class); } + public static boolean isPushSemiJoinThroughUnion(Session session) + { + return session.getSystemProperty(PUSH_SEMI_JOIN_THROUGH_UNION, Boolean.class); + } + public static boolean isNativeExecutionEnabled(Session session) { return session.getSystemProperty(NATIVE_EXECUTION_ENABLED, Boolean.class); diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/analyzer/FeaturesConfig.java b/presto-main-base/src/main/java/com/facebook/presto/sql/analyzer/FeaturesConfig.java index 5e3adba503208..ec50b6a2063a3 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/analyzer/FeaturesConfig.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/analyzer/FeaturesConfig.java @@ -154,6 +154,7 @@ public class FeaturesConfig private double defaultJoinSelectivityCoefficient; private double defaultWriterReplicationCoefficient = 3; private boolean pushAggregationThroughJoin = true; + private boolean pushSemiJoinThroughUnion; private double memoryRevokingTarget = 0.5; private double memoryRevokingThreshold = 0.9; private boolean useMarkDistinct = true; @@ -1625,6 +1626,19 @@ public FeaturesConfig setPushAggregationThroughJoin(boolean value) return this; } + public boolean isPushSemiJoinThroughUnion() + { + return pushSemiJoinThroughUnion; + } + + @Config("optimizer.push-semi-join-through-union") + @ConfigDescription("Push semi join through union to allow parallel semi join execution") + public FeaturesConfig setPushSemiJoinThroughUnion(boolean pushSemiJoinThroughUnion) + { + this.pushSemiJoinThroughUnion = pushSemiJoinThroughUnion; + return this; + } + public boolean isForceSingleNodeOutput() { return forceSingleNodeOutput; diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/PlanOptimizers.java b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/PlanOptimizers.java index e7bd31adf70ac..b9c2a3c2f6655 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/PlanOptimizers.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/PlanOptimizers.java @@ -106,6 +106,7 @@ import com.facebook.presto.sql.planner.iterative.rule.PushProjectionThroughUnion; import com.facebook.presto.sql.planner.iterative.rule.PushRemoteExchangeThroughAssignUniqueId; import com.facebook.presto.sql.planner.iterative.rule.PushRemoteExchangeThroughGroupId; +import com.facebook.presto.sql.planner.iterative.rule.PushSemiJoinThroughUnion; import com.facebook.presto.sql.planner.iterative.rule.PushTableWriteThroughUnion; import com.facebook.presto.sql.planner.iterative.rule.PushTopNThroughUnion; import com.facebook.presto.sql.planner.iterative.rule.RandomizeSourceKeyInSemiJoin; @@ -618,6 +619,12 @@ public PlanOptimizers( statsCalculator, estimatedExchangesCostCalculator, ImmutableSet.of(new LeftJoinNullFilterToSemiJoin(metadata.getFunctionAndTypeManager()))), + new IterativeOptimizer( + metadata, + ruleStats, + statsCalculator, + estimatedExchangesCostCalculator, + ImmutableSet.of(new PushSemiJoinThroughUnion())), new IterativeOptimizer( metadata, ruleStats, diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PushSemiJoinThroughUnion.java b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PushSemiJoinThroughUnion.java new file mode 100644 index 0000000000000..fbae8c601415a --- /dev/null +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PushSemiJoinThroughUnion.java @@ -0,0 +1,242 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.sql.planner.iterative.rule; + +import com.facebook.presto.Session; +import com.facebook.presto.matching.Captures; +import com.facebook.presto.matching.Pattern; +import com.facebook.presto.spi.plan.Assignments; +import com.facebook.presto.spi.plan.PlanNode; +import com.facebook.presto.spi.plan.ProjectNode; +import com.facebook.presto.spi.plan.SemiJoinNode; +import com.facebook.presto.spi.plan.UnionNode; +import com.facebook.presto.spi.relation.RowExpression; +import com.facebook.presto.spi.relation.VariableReferenceExpression; +import com.facebook.presto.sql.planner.RowExpressionVariableInliner; +import com.facebook.presto.sql.planner.iterative.Rule; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableListMultimap; +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ListMultimap; + +import java.util.HashMap; +import java.util.Map; +import java.util.Optional; + +import static com.facebook.presto.SystemSessionProperties.isPushSemiJoinThroughUnion; +import static com.facebook.presto.sql.planner.optimizations.SetOperationNodeUtils.fromListMultimap; +import static com.facebook.presto.sql.planner.plan.Patterns.semiJoin; + +/** + * Pushes a SemiJoinNode through a UnionNode (on the probe/source side). + *

+ * Transforms: + *

+ *     - SemiJoin (sourceJoinVar=c, output=sjOut)
+ *         - Union (output c from [a1, a2])
+ *             - source1 (outputs a1)
+ *             - source2 (outputs a2)
+ *         - filteringSource
+ * 
+ * into: + *
+ *     - Union (output sjOut from [sjOut_0, sjOut_1], c from [a1, a2])
+ *         - SemiJoin (sourceJoinVar=a1, output=sjOut_0)
+ *             - source1
+ *             - filteringSource
+ *         - SemiJoin (sourceJoinVar=a2, output=sjOut_1)
+ *             - source2
+ *             - filteringSource
+ * 
+ *

+ * Also handles the case where a ProjectNode sits between the SemiJoin and Union: + *

+ *     - SemiJoin
+ *         - Project
+ *             - Union
+ *         - filteringSource
+ * 
+ * In this case, the project is pushed into each union branch before the semi join. + */ +public class PushSemiJoinThroughUnion + implements Rule +{ + private static final Pattern PATTERN = semiJoin(); + + @Override + public Pattern getPattern() + { + return PATTERN; + } + + @Override + public boolean isEnabled(Session session) + { + return isPushSemiJoinThroughUnion(session); + } + + @Override + public Result apply(SemiJoinNode semiJoinNode, Captures captures, Context context) + { + PlanNode source = context.getLookup().resolve(semiJoinNode.getSource()); + + if (source instanceof UnionNode) { + return pushThroughUnion(semiJoinNode, (UnionNode) source, Optional.empty(), context); + } + + if (source instanceof ProjectNode) { + ProjectNode projectNode = (ProjectNode) source; + PlanNode projectSource = context.getLookup().resolve(projectNode.getSource()); + if (projectSource instanceof UnionNode) { + return pushThroughUnion(semiJoinNode, (UnionNode) projectSource, Optional.of(projectNode), context); + } + } + + return Result.empty(); + } + + private Result pushThroughUnion( + SemiJoinNode semiJoinNode, + UnionNode unionNode, + Optional projectNode, + Context context) + { + ImmutableList.Builder newSources = ImmutableList.builder(); + ImmutableListMultimap.Builder outputMappings = + ImmutableListMultimap.builder(); + + for (int i = 0; i < unionNode.getSources().size(); i++) { + Map unionVarMap = unionNode.sourceVariableMap(i); + + PlanNode branchSource; + VariableReferenceExpression mappedSourceJoinVar; + Optional mappedSourceHashVar; + Map branchDynamicFilters; + + if (projectNode.isPresent()) { + // Push the project into each union branch, translating its assignments + ProjectNode project = projectNode.get(); + Assignments.Builder assignments = Assignments.builder(); + Map projectVarMapping = new HashMap<>(); + + for (Map.Entry entry : project.getAssignments().entrySet()) { + RowExpression translatedExpression = RowExpressionVariableInliner.inlineVariables(unionVarMap, entry.getValue()); + VariableReferenceExpression newVar = context.getVariableAllocator().newVariable(translatedExpression); + assignments.put(newVar, translatedExpression); + projectVarMapping.put(entry.getKey(), newVar); + } + + branchSource = new ProjectNode( + project.getSourceLocation(), + context.getIdAllocator().getNextId(), + unionNode.getSources().get(i), + assignments.build(), + project.getLocality()); + + // Map the semi-join source variables through the project variable mapping + mappedSourceJoinVar = projectVarMapping.get(semiJoinNode.getSourceJoinVariable()); + if (mappedSourceJoinVar == null) { + return Result.empty(); + } + mappedSourceHashVar = semiJoinNode.getSourceHashVariable().map(projectVarMapping::get); + if (mappedSourceHashVar.isPresent() && mappedSourceHashVar.get() == null) { + return Result.empty(); + } + + // Build output-to-input mappings for original union output variables, + // mapped through the project + for (VariableReferenceExpression semiJoinOutputVar : semiJoinNode.getOutputVariables()) { + if (semiJoinOutputVar.equals(semiJoinNode.getSemiJoinOutput())) { + continue; // handled separately below + } + // This variable comes from the project's output. Map it to the per-branch project output. + VariableReferenceExpression branchVar = projectVarMapping.get(semiJoinOutputVar); + if (branchVar != null) { + outputMappings.put(semiJoinOutputVar, branchVar); + } + } + + // Remap dynamic filter source variables through the project variable mapping + branchDynamicFilters = remapDynamicFilters(semiJoinNode.getDynamicFilters(), projectVarMapping); + } + else { + branchSource = unionNode.getSources().get(i); + + // Map the semi-join source variables through the union variable mapping + mappedSourceJoinVar = unionVarMap.get(semiJoinNode.getSourceJoinVariable()); + if (mappedSourceJoinVar == null) { + return Result.empty(); + } + mappedSourceHashVar = semiJoinNode.getSourceHashVariable().map(unionVarMap::get); + if (mappedSourceHashVar.isPresent() && mappedSourceHashVar.get() == null) { + return Result.empty(); + } + + // Build output-to-input mappings for original union output variables + for (VariableReferenceExpression unionOutputVar : unionNode.getOutputVariables()) { + outputMappings.put(unionOutputVar, unionVarMap.get(unionOutputVar)); + } + + // Remap dynamic filter source variables through the union variable mapping + branchDynamicFilters = remapDynamicFilters(semiJoinNode.getDynamicFilters(), unionVarMap); + } + + // Allocate new semiJoinOutput variable for each branch + VariableReferenceExpression newSemiJoinOutput = + context.getVariableAllocator().newVariable(semiJoinNode.getSemiJoinOutput()); + + // Build new SemiJoinNode for this branch + SemiJoinNode newSemiJoin = new SemiJoinNode( + semiJoinNode.getSourceLocation(), + context.getIdAllocator().getNextId(), + branchSource, + semiJoinNode.getFilteringSource(), + mappedSourceJoinVar, + semiJoinNode.getFilteringSourceJoinVariable(), + newSemiJoinOutput, + mappedSourceHashVar, + semiJoinNode.getFilteringSourceHashVariable(), + semiJoinNode.getDistributionType(), + branchDynamicFilters); + + newSources.add(newSemiJoin); + + // Add the semiJoinOutput mapping + outputMappings.put(semiJoinNode.getSemiJoinOutput(), newSemiJoinOutput); + } + + ListMultimap mappings = outputMappings.build(); + + return Result.ofPlanNode(new UnionNode( + unionNode.getSourceLocation(), + context.getIdAllocator().getNextId(), + newSources.build(), + ImmutableList.copyOf(semiJoinNode.getOutputVariables()), + fromListMultimap(mappings))); + } + + private static Map remapDynamicFilters( + Map dynamicFilters, + Map variableMapping) + { + ImmutableMap.Builder remapped = ImmutableMap.builder(); + for (Map.Entry entry : dynamicFilters.entrySet()) { + VariableReferenceExpression mappedVar = variableMapping.get(entry.getValue()); + if (mappedVar != null) { + remapped.put(entry.getKey(), mappedVar); + } + } + return remapped.build(); + } +} diff --git a/presto-main-base/src/test/java/com/facebook/presto/sql/analyzer/TestFeaturesConfig.java b/presto-main-base/src/test/java/com/facebook/presto/sql/analyzer/TestFeaturesConfig.java index 2af699b59ed69..2f488fbc2debc 100644 --- a/presto-main-base/src/test/java/com/facebook/presto/sql/analyzer/TestFeaturesConfig.java +++ b/presto-main-base/src/test/java/com/facebook/presto/sql/analyzer/TestFeaturesConfig.java @@ -134,6 +134,7 @@ public void testDefaults() .setExchangeChecksumEnabled(false) .setEnableIntermediateAggregations(false) .setPushAggregationThroughJoin(true) + .setPushSemiJoinThroughUnion(false) .setForceSingleNodeOutput(true) .setPagesIndexEagerCompactionEnabled(false) .setFilterAndProjectMinOutputPageSize(new DataSize(500, KILOBYTE)) @@ -343,6 +344,7 @@ public void testExplicitPropertyMappings() .put("optimizer.retry-query-with-history-based-optimization", "true") .put("optimizer.treat-low-confidence-zero-estimation-as-unknown", "true") .put("optimizer.push-aggregation-through-join", "false") + .put("optimizer.push-semi-join-through-union", "true") .put("optimizer.aggregation-partition-merging", "top_down") .put("experimental.spill-enabled", "true") .put("experimental.join-spill-enabled", "false") @@ -564,6 +566,7 @@ public void testExplicitPropertyMappings() .setTreatLowConfidenceZeroEstimationAsUnknownEnabled(true) .setAggregationPartitioningMergingStrategy(TOP_DOWN) .setPushAggregationThroughJoin(false) + .setPushSemiJoinThroughUnion(true) .setSpillEnabled(true) .setJoinSpillingEnabled(false) .setSpillerSpillPaths("/tmp/custom/spill/path1,/tmp/custom/spill/path2") diff --git a/presto-main-base/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPushSemiJoinThroughUnion.java b/presto-main-base/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPushSemiJoinThroughUnion.java new file mode 100644 index 0000000000000..00eee9d8f1f9a --- /dev/null +++ b/presto-main-base/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPushSemiJoinThroughUnion.java @@ -0,0 +1,261 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.sql.planner.iterative.rule; + +import com.facebook.presto.spi.relation.VariableReferenceExpression; +import com.facebook.presto.sql.planner.iterative.rule.test.BaseRuleTest; +import com.facebook.presto.sql.relational.FunctionResolution; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableListMultimap; +import com.google.common.collect.ImmutableMap; +import org.testng.annotations.Test; + +import java.util.Optional; + +import static com.facebook.presto.SystemSessionProperties.PUSH_SEMI_JOIN_THROUGH_UNION; +import static com.facebook.presto.common.function.OperatorType.MULTIPLY; +import static com.facebook.presto.common.type.BigintType.BIGINT; +import static com.facebook.presto.common.type.BooleanType.BOOLEAN; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.expression; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.project; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.semiJoin; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.union; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.values; +import static com.facebook.presto.sql.planner.iterative.rule.test.PlanBuilder.assignment; +import static com.facebook.presto.sql.relational.Expressions.call; +import static com.facebook.presto.sql.relational.Expressions.constant; + +public class TestPushSemiJoinThroughUnion + extends BaseRuleTest +{ + @Test + public void testDoesNotFireWhenSourceIsNotUnion() + { + tester().assertThat(new PushSemiJoinThroughUnion()) + .setSystemProperty(PUSH_SEMI_JOIN_THROUGH_UNION, "true") + .on(p -> { + VariableReferenceExpression sourceJoinVar = p.variable("sourceJoinVar"); + VariableReferenceExpression filterJoinVar = p.variable("filterJoinVar"); + VariableReferenceExpression semiJoinOutput = p.variable("semiJoinOutput", BOOLEAN); + return p.semiJoin( + sourceJoinVar, + filterJoinVar, + semiJoinOutput, + Optional.empty(), + Optional.empty(), + p.values(sourceJoinVar), + p.values(filterJoinVar)); + }) + .doesNotFire(); + } + + @Test + public void testPushThroughTwoBranchUnion() + { + tester().assertThat(new PushSemiJoinThroughUnion()) + .setSystemProperty(PUSH_SEMI_JOIN_THROUGH_UNION, "true") + .on(p -> { + VariableReferenceExpression a = p.variable("a"); + VariableReferenceExpression b = p.variable("b"); + VariableReferenceExpression c = p.variable("c"); + VariableReferenceExpression filterJoinVar = p.variable("filterJoinVar"); + VariableReferenceExpression semiJoinOutput = p.variable("semiJoinOutput", BOOLEAN); + return p.semiJoin( + c, + filterJoinVar, + semiJoinOutput, + Optional.empty(), + Optional.empty(), + p.union( + ImmutableListMultimap.builder() + .put(c, a) + .put(c, b) + .build(), + ImmutableList.of( + p.values(a), + p.values(b))), + p.values(filterJoinVar)); + }) + .matches( + union( + semiJoin("a", "filterJoinVar", "semiJoinOutput_0", + values("a"), + values("filterJoinVar")), + semiJoin("b", "filterJoinVar", "semiJoinOutput_1", + values("b"), + values("filterJoinVar")))); + } + + @Test + public void testPushThroughThreeBranchUnion() + { + tester().assertThat(new PushSemiJoinThroughUnion()) + .setSystemProperty(PUSH_SEMI_JOIN_THROUGH_UNION, "true") + .on(p -> { + VariableReferenceExpression a = p.variable("a"); + VariableReferenceExpression b = p.variable("b"); + VariableReferenceExpression d = p.variable("d"); + VariableReferenceExpression c = p.variable("c"); + VariableReferenceExpression filterJoinVar = p.variable("filterJoinVar"); + VariableReferenceExpression semiJoinOutput = p.variable("semiJoinOutput", BOOLEAN); + return p.semiJoin( + c, + filterJoinVar, + semiJoinOutput, + Optional.empty(), + Optional.empty(), + p.union( + ImmutableListMultimap.builder() + .put(c, a) + .put(c, b) + .put(c, d) + .build(), + ImmutableList.of( + p.values(a), + p.values(b), + p.values(d))), + p.values(filterJoinVar)); + }) + .matches( + union( + semiJoin("a", "filterJoinVar", "semiJoinOutput_0", + values("a"), + values("filterJoinVar")), + semiJoin("b", "filterJoinVar", "semiJoinOutput_1", + values("b"), + values("filterJoinVar")), + semiJoin("d", "filterJoinVar", "semiJoinOutput_2", + values("d"), + values("filterJoinVar")))); + } + + @Test + public void testPushThroughProjectOverUnion() + { + FunctionResolution functionResolution = new FunctionResolution(tester().getMetadata().getFunctionAndTypeManager().getFunctionAndTypeResolver()); + tester().assertThat(new PushSemiJoinThroughUnion()) + .setSystemProperty(PUSH_SEMI_JOIN_THROUGH_UNION, "true") + .on(p -> { + VariableReferenceExpression a = p.variable("a"); + VariableReferenceExpression b = p.variable("b"); + VariableReferenceExpression c = p.variable("c"); + VariableReferenceExpression cTimes3 = p.variable("c_times_3"); + VariableReferenceExpression filterJoinVar = p.variable("filterJoinVar"); + VariableReferenceExpression semiJoinOutput = p.variable("semiJoinOutput", BOOLEAN); + return p.semiJoin( + cTimes3, + filterJoinVar, + semiJoinOutput, + Optional.empty(), + Optional.empty(), + p.project( + assignment( + cTimes3, + call("c * 3", functionResolution.arithmeticFunction(MULTIPLY, BIGINT, BIGINT), BIGINT, c, constant(3L, BIGINT))), + p.union( + ImmutableListMultimap.builder() + .put(c, a) + .put(c, b) + .build(), + ImmutableList.of( + p.values(a), + p.values(b)))), + p.values(filterJoinVar)); + }) + .matches( + union( + semiJoin( + project( + ImmutableMap.of("a_times_3", expression("a * 3")), + values("a")), + values("filterJoinVar")), + semiJoin( + project( + ImmutableMap.of("b_times_3", expression("b * 3")), + values("b")), + values("filterJoinVar")))); + } + + @Test + public void testPushThroughUnionWithHashVariables() + { + tester().assertThat(new PushSemiJoinThroughUnion()) + .setSystemProperty(PUSH_SEMI_JOIN_THROUGH_UNION, "true") + .on(p -> { + VariableReferenceExpression a = p.variable("a"); + VariableReferenceExpression b = p.variable("b"); + VariableReferenceExpression c = p.variable("c"); + VariableReferenceExpression aHash = p.variable("aHash"); + VariableReferenceExpression bHash = p.variable("bHash"); + VariableReferenceExpression cHash = p.variable("cHash"); + VariableReferenceExpression filterJoinVar = p.variable("filterJoinVar"); + VariableReferenceExpression filterHash = p.variable("filterHash"); + VariableReferenceExpression semiJoinOutput = p.variable("semiJoinOutput", BOOLEAN); + return p.semiJoin( + c, + filterJoinVar, + semiJoinOutput, + Optional.of(cHash), + Optional.of(filterHash), + p.union( + ImmutableListMultimap.builder() + .put(c, a) + .put(c, b) + .put(cHash, aHash) + .put(cHash, bHash) + .build(), + ImmutableList.of( + p.values(a, aHash), + p.values(b, bHash))), + p.values(filterJoinVar, filterHash)); + }) + .matches( + union( + semiJoin("a", "filterJoinVar", "semiJoinOutput_0", + values("a", "aHash"), + values("filterJoinVar", "filterHash")), + semiJoin("b", "filterJoinVar", "semiJoinOutput_1", + values("b", "bHash"), + values("filterJoinVar", "filterHash")))); + } + + @Test + public void testDoesNotFireWhenDisabled() + { + tester().assertThat(new PushSemiJoinThroughUnion()) + .on(p -> { + VariableReferenceExpression a = p.variable("a"); + VariableReferenceExpression b = p.variable("b"); + VariableReferenceExpression c = p.variable("c"); + VariableReferenceExpression filterJoinVar = p.variable("filterJoinVar"); + VariableReferenceExpression semiJoinOutput = p.variable("semiJoinOutput", BOOLEAN); + return p.semiJoin( + c, + filterJoinVar, + semiJoinOutput, + Optional.empty(), + Optional.empty(), + p.union( + ImmutableListMultimap.builder() + .put(c, a) + .put(c, b) + .build(), + ImmutableList.of( + p.values(a), + p.values(b))), + p.values(filterJoinVar)); + }) + .doesNotFire(); + } +}