diff --git a/datafusion/physical-plan/src/aggregates/mod.rs b/datafusion/physical-plan/src/aggregates/mod.rs index 258f4140bc1e..6a0ae202c067 100644 --- a/datafusion/physical-plan/src/aggregates/mod.rs +++ b/datafusion/physical-plan/src/aggregates/mod.rs @@ -369,14 +369,26 @@ impl AggregateExec { new_requirement.extend(req); new_requirement = collapse_lex_req(new_requirement); - let input_order_mode = - if indices.len() == groupby_exprs.len() && !indices.is_empty() { - InputOrderMode::Sorted - } else if !indices.is_empty() { - InputOrderMode::PartiallySorted(indices) - } else { - InputOrderMode::Linear - }; + // If our aggregation has grouping sets then our base grouping exprs will + // be expanded based on the flags in `group_by.groups` where for each + // group we swap the grouping expr for `null` if the flag is `true` + // That means that each index in `indices` is valid if and only if + // it is not null in every group + let indices: Vec = indices + .into_iter() + .filter(|idx| group_by.groups.iter().all(|group| !group[*idx])) + .collect(); + + let input_order_mode = if indices.len() == groupby_exprs.len() + && !indices.is_empty() + && group_by.groups.len() == 1 + { + InputOrderMode::Sorted + } else if !indices.is_empty() { + InputOrderMode::PartiallySorted(indices) + } else { + InputOrderMode::Linear + }; // construct a map from the input expression to the output expression of the Aggregation group by let projection_mapping = @@ -1180,6 +1192,7 @@ mod tests { use arrow::array::{Float64Array, UInt32Array}; use arrow::compute::{concat_batches, SortOptions}; use arrow::datatypes::DataType; + use arrow_array::{Float32Array, Int32Array}; use datafusion_common::{ assert_batches_eq, assert_batches_sorted_eq, internal_err, DataFusionError, ScalarValue, @@ -1195,7 +1208,9 @@ mod tests { use datafusion_physical_expr::expressions::{lit, OrderSensitiveArrayAgg}; use datafusion_physical_expr::PhysicalSortExpr; + use crate::common::collect; use datafusion_physical_expr_common::aggregate::create_aggregate_expr; + use datafusion_physical_expr_common::expressions::Literal; use futures::{FutureExt, Stream}; // Generate a schema which consists of 5 columns (a, b, c, d, e) @@ -2267,4 +2282,94 @@ mod tests { assert_eq!(new_agg.schema(), aggregate_exec.schema()); Ok(()) } + + #[tokio::test] + async fn test_agg_exec_group_by_const() -> Result<()> { + let schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Float32, true), + Field::new("b", DataType::Float32, true), + Field::new("const", DataType::Int32, false), + ])); + + let col_a = col("a", &schema)?; + let col_b = col("b", &schema)?; + let const_expr = Arc::new(Literal::new(ScalarValue::Int32(Some(1)))); + + let groups = PhysicalGroupBy::new( + vec![ + (col_a, "a".to_string()), + (col_b, "b".to_string()), + (const_expr, "const".to_string()), + ], + vec![ + ( + Arc::new(Literal::new(ScalarValue::Float32(None))), + "a".to_string(), + ), + ( + Arc::new(Literal::new(ScalarValue::Float32(None))), + "b".to_string(), + ), + ( + Arc::new(Literal::new(ScalarValue::Int32(None))), + "const".to_string(), + ), + ], + vec![ + vec![false, true, true], + vec![true, false, true], + vec![true, true, false], + ], + ); + + let aggregates: Vec> = vec![create_aggregate_expr( + count_udaf().as_ref(), + &[lit(1)], + &[datafusion_expr::lit(1)], + &[], + &[], + schema.as_ref(), + "1", + false, + false, + )?]; + + let input_batches = (0..4) + .map(|_| { + let a = Arc::new(Float32Array::from(vec![0.; 8192])); + let b = Arc::new(Float32Array::from(vec![0.; 8192])); + let c = Arc::new(Int32Array::from(vec![1; 8192])); + + RecordBatch::try_new(schema.clone(), vec![a, b, c]).unwrap() + }) + .collect(); + + let input = + Arc::new(MemoryExec::try_new(&[input_batches], schema.clone(), None)?); + + let aggregate_exec = Arc::new(AggregateExec::try_new( + AggregateMode::Partial, + groups, + aggregates.clone(), + vec![None], + input, + schema, + )?); + + let output = + collect(aggregate_exec.execute(0, Arc::new(TaskContext::default()))?).await?; + + let expected = [ + "+-----+-----+-------+----------+", + "| a | b | const | 1[count] |", + "+-----+-----+-------+----------+", + "| | 0.0 | | 32768 |", + "| 0.0 | | | 32768 |", + "| | | 1 | 32768 |", + "+-----+-----+-------+----------+", + ]; + assert_batches_sorted_eq!(expected, &output); + + Ok(()) + } }