diff --git a/datafusion/core/src/dataframe/mod.rs b/datafusion/core/src/dataframe/mod.rs index 60a09301ae0f..0eb8b3c42504 100644 --- a/datafusion/core/src/dataframe/mod.rs +++ b/datafusion/core/src/dataframe/mod.rs @@ -20,11 +20,6 @@ #[cfg(feature = "parquet")] mod parquet; -use std::any::Any; -use std::borrow::Cow; -use std::collections::HashMap; -use std::sync::Arc; - use crate::arrow::record_batch::RecordBatch; use crate::arrow::util::pretty; use crate::datasource::file_format::csv::CsvFormatFactory; @@ -43,6 +38,10 @@ use crate::physical_plan::{ ExecutionPlan, SendableRecordBatchStream, }; use crate::prelude::SessionContext; +use std::any::Any; +use std::borrow::Cow; +use std::collections::HashMap; +use std::sync::Arc; use arrow::array::{Array, ArrayRef, Int64Array, StringArray}; use arrow::compute::{cast, concat}; diff --git a/datafusion/physical-plan/src/aggregates/mod.rs b/datafusion/physical-plan/src/aggregates/mod.rs index c04211d679ca..4812fa41347d 100644 --- a/datafusion/physical-plan/src/aggregates/mod.rs +++ b/datafusion/physical-plan/src/aggregates/mod.rs @@ -1317,6 +1317,7 @@ mod tests { use crate::execution_plan::Boundedness; use crate::expressions::col; use crate::memory::MemoryExec; + use crate::metrics::MetricValue; use crate::test::assert_is_pending; use crate::test::exec::{assert_strong_count_converges_to_zero, BlockingExec}; use crate::RecordBatchStream; @@ -2783,4 +2784,137 @@ mod tests { assert_eq!(aggr_schema, expected_schema); Ok(()) } + + // test for https://github.com/apache/datafusion/issues/13949 + async fn run_test_with_spill_pool_if_necessary( + pool_size: usize, + expect_spill: bool, + ) -> Result<()> { + fn create_record_batch( + schema: &Arc, + data: (Vec, Vec), + ) -> Result { + Ok(RecordBatch::try_new( + Arc::clone(schema), + vec![ + Arc::new(UInt32Array::from(data.0)), + Arc::new(Float64Array::from(data.1)), + ], + )?) + } + + let schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::UInt32, false), + Field::new("b", DataType::Float64, false), + ])); + + let batches = vec![ + create_record_batch(&schema, (vec![2, 3, 4, 4], vec![1.0, 2.0, 3.0, 4.0]))?, + create_record_batch(&schema, (vec![2, 3, 4, 4], vec![1.0, 2.0, 3.0, 4.0]))?, + ]; + let plan: Arc = + Arc::new(MemoryExec::try_new(&[batches], Arc::clone(&schema), None)?); + + let grouping_set = PhysicalGroupBy::new( + vec![(col("a", &schema)?, "a".to_string())], + vec![], + vec![vec![false]], + ); + + // Test with MIN for simple intermediate state (min) and AVG for multiple intermediate states (partial sum, partial count). + let aggregates: Vec> = vec![ + Arc::new( + AggregateExprBuilder::new( + datafusion_functions_aggregate::min_max::min_udaf(), + vec![col("b", &schema)?], + ) + .schema(Arc::clone(&schema)) + .alias("MIN(b)") + .build()?, + ), + Arc::new( + AggregateExprBuilder::new(avg_udaf(), vec![col("b", &schema)?]) + .schema(Arc::clone(&schema)) + .alias("AVG(b)") + .build()?, + ), + ]; + + let single_aggregate = Arc::new(AggregateExec::try_new( + AggregateMode::Single, + grouping_set, + aggregates, + vec![None, None], + plan, + Arc::clone(&schema), + )?); + + let batch_size = 2; + let memory_pool = Arc::new(FairSpillPool::new(pool_size)); + let task_ctx = Arc::new( + TaskContext::default() + .with_session_config(SessionConfig::new().with_batch_size(batch_size)) + .with_runtime(Arc::new( + RuntimeEnvBuilder::new() + .with_memory_pool(memory_pool) + .build()?, + )), + ); + + let result = collect(single_aggregate.execute(0, Arc::clone(&task_ctx))?).await?; + + assert_spill_count_metric(expect_spill, single_aggregate); + + #[rustfmt::skip] + assert_batches_sorted_eq!( + [ + "+---+--------+--------+", + "| a | MIN(b) | AVG(b) |", + "+---+--------+--------+", + "| 2 | 1.0 | 1.0 |", + "| 3 | 2.0 | 2.0 |", + "| 4 | 3.0 | 3.5 |", + "+---+--------+--------+", + ], + &result + ); + + Ok(()) + } + + fn assert_spill_count_metric( + expect_spill: bool, + single_aggregate: Arc, + ) { + if let Some(metrics_set) = single_aggregate.metrics() { + let mut spill_count = 0; + + // Inspect metrics for SpillCount + for metric in metrics_set.iter() { + if let MetricValue::SpillCount(count) = metric.value() { + spill_count = count.value(); + break; + } + } + + if expect_spill && spill_count == 0 { + panic!( + "Expected spill but SpillCount metric not found or SpillCount was 0." + ); + } else if !expect_spill && spill_count > 0 { + panic!("Expected no spill but found SpillCount metric with value greater than 0."); + } + } else { + panic!("No metrics returned from the operator; cannot verify spilling."); + } + } + + #[tokio::test] + async fn test_aggregate_with_spill_if_necessary() -> Result<()> { + // test with spill + run_test_with_spill_pool_if_necessary(2_000, true).await?; + // test without spill + run_test_with_spill_pool_if_necessary(20_000, false).await?; + Ok(()) + } } diff --git a/datafusion/physical-plan/src/aggregates/row_hash.rs b/datafusion/physical-plan/src/aggregates/row_hash.rs index 98787d740c20..cdb3b2199cdc 100644 --- a/datafusion/physical-plan/src/aggregates/row_hash.rs +++ b/datafusion/physical-plan/src/aggregates/row_hash.rs @@ -24,8 +24,8 @@ use std::vec; use crate::aggregates::group_values::{new_group_values, GroupValues}; use crate::aggregates::order::GroupOrderingFull; use crate::aggregates::{ - evaluate_group_by, evaluate_many, evaluate_optional, group_schema, AggregateMode, - PhysicalGroupBy, + create_schema, evaluate_group_by, evaluate_many, evaluate_optional, group_schema, + AggregateMode, PhysicalGroupBy, }; use crate::metrics::{BaselineMetrics, MetricBuilder, RecordOutput}; use crate::sorts::sort::sort_batch; @@ -490,6 +490,31 @@ impl GroupedHashAggregateStream { .collect::>()?; let group_schema = group_schema(&agg.input().schema(), &agg_group_by)?; + + // fix https://github.com/apache/datafusion/issues/13949 + // Builds a **partial aggregation** schema by combining the group columns and + // the accumulator state columns produced by each aggregate expression. + // + // # Why Partial Aggregation Schema Is Needed + // + // In a multi-stage (partial/final) aggregation strategy, each partial-aggregate + // operator produces *intermediate* states (e.g., partial sums, counts) rather + // than final scalar values. These extra columns do **not** exist in the original + // input schema (which may be something like `[colA, colB, ...]`). Instead, + // each aggregator adds its own internal state columns (e.g., `[acc_state_1, acc_state_2, ...]`). + // + // Therefore, when we spill these intermediate states or pass them to another + // aggregation operator, we must use a schema that includes both the group + // columns **and** the partial-state columns. + let partial_agg_schema = create_schema( + &agg.input().schema(), + &agg_group_by, + &aggregate_exprs, + AggregateMode::Partial, + )?; + + let partial_agg_schema = Arc::new(partial_agg_schema); + let spill_expr = group_schema .fields .into_iter() @@ -522,7 +547,7 @@ impl GroupedHashAggregateStream { let spill_state = SpillState { spills: vec![], spill_expr, - spill_schema: Arc::clone(&agg_schema), + spill_schema: partial_agg_schema, is_stream_merging: false, merging_aggregate_arguments, merging_group_by: PhysicalGroupBy::new_single(agg_group_by.expr.clone()), @@ -964,9 +989,6 @@ impl GroupedHashAggregateStream { && self.update_memory_reservation().is_err() { assert_ne!(self.mode, AggregateMode::Partial); - // Use input batch (Partial mode) schema for spilling because - // the spilled data will be merged and re-evaluated later. - self.spill_state.spill_schema = batch.schema(); self.spill()?; self.clear_shrink(batch); }