Skip to content
Closed
246 changes: 193 additions & 53 deletions datafusion/physical-plan/src/topk/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@ use arrow::{
};
use datafusion_expr::{ColumnarValue, Operator};
use std::mem::size_of;
use std::{cmp::Ordering, collections::BinaryHeap, sync::Arc};
use std::panic::{AssertUnwindSafe, catch_unwind};
use std::{any::Any, cmp::Ordering, collections::BinaryHeap, sync::Arc};

use super::metrics::{
BaselineMetrics, Count, ExecutionPlanMetricsSet, MetricBuilder, RecordOutput,
Expand All @@ -35,7 +36,7 @@ use crate::{SendableRecordBatchStream, stream::RecordBatchStreamAdapter};
use arrow::array::{ArrayRef, RecordBatch};
use arrow::datatypes::SchemaRef;
use datafusion_common::{
HashMap, Result, ScalarValue, internal_datafusion_err, internal_err,
DataFusionError, HashMap, Result, ScalarValue, internal_datafusion_err, internal_err,
};
use datafusion_execution::{
memory_pool::{MemoryConsumer, MemoryReservation},
Expand Down Expand Up @@ -109,8 +110,6 @@ pub struct TopK {
metrics: TopKMetrics,
/// Reservation
reservation: MemoryReservation,
/// The target number of rows for output batches
batch_size: usize,
/// sort expressions
expr: LexOrdering,
/// row converter, for sort keys
Expand Down Expand Up @@ -216,7 +215,6 @@ impl TopK {
schema: Arc::clone(&schema),
metrics: TopKMetrics::new(metrics, partition_id),
reservation,
batch_size,
expr,
row_converter,
scratch_rows,
Expand Down Expand Up @@ -588,7 +586,6 @@ impl TopK {
schema,
metrics,
reservation: _,
batch_size,
expr: _,
row_converter: _,
scratch_rows: _,
Expand All @@ -605,20 +602,10 @@ impl TopK {

// break into record batches as needed
let mut batches = vec![];
if let Some(mut batch) = heap.emit()? {
for batch in heap.emit()? {
(&batch).record_output(&metrics.baseline);

loop {
if batch.num_rows() <= batch_size {
batches.push(Ok(batch));
break;
} else {
batches.push(Ok(batch.slice(0, batch_size)));
let remaining_length = batch.num_rows() - batch_size;
batch = batch.slice(batch_size, remaining_length);
}
}
};
batches.push(Ok(batch));
}
Ok(Box::pin(RecordBatchStreamAdapter::new(
schema,
futures::stream::iter(batches),
Expand Down Expand Up @@ -748,47 +735,92 @@ impl TopKHeap {
}

/// Returns the values stored in this heap, from values low to
/// high, as a single [`RecordBatch`], resetting the inner heap
pub fn emit(&mut self) -> Result<Option<RecordBatch>> {
/// high, as [`RecordBatch`]es, resetting the inner heap
pub fn emit(&mut self) -> Result<Vec<RecordBatch>> {
Ok(self.emit_with_state()?.0)
}

/// Returns the values stored in this heap, from values low to
/// high, as a single [`RecordBatch`], and a sorted vec of the
/// high, as [`RecordBatch`]es, and a sorted vec of the
/// current heap's contents
pub fn emit_with_state(&mut self) -> Result<(Option<RecordBatch>, Vec<TopKRow>)> {
pub fn emit_with_state(&mut self) -> Result<(Vec<RecordBatch>, Vec<TopKRow>)> {
// generate sorted rows
let topk_rows = std::mem::take(&mut self.inner).into_sorted_vec();

if self.store.is_empty() {
return Ok((None, topk_rows));
return Ok((Vec::new(), topk_rows));
}

// Collect the batches into a vec and store the "batch_id -> array_pos" mapping, to then
// build the `indices` vec below. This is needed since the batch ids are not continuous.
let batches = self.interleave_topk_rows(&topk_rows, self.batch_size)?;

Ok((batches, topk_rows))
}

fn interleave_topk_rows(
Comment thread
aviralgarg05 marked this conversation as resolved.
Comment thread
aviralgarg05 marked this conversation as resolved.
&self,
topk_rows: &[TopKRow],
max_rows_per_batch: usize,
) -> Result<Vec<RecordBatch>> {
// Collect the batches into a vec and store the "batch_id -> array_pos" mapping.
// This is needed since the batch ids are not continuous.
let mut record_batches = Vec::new();
let mut batch_id_array_pos = HashMap::new();
for (array_pos, (batch_id, batch)) in self.store.batches.iter().enumerate() {
record_batches.push(&batch.batch);
batch_id_array_pos.insert(*batch_id, array_pos);
}

let indices: Vec<_> = topk_rows
let all_indices = topk_rows
.iter()
.map(|k| (batch_id_array_pos[&k.batch_id], k.index))
.collect();
.map(|k| {
let array_pos = batch_id_array_pos.get(&k.batch_id).ok_or_else(|| {
internal_datafusion_err!(
"TopK row references missing batch id {}",
k.batch_id
)
})?;
Ok((*array_pos, k.index))
})
.collect::<Result<Vec<_>>>()?;

// At this point `indices` contains indexes within the
// rows and `input_arrays` contains a reference to the
// relevant RecordBatch for that index. `interleave_record_batch` pulls
// them together into a single new batch
let new_batch = interleave_record_batch(&record_batches, &indices)?;
let max_rows_per_batch = max_rows_per_batch.max(1);
let mut batches = Vec::new();
let mut start = 0;
while start < all_indices.len() {
let remaining_rows = all_indices.len() - start;
let mut chunk_size = remaining_rows.min(max_rows_per_batch);

Ok((Some(new_batch), topk_rows))
loop {
let indices = &all_indices[start..start + chunk_size];
match try_interleave_record_batch(&record_batches, indices) {
Ok(batch) => {
batches.push(batch);
start += chunk_size;
break;
}
Err(InterleaveError::Overflow(_)) if chunk_size > 1 => {
chunk_size = chunk_size.div_ceil(2);
if chunk_size == 0 {
return internal_err!(
"Invalid TopK chunk size during interleave"
);
}
}
Err(InterleaveError::Overflow(message)) => {
return internal_err!(
"TopK failed to interleave a single row due to offset overflow: {message}"
);
}
Err(InterleaveError::DataFusion(err)) => return Err(err),
}
}
}

Ok(batches)
}

/// Compact this heap, rewriting all stored batches into a single
/// input batch
/// Compact this heap, rewriting all stored batches into new input
/// batches.
pub fn maybe_compact(&mut self) -> Result<()> {
// we compact if the number of "unused" rows in the store is
// past some pre-defined threshold. Target holding up to
Expand All @@ -802,32 +834,40 @@ impl TopKHeap {
if self.store.len() <= 2 || unused_rows < max_unused_rows {
return Ok(());
}
// at first, compact the entire thing always into a new batch
// at first, compact the entire thing always into new batches
// (maybe we can get fancier in the future about ignoring
// batches that have a high usage ratio already

// Note: new batch is in the same order as inner
let num_rows = self.inner.len();
let (new_batch, mut topk_rows) = self.emit_with_state()?;
let Some(new_batch) = new_batch else {
// Note: new batches are in the same order as inner
let (new_batches, mut topk_rows) = self.emit_with_state()?;
if new_batches.is_empty() {
return Ok(());
};
}

// clear all old entries in store (this invalidates all
// store_ids in `inner`)
self.store.clear();

let mut batch_entry = self.register_batch(new_batch);
batch_entry.uses = num_rows;

// rewrite all existing entries to use the new batch, and
// remove old entries. The sortedness and their relative
// position do not change
for (i, topk_row) in topk_rows.iter_mut().enumerate() {
topk_row.batch_id = batch_entry.id;
topk_row.index = i;
// rewrite all existing entries to use the compacted batches.
// The sortedness and their relative position do not change.
let mut row_offset = 0;
for new_batch in new_batches {
let mut batch_entry = self.register_batch(new_batch);
batch_entry.uses = batch_entry.batch.num_rows();

for (index, topk_row) in topk_rows[row_offset..row_offset + batch_entry.uses]
.iter_mut()
.enumerate()
{
topk_row.batch_id = batch_entry.id;
topk_row.index = index;
}
row_offset += batch_entry.uses;
self.insert_batch_entry(batch_entry);
}
self.insert_batch_entry(batch_entry);

debug_assert_eq!(row_offset, topk_rows.len());

// restore the heap
self.inner = BinaryHeap::from(topk_rows);

Expand Down Expand Up @@ -884,6 +924,56 @@ impl TopKHeap {
}
}

enum InterleaveError {
Overflow(String),
DataFusion(DataFusionError),
}

fn try_interleave_record_batch(
Comment thread
aviralgarg05 marked this conversation as resolved.
Outdated
record_batches: &[&RecordBatch],
indices: &[(usize, usize)],
) -> std::result::Result<RecordBatch, InterleaveError> {
let result = catch_unwind(AssertUnwindSafe(|| {
interleave_record_batch(record_batches, indices)
}));

match result {
Ok(Ok(batch)) => Ok(batch),
Ok(Err(err)) => {
let message = err.to_string();
if is_overflow_message(&message) {
Err(InterleaveError::Overflow(message))
} else {
Err(InterleaveError::DataFusion(err.into()))
}
}
Err(payload) => {
let message = panic_message(payload.as_ref());
if is_overflow_message(&message) {
Err(InterleaveError::Overflow(message))
} else {
Err(InterleaveError::DataFusion(internal_datafusion_err!(
"TopK interleave panicked: {message}"
)))
}
}
}
}

fn is_overflow_message(message: &str) -> bool {
message.to_ascii_lowercase().contains("overflow")
}

fn panic_message(payload: &(dyn Any + Send)) -> String {
if let Some(message) = payload.downcast_ref::<&str>() {
(*message).to_string()
} else if let Some(message) = payload.downcast_ref::<String>() {
message.clone()
} else {
"unknown panic payload".to_string()
}
}

/// Represents one of the top K rows held in this heap. Orders
/// according to memcmp of row (e.g. the arrow Row format, but could
/// also be primitive values)
Expand Down Expand Up @@ -1110,6 +1200,56 @@ mod tests {
assert_eq!(record_batch_store.batches_size, 0);
}

#[test]
fn test_topk_heap_emit_with_state_respects_batch_size() -> Result<()> {
Comment thread
aviralgarg05 marked this conversation as resolved.
let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)]));
let mut heap = TopKHeap::new(5, 2);

let batch_a = RecordBatch::try_new(
Arc::clone(&schema),
vec![Arc::new(Int32Array::from(vec![5, 3, 1]))],
)?;
let mut entry_a = heap.register_batch(batch_a);
for (index, key) in [5_u8, 3_u8, 1_u8].into_iter().enumerate() {
heap.add(&mut entry_a, [key], index);
}
heap.insert_batch_entry(entry_a);

let batch_b = RecordBatch::try_new(
Arc::clone(&schema),
vec![Arc::new(Int32Array::from(vec![4, 2]))],
)?;
let mut entry_b = heap.register_batch(batch_b);
for (index, key) in [4_u8, 2_u8].into_iter().enumerate() {
heap.add(&mut entry_b, [key], index);
}
heap.insert_batch_entry(entry_b);

let (batches, topk_rows) = heap.emit_with_state()?;
assert_eq!(batches.len(), 3);
assert_eq!(
batches
.iter()
.map(RecordBatch::num_rows)
.collect::<Vec<_>>(),
vec![2, 2, 1]
);
assert_eq!(
topk_rows.iter().map(|row| row.row[0]).collect::<Vec<_>>(),
vec![1, 2, 3, 4, 5]
);

assert_batches_eq!(
&[
"+---+", "| a |", "+---+", "| 1 |", "| 2 |", "| 3 |", "| 4 |", "| 5 |",
"+---+",
],
&batches
);

Ok(())
}

/// This test validates that the `try_finish` method marks the TopK operator as finished
/// when the prefix (on column "a") of the last row in the current batch is strictly greater
/// than the max top‑k row.
Expand Down