Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -68,11 +68,13 @@
import io.trino.sql.planner.PlanFragment;
import io.trino.sql.planner.SplitSourceFactory;
import io.trino.sql.planner.SubPlan;
import io.trino.sql.planner.optimizations.PlanNodeSearcher;
import io.trino.sql.planner.plan.PlanFragmentId;
import io.trino.sql.planner.plan.PlanNode;
import io.trino.sql.planner.plan.PlanNodeId;
import io.trino.sql.planner.plan.RemoteSourceNode;
import io.trino.sql.planner.plan.TableScanNode;
import io.trino.sql.planner.plan.TableWriterNode;
import io.trino.tracing.TrinoAttributes;

import java.net.URI;
Expand Down Expand Up @@ -112,6 +114,8 @@
import static io.airlift.concurrent.MoreFutures.tryGetFutureValue;
import static io.airlift.concurrent.MoreFutures.whenAnyComplete;
import static io.airlift.http.client.HttpUriBuilder.uriBuilderFrom;
import static io.trino.SystemSessionProperties.getMaxHashPartitionCount;
import static io.trino.SystemSessionProperties.getMaxWriterTaskCount;
import static io.trino.SystemSessionProperties.getQueryRetryAttempts;
import static io.trino.SystemSessionProperties.getRetryDelayScaleFactor;
import static io.trino.SystemSessionProperties.getRetryInitialDelay;
Expand Down Expand Up @@ -875,6 +879,7 @@ public static DistributedStagesScheduler create(
partitioning.partitionCount));

Map<PlanFragmentId, Optional<int[]>> bucketToPartitionMap = createBucketToPartitionMap(
queryStateMachine.getSession(),
coordinatorStagesScheduler.getBucketToPartitionForStagesConsumedByCoordinator(),
stageManager,
partitioningCache);
Expand Down Expand Up @@ -956,6 +961,7 @@ public static DistributedStagesScheduler create(
}

