-
Notifications
You must be signed in to change notification settings - Fork 1.8k
Use partial aggregation schema for spilling to avoid column mismatch in GroupedHashAggregateStream #13995
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Use partial aggregation schema for spilling to avoid column mismatch in GroupedHashAggregateStream #13995
Changes from 10 commits
da2b11a
01d2b60
e094adb
d066aff
04d9123
242f5ab
270efd7
38ade08
f4fedea
9d6f405
5471775
b682e8c
8a00829
4e312e1
f521846
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -20,11 +20,6 @@ | |
| #[cfg(feature = "parquet")] | ||
| mod parquet; | ||
|
|
||
| use std::any::Any; | ||
| use std::borrow::Cow; | ||
| use std::collections::HashMap; | ||
| use std::sync::Arc; | ||
|
|
||
| use crate::arrow::record_batch::RecordBatch; | ||
| use crate::arrow::util::pretty; | ||
| use crate::datasource::file_format::csv::CsvFormatFactory; | ||
|
|
@@ -43,6 +38,10 @@ use crate::physical_plan::{ | |
| ExecutionPlan, SendableRecordBatchStream, | ||
| }; | ||
| use crate::prelude::SessionContext; | ||
| use std::any::Any; | ||
| use std::borrow::Cow; | ||
| use std::collections::HashMap; | ||
| use std::sync::Arc; | ||
|
|
||
| use arrow::array::{Array, ArrayRef, Int64Array, StringArray}; | ||
| use arrow::compute::{cast, concat}; | ||
|
|
@@ -1970,8 +1969,20 @@ mod tests { | |
| use std::vec; | ||
|
|
||
| use super::*; | ||
| use crate::arrow::array::{Float64Array, UInt32Array}; | ||
| use crate::assert_batches_sorted_eq; | ||
| use crate::execution::context::SessionConfig; | ||
| use crate::execution::memory_pool::FairSpillPool; | ||
| use crate::execution::runtime_env::RuntimeEnvBuilder; | ||
| use crate::physical_expr::aggregate::AggregateExprBuilder; | ||
| use crate::physical_expr::aggregate::AggregateFunctionExpr; | ||
| use crate::physical_plan::aggregates::AggregateExec; | ||
| use crate::physical_plan::aggregates::AggregateMode; | ||
| use crate::physical_plan::aggregates::PhysicalGroupBy; | ||
| use crate::physical_plan::common; | ||
| use crate::physical_plan::expressions::col as physical_col; | ||
| use crate::physical_plan::memory::MemoryExec; | ||
| use crate::physical_plan::metrics::MetricValue; | ||
| use crate::physical_plan::{ColumnarValue, Partitioning, PhysicalExpr}; | ||
| use crate::test_util::{register_aggregate_csv, test_table, test_table_with_name}; | ||
|
|
||
|
|
@@ -2743,6 +2754,143 @@ mod tests { | |
| Ok(()) | ||
| } | ||
|
|
||
| // test for https://github.com/apache/datafusion/issues/13949 | ||
| async fn run_test_with_spill_pool_if_necessary( | ||
|
||
| pool_size: usize, | ||
| expect_spill: bool, | ||
| ) -> Result<()> { | ||
| fn create_record_batch( | ||
| schema: &Arc<Schema>, | ||
| data: (Vec<u32>, Vec<f64>), | ||
| ) -> Result<RecordBatch> { | ||
| Ok(RecordBatch::try_new( | ||
| Arc::clone(schema), | ||
| vec![ | ||
| Arc::new(UInt32Array::from(data.0)), | ||
| Arc::new(Float64Array::from(data.1)), | ||
| ], | ||
| )?) | ||
| } | ||
|
|
||
| let schema = Arc::new(Schema::new(vec![ | ||
| Field::new("a", DataType::UInt32, false), | ||
| Field::new("b", DataType::Float64, false), | ||
| ])); | ||
|
|
||
| let batches = vec![ | ||
| create_record_batch(&schema, (vec![2, 3, 4, 4], vec![1.0, 2.0, 3.0, 4.0]))?, | ||
| create_record_batch(&schema, (vec![2, 3, 4, 4], vec![1.0, 2.0, 3.0, 4.0]))?, | ||
| ]; | ||
| let plan: Arc<dyn ExecutionPlan> = | ||
| Arc::new(MemoryExec::try_new(&[batches], schema.clone(), None)?); | ||
|
|
||
| let grouping_set = PhysicalGroupBy::new( | ||
| vec![(physical_col("a", &schema)?, "a".to_string())], | ||
| vec![], | ||
| vec![vec![false]], | ||
| ); | ||
|
|
||
| // Test with MIN for simple intermediate state (min) and AVG for multiple intermediate states (partial sum, partial count). | ||
| let aggregates: Vec<Arc<AggregateFunctionExpr>> = vec![ | ||
| Arc::new( | ||
| AggregateExprBuilder::new( | ||
| datafusion_functions_aggregate::min_max::min_udaf(), | ||
| vec![physical_col("b", &schema)?], | ||
| ) | ||
| .schema(schema.clone()) | ||
| .alias("MIN(b)") | ||
| .build()?, | ||
| ), | ||
| Arc::new( | ||
| AggregateExprBuilder::new( | ||
| datafusion_functions_aggregate::average::avg_udaf(), | ||
| vec![physical_col("b", &schema)?], | ||
| ) | ||
| .schema(schema.clone()) | ||
| .alias("AVG(b)") | ||
| .build()?, | ||
| ), | ||
| ]; | ||
|
|
||
| let single_aggregate = Arc::new(AggregateExec::try_new( | ||
| AggregateMode::Single, | ||
| grouping_set, | ||
| aggregates, | ||
| vec![None, None], | ||
| plan, | ||
| schema.clone(), | ||
| )?); | ||
|
|
||
| let batch_size = 2; | ||
| let memory_pool = Arc::new(FairSpillPool::new(pool_size)); | ||
| let task_ctx = Arc::new( | ||
| TaskContext::default() | ||
| .with_session_config(SessionConfig::new().with_batch_size(batch_size)) | ||
| .with_runtime(Arc::new( | ||
| RuntimeEnvBuilder::new() | ||
| .with_memory_pool(memory_pool) | ||
| .build()?, | ||
| )), | ||
| ); | ||
|
|
||
| let result = | ||
| common::collect(single_aggregate.execute(0, Arc::clone(&task_ctx))?).await?; | ||
|
|
||
|
||
| assert_spill_count_metric(expect_spill, single_aggregate); | ||
|
|
||
| #[rustfmt::skip] | ||
| assert_batches_sorted_eq!( | ||
| [ | ||
| "+---+--------+--------+", | ||
| "| a | MIN(b) | AVG(b) |", | ||
| "+---+--------+--------+", | ||
| "| 2 | 1.0 | 1.0 |", | ||
| "| 3 | 2.0 | 2.0 |", | ||
| "| 4 | 3.0 | 3.5 |", | ||
| "+---+--------+--------+", | ||
| ], | ||
| &result | ||
| ); | ||
|
|
||
| Ok(()) | ||
| } | ||
|
|
||
| fn assert_spill_count_metric( | ||
| expect_spill: bool, | ||
| single_aggregate: Arc<AggregateExec>, | ||
| ) { | ||
| if let Some(metrics_set) = single_aggregate.metrics() { | ||
| let mut spill_count = 0; | ||
|
|
||
| // Inspect metrics for SpillCount | ||
| for metric in metrics_set.iter() { | ||
| if let MetricValue::SpillCount(count) = metric.value() { | ||
| spill_count = count.value(); | ||
| break; | ||
| } | ||
| } | ||
|
|
||
| if expect_spill && spill_count == 0 { | ||
| panic!( | ||
| "Expected spill but SpillCount metric not found or SpillCount was 0." | ||
| ); | ||
| } else if !expect_spill && spill_count > 0 { | ||
| panic!("Expected no spill but found SpillCount metric with value greater than 0."); | ||
| } | ||
| } else { | ||
| panic!("No metrics returned from the operator; cannot verify spilling."); | ||
| } | ||
| } | ||
|
|
||
| #[tokio::test] | ||
| async fn test_aggregate_with_spill_if_necessary() -> Result<()> { | ||
| // test with spill | ||
| run_test_with_spill_pool_if_necessary(2_000, true).await?; | ||
| // test without spill | ||
| run_test_with_spill_pool_if_necessary(20_000, false).await?; | ||
| Ok(()) | ||
| } | ||
|
|
||
| #[tokio::test] | ||
| async fn test_aggregate_name_collision() -> Result<()> { | ||
| let df = test_table().await?; | ||
|
|
||
| Original file line number | Diff line number | Diff line change | ||||||
|---|---|---|---|---|---|---|---|---|
|
|
@@ -36,7 +36,7 @@ use crate::{aggregates, metrics, ExecutionPlan, PhysicalExpr}; | |||||||
| use crate::{RecordBatchStream, SendableRecordBatchStream}; | ||||||||
|
|
||||||||
| use arrow::array::*; | ||||||||
| use arrow::datatypes::SchemaRef; | ||||||||
| use arrow::datatypes::{Schema, SchemaRef}; | ||||||||
| use arrow_schema::SortOptions; | ||||||||
| use datafusion_common::{internal_err, DataFusionError, Result}; | ||||||||
| use datafusion_execution::disk_manager::RefCountedTempFile; | ||||||||
|
|
@@ -490,6 +490,11 @@ impl GroupedHashAggregateStream { | |||||||
| .collect::<Result<_>>()?; | ||||||||
|
|
||||||||
| let group_schema = group_schema(&agg.input().schema(), &agg_group_by)?; | ||||||||
|
|
||||||||
| // Build partial aggregate schema for spills | ||||||||
| let partial_agg_schema = | ||||||||
| build_partial_agg_schema(&group_schema, &aggregate_exprs)?; | ||||||||
|
|
||||||||
| let spill_expr = group_schema | ||||||||
| .fields | ||||||||
| .into_iter() | ||||||||
|
|
@@ -522,7 +527,7 @@ impl GroupedHashAggregateStream { | |||||||
| let spill_state = SpillState { | ||||||||
| spills: vec![], | ||||||||
| spill_expr, | ||||||||
| spill_schema: Arc::clone(&agg_schema), | ||||||||
| spill_schema: partial_agg_schema, | ||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It seems like the issue was related only to
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. hi @korowa ,
In other words, remove these lines, am I correct? datafusion/datafusion/physical-plan/src/aggregates/row_hash.rs Lines 967 to 969 in 487b952
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, this line seems to be redundant now -- I'd expect all aggregation modes to have the same spill schema (which is set by this PR), so it shouldn't depend on stream input anymore.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks for confirming. |
||||||||
| is_stream_merging: false, | ||||||||
| merging_aggregate_arguments, | ||||||||
| merging_group_by: PhysicalGroupBy::new_single(agg_group_by.expr.clone()), | ||||||||
|
|
@@ -802,6 +807,45 @@ impl RecordBatchStream for GroupedHashAggregateStream { | |||||||
| } | ||||||||
| } | ||||||||
|
|
||||||||
| // fix https://github.com/apache/datafusion/issues/13949 | ||||||||
| /// Builds a **partial aggregation** schema by combining the group columns and | ||||||||
| /// the accumulator state columns produced by each aggregate expression. | ||||||||
| /// | ||||||||
| /// # Why Partial Aggregation Schema Is Needed | ||||||||
| /// | ||||||||
| /// In a multi-stage (partial/final) aggregation strategy, each partial-aggregate | ||||||||
| /// operator produces *intermediate* states (e.g., partial sums, counts) rather | ||||||||
| /// than final scalar values. These extra columns do **not** exist in the original | ||||||||
| /// input schema (which may be something like `[colA, colB, ...]`). Instead, | ||||||||
| /// each aggregator adds its own internal state columns (e.g., `[acc_state_1, acc_state_2, ...]`). | ||||||||
| /// | ||||||||
| /// Therefore, when we spill these intermediate states or pass them to another | ||||||||
| /// aggregation operator, we must use a schema that includes both the group | ||||||||
| /// columns **and** the partial-state columns. Otherwise, using the original input | ||||||||
| /// schema to read partial states will result in a column-count mismatch error. | ||||||||
| /// | ||||||||
| /// This helper function constructs such a schema: | ||||||||
| /// `[group_col_1, group_col_2, ..., state_col_1, state_col_2, ...]` | ||||||||
| /// so that partial aggregation data can be handled consistently. | ||||||||
| fn build_partial_agg_schema( | ||||||||
|
||||||||
| group_schema: &SchemaRef, | ||||||||
| aggregate_exprs: &[Arc<AggregateFunctionExpr>], | ||||||||
| ) -> Result<SchemaRef> { | ||||||||
| let fields = group_schema.fields().clone(); | ||||||||
| // convert fields to Vec<Arc<Field>> | ||||||||
| let mut fields = fields.iter().cloned().collect::<Vec<_>>(); | ||||||||
| for expr in aggregate_exprs { | ||||||||
| let state_fields = expr.state_fields(); | ||||||||
| fields.extend( | ||||||||
| state_fields | ||||||||
| .into_iter() | ||||||||
| .flat_map(|inner_vec| inner_vec.into_iter()) // Flatten the Vec<Vec<Field>> to Vec<Field> | ||||||||
| .map(Arc::new), // Wrap each Field in Arc | ||||||||
| ); | ||||||||
| } | ||||||||
| Ok(Arc::new(Schema::new(fields))) | ||||||||
| } | ||||||||
|
|
||||||||
| impl GroupedHashAggregateStream { | ||||||||
| /// Perform group-by aggregation for the given [`RecordBatch`]. | ||||||||
| fn group_aggregate_batch(&mut self, batch: RecordBatch) -> Result<()> { | ||||||||
|
|
@@ -964,9 +1008,6 @@ impl GroupedHashAggregateStream { | |||||||
| && self.update_memory_reservation().is_err() | ||||||||
| { | ||||||||
| assert_ne!(self.mode, AggregateMode::Partial); | ||||||||
| // Use input batch (Partial mode) schema for spilling because | ||||||||
| // the spilled data will be merged and re-evaluated later. | ||||||||
| self.spill_state.spill_schema = batch.schema(); | ||||||||
| self.spill()?; | ||||||||
| self.clear_shrink(batch); | ||||||||
| } | ||||||||
|
|
||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
minor: this import reordering can be reverted to leave the file unmodified