Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import io.trino.operator.PrecomputedHashGenerator;
import io.trino.spi.Page;
import io.trino.spi.type.Type;
import io.trino.sql.planner.MergePartitioningHandle;
import io.trino.sql.planner.NodePartitioningManager;
import io.trino.sql.planner.PartitioningHandle;
import io.trino.sql.planner.SystemPartitioningHandle;
Expand Down Expand Up @@ -134,7 +135,8 @@ else if (partitioning.equals(SCALED_WRITER_DISTRIBUTION)) {
physicalWrittenBytesSupplier,
writerMinSize);
}
else if (partitioning.equals(FIXED_HASH_DISTRIBUTION) || partitioning.getCatalogHandle().isPresent()) {
else if (partitioning.equals(FIXED_HASH_DISTRIBUTION) || partitioning.getCatalogHandle().isPresent() ||
(partitioning.getConnectorHandle() instanceof MergePartitioningHandle)) {
exchangerSupplier = () -> {
PartitionFunction partitionFunction = createPartitionFunction(
nodePartitioningManager,
Expand Down Expand Up @@ -224,14 +226,22 @@ private static PartitionFunction createPartitionFunction(
// The same bucket function (with the same bucket count) as for node
// partitioning must be used. This way rows within a single bucket
// will be being processed by single thread.
int bucketCount = nodePartitioningManager.getBucketNodeMap(session, partitioning).getBucketCount();
int bucketCount = nodePartitioningManager.getBucketCount(session, partitioning);
int[] bucketToPartition = new int[bucketCount];

for (int bucket = 0; bucket < bucketCount; bucket++) {
// mix the bucket bits so we don't use the same bucket number used to distribute between stages
int hashedBucket = (int) XxHash64.hash(Long.reverse(bucket));
bucketToPartition[bucket] = hashedBucket & (partitionCount - 1);
}

if (partitioning.getConnectorHandle() instanceof MergePartitioningHandle handle) {
return handle.getPartitionFunction(
(scheme, types) -> nodePartitioningManager.getPartitionFunction(session, scheme, types, bucketToPartition),
partitionChannelTypes,
bucketToPartition);
}

return new BucketPartitionFunction(
nodePartitioningManager.getBucketFunction(session, partitioning, partitionChannelTypes, bucketCount),
bucketToPartition);
Expand Down Expand Up @@ -358,7 +368,8 @@ else if (partitioning.equals(SCALED_WRITER_DISTRIBUTION)) {
bufferCount = defaultConcurrency;
checkArgument(partitionChannels.isEmpty(), "Scaled writer exchange must not have partition channels");
}
else if (partitioning.equals(FIXED_HASH_DISTRIBUTION) || partitioning.getCatalogHandle().isPresent()) {
else if (partitioning.equals(FIXED_HASH_DISTRIBUTION) || partitioning.getCatalogHandle().isPresent() ||
(partitioning.getConnectorHandle() instanceof MergePartitioningHandle)) {
// partitioned exchange
bufferCount = defaultConcurrency;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,15 @@ public BucketNodeMap getBucketNodeMap(Session session, PartitioningHandle partit
return new BucketNodeMap(splitToBucket, createArbitraryBucketToNode(nodes, bucketCount));
}

public int getBucketCount(Session session, PartitioningHandle partitioning)
{
if (partitioning.getConnectorHandle() instanceof MergePartitioningHandle) {
// TODO: can we always use this code path?
return getNodePartitioningMap(session, partitioning).getBucketToPartition().length;
}
return getBucketNodeMap(session, partitioning).getBucketCount();
}

public int getNodeCount(Session session, PartitioningHandle partitioningHandle)
{
return getAllNodes(session, requiredCatalogHandle(partitioningHandle)).size();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@
import java.time.Instant;
import java.time.ZonedDateTime;
import java.time.format.DateTimeFormatter;
import java.util.ArrayList;
import java.util.List;
import java.util.Locale;
import java.util.Map;
Expand Down Expand Up @@ -5793,11 +5794,15 @@ public void testMergeSimpleSelectPartitioned()
}

@Test(dataProvider = "partitionedAndBucketedProvider")
public void testMergeUpdateWithVariousLayouts(String partitionPhase)
public void testMergeUpdateWithVariousLayouts(int writers, String partioning)
{
Session session = Session.builder(getSession())
.setSystemProperty(TASK_WRITER_COUNT, String.valueOf(writers))
.build();

String targetTable = "merge_formats_target_" + randomTableSuffix();
String sourceTable = "merge_formats_source_" + randomTableSuffix();
assertUpdate(format("CREATE TABLE %s (customer VARCHAR, purchase VARCHAR) %s", targetTable, partitionPhase));
assertUpdate(format("CREATE TABLE %s (customer VARCHAR, purchase VARCHAR) %s", targetTable, partioning));

assertUpdate(format("INSERT INTO %s (customer, purchase) VALUES ('Dave', 'dates'), ('Lou', 'limes'), ('Carol', 'candles')", targetTable), 3);
assertQuery("SELECT * FROM " + targetTable, "VALUES ('Dave', 'dates'), ('Lou', 'limes'), ('Carol', 'candles')");
Expand All @@ -5811,7 +5816,7 @@ public void testMergeUpdateWithVariousLayouts(String partitionPhase)
" WHEN MATCHED THEN UPDATE SET customer = CONCAT(t.customer, '_', s.customer)" +
" WHEN NOT MATCHED THEN INSERT (customer, purchase) VALUES(s.customer, s.purchase)";

assertUpdate(sql, 3);
assertUpdate(session, sql, 3);

assertQuery("SELECT * FROM " + targetTable, "VALUES ('Dave', 'dates'), ('Carol_Craig', 'candles'), ('Joe', 'jellybeans')");
assertUpdate("DROP TABLE " + sourceTable);
Expand All @@ -5821,17 +5826,31 @@ public void testMergeUpdateWithVariousLayouts(String partitionPhase)
@DataProvider
public Object[][] partitionedAndBucketedProvider()
{
return new Object[][] {
{"WITH (partitioning = ARRAY['customer'])"},
{"WITH (partitioning = ARRAY['purchase'])"},
{"WITH (partitioning = ARRAY['bucket(customer, 3)'])"},
{"WITH (partitioning = ARRAY['bucket(purchase, 4)'])"},
};
List<Integer> writerCounts = ImmutableList.of(1, 4);
List<String> partitioningTypes = ImmutableList.<String>builder()
.add("")
.add("WITH (partitioning = ARRAY['customer'])")
.add("WITH (partitioning = ARRAY['purchase'])")
.add("WITH (partitioning = ARRAY['bucket(customer, 3)'])")
.add("WITH (partitioning = ARRAY['bucket(purchase, 4)'])")
.build();

List<Object[]> data = new ArrayList<>();
for (int writers : writerCounts) {
for (String partitioning : partitioningTypes) {
data.add(new Object[] {writers, partitioning});
}
}
return data.toArray(Object[][]::new);
}

@Test(dataProvider = "partitionedAndBucketedProvider")
public void testMergeMultipleOperations(String partitioning)
public void testMergeMultipleOperations(int writers, String partitioning)
{
Session session = Session.builder(getSession())
.setSystemProperty(TASK_WRITER_COUNT, String.valueOf(writers))
.build();

int targetCustomerCount = 32;
String targetTable = "merge_multiple_" + randomTableSuffix();
assertUpdate(format("CREATE TABLE %s (purchase INT, zipcode INT, spouse VARCHAR, address VARCHAR, customer VARCHAR) %s", targetTable, partitioning));
Expand All @@ -5848,7 +5867,8 @@ public void testMergeMultipleOperations(String partitioning)
.mapToObj(intValue -> format("('joe_%s', %s, %s, 'jill_%s', '%s Eop Ct')", intValue, 3000, 83000, intValue, intValue))
.collect(joining(", "));

assertUpdate(format("MERGE INTO %s t USING (VALUES %s) AS s(customer, purchase, zipcode, spouse, address)", targetTable, firstMergeSource) +
assertUpdate(session,
format("MERGE INTO %s t USING (VALUES %s) AS s(customer, purchase, zipcode, spouse, address)", targetTable, firstMergeSource) +
" ON t.customer = s.customer" +
" WHEN MATCHED THEN UPDATE SET purchase = s.purchase, zipcode = s.zipcode, spouse = s.spouse, address = s.address",
targetCustomerCount / 2);
Expand All @@ -5867,7 +5887,8 @@ public void testMergeMultipleOperations(String partitioning)
.mapToObj(intValue -> format("('joe_%s', %s, %s, 'jen_%s', '%s Poe Ct')", intValue, 5000, 85000, intValue, intValue))
.collect(joining(", "));

assertUpdate(format("MERGE INTO %s t USING (VALUES %s) AS s(customer, purchase, zipcode, spouse, address)", targetTable, secondMergeSource) +
assertUpdate(session,
format("MERGE INTO %s t USING (VALUES %s) AS s(customer, purchase, zipcode, spouse, address)", targetTable, secondMergeSource) +
" ON t.customer = s.customer" +
" WHEN MATCHED AND t.zipcode = 91000 THEN DELETE" +
" WHEN MATCHED AND s.zipcode = 85000 THEN UPDATE SET zipcode = 60000" +
Expand Down