private static Map<PlanFragmentId, Optional<int[]>> createBucketToPartitionMap(
Session session,
Map<PlanFragmentId, Optional<int[]>> bucketToPartitionForStagesConsumedByCoordinator,
StageManager stageManager,
Function<PartitioningKey, NodePartitionMap> partitioningCache)
Expand All @@ -969,7 +975,7 @@ private static Map<PlanFragmentId, Optional<int[]>> createBucketToPartitionMap(
partitioningCache,
fragment.getRoot(),
fragment.getRemoteSourceNodes(),
fragment.getPartitionCount());
getFragmentMaxPartitionCount(session, fragment));
for (SqlStage childStage : stageManager.getChildren(stage.getStageId())) {
result.put(childStage.getFragment().getId(), bucketToPartition);
}
Expand All @@ -982,7 +988,7 @@ private static Optional<int[]> getBucketToPartition(
Function<PartitioningKey, NodePartitionMap> partitioningCache,
PlanNode fragmentRoot,
List<RemoteSourceNode> remoteSourceNodes,
Optional<Integer> partitionCount)
int partitionCount)
{
if (partitioningHandle.equals(SOURCE_DISTRIBUTION) || partitioningHandle.equals(SCALED_WRITER_ROUND_ROBIN_DISTRIBUTION)) {
return Optional.of(new int[1]);
Expand Down Expand Up @@ -1050,7 +1056,7 @@ private static StageScheduler createStageScheduler(
Span stageSpan = stageExecution.getStageSpan();
PlanFragment fragment = stageExecution.getFragment();
PartitioningHandle partitioningHandle = fragment.getPartitioning();
Optional<Integer> partitionCount = fragment.getPartitionCount();
int partitionCount = getFragmentMaxPartitionCount(session, fragment);
Map<PlanNodeId, SplitSource> splitSources = splitSourceFactory.createSplitSources(session, stageSpan, fragment);
if (!splitSources.isEmpty()) {
queryStateMachine.addStateChangeListener(new StateChangeListener<>()
Expand Down Expand Up @@ -1119,15 +1125,14 @@ public void stateChanged(QueryState newState)
.collect(toImmutableList());
Supplier<Collection<TaskStatus>> writerTasksProvider = stageExecution::getTaskStatuses;

checkState(partitionCount.isPresent(), "Partition count cannot be empty when scale writers is used");
ScaledWriterScheduler scheduler = new ScaledWriterScheduler(
stageExecution,
sourceTasksProvider,
writerTasksProvider,
nodeScheduler.createNodeSelector(session, Optional.empty()),
executor,
getWriterScalingMinDataProcessed(session),
partitionCount.get());
partitionCount);

whenAllStages(childStageExecutions, StageExecution.State::isDone)
.addListener(scheduler::finish, directExecutor());
Expand All @@ -1153,7 +1158,7 @@ public void stateChanged(QueryState newState)
List<InternalNode> stageNodeList;
if (fragment.getRemoteSourceNodes().stream().allMatch(node -> node.getExchangeType() == REPLICATE)) {
// no remote source
bucketNodeMap = nodePartitioningManager.getBucketNodeMap(session, partitioningHandle);
bucketNodeMap = nodePartitioningManager.getBucketNodeMap(session, partitioningHandle, partitionCount);
stageNodeList = new ArrayList<>(nodeScheduler.createNodeSelector(session, catalogHandle).allNodes());
Collections.shuffle(stageNodeList);
}
Expand All @@ -1176,6 +1181,13 @@ public void stateChanged(QueryState newState)
tableExecuteContextManager);
}

private static int getFragmentMaxPartitionCount(Session session, PlanFragment fragment)
{
return fragment.getPartitionCount().orElseGet(() -> PlanNodeSearcher.searchFrom(fragment.getRoot())
.whereIsInstanceOfAny(TableWriterNode.class)
.matches() ? getMaxWriterTaskCount(session) : getMaxHashPartitionCount(session));
}

private static void closeSplitSources(Collection<SplitSource> splitSources)
{
for (SplitSource source : splitSources) {
Expand Down Expand Up @@ -1576,12 +1588,12 @@ public Optional<StageId> getFailedStageId()
}
}

private record PartitioningKey(PartitioningHandle handle, Optional<Integer> partitionCount)
private record PartitioningKey(PartitioningHandle handle, int partitionCount)
{
public PartitioningKey(PartitioningHandle handle, Optional<Integer> partitionCount)
public PartitioningKey(PartitioningHandle handle, int partitionCount)
{
this.handle = requireNonNull(handle, "handle cannot be null");
this.partitionCount = requireNonNull(partitionCount, "partitionCount cannot be null");
this.partitionCount = partitionCount;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -287,9 +287,9 @@ public static int getBucketCount(Session session, NodePartitioningManager nodePa
{
if (partitioning.getConnectorHandle() instanceof MergePartitioningHandle) {
// TODO: can we always use this code path?
return nodePartitioningManager.getNodePartitioningMap(session, partitioning).getBucketToPartition().length;
return nodePartitioningManager.getNodePartitioningMap(session, partitioning, 1000).getBucketToPartition().length;
}
return nodePartitioningManager.getBucketNodeMap(session, partitioning).getBucketCount();
return nodePartitioningManager.getBucketCount(session, partitioning);
}

private static boolean isSystemPartitioning(PartitioningHandle partitioning)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ public static PartitionFunction createPartitionFunction(
// compared to only a single hive bucket reaching the min limit.
int bucketCount = (handle.getConnectorHandle() instanceof SystemPartitioningHandle)
? SCALE_WRITERS_PARTITION_COUNT
: nodePartitioningManager.getBucketNodeMap(session, handle).getBucketCount();
: nodePartitioningManager.getBucketCount(session, handle);
return nodePartitioningManager.getPartitionFunction(
session,
scheme,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,6 @@
import static com.google.common.base.Verify.verify;
import static com.google.common.collect.ImmutableList.toImmutableList;
import static io.trino.SystemSessionProperties.getFaultTolerantExecutionMaxPartitionCount;
import static io.trino.SystemSessionProperties.getMaxHashPartitionCount;
import static io.trino.SystemSessionProperties.getRetryPolicy;
import static io.trino.execution.TaskManagerConfig.MAX_WRITER_COUNT;
import static io.trino.operator.exchange.LocalExchange.SCALE_WRITERS_MAX_PARTITIONS_PER_WRITER;
Expand Down Expand Up @@ -136,12 +135,7 @@ public BucketFunction getBucketFunction(Session session, PartitioningHandle part
return bucketFunction;
}

public NodePartitionMap getNodePartitioningMap(Session session, PartitioningHandle partitioningHandle)
{
return getNodePartitioningMap(session, partitioningHandle, new HashMap<>(), new AtomicReference<>(), Optional.empty());
}

public NodePartitionMap getNodePartitioningMap(Session session, PartitioningHandle partitioningHandle, Optional<Integer> partitionCount)
public NodePartitionMap getNodePartitioningMap(Session session, PartitioningHandle partitioningHandle, int partitionCount)
{
return getNodePartitioningMap(session, partitioningHandle, new HashMap<>(), new AtomicReference<>(), partitionCount);
}
Expand All @@ -155,7 +149,7 @@ private NodePartitionMap getNodePartitioningMap(
PartitioningHandle partitioningHandle,
Map<Integer, List<InternalNode>> bucketToNodeCache,
AtomicReference<List<InternalNode>> systemPartitioningCache,
Optional<Integer> partitionCount)
int partitionCount)
{
requireNonNull(session, "session is null");
requireNonNull(partitioningHandle, "partitioningHandle is null");
Expand Down Expand Up @@ -188,9 +182,10 @@ private NodePartitionMap getNodePartitioningMap(
}
else {
CatalogHandle catalogHandle = requiredCatalogHandle(partitioningHandle);
List<InternalNode> allNodes = getAllNodes(session, catalogHandle);
bucketToNode = bucketToNodeCache.computeIfAbsent(
connectorBucketNodeMap.getBucketCount(),
bucketCount -> createArbitraryBucketToNode(connectorBucketNodeMap.getCacheKeyHint(), getAllNodes(session, catalogHandle), bucketCount));
bucketCount -> createArbitraryBucketToNode(connectorBucketNodeMap.getCacheKeyHint(), allNodes.subList(0, Math.min(allNodes.size(), partitionCount)), bucketCount));
}
}

Expand All @@ -215,7 +210,7 @@ private NodePartitionMap getNodePartitioningMap(
return new NodePartitionMap(partitionToNode, bucketToPartition, getSplitToBucket(session, partitioningHandle, bucketToNode.size()));
}

private List<InternalNode> systemBucketToNode(Session session, PartitioningHandle partitioningHandle, AtomicReference<List<InternalNode>> nodesCache, Optional<Integer> partitionCount)
private List<InternalNode> systemBucketToNode(Session session, PartitioningHandle partitioningHandle, AtomicReference<List<InternalNode>> nodesCache, int partitionCount)
{
SystemPartitioning partitioning = ((SystemPartitioningHandle) partitioningHandle.getConnectorHandle()).getPartitioning();

Expand All @@ -227,7 +222,7 @@ private List<InternalNode> systemBucketToNode(Session session, PartitioningHandl
case FIXED -> {
List<InternalNode> value = nodesCache.get();
if (value == null) {
value = nodeSelector.selectRandomNodes(partitionCount.orElse(getMaxHashPartitionCount(session)));
value = nodeSelector.selectRandomNodes(partitionCount);
nodesCache.set(value);
}
yield value;
Expand All @@ -238,7 +233,13 @@ private List<InternalNode> systemBucketToNode(Session session, PartitioningHandl
return nodes;
}

public BucketNodeMap getBucketNodeMap(Session session, PartitioningHandle partitioningHandle)
public int getBucketCount(Session session, PartitioningHandle partitioningHandle)
{
// we don't care about partition count at all, just bucket count
return getBucketNodeMap(session, partitioningHandle, 1000).getBucketCount();
}

public BucketNodeMap getBucketNodeMap(Session session, PartitioningHandle partitioningHandle, int partitionCount)
{
Optional<ConnectorBucketNodeMap> bucketNodeMap = getConnectorBucketNodeMap(session, partitioningHandle);
int bucketCount = bucketNodeMap.map(ConnectorBucketNodeMap::getBucketCount).orElseGet(() -> getDefaultBucketCount(session, partitioningHandle));
Expand All @@ -250,6 +251,7 @@ public BucketNodeMap getBucketNodeMap(Session session, PartitioningHandle partit

long seed = bucketNodeMap.map(ConnectorBucketNodeMap::getCacheKeyHint).orElse(ThreadLocalRandom.current().nextLong());
List<InternalNode> nodes = getAllNodes(session, requiredCatalogHandle(partitioningHandle));
nodes = nodes.subList(0, Math.min(nodes.size(), partitionCount));
return new BucketNodeMap(splitToBucket, createArbitraryBucketToNode(seed, nodes, bucketCount));
}

Expand Down
Loading