Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 10 additions & 61 deletions native/core/src/execution/shuffle/shuffle_writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,6 @@ impl ExecutionPlan for ShuffleWriterExec {
futures::stream::once(
external_shuffle(
input,
partition,
self.output_data_file.clone(),
self.output_index_file.clone(),
self.partitioning.clone(),
Expand All @@ -205,7 +204,6 @@ impl ExecutionPlan for ShuffleWriterExec {
#[allow(clippy::too_many_arguments)]
async fn external_shuffle(
mut input: SendableRecordBatchStream,
partition_id: usize,
output_data_file: String,
output_index_file: String,
partitioning: Partitioning,
Expand All @@ -216,7 +214,6 @@ async fn external_shuffle(
) -> Result<SendableRecordBatchStream> {
let schema = input.schema();
let mut repartitioner = ShuffleRepartitioner::try_new(
partition_id,
output_data_file,
output_index_file,
Arc::clone(&schema),
Expand Down Expand Up @@ -294,7 +291,6 @@ struct ShuffleRepartitioner {
num_output_partitions: usize,
runtime: Arc<RuntimeEnv>,
metrics: ShuffleRepartitionerMetrics,
reservation: MemoryReservation,

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the main change; removing the memory tracking in ShuffleRepartitioner because we already track the memory in each instance of PartitionedBuffer.

/// Hashes for each row in the current batch
hashes_buf: Vec<u32>,
/// Partition ids for each row in the current batch
Expand All @@ -306,7 +302,6 @@ struct ShuffleRepartitioner {
impl ShuffleRepartitioner {
#[allow(clippy::too_many_arguments)]
pub fn try_new(
partition_id: usize,
output_data_file: String,
output_index_file: String,
schema: SchemaRef,
Expand All @@ -318,9 +313,6 @@ impl ShuffleRepartitioner {
enable_fast_encoding: bool,
) -> Result<Self> {
let num_output_partitions = partitioning.partition_count();
let reservation = MemoryConsumer::new(format!("ShuffleRepartitioner[{}]", partition_id))
.with_can_spill(true)
.register(&runtime.memory_pool);

let mut hashes_buf = Vec::with_capacity(batch_size);
let mut partition_ids = Vec::with_capacity(batch_size);
Expand Down Expand Up @@ -352,7 +344,6 @@ impl ShuffleRepartitioner {
num_output_partitions,
runtime,
metrics,
reservation,
hashes_buf,
partition_ids,
batch_size,
Expand Down Expand Up @@ -472,41 +463,12 @@ impl ShuffleRepartitioner {
.enumerate()
.filter(|(_, (start, end))| start < end)
{
let mut mem_diff = self
.append_rows_to_partition(
input.columns(),
&shuffled_partition_ids[start..end],
partition_id,
)
.await?;

if mem_diff > 0 {
let mem_increase = mem_diff as usize;

let try_grow = {
let mut mempool_timer = self.metrics.mempool_time.timer();
let result = self.reservation.try_grow(mem_increase);
mempool_timer.stop();
result
};

if try_grow.is_err() {
self.spill().await?;
let mut mempool_timer = self.metrics.mempool_time.timer();
self.reservation.free();
self.reservation.try_grow(mem_increase)?;
mempool_timer.stop();
mem_diff = 0;
}
}

if mem_diff < 0 {
let mem_used = self.reservation.size();
let mem_decrease = mem_used.min(-mem_diff as usize);
let mut mempool_timer = self.metrics.mempool_time.timer();
self.reservation.shrink(mem_decrease);
mempool_timer.stop();
}
Comment on lines -475 to -509

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't need any of this memory accounting because it is already handled within append_rows_to_partition

self.append_rows_to_partition(
input.columns(),
&shuffled_partition_ids[start..end],
partition_id,
)
.await?;
}
}
Partitioning::UnknownPartitioning(n) if *n == 1 => {
Expand Down Expand Up @@ -593,10 +555,6 @@ impl ShuffleRepartitioner {

write_time.stop();

let mut mempool_timer = self.metrics.mempool_time.timer();
self.reservation.free();
mempool_timer.stop();

elapsed_compute.stop();

// shuffle writer always has empty output
Expand All @@ -608,7 +566,10 @@ impl ShuffleRepartitioner {
}

fn used(&self) -> usize {
self.reservation.size()
self.buffered_partitions
.iter()
.map(|b| b.reservation.size())
.sum()
}

fn spilled_bytes(&self) -> usize {
Expand Down Expand Up @@ -639,7 +600,6 @@ impl ShuffleRepartitioner {
for p in &mut self.buffered_partitions {
spilled_bytes += p.spill(&self.runtime, &self.metrics)?;
}
self.reservation.free();

self.metrics.spill_count.add(1);
self.metrics.spilled_bytes.add(spilled_bytes);
Expand Down Expand Up @@ -673,10 +633,6 @@ impl ShuffleRepartitioner {
// spill partitions and retry.
self.spill().await?;

let mut mempool_timer = self.metrics.mempool_time.timer();
self.reservation.free();
mempool_timer.stop();

start_index = new_start;
let output = &mut self.buffered_partitions[partition_id];
output_ret = output.append_rows(columns, indices, start_index, &self.metrics);
Expand Down Expand Up @@ -1125,7 +1081,6 @@ mod test {
let runtime_env = create_runtime(memory_limit);
let metrics_set = ExecutionPlanMetricsSet::new();
let mut repartitioner = ShuffleRepartitioner::try_new(
0,
"/tmp/data.out".to_string(),
"/tmp/index.out".to_string(),
batch.schema(),
Expand All @@ -1145,11 +1100,6 @@ mod test {
assert!(repartitioner.buffered_partitions[0].spill_file.is_none());
assert!(repartitioner.buffered_partitions[1].spill_file.is_none());

// TODO: note that we are currently double counting the memory usage
// because we reserve the memory twice - once at the repartitioner level
// and then again in each PartitionBuffer
// https://github.com/apache/datafusion-comet/issues/1448
assert_eq!(212992, repartitioner.reservation.size());
assert_eq!(
106496,
repartitioner.buffered_partitions[0].reservation.size()
Expand All @@ -1173,7 +1123,6 @@ mod test {
// insert another batch after spilling
repartitioner.insert_batch(batch.clone()).await.unwrap();

assert_eq!(212992, repartitioner.reservation.size());
assert_eq!(
106496,
repartitioner.buffered_partitions[0].reservation.size()
Expand Down