Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
282 changes: 276 additions & 6 deletions ballista/core/src/execution_plans/shuffle_reader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
use async_trait::async_trait;
use datafusion::arrow::ipc::reader::StreamReader;
use datafusion::common::stats::Precision;
use datafusion::physical_plan::coalesce::{LimitedBatchCoalescer, PushBatchStatus};
use std::any::Any;
use std::collections::HashMap;
use std::fmt::Debug;
Expand All @@ -38,12 +39,14 @@ use datafusion::arrow::record_batch::RecordBatch;
use datafusion::common::runtime::SpawnedTask;

use datafusion::error::{DataFusionError, Result};
use datafusion::physical_plan::metrics::{ExecutionPlanMetricsSet, MetricsSet};
use datafusion::physical_plan::metrics::{
BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet,
};
use datafusion::physical_plan::{
ColumnStatistics, DisplayAs, DisplayFormatType, ExecutionPlan, Partitioning,
PlanProperties, RecordBatchStream, SendableRecordBatchStream, Statistics,
};
use futures::{Stream, StreamExt, TryStreamExt};
use futures::{Stream, StreamExt, TryStreamExt, ready};

use crate::error::BallistaError;
use datafusion::execution::context::TaskContext;
Expand Down Expand Up @@ -162,6 +165,7 @@ impl ExecutionPlan for ShuffleReaderExec {
let max_message_size = config.ballista_grpc_client_max_message_size();
let force_remote_read = config.ballista_shuffle_reader_force_remote_read();
let prefer_flight = config.ballista_shuffle_reader_remote_prefer_flight();
let batch_size = config.batch_size();

if force_remote_read {
debug!(
Expand Down Expand Up @@ -197,11 +201,18 @@ impl ExecutionPlan for ShuffleReaderExec {
prefer_flight,
);

let result = RecordBatchStreamAdapter::new(
Arc::new(self.schema.as_ref().clone()),
let input_stream = Box::pin(RecordBatchStreamAdapter::new(
self.schema.clone(),
response_receiver.try_flatten(),
);
Ok(Box::pin(result))
));

Ok(Box::pin(CoalescedShuffleReaderStream::new(
input_stream,
batch_size,
None,
&self.metrics,
partition,
)))
}

fn metrics(&self) -> Option<MetricsSet> {
Expand Down Expand Up @@ -567,6 +578,96 @@ async fn fetch_partition_object_store(
))
}

struct CoalescedShuffleReaderStream {
schema: SchemaRef,
input: SendableRecordBatchStream,
coalescer: LimitedBatchCoalescer,
completed: bool,
baseline_metrics: BaselineMetrics,
}

impl CoalescedShuffleReaderStream {
pub fn new(
input: SendableRecordBatchStream,
batch_size: usize,
limit: Option<usize>,
metrics: &ExecutionPlanMetricsSet,
partition: usize,
) -> Self {
let schema = input.schema();
Self {
schema: schema.clone(),
input,
coalescer: LimitedBatchCoalescer::new(schema, batch_size, limit),
completed: false,
baseline_metrics: BaselineMetrics::new(metrics, partition),
}
}
}

impl Stream for CoalescedShuffleReaderStream {
type Item = Result<RecordBatch>;

fn poll_next(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Option<Self::Item>> {
let elapsed_compute = self.baseline_metrics.elapsed_compute().clone();
let _timer = elapsed_compute.timer();

loop {
// If there is already a completed batch ready, return it directly
if let Some(batch) = self.coalescer.next_completed_batch() {
self.baseline_metrics.record_output(batch.num_rows());
return Poll::Ready(Some(Ok(batch)));
}

// If the upstream is completed, then it is completed for this stream too
if self.completed {
return Poll::Ready(None);
}

// Pull from upstream
match ready!(self.input.poll_next_unpin(cx)) {
// If upstream is completed, then flush remaning buffered batches
None => {
self.completed = true;
if let Err(e) = self.coalescer.finish() {
return Poll::Ready(Some(Err(e)));
}
}
// If upstream is not completed, then push to coalescer
Some(Ok(batch)) => {
if batch.num_rows() > 0 {
// Try to push to coalescer
match self.coalescer.push_batch(batch) {
// If push is successful, then continue
Ok(PushBatchStatus::Continue) => {
continue;
}
// If limit is reached, then finish coalescer and set completed to true
Ok(PushBatchStatus::LimitReached) => {
self.completed = true;
if let Err(e) = self.coalescer.finish() {
return Poll::Ready(Some(Err(e)));
}
}
Err(e) => return Poll::Ready(Some(Err(e))),
}
}
}
Some(Err(e)) => return Poll::Ready(Some(Err(e))),
}
}
}
}

impl RecordBatchStream for CoalescedShuffleReaderStream {
fn schema(&self) -> SchemaRef {
self.schema.clone()
}
}

#[cfg(test)]
mod tests {
use super::*;
Expand Down Expand Up @@ -1025,10 +1126,179 @@ mod tests {
.unwrap()
}

fn create_custom_test_batch(rows: usize) -> RecordBatch {
let schema = create_test_schema();

// 1. Create number column (0, 1, 2, ..., rows-1)
let number_vec: Vec<u32> = (0..rows as u32).collect();
let number_array = UInt32Array::from(number_vec);

// 2. Create string column ("s0", "s1", ..., "s{rows-1}")
// Just to fill data, the content is not important
let string_vec: Vec<String> = (0..rows).map(|i| format!("s{}", i)).collect();
let string_array = StringArray::from(string_vec);

RecordBatch::try_new(schema, vec![Arc::new(number_array), Arc::new(string_array)])
.unwrap()
}

fn create_test_schema() -> SchemaRef {
Arc::new(Schema::new(vec![
Field::new("number", DataType::UInt32, true),
Field::new("str", DataType::Utf8, true),
]))
}

use datafusion::physical_plan::memory::MemoryStream;

#[tokio::test]
async fn test_coalesce_stream_logic() -> Result<()> {
// 1. Create test data - 10 small batches, each with 3 rows
let schema = create_test_schema();
let small_batch = create_test_batch();
let batches = vec![small_batch.clone(); 10];

// 2. Create mock upstream stream (Input Stream)
let input_stream = MemoryStream::try_new(batches, schema.clone(), None)?;
let input_stream = Box::pin(input_stream) as SendableRecordBatchStream;

// 3. Configure Coalescer: target batch size to 10 rows
let target_batch_size = 10;

// 4. Manually build the CoalescedShuffleReaderStream
let coalesced_stream = CoalescedShuffleReaderStream::new(
input_stream,
target_batch_size,
None,
&ExecutionPlanMetricsSet::new(),
0,
);

// 5. Execute stream and collect results
let output_batches = common::collect(Box::pin(coalesced_stream)).await?;

// 6. Assertions
// Assert A: Data total not lost (30 rows)
let total_rows: usize = output_batches.iter().map(|b| b.num_rows()).sum();
assert_eq!(total_rows, 30);

// Assert B: Batch count reduced (10 -> 3)
assert_eq!(output_batches.len(), 3);

// Assert C: Each batch size is correct (all should be 10)
assert_eq!(output_batches[0].num_rows(), 10);
assert_eq!(output_batches[1].num_rows(), 10);
assert_eq!(output_batches[2].num_rows(), 10);

Ok(())
}

#[tokio::test]
async fn test_coalesce_stream_remainder_flush() -> Result<()> {
let schema = create_test_schema();
// Create 10 small batch, each with 3 rows. Total 30 rows.
let small_batch = create_test_batch();
let batches = vec![small_batch.clone(); 10];

let input_stream = MemoryStream::try_new(batches, schema.clone(), None)?;
let input_stream = Box::pin(input_stream) as SendableRecordBatchStream;

// Target set to 100 rows.
// Because 30 < 100, it can never be filled. Must depend on the `finish()` mechanism to flush out these 30 rows at the end of the stream.
let target_batch_size = 100;

let coalesced_stream = CoalescedShuffleReaderStream::new(
input_stream,
target_batch_size,
None,
&ExecutionPlanMetricsSet::new(),
0,
);

let output_batches = common::collect(Box::pin(coalesced_stream)).await?;

// Assertions
assert_eq!(output_batches.len(), 1); // Should only have 1 batch
assert_eq!(output_batches[0].num_rows(), 30); // Should contain all 30 rows

Ok(())
}

#[tokio::test]
async fn test_coalesce_stream_large_batch() -> Result<()> {
let schema = create_test_schema();

// 1. Create a large batch (20 rows)
let big_batch = create_custom_test_batch(20);
let batches = vec![big_batch.clone(); 10]; // Total 200 rows

let input_stream = MemoryStream::try_new(batches, schema.clone(), None)?;
let input_stream = Box::pin(input_stream) as SendableRecordBatchStream;

// 2. Target set to small size, 10 rows
let target_batch_size = 10;

let coalesced_stream = CoalescedShuffleReaderStream::new(
input_stream,
target_batch_size,
None,
&ExecutionPlanMetricsSet::new(),
0,
);

let output_batches = common::collect(Box::pin(coalesced_stream)).await?;

// 3. Validation: It should not split the large batch, but directly output it
// Coalescer will not split the batch if size > (max_batch_size / 2)
assert_eq!(output_batches.len(), 10);
assert_eq!(output_batches[0].num_rows(), 20);

Ok(())
}

use futures::stream;

#[tokio::test]
async fn test_coalesce_stream_error_propagation() -> Result<()> {
let schema = create_test_schema();
let small_batch = create_test_batch(); // 3行

// 1. Construct a stream with error
let batches = vec![
Ok(small_batch),
Err(DataFusionError::Execution(
"Network connection failed".to_string(),
)),
];

// 2. Construct a stream with error
let stream = stream::iter(batches);
let input_stream =
Box::pin(RecordBatchStreamAdapter::new(schema.clone(), stream));

// 3. Configure Coalescer
let target_batch_size = 10;

let coalesced_stream = CoalescedShuffleReaderStream::new(
input_stream,
target_batch_size,
None,
&ExecutionPlanMetricsSet::new(),
0,
);

// 4. Execute stream
let result = common::collect(Box::pin(coalesced_stream)).await;

// 5. Validation
assert!(result.is_err());
assert!(
result
.unwrap_err()
.to_string()
.contains("Network connection failed")
);

Ok(())
}
}