diff --git a/rust/arrow/benches/sort_kernel.rs b/rust/arrow/benches/sort_kernel.rs index 01701d30a9f..74dc0ceae18 100644 --- a/rust/arrow/benches/sort_kernel.rs +++ b/rust/arrow/benches/sort_kernel.rs @@ -33,7 +33,7 @@ fn create_array(size: usize, with_nulls: bool) -> ArrayRef { Arc::new(array) } -fn bench_sort(arr_a: &ArrayRef, array_b: &ArrayRef) { +fn bench_sort(arr_a: &ArrayRef, array_b: &ArrayRef, limit: Option) { let columns = vec![ SortColumn { values: arr_a.clone(), @@ -45,29 +45,76 @@ fn bench_sort(arr_a: &ArrayRef, array_b: &ArrayRef) { }, ]; - criterion::black_box(lexsort(&columns).unwrap()); + criterion::black_box(lexsort(&columns, limit).unwrap()); } fn add_benchmark(c: &mut Criterion) { let arr_a = create_array(2u64.pow(10) as usize, false); let arr_b = create_array(2u64.pow(10) as usize, false); - c.bench_function("sort 2^10", |b| b.iter(|| bench_sort(&arr_a, &arr_b))); + c.bench_function("sort 2^10", |b| b.iter(|| bench_sort(&arr_a, &arr_b, None))); let arr_a = create_array(2u64.pow(12) as usize, false); let arr_b = create_array(2u64.pow(12) as usize, false); - c.bench_function("sort 2^12", |b| b.iter(|| bench_sort(&arr_a, &arr_b))); + c.bench_function("sort 2^12", |b| b.iter(|| bench_sort(&arr_a, &arr_b, None))); let arr_a = create_array(2u64.pow(10) as usize, true); let arr_b = create_array(2u64.pow(10) as usize, true); - c.bench_function("sort nulls 2^10", |b| b.iter(|| bench_sort(&arr_a, &arr_b))); + c.bench_function("sort nulls 2^10", |b| { + b.iter(|| bench_sort(&arr_a, &arr_b, None)) + }); let arr_a = create_array(2u64.pow(12) as usize, true); let arr_b = create_array(2u64.pow(12) as usize, true); - c.bench_function("sort nulls 2^12", |b| b.iter(|| bench_sort(&arr_a, &arr_b))); + c.bench_function("sort nulls 2^12", |b| { + b.iter(|| bench_sort(&arr_a, &arr_b, None)) + }); + + // with limit + { + let arr_a = create_array(2u64.pow(12) as usize, false); + let arr_b = create_array(2u64.pow(12) as usize, false); + c.bench_function("sort 2^12 limit 10", |b| { + b.iter(|| bench_sort(&arr_a, &arr_b, Some(10))) + }); + + let arr_a = create_array(2u64.pow(12) as usize, false); + let arr_b = create_array(2u64.pow(12) as usize, false); + c.bench_function("sort 2^12 limit 100", |b| { + b.iter(|| bench_sort(&arr_a, &arr_b, Some(100))) + }); + + let arr_a = create_array(2u64.pow(12) as usize, false); + let arr_b = create_array(2u64.pow(12) as usize, false); + c.bench_function("sort 2^12 limit 1000", |b| { + b.iter(|| bench_sort(&arr_a, &arr_b, Some(1000))) + }); + + let arr_a = create_array(2u64.pow(12) as usize, false); + let arr_b = create_array(2u64.pow(12) as usize, false); + c.bench_function("sort 2^12 limit 2^12", |b| { + b.iter(|| bench_sort(&arr_a, &arr_b, Some(2u64.pow(12) as usize))) + }); + + let arr_a = create_array(2u64.pow(12) as usize, true); + let arr_b = create_array(2u64.pow(12) as usize, true); + + c.bench_function("sort nulls 2^12 limit 10", |b| { + b.iter(|| bench_sort(&arr_a, &arr_b, Some(10))) + }); + c.bench_function("sort nulls 2^12 limit 100", |b| { + b.iter(|| bench_sort(&arr_a, &arr_b, Some(100))) + }); + c.bench_function("sort nulls 2^12 limit 1000", |b| { + b.iter(|| bench_sort(&arr_a, &arr_b, Some(1000))) + }); + c.bench_function("sort nulls 2^12 limit 2^12", |b| { + b.iter(|| bench_sort(&arr_a, &arr_b, Some(2u64.pow(12) as usize))) + }); + } } criterion_group!(benches, add_benchmark); diff --git a/rust/arrow/src/compute/kernels/sort.rs b/rust/arrow/src/compute/kernels/sort.rs index e33b76ed0a1..f5472738e7c 100644 --- a/rust/arrow/src/compute/kernels/sort.rs +++ b/rust/arrow/src/compute/kernels/sort.rs @@ -17,16 +17,15 @@ //! Defines sort kernel for `ArrayRef` -use std::cmp::{Ordering, Reverse}; +use std::cmp::Ordering; +use std::sync::Arc; use crate::array::*; +use crate::buffer::MutableBuffer; use crate::compute::take; use crate::datatypes::*; use crate::error::{ArrowError, Result}; -use crate::buffer::MutableBuffer; -use num::ToPrimitive; -use std::sync::Arc; use TimeUnit::*; /// Sort the `ArrayRef` using `SortOptions`. @@ -37,10 +36,34 @@ use TimeUnit::*; /// Returns an `ArrowError::ComputeError(String)` if the array type is either unsupported by `sort_to_indices` or `take`. /// pub fn sort(values: &ArrayRef, options: Option) -> Result { - let indices = sort_to_indices(values, options)?; + let indices = sort_to_indices(values, options, None)?; + take(values.as_ref(), &indices, None) +} + +/// Sort the `ArrayRef` partially. +/// It's unstable_sort, may not preserve the order of equal elements +/// Return an sorted `ArrayRef`, discarding the data after limit. +pub fn sort_limit( + values: &ArrayRef, + options: Option, + limit: Option, +) -> Result { + let indices = sort_to_indices(values, options, limit)?; take(values.as_ref(), &indices, None) } +#[inline] +fn sort_by(array: &mut [T], limit: usize, cmp: F) +where + F: FnMut(&T, &T) -> Ordering, +{ + if array.len() == limit { + array.sort_by(cmp); + } else { + partial_sort(array, limit, cmp); + } +} + // implements comparison using IEEE 754 total ordering for f32 // Original implementation from https://doc.rust-lang.org/std/primitive.f64.html#method.total_cmp // TODO to change to use std when it becomes stable @@ -76,118 +99,176 @@ where // partition indices into valid and null indices fn partition_validity(array: &ArrayRef) -> (Vec, Vec) { - let indices = 0..(array.len().to_u32().unwrap()); - indices.partition(|index| array.is_valid(*index as usize)) + match array.null_count() { + // faster path + 0 => ((0..(array.len() as u32)).collect(), vec![]), + _ => { + let indices = 0..(array.len() as u32); + indices.partition(|index| array.is_valid(*index as usize)) + } + } } /// Sort elements from `ArrayRef` into an unsigned integer (`UInt32Array`) of indices. /// For floating point arrays any NaN values are considered to be greater than any other non-null value +/// limit is an option for partial_sort pub fn sort_to_indices( values: &ArrayRef, options: Option, + limit: Option, ) -> Result { let options = options.unwrap_or_default(); let (v, n) = partition_validity(values); match values.data_type() { - DataType::Boolean => sort_boolean(values, v, n, &options), - DataType::Int8 => sort_primitive::(values, v, n, cmp, &options), - DataType::Int16 => sort_primitive::(values, v, n, cmp, &options), - DataType::Int32 => sort_primitive::(values, v, n, cmp, &options), - DataType::Int64 => sort_primitive::(values, v, n, cmp, &options), - DataType::UInt8 => sort_primitive::(values, v, n, cmp, &options), - DataType::UInt16 => sort_primitive::(values, v, n, cmp, &options), - DataType::UInt32 => sort_primitive::(values, v, n, cmp, &options), - DataType::UInt64 => sort_primitive::(values, v, n, cmp, &options), + DataType::Boolean => sort_boolean(values, v, n, &options, limit), + DataType::Int8 => { + sort_primitive::(values, v, n, cmp, &options, limit) + } + DataType::Int16 => { + sort_primitive::(values, v, n, cmp, &options, limit) + } + DataType::Int32 => { + sort_primitive::(values, v, n, cmp, &options, limit) + } + DataType::Int64 => { + sort_primitive::(values, v, n, cmp, &options, limit) + } + DataType::UInt8 => { + sort_primitive::(values, v, n, cmp, &options, limit) + } + DataType::UInt16 => { + sort_primitive::(values, v, n, cmp, &options, limit) + } + DataType::UInt32 => { + sort_primitive::(values, v, n, cmp, &options, limit) + } + DataType::UInt64 => { + sort_primitive::(values, v, n, cmp, &options, limit) + } DataType::Float32 => { - sort_primitive::(values, v, n, total_cmp_32, &options) + sort_primitive::(values, v, n, total_cmp_32, &options, limit) } DataType::Float64 => { - sort_primitive::(values, v, n, total_cmp_64, &options) + sort_primitive::(values, v, n, total_cmp_64, &options, limit) + } + DataType::Date32 => { + sort_primitive::(values, v, n, cmp, &options, limit) + } + DataType::Date64 => { + sort_primitive::(values, v, n, cmp, &options, limit) } - DataType::Date32 => sort_primitive::(values, v, n, cmp, &options), - DataType::Date64 => sort_primitive::(values, v, n, cmp, &options), DataType::Time32(Second) => { - sort_primitive::(values, v, n, cmp, &options) + sort_primitive::(values, v, n, cmp, &options, limit) } DataType::Time32(Millisecond) => { - sort_primitive::(values, v, n, cmp, &options) + sort_primitive::(values, v, n, cmp, &options, limit) } DataType::Time64(Microsecond) => { - sort_primitive::(values, v, n, cmp, &options) + sort_primitive::(values, v, n, cmp, &options, limit) } DataType::Time64(Nanosecond) => { - sort_primitive::(values, v, n, cmp, &options) + sort_primitive::(values, v, n, cmp, &options, limit) } DataType::Timestamp(Second, _) => { - sort_primitive::(values, v, n, cmp, &options) + sort_primitive::(values, v, n, cmp, &options, limit) } DataType::Timestamp(Millisecond, _) => { - sort_primitive::(values, v, n, cmp, &options) + sort_primitive::( + values, v, n, cmp, &options, limit, + ) } DataType::Timestamp(Microsecond, _) => { - sort_primitive::(values, v, n, cmp, &options) + sort_primitive::( + values, v, n, cmp, &options, limit, + ) } DataType::Timestamp(Nanosecond, _) => { - sort_primitive::(values, v, n, cmp, &options) + sort_primitive::( + values, v, n, cmp, &options, limit, + ) } DataType::Interval(IntervalUnit::YearMonth) => { - sort_primitive::(values, v, n, cmp, &options) + sort_primitive::(values, v, n, cmp, &options, limit) } DataType::Interval(IntervalUnit::DayTime) => { - sort_primitive::(values, v, n, cmp, &options) + sort_primitive::(values, v, n, cmp, &options, limit) } DataType::Duration(TimeUnit::Second) => { - sort_primitive::(values, v, n, cmp, &options) + sort_primitive::(values, v, n, cmp, &options, limit) } DataType::Duration(TimeUnit::Millisecond) => { - sort_primitive::(values, v, n, cmp, &options) + sort_primitive::( + values, v, n, cmp, &options, limit, + ) } DataType::Duration(TimeUnit::Microsecond) => { - sort_primitive::(values, v, n, cmp, &options) + sort_primitive::( + values, v, n, cmp, &options, limit, + ) } DataType::Duration(TimeUnit::Nanosecond) => { - sort_primitive::(values, v, n, cmp, &options) + sort_primitive::( + values, v, n, cmp, &options, limit, + ) } - DataType::Utf8 => sort_string(values, v, n, &options), + DataType::Utf8 => sort_string(values, v, n, &options, limit), DataType::List(field) => match field.data_type() { - DataType::Int8 => sort_list::(values, v, n, &options), - DataType::Int16 => sort_list::(values, v, n, &options), - DataType::Int32 => sort_list::(values, v, n, &options), - DataType::Int64 => sort_list::(values, v, n, &options), - DataType::UInt8 => sort_list::(values, v, n, &options), - DataType::UInt16 => sort_list::(values, v, n, &options), - DataType::UInt32 => sort_list::(values, v, n, &options), - DataType::UInt64 => sort_list::(values, v, n, &options), + DataType::Int8 => sort_list::(values, v, n, &options, limit), + DataType::Int16 => sort_list::(values, v, n, &options, limit), + DataType::Int32 => sort_list::(values, v, n, &options, limit), + DataType::Int64 => sort_list::(values, v, n, &options, limit), + DataType::UInt8 => sort_list::(values, v, n, &options, limit), + DataType::UInt16 => { + sort_list::(values, v, n, &options, limit) + } + DataType::UInt32 => { + sort_list::(values, v, n, &options, limit) + } + DataType::UInt64 => { + sort_list::(values, v, n, &options, limit) + } t => Err(ArrowError::ComputeError(format!( "Sort not supported for list type {:?}", t ))), }, DataType::LargeList(field) => match field.data_type() { - DataType::Int8 => sort_list::(values, v, n, &options), - DataType::Int16 => sort_list::(values, v, n, &options), - DataType::Int32 => sort_list::(values, v, n, &options), - DataType::Int64 => sort_list::(values, v, n, &options), - DataType::UInt8 => sort_list::(values, v, n, &options), - DataType::UInt16 => sort_list::(values, v, n, &options), - DataType::UInt32 => sort_list::(values, v, n, &options), - DataType::UInt64 => sort_list::(values, v, n, &options), + DataType::Int8 => sort_list::(values, v, n, &options, limit), + DataType::Int16 => sort_list::(values, v, n, &options, limit), + DataType::Int32 => sort_list::(values, v, n, &options, limit), + DataType::Int64 => sort_list::(values, v, n, &options, limit), + DataType::UInt8 => sort_list::(values, v, n, &options, limit), + DataType::UInt16 => { + sort_list::(values, v, n, &options, limit) + } + DataType::UInt32 => { + sort_list::(values, v, n, &options, limit) + } + DataType::UInt64 => { + sort_list::(values, v, n, &options, limit) + } t => Err(ArrowError::ComputeError(format!( "Sort not supported for list type {:?}", t ))), }, DataType::FixedSizeList(field, _) => match field.data_type() { - DataType::Int8 => sort_list::(values, v, n, &options), - DataType::Int16 => sort_list::(values, v, n, &options), - DataType::Int32 => sort_list::(values, v, n, &options), - DataType::Int64 => sort_list::(values, v, n, &options), - DataType::UInt8 => sort_list::(values, v, n, &options), - DataType::UInt16 => sort_list::(values, v, n, &options), - DataType::UInt32 => sort_list::(values, v, n, &options), - DataType::UInt64 => sort_list::(values, v, n, &options), + DataType::Int8 => sort_list::(values, v, n, &options, limit), + DataType::Int16 => sort_list::(values, v, n, &options, limit), + DataType::Int32 => sort_list::(values, v, n, &options, limit), + DataType::Int64 => sort_list::(values, v, n, &options, limit), + DataType::UInt8 => sort_list::(values, v, n, &options, limit), + DataType::UInt16 => { + sort_list::(values, v, n, &options, limit) + } + DataType::UInt32 => { + sort_list::(values, v, n, &options, limit) + } + DataType::UInt64 => { + sort_list::(values, v, n, &options, limit) + } t => Err(ArrowError::ComputeError(format!( "Sort not supported for list type {:?}", t @@ -198,28 +279,28 @@ pub fn sort_to_indices( { match key_type.as_ref() { DataType::Int8 => { - sort_string_dictionary::(values, v, n, &options) + sort_string_dictionary::(values, v, n, &options, limit) } DataType::Int16 => { - sort_string_dictionary::(values, v, n, &options) + sort_string_dictionary::(values, v, n, &options, limit) } DataType::Int32 => { - sort_string_dictionary::(values, v, n, &options) + sort_string_dictionary::(values, v, n, &options, limit) } DataType::Int64 => { - sort_string_dictionary::(values, v, n, &options) + sort_string_dictionary::(values, v, n, &options, limit) } DataType::UInt8 => { - sort_string_dictionary::(values, v, n, &options) + sort_string_dictionary::(values, v, n, &options, limit) } DataType::UInt16 => { - sort_string_dictionary::(values, v, n, &options) + sort_string_dictionary::(values, v, n, &options, limit) } DataType::UInt32 => { - sort_string_dictionary::(values, v, n, &options) + sort_string_dictionary::(values, v, n, &options, limit) } DataType::UInt64 => { - sort_string_dictionary::(values, v, n, &options) + sort_string_dictionary::(values, v, n, &options, limit) } t => Err(ArrowError::ComputeError(format!( "Sort not supported for dictionary key type {:?}", @@ -260,6 +341,7 @@ fn sort_boolean( value_indices: Vec, null_indices: Vec, options: &SortOptions, + limit: Option, ) -> Result { let values = values .as_any() @@ -278,10 +360,14 @@ fn sort_boolean( let valids_len = valids.len(); let nulls_len = nulls.len(); + let mut len = values.len(); + if let Some(limit) = limit { + len = limit.min(len); + } if !descending { - valids.sort_by(|a, b| a.1.cmp(&b.1)); + sort_by(&mut valids, len - nulls_len, |a, b| cmp(a.1, b.1)); } else { - valids.sort_by(|a, b| a.1.cmp(&b.1).reverse()); + sort_by(&mut valids, len - nulls_len, |a, b| cmp(a.1, b.1).reverse()); // reverse to keep a stable ordering nulls.reverse(); } @@ -295,17 +381,23 @@ fn sort_boolean( debug_assert_eq!(result_slice.len(), nulls_len + valids_len); if options.nulls_first { - result_slice[0..nulls_len].copy_from_slice(&nulls); - insert_valid_values(result_slice, nulls_len, valids); + let size = nulls_len.min(len); + result_slice[0..nulls_len.min(len)].copy_from_slice(&nulls); + if nulls_len < len { + insert_valid_values(result_slice, nulls_len, &valids[0..len - size]); + } } else { // nulls last - insert_valid_values(result_slice, 0, valids); - result_slice[valids_len..].copy_from_slice(nulls.as_slice()) + let size = valids.len().min(len); + insert_valid_values(result_slice, 0, &valids[0..size]); + if len > size { + result_slice[valids_len..].copy_from_slice(&nulls[0..(len - valids_len)]); + } } let result_data = Arc::new(ArrayData::new( DataType::UInt32, - values.len(), + len, Some(0), None, 0, @@ -324,6 +416,7 @@ fn sort_primitive( null_indices: Vec, cmp: F, options: &SortOptions, + limit: Option, ) -> Result where T: ArrowPrimitiveType, @@ -343,11 +436,15 @@ where let valids_len = valids.len(); let nulls_len = nulls.len(); + let mut len = values.len(); + if let Some(limit) = limit { + len = limit.min(len); + } if !descending { - valids.sort_by(|a, b| cmp(a.1, b.1)); + sort_by(&mut valids, len - nulls_len, |a, b| cmp(a.1, b.1)); } else { - valids.sort_by(|a, b| cmp(a.1, b.1).reverse()); + sort_by(&mut valids, len - nulls_len, |a, b| cmp(a.1, b.1).reverse()); // reverse to keep a stable ordering nulls.reverse(); } @@ -361,17 +458,23 @@ where debug_assert_eq!(result_slice.len(), nulls_len + valids_len); if options.nulls_first { - result_slice[0..nulls_len].copy_from_slice(&nulls); - insert_valid_values(result_slice, nulls_len, valids); + let size = nulls_len.min(len); + result_slice[0..nulls_len.min(len)].copy_from_slice(&nulls); + if nulls_len < len { + insert_valid_values(result_slice, nulls_len, &valids[0..len - size]); + } } else { // nulls last - insert_valid_values(result_slice, 0, valids); - result_slice[valids_len..].copy_from_slice(nulls.as_slice()) + let size = valids.len().min(len); + insert_valid_values(result_slice, 0, &valids[0..size]); + if len > size { + result_slice[valids_len..].copy_from_slice(&nulls[0..(len - valids_len)]); + } } let result_data = Arc::new(ArrayData::new( DataType::UInt32, - values.len(), + len, Some(0), None, 0, @@ -383,23 +486,18 @@ where } // insert valid and nan values in the correct order depending on the descending flag -fn insert_valid_values( - result_slice: &mut [u32], - offset: usize, - valids: Vec<(u32, T)>, -) { +fn insert_valid_values(result_slice: &mut [u32], offset: usize, valids: &[(u32, T)]) { let valids_len = valids.len(); - // helper to append the index part of the valid tuples let append_valids = move |dst_slice: &mut [u32]| { debug_assert_eq!(dst_slice.len(), valids_len); dst_slice .iter_mut() - .zip(valids.into_iter()) + .zip(valids.iter()) .for_each(|(dst, src)| *dst = src.0) }; - append_valids(&mut result_slice[offset..offset + valids_len]); + append_valids(&mut result_slice[offset..offset + valids.len()]); } /// Sort strings @@ -408,6 +506,7 @@ fn sort_string( value_indices: Vec, null_indices: Vec, options: &SortOptions, + limit: Option, ) -> Result { let values = as_string_array(values); @@ -416,6 +515,7 @@ fn sort_string( value_indices, null_indices, options, + limit, |array, idx| array.value(idx as usize), ) } @@ -426,6 +526,7 @@ fn sort_string_dictionary( value_indices: Vec, null_indices: Vec, options: &SortOptions, + limit: Option, ) -> Result { let values: &DictionaryArray = as_dictionary_array::(values); @@ -439,6 +540,7 @@ fn sort_string_dictionary( value_indices, null_indices, options, + limit, |array: &PrimitiveArray, idx| -> &str { let key: T::Native = array.value(idx as usize); dict.value(key.to_usize().unwrap()) @@ -454,6 +556,7 @@ fn sort_string_helper<'a, A: Array, F>( value_indices: Vec, null_indices: Vec, options: &SortOptions, + limit: Option, value_fn: F, ) -> Result where @@ -464,10 +567,18 @@ where .map(|index| (index, value_fn(&values, index))) .collect::>(); let mut nulls = null_indices; - if !options.descending { - valids.sort_by_key(|a| a.1); + let descending = options.descending; + let mut len = values.len(); + let nulls_len = nulls.len(); + + if let Some(limit) = limit { + len = limit.min(len); + } + if !descending { + sort_by(&mut valids, len - nulls_len, |a, b| cmp(a.1, b.1)); } else { - valids.sort_by_key(|a| Reverse(a.1)); + sort_by(&mut valids, len - nulls_len, |a, b| cmp(a.1, b.1).reverse()); + // reverse to keep a stable ordering nulls.reverse(); } // collect the order of valid tuplies @@ -475,12 +586,13 @@ where if options.nulls_first { nulls.append(&mut valid_indices); + nulls.truncate(len); return Ok(UInt32Array::from(nulls)); } // no need to sort nulls as they are in the correct order already valid_indices.append(&mut nulls); - + valid_indices.truncate(len); Ok(UInt32Array::from(valid_indices)) } @@ -490,6 +602,7 @@ fn sort_list( value_indices: Vec, mut null_indices: Vec, options: &SortOptions, + limit: Option, ) -> Result where S: OffsetSizeTrait, @@ -517,20 +630,34 @@ where }, ); - if !options.descending { - valids.sort_by(|a, b| cmp_array(a.1.as_ref(), b.1.as_ref())) + let mut len = values.len(); + let nulls_len = null_indices.len(); + let descending = options.descending; + + if let Some(limit) = limit { + len = limit.min(len); + } + if !descending { + sort_by(&mut valids, len - nulls_len, |a, b| { + cmp_array(a.1.as_ref(), b.1.as_ref()) + }); } else { - valids.sort_by(|a, b| cmp_array(a.1.as_ref(), b.1.as_ref()).reverse()) + sort_by(&mut valids, len - nulls_len, |a, b| { + cmp_array(a.1.as_ref(), b.1.as_ref()).reverse() + }); + // reverse to keep a stable ordering + null_indices.reverse(); } let mut valid_indices: Vec = valids.iter().map(|tuple| tuple.0).collect(); - if options.nulls_first { null_indices.append(&mut valid_indices); + null_indices.truncate(len); return Ok(UInt32Array::from(null_indices)); } valid_indices.append(&mut null_indices); + valid_indices.truncate(len); Ok(UInt32Array::from(valid_indices)) } @@ -595,13 +722,13 @@ pub struct SortColumn { /// nulls_first: false, /// }), /// }, -/// ]).unwrap(); +/// ], None).unwrap(); /// /// assert_eq!(as_primitive_array::(&sorted_columns[0]).value(1), -64); /// assert!(sorted_columns[0].is_null(0)); /// ``` -pub fn lexsort(columns: &[SortColumn]) -> Result> { - let indices = lexsort_to_indices(columns)?; +pub fn lexsort(columns: &[SortColumn], limit: Option) -> Result> { + let indices = lexsort_to_indices(columns, limit)?; columns .iter() .map(|c| take(c.values.as_ref(), &indices, None)) @@ -610,7 +737,10 @@ pub fn lexsort(columns: &[SortColumn]) -> Result> { /// Sort elements lexicographically from a list of `ArrayRef` into an unsigned integer /// (`UInt32Array`) of indices. -pub fn lexsort_to_indices(columns: &[SortColumn]) -> Result { +pub fn lexsort_to_indices( + columns: &[SortColumn], + limit: Option, +) -> Result { if columns.is_empty() { return Err(ArrowError::InvalidArgumentError( "Sort requires at least one column".to_string(), @@ -619,7 +749,7 @@ pub fn lexsort_to_indices(columns: &[SortColumn]) -> Result { if columns.len() == 1 { // fallback to non-lexical sort let column = &columns[0]; - return sort_to_indices(&column.values, column.options); + return sort_to_indices(&column.values, column.options, limit); } let row_count = columns[0].values.len(); @@ -686,22 +816,38 @@ pub fn lexsort_to_indices(columns: &[SortColumn]) -> Result { }; let mut value_indices = (0..row_count).collect::>(); - value_indices.sort_by(lex_comparator); + let mut len = value_indices.len(); + + if let Some(limit) = limit { + len = limit.min(len); + } + sort_by(&mut value_indices, len, lex_comparator); Ok(UInt32Array::from( - value_indices - .into_iter() - .map(|i| i as u32) + (&value_indices)[0..len] + .iter() + .map(|i| *i as u32) .collect::>(), )) } +/// It's unstable_sort, may not preserve the order of equal elements +pub fn partial_sort(v: &mut [T], limit: usize, mut is_less: F) +where + F: FnMut(&T, &T) -> Ordering, +{ + let (before, _mid, _after) = v.select_nth_unstable_by(limit, &mut is_less); + before.sort_unstable_by(is_less); +} + #[cfg(test)] mod tests { use super::*; use crate::compute::util::tests::{ build_fixed_size_list_nullable, build_generic_list_nullable, }; + use rand::rngs::StdRng; + use rand::{Rng, RngCore, SeedableRng}; use std::convert::TryFrom; use std::iter::FromIterator; use std::sync::Arc; @@ -709,17 +855,20 @@ mod tests { fn test_sort_to_indices_boolean_arrays( data: Vec>, options: Option, + limit: Option, expected_data: Vec, ) { let output = BooleanArray::from(data); let expected = UInt32Array::from(expected_data); - let output = sort_to_indices(&(Arc::new(output) as ArrayRef), options).unwrap(); + let output = + sort_to_indices(&(Arc::new(output) as ArrayRef), options, limit).unwrap(); assert_eq!(output, expected) } fn test_sort_to_indices_primitive_arrays( data: Vec>, options: Option, + limit: Option, expected_data: Vec, ) where T: ArrowPrimitiveType, @@ -727,13 +876,15 @@ mod tests { { let output = PrimitiveArray::::from(data); let expected = UInt32Array::from(expected_data); - let output = sort_to_indices(&(Arc::new(output) as ArrayRef), options).unwrap(); + let output = + sort_to_indices(&(Arc::new(output) as ArrayRef), options, limit).unwrap(); assert_eq!(output, expected) } fn test_sort_primitive_arrays( data: Vec>, options: Option, + limit: Option, expected_data: Vec>, ) where T: ArrowPrimitiveType, @@ -741,35 +892,49 @@ mod tests { { let output = PrimitiveArray::::from(data); let expected = Arc::new(PrimitiveArray::::from(expected_data)) as ArrayRef; - let output = sort(&(Arc::new(output) as ArrayRef), options).unwrap(); + let output = match limit { + Some(_) => { + sort_limit(&(Arc::new(output) as ArrayRef), options, limit).unwrap() + } + _ => sort(&(Arc::new(output) as ArrayRef), options).unwrap(), + }; assert_eq!(&output, &expected) } fn test_sort_to_indices_string_arrays( data: Vec>, options: Option, + limit: Option, expected_data: Vec, ) { let output = StringArray::from(data); let expected = UInt32Array::from(expected_data); - let output = sort_to_indices(&(Arc::new(output) as ArrayRef), options).unwrap(); + let output = + sort_to_indices(&(Arc::new(output) as ArrayRef), options, limit).unwrap(); assert_eq!(output, expected) } fn test_sort_string_arrays( data: Vec>, options: Option, + limit: Option, expected_data: Vec>, ) { let output = StringArray::from(data); let expected = Arc::new(StringArray::from(expected_data)) as ArrayRef; - let output = sort(&(Arc::new(output) as ArrayRef), options).unwrap(); + let output = match limit { + Some(_) => { + sort_limit(&(Arc::new(output) as ArrayRef), options, limit).unwrap() + } + _ => sort(&(Arc::new(output) as ArrayRef), options).unwrap(), + }; assert_eq!(&output, &expected) } fn test_sort_string_dict_arrays( data: Vec>, options: Option, + limit: Option, expected_data: Vec>, ) { let array = DictionaryArray::::from_iter(data.into_iter()); @@ -779,7 +944,12 @@ mod tests { .downcast_ref::() .expect("Unable to get dictionary values"); - let sorted = sort(&(Arc::new(array) as ArrayRef), options).unwrap(); + let sorted = match limit { + Some(_) => { + sort_limit(&(Arc::new(array) as ArrayRef), options, limit).unwrap() + } + _ => sort(&(Arc::new(array) as ArrayRef), options).unwrap(), + }; let sorted = sorted .as_any() .downcast_ref::>() @@ -814,6 +984,7 @@ mod tests { fn test_sort_list_arrays( data: Vec>>>, options: Option, + limit: Option, expected_data: Vec>>>, fixed_length: Option, ) where @@ -823,7 +994,10 @@ mod tests { // for FixedSizedList if let Some(length) = fixed_length { let input = Arc::new(build_fixed_size_list_nullable(data.clone(), length)); - let sorted = sort(&(input as ArrayRef), options).unwrap(); + let sorted = match limit { + Some(_) => sort_limit(&(input as ArrayRef), options, limit).unwrap(), + _ => sort(&(input as ArrayRef), options).unwrap(), + }; let expected = Arc::new(build_fixed_size_list_nullable( expected_data.clone(), length, @@ -834,7 +1008,10 @@ mod tests { // for List let input = Arc::new(build_generic_list_nullable::(data.clone())); - let sorted = sort(&(input as ArrayRef), options).unwrap(); + let sorted = match limit { + Some(_) => sort_limit(&(input as ArrayRef), options, limit).unwrap(), + _ => sort(&(input as ArrayRef), options).unwrap(), + }; let expected = Arc::new(build_generic_list_nullable::(expected_data.clone())) as ArrayRef; @@ -843,15 +1020,22 @@ mod tests { // for LargeList let input = Arc::new(build_generic_list_nullable::(data)); - let sorted = sort(&(input as ArrayRef), options).unwrap(); + let sorted = match limit { + Some(_) => sort_limit(&(input as ArrayRef), options, limit).unwrap(), + _ => sort(&(input as ArrayRef), options).unwrap(), + }; let expected = Arc::new(build_generic_list_nullable::(expected_data)) as ArrayRef; assert_eq!(&sorted, &expected); } - fn test_lex_sort_arrays(input: Vec, expected_output: Vec) { - let sorted = lexsort(&input).unwrap(); + fn test_lex_sort_arrays( + input: Vec, + expected_output: Vec, + limit: Option, + ) { + let sorted = lexsort(&input, limit).unwrap(); for (result, expected) in sorted.iter().zip(expected_output.iter()) { assert_eq!(result, expected); @@ -863,21 +1047,25 @@ mod tests { test_sort_to_indices_primitive_arrays::( vec![None, Some(0), Some(2), Some(-1), Some(0), None], None, + None, vec![0, 5, 3, 1, 4, 2], ); test_sort_to_indices_primitive_arrays::( vec![None, Some(0), Some(2), Some(-1), Some(0), None], None, + None, vec![0, 5, 3, 1, 4, 2], ); test_sort_to_indices_primitive_arrays::( vec![None, Some(0), Some(2), Some(-1), Some(0), None], None, + None, vec![0, 5, 3, 1, 4, 2], ); test_sort_to_indices_primitive_arrays::( vec![None, Some(0), Some(2), Some(-1), Some(0), None], None, + None, vec![0, 5, 3, 1, 4, 2], ); test_sort_to_indices_primitive_arrays::( @@ -890,6 +1078,7 @@ mod tests { None, ], None, + None, vec![0, 5, 3, 1, 4, 2], ); test_sort_to_indices_primitive_arrays::( @@ -902,6 +1091,7 @@ mod tests { None, ], None, + None, vec![0, 5, 3, 1, 4, 2], ); @@ -912,6 +1102,7 @@ mod tests { descending: true, nulls_first: false, }), + None, vec![2, 1, 4, 3, 5, 0], // [2, 4, 1, 3, 5, 0] ); @@ -921,6 +1112,7 @@ mod tests { descending: true, nulls_first: false, }), + None, vec![2, 1, 4, 3, 5, 0], ); @@ -930,6 +1122,7 @@ mod tests { descending: true, nulls_first: false, }), + None, vec![2, 1, 4, 3, 5, 0], ); @@ -939,6 +1132,7 @@ mod tests { descending: true, nulls_first: false, }), + None, vec![2, 1, 4, 3, 5, 0], ); @@ -955,6 +1149,7 @@ mod tests { descending: true, nulls_first: false, }), + None, vec![2, 1, 4, 3, 5, 0], ); @@ -964,6 +1159,7 @@ mod tests { descending: true, nulls_first: false, }), + None, vec![2, 1, 4, 3, 5, 0], ); @@ -974,6 +1170,7 @@ mod tests { descending: true, nulls_first: true, }), + None, vec![5, 0, 2, 1, 4, 3], // [5, 0, 2, 4, 1, 3] ); @@ -983,6 +1180,7 @@ mod tests { descending: true, nulls_first: true, }), + None, vec![5, 0, 2, 1, 4, 3], // [5, 0, 2, 4, 1, 3] ); @@ -992,6 +1190,7 @@ mod tests { descending: true, nulls_first: true, }), + None, vec![5, 0, 2, 1, 4, 3], ); @@ -1001,6 +1200,7 @@ mod tests { descending: true, nulls_first: true, }), + None, vec![5, 0, 2, 1, 4, 3], ); @@ -1010,6 +1210,7 @@ mod tests { descending: true, nulls_first: true, }), + None, vec![5, 0, 2, 1, 4, 3], ); @@ -1019,6 +1220,7 @@ mod tests { descending: true, nulls_first: true, }), + None, vec![5, 0, 2, 1, 4, 3], ); } @@ -1029,6 +1231,7 @@ mod tests { test_sort_to_indices_boolean_arrays( vec![None, Some(false), Some(true), Some(true), Some(false), None], None, + None, vec![0, 5, 1, 4, 2, 3], ); @@ -1039,6 +1242,7 @@ mod tests { descending: true, nulls_first: false, }), + None, vec![2, 3, 1, 4, 5, 0], ); @@ -1049,8 +1253,20 @@ mod tests { descending: true, nulls_first: true, }), + None, vec![5, 0, 2, 3, 1, 4], ); + + // boolean, descending, nulls first, limit + test_sort_to_indices_boolean_arrays( + vec![None, Some(false), Some(true), Some(true), Some(false), None], + Some(SortOptions { + descending: true, + nulls_first: true, + }), + Some(3), + vec![5, 0, 2], + ); } #[test] @@ -1059,21 +1275,25 @@ mod tests { test_sort_primitive_arrays::( vec![None, Some(3), Some(5), Some(2), Some(3), None], None, + None, vec![None, None, Some(2), Some(3), Some(3), Some(5)], ); test_sort_primitive_arrays::( vec![None, Some(3), Some(5), Some(2), Some(3), None], None, + None, vec![None, None, Some(2), Some(3), Some(3), Some(5)], ); test_sort_primitive_arrays::( vec![None, Some(3), Some(5), Some(2), Some(3), None], None, + None, vec![None, None, Some(2), Some(3), Some(3), Some(5)], ); test_sort_primitive_arrays::( vec![None, Some(3), Some(5), Some(2), Some(3), None], None, + None, vec![None, None, Some(2), Some(3), Some(3), Some(5)], ); @@ -1084,6 +1304,7 @@ mod tests { descending: true, nulls_first: false, }), + None, vec![Some(2), Some(0), Some(0), Some(-1), None, None], ); test_sort_primitive_arrays::( @@ -1092,6 +1313,7 @@ mod tests { descending: true, nulls_first: false, }), + None, vec![Some(2), Some(0), Some(0), Some(-1), None, None], ); test_sort_primitive_arrays::( @@ -1100,6 +1322,7 @@ mod tests { descending: true, nulls_first: false, }), + None, vec![Some(2), Some(0), Some(0), Some(-1), None, None], ); test_sort_primitive_arrays::( @@ -1108,6 +1331,7 @@ mod tests { descending: true, nulls_first: false, }), + None, vec![Some(2), Some(0), Some(0), Some(-1), None, None], ); @@ -1118,6 +1342,7 @@ mod tests { descending: true, nulls_first: true, }), + None, vec![None, None, Some(2), Some(0), Some(0), Some(-1)], ); test_sort_primitive_arrays::( @@ -1126,6 +1351,7 @@ mod tests { descending: true, nulls_first: true, }), + None, vec![None, None, Some(2), Some(0), Some(0), Some(-1)], ); test_sort_primitive_arrays::( @@ -1134,6 +1360,7 @@ mod tests { descending: true, nulls_first: true, }), + None, vec![None, None, Some(2), Some(0), Some(0), Some(-1)], ); test_sort_primitive_arrays::( @@ -1142,14 +1369,27 @@ mod tests { descending: true, nulls_first: true, }), + None, vec![None, None, Some(2), Some(0), Some(0), Some(-1)], ); + + test_sort_primitive_arrays::( + vec![None, Some(0), Some(2), Some(-1), Some(0), None], + Some(SortOptions { + descending: true, + nulls_first: true, + }), + Some(3), + vec![None, None, Some(2)], + ); + test_sort_primitive_arrays::( vec![None, Some(0.0), Some(2.0), Some(-1.0), Some(0.0), None], Some(SortOptions { descending: true, nulls_first: true, }), + None, vec![None, None, Some(2.0), Some(0.0), Some(0.0), Some(-1.0)], ); test_sort_primitive_arrays::( @@ -1158,6 +1398,7 @@ mod tests { descending: true, nulls_first: true, }), + None, vec![None, None, Some(f64::NAN), Some(2.0), Some(0.0), Some(-1.0)], ); test_sort_primitive_arrays::( @@ -1166,6 +1407,7 @@ mod tests { descending: true, nulls_first: true, }), + None, vec![Some(f64::NAN), Some(f64::NAN), Some(f64::NAN), Some(1.0)], ); @@ -1176,6 +1418,7 @@ mod tests { descending: false, nulls_first: true, }), + None, vec![None, None, Some(-1), Some(0), Some(0), Some(2)], ); test_sort_primitive_arrays::( @@ -1184,6 +1427,7 @@ mod tests { descending: false, nulls_first: true, }), + None, vec![None, None, Some(-1), Some(0), Some(0), Some(2)], ); test_sort_primitive_arrays::( @@ -1192,6 +1436,7 @@ mod tests { descending: false, nulls_first: true, }), + None, vec![None, None, Some(-1), Some(0), Some(0), Some(2)], ); test_sort_primitive_arrays::( @@ -1200,6 +1445,7 @@ mod tests { descending: false, nulls_first: true, }), + None, vec![None, None, Some(-1), Some(0), Some(0), Some(2)], ); test_sort_primitive_arrays::( @@ -1208,6 +1454,7 @@ mod tests { descending: false, nulls_first: true, }), + None, vec![None, None, Some(-1.0), Some(0.0), Some(0.0), Some(2.0)], ); test_sort_primitive_arrays::( @@ -1216,6 +1463,7 @@ mod tests { descending: false, nulls_first: true, }), + None, vec![None, None, Some(-1.0), Some(0.0), Some(2.0), Some(f64::NAN)], ); test_sort_primitive_arrays::( @@ -1224,8 +1472,31 @@ mod tests { descending: false, nulls_first: true, }), + None, vec![Some(1.0), Some(f64::NAN), Some(f64::NAN), Some(f64::NAN)], ); + + // limit + test_sort_primitive_arrays::( + vec![Some(f64::NAN), Some(f64::NAN), Some(f64::NAN), Some(1.0)], + Some(SortOptions { + descending: false, + nulls_first: true, + }), + Some(2), + vec![Some(1.0), Some(f64::NAN)], + ); + + // limit with actual value + test_sort_primitive_arrays::( + vec![Some(2.0), Some(4.0), Some(3.0), Some(1.0)], + Some(SortOptions { + descending: false, + nulls_first: true, + }), + Some(3), + vec![Some(1.0), Some(2.0), Some(3.0)], + ); } #[test] @@ -1240,6 +1511,7 @@ mod tests { Some("-ad"), ], None, + None, vec![0, 3, 5, 1, 4, 2], ); @@ -1256,6 +1528,7 @@ mod tests { descending: true, nulls_first: false, }), + None, vec![2, 4, 1, 5, 3, 0], ); @@ -1272,6 +1545,7 @@ mod tests { descending: false, nulls_first: true, }), + None, vec![0, 3, 5, 1, 4, 2], ); @@ -1288,8 +1562,26 @@ mod tests { descending: true, nulls_first: true, }), + None, vec![3, 0, 2, 4, 1, 5], ); + + test_sort_to_indices_string_arrays( + vec![ + None, + Some("bad"), + Some("sad"), + None, + Some("glad"), + Some("-ad"), + ], + Some(SortOptions { + descending: true, + nulls_first: true, + }), + Some(3), + vec![3, 0, 2], + ); } #[test] @@ -1304,6 +1596,7 @@ mod tests { Some("-ad"), ], None, + None, vec![ None, None, @@ -1327,6 +1620,7 @@ mod tests { descending: true, nulls_first: false, }), + None, vec![ Some("sad"), Some("glad"), @@ -1350,6 +1644,7 @@ mod tests { descending: false, nulls_first: true, }), + None, vec![ None, None, @@ -1373,6 +1668,7 @@ mod tests { descending: true, nulls_first: true, }), + None, vec![ None, None, @@ -1382,6 +1678,23 @@ mod tests { Some("-ad"), ], ); + + test_sort_string_arrays( + vec![ + None, + Some("bad"), + Some("sad"), + None, + Some("glad"), + Some("-ad"), + ], + Some(SortOptions { + descending: true, + nulls_first: true, + }), + Some(3), + vec![None, None, Some("sad")], + ); } #[test] @@ -1396,6 +1709,7 @@ mod tests { Some("-ad"), ], None, + None, vec![ None, None, @@ -1419,6 +1733,7 @@ mod tests { descending: true, nulls_first: false, }), + None, vec![ Some("sad"), Some("glad"), @@ -1442,6 +1757,7 @@ mod tests { descending: false, nulls_first: true, }), + None, vec![ None, None, @@ -1465,6 +1781,7 @@ mod tests { descending: true, nulls_first: true, }), + None, vec![ None, None, @@ -1474,6 +1791,23 @@ mod tests { Some("-ad"), ], ); + + test_sort_string_dict_arrays::( + vec![ + None, + Some("bad"), + Some("sad"), + None, + Some("glad"), + Some("-ad"), + ], + Some(SortOptions { + descending: true, + nulls_first: true, + }), + Some(3), + vec![None, None, Some("sad")], + ); } #[test] @@ -1489,6 +1823,7 @@ mod tests { descending: false, nulls_first: false, }), + None, vec![ Some(vec![Some(1)]), Some(vec![Some(2)]), @@ -1510,6 +1845,7 @@ mod tests { descending: false, nulls_first: false, }), + None, vec![ Some(vec![Some(1), Some(0)]), Some(vec![Some(1), Some(1)]), @@ -1532,6 +1868,7 @@ mod tests { descending: false, nulls_first: false, }), + None, vec![ Some(vec![Some(2), Some(3), Some(4)]), Some(vec![Some(3), Some(3), None]), @@ -1541,6 +1878,23 @@ mod tests { ], Some(3), ); + + test_sort_list_arrays::( + vec![ + Some(vec![Some(1), Some(0)]), + Some(vec![Some(4), Some(3), Some(2), Some(1)]), + Some(vec![Some(2), Some(3), Some(4)]), + Some(vec![Some(3), Some(3), Some(3), Some(3)]), + Some(vec![Some(1), Some(1)]), + ], + Some(SortOptions { + descending: false, + nulls_first: false, + }), + Some(2), + vec![Some(vec![Some(1), Some(0)]), Some(vec![Some(1), Some(1)])], + None, + ); } #[test] @@ -1560,7 +1914,14 @@ mod tests { Some(2), Some(17), ])) as ArrayRef]; - test_lex_sort_arrays(input, expected); + test_lex_sort_arrays(input.clone(), expected, None); + + let expected = vec![Arc::new(PrimitiveArray::::from(vec![ + Some(-1), + Some(0), + Some(2), + ])) as ArrayRef]; + test_lex_sort_arrays(input, expected, Some(3)); } #[test] @@ -1577,7 +1938,7 @@ mod tests { }, ]; assert!( - lexsort(&input).is_err(), + lexsort(&input, None).is_err(), "lexsort should reject columns with different row counts" ); } @@ -1633,7 +1994,7 @@ mod tests { Some(-2), ])) as ArrayRef, ]; - test_lex_sort_arrays(input, expected); + test_lex_sort_arrays(input, expected, None); // test mix of string and in64 with option let input = vec![ @@ -1676,7 +2037,7 @@ mod tests { Some("7"), ])) as ArrayRef, ]; - test_lex_sort_arrays(input, expected); + test_lex_sort_arrays(input, expected, None); // test sort with nulls first let input = vec![ @@ -1719,7 +2080,7 @@ mod tests { Some("world"), ])) as ArrayRef, ]; - test_lex_sort_arrays(input, expected); + test_lex_sort_arrays(input, expected, None); // test sort with nulls last let input = vec![ @@ -1762,7 +2123,7 @@ mod tests { None, ])) as ArrayRef, ]; - test_lex_sort_arrays(input, expected); + test_lex_sort_arrays(input, expected, None); // test sort with opposite options let input = vec![ @@ -1809,6 +2170,33 @@ mod tests { Some("foo"), ])) as ArrayRef, ]; - test_lex_sort_arrays(input, expected); + test_lex_sort_arrays(input, expected, None); + } + + #[test] + fn test_partial_sort() { + let mut before: Vec<&str> = vec![ + "a", "cat", "mat", "on", "sat", "the", "xxx", "xxxx", "fdadfdsf", + ]; + let mut d = before.clone(); + d.sort_unstable(); + + for last in 0..before.len() { + partial_sort(&mut before, last, |a, b| a.cmp(b)); + assert_eq!(&d[0..last], &before.as_slice()[0..last]); + } + } + + #[test] + fn test_partial_rand_sort() { + let size = 1000u32; + let mut rng = StdRng::seed_from_u64(42); + let mut before: Vec = (0..size).map(|_| rng.gen::()).collect(); + let mut d = before.clone(); + let last = (rng.next_u32() % size) as usize; + d.sort_unstable(); + + partial_sort(&mut before, last, |a, b| a.cmp(b)); + assert_eq!(&d[0..last], &before[0..last]); } } diff --git a/rust/datafusion/src/physical_plan/sort.rs b/rust/datafusion/src/physical_plan/sort.rs index 042b9f1da81..c595b0b960e 100644 --- a/rust/datafusion/src/physical_plan/sort.rs +++ b/rust/datafusion/src/physical_plan/sort.rs @@ -156,12 +156,14 @@ fn sort_batches( )?; // sort combined record batch + // TODO: pushup the limit expression to sort let indices = lexsort_to_indices( &expr .iter() .map(|e| e.evaluate_to_sort_column(&combined_batch)) .collect::>>() .map_err(DataFusionError::into_arrow_external_error)?, + None, )?; // reorder all rows based on sorted indices