Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 4 additions & 5 deletions datafusion/core/src/dataframe/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,6 @@
#[cfg(feature = "parquet")]
mod parquet;

use std::any::Any;
use std::borrow::Cow;
use std::collections::HashMap;
use std::sync::Arc;

use crate::arrow::record_batch::RecordBatch;
use crate::arrow::util::pretty;
use crate::datasource::file_format::csv::CsvFormatFactory;
Expand All @@ -43,6 +38,10 @@ use crate::physical_plan::{
ExecutionPlan, SendableRecordBatchStream,
};
use crate::prelude::SessionContext;
use std::any::Any;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

minor: this import reordering can be reverted to leave the file unmodified

use std::borrow::Cow;
use std::collections::HashMap;
use std::sync::Arc;

use arrow::array::{Array, ArrayRef, Int64Array, StringArray};
use arrow::compute::{cast, concat};
Expand Down
134 changes: 134 additions & 0 deletions datafusion/physical-plan/src/aggregates/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1317,6 +1317,7 @@ mod tests {
use crate::execution_plan::Boundedness;
use crate::expressions::col;
use crate::memory::MemoryExec;
use crate::metrics::MetricValue;
use crate::test::assert_is_pending;
use crate::test::exec::{assert_strong_count_converges_to_zero, BlockingExec};
use crate::RecordBatchStream;
Expand Down Expand Up @@ -2783,4 +2784,137 @@ mod tests {
assert_eq!(aggr_schema, expected_schema);
Ok(())
}

// test for https://github.com/apache/datafusion/issues/13949
async fn run_test_with_spill_pool_if_necessary(
pool_size: usize,
expect_spill: bool,
) -> Result<()> {
fn create_record_batch(
schema: &Arc<Schema>,
data: (Vec<u32>, Vec<f64>),
) -> Result<RecordBatch> {
Ok(RecordBatch::try_new(
Arc::clone(schema),
vec![
Arc::new(UInt32Array::from(data.0)),
Arc::new(Float64Array::from(data.1)),
],
)?)
}

let schema = Arc::new(Schema::new(vec![
Field::new("a", DataType::UInt32, false),
Field::new("b", DataType::Float64, false),
]));

let batches = vec![
create_record_batch(&schema, (vec![2, 3, 4, 4], vec![1.0, 2.0, 3.0, 4.0]))?,
create_record_batch(&schema, (vec![2, 3, 4, 4], vec![1.0, 2.0, 3.0, 4.0]))?,
];
let plan: Arc<dyn ExecutionPlan> =
Arc::new(MemoryExec::try_new(&[batches], Arc::clone(&schema), None)?);

let grouping_set = PhysicalGroupBy::new(
vec![(col("a", &schema)?, "a".to_string())],
vec![],
vec![vec![false]],
);

// Test with MIN for simple intermediate state (min) and AVG for multiple intermediate states (partial sum, partial count).
let aggregates: Vec<Arc<AggregateFunctionExpr>> = vec![
Arc::new(
AggregateExprBuilder::new(
datafusion_functions_aggregate::min_max::min_udaf(),
vec![col("b", &schema)?],
)
.schema(Arc::clone(&schema))
.alias("MIN(b)")
.build()?,
),
Arc::new(
AggregateExprBuilder::new(avg_udaf(), vec![col("b", &schema)?])
.schema(Arc::clone(&schema))
.alias("AVG(b)")
.build()?,
),
];

let single_aggregate = Arc::new(AggregateExec::try_new(
AggregateMode::Single,
grouping_set,
aggregates,
vec![None, None],
plan,
Arc::clone(&schema),
)?);

let batch_size = 2;
let memory_pool = Arc::new(FairSpillPool::new(pool_size));
let task_ctx = Arc::new(
TaskContext::default()
.with_session_config(SessionConfig::new().with_batch_size(batch_size))
.with_runtime(Arc::new(
RuntimeEnvBuilder::new()
.with_memory_pool(memory_pool)
.build()?,
)),
);

let result = collect(single_aggregate.execute(0, Arc::clone(&task_ctx))?).await?;

assert_spill_count_metric(expect_spill, single_aggregate);

#[rustfmt::skip]
assert_batches_sorted_eq!(
[
"+---+--------+--------+",
"| a | MIN(b) | AVG(b) |",
"+---+--------+--------+",
"| 2 | 1.0 | 1.0 |",
"| 3 | 2.0 | 2.0 |",
"| 4 | 3.0 | 3.5 |",
"+---+--------+--------+",
],
&result
);

Ok(())
}

fn assert_spill_count_metric(
expect_spill: bool,
single_aggregate: Arc<AggregateExec>,
) {
if let Some(metrics_set) = single_aggregate.metrics() {
let mut spill_count = 0;

// Inspect metrics for SpillCount
for metric in metrics_set.iter() {
if let MetricValue::SpillCount(count) = metric.value() {
spill_count = count.value();
break;
}
}

if expect_spill && spill_count == 0 {
panic!(
"Expected spill but SpillCount metric not found or SpillCount was 0."
);
} else if !expect_spill && spill_count > 0 {
panic!("Expected no spill but found SpillCount metric with value greater than 0.");
}
} else {
panic!("No metrics returned from the operator; cannot verify spilling.");
}
}

#[tokio::test]
async fn test_aggregate_with_spill_if_necessary() -> Result<()> {
// test with spill
run_test_with_spill_pool_if_necessary(2_000, true).await?;
// test without spill
run_test_with_spill_pool_if_necessary(20_000, false).await?;
Ok(())
}
}
34 changes: 28 additions & 6 deletions datafusion/physical-plan/src/aggregates/row_hash.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@ use std::vec;
use crate::aggregates::group_values::{new_group_values, GroupValues};
use crate::aggregates::order::GroupOrderingFull;
use crate::aggregates::{
evaluate_group_by, evaluate_many, evaluate_optional, group_schema, AggregateMode,
PhysicalGroupBy,
create_schema, evaluate_group_by, evaluate_many, evaluate_optional, group_schema,
AggregateMode, PhysicalGroupBy,
};
use crate::metrics::{BaselineMetrics, MetricBuilder, RecordOutput};
use crate::sorts::sort::sort_batch;
Expand Down Expand Up @@ -490,6 +490,31 @@ impl GroupedHashAggregateStream {
.collect::<Result<_>>()?;

let group_schema = group_schema(&agg.input().schema(), &agg_group_by)?;

// fix https://github.com/apache/datafusion/issues/13949
// Builds a **partial aggregation** schema by combining the group columns and
// the accumulator state columns produced by each aggregate expression.
//
// # Why Partial Aggregation Schema Is Needed
//
// In a multi-stage (partial/final) aggregation strategy, each partial-aggregate
// operator produces *intermediate* states (e.g., partial sums, counts) rather
// than final scalar values. These extra columns do **not** exist in the original
// input schema (which may be something like `[colA, colB, ...]`). Instead,
// each aggregator adds its own internal state columns (e.g., `[acc_state_1, acc_state_2, ...]`).
//
// Therefore, when we spill these intermediate states or pass them to another
// aggregation operator, we must use a schema that includes both the group
// columns **and** the partial-state columns.
let partial_agg_schema = create_schema(
&agg.input().schema(),
&agg_group_by,
&aggregate_exprs,
AggregateMode::Partial,
)?;

let partial_agg_schema = Arc::new(partial_agg_schema);

let spill_expr = group_schema
.fields
.into_iter()
Expand Down Expand Up @@ -522,7 +547,7 @@ impl GroupedHashAggregateStream {
let spill_state = SpillState {
spills: vec![],
spill_expr,
spill_schema: Arc::clone(&agg_schema),
spill_schema: partial_agg_schema,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems like the issue was related only to AggregateMode::Single[Partitioned] cases, since for both Final and FinalPartitioned, there is a reassignment right before spilling (the new value is a schema for Partial output which is exactly group_by + state fields). Perhaps we can remove this reassignment now and rely on original spill_schema value set on stream creation (before removing it, we need to ensure that spill schema will be equal to intermediate result schema for any aggregation mode which supports spilling)?

Copy link
Contributor Author

@kosiew kosiew Jan 6, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hi @korowa ,

remove this reassignment now

In other words, remove these lines, am I correct?

// Use input batch (Partial mode) schema for spilling because
// the spilled data will be merged and re-evaluated later.
self.spill_state.spill_schema = batch.schema();

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, this line seems to be redundant now -- I'd expect all aggregation modes to have the same spill schema (which is set by this PR), so it shouldn't depend on stream input anymore.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for confirming.
The lines are removed.

is_stream_merging: false,
merging_aggregate_arguments,
merging_group_by: PhysicalGroupBy::new_single(agg_group_by.expr.clone()),
Expand Down Expand Up @@ -964,9 +989,6 @@ impl GroupedHashAggregateStream {
&& self.update_memory_reservation().is_err()
{
assert_ne!(self.mode, AggregateMode::Partial);
// Use input batch (Partial mode) schema for spilling because
// the spilled data will be merged and re-evaluated later.
self.spill_state.spill_schema = batch.schema();
self.spill()?;
self.clear_shrink(batch);
}
Expand Down
Loading