Skip to content

Commit

Permalink
Extract CoalesceBatchesStream to a struct
Browse files Browse the repository at this point in the history
  • Loading branch information
alamb committed Jul 23, 2024
1 parent 77311a5 commit 117901b
Showing 1 changed file with 142 additions and 113 deletions.
255 changes: 142 additions & 113 deletions datafusion/physical-plan/src/coalesce_batches.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
use std::any::Any;
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};
use std::task::{ready, Context, Poll};

use super::metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet};
use super::{DisplayAs, ExecutionPlanProperties, PlanProperties, Statistics};
Expand Down Expand Up @@ -146,10 +146,7 @@ impl ExecutionPlan for CoalesceBatchesExec {
) -> Result<SendableRecordBatchStream> {
Ok(Box::pin(CoalesceBatchesStream {
input: self.input.execute(partition, context)?,
schema: self.input.schema(),
target_batch_size: self.target_batch_size,
buffer: Vec::new(),
buffered_rows: 0,
coalescer: BatchCoalescer::new(self.input.schema(), self.target_batch_size),
is_closed: false,
baseline_metrics: BaselineMetrics::new(&self.metrics, partition),
}))
Expand All @@ -167,14 +164,8 @@ impl ExecutionPlan for CoalesceBatchesExec {
struct CoalesceBatchesStream {
/// The input plan
input: SendableRecordBatchStream,
/// The input schema
schema: SchemaRef,
/// Minimum number of rows for coalesces batches
target_batch_size: usize,
/// Buffered batches
buffer: Vec<RecordBatch>,
/// Buffered row count
buffered_rows: usize,
/// Buffer for combining batches
coalescer: BatchCoalescer,
/// Whether the stream has finished returning all of its data or not
is_closed: bool,
/// Execution metrics
Expand Down Expand Up @@ -213,66 +204,35 @@ impl CoalesceBatchesStream {
let input_batch = self.input.poll_next_unpin(cx);
// records time on drop
let _timer = cloned_time.timer();
match input_batch {
Poll::Ready(x) => match x {
Some(Ok(batch)) => {
if batch.num_rows() >= self.target_batch_size
&& self.buffer.is_empty()
{
return Poll::Ready(Some(Ok(batch)));
} else if batch.num_rows() == 0 {
// discard empty batches
} else {
// add to the buffered batches
self.buffered_rows += batch.num_rows();
self.buffer.push(batch);
// check to see if we have enough batches yet
if self.buffered_rows >= self.target_batch_size {
// combine the batches and return
let batch = concat_batches(
&self.schema,
&self.buffer,
self.buffered_rows,
)?;
// reset buffer state
self.buffer.clear();
self.buffered_rows = 0;
// return batch
return Poll::Ready(Some(Ok(batch)));
}
}
}
None => {
self.is_closed = true;
// we have reached the end of the input stream but there could still
// be buffered batches
if self.buffer.is_empty() {
return Poll::Ready(None);
} else {
// combine the batches and return
let batch = concat_batches(
&self.schema,
&self.buffer,
self.buffered_rows,
)?;
// reset buffer state
self.buffer.clear();
self.buffered_rows = 0;
// return batch
return Poll::Ready(Some(Ok(batch)));
}
match ready!(input_batch) {
Some(result) => {
let Ok(input_batch) = result else {
return Poll::Ready(Some(result)); // pass back error
};
// Buffer the batch and either get more input if not enough
// rows yet or output
match self.coalescer.push_batch(input_batch) {
Ok(None) => continue,
res => return Poll::Ready(res.transpose()),
}
other => return Poll::Ready(other),
},
Poll::Pending => return Poll::Pending,
}
None => {
self.is_closed = true;
// we have reached the end of the input stream but there could still
// be buffered batches
return match self.coalescer.finish() {
Ok(None) => Poll::Ready(None),
res => Poll::Ready(res.transpose()),
};
}
}
}
}
}

impl RecordBatchStream for CoalesceBatchesStream {
fn schema(&self) -> SchemaRef {
Arc::clone(&self.schema)
self.coalescer.schema()
}
}

Expand All @@ -290,26 +250,106 @@ pub fn concat_batches(
arrow::compute::concat_batches(schema, batches)
}

/// Concatenating multiple record batches into larger batches
///
/// TODO ASCII ART
///
/// Notes:
///
/// 1. The output is exactly the same order as the input rows
///
/// 2. The output is a sequence of batches, with all but the last being at least
/// `target_batch_size` rows.
///
/// 3. Eventually this may also be able to handle other optimizations such as a
/// combined filter/coalesce operation.
#[derive(Debug)]
struct BatchCoalescer {
/// The input schema
schema: SchemaRef,
/// Minimum number of rows for coalesces batches
target_batch_size: usize,
/// Buffered batches
buffer: Vec<RecordBatch>,
/// Buffered row count
buffered_rows: usize,
}

impl BatchCoalescer {
/// Create a new BatchCoalescer that produces batches of at least `target_batch_size` rows
fn new(schema: SchemaRef, target_batch_size: usize) -> Self {
Self {
schema,
target_batch_size,
buffer: vec![],
buffered_rows: 0,
}
}

/// Return the schema of the output batches
fn schema(&self) -> SchemaRef {
Arc::clone(&self.schema)
}

/// Add a batch to the coalescer, returning a batch if the target batch size is reached
fn push_batch(&mut self, batch: RecordBatch) -> Result<Option<RecordBatch>> {
if batch.num_rows() >= self.target_batch_size && self.buffer.is_empty() {
return Ok(Some(batch));
}
// discard empty batches
if batch.num_rows() == 0 {
return Ok(None);
}
// add to the buffered batches
self.buffered_rows += batch.num_rows();
self.buffer.push(batch);
// check to see if we have enough batches yet
let batch = if self.buffered_rows >= self.target_batch_size {
// combine the batches and return
let batch = concat_batches(&self.schema, &self.buffer, self.buffered_rows)?;
// reset buffer state
self.buffer.clear();
self.buffered_rows = 0;
// return batch
Some(batch)
} else {
None
};
Ok(batch)
}

/// Finish the coalescing process, returning all buffered data as a final,
/// single batch, if any
fn finish(&mut self) -> Result<Option<RecordBatch>> {
if self.buffer.is_empty() {
Ok(None)
} else {
// combine the batches and return
let batch = concat_batches(&self.schema, &self.buffer, self.buffered_rows)?;
// reset buffer state
self.buffer.clear();
self.buffered_rows = 0;
// return batch
Ok(Some(batch))
}
}
}

#[cfg(test)]
mod tests {
use super::*;
use crate::{memory::MemoryExec, repartition::RepartitionExec, Partitioning};

use arrow::datatypes::{DataType, Field, Schema};
use arrow_array::UInt32Array;

#[tokio::test(flavor = "multi_thread")]
async fn test_concat_batches() -> Result<()> {
let schema = test_schema();
let partition = create_vec_batches(&schema, 10);
let partitions = vec![partition];

let output_partitions = coalesce_batches(&schema, partitions, 21).await?;
assert_eq!(1, output_partitions.len());
let Scenario { schema, batch } = uint32_scenario();

// input is 10 batches x 8 rows (80 rows)
let input = std::iter::repeat(batch).take(10);

// expected output is batches of at least 20 rows (except for the final batch)
let batches = &output_partitions[0];
let batches = do_coalesce_batches(&schema, input, 21);
assert_eq!(4, batches.len());
assert_eq!(24, batches[0].num_rows());
assert_eq!(24, batches[1].num_rows());
Expand All @@ -319,54 +359,43 @@ mod tests {
Ok(())
}

fn test_schema() -> Arc<Schema> {
Arc::new(Schema::new(vec![Field::new("c0", DataType::UInt32, false)]))
}

async fn coalesce_batches(
// Coalesce the batches with a BatchCoalescer function with the given input
// and target batch size returning the resulting batches
fn do_coalesce_batches(
schema: &SchemaRef,
input_partitions: Vec<Vec<RecordBatch>>,
input: impl IntoIterator<Item = RecordBatch>,
target_batch_size: usize,
) -> Result<Vec<Vec<RecordBatch>>> {
) -> Vec<RecordBatch> {
// create physical plan
let exec = MemoryExec::try_new(&input_partitions, Arc::clone(schema), None)?;
let exec =
RepartitionExec::try_new(Arc::new(exec), Partitioning::RoundRobinBatch(1))?;
let exec: Arc<dyn ExecutionPlan> =
Arc::new(CoalesceBatchesExec::new(Arc::new(exec), target_batch_size));

// execute and collect results
let output_partition_count = exec.output_partitioning().partition_count();
let mut output_partitions = Vec::with_capacity(output_partition_count);
for i in 0..output_partition_count {
// execute this *output* partition and collect all batches
let task_ctx = Arc::new(TaskContext::default());
let mut stream = exec.execute(i, Arc::clone(&task_ctx))?;
let mut batches = vec![];
while let Some(result) = stream.next().await {
batches.push(result?);
}
output_partitions.push(batches);
let mut coalescer = BatchCoalescer::new(Arc::clone(schema), target_batch_size);
let mut output_batches: Vec<_> = input
.into_iter()
.filter_map(|batch| coalescer.push_batch(batch).unwrap())
.collect();
if let Some(batch) = coalescer.finish().unwrap() {
output_batches.push(batch);
}
Ok(output_partitions)
output_batches
}

/// Create vector batches
fn create_vec_batches(schema: &Schema, n: usize) -> Vec<RecordBatch> {
let batch = create_batch(schema);
let mut vec = Vec::with_capacity(n);
for _ in 0..n {
vec.push(batch.clone());
}
vec
/// Test scenario
#[derive(Debug)]
struct Scenario {
schema: Arc<Schema>,
batch: RecordBatch,
}

/// Create batch
fn create_batch(schema: &Schema) -> RecordBatch {
RecordBatch::try_new(
Arc::new(schema.clone()),
/// a batch of 8 rows of UInt32
fn uint32_scenario() -> Scenario {
let schema =
Arc::new(Schema::new(vec![Field::new("c0", DataType::UInt32, false)]));

let batch = RecordBatch::try_new(
Arc::clone(&schema),
vec![Arc::new(UInt32Array::from(vec![1, 2, 3, 4, 5, 6, 7, 8]))],
)
.unwrap()
.unwrap();

Scenario { schema, batch }
}
}

0 comments on commit 117901b

Please sign in to comment.