diff --git a/arrow/src/compute/kernels/window.rs b/arrow/src/compute/kernels/window.rs index a6b342961ac8..a537cbc76e47 100644 --- a/arrow/src/compute/kernels/window.rs +++ b/arrow/src/compute/kernels/window.rs @@ -18,7 +18,7 @@ //! Defines windowing functions, like `shift`ing use crate::array::{Array, ArrayRef}; -use crate::{array::PrimitiveArray, datatypes::ArrowPrimitiveType, error::Result}; +use crate::error::Result; use crate::{ array::{make_array, new_null_array}, compute::concat, @@ -56,23 +56,20 @@ use num::{abs, clamp}; /// let expected: Int32Array = vec![None, None, None].into(); /// assert_eq!(res.as_ref(), &expected); /// ``` -pub fn shift(values: &PrimitiveArray, offset: i64) -> Result -where - T: ArrowPrimitiveType, -{ - let value_len = values.len() as i64; +pub fn shift(array: &Array, offset: i64) -> Result { + let value_len = array.len() as i64; if offset == 0 { - Ok(make_array(values.data_ref().clone())) + Ok(make_array(array.data_ref().clone())) } else if offset == i64::MIN || abs(offset) >= value_len { - Ok(new_null_array(&T::DATA_TYPE, values.len())) + Ok(new_null_array(array.data_type(), array.len())) } else { let slice_offset = clamp(-offset, 0, value_len) as usize; - let length = values.len() - abs(offset) as usize; - let slice = values.slice(slice_offset, length); + let length = array.len() - abs(offset) as usize; + let slice = array.slice(slice_offset, length); // Generate array with remaining `null` items let nulls = abs(offset) as usize; - let null_arr = new_null_array(&T::DATA_TYPE, nulls); + let null_arr = new_null_array(array.data_type(), nulls); // Concatenate both arrays, add nulls after if shift > 0 else before if offset > 0 { @@ -86,7 +83,7 @@ where #[cfg(test)] mod tests { use super::*; - use crate::array::Int32Array; + use crate::array::{Float64Array, Int32Array, Int32DictionaryArray}; #[test] fn test_shift_neg() { @@ -104,6 +101,50 @@ mod tests { assert_eq!(res.as_ref(), &expected); } + #[test] + fn test_shift_neg_float64() { + let a: Float64Array = vec![Some(1.), None, Some(4.)].into(); + let res = shift(&a, -1).unwrap(); + let expected: Float64Array = vec![None, Some(4.), None].into(); + assert_eq!(res.as_ref(), &expected); + } + + #[test] + fn test_shift_pos_float64() { + let a: Float64Array = vec![Some(1.), None, Some(4.)].into(); + let res = shift(&a, 1).unwrap(); + let expected: Float64Array = vec![None, Some(1.), None].into(); + assert_eq!(res.as_ref(), &expected); + } + + #[test] + fn test_shift_neg_int32_dict() { + let a: Int32DictionaryArray = [Some("alpha"), None, Some("beta"), Some("alpha")] + .iter() + .copied() + .collect(); + let res = shift(&a, -1).unwrap(); + let expected: Int32DictionaryArray = [None, Some("beta"), Some("alpha"), None] + .iter() + .copied() + .collect(); + assert_eq!(res.as_ref(), &expected); + } + + #[test] + fn test_shift_pos_int32_dict() { + let a: Int32DictionaryArray = [Some("alpha"), None, Some("beta"), Some("alpha")] + .iter() + .copied() + .collect(); + let res = shift(&a, 1).unwrap(); + let expected: Int32DictionaryArray = [None, Some("alpha"), None, Some("beta")] + .iter() + .copied() + .collect(); + assert_eq!(res.as_ref(), &expected); + } + #[test] fn test_shift_nil() { let a: Int32Array = vec![Some(1), None, Some(4)].into();