diff --git a/presto-main/src/main/java/com/facebook/presto/SystemSessionProperties.java b/presto-main/src/main/java/com/facebook/presto/SystemSessionProperties.java
index 795534157948b..be4654d515b6e 100644
--- a/presto-main/src/main/java/com/facebook/presto/SystemSessionProperties.java
+++ b/presto-main/src/main/java/com/facebook/presto/SystemSessionProperties.java
@@ -32,6 +32,7 @@
import com.facebook.presto.sql.analyzer.FeaturesConfig.JoinDistributionType;
import com.facebook.presto.sql.analyzer.FeaturesConfig.JoinNotNullInferenceStrategy;
import com.facebook.presto.sql.analyzer.FeaturesConfig.JoinReorderingStrategy;
+import com.facebook.presto.sql.analyzer.FeaturesConfig.LeftJoinArrayContainsToInnerJoinStrategy;
import com.facebook.presto.sql.analyzer.FeaturesConfig.PartialAggregationStrategy;
import com.facebook.presto.sql.analyzer.FeaturesConfig.PartialMergePushdownStrategy;
import com.facebook.presto.sql.analyzer.FeaturesConfig.PartitioningPrecisionStrategy;
@@ -287,6 +288,7 @@ public final class SystemSessionProperties
public static final String REWRITE_CROSS_JOIN_OR_TO_INNER_JOIN = "rewrite_cross_join_or_to_inner_join";
public static final String REWRITE_CROSS_JOIN_ARRAY_CONTAINS_TO_INNER_JOIN = "rewrite_cross_join_array_contains_to_inner_join";
public static final String REWRITE_CROSS_JOIN_ARRAY_NOT_CONTAINS_TO_ANTI_JOIN = "rewrite_cross_join_array_not_contains_to_anti_join";
+ public static final String REWRITE_LEFT_JOIN_ARRAY_CONTAINS_TO_EQUI_JOIN = "rewrite_left_join_array_contains_to_equi_join";
public static final String REWRITE_LEFT_JOIN_NULL_FILTER_TO_SEMI_JOIN = "rewrite_left_join_null_filter_to_semi_join";
public static final String USE_BROADCAST_WHEN_BUILDSIZE_SMALL_PROBESIDE_UNKNOWN = "use_broadcast_when_buildsize_small_probeside_unknown";
public static final String ADD_PARTIAL_NODE_FOR_ROW_NUMBER_WITH_LIMIT = "add_partial_node_for_row_number_with_limit";
@@ -1734,6 +1736,18 @@ public SystemSessionProperties(
"Rewrite cross join with array not contains filter to anti join",
featuresConfig.isRewriteCrossJoinWithArrayNotContainsFilterToAntiJoin(),
false),
+ new PropertyMetadata<>(
+ REWRITE_LEFT_JOIN_ARRAY_CONTAINS_TO_EQUI_JOIN,
+ format("Set the strategy used to convert left join with array contains to inner join. Options are: %s",
+ Stream.of(LeftJoinArrayContainsToInnerJoinStrategy.values())
+ .map(LeftJoinArrayContainsToInnerJoinStrategy::name)
+ .collect(joining(","))),
+ VARCHAR,
+ LeftJoinArrayContainsToInnerJoinStrategy.class,
+ featuresConfig.getLeftJoinWithArrayContainsToEquiJoinStrategy(),
+ false,
+ value -> LeftJoinArrayContainsToInnerJoinStrategy.valueOf(((String) value).toUpperCase()),
+ LeftJoinArrayContainsToInnerJoinStrategy::name),
new PropertyMetadata<>(
JOINS_NOT_NULL_INFERENCE_STRATEGY,
format("Set the strategy used NOT NULL filter inference on Join Nodes. Options are: %s",
@@ -2955,6 +2969,11 @@ public static boolean isRewriteCrossJoinArrayNotContainsToAntiJoinEnabled(Sessio
return session.getSystemProperty(REWRITE_CROSS_JOIN_ARRAY_NOT_CONTAINS_TO_ANTI_JOIN, Boolean.class);
}
+ public static LeftJoinArrayContainsToInnerJoinStrategy getLeftJoinArrayContainsToInnerJoinStrategy(Session session)
+ {
+ return session.getSystemProperty(REWRITE_LEFT_JOIN_ARRAY_CONTAINS_TO_EQUI_JOIN, LeftJoinArrayContainsToInnerJoinStrategy.class);
+ }
+
public static boolean isRewriteLeftJoinNullFilterToSemiJoinEnabled(Session session)
{
return session.getSystemProperty(REWRITE_LEFT_JOIN_NULL_FILTER_TO_SEMI_JOIN, Boolean.class);
diff --git a/presto-main/src/main/java/com/facebook/presto/sql/analyzer/FeaturesConfig.java b/presto-main/src/main/java/com/facebook/presto/sql/analyzer/FeaturesConfig.java
index c228156135836..94b1b16ed4fde 100644
--- a/presto-main/src/main/java/com/facebook/presto/sql/analyzer/FeaturesConfig.java
+++ b/presto-main/src/main/java/com/facebook/presto/sql/analyzer/FeaturesConfig.java
@@ -273,6 +273,7 @@ public class FeaturesConfig
private PushDownFilterThroughCrossJoinStrategy pushDownFilterExpressionEvaluationThroughCrossJoin = PushDownFilterThroughCrossJoinStrategy.REWRITTEN_TO_INNER_JOIN;
private boolean rewriteCrossJoinWithOrFilterToInnerJoin = true;
private boolean rewriteCrossJoinWithArrayContainsFilterToInnerJoin = true;
+ private LeftJoinArrayContainsToInnerJoinStrategy leftJoinWithArrayContainsToEquiJoinStrategy = LeftJoinArrayContainsToInnerJoinStrategy.DISABLED;
private boolean rewriteCrossJoinWithArrayNotContainsFilterToAntiJoin = true;
private JoinNotNullInferenceStrategy joinNotNullInferenceStrategy = NONE;
private boolean leftJoinNullFilterToSemiJoin = true;
@@ -424,6 +425,13 @@ public enum JoinNotNullInferenceStrategy
USE_FUNCTION_METADATA
}
+ // TODO: Implement cost based strategy
+ public enum LeftJoinArrayContainsToInnerJoinStrategy
+ {
+ DISABLED,
+ ALWAYS_ENABLED
+ }
+
public double getCpuCostWeight()
{
return cpuCostWeight;
@@ -2746,6 +2754,19 @@ public FeaturesConfig setRewriteCrossJoinWithArrayContainsFilterToInnerJoin(bool
return this;
}
+ public LeftJoinArrayContainsToInnerJoinStrategy getLeftJoinWithArrayContainsToEquiJoinStrategy()
+ {
+ return leftJoinWithArrayContainsToEquiJoinStrategy;
+ }
+
+ @Config("optimizer.left-join-with-array-contains-to-equi-join-strategy")
+ @ConfigDescription("When to apply rewrite left join with array contains to equi join")
+ public FeaturesConfig setLeftJoinWithArrayContainsToEquiJoinStrategy(LeftJoinArrayContainsToInnerJoinStrategy leftJoinWithArrayContainsToEquiJoinStrategy)
+ {
+ this.leftJoinWithArrayContainsToEquiJoinStrategy = leftJoinWithArrayContainsToEquiJoinStrategy;
+ return this;
+ }
+
public boolean isRewriteCrossJoinWithArrayNotContainsFilterToAntiJoin()
{
return this.rewriteCrossJoinWithArrayNotContainsFilterToAntiJoin;
diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/PlanOptimizers.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/PlanOptimizers.java
index 57caee0d9b240..e13744b7aae32 100644
--- a/presto-main/src/main/java/com/facebook/presto/sql/planner/PlanOptimizers.java
+++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/PlanOptimizers.java
@@ -46,6 +46,7 @@
import com.facebook.presto.sql.planner.iterative.rule.InlineProjections;
import com.facebook.presto.sql.planner.iterative.rule.InlineSqlFunctions;
import com.facebook.presto.sql.planner.iterative.rule.LeftJoinNullFilterToSemiJoin;
+import com.facebook.presto.sql.planner.iterative.rule.LeftJoinWithArrayContainsToEquiJoinCondition;
import com.facebook.presto.sql.planner.iterative.rule.MergeDuplicateAggregation;
import com.facebook.presto.sql.planner.iterative.rule.MergeFilters;
import com.facebook.presto.sql.planner.iterative.rule.MergeLimitWithDistinct;
@@ -522,6 +523,13 @@ public PlanOptimizers(
new CrossJoinWithOrFilterToInnerJoin(metadata.getFunctionAndTypeManager()),
new CrossJoinWithArrayContainsToInnerJoin(metadata.getFunctionAndTypeManager()),
new CrossJoinWithArrayNotContainsToAntiJoin(metadata, metadata.getFunctionAndTypeManager()))),
+ new IterativeOptimizer(
+ metadata,
+ ruleStats,
+ statsCalculator,
+ estimatedExchangesCostCalculator,
+ ImmutableSet.of(
+ new LeftJoinWithArrayContainsToEquiJoinCondition(metadata.getFunctionAndTypeManager()))),
new IterativeOptimizer(
metadata,
ruleStats,
diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/LeftJoinWithArrayContainsToEquiJoinCondition.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/LeftJoinWithArrayContainsToEquiJoinCondition.java
new file mode 100644
index 0000000000000..efccd0654c1ce
--- /dev/null
+++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/LeftJoinWithArrayContainsToEquiJoinCondition.java
@@ -0,0 +1,165 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package com.facebook.presto.sql.planner.iterative.rule;
+
+import com.facebook.presto.Session;
+import com.facebook.presto.common.type.ArrayType;
+import com.facebook.presto.matching.Captures;
+import com.facebook.presto.matching.Pattern;
+import com.facebook.presto.metadata.FunctionAndTypeManager;
+import com.facebook.presto.spi.plan.PlanNode;
+import com.facebook.presto.spi.relation.CallExpression;
+import com.facebook.presto.spi.relation.RowExpression;
+import com.facebook.presto.spi.relation.VariableReferenceExpression;
+import com.facebook.presto.sql.analyzer.FeaturesConfig.LeftJoinArrayContainsToInnerJoinStrategy;
+import com.facebook.presto.sql.planner.PlannerUtils;
+import com.facebook.presto.sql.planner.iterative.Rule;
+import com.facebook.presto.sql.planner.plan.JoinNode;
+import com.facebook.presto.sql.planner.plan.UnnestNode;
+import com.facebook.presto.sql.relational.FunctionResolution;
+import com.facebook.presto.sql.relational.RowExpressionDeterminismEvaluator;
+import com.google.common.collect.ImmutableList;
+import com.google.common.collect.ImmutableMap;
+
+import java.util.List;
+import java.util.Optional;
+
+import static com.facebook.presto.SystemSessionProperties.getLeftJoinArrayContainsToInnerJoinStrategy;
+import static com.facebook.presto.expressions.LogicalRowExpressions.and;
+import static com.facebook.presto.expressions.LogicalRowExpressions.extractConjuncts;
+import static com.facebook.presto.sql.planner.VariablesExtractor.extractAll;
+import static com.facebook.presto.sql.planner.plan.Patterns.join;
+import static com.facebook.presto.sql.relational.Expressions.call;
+import static com.google.common.base.Preconditions.checkState;
+import static com.google.common.collect.ImmutableList.toImmutableList;
+import static java.util.Objects.requireNonNull;
+
+/**
+ * When the join condition of a left join has pattern of contains(array, element) where array, we can rewrite it as a equi join condition. For example:
+ *
+ * - Left Join
+ * empty join clause
+ * filter: contains(r_array, l_key)
+ * - scan l
+ * - scan r
+ *
+ * into:
+ *
+ * - Left Join
+ * l_key = field
+ * - scan l
+ * - Unnest
+ * field <- unnest distinct_array
+ * - project
+ * distinct_array := remove_nulls(array_distinct(r_array))
+ * - scan r
+ * r_array
+ *
+ */
+public class LeftJoinWithArrayContainsToEquiJoinCondition
+ implements Rule
+{
+ private static final Pattern PATTERN = join().matching(x -> x.getType().equals(JoinNode.Type.LEFT) && x.getCriteria().isEmpty() && x.getFilter().isPresent());
+ private final FunctionAndTypeManager functionAndTypeManager;
+ private final RowExpressionDeterminismEvaluator determinismEvaluator;
+ private final FunctionResolution functionResolution;
+
+ public LeftJoinWithArrayContainsToEquiJoinCondition(FunctionAndTypeManager functionAndTypeManager)
+ {
+ this.functionAndTypeManager = requireNonNull(functionAndTypeManager, "functionAndTypeManager is null");
+ this.determinismEvaluator = new RowExpressionDeterminismEvaluator(functionAndTypeManager);
+ this.functionResolution = new FunctionResolution(functionAndTypeManager.getFunctionAndTypeResolver());
+ }
+
+ @Override
+ public Pattern getPattern()
+ {
+ return PATTERN;
+ }
+
+ @Override
+ public boolean isEnabled(Session session)
+ {
+ // TODO: implement cost based with HBO
+ return getLeftJoinArrayContainsToInnerJoinStrategy(session).equals(LeftJoinArrayContainsToInnerJoinStrategy.ALWAYS_ENABLED);
+ }
+
+ @Override
+ public Result apply(JoinNode node, Captures captures, Context context)
+ {
+ RowExpression filterPredicate = node.getFilter().get();
+ List leftInput = node.getLeft().getOutputVariables();
+ List rightInput = node.getRight().getOutputVariables();
+ List andConjuncts = extractConjuncts(filterPredicate);
+ Optional arrayContains = andConjuncts.stream().filter(rowExpression -> isSupportedJoinCondition(rowExpression, leftInput, rightInput)).findFirst();
+ if (!arrayContains.isPresent()) {
+ return Result.empty();
+ }
+ List remainingConjuncts = andConjuncts.stream().filter(rowExpression -> !rowExpression.equals(arrayContains.get())).collect(toImmutableList());
+ RowExpression array = ((CallExpression) arrayContains.get()).getArguments().get(0);
+ RowExpression element = ((CallExpression) arrayContains.get()).getArguments().get(1);
+ checkState(array.getType() instanceof ArrayType && ((ArrayType) array.getType()).getElementType().equals(element.getType()));
+
+ PlanNode newLeft = node.getLeft();
+ ImmutableMap.Builder leftAssignment = ImmutableMap.builder();
+ VariableReferenceExpression elementVariable;
+ if (!(element instanceof VariableReferenceExpression)) {
+ elementVariable = context.getVariableAllocator().newVariable(element);
+ leftAssignment.put(elementVariable, element);
+ newLeft = PlannerUtils.addProjections(node.getLeft(), context.getIdAllocator(), leftAssignment.build());
+ }
+ else {
+ elementVariable = (VariableReferenceExpression) element;
+ }
+
+ CallExpression arrayDistinct = call(functionAndTypeManager, "array_distinct", new ArrayType(element.getType()), array);
+ CallExpression arrayFilterNull = call(functionAndTypeManager, "remove_nulls", arrayDistinct.getType(), ImmutableList.of(arrayDistinct));
+ VariableReferenceExpression arrayFilterNullVariable = context.getVariableAllocator().newVariable(arrayFilterNull);
+ PlanNode newRight = PlannerUtils.addProjections(node.getRight(), context.getIdAllocator(), ImmutableMap.of(arrayFilterNullVariable, arrayFilterNull));
+ VariableReferenceExpression unnestVariable = context.getVariableAllocator().newVariable("unnest", element.getType());
+
+ UnnestNode unnestNode = new UnnestNode(newRight.getSourceLocation(),
+ context.getIdAllocator().getNextId(),
+ newRight,
+ newRight.getOutputVariables(),
+ ImmutableMap.of(arrayFilterNullVariable, ImmutableList.of(unnestVariable)),
+ Optional.empty());
+
+ JoinNode.EquiJoinClause equiJoinClause = new JoinNode.EquiJoinClause(elementVariable, unnestVariable);
+
+ return Result.ofPlanNode(new JoinNode(node.getSourceLocation(),
+ context.getIdAllocator().getNextId(),
+ node.getType(),
+ newLeft,
+ unnestNode,
+ ImmutableList.of(equiJoinClause),
+ node.getOutputVariables(),
+ remainingConjuncts.isEmpty() ? Optional.empty() : Optional.of(and(remainingConjuncts)),
+ Optional.empty(),
+ Optional.empty(),
+ node.getDistributionType(),
+ node.getDynamicFilters()));
+ }
+
+ private boolean isSupportedJoinCondition(RowExpression rowExpression, List leftInput, List rightInput)
+ {
+ if (rowExpression instanceof CallExpression && functionResolution.isArrayContainsFunction(((CallExpression) rowExpression).getFunctionHandle())) {
+ RowExpression arrayExpression = ((CallExpression) rowExpression).getArguments().get(0);
+ RowExpression elementExpression = ((CallExpression) rowExpression).getArguments().get(1);
+ return determinismEvaluator.isDeterministic(arrayExpression) && rightInput.containsAll(extractAll(arrayExpression))
+ && determinismEvaluator.isDeterministic(elementExpression) && leftInput.containsAll(extractAll(elementExpression));
+ }
+ return false;
+ }
+}
diff --git a/presto-main/src/test/java/com/facebook/presto/sql/analyzer/TestFeaturesConfig.java b/presto-main/src/test/java/com/facebook/presto/sql/analyzer/TestFeaturesConfig.java
index c9246882ddca5..e08c8f1b7782d 100644
--- a/presto-main/src/test/java/com/facebook/presto/sql/analyzer/TestFeaturesConfig.java
+++ b/presto-main/src/test/java/com/facebook/presto/sql/analyzer/TestFeaturesConfig.java
@@ -21,6 +21,7 @@
import com.facebook.presto.sql.analyzer.FeaturesConfig.AggregationIfToFilterRewriteStrategy;
import com.facebook.presto.sql.analyzer.FeaturesConfig.JoinDistributionType;
import com.facebook.presto.sql.analyzer.FeaturesConfig.JoinReorderingStrategy;
+import com.facebook.presto.sql.analyzer.FeaturesConfig.LeftJoinArrayContainsToInnerJoinStrategy;
import com.facebook.presto.sql.analyzer.FeaturesConfig.PartialAggregationStrategy;
import com.facebook.presto.sql.analyzer.FeaturesConfig.PartitioningPrecisionStrategy;
import com.facebook.presto.sql.analyzer.FeaturesConfig.PushDownFilterThroughCrossJoinStrategy;
@@ -244,6 +245,7 @@ public void testDefaults()
.setDefaultJoinSelectivityCoefficient(0)
.setRewriteCrossJoinWithOrFilterToInnerJoin(true)
.setRewriteCrossJoinWithArrayContainsFilterToInnerJoin(true)
+ .setLeftJoinWithArrayContainsToEquiJoinStrategy(LeftJoinArrayContainsToInnerJoinStrategy.DISABLED)
.setRewriteCrossJoinWithArrayNotContainsFilterToAntiJoin(true)
.setLeftJoinNullFilterToSemiJoin(true)
.setBroadcastJoinWithSmallBuildUnknownProbe(false)
@@ -444,6 +446,7 @@ public void testExplicitPropertyMappings()
.put("optimizer.push-down-filter-expression-evaluation-through-cross-join", "DISABLED")
.put("optimizer.rewrite-cross-join-with-or-filter-to-inner-join", "false")
.put("optimizer.rewrite-cross-join-with-array-contains-filter-to-inner-join", "false")
+ .put("optimizer.left-join-with-array-contains-to-equi-join-strategy", "ALWAYS_ENABLED")
.put("optimizer.rewrite-cross-join-with-array-not-contains-filter-to-anti-join", "false")
.put("optimizer.default-join-selectivity-coefficient", "0.5")
.put("optimizer.rewrite-left-join-with-null-filter-to-semi-join", "false")
@@ -644,6 +647,7 @@ public void testExplicitPropertyMappings()
.setPushDownFilterExpressionEvaluationThroughCrossJoin(PushDownFilterThroughCrossJoinStrategy.DISABLED)
.setRewriteCrossJoinWithOrFilterToInnerJoin(false)
.setRewriteCrossJoinWithArrayContainsFilterToInnerJoin(false)
+ .setLeftJoinWithArrayContainsToEquiJoinStrategy(LeftJoinArrayContainsToInnerJoinStrategy.ALWAYS_ENABLED)
.setRewriteCrossJoinWithArrayNotContainsFilterToAntiJoin(false)
.setLeftJoinNullFilterToSemiJoin(false)
.setBroadcastJoinWithSmallBuildUnknownProbe(true)
diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestLeftJoinWithArrayContainsToEquiJoinCondition.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestLeftJoinWithArrayContainsToEquiJoinCondition.java
new file mode 100644
index 0000000000000..3e51a08d13443
--- /dev/null
+++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestLeftJoinWithArrayContainsToEquiJoinCondition.java
@@ -0,0 +1,241 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package com.facebook.presto.sql.planner.iterative.rule;
+
+import com.facebook.presto.common.type.ArrayType;
+import com.facebook.presto.sql.planner.iterative.rule.test.BaseRuleTest;
+import com.facebook.presto.sql.planner.plan.JoinNode;
+import com.google.common.collect.ImmutableList;
+import com.google.common.collect.ImmutableMap;
+import com.google.common.collect.ImmutableSet;
+import org.testng.annotations.Test;
+
+import java.util.Optional;
+
+import static com.facebook.presto.SystemSessionProperties.REWRITE_LEFT_JOIN_ARRAY_CONTAINS_TO_EQUI_JOIN;
+import static com.facebook.presto.common.type.BigintType.BIGINT;
+import static com.facebook.presto.common.type.VarcharType.VARCHAR;
+import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.equiJoinClause;
+import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.expression;
+import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.join;
+import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.project;
+import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.unnest;
+import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.values;
+
+public class TestLeftJoinWithArrayContainsToEquiJoinCondition
+ extends BaseRuleTest
+{
+ @Test
+ public void testTriggerForBigIntArrayRightSide()
+ {
+ tester().assertThat(new LeftJoinWithArrayContainsToEquiJoinCondition(getMetadata().getFunctionAndTypeManager()))
+ .setSystemProperty(REWRITE_LEFT_JOIN_ARRAY_CONTAINS_TO_EQUI_JOIN, "ALWAYS_ENABLED")
+ .on(p ->
+ {
+ p.variable("left_k1", BIGINT);
+ p.variable("right_array_k1", new ArrayType(BIGINT));
+ return
+ p.join(JoinNode.Type.LEFT,
+ p.values(p.variable("left_k1")),
+ p.values(p.variable("right_array_k1", new ArrayType(BIGINT))),
+ p.rowExpression("contains(right_array_k1, left_k1)"));
+ }).matches(
+ join(
+ JoinNode.Type.LEFT,
+ ImmutableList.of(equiJoinClause("left_k1", "unnest")),
+ values("left_k1"),
+ unnest(
+ ImmutableMap.of("array_distinct", ImmutableList.of("unnest")),
+ project(
+ ImmutableMap.of("array_distinct", expression("remove_nulls(array_distinct(right_array_k1))")),
+ values("right_array_k1")))));
+ }
+
+ @Test
+ public void testNotTriggerForArrayOnLeftSide()
+ {
+ tester().assertThat(new LeftJoinWithArrayContainsToEquiJoinCondition(getMetadata().getFunctionAndTypeManager()))
+ .setSystemProperty(REWRITE_LEFT_JOIN_ARRAY_CONTAINS_TO_EQUI_JOIN, "ALWAYS_ENABLED")
+ .on(p ->
+ {
+ p.variable("left_array_k1", new ArrayType(BIGINT));
+ p.variable("right_k1", BIGINT);
+ return
+ p.join(JoinNode.Type.LEFT,
+ p.values(p.variable("left_array_k1", new ArrayType(BIGINT))),
+ p.values(p.variable("right_k1")),
+ p.rowExpression("contains(left_array_k1, right_k1)"));
+ }).doesNotFire();
+ }
+
+ @Test
+ public void testMultipleArrayContainsConditions()
+ {
+ tester().assertThat(new LeftJoinWithArrayContainsToEquiJoinCondition(getMetadata().getFunctionAndTypeManager()))
+ .setSystemProperty(REWRITE_LEFT_JOIN_ARRAY_CONTAINS_TO_EQUI_JOIN, "ALWAYS_ENABLED")
+ .on(p ->
+ {
+ p.variable("left_array_k1", new ArrayType(BIGINT));
+ p.variable("left_k2", BIGINT);
+ p.variable("right_k1", BIGINT);
+ p.variable("right_array_k2", new ArrayType(BIGINT));
+ return
+ p.join(JoinNode.Type.LEFT,
+ p.values(p.variable("left_array_k1", new ArrayType(BIGINT)), p.variable("left_k2")),
+ p.values(p.variable("right_k1"), p.variable("right_array_k2", new ArrayType(BIGINT))),
+ p.rowExpression("contains(left_array_k1, right_k1) and contains(right_array_k2, left_k2)"));
+ }).matches(
+ join(
+ JoinNode.Type.LEFT,
+ ImmutableList.of(equiJoinClause("left_k2", "unnest")),
+ Optional.of("contains(left_array_k1, right_k1)"),
+ values("left_array_k1", "left_k2"),
+ unnest(
+ ImmutableMap.of("array_distinct", ImmutableList.of("unnest")),
+ project(
+ ImmutableMap.of("array_distinct", expression("remove_nulls(array_distinct(right_array_k2))")),
+ values("right_k1", "right_array_k2")))));
+ }
+
+ @Test
+ public void testMultipleInvalidArrayContainsConditions()
+ {
+ tester().assertThat(new LeftJoinWithArrayContainsToEquiJoinCondition(getMetadata().getFunctionAndTypeManager()))
+ .setSystemProperty(REWRITE_LEFT_JOIN_ARRAY_CONTAINS_TO_EQUI_JOIN, "ALWAYS_ENABLED")
+ .on(p ->
+ {
+ p.variable("left_array_k1", new ArrayType(BIGINT));
+ p.variable("left_k2", BIGINT);
+ p.variable("right_k1", BIGINT);
+ p.variable("right_array_k2", new ArrayType(BIGINT));
+ return
+ p.join(JoinNode.Type.LEFT,
+ p.values(p.variable("left_array_k1", new ArrayType(BIGINT)), p.variable("left_k2")),
+ p.values(p.variable("right_k1"), p.variable("right_array_k2", new ArrayType(BIGINT))),
+ p.rowExpression("contains(left_array_k1, right_k1) or contains(right_array_k2, left_k2)"));
+ }).doesNotFire();
+ }
+
+ @Test
+ public void testArrayContainsWithCast()
+ {
+ tester().assertThat(new LeftJoinWithArrayContainsToEquiJoinCondition(getMetadata().getFunctionAndTypeManager()))
+ .setSystemProperty(REWRITE_LEFT_JOIN_ARRAY_CONTAINS_TO_EQUI_JOIN, "ALWAYS_ENABLED")
+ .on(p ->
+ {
+ p.variable("right_array_k1", new ArrayType(BIGINT));
+ p.variable("left_k1", VARCHAR);
+ return p.join(JoinNode.Type.LEFT,
+ p.values(p.variable("left_k1", VARCHAR)),
+ p.values(p.variable("right_array_k1", new ArrayType(BIGINT))),
+ p.rowExpression("contains(right_array_k1, CAST(left_k1 AS BIGINT))"));
+ }).matches(
+ join(
+ JoinNode.Type.LEFT,
+ ImmutableList.of(equiJoinClause("cast_left", "unnest")),
+ project(
+ ImmutableMap.of("cast_left", expression("CAST(left_k1 AS BIGINT)")),
+ values("left_k1")),
+ unnest(
+ ImmutableMap.of("array_distinct", ImmutableList.of("unnest")),
+ project(
+ ImmutableMap.of("array_distinct", expression("remove_nulls(array_distinct(right_array_k1))")),
+ values("right_array_k1")))));
+ }
+
+ @Test
+ public void testArrayContainsWithCast2()
+ {
+ tester().assertThat(new LeftJoinWithArrayContainsToEquiJoinCondition(getMetadata().getFunctionAndTypeManager()))
+ .setSystemProperty(REWRITE_LEFT_JOIN_ARRAY_CONTAINS_TO_EQUI_JOIN, "ALWAYS_ENABLED")
+ .on(p ->
+ {
+ p.variable("right_array_k1", new ArrayType(BIGINT));
+ p.variable("left_k1", VARCHAR);
+ return p.join(JoinNode.Type.LEFT,
+ p.values(p.variable("left_k1", VARCHAR)),
+ p.values(p.variable("right_array_k1", new ArrayType(BIGINT))),
+ p.rowExpression("contains(CAST(right_array_k1 AS ARRAY), left_k1)"));
+ }).matches(
+ join(
+ JoinNode.Type.LEFT,
+ ImmutableList.of(equiJoinClause("left_k1", "unnest")),
+ values("left_k1"),
+ unnest(
+ ImmutableMap.of("array_distinct", ImmutableList.of("unnest")),
+ project(
+ ImmutableMap.of("array_distinct", expression("remove_nulls(array_distinct(CAST(right_array_k1 AS ARRAY)))")),
+ values("right_array_k1")))));
+ }
+
+ @Test
+ public void testArrayContainsWithCoalesce()
+ {
+ tester().assertThat(
+ ImmutableSet.of(
+ new LeftJoinWithArrayContainsToEquiJoinCondition(getFunctionManager())))
+ .setSystemProperty(REWRITE_LEFT_JOIN_ARRAY_CONTAINS_TO_EQUI_JOIN, "ALWAYS_ENABLED")
+ .on(p ->
+ {
+ p.variable("right_array_k1", new ArrayType(BIGINT));
+ p.variable("left_k1", VARCHAR);
+ p.variable("left_k2", BIGINT);
+ return
+ p.join(JoinNode.Type.LEFT,
+ p.values(p.variable("left_k1", VARCHAR), p.variable("left_k2", BIGINT)),
+ p.values(p.variable("right_array_k1", new ArrayType(BIGINT))),
+ p.rowExpression("contains(right_array_k1, coalesce(CAST(left_k1 AS BIGINT), left_k2))"));
+ }).matches(
+ join(
+ JoinNode.Type.LEFT,
+ ImmutableList.of(equiJoinClause("expr", "unnest")),
+ project(
+ ImmutableMap.of("expr", expression("COALESCE(CAST(left_k1 AS bigint), left_k2)")),
+ values("left_k1", "left_k2")),
+ unnest(
+ ImmutableMap.of("array_distinct", ImmutableList.of("unnest")),
+ project(
+ ImmutableMap.of("array_distinct", expression("remove_nulls(array_distinct(right_array_k1))")),
+ values("right_array_k1")))));
+ }
+
+ @Test
+ public void testConditionWithAnd()
+ {
+ tester().assertThat(new LeftJoinWithArrayContainsToEquiJoinCondition(getMetadata().getFunctionAndTypeManager()))
+ .setSystemProperty(REWRITE_LEFT_JOIN_ARRAY_CONTAINS_TO_EQUI_JOIN, "ALWAYS_ENABLED")
+ .on(p ->
+ {
+ p.variable("right_array_k1", new ArrayType(BIGINT));
+ p.variable("right_k2", BIGINT);
+ p.variable("left_k1", BIGINT);
+ p.variable("left_k2", BIGINT);
+ return
+ p.join(JoinNode.Type.LEFT,
+ p.values(p.variable("left_k1"), p.variable("left_k2")),
+ p.values(p.variable("right_array_k1", new ArrayType(BIGINT)), p.variable("right_k2")),
+ p.rowExpression("contains(right_array_k1, left_k1) and right_k2+left_k2 > 10"));
+ }).matches(
+ join(
+ JoinNode.Type.LEFT,
+ ImmutableList.of(equiJoinClause("left_k1", "unnest")),
+ Optional.of("right_k2+left_k2 > 10"),
+ values("left_k1", "left_k2"),
+ unnest(
+ ImmutableMap.of("array_distinct", ImmutableList.of("unnest")),
+ project(
+ ImmutableMap.of("array_distinct", expression("remove_nulls(array_distinct(right_array_k1))")),
+ values("right_array_k1", "right_k2")))));
+ }
+}
diff --git a/presto-tests/src/main/java/com/facebook/presto/tests/AbstractTestQueries.java b/presto-tests/src/main/java/com/facebook/presto/tests/AbstractTestQueries.java
index 45204b5c5b7b5..1ae45256fb507 100644
--- a/presto-tests/src/main/java/com/facebook/presto/tests/AbstractTestQueries.java
+++ b/presto-tests/src/main/java/com/facebook/presto/tests/AbstractTestQueries.java
@@ -77,6 +77,7 @@
import static com.facebook.presto.SystemSessionProperties.REWRITE_CROSS_JOIN_ARRAY_CONTAINS_TO_INNER_JOIN;
import static com.facebook.presto.SystemSessionProperties.REWRITE_CROSS_JOIN_ARRAY_NOT_CONTAINS_TO_ANTI_JOIN;
import static com.facebook.presto.SystemSessionProperties.REWRITE_CROSS_JOIN_OR_TO_INNER_JOIN;
+import static com.facebook.presto.SystemSessionProperties.REWRITE_LEFT_JOIN_ARRAY_CONTAINS_TO_EQUI_JOIN;
import static com.facebook.presto.SystemSessionProperties.REWRITE_LEFT_JOIN_NULL_FILTER_TO_SEMI_JOIN;
import static com.facebook.presto.SystemSessionProperties.SIMPLIFY_PLAN_WITH_EMPTY_INPUT;
import static com.facebook.presto.SystemSessionProperties.USE_DEFAULTS_FOR_CORRELATED_AGGREGATION_PUSHDOWN_THROUGH_OUTER_JOINS;
@@ -6935,6 +6936,72 @@ public void testCrossJoinWithArrayContainsCondition()
assertQuery(enableOptimization, sql, "values (1, 'JAPAN')");
}
+ @Test
+ public void testLeftJoinWithArrayContainsCondition()
+ {
+ Session enableOptimization = Session.builder(getSession())
+ .setSystemProperty(REWRITE_LEFT_JOIN_ARRAY_CONTAINS_TO_EQUI_JOIN, "ALWAYS_ENABLED")
+ .build();
+
+ String sql = "with t1 as (select * from (values (array[1, 2, 3], 10), (array[4, 5, 6], 11)) t(arr, k)), t2 as (select * from (values (1, 'a'), (4, 'b')) t(k, v)) " +
+ "select t1.k, t2.k, t2.v from t2 left join t1 on contains(t1.arr, t2.k)";
+ assertQuery(enableOptimization, sql, "values (10, 1, 'a'), (11, 4, 'b')");
+
+ sql = "with t1 as (select * from (values (array[1, 2, 3, null], 10), (array[4, 5, 6, null, null], 11)) t(arr, k)), t2 as (select * from (values (1, 'a'), (4, 'b')) t(k, v)) " +
+ "select t1.k, t2.k, t2.v from t2 left join t1 on contains(t1.arr, t2.k)";
+ assertQuery(enableOptimization, sql, "values (10, 1, 'a'), (11, 4, 'b')");
+
+ sql = "with t1 as (select * from (values (array[1, 2, 3], 10), (array[4, 5, 6], 11), (array[null, 9], 12)) t(arr, k)), t2 as (select * from (values (1, 'a'), (4, 'b'), (null, 'c'), (9, 'd'), (8, 'd')) t(k, v)) " +
+ "select t1.k, t2.k, t2.v from t2 left join t1 on contains(t1.arr, t2.k)";
+ assertQuery(enableOptimization, sql, "values (10, 1, 'a'), (11, 4, 'b'), (null, null, 'c'), (12, 9, 'd'), (null, 8, 'd')");
+
+ sql = "with t1 as (select * from (values (array[1, 2, 3, null, null], 10), (array[4, 5, 6, null, null], 11), (array[null, 9], 12)) t(arr, k)), t2 as (select * from (values (1, 'a'), (4, 'b'), (null, 'c'), (9, 'd'), (8, 'd')) t(k, v)) " +
+ "select t1.k, t2.k, t2.v from t2 left join t1 on contains(t1.arr, t2.k)";
+ assertQuery(enableOptimization, sql, "values (10, 1, 'a'), (11, 4, 'b'), (null, null, 'c'), (12, 9, 'd'), (null, 8, 'd')");
+
+ sql = "with t1 as (select * from (values (array[1, 1, 3], 10), (array[4, 4, 6], 11)) t(arr, k)), t2 as (select * from (values (1, 'a'), (4, 'b')) t(k, v)) " +
+ "select t1.k, t2.k, t2.v from t2 left join t1 on contains(t1.arr, t2.k)";
+ assertQuery(enableOptimization, sql, "values (10, 1, 'a'), (11, 4, 'b')");
+
+ sql = "with t1 as (select * from (values (array[1, 1, 3, null, null], 10), (array[4, 4, 6, null, null], 11)) t(arr, k)), t2 as (select * from (values (1, 'a'), (4, 'b')) t(k, v)) " +
+ "select t1.k, t2.k, t2.v from t2 left join t1 on contains(t1.arr, t2.k)";
+ assertQuery(enableOptimization, sql, "values (10, 1, 'a'), (11, 4, 'b')");
+
+ sql = "with t1 as (select * from (values (array[1, null, 3], 10), (array[4, null, 6], 11)) t(arr, k)), t2 as (select * from (values (1, 'a'), (null, 'b')) t(k, v)) " +
+ "select t1.k, t2.k, t2.v from t2 left join t1 on contains(t1.arr, t2.k)";
+ assertQuery(enableOptimization, sql, "values (10, 1, 'a'), (NULL, NULL, 'b')");
+
+ sql = "with t1 as (select * from (values (array[1, 2, 3], 10), (array[4, 5, 6], 11)) t(arr, k)), t2 as (select * from (values (1, 'a'), (4, 'b')) t(k, v)) " +
+ "select t1.k, t2.k, t2.v from t2 left join t1 on contains(t1.arr, t2.k) and t1.k > 10";
+ assertQuery(enableOptimization, sql, "values (NULL, 1, 'a'), (11, 4, 'b')");
+
+ sql = "with t1 as (select * from (values (array[1, 2, 3], 1), (array[4, 5, 6], 11)) t(arr, k)), t2 as (select * from (values (1, 'a'), (4, 'b')) t(k, v)) " +
+ "select t1.k, t2.k, t2.v from t2 left join t1 on contains(t1.arr, t2.k) or t1.k = t2.k";
+ assertQuery(enableOptimization, sql, "values (1, 1, 'a'), (11, 4, 'b')");
+
+ sql = "with t1 as (select array_agg(orderkey) orderkey, partkey from lineitem l where l.quantity < 5 group by partkey) " +
+ "select t1.partkey, o.orderkey, o.totalprice from orders o left join t1 on contains(t1.orderkey, o.orderkey) where o.totalprice < 2000";
+ // Because the UDF has different names in H2, which is `array_contains`
+ String h2Sql = "with t1 as (select array_agg(orderkey) orderkey, partkey from lineitem l where l.quantity < 5 group by partkey) " +
+ "select t1.partkey, o.orderkey, o.totalprice from orders o left join t1 on array_contains(t1.orderkey, o.orderkey) where o.totalprice < 2000";
+ assertQuery(enableOptimization, sql, h2Sql);
+
+ sql = "with t1 as (select array_agg(orderkey) orderkey, partkey from lineitem l where l.quantity < 5 group by partkey) " +
+ "select t1.partkey, o.orderkey, o.totalprice from orders o left join t1 on contains(t1.orderkey, o.orderkey) and t1.partkey < o.orderkey where o.totalprice < 2000";
+ h2Sql = "with t1 as (select array_agg(orderkey) orderkey, partkey from lineitem l where l.quantity < 5 group by partkey) " +
+ "select t1.partkey, o.orderkey, o.totalprice from orders o left join t1 on array_contains(t1.orderkey, o.orderkey) and t1.partkey < o.orderkey where o.totalprice < 2000";
+ assertQuery(enableOptimization, sql, h2Sql);
+
+ // Element type and array type does not match
+ sql = "with t1 as (select * from (values (array[cast(1 as bigint), 2, 3], 10), (array[4, 5, 6], 11)) t(arr, k)), t2 as (select * from (values (cast(1 as integer), 'a'), (4, 'b')) t(k, v)) " +
+ "select t1.k, t2.k, t2.v from t2 left join t1 on contains(t1.arr, t2.k)";
+ assertQuery(enableOptimization, sql, "values (11, 4, 'b'), (10, 1, 'a')");
+
+ sql = "with t1 as (select * from (values (array[cast(1 as integer), 2, 3], 10), (array[4, 5, 6], 11)) t(arr, k)), t2 as (select * from (values (cast(1 as bigint), 'a'), (4, 'b')) t(k, v)) " +
+ "select t1.k, t2.k, t2.v from t2 left join t1 on contains(t1.arr, t2.k)";
+ assertQuery(enableOptimization, sql, "values (11, 4, 'b'), (10, 1, 'a')");
+ }
+
@Test
public void testCrossJoinWithArrayNotContainsCondition()
{