Skip to content

Commit 02a0622

Browse files
committed
Relax filter constrain for index join planning
Allow index join planning with unsupported join conditions and let connectors decide whether to support them or not and how.
1 parent c83e6e5 commit 02a0622

File tree

2 files changed

+55
-60
lines changed

2 files changed

+55
-60
lines changed

presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/IndexJoinOptimizer.java

Lines changed: 30 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,6 @@
4040
import com.facebook.presto.spi.relation.ConstantExpression;
4141
import com.facebook.presto.spi.relation.DomainTranslator.ExtractionResult;
4242
import com.facebook.presto.spi.relation.RowExpression;
43-
import com.facebook.presto.spi.relation.SpecialFormExpression;
4443
import com.facebook.presto.spi.relation.VariableReferenceExpression;
4544
import com.facebook.presto.sql.planner.SimplePlanVisitor;
4645
import com.facebook.presto.sql.planner.TypeProvider;
@@ -65,6 +64,7 @@
6564
import java.util.concurrent.atomic.AtomicBoolean;
6665

6766
import static com.facebook.presto.expressions.LogicalRowExpressions.TRUE_CONSTANT;
67+
import static com.facebook.presto.expressions.LogicalRowExpressions.extractConjuncts;
6868
import static com.facebook.presto.spi.function.FunctionKind.AGGREGATE;
6969
import static com.facebook.presto.spi.plan.WindowNode.Frame.WindowType.RANGE;
7070
import static com.facebook.presto.sql.planner.plan.AssignmentUtils.identityAssignments;
@@ -841,61 +841,43 @@ public String toString()
841841
// Traverse the non-equal join condition and extract the lookup variables.
842842
private static void extractFromFilter(RowExpression expression, Context context)
843843
{
844-
if (expression instanceof SpecialFormExpression && ((SpecialFormExpression) expression).getForm() == SpecialFormExpression.Form.AND) {
845-
for (RowExpression operand : LogicalRowExpressions.extractConjuncts(expression)) {
846-
extractFromFilter(operand, context);
847-
if (!context.isEligible()) {
848-
return;
849-
}
844+
List<RowExpression> conjuncts = extractConjuncts(expression);
845+
for (RowExpression conjunct : conjuncts) {
846+
// Index lookup condition only supports Equal, BETWEEN and CONTAINS.
847+
if (!(conjunct instanceof CallExpression)) {
848+
continue;
850849
}
851-
return;
852-
}
853850

854-
// Index lookup only supports Equal, BETWEEN and CONTAINS/IN.
855-
if (!(expression instanceof CallExpression)) {
856-
context.markIneligible();
857-
return;
858-
}
851+
CallExpression callExpression = (CallExpression) conjunct;
852+
if (context.getStandardFunctionResolution().isEqualsFunction(callExpression.getFunctionHandle())
853+
&& callExpression.getArguments().size() == 2) {
854+
RowExpression leftArg = callExpression.getArguments().get(0);
855+
RowExpression rightArg = callExpression.getArguments().get(1);
859856

860-
CallExpression callExpression = (CallExpression) expression;
861-
if (context.getStandardFunctionResolution().isEqualsFunction(callExpression.getFunctionHandle())
862-
&& callExpression.getArguments().size() == 2) {
863-
RowExpression leftArg = callExpression.getArguments().get(0);
864-
RowExpression rightArg = callExpression.getArguments().get(1);
857+
VariableReferenceExpression variable = null;
858+
// Check for pattern: constant = variable or variable = constant.
859+
if (isConstant(leftArg) && isVariable(rightArg)) {
860+
variable = (VariableReferenceExpression) rightArg;
861+
}
862+
else if (isVariable(leftArg) && isConstant(rightArg)) {
863+
variable = (VariableReferenceExpression) leftArg;
864+
}
865865

866-
VariableReferenceExpression variable = null;
867-
// Check for pattern: constant = variable or variable = constant.
868-
if (isConstant(leftArg) && isVariable(rightArg)) {
869-
variable = (VariableReferenceExpression) rightArg;
866+
if (variable != null) {
867+
// It is a lookup equal condition only when it's variable=constant.
868+
context.getLookupVariables().add(variable);
869+
}
870870
}
871-
else if (isVariable(leftArg) && isConstant(rightArg)) {
872-
variable = (VariableReferenceExpression) leftArg;
871+
else if (context.getStandardFunctionResolution().isBetweenFunction(callExpression.getFunctionHandle())
872+
&& isVariable(callExpression.getArguments().get(0))) {
873+
context.getLookupVariables().add((VariableReferenceExpression) callExpression.getArguments().get(0));
873874
}
874-
875-
if (variable != null) {
876-
context.getLookupVariables().add(variable);
877-
return;
875+
else if (callExpression.getDisplayName().equalsIgnoreCase("CONTAINS")
876+
&& callExpression.getArguments().size() == 2
877+
&& isVariable(callExpression.getArguments().get(1))) {
878+
context.getLookupVariables().add((VariableReferenceExpression) callExpression.getArguments().get(1));
878879
}
879-
880-
// Equal condition must be constant.
881-
context.markIneligible();
882-
return;
883-
}
884-
885-
if (context.getStandardFunctionResolution().isBetweenFunction(callExpression.getFunctionHandle())
886-
&& isVariable(callExpression.getArguments().get(0))) {
887-
context.getLookupVariables().add((VariableReferenceExpression) callExpression.getArguments().get(0));
888-
return;
889880
}
890-
891-
if (callExpression.getDisplayName().equalsIgnoreCase("CONTAINS")
892-
&& callExpression.getArguments().size() == 2
893-
&& isVariable(callExpression.getArguments().get(1))) {
894-
context.getLookupVariables().add((VariableReferenceExpression) callExpression.getArguments().get(1));
895-
return;
896-
}
897-
898-
context.markIneligible();
899881
}
900882

901883
public static void extractFromSubPlan(PlanNode node, Context context)

presto-tests/src/test/java/com/facebook/presto/tests/TestNativeIndexJoinLogicalPlanner.java

Lines changed: 25 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ public void testBasicIndexJoin()
8181
" FROM lineitem\n" +
8282
" WHERE partkey % 8 = 0) l\n" +
8383
joinType + " JOIN orders o\n" +
84-
" ON l.orderkey = o.orderkey",
84+
" ON l.orderkey = o.orderkey\n",
8585
anyTree(indexJoin(
8686
filter(tableScan("lineitem")),
8787
indexSource("orders"))));
@@ -94,7 +94,7 @@ public void testBasicIndexJoin()
9494
" WHERE partkey % 8 = 0) l\n" +
9595
joinType + " JOIN orders o\n" +
9696
" ON l.orderkey = o.orderkey\n" +
97-
" AND l.orderstatus = o.orderstatus",
97+
" AND l.orderstatus = o.orderstatus\n",
9898
anyTree(indexJoin(
9999
project(filter(tableScan("lineitem"))),
100100
indexSource("orders"))));
@@ -107,7 +107,7 @@ public void testBasicIndexJoin()
107107
" WHERE partkey % 8 = 0) l\n" +
108108
joinType + " JOIN orders o\n" +
109109
" ON l.orderkey = o.orderkey\n" +
110-
" AND o.custkey = 100",
110+
" AND o.custkey = 100\n",
111111
anyTree(indexJoin(
112112
project(filter(tableScan("lineitem"))),
113113
filter(indexSource("orders")))));
@@ -125,8 +125,8 @@ public void testNonEqualIndexJoin()
125125
" FROM lineitem\n" +
126126
" WHERE partkey % 8 = 0) l\n" +
127127
joinType + " JOIN orders o\n" +
128-
" ON l.orderkey = o.orderkey" +
129-
" AND o.custkey BETWEEN 1 AND l.partkey",
128+
" ON l.orderkey = o.orderkey\n" +
129+
" AND o.custkey BETWEEN 1 AND l.partkey\n",
130130
anyTree(indexJoin(
131131
filter(tableScan("lineitem")),
132132
indexSource("orders"))));
@@ -138,8 +138,8 @@ public void testNonEqualIndexJoin()
138138
" FROM lineitem\n" +
139139
" WHERE partkey % 8 = 0) l\n" +
140140
joinType + " JOIN orders o\n" +
141-
" ON l.orderkey = o.orderkey" +
142-
" AND CONTAINS(ARRAY[1, l.partkey, 3], o.custkey)",
141+
" ON l.orderkey = o.orderkey\n" +
142+
" AND CONTAINS(ARRAY[1, l.partkey, 3], o.custkey\n)",
143143
anyTree(indexJoin(
144144
filter(tableScan("lineitem")),
145145
indexSource("orders"))));
@@ -151,8 +151,21 @@ public void testNonEqualIndexJoin()
151151
" FROM lineitem\n" +
152152
" WHERE partkey % 8 = 0) l\n" +
153153
joinType + " JOIN orders o\n" +
154-
" ON l.orderkey = o.orderkey" +
155-
" AND o.custkey BETWEEN 1 AND 100",
154+
" ON l.orderkey = o.orderkey\n" +
155+
" AND o.custkey BETWEEN 1 AND 100\n",
156+
anyTree(indexJoin(
157+
filter(tableScan("lineitem")),
158+
filter(indexSource("orders")))));
159+
160+
assertPlan("" +
161+
"SELECT *\n" +
162+
"FROM (\n" +
163+
" SELECT *\n" +
164+
" FROM lineitem\n" +
165+
" WHERE partkey % 8 = 0) l\n" +
166+
joinType + " JOIN orders o\n" +
167+
" ON l.orderkey = o.orderkey\n" +
168+
" AND CONTAINS(ARRAY[1, 2, 3], o.custkey)\n",
156169
anyTree(indexJoin(
157170
filter(tableScan("lineitem")),
158171
filter(indexSource("orders")))));
@@ -164,8 +177,8 @@ public void testNonEqualIndexJoin()
164177
" FROM lineitem\n" +
165178
" WHERE partkey % 8 = 0) l\n" +
166179
joinType + " JOIN orders o\n" +
167-
" ON l.orderkey = o.orderkey" +
168-
" AND CONTAINS(ARRAY[1, 2, 3], o.custkey)",
180+
" ON l.orderkey = o.orderkey\n" +
181+
" AND o.custkey % 100 = 0\n",
169182
anyTree(indexJoin(
170183
filter(tableScan("lineitem")),
171184
filter(indexSource("orders")))));
@@ -182,7 +195,7 @@ public void testPushdownSubfields()
182195
" ) \n" +
183196
"FROM lineitem l\n" +
184197
"JOIN orders_extra o\n" +
185-
" ON l.orderkey = o.orderkey";
198+
" ON l.orderkey = o.orderkey\n";
186199
PlanMatchPattern expectedQueryPlan = output(
187200
project(
188201
indexJoin(

0 commit comments

Comments
 (0)