Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Enhancement] colocate join need consider column equivalent conduction #13344

Merged
merged 4 commits into from
Nov 16, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,12 @@

import com.google.common.base.Preconditions;
import com.google.common.collect.Lists;
import com.google.common.collect.Sets;
import com.starrocks.catalog.ColocateTableIndex;
import com.starrocks.common.Pair;
import com.starrocks.qe.ConnectContext;
import com.starrocks.server.GlobalStateMgr;
import com.starrocks.sql.optimizer.base.ColumnRefSet;
import com.starrocks.sql.optimizer.base.DistributionProperty;
import com.starrocks.sql.optimizer.base.DistributionSpec;
import com.starrocks.sql.optimizer.base.HashDistributionDesc;
Expand All @@ -24,6 +26,7 @@

import java.util.Arrays;
import java.util.List;
import java.util.Set;

public class ChildOutputPropertyGuarantor extends PropertyDeriverBase<Void, ExpressionContext> {
private final OptimizerContext context;
Expand Down Expand Up @@ -102,16 +105,27 @@ public boolean canColocateJoin(HashDistributionSpec leftLocalDistributionSpec,
Preconditions.checkState(leftLocalDistributionDesc.getColumns().size() ==
rightLocalDistributionDesc.getColumns().size());
}
// check orders of predicate columns is right
// check predicate columns is satisfy bucket hash columns

// The order of equivalence predicates(shuffle columns are derived from them) is
// meaningless, hence it is correct to use a set to save these shuffle pairs. According
// to the distribution column information of the left and right children, we can build
// distribution pairs. We can use colocate join is judged by whether all the distribution
// pairs are exist in the equivalent predicates set.
Set<Pair<Integer, Integer>> shufflePairs = Sets.newHashSet();
for (int i = 0; i < leftShuffleColumns.size(); i++) {
shufflePairs.add(Pair.create(leftShuffleColumns.get(i), rightShuffleColumns.get(i)));
}

for (int i = 0; i < leftLocalDistributionDesc.getColumns().size(); ++i) {
int leftScanColumnId = leftLocalDistributionDesc.getColumns().get(i);
int leftIndex = leftShuffleColumns.indexOf(leftScanColumnId);
ColumnRefSet leftEquivalentCols = leftLocalDistributionSpec.getPropertyInfo()
.getEquivalentColumns(leftScanColumnId);

int rightScanColumnId = rightLocalDistributionDesc.getColumns().get(i);
int rightIndex = rightShuffleColumns.indexOf(rightScanColumnId);
ColumnRefSet rightEquivalentCols = rightLocalDistributionSpec.getPropertyInfo()
.getEquivalentColumns(rightScanColumnId);

if (leftIndex != rightIndex) {
if (!isDistributionPairExist(shufflePairs, leftEquivalentCols, rightEquivalentCols)) {
return false;
}
}
Expand Down Expand Up @@ -440,4 +454,18 @@ private boolean checkChildDistributionSatisfyShuffle(HashDistributionSpec leftDi
}
return leftIndexList.equals(rightIndexList);
}

private boolean isDistributionPairExist(Set<Pair<Integer, Integer>> shufflePairs,
ColumnRefSet leftEquivalentCols,
ColumnRefSet rightEquivalentCols) {
for (int leftCol : leftEquivalentCols.getColumnIds()) {
for (int rightCol : rightEquivalentCols.getColumnIds()) {
Pair<Integer, Integer> distributionPair = Pair.create(leftCol, rightCol);
if (shufflePairs.contains(distributionPair)) {
return true;
}
}
}
return false;
}
}
100 changes: 100 additions & 0 deletions fe/fe-core/src/test/java/com/starrocks/sql/plan/ColocateJoinTest.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
// This file is licensed under the Elastic License 2.0. Copyright 2021-present, StarRocks Inc.

package com.starrocks.sql.plan;

import com.google.common.collect.Lists;
import com.starrocks.common.FeConstants;
import org.apache.commons.lang.StringUtils;
import org.junit.Assert;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource;

import java.util.List;
import java.util.stream.Stream;

