diff --git a/presto-main/src/main/java/com/facebook/presto/SystemSessionProperties.java b/presto-main/src/main/java/com/facebook/presto/SystemSessionProperties.java index 795534157948b..be4654d515b6e 100644 --- a/presto-main/src/main/java/com/facebook/presto/SystemSessionProperties.java +++ b/presto-main/src/main/java/com/facebook/presto/SystemSessionProperties.java @@ -32,6 +32,7 @@ import com.facebook.presto.sql.analyzer.FeaturesConfig.JoinDistributionType; import com.facebook.presto.sql.analyzer.FeaturesConfig.JoinNotNullInferenceStrategy; import com.facebook.presto.sql.analyzer.FeaturesConfig.JoinReorderingStrategy; +import com.facebook.presto.sql.analyzer.FeaturesConfig.LeftJoinArrayContainsToInnerJoinStrategy; import com.facebook.presto.sql.analyzer.FeaturesConfig.PartialAggregationStrategy; import com.facebook.presto.sql.analyzer.FeaturesConfig.PartialMergePushdownStrategy; import com.facebook.presto.sql.analyzer.FeaturesConfig.PartitioningPrecisionStrategy; @@ -287,6 +288,7 @@ public final class SystemSessionProperties public static final String REWRITE_CROSS_JOIN_OR_TO_INNER_JOIN = "rewrite_cross_join_or_to_inner_join"; public static final String REWRITE_CROSS_JOIN_ARRAY_CONTAINS_TO_INNER_JOIN = "rewrite_cross_join_array_contains_to_inner_join"; public static final String REWRITE_CROSS_JOIN_ARRAY_NOT_CONTAINS_TO_ANTI_JOIN = "rewrite_cross_join_array_not_contains_to_anti_join"; + public static final String REWRITE_LEFT_JOIN_ARRAY_CONTAINS_TO_EQUI_JOIN = "rewrite_left_join_array_contains_to_equi_join"; public static final String REWRITE_LEFT_JOIN_NULL_FILTER_TO_SEMI_JOIN = "rewrite_left_join_null_filter_to_semi_join"; public static final String USE_BROADCAST_WHEN_BUILDSIZE_SMALL_PROBESIDE_UNKNOWN = "use_broadcast_when_buildsize_small_probeside_unknown"; public static final String ADD_PARTIAL_NODE_FOR_ROW_NUMBER_WITH_LIMIT = "add_partial_node_for_row_number_with_limit"; @@ -1734,6 +1736,18 @@ public SystemSessionProperties( "Rewrite cross join with array not contains filter to anti join", featuresConfig.isRewriteCrossJoinWithArrayNotContainsFilterToAntiJoin(), false), + new PropertyMetadata<>( + REWRITE_LEFT_JOIN_ARRAY_CONTAINS_TO_EQUI_JOIN, + format("Set the strategy used to convert left join with array contains to inner join. Options are: %s", + Stream.of(LeftJoinArrayContainsToInnerJoinStrategy.values()) + .map(LeftJoinArrayContainsToInnerJoinStrategy::name) + .collect(joining(","))), + VARCHAR, + LeftJoinArrayContainsToInnerJoinStrategy.class, + featuresConfig.getLeftJoinWithArrayContainsToEquiJoinStrategy(), + false, + value -> LeftJoinArrayContainsToInnerJoinStrategy.valueOf(((String) value).toUpperCase()), + LeftJoinArrayContainsToInnerJoinStrategy::name), new PropertyMetadata<>( JOINS_NOT_NULL_INFERENCE_STRATEGY, format("Set the strategy used NOT NULL filter inference on Join Nodes. Options are: %s", @@ -2955,6 +2969,11 @@ public static boolean isRewriteCrossJoinArrayNotContainsToAntiJoinEnabled(Sessio return session.getSystemProperty(REWRITE_CROSS_JOIN_ARRAY_NOT_CONTAINS_TO_ANTI_JOIN, Boolean.class); } + public static LeftJoinArrayContainsToInnerJoinStrategy getLeftJoinArrayContainsToInnerJoinStrategy(Session session) + { + return session.getSystemProperty(REWRITE_LEFT_JOIN_ARRAY_CONTAINS_TO_EQUI_JOIN, LeftJoinArrayContainsToInnerJoinStrategy.class); + } + public static boolean isRewriteLeftJoinNullFilterToSemiJoinEnabled(Session session) { return session.getSystemProperty(REWRITE_LEFT_JOIN_NULL_FILTER_TO_SEMI_JOIN, Boolean.class); diff --git a/presto-main/src/main/java/com/facebook/presto/sql/analyzer/FeaturesConfig.java b/presto-main/src/main/java/com/facebook/presto/sql/analyzer/FeaturesConfig.java index c228156135836..94b1b16ed4fde 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/analyzer/FeaturesConfig.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/analyzer/FeaturesConfig.java @@ -273,6 +273,7 @@ public class FeaturesConfig private PushDownFilterThroughCrossJoinStrategy pushDownFilterExpressionEvaluationThroughCrossJoin = PushDownFilterThroughCrossJoinStrategy.REWRITTEN_TO_INNER_JOIN; private boolean rewriteCrossJoinWithOrFilterToInnerJoin = true; private boolean rewriteCrossJoinWithArrayContainsFilterToInnerJoin = true; + private LeftJoinArrayContainsToInnerJoinStrategy leftJoinWithArrayContainsToEquiJoinStrategy = LeftJoinArrayContainsToInnerJoinStrategy.DISABLED; private boolean rewriteCrossJoinWithArrayNotContainsFilterToAntiJoin = true; private JoinNotNullInferenceStrategy joinNotNullInferenceStrategy = NONE; private boolean leftJoinNullFilterToSemiJoin = true; @@ -424,6 +425,13 @@ public enum JoinNotNullInferenceStrategy USE_FUNCTION_METADATA } + // TODO: Implement cost based strategy + public enum LeftJoinArrayContainsToInnerJoinStrategy + { + DISABLED, + ALWAYS_ENABLED + } + public double getCpuCostWeight() { return cpuCostWeight; @@ -2746,6 +2754,19 @@ public FeaturesConfig setRewriteCrossJoinWithArrayContainsFilterToInnerJoin(bool return this; } + public LeftJoinArrayContainsToInnerJoinStrategy getLeftJoinWithArrayContainsToEquiJoinStrategy() + { + return leftJoinWithArrayContainsToEquiJoinStrategy; + } + + @Config("optimizer.left-join-with-array-contains-to-equi-join-strategy") + @ConfigDescription("When to apply rewrite left join with array contains to equi join") + public FeaturesConfig setLeftJoinWithArrayContainsToEquiJoinStrategy(LeftJoinArrayContainsToInnerJoinStrategy leftJoinWithArrayContainsToEquiJoinStrategy) + { + this.leftJoinWithArrayContainsToEquiJoinStrategy = leftJoinWithArrayContainsToEquiJoinStrategy; + return this; + } + public boolean isRewriteCrossJoinWithArrayNotContainsFilterToAntiJoin() { return this.rewriteCrossJoinWithArrayNotContainsFilterToAntiJoin; diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/PlanOptimizers.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/PlanOptimizers.java index 57caee0d9b240..e13744b7aae32 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/PlanOptimizers.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/PlanOptimizers.java @@ -46,6 +46,7 @@ import com.facebook.presto.sql.planner.iterative.rule.InlineProjections; import com.facebook.presto.sql.planner.iterative.rule.InlineSqlFunctions; import com.facebook.presto.sql.planner.iterative.rule.LeftJoinNullFilterToSemiJoin; +import com.facebook.presto.sql.planner.iterative.rule.LeftJoinWithArrayContainsToEquiJoinCondition; import com.facebook.presto.sql.planner.iterative.rule.MergeDuplicateAggregation; import com.facebook.presto.sql.planner.iterative.rule.MergeFilters; import com.facebook.presto.sql.planner.iterative.rule.MergeLimitWithDistinct; @@ -522,6 +523,13 @@ public PlanOptimizers( new CrossJoinWithOrFilterToInnerJoin(metadata.getFunctionAndTypeManager()), new CrossJoinWithArrayContainsToInnerJoin(metadata.getFunctionAndTypeManager()), new CrossJoinWithArrayNotContainsToAntiJoin(metadata, metadata.getFunctionAndTypeManager()))), + new IterativeOptimizer( + metadata, + ruleStats, + statsCalculator, + estimatedExchangesCostCalculator, + ImmutableSet.of( + new LeftJoinWithArrayContainsToEquiJoinCondition(metadata.getFunctionAndTypeManager()))), new IterativeOptimizer( metadata, ruleStats, diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/LeftJoinWithArrayContainsToEquiJoinCondition.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/LeftJoinWithArrayContainsToEquiJoinCondition.java new file mode 100644 index 0000000000000..efccd0654c1ce --- /dev/null +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/LeftJoinWithArrayContainsToEquiJoinCondition.java @@ -0,0 +1,165 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.sql.planner.iterative.rule; + +import com.facebook.presto.Session; +import com.facebook.presto.common.type.ArrayType; +import com.facebook.presto.matching.Captures; +import com.facebook.presto.matching.Pattern; +import com.facebook.presto.metadata.FunctionAndTypeManager; +import com.facebook.presto.spi.plan.PlanNode; +import com.facebook.presto.spi.relation.CallExpression; +import com.facebook.presto.spi.relation.RowExpression; +import com.facebook.presto.spi.relation.VariableReferenceExpression; +import com.facebook.presto.sql.analyzer.FeaturesConfig.LeftJoinArrayContainsToInnerJoinStrategy; +import com.facebook.presto.sql.planner.PlannerUtils; +import com.facebook.presto.sql.planner.iterative.Rule; +import com.facebook.presto.sql.planner.plan.JoinNode; +import com.facebook.presto.sql.planner.plan.UnnestNode; +import com.facebook.presto.sql.relational.FunctionResolution; +import com.facebook.presto.sql.relational.RowExpressionDeterminismEvaluator; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; + +import java.util.List; +import java.util.Optional; + +import static com.facebook.presto.SystemSessionProperties.getLeftJoinArrayContainsToInnerJoinStrategy; +import static com.facebook.presto.expressions.LogicalRowExpressions.and; +import static com.facebook.presto.expressions.LogicalRowExpressions.extractConjuncts; +import static com.facebook.presto.sql.planner.VariablesExtractor.extractAll; +import static com.facebook.presto.sql.planner.plan.Patterns.join; +import static com.facebook.presto.sql.relational.Expressions.call; +import static com.google.common.base.Preconditions.checkState; +import static com.google.common.collect.ImmutableList.toImmutableList; +import static java.util.Objects.requireNonNull; + +/** + * When the join condition of a left join has pattern of contains(array, element) where array, we can rewrite it as a equi join condition. For example: + *
+ * - Left Join
+ *      empty join clause
+ *      filter: contains(r_array, l_key)
+ *      - scan l
+ *      - scan r
+ * 
+ * into: + *
+ *     - Left Join
+ *          l_key = field
+ *          - scan l
+ *          - Unnest
+ *              field <- unnest distinct_array
+ *              - project
+ *                  distinct_array := remove_nulls(array_distinct(r_array))
+ *                  - scan r
+ *                      r_array
+ * 
+ */ +public class LeftJoinWithArrayContainsToEquiJoinCondition + implements Rule +{ + private static final Pattern PATTERN = join().matching(x -> x.getType().equals(JoinNode.Type.LEFT) && x.getCriteria().isEmpty() && x.getFilter().isPresent()); + private final FunctionAndTypeManager functionAndTypeManager; + private final RowExpressionDeterminismEvaluator determinismEvaluator; + private final FunctionResolution functionResolution; + + public LeftJoinWithArrayContainsToEquiJoinCondition(FunctionAndTypeManager functionAndTypeManager) + { + this.functionAndTypeManager = requireNonNull(functionAndTypeManager, "functionAndTypeManager is null"); + this.determinismEvaluator = new RowExpressionDeterminismEvaluator(functionAndTypeManager); + this.functionResolution = new FunctionResolution(functionAndTypeManager.getFunctionAndTypeResolver()); + } + + @Override + public Pattern getPattern() + { + return PATTERN; + } + + @Override + public boolean isEnabled(Session session) + { + // TODO: implement cost based with HBO + return getLeftJoinArrayContainsToInnerJoinStrategy(session).equals(LeftJoinArrayContainsToInnerJoinStrategy.ALWAYS_ENABLED); + } + + @Override + public Result apply(JoinNode node, Captures captures, Context context) + { + RowExpression filterPredicate = node.getFilter().get(); + List leftInput = node.getLeft().getOutputVariables(); + List rightInput = node.getRight().getOutputVariables(); + List andConjuncts = extractConjuncts(filterPredicate); + Optional arrayContains = andConjuncts.stream().filter(rowExpression -> isSupportedJoinCondition(rowExpression, leftInput, rightInput)).findFirst(); + if (!arrayContains.isPresent()) { + return Result.empty(); + } + List remainingConjuncts = andConjuncts.stream().filter(rowExpression -> !rowExpression.equals(arrayContains.get())).collect(toImmutableList()); + RowExpression array = ((CallExpression) arrayContains.get()).getArguments().get(0); + RowExpression element = ((CallExpression) arrayContains.get()).getArguments().get(1); + checkState(array.getType() instanceof ArrayType && ((ArrayType) array.getType()).getElementType().equals(element.getType())); + + PlanNode newLeft = node.getLeft(); + ImmutableMap.Builder leftAssignment = ImmutableMap.builder(); + VariableReferenceExpression elementVariable; + if (!(element instanceof VariableReferenceExpression)) { + elementVariable = context.getVariableAllocator().newVariable(element); + leftAssignment.put(elementVariable, element); + newLeft = PlannerUtils.addProjections(node.getLeft(), context.getIdAllocator(), leftAssignment.build()); + } + else { + elementVariable = (VariableReferenceExpression) element; + } + + CallExpression arrayDistinct = call(functionAndTypeManager, "array_distinct", new ArrayType(element.getType()), array); + CallExpression arrayFilterNull = call(functionAndTypeManager, "remove_nulls", arrayDistinct.getType(), ImmutableList.of(arrayDistinct)); + VariableReferenceExpression arrayFilterNullVariable = context.getVariableAllocator().newVariable(arrayFilterNull); + PlanNode newRight = PlannerUtils.addProjections(node.getRight(), context.getIdAllocator(), ImmutableMap.of(arrayFilterNullVariable, arrayFilterNull)); + VariableReferenceExpression unnestVariable = context.getVariableAllocator().newVariable("unnest", element.getType()); + + UnnestNode unnestNode = new UnnestNode(newRight.getSourceLocation(), + context.getIdAllocator().getNextId(), + newRight, + newRight.getOutputVariables(), + ImmutableMap.of(arrayFilterNullVariable, ImmutableList.of(unnestVariable)), + Optional.empty()); + + JoinNode.EquiJoinClause equiJoinClause = new JoinNode.EquiJoinClause(elementVariable, unnestVariable); + + return Result.ofPlanNode(new JoinNode(node.getSourceLocation(), + context.getIdAllocator().getNextId(), + node.getType(), + newLeft, + unnestNode, + ImmutableList.of(equiJoinClause), + node.getOutputVariables(), + remainingConjuncts.isEmpty() ? Optional.empty() : Optional.of(and(remainingConjuncts)), + Optional.empty(), + Optional.empty(), + node.getDistributionType(), + node.getDynamicFilters())); + } + + private boolean isSupportedJoinCondition(RowExpression rowExpression, List leftInput, List rightInput) + { + if (rowExpression instanceof CallExpression && functionResolution.isArrayContainsFunction(((CallExpression) rowExpression).getFunctionHandle())) { + RowExpression arrayExpression = ((CallExpression) rowExpression).getArguments().get(0); + RowExpression elementExpression = ((CallExpression) rowExpression).getArguments().get(1); + return determinismEvaluator.isDeterministic(arrayExpression) && rightInput.containsAll(extractAll(arrayExpression)) + && determinismEvaluator.isDeterministic(elementExpression) && leftInput.containsAll(extractAll(elementExpression)); + } + return false; + } +} diff --git a/presto-main/src/test/java/com/facebook/presto/sql/analyzer/TestFeaturesConfig.java b/presto-main/src/test/java/com/facebook/presto/sql/analyzer/TestFeaturesConfig.java index c9246882ddca5..e08c8f1b7782d 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/analyzer/TestFeaturesConfig.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/analyzer/TestFeaturesConfig.java @@ -21,6 +21,7 @@ import com.facebook.presto.sql.analyzer.FeaturesConfig.AggregationIfToFilterRewriteStrategy; import com.facebook.presto.sql.analyzer.FeaturesConfig.JoinDistributionType; import com.facebook.presto.sql.analyzer.FeaturesConfig.JoinReorderingStrategy; +import com.facebook.presto.sql.analyzer.FeaturesConfig.LeftJoinArrayContainsToInnerJoinStrategy; import com.facebook.presto.sql.analyzer.FeaturesConfig.PartialAggregationStrategy; import com.facebook.presto.sql.analyzer.FeaturesConfig.PartitioningPrecisionStrategy; import com.facebook.presto.sql.analyzer.FeaturesConfig.PushDownFilterThroughCrossJoinStrategy; @@ -244,6 +245,7 @@ public void testDefaults() .setDefaultJoinSelectivityCoefficient(0) .setRewriteCrossJoinWithOrFilterToInnerJoin(true) .setRewriteCrossJoinWithArrayContainsFilterToInnerJoin(true) + .setLeftJoinWithArrayContainsToEquiJoinStrategy(LeftJoinArrayContainsToInnerJoinStrategy.DISABLED) .setRewriteCrossJoinWithArrayNotContainsFilterToAntiJoin(true) .setLeftJoinNullFilterToSemiJoin(true) .setBroadcastJoinWithSmallBuildUnknownProbe(false) @@ -444,6 +446,7 @@ public void testExplicitPropertyMappings() .put("optimizer.push-down-filter-expression-evaluation-through-cross-join", "DISABLED") .put("optimizer.rewrite-cross-join-with-or-filter-to-inner-join", "false") .put("optimizer.rewrite-cross-join-with-array-contains-filter-to-inner-join", "false") + .put("optimizer.left-join-with-array-contains-to-equi-join-strategy", "ALWAYS_ENABLED") .put("optimizer.rewrite-cross-join-with-array-not-contains-filter-to-anti-join", "false") .put("optimizer.default-join-selectivity-coefficient", "0.5") .put("optimizer.rewrite-left-join-with-null-filter-to-semi-join", "false") @@ -644,6 +647,7 @@ public void testExplicitPropertyMappings() .setPushDownFilterExpressionEvaluationThroughCrossJoin(PushDownFilterThroughCrossJoinStrategy.DISABLED) .setRewriteCrossJoinWithOrFilterToInnerJoin(false) .setRewriteCrossJoinWithArrayContainsFilterToInnerJoin(false) + .setLeftJoinWithArrayContainsToEquiJoinStrategy(LeftJoinArrayContainsToInnerJoinStrategy.ALWAYS_ENABLED) .setRewriteCrossJoinWithArrayNotContainsFilterToAntiJoin(false) .setLeftJoinNullFilterToSemiJoin(false) .setBroadcastJoinWithSmallBuildUnknownProbe(true) diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestLeftJoinWithArrayContainsToEquiJoinCondition.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestLeftJoinWithArrayContainsToEquiJoinCondition.java new file mode 100644 index 0000000000000..3e51a08d13443 --- /dev/null +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestLeftJoinWithArrayContainsToEquiJoinCondition.java @@ -0,0 +1,241 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.sql.planner.iterative.rule; + +import com.facebook.presto.common.type.ArrayType; +import com.facebook.presto.sql.planner.iterative.rule.test.BaseRuleTest; +import com.facebook.presto.sql.planner.plan.JoinNode; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ImmutableSet; +import org.testng.annotations.Test; + +import java.util.Optional; + +import static com.facebook.presto.SystemSessionProperties.REWRITE_LEFT_JOIN_ARRAY_CONTAINS_TO_EQUI_JOIN; +import static com.facebook.presto.common.type.BigintType.BIGINT; +import static com.facebook.presto.common.type.VarcharType.VARCHAR; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.equiJoinClause; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.expression; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.join; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.project; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.unnest; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.values; + +public class TestLeftJoinWithArrayContainsToEquiJoinCondition + extends BaseRuleTest +{ + @Test + public void testTriggerForBigIntArrayRightSide() + { + tester().assertThat(new LeftJoinWithArrayContainsToEquiJoinCondition(getMetadata().getFunctionAndTypeManager())) + .setSystemProperty(REWRITE_LEFT_JOIN_ARRAY_CONTAINS_TO_EQUI_JOIN, "ALWAYS_ENABLED") + .on(p -> + { + p.variable("left_k1", BIGINT); + p.variable("right_array_k1", new ArrayType(BIGINT)); + return + p.join(JoinNode.Type.LEFT, + p.values(p.variable("left_k1")), + p.values(p.variable("right_array_k1", new ArrayType(BIGINT))), + p.rowExpression("contains(right_array_k1, left_k1)")); + }).matches( + join( + JoinNode.Type.LEFT, + ImmutableList.of(equiJoinClause("left_k1", "unnest")), + values("left_k1"), + unnest( + ImmutableMap.of("array_distinct", ImmutableList.of("unnest")), + project( + ImmutableMap.of("array_distinct", expression("remove_nulls(array_distinct(right_array_k1))")), + values("right_array_k1"))))); + } + + @Test + public void testNotTriggerForArrayOnLeftSide() + { + tester().assertThat(new LeftJoinWithArrayContainsToEquiJoinCondition(getMetadata().getFunctionAndTypeManager())) + .setSystemProperty(REWRITE_LEFT_JOIN_ARRAY_CONTAINS_TO_EQUI_JOIN, "ALWAYS_ENABLED") + .on(p -> + { + p.variable("left_array_k1", new ArrayType(BIGINT)); + p.variable("right_k1", BIGINT); + return + p.join(JoinNode.Type.LEFT, + p.values(p.variable("left_array_k1", new ArrayType(BIGINT))), + p.values(p.variable("right_k1")), + p.rowExpression("contains(left_array_k1, right_k1)")); + }).doesNotFire(); + } + + @Test + public void testMultipleArrayContainsConditions() + { + tester().assertThat(new LeftJoinWithArrayContainsToEquiJoinCondition(getMetadata().getFunctionAndTypeManager())) + .setSystemProperty(REWRITE_LEFT_JOIN_ARRAY_CONTAINS_TO_EQUI_JOIN, "ALWAYS_ENABLED") + .on(p -> + { + p.variable("left_array_k1", new ArrayType(BIGINT)); + p.variable("left_k2", BIGINT); + p.variable("right_k1", BIGINT); + p.variable("right_array_k2", new ArrayType(BIGINT)); + return + p.join(JoinNode.Type.LEFT, + p.values(p.variable("left_array_k1", new ArrayType(BIGINT)), p.variable("left_k2")), + p.values(p.variable("right_k1"), p.variable("right_array_k2", new ArrayType(BIGINT))), + p.rowExpression("contains(left_array_k1, right_k1) and contains(right_array_k2, left_k2)")); + }).matches( + join( + JoinNode.Type.LEFT, + ImmutableList.of(equiJoinClause("left_k2", "unnest")), + Optional.of("contains(left_array_k1, right_k1)"), + values("left_array_k1", "left_k2"), + unnest( + ImmutableMap.of("array_distinct", ImmutableList.of("unnest")), + project( + ImmutableMap.of("array_distinct", expression("remove_nulls(array_distinct(right_array_k2))")), + values("right_k1", "right_array_k2"))))); + } + + @Test + public void testMultipleInvalidArrayContainsConditions() + { + tester().assertThat(new LeftJoinWithArrayContainsToEquiJoinCondition(getMetadata().getFunctionAndTypeManager())) + .setSystemProperty(REWRITE_LEFT_JOIN_ARRAY_CONTAINS_TO_EQUI_JOIN, "ALWAYS_ENABLED") + .on(p -> + { + p.variable("left_array_k1", new ArrayType(BIGINT)); + p.variable("left_k2", BIGINT); + p.variable("right_k1", BIGINT); + p.variable("right_array_k2", new ArrayType(BIGINT)); + return + p.join(JoinNode.Type.LEFT, + p.values(p.variable("left_array_k1", new ArrayType(BIGINT)), p.variable("left_k2")), + p.values(p.variable("right_k1"), p.variable("right_array_k2", new ArrayType(BIGINT))), + p.rowExpression("contains(left_array_k1, right_k1) or contains(right_array_k2, left_k2)")); + }).doesNotFire(); + } + + @Test + public void testArrayContainsWithCast() + { + tester().assertThat(new LeftJoinWithArrayContainsToEquiJoinCondition(getMetadata().getFunctionAndTypeManager())) + .setSystemProperty(REWRITE_LEFT_JOIN_ARRAY_CONTAINS_TO_EQUI_JOIN, "ALWAYS_ENABLED") + .on(p -> + { + p.variable("right_array_k1", new ArrayType(BIGINT)); + p.variable("left_k1", VARCHAR); + return p.join(JoinNode.Type.LEFT, + p.values(p.variable("left_k1", VARCHAR)), + p.values(p.variable("right_array_k1", new ArrayType(BIGINT))), + p.rowExpression("contains(right_array_k1, CAST(left_k1 AS BIGINT))")); + }).matches( + join( + JoinNode.Type.LEFT, + ImmutableList.of(equiJoinClause("cast_left", "unnest")), + project( + ImmutableMap.of("cast_left", expression("CAST(left_k1 AS BIGINT)")), + values("left_k1")), + unnest( + ImmutableMap.of("array_distinct", ImmutableList.of("unnest")), + project( + ImmutableMap.of("array_distinct", expression("remove_nulls(array_distinct(right_array_k1))")), + values("right_array_k1"))))); + } + + @Test + public void testArrayContainsWithCast2() + { + tester().assertThat(new LeftJoinWithArrayContainsToEquiJoinCondition(getMetadata().getFunctionAndTypeManager())) + .setSystemProperty(REWRITE_LEFT_JOIN_ARRAY_CONTAINS_TO_EQUI_JOIN, "ALWAYS_ENABLED") + .on(p -> + { + p.variable("right_array_k1", new ArrayType(BIGINT)); + p.variable("left_k1", VARCHAR); + return p.join(JoinNode.Type.LEFT, + p.values(p.variable("left_k1", VARCHAR)), + p.values(p.variable("right_array_k1", new ArrayType(BIGINT))), + p.rowExpression("contains(CAST(right_array_k1 AS ARRAY), left_k1)")); + }).matches( + join( + JoinNode.Type.LEFT, + ImmutableList.of(equiJoinClause("left_k1", "unnest")), + values("left_k1"), + unnest( + ImmutableMap.of("array_distinct", ImmutableList.of("unnest")), + project( + ImmutableMap.of("array_distinct", expression("remove_nulls(array_distinct(CAST(right_array_k1 AS ARRAY)))")), + values("right_array_k1"))))); + } + + @Test + public void testArrayContainsWithCoalesce() + { + tester().assertThat( + ImmutableSet.of( + new LeftJoinWithArrayContainsToEquiJoinCondition(getFunctionManager()))) + .setSystemProperty(REWRITE_LEFT_JOIN_ARRAY_CONTAINS_TO_EQUI_JOIN, "ALWAYS_ENABLED") + .on(p -> + { + p.variable("right_array_k1", new ArrayType(BIGINT)); + p.variable("left_k1", VARCHAR); + p.variable("left_k2", BIGINT); + return + p.join(JoinNode.Type.LEFT, + p.values(p.variable("left_k1", VARCHAR), p.variable("left_k2", BIGINT)), + p.values(p.variable("right_array_k1", new ArrayType(BIGINT))), + p.rowExpression("contains(right_array_k1, coalesce(CAST(left_k1 AS BIGINT), left_k2))")); + }).matches( + join( + JoinNode.Type.LEFT, + ImmutableList.of(equiJoinClause("expr", "unnest")), + project( + ImmutableMap.of("expr", expression("COALESCE(CAST(left_k1 AS bigint), left_k2)")), + values("left_k1", "left_k2")), + unnest( + ImmutableMap.of("array_distinct", ImmutableList.of("unnest")), + project( + ImmutableMap.of("array_distinct", expression("remove_nulls(array_distinct(right_array_k1))")), + values("right_array_k1"))))); + } + + @Test + public void testConditionWithAnd() + { + tester().assertThat(new LeftJoinWithArrayContainsToEquiJoinCondition(getMetadata().getFunctionAndTypeManager())) + .setSystemProperty(REWRITE_LEFT_JOIN_ARRAY_CONTAINS_TO_EQUI_JOIN, "ALWAYS_ENABLED") + .on(p -> + { + p.variable("right_array_k1", new ArrayType(BIGINT)); + p.variable("right_k2", BIGINT); + p.variable("left_k1", BIGINT); + p.variable("left_k2", BIGINT); + return + p.join(JoinNode.Type.LEFT, + p.values(p.variable("left_k1"), p.variable("left_k2")), + p.values(p.variable("right_array_k1", new ArrayType(BIGINT)), p.variable("right_k2")), + p.rowExpression("contains(right_array_k1, left_k1) and right_k2+left_k2 > 10")); + }).matches( + join( + JoinNode.Type.LEFT, + ImmutableList.of(equiJoinClause("left_k1", "unnest")), + Optional.of("right_k2+left_k2 > 10"), + values("left_k1", "left_k2"), + unnest( + ImmutableMap.of("array_distinct", ImmutableList.of("unnest")), + project( + ImmutableMap.of("array_distinct", expression("remove_nulls(array_distinct(right_array_k1))")), + values("right_array_k1", "right_k2"))))); + } +} diff --git a/presto-tests/src/main/java/com/facebook/presto/tests/AbstractTestQueries.java b/presto-tests/src/main/java/com/facebook/presto/tests/AbstractTestQueries.java index 45204b5c5b7b5..1ae45256fb507 100644 --- a/presto-tests/src/main/java/com/facebook/presto/tests/AbstractTestQueries.java +++ b/presto-tests/src/main/java/com/facebook/presto/tests/AbstractTestQueries.java @@ -77,6 +77,7 @@ import static com.facebook.presto.SystemSessionProperties.REWRITE_CROSS_JOIN_ARRAY_CONTAINS_TO_INNER_JOIN; import static com.facebook.presto.SystemSessionProperties.REWRITE_CROSS_JOIN_ARRAY_NOT_CONTAINS_TO_ANTI_JOIN; import static com.facebook.presto.SystemSessionProperties.REWRITE_CROSS_JOIN_OR_TO_INNER_JOIN; +import static com.facebook.presto.SystemSessionProperties.REWRITE_LEFT_JOIN_ARRAY_CONTAINS_TO_EQUI_JOIN; import static com.facebook.presto.SystemSessionProperties.REWRITE_LEFT_JOIN_NULL_FILTER_TO_SEMI_JOIN; import static com.facebook.presto.SystemSessionProperties.SIMPLIFY_PLAN_WITH_EMPTY_INPUT; import static com.facebook.presto.SystemSessionProperties.USE_DEFAULTS_FOR_CORRELATED_AGGREGATION_PUSHDOWN_THROUGH_OUTER_JOINS; @@ -6935,6 +6936,72 @@ public void testCrossJoinWithArrayContainsCondition() assertQuery(enableOptimization, sql, "values (1, 'JAPAN')"); } + @Test + public void testLeftJoinWithArrayContainsCondition() + { + Session enableOptimization = Session.builder(getSession()) + .setSystemProperty(REWRITE_LEFT_JOIN_ARRAY_CONTAINS_TO_EQUI_JOIN, "ALWAYS_ENABLED") + .build(); + + String sql = "with t1 as (select * from (values (array[1, 2, 3], 10), (array[4, 5, 6], 11)) t(arr, k)), t2 as (select * from (values (1, 'a'), (4, 'b')) t(k, v)) " + + "select t1.k, t2.k, t2.v from t2 left join t1 on contains(t1.arr, t2.k)"; + assertQuery(enableOptimization, sql, "values (10, 1, 'a'), (11, 4, 'b')"); + + sql = "with t1 as (select * from (values (array[1, 2, 3, null], 10), (array[4, 5, 6, null, null], 11)) t(arr, k)), t2 as (select * from (values (1, 'a'), (4, 'b')) t(k, v)) " + + "select t1.k, t2.k, t2.v from t2 left join t1 on contains(t1.arr, t2.k)"; + assertQuery(enableOptimization, sql, "values (10, 1, 'a'), (11, 4, 'b')"); + + sql = "with t1 as (select * from (values (array[1, 2, 3], 10), (array[4, 5, 6], 11), (array[null, 9], 12)) t(arr, k)), t2 as (select * from (values (1, 'a'), (4, 'b'), (null, 'c'), (9, 'd'), (8, 'd')) t(k, v)) " + + "select t1.k, t2.k, t2.v from t2 left join t1 on contains(t1.arr, t2.k)"; + assertQuery(enableOptimization, sql, "values (10, 1, 'a'), (11, 4, 'b'), (null, null, 'c'), (12, 9, 'd'), (null, 8, 'd')"); + + sql = "with t1 as (select * from (values (array[1, 2, 3, null, null], 10), (array[4, 5, 6, null, null], 11), (array[null, 9], 12)) t(arr, k)), t2 as (select * from (values (1, 'a'), (4, 'b'), (null, 'c'), (9, 'd'), (8, 'd')) t(k, v)) " + + "select t1.k, t2.k, t2.v from t2 left join t1 on contains(t1.arr, t2.k)"; + assertQuery(enableOptimization, sql, "values (10, 1, 'a'), (11, 4, 'b'), (null, null, 'c'), (12, 9, 'd'), (null, 8, 'd')"); + + sql = "with t1 as (select * from (values (array[1, 1, 3], 10), (array[4, 4, 6], 11)) t(arr, k)), t2 as (select * from (values (1, 'a'), (4, 'b')) t(k, v)) " + + "select t1.k, t2.k, t2.v from t2 left join t1 on contains(t1.arr, t2.k)"; + assertQuery(enableOptimization, sql, "values (10, 1, 'a'), (11, 4, 'b')"); + + sql = "with t1 as (select * from (values (array[1, 1, 3, null, null], 10), (array[4, 4, 6, null, null], 11)) t(arr, k)), t2 as (select * from (values (1, 'a'), (4, 'b')) t(k, v)) " + + "select t1.k, t2.k, t2.v from t2 left join t1 on contains(t1.arr, t2.k)"; + assertQuery(enableOptimization, sql, "values (10, 1, 'a'), (11, 4, 'b')"); + + sql = "with t1 as (select * from (values (array[1, null, 3], 10), (array[4, null, 6], 11)) t(arr, k)), t2 as (select * from (values (1, 'a'), (null, 'b')) t(k, v)) " + + "select t1.k, t2.k, t2.v from t2 left join t1 on contains(t1.arr, t2.k)"; + assertQuery(enableOptimization, sql, "values (10, 1, 'a'), (NULL, NULL, 'b')"); + + sql = "with t1 as (select * from (values (array[1, 2, 3], 10), (array[4, 5, 6], 11)) t(arr, k)), t2 as (select * from (values (1, 'a'), (4, 'b')) t(k, v)) " + + "select t1.k, t2.k, t2.v from t2 left join t1 on contains(t1.arr, t2.k) and t1.k > 10"; + assertQuery(enableOptimization, sql, "values (NULL, 1, 'a'), (11, 4, 'b')"); + + sql = "with t1 as (select * from (values (array[1, 2, 3], 1), (array[4, 5, 6], 11)) t(arr, k)), t2 as (select * from (values (1, 'a'), (4, 'b')) t(k, v)) " + + "select t1.k, t2.k, t2.v from t2 left join t1 on contains(t1.arr, t2.k) or t1.k = t2.k"; + assertQuery(enableOptimization, sql, "values (1, 1, 'a'), (11, 4, 'b')"); + + sql = "with t1 as (select array_agg(orderkey) orderkey, partkey from lineitem l where l.quantity < 5 group by partkey) " + + "select t1.partkey, o.orderkey, o.totalprice from orders o left join t1 on contains(t1.orderkey, o.orderkey) where o.totalprice < 2000"; + // Because the UDF has different names in H2, which is `array_contains` + String h2Sql = "with t1 as (select array_agg(orderkey) orderkey, partkey from lineitem l where l.quantity < 5 group by partkey) " + + "select t1.partkey, o.orderkey, o.totalprice from orders o left join t1 on array_contains(t1.orderkey, o.orderkey) where o.totalprice < 2000"; + assertQuery(enableOptimization, sql, h2Sql); + + sql = "with t1 as (select array_agg(orderkey) orderkey, partkey from lineitem l where l.quantity < 5 group by partkey) " + + "select t1.partkey, o.orderkey, o.totalprice from orders o left join t1 on contains(t1.orderkey, o.orderkey) and t1.partkey < o.orderkey where o.totalprice < 2000"; + h2Sql = "with t1 as (select array_agg(orderkey) orderkey, partkey from lineitem l where l.quantity < 5 group by partkey) " + + "select t1.partkey, o.orderkey, o.totalprice from orders o left join t1 on array_contains(t1.orderkey, o.orderkey) and t1.partkey < o.orderkey where o.totalprice < 2000"; + assertQuery(enableOptimization, sql, h2Sql); + + // Element type and array type does not match + sql = "with t1 as (select * from (values (array[cast(1 as bigint), 2, 3], 10), (array[4, 5, 6], 11)) t(arr, k)), t2 as (select * from (values (cast(1 as integer), 'a'), (4, 'b')) t(k, v)) " + + "select t1.k, t2.k, t2.v from t2 left join t1 on contains(t1.arr, t2.k)"; + assertQuery(enableOptimization, sql, "values (11, 4, 'b'), (10, 1, 'a')"); + + sql = "with t1 as (select * from (values (array[cast(1 as integer), 2, 3], 10), (array[4, 5, 6], 11)) t(arr, k)), t2 as (select * from (values (cast(1 as bigint), 'a'), (4, 'b')) t(k, v)) " + + "select t1.k, t2.k, t2.v from t2 left join t1 on contains(t1.arr, t2.k)"; + assertQuery(enableOptimization, sql, "values (11, 4, 'b'), (10, 1, 'a')"); + } + @Test public void testCrossJoinWithArrayNotContainsCondition() {