From 0479a841c7a0ad6c426e6ea64766cae9100b992e Mon Sep 17 00:00:00 2001 From: Gaurav Sehgal Date: Tue, 27 Sep 2022 02:37:43 +0530 Subject: [PATCH 1/2] Fix bug in finding local scaling exchange node --- .../main/java/io/trino/sql/planner/LocalExecutionPlanner.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/LocalExecutionPlanner.java b/core/trino-main/src/main/java/io/trino/sql/planner/LocalExecutionPlanner.java index 7a6d7108c8b2..9318e8685dea 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/LocalExecutionPlanner.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/LocalExecutionPlanner.java @@ -3501,7 +3501,7 @@ public PhysicalOperation visitExchange(ExchangeNode node, LocalExecutionPlanCont private boolean isLocalScaledWriterExchange(PlanNode node) { Optional result = searchFrom(node) - .where(planNode -> node instanceof ExchangeNode && ((ExchangeNode) node).getScope() == LOCAL) + .where(planNode -> planNode instanceof ExchangeNode && ((ExchangeNode) planNode).getScope() == LOCAL) .findFirst(); return result.isPresent() From 9e7d62dc94205d95e88063009c788315e70d7077 Mon Sep 17 00:00:00 2001 From: Gaurav Sehgal Date: Wed, 28 Sep 2022 03:04:04 +0530 Subject: [PATCH 2/2] Pass partition channel types directly to LocalExchange --- .../operator/exchange/LocalExchange.java | 6 +---- .../sql/planner/LocalExecutionPlanner.java | 11 ++++++---- .../operator/exchange/TestLocalExchange.java | 22 +++++++++---------- .../io/trino/operator/join/JoinTestUtils.java | 6 ++++- .../join/unspilled/JoinTestUtils.java | 7 +++++- 5 files changed, 29 insertions(+), 23 deletions(-) diff --git a/core/trino-main/src/main/java/io/trino/operator/exchange/LocalExchange.java b/core/trino-main/src/main/java/io/trino/operator/exchange/LocalExchange.java index cfb25bb3d554..6f5f2b693e13 100644 --- a/core/trino-main/src/main/java/io/trino/operator/exchange/LocalExchange.java +++ b/core/trino-main/src/main/java/io/trino/operator/exchange/LocalExchange.java @@ -88,7 +88,7 @@ public LocalExchange( int defaultConcurrency, PartitioningHandle partitioning, List partitionChannels, - List types, + List partitionChannelTypes, Optional partitionHashChannel, DataSize maxBufferedBytes, BlockTypeOperators blockTypeOperators, @@ -106,10 +106,6 @@ public LocalExchange( .map(buffer -> (Consumer) buffer::addPage) .collect(toImmutableList()); - List partitionChannelTypes = partitionChannels.stream() - .map(types::get) - .collect(toImmutableList()); - this.memoryManager = new LocalExchangeMemoryManager(maxBufferedBytes.toBytes()); if (partitioning.equals(SINGLE_DISTRIBUTION)) { exchangerSupplier = () -> new BroadcastExchanger(buffers, memoryManager); diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/LocalExecutionPlanner.java b/core/trino-main/src/main/java/io/trino/sql/planner/LocalExecutionPlanner.java index 9318e8685dea..e1dd28e740e1 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/LocalExecutionPlanner.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/LocalExecutionPlanner.java @@ -3529,7 +3529,7 @@ private PhysicalOperation createLocalMerge(ExchangeNode node, LocalExecutionPlan operatorsCount, node.getPartitioningScheme().getPartitioning().getHandle(), ImmutableList.of(), - types, + ImmutableList.of(), Optional.empty(), maxLocalExchangeBufferSize, blockTypeOperators, @@ -3583,11 +3583,14 @@ else if (context.getDriverInstanceCount().isPresent()) { } List types = getSourceOperatorTypes(node, context.getTypes()); - List channels = node.getPartitioningScheme().getPartitioning().getArguments().stream() + List partitionChannels = node.getPartitioningScheme().getPartitioning().getArguments().stream() .map(argument -> node.getOutputSymbols().indexOf(argument.getColumn())) .collect(toImmutableList()); Optional hashChannel = node.getPartitioningScheme().getHashColumn() .map(symbol -> node.getOutputSymbols().indexOf(symbol)); + List partitionChannelTypes = partitionChannels.stream() + .map(types::get) + .collect(toImmutableList()); List driverFactoryParametersList = new ArrayList<>(); for (int i = 0; i < node.getSources().size(); i++) { @@ -3603,8 +3606,8 @@ else if (context.getDriverInstanceCount().isPresent()) { session, driverInstanceCount, node.getPartitioningScheme().getPartitioning().getHandle(), - channels, - types, + partitionChannels, + partitionChannelTypes, hashChannel, maxLocalExchangeBufferSize, blockTypeOperators, diff --git a/core/trino-main/src/test/java/io/trino/operator/exchange/TestLocalExchange.java b/core/trino-main/src/test/java/io/trino/operator/exchange/TestLocalExchange.java index 24d65d2b67d2..d349b292a5ab 100644 --- a/core/trino-main/src/test/java/io/trino/operator/exchange/TestLocalExchange.java +++ b/core/trino-main/src/test/java/io/trino/operator/exchange/TestLocalExchange.java @@ -111,7 +111,7 @@ public void testGatherSingleWriter() 8, SINGLE_DISTRIBUTION, ImmutableList.of(), - TYPES, + ImmutableList.of(), Optional.empty(), DataSize.ofBytes(retainedSizeOfPages(99)), TYPE_OPERATOR_FACTORY, @@ -185,7 +185,7 @@ public void testBroadcast() 2, FIXED_BROADCAST_DISTRIBUTION, ImmutableList.of(), - TYPES, + ImmutableList.of(), Optional.empty(), LOCAL_EXCHANGE_MAX_BUFFERED_BYTES, TYPE_OPERATOR_FACTORY, @@ -274,7 +274,7 @@ public void testRandom() 2, FIXED_ARBITRARY_DISTRIBUTION, ImmutableList.of(), - TYPES, + ImmutableList.of(), Optional.empty(), LOCAL_EXCHANGE_MAX_BUFFERED_BYTES, TYPE_OPERATOR_FACTORY, @@ -325,7 +325,7 @@ public void testScaleWriter() 3, SCALED_WRITER_DISTRIBUTION, ImmutableList.of(), - TYPES, + ImmutableList.of(), Optional.empty(), DataSize.ofBytes(retainedSizeOfPages(4)), TYPE_OPERATOR_FACTORY, @@ -406,7 +406,7 @@ public void testNoWriterScalingWhenOnlyBufferSizeLimitIsExceeded() 3, SCALED_WRITER_DISTRIBUTION, ImmutableList.of(), - TYPES, + ImmutableList.of(), Optional.empty(), DataSize.ofBytes(retainedSizeOfPages(4)), TYPE_OPERATOR_FACTORY, @@ -449,7 +449,7 @@ public void testNoWriterScalingWhenOnlyWriterMinSizeLimitIsExceeded() 3, SCALED_WRITER_DISTRIBUTION, ImmutableList.of(), - TYPES, + ImmutableList.of(), Optional.empty(), DataSize.ofBytes(retainedSizeOfPages(20)), TYPE_OPERATOR_FACTORY, @@ -493,7 +493,7 @@ public void testPassthrough() 2, FIXED_PASSTHROUGH_DISTRIBUTION, ImmutableList.of(), - TYPES, + ImmutableList.of(), Optional.empty(), DataSize.ofBytes(retainedSizeOfPages(1)), TYPE_OPERATOR_FACTORY, @@ -658,7 +658,7 @@ public BucketFunction getBucketFunction(ConnectorTransactionHandle transactionHa 2, partitioningHandle, ImmutableList.of(1), - types, + ImmutableList.of(BIGINT), Optional.empty(), LOCAL_EXCHANGE_MAX_BUFFERED_BYTES, TYPE_OPERATOR_FACTORY, @@ -704,15 +704,13 @@ public BucketFunction getBucketFunction(ConnectorTransactionHandle transactionHa @Test public void writeUnblockWhenAllReadersFinish() { - ImmutableList types = ImmutableList.of(BIGINT); - LocalExchange localExchange = new LocalExchange( nodePartitioningManager, SESSION, 2, FIXED_BROADCAST_DISTRIBUTION, ImmutableList.of(), - types, + ImmutableList.of(), Optional.empty(), LOCAL_EXCHANGE_MAX_BUFFERED_BYTES, TYPE_OPERATOR_FACTORY, @@ -760,7 +758,7 @@ public void writeUnblockWhenAllReadersFinishAndPagesConsumed() 2, FIXED_BROADCAST_DISTRIBUTION, ImmutableList.of(), - TYPES, + ImmutableList.of(), Optional.empty(), DataSize.ofBytes(1), TYPE_OPERATOR_FACTORY, diff --git a/core/trino-main/src/test/java/io/trino/operator/join/JoinTestUtils.java b/core/trino-main/src/test/java/io/trino/operator/join/JoinTestUtils.java index 6fb25a7411c3..b9e9044623d0 100644 --- a/core/trino-main/src/test/java/io/trino/operator/join/JoinTestUtils.java +++ b/core/trino-main/src/test/java/io/trino/operator/join/JoinTestUtils.java @@ -141,13 +141,17 @@ public static BuildSideSetup setupBuildSide( int partitionCount = parallelBuild ? PARTITION_COUNT : 1; List hashChannels = buildPages.getHashChannels().orElseThrow(); + List types = buildPages.getTypes(); + List hashChannelTypes = hashChannels.stream() + .map(types::get) + .collect(toImmutableList()); LocalExchange localExchange = new LocalExchange( nodePartitioningManager, taskContext.getSession(), partitionCount, FIXED_HASH_DISTRIBUTION, hashChannels, - buildPages.getTypes(), + hashChannelTypes, buildPages.getHashChannel(), DataSize.of(32, DataSize.Unit.MEGABYTE), TYPE_OPERATOR_FACTORY, diff --git a/core/trino-main/src/test/java/io/trino/operator/join/unspilled/JoinTestUtils.java b/core/trino-main/src/test/java/io/trino/operator/join/unspilled/JoinTestUtils.java index 71e15f8533f6..a9969a647b22 100644 --- a/core/trino-main/src/test/java/io/trino/operator/join/unspilled/JoinTestUtils.java +++ b/core/trino-main/src/test/java/io/trino/operator/join/unspilled/JoinTestUtils.java @@ -34,6 +34,7 @@ import io.trino.operator.join.unspilled.HashBuilderOperator.HashBuilderOperatorFactory; import io.trino.spi.Page; import io.trino.spi.TrinoException; +import io.trino.spi.type.Type; import io.trino.spi.type.TypeOperators; import io.trino.sql.gen.JoinFilterFunctionCompiler; import io.trino.sql.planner.NodePartitioningManager; @@ -136,13 +137,17 @@ public static BuildSideSetup setupBuildSide( int partitionCount = parallelBuild ? PARTITION_COUNT : 1; List hashChannels = buildPages.getHashChannels().orElseThrow(); + List types = buildPages.getTypes(); + List hashChannelTypes = hashChannels.stream() + .map(types::get) + .collect(toImmutableList()); LocalExchange localExchange = new LocalExchange( nodePartitioningManager, taskContext.getSession(), partitionCount, FIXED_HASH_DISTRIBUTION, hashChannels, - buildPages.getTypes(), + hashChannelTypes, buildPages.getHashChannel(), DataSize.of(32, DataSize.Unit.MEGABYTE), TYPE_OPERATOR_FACTORY,