From e9e013c568250c33fe2cfbb1fa83af94415be6e2 Mon Sep 17 00:00:00 2001 From: David Phillips Date: Mon, 26 Sep 2022 16:20:40 -0700 Subject: [PATCH] Fix MERGE when task_writer_count > 1 --- .../operator/exchange/LocalExchange.java | 17 +++++-- .../sql/planner/NodePartitioningManager.java | 9 ++++ .../iceberg/BaseIcebergConnectorTest.java | 45 ++++++++++++++----- 3 files changed, 56 insertions(+), 15 deletions(-) diff --git a/core/trino-main/src/main/java/io/trino/operator/exchange/LocalExchange.java b/core/trino-main/src/main/java/io/trino/operator/exchange/LocalExchange.java index a9be3416e296..cfb25bb3d554 100644 --- a/core/trino-main/src/main/java/io/trino/operator/exchange/LocalExchange.java +++ b/core/trino-main/src/main/java/io/trino/operator/exchange/LocalExchange.java @@ -26,6 +26,7 @@ import io.trino.operator.PrecomputedHashGenerator; import io.trino.spi.Page; import io.trino.spi.type.Type; +import io.trino.sql.planner.MergePartitioningHandle; import io.trino.sql.planner.NodePartitioningManager; import io.trino.sql.planner.PartitioningHandle; import io.trino.sql.planner.SystemPartitioningHandle; @@ -134,7 +135,8 @@ else if (partitioning.equals(SCALED_WRITER_DISTRIBUTION)) { physicalWrittenBytesSupplier, writerMinSize); } - else if (partitioning.equals(FIXED_HASH_DISTRIBUTION) || partitioning.getCatalogHandle().isPresent()) { + else if (partitioning.equals(FIXED_HASH_DISTRIBUTION) || partitioning.getCatalogHandle().isPresent() || + (partitioning.getConnectorHandle() instanceof MergePartitioningHandle)) { exchangerSupplier = () -> { PartitionFunction partitionFunction = createPartitionFunction( nodePartitioningManager, @@ -224,14 +226,22 @@ private static PartitionFunction createPartitionFunction( // The same bucket function (with the same bucket count) as for node // partitioning must be used. This way rows within a single bucket // will be being processed by single thread. - int bucketCount = nodePartitioningManager.getBucketNodeMap(session, partitioning).getBucketCount(); + int bucketCount = nodePartitioningManager.getBucketCount(session, partitioning); int[] bucketToPartition = new int[bucketCount]; + for (int bucket = 0; bucket < bucketCount; bucket++) { // mix the bucket bits so we don't use the same bucket number used to distribute between stages int hashedBucket = (int) XxHash64.hash(Long.reverse(bucket)); bucketToPartition[bucket] = hashedBucket & (partitionCount - 1); } + if (partitioning.getConnectorHandle() instanceof MergePartitioningHandle handle) { + return handle.getPartitionFunction( + (scheme, types) -> nodePartitioningManager.getPartitionFunction(session, scheme, types, bucketToPartition), + partitionChannelTypes, + bucketToPartition); + } + return new BucketPartitionFunction( nodePartitioningManager.getBucketFunction(session, partitioning, partitionChannelTypes, bucketCount), bucketToPartition); @@ -358,7 +368,8 @@ else if (partitioning.equals(SCALED_WRITER_DISTRIBUTION)) { bufferCount = defaultConcurrency; checkArgument(partitionChannels.isEmpty(), "Scaled writer exchange must not have partition channels"); } - else if (partitioning.equals(FIXED_HASH_DISTRIBUTION) || partitioning.getCatalogHandle().isPresent()) { + else if (partitioning.equals(FIXED_HASH_DISTRIBUTION) || partitioning.getCatalogHandle().isPresent() || + (partitioning.getConnectorHandle() instanceof MergePartitioningHandle)) { // partitioned exchange bufferCount = defaultConcurrency; } diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/NodePartitioningManager.java b/core/trino-main/src/main/java/io/trino/sql/planner/NodePartitioningManager.java index e0d8343b1140..015ae824525d 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/NodePartitioningManager.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/NodePartitioningManager.java @@ -239,6 +239,15 @@ public BucketNodeMap getBucketNodeMap(Session session, PartitioningHandle partit return new BucketNodeMap(splitToBucket, createArbitraryBucketToNode(nodes, bucketCount)); } + public int getBucketCount(Session session, PartitioningHandle partitioning) + { + if (partitioning.getConnectorHandle() instanceof MergePartitioningHandle) { + // TODO: can we always use this code path? + return getNodePartitioningMap(session, partitioning).getBucketToPartition().length; + } + return getBucketNodeMap(session, partitioning).getBucketCount(); + } + public int getNodeCount(Session session, PartitioningHandle partitioningHandle) { return getAllNodes(session, requiredCatalogHandle(partitioningHandle)).size(); diff --git a/plugin/trino-iceberg/src/test/java/io/trino/plugin/iceberg/BaseIcebergConnectorTest.java b/plugin/trino-iceberg/src/test/java/io/trino/plugin/iceberg/BaseIcebergConnectorTest.java index b8f069570537..98b71d8323b0 100644 --- a/plugin/trino-iceberg/src/test/java/io/trino/plugin/iceberg/BaseIcebergConnectorTest.java +++ b/plugin/trino-iceberg/src/test/java/io/trino/plugin/iceberg/BaseIcebergConnectorTest.java @@ -64,6 +64,7 @@ import java.time.Instant; import java.time.ZonedDateTime; import java.time.format.DateTimeFormatter; +import java.util.ArrayList; import java.util.List; import java.util.Locale; import java.util.Map; @@ -5793,11 +5794,15 @@ public void testMergeSimpleSelectPartitioned() } @Test(dataProvider = "partitionedAndBucketedProvider") - public void testMergeUpdateWithVariousLayouts(String partitionPhase) + public void testMergeUpdateWithVariousLayouts(int writers, String partioning) { + Session session = Session.builder(getSession()) + .setSystemProperty(TASK_WRITER_COUNT, String.valueOf(writers)) + .build(); + String targetTable = "merge_formats_target_" + randomTableSuffix(); String sourceTable = "merge_formats_source_" + randomTableSuffix(); - assertUpdate(format("CREATE TABLE %s (customer VARCHAR, purchase VARCHAR) %s", targetTable, partitionPhase)); + assertUpdate(format("CREATE TABLE %s (customer VARCHAR, purchase VARCHAR) %s", targetTable, partioning)); assertUpdate(format("INSERT INTO %s (customer, purchase) VALUES ('Dave', 'dates'), ('Lou', 'limes'), ('Carol', 'candles')", targetTable), 3); assertQuery("SELECT * FROM " + targetTable, "VALUES ('Dave', 'dates'), ('Lou', 'limes'), ('Carol', 'candles')"); @@ -5811,7 +5816,7 @@ public void testMergeUpdateWithVariousLayouts(String partitionPhase) " WHEN MATCHED THEN UPDATE SET customer = CONCAT(t.customer, '_', s.customer)" + " WHEN NOT MATCHED THEN INSERT (customer, purchase) VALUES(s.customer, s.purchase)"; - assertUpdate(sql, 3); + assertUpdate(session, sql, 3); assertQuery("SELECT * FROM " + targetTable, "VALUES ('Dave', 'dates'), ('Carol_Craig', 'candles'), ('Joe', 'jellybeans')"); assertUpdate("DROP TABLE " + sourceTable); @@ -5821,17 +5826,31 @@ public void testMergeUpdateWithVariousLayouts(String partitionPhase) @DataProvider public Object[][] partitionedAndBucketedProvider() { - return new Object[][] { - {"WITH (partitioning = ARRAY['customer'])"}, - {"WITH (partitioning = ARRAY['purchase'])"}, - {"WITH (partitioning = ARRAY['bucket(customer, 3)'])"}, - {"WITH (partitioning = ARRAY['bucket(purchase, 4)'])"}, - }; + List writerCounts = ImmutableList.of(1, 4); + List partitioningTypes = ImmutableList.builder() + .add("") + .add("WITH (partitioning = ARRAY['customer'])") + .add("WITH (partitioning = ARRAY['purchase'])") + .add("WITH (partitioning = ARRAY['bucket(customer, 3)'])") + .add("WITH (partitioning = ARRAY['bucket(purchase, 4)'])") + .build(); + + List data = new ArrayList<>(); + for (int writers : writerCounts) { + for (String partitioning : partitioningTypes) { + data.add(new Object[] {writers, partitioning}); + } + } + return data.toArray(Object[][]::new); } @Test(dataProvider = "partitionedAndBucketedProvider") - public void testMergeMultipleOperations(String partitioning) + public void testMergeMultipleOperations(int writers, String partitioning) { + Session session = Session.builder(getSession()) + .setSystemProperty(TASK_WRITER_COUNT, String.valueOf(writers)) + .build(); + int targetCustomerCount = 32; String targetTable = "merge_multiple_" + randomTableSuffix(); assertUpdate(format("CREATE TABLE %s (purchase INT, zipcode INT, spouse VARCHAR, address VARCHAR, customer VARCHAR) %s", targetTable, partitioning)); @@ -5848,7 +5867,8 @@ public void testMergeMultipleOperations(String partitioning) .mapToObj(intValue -> format("('joe_%s', %s, %s, 'jill_%s', '%s Eop Ct')", intValue, 3000, 83000, intValue, intValue)) .collect(joining(", ")); - assertUpdate(format("MERGE INTO %s t USING (VALUES %s) AS s(customer, purchase, zipcode, spouse, address)", targetTable, firstMergeSource) + + assertUpdate(session, + format("MERGE INTO %s t USING (VALUES %s) AS s(customer, purchase, zipcode, spouse, address)", targetTable, firstMergeSource) + " ON t.customer = s.customer" + " WHEN MATCHED THEN UPDATE SET purchase = s.purchase, zipcode = s.zipcode, spouse = s.spouse, address = s.address", targetCustomerCount / 2); @@ -5867,7 +5887,8 @@ public void testMergeMultipleOperations(String partitioning) .mapToObj(intValue -> format("('joe_%s', %s, %s, 'jen_%s', '%s Poe Ct')", intValue, 5000, 85000, intValue, intValue)) .collect(joining(", ")); - assertUpdate(format("MERGE INTO %s t USING (VALUES %s) AS s(customer, purchase, zipcode, spouse, address)", targetTable, secondMergeSource) + + assertUpdate(session, + format("MERGE INTO %s t USING (VALUES %s) AS s(customer, purchase, zipcode, spouse, address)", targetTable, secondMergeSource) + " ON t.customer = s.customer" + " WHEN MATCHED AND t.zipcode = 91000 THEN DELETE" + " WHEN MATCHED AND s.zipcode = 85000 THEN UPDATE SET zipcode = 60000" +