Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -56,39 +56,37 @@ class HashDistributionSplitAssigner
private final Set<PlanNodeId> replicatedSources;
private final Set<PlanNodeId> allSources;
private final FaultTolerantPartitioningScheme sourcePartitioningScheme;
private final Map<Integer, TaskPartition> outputPartitionToTaskPartition;
private final Map<Integer, TaskPartition> sourcePartitionToTaskPartition;

private final Set<Integer> createdTaskPartitions = new HashSet<>();
private final Set<PlanNodeId> completedSources = new HashSet<>();
private final ListMultimap<PlanNodeId, Split> replicatedSplits = ArrayListMultimap.create();

private int nextTaskPartitionId;
private boolean allTaskPartitionsCreated;

public static HashDistributionSplitAssigner create(
Optional<CatalogHandle> catalogRequirement,
Set<PlanNodeId> partitionedSources,
Set<PlanNodeId> replicatedSources,
FaultTolerantPartitioningScheme sourcePartitioningScheme,
Map<PlanNodeId, OutputDataSizeEstimate> outputDataSizeEstimates,
Map<PlanNodeId, OutputDataSizeEstimate> sourceDataSizeEstimates,
PlanFragment fragment,
long targetPartitionSizeInBytes,
int targetMaxTaskCount)
{
if (fragment.getPartitioning().equals(SCALED_WRITER_HASH_DISTRIBUTION)) {
verify(

fragment.getPartitionedSources().isEmpty() && fragment.getRemoteSourceNodes().size() == 1,
verify(fragment.getPartitionedSources().isEmpty() && fragment.getRemoteSourceNodes().size() == 1,
"SCALED_WRITER_HASH_DISTRIBUTION fragments are expected to have exactly one remote source and no table scans");
}
return new HashDistributionSplitAssigner(
catalogRequirement,
partitionedSources,
replicatedSources,
sourcePartitioningScheme,
createOutputPartitionToTaskPartition(
createSourcePartitionToTaskPartition(
sourcePartitioningScheme,
partitionedSources,
outputDataSizeEstimates,
sourceDataSizeEstimates,
targetPartitionSizeInBytes,
targetMaxTaskCount,
sourceId -> fragment.getPartitioning().equals(SCALED_WRITER_HASH_DISTRIBUTION),
Expand All @@ -102,33 +100,60 @@ public static HashDistributionSplitAssigner create(
Set<PlanNodeId> partitionedSources,
Set<PlanNodeId> replicatedSources,
FaultTolerantPartitioningScheme sourcePartitioningScheme,
Map<Integer, TaskPartition> outputPartitionToTaskPartition)
Map<Integer, TaskPartition> sourcePartitionToTaskPartition)
{
this.catalogRequirement = requireNonNull(catalogRequirement, "catalogRequirement is null");
this.replicatedSources = ImmutableSet.copyOf(requireNonNull(replicatedSources, "replicatedSources is null"));
allSources = ImmutableSet.<PlanNodeId>builder()
this.allSources = ImmutableSet.<PlanNodeId>builder()
.addAll(partitionedSources)
.addAll(replicatedSources)
.build();
this.sourcePartitioningScheme = requireNonNull(sourcePartitioningScheme, "sourcePartitioningScheme is null");
this.outputPartitionToTaskPartition = ImmutableMap.copyOf(requireNonNull(outputPartitionToTaskPartition, "outputPartitionToTaskPartition is null"));
this.sourcePartitionToTaskPartition = ImmutableMap.copyOf(requireNonNull(sourcePartitionToTaskPartition, "sourcePartitionToTaskPartition is null"));
}

@Override
public AssignmentResult assign(PlanNodeId planNodeId, ListMultimap<Integer, Split> splits, boolean noMoreSplits)
{
AssignmentResult.Builder assignment = AssignmentResult.builder();

if (!allTaskPartitionsCreated) {
// create tasks all at once
int nextTaskPartitionId = 0;
for (int sourcePartitionId = 0; sourcePartitionId < sourcePartitioningScheme.getPartitionCount(); sourcePartitionId++) {
TaskPartition taskPartition = sourcePartitionToTaskPartition.get(sourcePartitionId);
verify(taskPartition != null, "taskPartition not found for sourcePartitionId: %s", sourcePartitionId);

for (SubPartition subPartition : taskPartition.getSubPartitions()) {
if (!subPartition.isIdAssigned()) {
int taskPartitionId = nextTaskPartitionId++;
subPartition.assignId(taskPartitionId);
Set<HostAddress> hostRequirement = sourcePartitioningScheme.getNodeRequirement(sourcePartitionId)
.map(InternalNode::getHostAndPort)
.map(ImmutableSet::of)
.orElse(ImmutableSet.of());
assignment.addPartition(new Partition(
taskPartitionId,
new NodeRequirements(catalogRequirement, hostRequirement)));
createdTaskPartitions.add(taskPartitionId);
}
}
}
assignment.setNoMorePartitions();

allTaskPartitionsCreated = true;
}

if (replicatedSources.contains(planNodeId)) {
replicatedSplits.putAll(planNodeId, splits.values());
for (Integer partitionId : createdTaskPartitions) {
assignment.updatePartition(new PartitionUpdate(partitionId, planNodeId, ImmutableList.copyOf(splits.values()), noMoreSplits));
}
}
else {
splits.forEach((outputPartitionId, split) -> {
TaskPartition taskPartition = outputPartitionToTaskPartition.get(outputPartitionId);
verify(taskPartition != null, "taskPartition not found for outputPartitionId: %s", outputPartitionId);
splits.forEach((sourcePartitionId, split) -> {
TaskPartition taskPartition = sourcePartitionToTaskPartition.get(sourcePartitionId);
verify(taskPartition != null, "taskPartition not found for sourcePartitionId: %s", sourcePartitionId);

List<SubPartition> subPartitions;
if (taskPartition.getSplitBy().isPresent() && taskPartition.getSplitBy().get().equals(planNodeId)) {
Expand All @@ -139,27 +164,6 @@ public AssignmentResult assign(PlanNodeId planNodeId, ListMultimap<Integer, Spli
}

for (SubPartition subPartition : subPartitions) {
if (!subPartition.isIdAssigned()) {
int taskPartitionId = nextTaskPartitionId++;
// Assigns lazily to ensure task ids are incremental and with no gaps.
// Gaps can occur when scanning over a bucketed table as some buckets may contain no data.
subPartition.assignId(taskPartitionId);
Set<HostAddress> hostRequirement = sourcePartitioningScheme.getNodeRequirement(outputPartitionId)
.map(InternalNode::getHostAndPort)
.map(ImmutableSet::of)
.orElse(ImmutableSet.of());
assignment.addPartition(new Partition(
taskPartitionId,
new NodeRequirements(catalogRequirement, hostRequirement)));
for (PlanNodeId replicatedSource : replicatedSplits.keySet()) {
assignment.updatePartition(new PartitionUpdate(taskPartitionId, replicatedSource, replicatedSplits.get(replicatedSource), completedSources.contains(replicatedSource)));
}
for (PlanNodeId completedSource : completedSources) {
assignment.updatePartition(new PartitionUpdate(taskPartitionId, completedSource, ImmutableList.of(), true));
}
createdTaskPartitions.add(taskPartitionId);
}

assignment.updatePartition(new PartitionUpdate(subPartition.getId(), planNodeId, ImmutableList.of(split), false));
}
});
Expand All @@ -170,23 +174,11 @@ public AssignmentResult assign(PlanNodeId planNodeId, ListMultimap<Integer, Spli
for (Integer taskPartition : createdTaskPartitions) {
assignment.updatePartition(new PartitionUpdate(taskPartition, planNodeId, ImmutableList.of(), true));
}

if (completedSources.containsAll(allSources)) {
if (createdTaskPartitions.isEmpty()) {
assignment.addPartition(new Partition(
0,
new NodeRequirements(catalogRequirement, ImmutableSet.of())));
for (PlanNodeId replicatedSource : replicatedSplits.keySet()) {
assignment.updatePartition(new PartitionUpdate(0, replicatedSource, replicatedSplits.get(replicatedSource), true));
}
for (PlanNodeId completedSource : completedSources) {
assignment.updatePartition(new PartitionUpdate(0, completedSource, ImmutableList.of(), true));
}
createdTaskPartitions.add(0);
}
for (Integer taskPartition : createdTaskPartitions) {
assignment.sealPartition(taskPartition);
}
assignment.setNoMorePartitions();
replicatedSplits.clear();
}
}
Expand All @@ -202,10 +194,10 @@ public AssignmentResult finish()
}

@VisibleForTesting
static Map<Integer, TaskPartition> createOutputPartitionToTaskPartition(
static Map<Integer, TaskPartition> createSourcePartitionToTaskPartition(
FaultTolerantPartitioningScheme sourcePartitioningScheme,
Set<PlanNodeId> partitionedSources,
Map<PlanNodeId, OutputDataSizeEstimate> outputDataSizeEstimates,
Map<PlanNodeId, OutputDataSizeEstimate> sourceDataSizeEstimates,
long targetPartitionSizeInBytes,
int targetMaxTaskCount,
Predicate<PlanNodeId> canSplit,
Expand All @@ -214,14 +206,14 @@ static Map<Integer, TaskPartition> createOutputPartitionToTaskPartition(
int partitionCount = sourcePartitioningScheme.getPartitionCount();
if (sourcePartitioningScheme.isExplicitPartitionToNodeMappingPresent() ||
partitionedSources.isEmpty() ||
!outputDataSizeEstimates.keySet().containsAll(partitionedSources)) {
!sourceDataSizeEstimates.keySet().containsAll(partitionedSources)) {
// if bucket scheme is set explicitly or if estimates are missing create one task partition per output partition
return IntStream.range(0, partitionCount)
.boxed()
.collect(toImmutableMap(Function.identity(), (key) -> new TaskPartition(1, Optional.empty())));
}

List<OutputDataSizeEstimate> partitionedSourcesEstimates = outputDataSizeEstimates.entrySet().stream()
List<OutputDataSizeEstimate> partitionedSourcesEstimates = sourceDataSizeEstimates.entrySet().stream()
.filter(entry -> partitionedSources.contains(entry.getKey()))
.map(Map.Entry::getValue)
.collect(toImmutableList());
Expand Down Expand Up @@ -249,7 +241,7 @@ static Map<Integer, TaskPartition> createOutputPartitionToTaskPartition(
partitionSizeInBytes,
targetPartitionSizeInBytes,
partitionedSources,
outputDataSizeEstimates,
sourceDataSizeEstimates,
partitionId,
canSplit);
result.put(partitionId, taskPartition);
Expand All @@ -268,13 +260,13 @@ private static TaskPartition createTaskPartition(
long partitionSizeInBytes,
long targetPartitionSizeInBytes,
Set<PlanNodeId> partitionedSources,
Map<PlanNodeId, OutputDataSizeEstimate> outputDataSizeEstimates,
Map<PlanNodeId, OutputDataSizeEstimate> sourceDataSizeEstimates,
int partitionId,
Predicate<PlanNodeId> canSplit)
{
if (partitionSizeInBytes > targetPartitionSizeInBytes) {
// try to assign multiple sub-partitions if possible
Map<PlanNodeId, Long> sourceSizes = getSourceSizes(partitionedSources, outputDataSizeEstimates, partitionId);
Map<PlanNodeId, Long> sourceSizes = getSourceSizes(partitionedSources, sourceDataSizeEstimates, partitionId);
PlanNodeId largestSource = sourceSizes.entrySet().stream()
.max(Map.Entry.comparingByValue())
.map(Map.Entry::getKey)
Expand All @@ -289,10 +281,10 @@ private static TaskPartition createTaskPartition(
return new TaskPartition(1, Optional.empty());
}

private static Map<PlanNodeId, Long> getSourceSizes(Set<PlanNodeId> partitionedSources, Map<PlanNodeId, OutputDataSizeEstimate> outputDataSizeEstimates, int partitionId)
private static Map<PlanNodeId, Long> getSourceSizes(Set<PlanNodeId> partitionedSources, Map<PlanNodeId, OutputDataSizeEstimate> sourceDataSizeEstimates, int partitionId)
{
return partitionedSources.stream()
.collect(toImmutableMap(Function.identity(), source -> outputDataSizeEstimates.get(source).getPartitionSizeInBytes(partitionId)));
.collect(toImmutableMap(Function.identity(), source -> sourceDataSizeEstimates.get(source).getPartitionSizeInBytes(partitionId)));
}

private record PartitionAssignment(TaskPartition taskPartition, long assignedDataSizeInBytes)
Expand Down
Loading