diff --git a/arrow/src/compute/kernels/sort.rs b/arrow/src/compute/kernels/sort.rs index 30341b6f63a6..9287425bf126 100644 --- a/arrow/src/compute/kernels/sort.rs +++ b/arrow/src/compute/kernels/sort.rs @@ -487,24 +487,27 @@ where len = limit.min(len); } if !descending { - sort_by(&mut valids, len - nulls_len, |a, b| cmp(a.1, b.1)); + sort_by(&mut valids, len.saturating_sub(nulls_len), |a, b| { + cmp(a.1, b.1) + }); } else { - sort_by(&mut valids, len - nulls_len, |a, b| cmp(a.1, b.1).reverse()); + sort_by(&mut valids, len.saturating_sub(nulls_len), |a, b| { + cmp(a.1, b.1).reverse() + }); // reverse to keep a stable ordering nulls.reverse(); } // collect results directly into a buffer instead of a vec to avoid another aligned allocation - let mut result = MutableBuffer::new(values.len() * std::mem::size_of::()); + let result_capacity = len * std::mem::size_of::(); + let mut result = MutableBuffer::new(result_capacity); // sets len to capacity so we can access the whole buffer as a typed slice - result.resize(values.len() * std::mem::size_of::(), 0); + result.resize(result_capacity, 0); let result_slice: &mut [u32] = result.typed_data_mut(); - debug_assert_eq!(result_slice.len(), nulls_len + valids_len); - if options.nulls_first { let size = nulls_len.min(len); - result_slice[0..nulls_len.min(len)].copy_from_slice(&nulls); + result_slice[0..size].copy_from_slice(&nulls[0..size]); if nulls_len < len { insert_valid_values(result_slice, nulls_len, &valids[0..len - size]); } @@ -1556,6 +1559,48 @@ mod tests { Some(3), vec![Some(1.0), Some(2.0), Some(3.0)], ); + + // valid values less than limit with extra nulls + test_sort_primitive_arrays::( + vec![Some(2.0), None, None, Some(1.0)], + Some(SortOptions { + descending: false, + nulls_first: false, + }), + Some(3), + vec![Some(1.0), Some(2.0), None], + ); + + test_sort_primitive_arrays::( + vec![Some(2.0), None, None, Some(1.0)], + Some(SortOptions { + descending: false, + nulls_first: true, + }), + Some(3), + vec![None, None, Some(1.0)], + ); + + // more nulls than limit + test_sort_primitive_arrays::( + vec![Some(2.0), None, None, None], + Some(SortOptions { + descending: false, + nulls_first: true, + }), + Some(2), + vec![None, None], + ); + + test_sort_primitive_arrays::( + vec![Some(2.0), None, None, None], + Some(SortOptions { + descending: false, + nulls_first: false, + }), + Some(2), + vec![Some(2.0), None], + ); } #[test]