Skip to content
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
158 changes: 153 additions & 5 deletions datafusion/core/src/dataframe/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,6 @@
#[cfg(feature = "parquet")]
mod parquet;

use std::any::Any;
use std::borrow::Cow;
use std::collections::HashMap;
use std::sync::Arc;

use crate::arrow::record_batch::RecordBatch;
use crate::arrow::util::pretty;
use crate::datasource::file_format::csv::CsvFormatFactory;
Expand All @@ -43,6 +38,10 @@ use crate::physical_plan::{
ExecutionPlan, SendableRecordBatchStream,
};
use crate::prelude::SessionContext;
use std::any::Any;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

minor: this import reordering can be reverted to leave the file unmodified

use std::borrow::Cow;
use std::collections::HashMap;
use std::sync::Arc;

use arrow::array::{Array, ArrayRef, Int64Array, StringArray};
use arrow::compute::{cast, concat};
Expand Down Expand Up @@ -1970,8 +1969,20 @@ mod tests {
use std::vec;

use super::*;
use crate::arrow::array::{Float64Array, UInt32Array};
use crate::assert_batches_sorted_eq;
use crate::execution::context::SessionConfig;
use crate::execution::memory_pool::FairSpillPool;
use crate::execution::runtime_env::RuntimeEnvBuilder;
use crate::physical_expr::aggregate::AggregateExprBuilder;
use crate::physical_expr::aggregate::AggregateFunctionExpr;
use crate::physical_plan::aggregates::AggregateExec;
use crate::physical_plan::aggregates::AggregateMode;
use crate::physical_plan::aggregates::PhysicalGroupBy;
use crate::physical_plan::common;
use crate::physical_plan::expressions::col as physical_col;
use crate::physical_plan::memory::MemoryExec;
use crate::physical_plan::metrics::MetricValue;
use crate::physical_plan::{ColumnarValue, Partitioning, PhysicalExpr};
use crate::test_util::{register_aggregate_csv, test_table, test_table_with_name};

Expand Down Expand Up @@ -2743,6 +2754,143 @@ mod tests {
Ok(())
}

// test for https://github.com/apache/datafusion/issues/13949
async fn run_test_with_spill_pool_if_necessary(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suppose it'll be better to move this test to other aggregate tests in datafusion/physical-plan/src/mod.rs

Copy link
Contributor Author

@kosiew kosiew Jan 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My bad, yes, I meant aggregates/mod.rs

pool_size: usize,
expect_spill: bool,
) -> Result<()> {
fn create_record_batch(
schema: &Arc<Schema>,
data: (Vec<u32>, Vec<f64>),
) -> Result<RecordBatch> {
Ok(RecordBatch::try_new(
Arc::clone(schema),
vec![
Arc::new(UInt32Array::from(data.0)),
Arc::new(Float64Array::from(data.1)),
],
)?)
}

let schema = Arc::new(Schema::new(vec![
Field::new("a", DataType::UInt32, false),
Field::new("b", DataType::Float64, false),
]));

let batches = vec![
create_record_batch(&schema, (vec![2, 3, 4, 4], vec![1.0, 2.0, 3.0, 4.0]))?,
create_record_batch(&schema, (vec![2, 3, 4, 4], vec![1.0, 2.0, 3.0, 4.0]))?,
];
let plan: Arc<dyn ExecutionPlan> =
Arc::new(MemoryExec::try_new(&[batches], schema.clone(), None)?);

let grouping_set = PhysicalGroupBy::new(
vec![(physical_col("a", &schema)?, "a".to_string())],
vec![],
vec![vec![false]],
);

// Test with MIN for simple intermediate state (min) and AVG for multiple intermediate states (partial sum, partial count).
let aggregates: Vec<Arc<AggregateFunctionExpr>> = vec![
Arc::new(
AggregateExprBuilder::new(
datafusion_functions_aggregate::min_max::min_udaf(),
vec![physical_col("b", &schema)?],
)
.schema(schema.clone())
.alias("MIN(b)")
.build()?,
),
Arc::new(
AggregateExprBuilder::new(
datafusion_functions_aggregate::average::avg_udaf(),
vec![physical_col("b", &schema)?],
)
.schema(schema.clone())
.alias("AVG(b)")
.build()?,
),
];

let single_aggregate = Arc::new(AggregateExec::try_new(
AggregateMode::Single,
grouping_set,
aggregates,
vec![None, None],
plan,
schema.clone(),
)?);

let batch_size = 2;
let memory_pool = Arc::new(FairSpillPool::new(pool_size));
let task_ctx = Arc::new(
TaskContext::default()
.with_session_config(SessionConfig::new().with_batch_size(batch_size))
.with_runtime(Arc::new(
RuntimeEnvBuilder::new()
.with_memory_pool(memory_pool)
.build()?,
)),
);

let result =
common::collect(single_aggregate.execute(0, Arc::clone(&task_ctx))?).await?;

Copy link
Contributor

@2010YOUY01 2010YOUY01 Jan 4, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suggest to add an assertion here to make sure spilling actually happened for certain test cases. Like:

        let metrics = single_aggregate.metrics();
        // ...and assert some metrics inside like 'spill count' is > 0

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @2010YOUY01 for the review and suggestions.
I have implemented both.

assert_spill_count_metric(expect_spill, single_aggregate);

#[rustfmt::skip]
assert_batches_sorted_eq!(
[
"+---+--------+--------+",
"| a | MIN(b) | AVG(b) |",
"+---+--------+--------+",
"| 2 | 1.0 | 1.0 |",
"| 3 | 2.0 | 2.0 |",
"| 4 | 3.0 | 3.5 |",
"+---+--------+--------+",
],
&result
);

Ok(())
}

fn assert_spill_count_metric(
expect_spill: bool,
single_aggregate: Arc<AggregateExec>,
) {
if let Some(metrics_set) = single_aggregate.metrics() {
let mut spill_count = 0;

// Inspect metrics for SpillCount
for metric in metrics_set.iter() {
if let MetricValue::SpillCount(count) = metric.value() {
spill_count = count.value();
break;
}
}

if expect_spill && spill_count == 0 {
panic!(
"Expected spill but SpillCount metric not found or SpillCount was 0."
);
} else if !expect_spill && spill_count > 0 {
panic!("Expected no spill but found SpillCount metric with value greater than 0.");
}
} else {
panic!("No metrics returned from the operator; cannot verify spilling.");
}
}

#[tokio::test]
async fn test_aggregate_with_spill_if_necessary() -> Result<()> {
// test with spill
run_test_with_spill_pool_if_necessary(2_000, true).await?;
// test without spill
run_test_with_spill_pool_if_necessary(20_000, false).await?;
Ok(())
}

#[tokio::test]
async fn test_aggregate_name_collision() -> Result<()> {
let df = test_table().await?;
Expand Down
51 changes: 46 additions & 5 deletions datafusion/physical-plan/src/aggregates/row_hash.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ use crate::{aggregates, metrics, ExecutionPlan, PhysicalExpr};
use crate::{RecordBatchStream, SendableRecordBatchStream};

use arrow::array::*;
use arrow::datatypes::SchemaRef;
use arrow::datatypes::{Schema, SchemaRef};
use arrow_schema::SortOptions;
use datafusion_common::{internal_err, DataFusionError, Result};
use datafusion_execution::disk_manager::RefCountedTempFile;
Expand Down Expand Up @@ -490,6 +490,11 @@ impl GroupedHashAggregateStream {
.collect::<Result<_>>()?;

let group_schema = group_schema(&agg.input().schema(), &agg_group_by)?;

// Build partial aggregate schema for spills
let partial_agg_schema =
build_partial_agg_schema(&group_schema, &aggregate_exprs)?;

let spill_expr = group_schema
.fields
.into_iter()
Expand Down Expand Up @@ -522,7 +527,7 @@ impl GroupedHashAggregateStream {
let spill_state = SpillState {
spills: vec![],
spill_expr,
spill_schema: Arc::clone(&agg_schema),
spill_schema: partial_agg_schema,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems like the issue was related only to AggregateMode::Single[Partitioned] cases, since for both Final and FinalPartitioned, there is a reassignment right before spilling (the new value is a schema for Partial output which is exactly group_by + state fields). Perhaps we can remove this reassignment now and rely on original spill_schema value set on stream creation (before removing it, we need to ensure that spill schema will be equal to intermediate result schema for any aggregation mode which supports spilling)?

Copy link
Contributor Author

@kosiew kosiew Jan 6, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hi @korowa ,

remove this reassignment now

In other words, remove these lines, am I correct?

// Use input batch (Partial mode) schema for spilling because
// the spilled data will be merged and re-evaluated later.
self.spill_state.spill_schema = batch.schema();

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, this line seems to be redundant now -- I'd expect all aggregation modes to have the same spill schema (which is set by this PR), so it shouldn't depend on stream input anymore.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for confirming.
The lines are removed.

is_stream_merging: false,
merging_aggregate_arguments,
merging_group_by: PhysicalGroupBy::new_single(agg_group_by.expr.clone()),
Expand Down Expand Up @@ -802,6 +807,45 @@ impl RecordBatchStream for GroupedHashAggregateStream {
}
}

// fix https://github.com/apache/datafusion/issues/13949
/// Builds a **partial aggregation** schema by combining the group columns and
/// the accumulator state columns produced by each aggregate expression.
///
/// # Why Partial Aggregation Schema Is Needed
///
/// In a multi-stage (partial/final) aggregation strategy, each partial-aggregate
/// operator produces *intermediate* states (e.g., partial sums, counts) rather
/// than final scalar values. These extra columns do **not** exist in the original
/// input schema (which may be something like `[colA, colB, ...]`). Instead,
/// each aggregator adds its own internal state columns (e.g., `[acc_state_1, acc_state_2, ...]`).
///
/// Therefore, when we spill these intermediate states or pass them to another
/// aggregation operator, we must use a schema that includes both the group
/// columns **and** the partial-state columns. Otherwise, using the original input
/// schema to read partial states will result in a column-count mismatch error.
///
/// This helper function constructs such a schema:
/// `[group_col_1, group_col_2, ..., state_col_1, state_col_2, ...]`
/// so that partial aggregation data can be handled consistently.
fn build_partial_agg_schema(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps instead of the new helper we could reuse aggregates::create_schema?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I checked create_schema and it handles aggregates like MIN, MAX well but it does not handle AVG which has multiple intermediate states (partial sum, partial count).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I'm not mistaken, it should for mode = AggregateMode::Partial -- for this case it also returns state_fields instead of result field

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Aaa..... 🤔
Thanks for the pointer. It does work.

group_schema: &SchemaRef,
aggregate_exprs: &[Arc<AggregateFunctionExpr>],
) -> Result<SchemaRef> {
let fields = group_schema.fields().clone();
// convert fields to Vec<Arc<Field>>
let mut fields = fields.iter().cloned().collect::<Vec<_>>();
for expr in aggregate_exprs {
let state_fields = expr.state_fields();
fields.extend(
state_fields
.into_iter()
.flat_map(|inner_vec| inner_vec.into_iter()) // Flatten the Vec<Vec<Field>> to Vec<Field>
.map(Arc::new), // Wrap each Field in Arc
);
}
Ok(Arc::new(Schema::new(fields)))
}

impl GroupedHashAggregateStream {
/// Perform group-by aggregation for the given [`RecordBatch`].
fn group_aggregate_batch(&mut self, batch: RecordBatch) -> Result<()> {
Expand Down Expand Up @@ -964,9 +1008,6 @@ impl GroupedHashAggregateStream {
&& self.update_memory_reservation().is_err()
{
assert_ne!(self.mode, AggregateMode::Partial);
// Use input batch (Partial mode) schema for spilling because
// the spilled data will be merged and re-evaluated later.
self.spill_state.spill_schema = batch.schema();
self.spill()?;
self.clear_shrink(batch);
}
Expand Down
Loading