class ColocateJoinTest extends PlanTestBase {

@BeforeAll
public static void beforeClass() throws Exception {
PlanTestBase.beforeClass();
FeConstants.runningUnitTest = true;
starRocksAssert.withTable("CREATE TABLE `colocate_t2_1` (\n" +
" `v7` bigint NULL COMMENT \"\",\n" +
" `v8` bigint NULL COMMENT \"\",\n" +
" `v9` bigint NULL\n" +
") ENGINE=OLAP\n" +
"DUPLICATE KEY(`v7`, `v8`, v9)\n" +
"DISTRIBUTED BY HASH(`v7`) BUCKETS 3\n" +
"PROPERTIES (\n" +
"\"replication_num\" = \"1\",\n" +
"\"in_memory\" = \"false\",\n" +
"\"storage_format\" = \"DEFAULT\",\n" +
"\"colocate_with\" = \"colocate_group_1\"" +
");");
}

@ParameterizedTest(name = "sql_{index}: {0}.")
@MethodSource("colocateJoinOnceSqls")
void testColocateJoinOnce(String sql) throws Exception {
String plan = getFragmentPlan(sql);
int count = StringUtils.countMatches(plan, "INNER JOIN (COLOCATE)");
Assert.assertEquals(plan, 1, count);
}

@ParameterizedTest(name = "sql_{index}: {0}.")
@MethodSource("colocateJoinTwiceSqls")
void testColocateJoinTwice(String sql) throws Exception {
String plan = getFragmentPlan(sql);
int count = StringUtils.countMatches(plan, "INNER JOIN (COLOCATE)");
Assert.assertEquals(plan, 2, count);
}


private static Stream<Arguments> colocateJoinOnceSqls() {
List<String> sqls = Lists.newArrayList();

// sqls should colocate join but not support now
List<String> unsupportedSqls = Lists.newArrayList();
sqls.add("select * from colocate_t0 join colocate_t1 on v1 = v5 and v1 = v4");
sqls.add("select * from colocate_t0 join colocate_t1 on v2 = v4 and v1 = v4");
sqls.add("select * from colocate_t0 join colocate_t1 on v1 + v2 = v4 + v5 and v1 = v4 + 1 and v1 = v4");
sqls.add("select * from colocate_t0, colocate_t1 where v1 = v5 and v1 = v4");
sqls.add("select * from colocate_t0, colocate_t1 where v2 = v4 and v1 = v4");
sqls.add("select * from colocate_t0, colocate_t1 where v1 + v2 = v4 + v5 and v1 = v4 + 1 and v1 = v4");

sqls.add("select * from colocate_t0, colocate_t1, colocate_t2_1 where v1 = v5 and v5 = v7");

// TODO(packy) now we cannot derive v1 = v7 plan from the below sqls
unsupportedSqls.add("select * from colocate_t0 join colocate_t1 on v1 = v5 join colocate_t2 on v5 = v7");
unsupportedSqls.add("select * from colocate_t0 join colocate_t1 on v1 = v5 + v6 join colocate_t2 on v5 + v6 = v7");
unsupportedSqls.add("select * from colocate_t0, colocate_t1, colocate_t2_1 where v1 = v5 + v6 and v5 + v6 = v7");
return sqls.stream().map(e -> Arguments.of(e));
}

private static Stream<Arguments> colocateJoinTwiceSqls() {
List<String> sqls = Lists.newArrayList();
// sqls should colocate join but not support now
List<String> unsupportedSqls = Lists.newArrayList();
sqls.add("select * from colocate_t0 join colocate_t1 on v1 = v4 join colocate_t2_1 on v4 = v7");
sqls.add("select * from colocate_t0 join colocate_t1 on v1 = v5 and v1 = v4 join colocate_t2_1 on v5 = v7 and v7 = v2");
sqls.add("select * from colocate_t0 join colocate_t1 on v1 = v5 join colocate_t2_1 on v1 = v4 and v1 = v7");


sqls.add("select * from colocate_t0, colocate_t1, colocate_t2_1 where v1 = v4 and v4 = v7");
sqls.add("select * from colocate_t0, colocate_t1, colocate_t2_1 where v1 = v5 and v1 = v4 and v5 = v7 and v7 = v2");
sqls.add("select * from colocate_t0, colocate_t1, colocate_t2_1 where v1 = v5 and v1 = v4 and v1 = v7");


// TODO(packy) the expr col seems not been equivalent conduction
unsupportedSqls.add("select * from colocate_t0 join colocate_t1 on v1 = v5 and v1 = v4 + v6 and v1 = v4 " +
"join colocate_t2_1 on v4 + v6 = v7");
unsupportedSqls.add("select * from colocate_t0 join colocate_t1 on v1 + v2 = v4 and v1 + v2 = v5 - v4 " +
"join colocate_t2_1 on v5 - v4 = v7 and v7 = v1");
unsupportedSqls.add("select * from colocate_t0 join colocate_t1 on v1 + v2 = v4 - v3 and v1 = v4 + v5 " +
"join colocate_t2_1 on v4 + v5 = v4 and v4 + v5 = v7");
return sqls.stream().map(e -> Arguments.of(e));
}

}
36 changes: 15 additions & 21 deletions fe/fe-core/src/test/resources/sql/enumerate-plan/tpch-q18.sql
Original file line number Diff line number Diff line change
Expand Up @@ -111,9 +111,8 @@ TOP-N (order by [[13: O_TOTALPRICE DESC NULLS LAST, 14: O_ORDERDATE ASC NULLS FI
SCAN (columns[10: O_ORDERKEY, 11: O_CUSTKEY, 13: O_TOTALPRICE, 14: O_ORDERDATE] predicate[null])
EXCHANGE BROADCAST
SCAN (columns[1: C_CUSTKEY, 2: C_NAME] predicate[null])
EXCHANGE SHUFFLE[37]
AGGREGATE ([GLOBAL] aggregate [{54: sum=sum(41: L_QUANTITY)}] group by [[37: L_ORDERKEY]] having [54: sum > 315.0]
SCAN (columns[37: L_ORDERKEY, 41: L_QUANTITY] predicate[null])
AGGREGATE ([GLOBAL] aggregate [{54: sum=sum(41: L_QUANTITY)}] group by [[37: L_ORDERKEY]] having [54: sum > 315.0]
SCAN (columns[37: L_ORDERKEY, 41: L_QUANTITY] predicate[null])
[end]
[plan-6]
TOP-N (order by [[13: O_TOTALPRICE DESC NULLS LAST, 14: O_ORDERDATE ASC NULLS FIRST]])
Expand All @@ -127,9 +126,8 @@ TOP-N (order by [[13: O_TOTALPRICE DESC NULLS LAST, 14: O_ORDERDATE ASC NULLS FI
SCAN (columns[10: O_ORDERKEY, 11: O_CUSTKEY, 13: O_TOTALPRICE, 14: O_ORDERDATE] predicate[null])
EXCHANGE BROADCAST
SCAN (columns[1: C_CUSTKEY, 2: C_NAME] predicate[null])
EXCHANGE SHUFFLE[37]
AGGREGATE ([GLOBAL] aggregate [{54: sum=sum(41: L_QUANTITY)}] group by [[37: L_ORDERKEY]] having [54: sum > 315.0]
SCAN (columns[37: L_ORDERKEY, 41: L_QUANTITY] predicate[null])
AGGREGATE ([GLOBAL] aggregate [{54: sum=sum(41: L_QUANTITY)}] group by [[37: L_ORDERKEY]] having [54: sum > 315.0]
SCAN (columns[37: L_ORDERKEY, 41: L_QUANTITY] predicate[null])
[end]
[plan-7]
TOP-N (order by [[13: O_TOTALPRICE DESC NULLS LAST, 14: O_ORDERDATE ASC NULLS FIRST]])
Expand All @@ -143,10 +141,9 @@ TOP-N (order by [[13: O_TOTALPRICE DESC NULLS LAST, 14: O_ORDERDATE ASC NULLS FI
SCAN (columns[10: O_ORDERKEY, 11: O_CUSTKEY, 13: O_TOTALPRICE, 14: O_ORDERDATE] predicate[null])
EXCHANGE BROADCAST
SCAN (columns[1: C_CUSTKEY, 2: C_NAME] predicate[null])
EXCHANGE SHUFFLE[37]
AGGREGATE ([GLOBAL] aggregate [{54: sum=sum(54: sum)}] group by [[37: L_ORDERKEY]] having [54: sum > 315.0]
AGGREGATE ([LOCAL] aggregate [{54: sum=sum(41: L_QUANTITY)}] group by [[37: L_ORDERKEY]] having [null]
SCAN (columns[37: L_ORDERKEY, 41: L_QUANTITY] predicate[null])
AGGREGATE ([GLOBAL] aggregate [{54: sum=sum(54: sum)}] group by [[37: L_ORDERKEY]] having [54: sum > 315.0]
AGGREGATE ([LOCAL] aggregate [{54: sum=sum(41: L_QUANTITY)}] group by [[37: L_ORDERKEY]] having [null]
SCAN (columns[37: L_ORDERKEY, 41: L_QUANTITY] predicate[null])
[end]
[plan-8]
TOP-N (order by [[13: O_TOTALPRICE DESC NULLS LAST, 14: O_ORDERDATE ASC NULLS FIRST]])
Expand All @@ -160,10 +157,9 @@ TOP-N (order by [[13: O_TOTALPRICE DESC NULLS LAST, 14: O_ORDERDATE ASC NULLS FI
SCAN (columns[10: O_ORDERKEY, 11: O_CUSTKEY, 13: O_TOTALPRICE, 14: O_ORDERDATE] predicate[null])
EXCHANGE BROADCAST
SCAN (columns[1: C_CUSTKEY, 2: C_NAME] predicate[null])
EXCHANGE SHUFFLE[37]
AGGREGATE ([GLOBAL] aggregate [{54: sum=sum(54: sum)}] group by [[37: L_ORDERKEY]] having [54: sum > 315.0]
AGGREGATE ([LOCAL] aggregate [{54: sum=sum(41: L_QUANTITY)}] group by [[37: L_ORDERKEY]] having [null]
SCAN (columns[37: L_ORDERKEY, 41: L_QUANTITY] predicate[null])
AGGREGATE ([GLOBAL] aggregate [{54: sum=sum(54: sum)}] group by [[37: L_ORDERKEY]] having [54: sum > 315.0]
AGGREGATE ([LOCAL] aggregate [{54: sum=sum(41: L_QUANTITY)}] group by [[37: L_ORDERKEY]] having [null]
SCAN (columns[37: L_ORDERKEY, 41: L_QUANTITY] predicate[null])
[end]
[plan-9]
TOP-N (order by [[13: O_TOTALPRICE DESC NULLS LAST, 14: O_ORDERDATE ASC NULLS FIRST]])
Expand All @@ -175,9 +171,8 @@ TOP-N (order by [[13: O_TOTALPRICE DESC NULLS LAST, 14: O_ORDERDATE ASC NULLS FI
SCAN (columns[20: L_ORDERKEY, 24: L_QUANTITY] predicate[null])
EXCHANGE SHUFFLE[10]
SCAN (columns[10: O_ORDERKEY, 11: O_CUSTKEY, 13: O_TOTALPRICE, 14: O_ORDERDATE] predicate[null])
EXCHANGE SHUFFLE[37]
AGGREGATE ([GLOBAL] aggregate [{54: sum=sum(41: L_QUANTITY)}] group by [[37: L_ORDERKEY]] having [54: sum > 315.0]
SCAN (columns[37: L_ORDERKEY, 41: L_QUANTITY] predicate[null])
AGGREGATE ([GLOBAL] aggregate [{54: sum=sum(41: L_QUANTITY)}] group by [[37: L_ORDERKEY]] having [54: sum > 315.0]
SCAN (columns[37: L_ORDERKEY, 41: L_QUANTITY] predicate[null])
EXCHANGE BROADCAST
SCAN (columns[1: C_CUSTKEY, 2: C_NAME] predicate[null])
[end]
Expand All @@ -191,10 +186,9 @@ TOP-N (order by [[13: O_TOTALPRICE DESC NULLS LAST, 14: O_ORDERDATE ASC NULLS FI
SCAN (columns[20: L_ORDERKEY, 24: L_QUANTITY] predicate[null])
EXCHANGE SHUFFLE[10]
SCAN (columns[10: O_ORDERKEY, 11: O_CUSTKEY, 13: O_TOTALPRICE, 14: O_ORDERDATE] predicate[null])
EXCHANGE SHUFFLE[37]
AGGREGATE ([GLOBAL] aggregate [{54: sum=sum(54: sum)}] group by [[37: L_ORDERKEY]] having [54: sum > 315.0]
AGGREGATE ([LOCAL] aggregate [{54: sum=sum(41: L_QUANTITY)}] group by [[37: L_ORDERKEY]] having [null]
SCAN (columns[37: L_ORDERKEY, 41: L_QUANTITY] predicate[null])
AGGREGATE ([GLOBAL] aggregate [{54: sum=sum(54: sum)}] group by [[37: L_ORDERKEY]] having [54: sum > 315.0]
AGGREGATE ([LOCAL] aggregate [{54: sum=sum(41: L_QUANTITY)}] group by [[37: L_ORDERKEY]] having [null]
SCAN (columns[37: L_ORDERKEY, 41: L_QUANTITY] predicate[null])
EXCHANGE BROADCAST
SCAN (columns[1: C_CUSTKEY, 2: C_NAME] predicate[null])
[end]
Expand Down