diff --git a/datafusion/src/physical_plan/expressions/nth_value.rs b/datafusion/src/physical_plan/expressions/nth_value.rs index 577c19b54ade0..cb03fde3ae914 100644 --- a/datafusion/src/physical_plan/expressions/nth_value.rs +++ b/datafusion/src/physical_plan/expressions/nth_value.rs @@ -20,7 +20,7 @@ use crate::error::{DataFusionError, Result}; use crate::physical_plan::{window_functions::BuiltInWindowFunctionExpr, PhysicalExpr}; use crate::scalar::ScalarValue; -use arrow::array::{new_empty_array, new_null_array, ArrayRef}; +use arrow::array::{new_null_array, ArrayRef}; use arrow::datatypes::{DataType, Field}; use std::any::Any; use std::sync::Arc; @@ -127,19 +127,19 @@ impl BuiltInWindowFunctionExpr for NthValue { value.len() ))); } - if num_rows == 0 { - return Ok(new_empty_array(value.data_type())); - } + assert!(num_rows > 0, "Impossibly got empty values"); let index: usize = match self.kind { NthValueKind::First => 0, NthValueKind::Last => (num_rows as usize) - 1, NthValueKind::Nth(n) => (n as usize) - 1, }; + Ok(if index >= num_rows { - new_null_array(value.data_type(), num_rows) + let data_type: &DataType = value.data_type(); + new_null_array(data_type, num_rows) } else { - let value = ScalarValue::try_from_array(value, index)?; - value.to_array_of_size(num_rows) + let scalar = ScalarValue::try_from_array(value, num_rows)?; + scalar.to_array_of_size(num_rows) }) } } diff --git a/datafusion/src/physical_plan/expressions/row_number.rs b/datafusion/src/physical_plan/expressions/row_number.rs index 0444ee971f40d..7c4d0324e830c 100644 --- a/datafusion/src/physical_plan/expressions/row_number.rs +++ b/datafusion/src/physical_plan/expressions/row_number.rs @@ -59,7 +59,7 @@ impl BuiltInWindowFunctionExpr for RowNumber { fn evaluate(&self, num_rows: usize, _values: &[ArrayRef]) -> Result { Ok(Arc::new(UInt64Array::from_iter_values( - (1..num_rows + 1).map(|i| i as u64), + 1..(num_rows as u64) + 1, ))) } } diff --git a/datafusion/src/physical_plan/window_functions.rs b/datafusion/src/physical_plan/window_functions.rs index 4f56aa7d38262..366f59a3a4054 100644 --- a/datafusion/src/physical_plan/window_functions.rs +++ b/datafusion/src/physical_plan/window_functions.rs @@ -20,13 +20,13 @@ //! //! see also https://www.postgresql.org/docs/current/functions-window.html -use crate::arrow::array::ArrayRef; use crate::arrow::datatypes::Field; use crate::error::{DataFusionError, Result}; use crate::physical_plan::{ aggregates, aggregates::AggregateFunction, functions::Signature, type_coercion::data_types, PhysicalExpr, }; +use arrow::array::ArrayRef; use arrow::datatypes::DataType; use std::any::Any; use std::sync::Arc; diff --git a/datafusion/src/physical_plan/windows.rs b/datafusion/src/physical_plan/windows.rs index 2f539057c82f4..a5f23374a53bb 100644 --- a/datafusion/src/physical_plan/windows.rs +++ b/datafusion/src/physical_plan/windows.rs @@ -29,6 +29,7 @@ use crate::physical_plan::{ Accumulator, AggregateExpr, Distribution, ExecutionPlan, Partitioning, PhysicalExpr, RecordBatchStream, SendableRecordBatchStream, WindowExpr, }; +use crate::scalar::ScalarValue; use arrow::compute::concat; use arrow::{ array::ArrayRef, @@ -42,6 +43,7 @@ use futures::Future; use pin_project_lite::pin_project; use std::any::Any; use std::convert::TryInto; +use std::iter; use std::ops::Range; use std::pin::Pin; use std::sync::Arc; @@ -187,11 +189,13 @@ impl WindowExpr for BuiltInWindowExpr { .collect::>(); self.window.evaluate(len, &values) }) - .collect::>>()? - .into_iter() - .collect::>(); - let results = results.iter().map(|i| i.as_ref()).collect::>(); - concat(&results).map_err(DataFusionError::ArrowError) + .collect::>>()?; + if results.len() == 1 { + Ok(results[0].clone()) + } else { + let results = results.iter().map(|i| i.as_ref()).collect::>(); + concat(&results).map_err(DataFusionError::ArrowError) + } } } @@ -246,21 +250,25 @@ impl AggregateWindowExpr { let values = self.evaluate_args(batch)?; let results = partition_points .iter() - .map(|partition_range| { + .map::, _>(|partition_range| { let sort_partition_points = find_ranges_in_range(partition_range, &sort_partition_points); let mut window_accumulators = self.create_accumulator()?; - sort_partition_points + let result = sort_partition_points .iter() .map(|range| window_accumulators.scan_peers(&values, range)) - .collect::>>() + .collect::>>()? + .into_iter() + .flatten(); + ScalarValue::iter_to_array(result) }) - .collect::>>>()? - .into_iter() - .flatten() - .collect::>(); - let results = results.iter().map(|i| i.as_ref()).collect::>(); - concat(&results).map_err(DataFusionError::ArrowError) + .collect::>>()?; + if results.len() == 1 { + Ok(results[0].clone()) + } else { + let results = results.iter().map(|i| i.as_ref()).collect::>(); + concat(&results).map_err(DataFusionError::ArrowError) + } } fn group_based_evaluate(&self, _batch: &RecordBatch) -> Result { @@ -328,7 +336,7 @@ impl AggregateWindowAccumulator { &mut self, values: &[ArrayRef], value_range: &Range, - ) -> Result { + ) -> Result> { if value_range.is_empty() { return Err(DataFusionError::Internal( "Value range cannot be empty".to_owned(), @@ -341,7 +349,7 @@ impl AggregateWindowAccumulator { .collect::>(); self.accumulator.update_batch(&values)?; let value = self.accumulator.evaluate()?; - Ok(value.to_array_of_size(len)) + Ok(iter::repeat(value).take(len)) } }