From abfb7c9bcb88ed907719e8cb885a52e75c46d6d4 Mon Sep 17 00:00:00 2001 From: sundy-li <543950155@qq.com> Date: Sun, 28 Feb 2021 21:58:48 +0800 Subject: [PATCH 01/11] [Rust] Introduce limit option for sort kernel 1. Introduce partial_sort 2. Introduce pdqsort --- rust/arrow/Cargo.toml | 2 + rust/arrow/benches/sort_kernel.rs | 2 +- rust/arrow/src/compute/kernels/sort.rs | 415 +++++++++++++++------- rust/datafusion/src/physical_plan/sort.rs | 2 + 4 files changed, 296 insertions(+), 125 deletions(-) diff --git a/rust/arrow/Cargo.toml b/rust/arrow/Cargo.toml index 5ab1f8cc02b..11ad67d6b64 100644 --- a/rust/arrow/Cargo.toml +++ b/rust/arrow/Cargo.toml @@ -51,6 +51,8 @@ flatbuffers = "^0.8" hex = "0.4" prettytable-rs = { version = "0.8.0", optional = true } lexical-core = "^0.7" +partial_sort = "0.1.1" +pdqsort = "1.0.3" [features] default = [] diff --git a/rust/arrow/benches/sort_kernel.rs b/rust/arrow/benches/sort_kernel.rs index 01701d30a9f..3369b245bcf 100644 --- a/rust/arrow/benches/sort_kernel.rs +++ b/rust/arrow/benches/sort_kernel.rs @@ -45,7 +45,7 @@ fn bench_sort(arr_a: &ArrayRef, array_b: &ArrayRef) { }, ]; - criterion::black_box(lexsort(&columns).unwrap()); + criterion::black_box(lexsort(&columns, None).unwrap()); } fn add_benchmark(c: &mut Criterion) { diff --git a/rust/arrow/src/compute/kernels/sort.rs b/rust/arrow/src/compute/kernels/sort.rs index e33b76ed0a1..e6fa5ce047b 100644 --- a/rust/arrow/src/compute/kernels/sort.rs +++ b/rust/arrow/src/compute/kernels/sort.rs @@ -17,16 +17,16 @@ //! Defines sort kernel for `ArrayRef` -use std::cmp::{Ordering, Reverse}; +use std::cmp::Ordering; +use std::sync::Arc; use crate::array::*; +use crate::buffer::MutableBuffer; use crate::compute::take; use crate::datatypes::*; use crate::error::{ArrowError, Result}; -use crate::buffer::MutableBuffer; -use num::ToPrimitive; -use std::sync::Arc; +use partial_sort::PartialSort; use TimeUnit::*; /// Sort the `ArrayRef` using `SortOptions`. @@ -36,8 +36,12 @@ use TimeUnit::*; /// /// Returns an `ArrowError::ComputeError(String)` if the array type is either unsupported by `sort_to_indices` or `take`. /// -pub fn sort(values: &ArrayRef, options: Option) -> Result { - let indices = sort_to_indices(values, options)?; +pub fn sort( + values: &ArrayRef, + options: Option, + limit: Option, +) -> Result { + let indices = sort_to_indices(values, options, limit)?; take(values.as_ref(), &indices, None) } @@ -76,8 +80,14 @@ where // partition indices into valid and null indices fn partition_validity(array: &ArrayRef) -> (Vec, Vec) { - let indices = 0..(array.len().to_u32().unwrap()); - indices.partition(|index| array.is_valid(*index as usize)) + match array.null_count() { + // faster path + 0 => ((0..(array.len() as u32)).collect(), vec![]), + _ => { + let indices = 0..(array.len() as u32); + indices.partition(|index| array.is_valid(*index as usize)) + } + } } /// Sort elements from `ArrayRef` into an unsigned integer (`UInt32Array`) of indices. @@ -85,109 +95,160 @@ fn partition_validity(array: &ArrayRef) -> (Vec, Vec) { pub fn sort_to_indices( values: &ArrayRef, options: Option, + limit: Option, ) -> Result { let options = options.unwrap_or_default(); let (v, n) = partition_validity(values); match values.data_type() { - DataType::Boolean => sort_boolean(values, v, n, &options), - DataType::Int8 => sort_primitive::(values, v, n, cmp, &options), - DataType::Int16 => sort_primitive::(values, v, n, cmp, &options), - DataType::Int32 => sort_primitive::(values, v, n, cmp, &options), - DataType::Int64 => sort_primitive::(values, v, n, cmp, &options), - DataType::UInt8 => sort_primitive::(values, v, n, cmp, &options), - DataType::UInt16 => sort_primitive::(values, v, n, cmp, &options), - DataType::UInt32 => sort_primitive::(values, v, n, cmp, &options), - DataType::UInt64 => sort_primitive::(values, v, n, cmp, &options), + DataType::Boolean => sort_boolean(values, v, n, &options, limit), + DataType::Int8 => { + sort_primitive::(values, v, n, cmp, &options, limit) + } + DataType::Int16 => { + sort_primitive::(values, v, n, cmp, &options, limit) + } + DataType::Int32 => { + sort_primitive::(values, v, n, cmp, &options, limit) + } + DataType::Int64 => { + sort_primitive::(values, v, n, cmp, &options, limit) + } + DataType::UInt8 => { + sort_primitive::(values, v, n, cmp, &options, limit) + } + DataType::UInt16 => { + sort_primitive::(values, v, n, cmp, &options, limit) + } + DataType::UInt32 => { + sort_primitive::(values, v, n, cmp, &options, limit) + } + DataType::UInt64 => { + sort_primitive::(values, v, n, cmp, &options, limit) + } DataType::Float32 => { - sort_primitive::(values, v, n, total_cmp_32, &options) + sort_primitive::(values, v, n, total_cmp_32, &options, limit) } DataType::Float64 => { - sort_primitive::(values, v, n, total_cmp_64, &options) + sort_primitive::(values, v, n, total_cmp_64, &options, limit) + } + DataType::Date32 => { + sort_primitive::(values, v, n, cmp, &options, limit) + } + DataType::Date64 => { + sort_primitive::(values, v, n, cmp, &options, limit) } - DataType::Date32 => sort_primitive::(values, v, n, cmp, &options), - DataType::Date64 => sort_primitive::(values, v, n, cmp, &options), DataType::Time32(Second) => { - sort_primitive::(values, v, n, cmp, &options) + sort_primitive::(values, v, n, cmp, &options, limit) } DataType::Time32(Millisecond) => { - sort_primitive::(values, v, n, cmp, &options) + sort_primitive::(values, v, n, cmp, &options, limit) } DataType::Time64(Microsecond) => { - sort_primitive::(values, v, n, cmp, &options) + sort_primitive::(values, v, n, cmp, &options, limit) } DataType::Time64(Nanosecond) => { - sort_primitive::(values, v, n, cmp, &options) + sort_primitive::(values, v, n, cmp, &options, limit) } DataType::Timestamp(Second, _) => { - sort_primitive::(values, v, n, cmp, &options) + sort_primitive::(values, v, n, cmp, &options, limit) } DataType::Timestamp(Millisecond, _) => { - sort_primitive::(values, v, n, cmp, &options) + sort_primitive::( + values, v, n, cmp, &options, limit, + ) } DataType::Timestamp(Microsecond, _) => { - sort_primitive::(values, v, n, cmp, &options) + sort_primitive::( + values, v, n, cmp, &options, limit, + ) } DataType::Timestamp(Nanosecond, _) => { - sort_primitive::(values, v, n, cmp, &options) + sort_primitive::( + values, v, n, cmp, &options, limit, + ) } DataType::Interval(IntervalUnit::YearMonth) => { - sort_primitive::(values, v, n, cmp, &options) + sort_primitive::(values, v, n, cmp, &options, limit) } DataType::Interval(IntervalUnit::DayTime) => { - sort_primitive::(values, v, n, cmp, &options) + sort_primitive::(values, v, n, cmp, &options, limit) } DataType::Duration(TimeUnit::Second) => { - sort_primitive::(values, v, n, cmp, &options) + sort_primitive::(values, v, n, cmp, &options, limit) } DataType::Duration(TimeUnit::Millisecond) => { - sort_primitive::(values, v, n, cmp, &options) + sort_primitive::( + values, v, n, cmp, &options, limit, + ) } DataType::Duration(TimeUnit::Microsecond) => { - sort_primitive::(values, v, n, cmp, &options) + sort_primitive::( + values, v, n, cmp, &options, limit, + ) } DataType::Duration(TimeUnit::Nanosecond) => { - sort_primitive::(values, v, n, cmp, &options) + sort_primitive::( + values, v, n, cmp, &options, limit, + ) } - DataType::Utf8 => sort_string(values, v, n, &options), + DataType::Utf8 => sort_string(values, v, n, &options, limit), DataType::List(field) => match field.data_type() { - DataType::Int8 => sort_list::(values, v, n, &options), - DataType::Int16 => sort_list::(values, v, n, &options), - DataType::Int32 => sort_list::(values, v, n, &options), - DataType::Int64 => sort_list::(values, v, n, &options), - DataType::UInt8 => sort_list::(values, v, n, &options), - DataType::UInt16 => sort_list::(values, v, n, &options), - DataType::UInt32 => sort_list::(values, v, n, &options), - DataType::UInt64 => sort_list::(values, v, n, &options), + DataType::Int8 => sort_list::(values, v, n, &options, limit), + DataType::Int16 => sort_list::(values, v, n, &options, limit), + DataType::Int32 => sort_list::(values, v, n, &options, limit), + DataType::Int64 => sort_list::(values, v, n, &options, limit), + DataType::UInt8 => sort_list::(values, v, n, &options, limit), + DataType::UInt16 => { + sort_list::(values, v, n, &options, limit) + } + DataType::UInt32 => { + sort_list::(values, v, n, &options, limit) + } + DataType::UInt64 => { + sort_list::(values, v, n, &options, limit) + } t => Err(ArrowError::ComputeError(format!( "Sort not supported for list type {:?}", t ))), }, DataType::LargeList(field) => match field.data_type() { - DataType::Int8 => sort_list::(values, v, n, &options), - DataType::Int16 => sort_list::(values, v, n, &options), - DataType::Int32 => sort_list::(values, v, n, &options), - DataType::Int64 => sort_list::(values, v, n, &options), - DataType::UInt8 => sort_list::(values, v, n, &options), - DataType::UInt16 => sort_list::(values, v, n, &options), - DataType::UInt32 => sort_list::(values, v, n, &options), - DataType::UInt64 => sort_list::(values, v, n, &options), + DataType::Int8 => sort_list::(values, v, n, &options, limit), + DataType::Int16 => sort_list::(values, v, n, &options, limit), + DataType::Int32 => sort_list::(values, v, n, &options, limit), + DataType::Int64 => sort_list::(values, v, n, &options, limit), + DataType::UInt8 => sort_list::(values, v, n, &options, limit), + DataType::UInt16 => { + sort_list::(values, v, n, &options, limit) + } + DataType::UInt32 => { + sort_list::(values, v, n, &options, limit) + } + DataType::UInt64 => { + sort_list::(values, v, n, &options, limit) + } t => Err(ArrowError::ComputeError(format!( "Sort not supported for list type {:?}", t ))), }, DataType::FixedSizeList(field, _) => match field.data_type() { - DataType::Int8 => sort_list::(values, v, n, &options), - DataType::Int16 => sort_list::(values, v, n, &options), - DataType::Int32 => sort_list::(values, v, n, &options), - DataType::Int64 => sort_list::(values, v, n, &options), - DataType::UInt8 => sort_list::(values, v, n, &options), - DataType::UInt16 => sort_list::(values, v, n, &options), - DataType::UInt32 => sort_list::(values, v, n, &options), - DataType::UInt64 => sort_list::(values, v, n, &options), + DataType::Int8 => sort_list::(values, v, n, &options, limit), + DataType::Int16 => sort_list::(values, v, n, &options, limit), + DataType::Int32 => sort_list::(values, v, n, &options, limit), + DataType::Int64 => sort_list::(values, v, n, &options, limit), + DataType::UInt8 => sort_list::(values, v, n, &options, limit), + DataType::UInt16 => { + sort_list::(values, v, n, &options, limit) + } + DataType::UInt32 => { + sort_list::(values, v, n, &options, limit) + } + DataType::UInt64 => { + sort_list::(values, v, n, &options, limit) + } t => Err(ArrowError::ComputeError(format!( "Sort not supported for list type {:?}", t @@ -198,28 +259,28 @@ pub fn sort_to_indices( { match key_type.as_ref() { DataType::Int8 => { - sort_string_dictionary::(values, v, n, &options) + sort_string_dictionary::(values, v, n, &options, limit) } DataType::Int16 => { - sort_string_dictionary::(values, v, n, &options) + sort_string_dictionary::(values, v, n, &options, limit) } DataType::Int32 => { - sort_string_dictionary::(values, v, n, &options) + sort_string_dictionary::(values, v, n, &options, limit) } DataType::Int64 => { - sort_string_dictionary::(values, v, n, &options) + sort_string_dictionary::(values, v, n, &options, limit) } DataType::UInt8 => { - sort_string_dictionary::(values, v, n, &options) + sort_string_dictionary::(values, v, n, &options, limit) } DataType::UInt16 => { - sort_string_dictionary::(values, v, n, &options) + sort_string_dictionary::(values, v, n, &options, limit) } DataType::UInt32 => { - sort_string_dictionary::(values, v, n, &options) + sort_string_dictionary::(values, v, n, &options, limit) } DataType::UInt64 => { - sort_string_dictionary::(values, v, n, &options) + sort_string_dictionary::(values, v, n, &options, limit) } t => Err(ArrowError::ComputeError(format!( "Sort not supported for dictionary key type {:?}", @@ -260,6 +321,7 @@ fn sort_boolean( value_indices: Vec, null_indices: Vec, options: &SortOptions, + limit: Option, ) -> Result { let values = values .as_any() @@ -278,12 +340,27 @@ fn sort_boolean( let valids_len = valids.len(); let nulls_len = nulls.len(); - if !descending { - valids.sort_by(|a, b| a.1.cmp(&b.1)); - } else { - valids.sort_by(|a, b| a.1.cmp(&b.1).reverse()); - // reverse to keep a stable ordering - nulls.reverse(); + let mut len = values.len(); + match limit { + Some(limit) => { + len = limit.min(len); + if !descending { + valids.partial_sort(len, |a, b| cmp(a.1, b.1)); + } else { + valids.partial_sort(len, |a, b| cmp(a.1, b.1).reverse()); + // reverse to keep a stable ordering + nulls.reverse(); + } + } + _ => { + if !descending { + pdqsort::sort_by(&mut valids, |a, b| cmp(a.1, b.1)); + } else { + pdqsort::sort_by(&mut valids, |a, b| cmp(a.1, b.1).reverse()); + // reverse to keep a stable ordering + nulls.reverse(); + } + } } // collect results directly into a buffer instead of a vec to avoid another aligned allocation @@ -295,17 +372,23 @@ fn sort_boolean( debug_assert_eq!(result_slice.len(), nulls_len + valids_len); if options.nulls_first { - result_slice[0..nulls_len].copy_from_slice(&nulls); - insert_valid_values(result_slice, nulls_len, valids); + let size = nulls_len.min(len); + result_slice[0..nulls_len.min(len)].copy_from_slice(&nulls); + if nulls_len < len { + insert_valid_values(result_slice, nulls_len, valids, len - size); + } } else { // nulls last - insert_valid_values(result_slice, 0, valids); - result_slice[valids_len..].copy_from_slice(nulls.as_slice()) + let size = valids.len().min(len); + insert_valid_values(result_slice, 0, valids, size); + if len > size { + result_slice[valids_len..].copy_from_slice(&nulls[0..(len - valids_len)]); + } } let result_data = Arc::new(ArrayData::new( DataType::UInt32, - values.len(), + len, Some(0), None, 0, @@ -324,6 +407,7 @@ fn sort_primitive( null_indices: Vec, cmp: F, options: &SortOptions, + limit: Option, ) -> Result where T: ArrowPrimitiveType, @@ -343,13 +427,29 @@ where let valids_len = valids.len(); let nulls_len = nulls.len(); - - if !descending { - valids.sort_by(|a, b| cmp(a.1, b.1)); - } else { - valids.sort_by(|a, b| cmp(a.1, b.1).reverse()); - // reverse to keep a stable ordering - nulls.reverse(); + let mut len = values.len(); + + match limit { + Some(limit) => { + len = limit.min(len); + + if !descending { + valids.partial_sort(len, |a, b| cmp(a.1, b.1)); + } else { + valids.partial_sort(len, |a, b| cmp(a.1, b.1).reverse()); + // reverse to keep a stable ordering + nulls.reverse(); + } + } + _ => { + if !descending { + pdqsort::sort_by(&mut valids, |a, b| cmp(a.1, b.1)); + } else { + pdqsort::sort_by(&mut valids, |a, b| cmp(a.1, b.1).reverse()); + // reverse to keep a stable ordering + nulls.reverse(); + } + } } // collect results directly into a buffer instead of a vec to avoid another aligned allocation @@ -361,17 +461,23 @@ where debug_assert_eq!(result_slice.len(), nulls_len + valids_len); if options.nulls_first { - result_slice[0..nulls_len].copy_from_slice(&nulls); - insert_valid_values(result_slice, nulls_len, valids); + let size = nulls_len.min(len); + result_slice[0..nulls_len.min(len)].copy_from_slice(&nulls); + if nulls_len < len { + insert_valid_values(result_slice, nulls_len, valids, len - size); + } } else { // nulls last - insert_valid_values(result_slice, 0, valids); - result_slice[valids_len..].copy_from_slice(nulls.as_slice()) + let size = valids.len().min(len); + insert_valid_values(result_slice, 0, valids, size); + if len > size { + result_slice[valids_len..].copy_from_slice(&nulls[0..(len - valids_len)]); + } } let result_data = Arc::new(ArrayData::new( DataType::UInt32, - values.len(), + len, Some(0), None, 0, @@ -387,19 +493,19 @@ fn insert_valid_values( result_slice: &mut [u32], offset: usize, valids: Vec<(u32, T)>, + len: usize, ) { let valids_len = valids.len(); - // helper to append the index part of the valid tuples let append_valids = move |dst_slice: &mut [u32]| { debug_assert_eq!(dst_slice.len(), valids_len); dst_slice .iter_mut() - .zip(valids.into_iter()) + .zip(valids.as_slice()[0..len].iter()) .for_each(|(dst, src)| *dst = src.0) }; - append_valids(&mut result_slice[offset..offset + valids_len]); + append_valids(&mut result_slice[offset..offset + len]); } /// Sort strings @@ -408,6 +514,7 @@ fn sort_string( value_indices: Vec, null_indices: Vec, options: &SortOptions, + limit: Option, ) -> Result { let values = as_string_array(values); @@ -416,6 +523,7 @@ fn sort_string( value_indices, null_indices, options, + limit, |array, idx| array.value(idx as usize), ) } @@ -426,6 +534,7 @@ fn sort_string_dictionary( value_indices: Vec, null_indices: Vec, options: &SortOptions, + limit: Option, ) -> Result { let values: &DictionaryArray = as_dictionary_array::(values); @@ -439,6 +548,7 @@ fn sort_string_dictionary( value_indices, null_indices, options, + limit, |array: &PrimitiveArray, idx| -> &str { let key: T::Native = array.value(idx as usize); dict.value(key.to_usize().unwrap()) @@ -454,6 +564,7 @@ fn sort_string_helper<'a, A: Array, F>( value_indices: Vec, null_indices: Vec, options: &SortOptions, + limit: Option, value_fn: F, ) -> Result where @@ -464,23 +575,42 @@ where .map(|index| (index, value_fn(&values, index))) .collect::>(); let mut nulls = null_indices; - if !options.descending { - valids.sort_by_key(|a| a.1); - } else { - valids.sort_by_key(|a| Reverse(a.1)); - nulls.reverse(); + let descending = options.descending; + let mut len = values.len(); + match limit { + Some(limit) => { + len = limit.min(len); + if !descending { + valids.partial_sort(len, |a, b| cmp(a.1, b.1)); + } else { + valids.partial_sort(len, |a, b| cmp(a.1, b.1).reverse()); + // reverse to keep a stable ordering + nulls.reverse(); + } + } + _ => { + if !descending { + valids.sort_by(|a, b| cmp(a.1, b.1)); + } else { + valids.sort_by(|a, b| cmp(a.1, b.1).reverse()); + // reverse to keep a stable ordering + nulls.reverse(); + } + } } + // collect the order of valid tuplies let mut valid_indices: Vec = valids.iter().map(|tuple| tuple.0).collect(); if options.nulls_first { nulls.append(&mut valid_indices); + nulls.truncate(len); return Ok(UInt32Array::from(nulls)); } // no need to sort nulls as they are in the correct order already valid_indices.append(&mut nulls); - + valid_indices.truncate(len); Ok(UInt32Array::from(valid_indices)) } @@ -490,6 +620,7 @@ fn sort_list( value_indices: Vec, mut null_indices: Vec, options: &SortOptions, + limit: Option, ) -> Result where S: OffsetSizeTrait, @@ -517,20 +648,43 @@ where }, ); - if !options.descending { - valids.sort_by(|a, b| cmp_array(a.1.as_ref(), b.1.as_ref())) - } else { - valids.sort_by(|a, b| cmp_array(a.1.as_ref(), b.1.as_ref()).reverse()) + let mut len = values.len(); + let descending = options.descending; + match limit { + Some(limit) => { + len = limit.min(len); + + if !descending { + valids.partial_sort(len, |a, b| cmp_array(a.1.as_ref(), b.1.as_ref())); + } else { + valids.partial_sort(len, |a, b| { + cmp_array(a.1.as_ref(), b.1.as_ref()).reverse() + }); + // reverse to keep a stable ordering + null_indices.reverse(); + } + } + _ => { + if !descending { + valids.sort_by(|a, b| cmp_array(a.1.as_ref(), b.1.as_ref())); + } else { + valids.sort_by(|a, b| cmp_array(a.1.as_ref(), b.1.as_ref()).reverse()); + // reverse to keep a stable ordering + null_indices.reverse(); + } + } } let mut valid_indices: Vec = valids.iter().map(|tuple| tuple.0).collect(); if options.nulls_first { null_indices.append(&mut valid_indices); + null_indices.truncate(len); return Ok(UInt32Array::from(null_indices)); } valid_indices.append(&mut null_indices); + valid_indices.truncate(len); Ok(UInt32Array::from(valid_indices)) } @@ -595,13 +749,13 @@ pub struct SortColumn { /// nulls_first: false, /// }), /// }, -/// ]).unwrap(); +/// ], None).unwrap(); /// /// assert_eq!(as_primitive_array::(&sorted_columns[0]).value(1), -64); /// assert!(sorted_columns[0].is_null(0)); /// ``` -pub fn lexsort(columns: &[SortColumn]) -> Result> { - let indices = lexsort_to_indices(columns)?; +pub fn lexsort(columns: &[SortColumn], limit: Option) -> Result> { + let indices = lexsort_to_indices(columns, limit)?; columns .iter() .map(|c| take(c.values.as_ref(), &indices, None)) @@ -610,7 +764,10 @@ pub fn lexsort(columns: &[SortColumn]) -> Result> { /// Sort elements lexicographically from a list of `ArrayRef` into an unsigned integer /// (`UInt32Array`) of indices. -pub fn lexsort_to_indices(columns: &[SortColumn]) -> Result { +pub fn lexsort_to_indices( + columns: &[SortColumn], + limit: Option, +) -> Result { if columns.is_empty() { return Err(ArrowError::InvalidArgumentError( "Sort requires at least one column".to_string(), @@ -619,7 +776,7 @@ pub fn lexsort_to_indices(columns: &[SortColumn]) -> Result { if columns.len() == 1 { // fallback to non-lexical sort let column = &columns[0]; - return sort_to_indices(&column.values, column.options); + return sort_to_indices(&column.values, column.options, limit); } let row_count = columns[0].values.len(); @@ -686,12 +843,19 @@ pub fn lexsort_to_indices(columns: &[SortColumn]) -> Result { }; let mut value_indices = (0..row_count).collect::>(); - value_indices.sort_by(lex_comparator); + let mut len = value_indices.len(); + match limit { + Some(limit) => { + len = len.min(limit); + value_indices.partial_sort(len, lex_comparator); + } + None => value_indices.sort_by(lex_comparator), + } Ok(UInt32Array::from( - value_indices - .into_iter() - .map(|i| i as u32) + (&value_indices)[0..len] + .iter() + .map(|i| *i as u32) .collect::>(), )) } @@ -713,7 +877,8 @@ mod tests { ) { let output = BooleanArray::from(data); let expected = UInt32Array::from(expected_data); - let output = sort_to_indices(&(Arc::new(output) as ArrayRef), options).unwrap(); + let output = + sort_to_indices(&(Arc::new(output) as ArrayRef), options, None).unwrap(); assert_eq!(output, expected) } @@ -727,7 +892,8 @@ mod tests { { let output = PrimitiveArray::::from(data); let expected = UInt32Array::from(expected_data); - let output = sort_to_indices(&(Arc::new(output) as ArrayRef), options).unwrap(); + let output = + sort_to_indices(&(Arc::new(output) as ArrayRef), options, None).unwrap(); assert_eq!(output, expected) } @@ -741,7 +907,7 @@ mod tests { { let output = PrimitiveArray::::from(data); let expected = Arc::new(PrimitiveArray::::from(expected_data)) as ArrayRef; - let output = sort(&(Arc::new(output) as ArrayRef), options).unwrap(); + let output = sort(&(Arc::new(output) as ArrayRef), options, None).unwrap(); assert_eq!(&output, &expected) } @@ -752,7 +918,8 @@ mod tests { ) { let output = StringArray::from(data); let expected = UInt32Array::from(expected_data); - let output = sort_to_indices(&(Arc::new(output) as ArrayRef), options).unwrap(); + let output = + sort_to_indices(&(Arc::new(output) as ArrayRef), options, None).unwrap(); assert_eq!(output, expected) } @@ -763,7 +930,7 @@ mod tests { ) { let output = StringArray::from(data); let expected = Arc::new(StringArray::from(expected_data)) as ArrayRef; - let output = sort(&(Arc::new(output) as ArrayRef), options).unwrap(); + let output = sort(&(Arc::new(output) as ArrayRef), options, None).unwrap(); assert_eq!(&output, &expected) } @@ -779,7 +946,7 @@ mod tests { .downcast_ref::() .expect("Unable to get dictionary values"); - let sorted = sort(&(Arc::new(array) as ArrayRef), options).unwrap(); + let sorted = sort(&(Arc::new(array) as ArrayRef), options, None).unwrap(); let sorted = sorted .as_any() .downcast_ref::>() @@ -823,7 +990,7 @@ mod tests { // for FixedSizedList if let Some(length) = fixed_length { let input = Arc::new(build_fixed_size_list_nullable(data.clone(), length)); - let sorted = sort(&(input as ArrayRef), options).unwrap(); + let sorted = sort(&(input as ArrayRef), options, None).unwrap(); let expected = Arc::new(build_fixed_size_list_nullable( expected_data.clone(), length, @@ -834,7 +1001,7 @@ mod tests { // for List let input = Arc::new(build_generic_list_nullable::(data.clone())); - let sorted = sort(&(input as ArrayRef), options).unwrap(); + let sorted = sort(&(input as ArrayRef), options, None).unwrap(); let expected = Arc::new(build_generic_list_nullable::(expected_data.clone())) as ArrayRef; @@ -843,7 +1010,7 @@ mod tests { // for LargeList let input = Arc::new(build_generic_list_nullable::(data)); - let sorted = sort(&(input as ArrayRef), options).unwrap(); + let sorted = sort(&(input as ArrayRef), options, None).unwrap(); let expected = Arc::new(build_generic_list_nullable::(expected_data)) as ArrayRef; @@ -851,7 +1018,7 @@ mod tests { } fn test_lex_sort_arrays(input: Vec, expected_output: Vec) { - let sorted = lexsort(&input).unwrap(); + let sorted = lexsort(&input, None).unwrap(); for (result, expected) in sorted.iter().zip(expected_output.iter()) { assert_eq!(result, expected); @@ -1577,7 +1744,7 @@ mod tests { }, ]; assert!( - lexsort(&input).is_err(), + lexsort(&input, None).is_err(), "lexsort should reject columns with different row counts" ); } diff --git a/rust/datafusion/src/physical_plan/sort.rs b/rust/datafusion/src/physical_plan/sort.rs index 042b9f1da81..c595b0b960e 100644 --- a/rust/datafusion/src/physical_plan/sort.rs +++ b/rust/datafusion/src/physical_plan/sort.rs @@ -156,12 +156,14 @@ fn sort_batches( )?; // sort combined record batch + // TODO: pushup the limit expression to sort let indices = lexsort_to_indices( &expr .iter() .map(|e| e.evaluate_to_sort_column(&combined_batch)) .collect::>>() .map_err(DataFusionError::into_arrow_external_error)?, + None, )?; // reorder all rows based on sorted indices From 06c503c9559633738065047b4ddb41a44fd2c7d4 Mon Sep 17 00:00:00 2001 From: sundy-li <543950155@qq.com> Date: Mon, 1 Mar 2021 15:25:01 +0800 Subject: [PATCH 02/11] Add tests and benchmarks --- rust/arrow/benches/sort_kernel.rs | 41 +++- rust/arrow/src/compute/kernels/sort.rs | 272 ++++++++++++++++++++----- 2 files changed, 253 insertions(+), 60 deletions(-) diff --git a/rust/arrow/benches/sort_kernel.rs b/rust/arrow/benches/sort_kernel.rs index 3369b245bcf..105b3e66614 100644 --- a/rust/arrow/benches/sort_kernel.rs +++ b/rust/arrow/benches/sort_kernel.rs @@ -33,7 +33,7 @@ fn create_array(size: usize, with_nulls: bool) -> ArrayRef { Arc::new(array) } -fn bench_sort(arr_a: &ArrayRef, array_b: &ArrayRef) { +fn bench_sort(arr_a: &ArrayRef, array_b: &ArrayRef, limit: Option) { let columns = vec![ SortColumn { values: arr_a.clone(), @@ -45,29 +45,58 @@ fn bench_sort(arr_a: &ArrayRef, array_b: &ArrayRef) { }, ]; - criterion::black_box(lexsort(&columns, None).unwrap()); + criterion::black_box(lexsort(&columns, limit).unwrap()); } fn add_benchmark(c: &mut Criterion) { let arr_a = create_array(2u64.pow(10) as usize, false); let arr_b = create_array(2u64.pow(10) as usize, false); - c.bench_function("sort 2^10", |b| b.iter(|| bench_sort(&arr_a, &arr_b))); + c.bench_function("sort 2^10", |b| b.iter(|| bench_sort(&arr_a, &arr_b, None))); let arr_a = create_array(2u64.pow(12) as usize, false); let arr_b = create_array(2u64.pow(12) as usize, false); - c.bench_function("sort 2^12", |b| b.iter(|| bench_sort(&arr_a, &arr_b))); + c.bench_function("sort 2^12", |b| b.iter(|| bench_sort(&arr_a, &arr_b, None))); let arr_a = create_array(2u64.pow(10) as usize, true); let arr_b = create_array(2u64.pow(10) as usize, true); - c.bench_function("sort nulls 2^10", |b| b.iter(|| bench_sort(&arr_a, &arr_b))); + c.bench_function("sort nulls 2^10", |b| { + b.iter(|| bench_sort(&arr_a, &arr_b, None)) + }); let arr_a = create_array(2u64.pow(12) as usize, true); let arr_b = create_array(2u64.pow(12) as usize, true); - c.bench_function("sort nulls 2^12", |b| b.iter(|| bench_sort(&arr_a, &arr_b))); + c.bench_function("sort nulls 2^12", |b| { + b.iter(|| bench_sort(&arr_a, &arr_b, None)) + }); + + /// with limit + { + let arr_a = create_array(2u64.pow(12) as usize, false); + let arr_b = create_array(2u64.pow(12) as usize, false); + c.bench_function("sort 2^12 limit 10", |b| { + b.iter(|| bench_sort(&arr_a, &arr_b, Some(10))) + }); + + let arr_a = create_array(2u64.pow(12) as usize, false); + let arr_b = create_array(2u64.pow(12) as usize, false); + c.bench_function("sort 2^12 limit 100", |b| { + b.iter(|| bench_sort(&arr_a, &arr_b, Some(100))) + }); + + let arr_a = create_array(2u64.pow(12) as usize, true); + let arr_b = create_array(2u64.pow(12) as usize, true); + + c.bench_function("sort nulls 2^12 limit 10", |b| { + b.iter(|| bench_sort(&arr_a, &arr_b, Some(10))) + }); + c.bench_function("sort nulls 2^12 limit 100", |b| { + b.iter(|| bench_sort(&arr_a, &arr_b, Some(10))) + }); + } } criterion_group!(benches, add_benchmark); diff --git a/rust/arrow/src/compute/kernels/sort.rs b/rust/arrow/src/compute/kernels/sort.rs index e6fa5ce047b..18e9268d771 100644 --- a/rust/arrow/src/compute/kernels/sort.rs +++ b/rust/arrow/src/compute/kernels/sort.rs @@ -375,12 +375,12 @@ fn sort_boolean( let size = nulls_len.min(len); result_slice[0..nulls_len.min(len)].copy_from_slice(&nulls); if nulls_len < len { - insert_valid_values(result_slice, nulls_len, valids, len - size); + insert_valid_values(result_slice, nulls_len, &valids[0..len - size]); } } else { // nulls last let size = valids.len().min(len); - insert_valid_values(result_slice, 0, valids, size); + insert_valid_values(result_slice, 0, &valids[0..size]); if len > size { result_slice[valids_len..].copy_from_slice(&nulls[0..(len - valids_len)]); } @@ -464,12 +464,12 @@ where let size = nulls_len.min(len); result_slice[0..nulls_len.min(len)].copy_from_slice(&nulls); if nulls_len < len { - insert_valid_values(result_slice, nulls_len, valids, len - size); + insert_valid_values(result_slice, nulls_len, &valids[0..len - size]); } } else { // nulls last let size = valids.len().min(len); - insert_valid_values(result_slice, 0, valids, size); + insert_valid_values(result_slice, 0, &valids[0..size]); if len > size { result_slice[valids_len..].copy_from_slice(&nulls[0..(len - valids_len)]); } @@ -489,23 +489,18 @@ where } // insert valid and nan values in the correct order depending on the descending flag -fn insert_valid_values( - result_slice: &mut [u32], - offset: usize, - valids: Vec<(u32, T)>, - len: usize, -) { +fn insert_valid_values(result_slice: &mut [u32], offset: usize, valids: &[(u32, T)]) { let valids_len = valids.len(); // helper to append the index part of the valid tuples let append_valids = move |dst_slice: &mut [u32]| { debug_assert_eq!(dst_slice.len(), valids_len); dst_slice .iter_mut() - .zip(valids.as_slice()[0..len].iter()) + .zip(valids.iter()) .for_each(|(dst, src)| *dst = src.0) }; - append_valids(&mut result_slice[offset..offset + len]); + append_valids(&mut result_slice[offset..offset + valids.len()]); } /// Sort strings @@ -650,29 +645,21 @@ where let mut len = values.len(); let descending = options.descending; + match limit { Some(limit) => { len = limit.min(len); - - if !descending { - valids.partial_sort(len, |a, b| cmp_array(a.1.as_ref(), b.1.as_ref())); - } else { - valids.partial_sort(len, |a, b| { - cmp_array(a.1.as_ref(), b.1.as_ref()).reverse() - }); - // reverse to keep a stable ordering - null_indices.reverse(); - } - } - _ => { - if !descending { - valids.sort_by(|a, b| cmp_array(a.1.as_ref(), b.1.as_ref())); - } else { - valids.sort_by(|a, b| cmp_array(a.1.as_ref(), b.1.as_ref()).reverse()); - // reverse to keep a stable ordering - null_indices.reverse(); - } } + _ => {} + } + + /// we are not using partial_sort here, because array is ArrayRef. Something is not working good in that. + if !descending { + valids.sort_by(|a, b| cmp_array(a.1.as_ref(), b.1.as_ref())); + } else { + valids.sort_by(|a, b| cmp_array(a.1.as_ref(), b.1.as_ref()).reverse()); + // reverse to keep a stable ordering + null_indices.reverse(); } let mut valid_indices: Vec = valids.iter().map(|tuple| tuple.0).collect(); @@ -873,18 +860,20 @@ mod tests { fn test_sort_to_indices_boolean_arrays( data: Vec>, options: Option, + limit: Option, expected_data: Vec, ) { let output = BooleanArray::from(data); let expected = UInt32Array::from(expected_data); let output = - sort_to_indices(&(Arc::new(output) as ArrayRef), options, None).unwrap(); + sort_to_indices(&(Arc::new(output) as ArrayRef), options, limit).unwrap(); assert_eq!(output, expected) } fn test_sort_to_indices_primitive_arrays( data: Vec>, options: Option, + limit: Option, expected_data: Vec, ) where T: ArrowPrimitiveType, @@ -893,13 +882,14 @@ mod tests { let output = PrimitiveArray::::from(data); let expected = UInt32Array::from(expected_data); let output = - sort_to_indices(&(Arc::new(output) as ArrayRef), options, None).unwrap(); + sort_to_indices(&(Arc::new(output) as ArrayRef), options, limit).unwrap(); assert_eq!(output, expected) } fn test_sort_primitive_arrays( data: Vec>, options: Option, + limit: Option, expected_data: Vec>, ) where T: ArrowPrimitiveType, @@ -907,36 +897,39 @@ mod tests { { let output = PrimitiveArray::::from(data); let expected = Arc::new(PrimitiveArray::::from(expected_data)) as ArrayRef; - let output = sort(&(Arc::new(output) as ArrayRef), options, None).unwrap(); + let output = sort(&(Arc::new(output) as ArrayRef), options, limit).unwrap(); assert_eq!(&output, &expected) } fn test_sort_to_indices_string_arrays( data: Vec>, options: Option, + limit: Option, expected_data: Vec, ) { let output = StringArray::from(data); let expected = UInt32Array::from(expected_data); let output = - sort_to_indices(&(Arc::new(output) as ArrayRef), options, None).unwrap(); + sort_to_indices(&(Arc::new(output) as ArrayRef), options, limit).unwrap(); assert_eq!(output, expected) } fn test_sort_string_arrays( data: Vec>, options: Option, + limit: Option, expected_data: Vec>, ) { let output = StringArray::from(data); let expected = Arc::new(StringArray::from(expected_data)) as ArrayRef; - let output = sort(&(Arc::new(output) as ArrayRef), options, None).unwrap(); + let output = sort(&(Arc::new(output) as ArrayRef), options, limit).unwrap(); assert_eq!(&output, &expected) } fn test_sort_string_dict_arrays( data: Vec>, options: Option, + limit: Option, expected_data: Vec>, ) { let array = DictionaryArray::::from_iter(data.into_iter()); @@ -946,7 +939,7 @@ mod tests { .downcast_ref::() .expect("Unable to get dictionary values"); - let sorted = sort(&(Arc::new(array) as ArrayRef), options, None).unwrap(); + let sorted = sort(&(Arc::new(array) as ArrayRef), options, limit).unwrap(); let sorted = sorted .as_any() .downcast_ref::>() @@ -981,6 +974,7 @@ mod tests { fn test_sort_list_arrays( data: Vec>>>, options: Option, + limit: Option, expected_data: Vec>>>, fixed_length: Option, ) where @@ -990,7 +984,7 @@ mod tests { // for FixedSizedList if let Some(length) = fixed_length { let input = Arc::new(build_fixed_size_list_nullable(data.clone(), length)); - let sorted = sort(&(input as ArrayRef), options, None).unwrap(); + let sorted = sort(&(input as ArrayRef), options, limit).unwrap(); let expected = Arc::new(build_fixed_size_list_nullable( expected_data.clone(), length, @@ -1000,25 +994,29 @@ mod tests { } // for List - let input = Arc::new(build_generic_list_nullable::(data.clone())); - let sorted = sort(&(input as ArrayRef), options, None).unwrap(); - let expected = - Arc::new(build_generic_list_nullable::(expected_data.clone())) - as ArrayRef; - - assert_eq!(&sorted, &expected); + // let input = Arc::new(build_generic_list_nullable::(data.clone())); + // let sorted = sort(&(input as ArrayRef), options, limit).unwrap(); + // let expected = + // Arc::new(build_generic_list_nullable::(expected_data.clone())) + // as ArrayRef; + // + // assert_eq!(&sorted, &expected); // for LargeList let input = Arc::new(build_generic_list_nullable::(data)); - let sorted = sort(&(input as ArrayRef), options, None).unwrap(); + let sorted = sort(&(input as ArrayRef), options, limit).unwrap(); let expected = Arc::new(build_generic_list_nullable::(expected_data)) as ArrayRef; assert_eq!(&sorted, &expected); } - fn test_lex_sort_arrays(input: Vec, expected_output: Vec) { - let sorted = lexsort(&input, None).unwrap(); + fn test_lex_sort_arrays( + input: Vec, + expected_output: Vec, + limit: Option, + ) { + let sorted = lexsort(&input, limit).unwrap(); for (result, expected) in sorted.iter().zip(expected_output.iter()) { assert_eq!(result, expected); @@ -1030,21 +1028,25 @@ mod tests { test_sort_to_indices_primitive_arrays::( vec![None, Some(0), Some(2), Some(-1), Some(0), None], None, + None, vec![0, 5, 3, 1, 4, 2], ); test_sort_to_indices_primitive_arrays::( vec![None, Some(0), Some(2), Some(-1), Some(0), None], None, + None, vec![0, 5, 3, 1, 4, 2], ); test_sort_to_indices_primitive_arrays::( vec![None, Some(0), Some(2), Some(-1), Some(0), None], None, + None, vec![0, 5, 3, 1, 4, 2], ); test_sort_to_indices_primitive_arrays::( vec![None, Some(0), Some(2), Some(-1), Some(0), None], None, + None, vec![0, 5, 3, 1, 4, 2], ); test_sort_to_indices_primitive_arrays::( @@ -1057,6 +1059,7 @@ mod tests { None, ], None, + None, vec![0, 5, 3, 1, 4, 2], ); test_sort_to_indices_primitive_arrays::( @@ -1069,6 +1072,7 @@ mod tests { None, ], None, + None, vec![0, 5, 3, 1, 4, 2], ); @@ -1079,6 +1083,7 @@ mod tests { descending: true, nulls_first: false, }), + None, vec![2, 1, 4, 3, 5, 0], // [2, 4, 1, 3, 5, 0] ); @@ -1088,6 +1093,7 @@ mod tests { descending: true, nulls_first: false, }), + None, vec![2, 1, 4, 3, 5, 0], ); @@ -1097,6 +1103,7 @@ mod tests { descending: true, nulls_first: false, }), + None, vec![2, 1, 4, 3, 5, 0], ); @@ -1106,6 +1113,7 @@ mod tests { descending: true, nulls_first: false, }), + None, vec![2, 1, 4, 3, 5, 0], ); @@ -1122,6 +1130,7 @@ mod tests { descending: true, nulls_first: false, }), + None, vec![2, 1, 4, 3, 5, 0], ); @@ -1131,6 +1140,7 @@ mod tests { descending: true, nulls_first: false, }), + None, vec![2, 1, 4, 3, 5, 0], ); @@ -1141,6 +1151,7 @@ mod tests { descending: true, nulls_first: true, }), + None, vec![5, 0, 2, 1, 4, 3], // [5, 0, 2, 4, 1, 3] ); @@ -1150,6 +1161,7 @@ mod tests { descending: true, nulls_first: true, }), + None, vec![5, 0, 2, 1, 4, 3], // [5, 0, 2, 4, 1, 3] ); @@ -1159,6 +1171,7 @@ mod tests { descending: true, nulls_first: true, }), + None, vec![5, 0, 2, 1, 4, 3], ); @@ -1168,6 +1181,7 @@ mod tests { descending: true, nulls_first: true, }), + None, vec![5, 0, 2, 1, 4, 3], ); @@ -1177,6 +1191,7 @@ mod tests { descending: true, nulls_first: true, }), + None, vec![5, 0, 2, 1, 4, 3], ); @@ -1186,6 +1201,7 @@ mod tests { descending: true, nulls_first: true, }), + None, vec![5, 0, 2, 1, 4, 3], ); } @@ -1196,6 +1212,7 @@ mod tests { test_sort_to_indices_boolean_arrays( vec![None, Some(false), Some(true), Some(true), Some(false), None], None, + None, vec![0, 5, 1, 4, 2, 3], ); @@ -1206,6 +1223,7 @@ mod tests { descending: true, nulls_first: false, }), + None, vec![2, 3, 1, 4, 5, 0], ); @@ -1216,8 +1234,20 @@ mod tests { descending: true, nulls_first: true, }), + None, vec![5, 0, 2, 3, 1, 4], ); + + // boolean, descending, nulls first, limit + test_sort_to_indices_boolean_arrays( + vec![None, Some(false), Some(true), Some(true), Some(false), None], + Some(SortOptions { + descending: true, + nulls_first: true, + }), + Some(3), + vec![5, 0, 2], + ); } #[test] @@ -1226,21 +1256,25 @@ mod tests { test_sort_primitive_arrays::( vec![None, Some(3), Some(5), Some(2), Some(3), None], None, + None, vec![None, None, Some(2), Some(3), Some(3), Some(5)], ); test_sort_primitive_arrays::( vec![None, Some(3), Some(5), Some(2), Some(3), None], None, + None, vec![None, None, Some(2), Some(3), Some(3), Some(5)], ); test_sort_primitive_arrays::( vec![None, Some(3), Some(5), Some(2), Some(3), None], None, + None, vec![None, None, Some(2), Some(3), Some(3), Some(5)], ); test_sort_primitive_arrays::( vec![None, Some(3), Some(5), Some(2), Some(3), None], None, + None, vec![None, None, Some(2), Some(3), Some(3), Some(5)], ); @@ -1251,6 +1285,7 @@ mod tests { descending: true, nulls_first: false, }), + None, vec![Some(2), Some(0), Some(0), Some(-1), None, None], ); test_sort_primitive_arrays::( @@ -1259,6 +1294,7 @@ mod tests { descending: true, nulls_first: false, }), + None, vec![Some(2), Some(0), Some(0), Some(-1), None, None], ); test_sort_primitive_arrays::( @@ -1267,6 +1303,7 @@ mod tests { descending: true, nulls_first: false, }), + None, vec![Some(2), Some(0), Some(0), Some(-1), None, None], ); test_sort_primitive_arrays::( @@ -1275,6 +1312,7 @@ mod tests { descending: true, nulls_first: false, }), + None, vec![Some(2), Some(0), Some(0), Some(-1), None, None], ); @@ -1285,6 +1323,7 @@ mod tests { descending: true, nulls_first: true, }), + None, vec![None, None, Some(2), Some(0), Some(0), Some(-1)], ); test_sort_primitive_arrays::( @@ -1293,6 +1332,7 @@ mod tests { descending: true, nulls_first: true, }), + None, vec![None, None, Some(2), Some(0), Some(0), Some(-1)], ); test_sort_primitive_arrays::( @@ -1301,6 +1341,7 @@ mod tests { descending: true, nulls_first: true, }), + None, vec![None, None, Some(2), Some(0), Some(0), Some(-1)], ); test_sort_primitive_arrays::( @@ -1309,14 +1350,27 @@ mod tests { descending: true, nulls_first: true, }), + None, vec![None, None, Some(2), Some(0), Some(0), Some(-1)], ); + + test_sort_primitive_arrays::( + vec![None, Some(0), Some(2), Some(-1), Some(0), None], + Some(SortOptions { + descending: true, + nulls_first: true, + }), + Some(3), + vec![None, None, Some(2)], + ); + test_sort_primitive_arrays::( vec![None, Some(0.0), Some(2.0), Some(-1.0), Some(0.0), None], Some(SortOptions { descending: true, nulls_first: true, }), + None, vec![None, None, Some(2.0), Some(0.0), Some(0.0), Some(-1.0)], ); test_sort_primitive_arrays::( @@ -1325,6 +1379,7 @@ mod tests { descending: true, nulls_first: true, }), + None, vec![None, None, Some(f64::NAN), Some(2.0), Some(0.0), Some(-1.0)], ); test_sort_primitive_arrays::( @@ -1333,6 +1388,7 @@ mod tests { descending: true, nulls_first: true, }), + None, vec![Some(f64::NAN), Some(f64::NAN), Some(f64::NAN), Some(1.0)], ); @@ -1343,6 +1399,7 @@ mod tests { descending: false, nulls_first: true, }), + None, vec![None, None, Some(-1), Some(0), Some(0), Some(2)], ); test_sort_primitive_arrays::( @@ -1351,6 +1408,7 @@ mod tests { descending: false, nulls_first: true, }), + None, vec![None, None, Some(-1), Some(0), Some(0), Some(2)], ); test_sort_primitive_arrays::( @@ -1359,6 +1417,7 @@ mod tests { descending: false, nulls_first: true, }), + None, vec![None, None, Some(-1), Some(0), Some(0), Some(2)], ); test_sort_primitive_arrays::( @@ -1367,6 +1426,7 @@ mod tests { descending: false, nulls_first: true, }), + None, vec![None, None, Some(-1), Some(0), Some(0), Some(2)], ); test_sort_primitive_arrays::( @@ -1375,6 +1435,7 @@ mod tests { descending: false, nulls_first: true, }), + None, vec![None, None, Some(-1.0), Some(0.0), Some(0.0), Some(2.0)], ); test_sort_primitive_arrays::( @@ -1383,6 +1444,7 @@ mod tests { descending: false, nulls_first: true, }), + None, vec![None, None, Some(-1.0), Some(0.0), Some(2.0), Some(f64::NAN)], ); test_sort_primitive_arrays::( @@ -1391,6 +1453,18 @@ mod tests { descending: false, nulls_first: true, }), + None, + vec![Some(1.0), Some(f64::NAN), Some(f64::NAN), Some(f64::NAN)], + ); + + // limit + test_sort_primitive_arrays::( + vec![Some(f64::NAN), Some(f64::NAN), Some(f64::NAN), Some(1.0)], + Some(SortOptions { + descending: false, + nulls_first: true, + }), + Some(4), vec![Some(1.0), Some(f64::NAN), Some(f64::NAN), Some(f64::NAN)], ); } @@ -1407,6 +1481,7 @@ mod tests { Some("-ad"), ], None, + None, vec![0, 3, 5, 1, 4, 2], ); @@ -1423,6 +1498,7 @@ mod tests { descending: true, nulls_first: false, }), + None, vec![2, 4, 1, 5, 3, 0], ); @@ -1439,6 +1515,7 @@ mod tests { descending: false, nulls_first: true, }), + None, vec![0, 3, 5, 1, 4, 2], ); @@ -1455,8 +1532,26 @@ mod tests { descending: true, nulls_first: true, }), + None, vec![3, 0, 2, 4, 1, 5], ); + + test_sort_to_indices_string_arrays( + vec![ + None, + Some("bad"), + Some("sad"), + None, + Some("glad"), + Some("-ad"), + ], + Some(SortOptions { + descending: true, + nulls_first: true, + }), + Some(3), + vec![3, 0, 2], + ); } #[test] @@ -1471,6 +1566,7 @@ mod tests { Some("-ad"), ], None, + None, vec![ None, None, @@ -1494,6 +1590,7 @@ mod tests { descending: true, nulls_first: false, }), + None, vec![ Some("sad"), Some("glad"), @@ -1517,6 +1614,7 @@ mod tests { descending: false, nulls_first: true, }), + None, vec![ None, None, @@ -1540,6 +1638,7 @@ mod tests { descending: true, nulls_first: true, }), + None, vec![ None, None, @@ -1549,6 +1648,23 @@ mod tests { Some("-ad"), ], ); + + test_sort_string_arrays( + vec![ + None, + Some("bad"), + Some("sad"), + None, + Some("glad"), + Some("-ad"), + ], + Some(SortOptions { + descending: true, + nulls_first: true, + }), + Some(3), + vec![None, None, Some("sad")], + ); } #[test] @@ -1563,6 +1679,7 @@ mod tests { Some("-ad"), ], None, + None, vec![ None, None, @@ -1586,6 +1703,7 @@ mod tests { descending: true, nulls_first: false, }), + None, vec![ Some("sad"), Some("glad"), @@ -1609,6 +1727,7 @@ mod tests { descending: false, nulls_first: true, }), + None, vec![ None, None, @@ -1632,6 +1751,7 @@ mod tests { descending: true, nulls_first: true, }), + None, vec![ None, None, @@ -1641,6 +1761,23 @@ mod tests { Some("-ad"), ], ); + + test_sort_string_dict_arrays::( + vec![ + None, + Some("bad"), + Some("sad"), + None, + Some("glad"), + Some("-ad"), + ], + Some(SortOptions { + descending: true, + nulls_first: true, + }), + Some(3), + vec![None, None, Some("sad")], + ); } #[test] @@ -1656,6 +1793,7 @@ mod tests { descending: false, nulls_first: false, }), + None, vec![ Some(vec![Some(1)]), Some(vec![Some(2)]), @@ -1677,6 +1815,7 @@ mod tests { descending: false, nulls_first: false, }), + None, vec![ Some(vec![Some(1), Some(0)]), Some(vec![Some(1), Some(1)]), @@ -1699,6 +1838,7 @@ mod tests { descending: false, nulls_first: false, }), + None, vec![ Some(vec![Some(2), Some(3), Some(4)]), Some(vec![Some(3), Some(3), None]), @@ -1708,6 +1848,23 @@ mod tests { ], Some(3), ); + + test_sort_list_arrays::( + vec![ + Some(vec![Some(1), Some(0)]), + Some(vec![Some(4), Some(3), Some(2), Some(1)]), + Some(vec![Some(2), Some(3), Some(4)]), + Some(vec![Some(3), Some(3), Some(3), Some(3)]), + Some(vec![Some(1), Some(1)]), + ], + Some(SortOptions { + descending: false, + nulls_first: false, + }), + Some(2), + vec![Some(vec![Some(1), Some(0)]), Some(vec![Some(1), Some(1)])], + None, + ); } #[test] @@ -1727,7 +1884,14 @@ mod tests { Some(2), Some(17), ])) as ArrayRef]; - test_lex_sort_arrays(input, expected); + test_lex_sort_arrays(input.clone(), expected, None); + + let expected = vec![Arc::new(PrimitiveArray::::from(vec![ + Some(-1), + Some(0), + Some(2), + ])) as ArrayRef]; + test_lex_sort_arrays(input, expected, Some(3)); } #[test] @@ -1800,7 +1964,7 @@ mod tests { Some(-2), ])) as ArrayRef, ]; - test_lex_sort_arrays(input, expected); + test_lex_sort_arrays(input, expected, None); // test mix of string and in64 with option let input = vec![ @@ -1843,7 +2007,7 @@ mod tests { Some("7"), ])) as ArrayRef, ]; - test_lex_sort_arrays(input, expected); + test_lex_sort_arrays(input, expected, None); // test sort with nulls first let input = vec![ @@ -1886,7 +2050,7 @@ mod tests { Some("world"), ])) as ArrayRef, ]; - test_lex_sort_arrays(input, expected); + test_lex_sort_arrays(input, expected, None); // test sort with nulls last let input = vec![ @@ -1929,7 +2093,7 @@ mod tests { None, ])) as ArrayRef, ]; - test_lex_sort_arrays(input, expected); + test_lex_sort_arrays(input, expected, None); // test sort with opposite options let input = vec![ @@ -1976,6 +2140,6 @@ mod tests { Some("foo"), ])) as ArrayRef, ]; - test_lex_sort_arrays(input, expected); + test_lex_sort_arrays(input, expected, None); } } From c7df0d625823e328e1d9a4787d98dbb8b4d1bed5 Mon Sep 17 00:00:00 2001 From: sundy-li <543950155@qq.com> Date: Mon, 1 Mar 2021 17:50:44 +0800 Subject: [PATCH 03/11] update comment --- rust/arrow/src/compute/kernels/sort.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/rust/arrow/src/compute/kernels/sort.rs b/rust/arrow/src/compute/kernels/sort.rs index 18e9268d771..ab5db6aca9d 100644 --- a/rust/arrow/src/compute/kernels/sort.rs +++ b/rust/arrow/src/compute/kernels/sort.rs @@ -653,7 +653,7 @@ where _ => {} } - /// we are not using partial_sort here, because array is ArrayRef. Something is not working good in that. + // we are not using partial_sort here, because array is ArrayRef. Something is not working good in that. if !descending { valids.sort_by(|a, b| cmp_array(a.1.as_ref(), b.1.as_ref())); } else { From 4f4dc62ebc681b015a7a9818174e7ae4a4489fbd Mon Sep 17 00:00:00 2001 From: sundy-li <543950155@qq.com> Date: Mon, 1 Mar 2021 17:52:28 +0800 Subject: [PATCH 04/11] update styles --- rust/arrow/src/compute/kernels/sort.rs | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/rust/arrow/src/compute/kernels/sort.rs b/rust/arrow/src/compute/kernels/sort.rs index ab5db6aca9d..e7fb081a5b4 100644 --- a/rust/arrow/src/compute/kernels/sort.rs +++ b/rust/arrow/src/compute/kernels/sort.rs @@ -646,11 +646,8 @@ where let mut len = values.len(); let descending = options.descending; - match limit { - Some(limit) => { - len = limit.min(len); - } - _ => {} + if let Some(size) = limit { + len = size.min(len); } // we are not using partial_sort here, because array is ArrayRef. Something is not working good in that. From dc3167e3826177af1a6feec516a528a4ee38c674 Mon Sep 17 00:00:00 2001 From: sundy-li <543950155@qq.com> Date: Tue, 2 Mar 2021 11:24:31 +0800 Subject: [PATCH 05/11] Remove pdqsort && add a separate function partial_sort --- rust/arrow/Cargo.toml | 1 - rust/arrow/benches/sort_kernel.rs | 13 +++++- rust/arrow/src/compute/kernels/sort.rs | 65 +++++++++++++++++++------- 3 files changed, 59 insertions(+), 20 deletions(-) diff --git a/rust/arrow/Cargo.toml b/rust/arrow/Cargo.toml index 11ad67d6b64..cdff0f3fea6 100644 --- a/rust/arrow/Cargo.toml +++ b/rust/arrow/Cargo.toml @@ -52,7 +52,6 @@ hex = "0.4" prettytable-rs = { version = "0.8.0", optional = true } lexical-core = "^0.7" partial_sort = "0.1.1" -pdqsort = "1.0.3" [features] default = [] diff --git a/rust/arrow/benches/sort_kernel.rs b/rust/arrow/benches/sort_kernel.rs index 105b3e66614..038833fe732 100644 --- a/rust/arrow/benches/sort_kernel.rs +++ b/rust/arrow/benches/sort_kernel.rs @@ -73,7 +73,7 @@ fn add_benchmark(c: &mut Criterion) { b.iter(|| bench_sort(&arr_a, &arr_b, None)) }); - /// with limit + // with limit { let arr_a = create_array(2u64.pow(12) as usize, false); let arr_b = create_array(2u64.pow(12) as usize, false); @@ -87,6 +87,12 @@ fn add_benchmark(c: &mut Criterion) { b.iter(|| bench_sort(&arr_a, &arr_b, Some(100))) }); + let arr_a = create_array(2u64.pow(12) as usize, false); + let arr_b = create_array(2u64.pow(12) as usize, false); + c.bench_function("sort 2^12 limit 1000", |b| { + b.iter(|| bench_sort(&arr_a, &arr_b, Some(1000))) + }); + let arr_a = create_array(2u64.pow(12) as usize, true); let arr_b = create_array(2u64.pow(12) as usize, true); @@ -94,7 +100,10 @@ fn add_benchmark(c: &mut Criterion) { b.iter(|| bench_sort(&arr_a, &arr_b, Some(10))) }); c.bench_function("sort nulls 2^12 limit 100", |b| { - b.iter(|| bench_sort(&arr_a, &arr_b, Some(10))) + b.iter(|| bench_sort(&arr_a, &arr_b, Some(100))) + }); + c.bench_function("sort nulls 2^12 limit 1000", |b| { + b.iter(|| bench_sort(&arr_a, &arr_b, Some(1000))) }); } } diff --git a/rust/arrow/src/compute/kernels/sort.rs b/rust/arrow/src/compute/kernels/sort.rs index e7fb081a5b4..b134813bc02 100644 --- a/rust/arrow/src/compute/kernels/sort.rs +++ b/rust/arrow/src/compute/kernels/sort.rs @@ -36,7 +36,14 @@ use TimeUnit::*; /// /// Returns an `ArrowError::ComputeError(String)` if the array type is either unsupported by `sort_to_indices` or `take`. /// -pub fn sort( +pub fn sort(values: &ArrayRef, options: Option) -> Result { + let indices = sort_to_indices(values, options, None)?; + take(values.as_ref(), &indices, None) +} + +/// Sort the `ArrayRef` partially. +/// Return an sorted `ArrayRef`, discarding the data after limit. +pub fn partial_sort( values: &ArrayRef, options: Option, limit: Option, @@ -354,9 +361,9 @@ fn sort_boolean( } _ => { if !descending { - pdqsort::sort_by(&mut valids, |a, b| cmp(a.1, b.1)); + valids.sort_by(|a, b| cmp(a.1, b.1)); } else { - pdqsort::sort_by(&mut valids, |a, b| cmp(a.1, b.1).reverse()); + valids.sort_by(|a, b| cmp(a.1, b.1).reverse()); // reverse to keep a stable ordering nulls.reverse(); } @@ -443,9 +450,9 @@ where } _ => { if !descending { - pdqsort::sort_by(&mut valids, |a, b| cmp(a.1, b.1)); + valids.sort_by(|a, b| cmp(a.1, b.1)); } else { - pdqsort::sort_by(&mut valids, |a, b| cmp(a.1, b.1).reverse()); + valids.sort_by(|a, b| cmp(a.1, b.1).reverse()); // reverse to keep a stable ordering nulls.reverse(); } @@ -894,7 +901,12 @@ mod tests { { let output = PrimitiveArray::::from(data); let expected = Arc::new(PrimitiveArray::::from(expected_data)) as ArrayRef; - let output = sort(&(Arc::new(output) as ArrayRef), options, limit).unwrap(); + let output = match limit { + Some(_) => { + partial_sort(&(Arc::new(output) as ArrayRef), options, limit).unwrap() + } + _ => sort(&(Arc::new(output) as ArrayRef), options).unwrap(), + }; assert_eq!(&output, &expected) } @@ -919,7 +931,12 @@ mod tests { ) { let output = StringArray::from(data); let expected = Arc::new(StringArray::from(expected_data)) as ArrayRef; - let output = sort(&(Arc::new(output) as ArrayRef), options, limit).unwrap(); + let output = match limit { + Some(_) => { + partial_sort(&(Arc::new(output) as ArrayRef), options, limit).unwrap() + } + _ => sort(&(Arc::new(output) as ArrayRef), options).unwrap(), + }; assert_eq!(&output, &expected) } @@ -936,7 +953,12 @@ mod tests { .downcast_ref::() .expect("Unable to get dictionary values"); - let sorted = sort(&(Arc::new(array) as ArrayRef), options, limit).unwrap(); + let sorted = match limit { + Some(_) => { + partial_sort(&(Arc::new(array) as ArrayRef), options, limit).unwrap() + } + _ => sort(&(Arc::new(array) as ArrayRef), options).unwrap(), + }; let sorted = sorted .as_any() .downcast_ref::>() @@ -981,7 +1003,10 @@ mod tests { // for FixedSizedList if let Some(length) = fixed_length { let input = Arc::new(build_fixed_size_list_nullable(data.clone(), length)); - let sorted = sort(&(input as ArrayRef), options, limit).unwrap(); + let sorted = match limit { + Some(_) => partial_sort(&(input as ArrayRef), options, limit).unwrap(), + _ => sort(&(input as ArrayRef), options).unwrap(), + }; let expected = Arc::new(build_fixed_size_list_nullable( expected_data.clone(), length, @@ -991,17 +1016,23 @@ mod tests { } // for List - // let input = Arc::new(build_generic_list_nullable::(data.clone())); - // let sorted = sort(&(input as ArrayRef), options, limit).unwrap(); - // let expected = - // Arc::new(build_generic_list_nullable::(expected_data.clone())) - // as ArrayRef; - // - // assert_eq!(&sorted, &expected); + let input = Arc::new(build_generic_list_nullable::(data.clone())); + let sorted = match limit { + Some(_) => partial_sort(&(input as ArrayRef), options, limit).unwrap(), + _ => sort(&(input as ArrayRef), options).unwrap(), + }; + let expected = + Arc::new(build_generic_list_nullable::(expected_data.clone())) + as ArrayRef; + + assert_eq!(&sorted, &expected); // for LargeList let input = Arc::new(build_generic_list_nullable::(data)); - let sorted = sort(&(input as ArrayRef), options, limit).unwrap(); + let sorted = match limit { + Some(_) => partial_sort(&(input as ArrayRef), options, limit).unwrap(), + _ => sort(&(input as ArrayRef), options).unwrap(), + }; let expected = Arc::new(build_generic_list_nullable::(expected_data)) as ArrayRef; From b98b0b3604cd1e23d850ef943c2d166efe49ea68 Mon Sep 17 00:00:00 2001 From: sundy-li <543950155@qq.com> Date: Thu, 4 Mar 2021 18:02:24 +0800 Subject: [PATCH 06/11] 1. Introduce a wrap function for sort and partial_sort, if limit size equals to len size, it will use merge sort by default. 2. update partial_sort version to 0.1.2 --- rust/arrow/Cargo.toml | 2 +- rust/arrow/benches/sort_kernel.rs | 9 ++ rust/arrow/src/compute/kernels/sort.rs | 129 +++++++++++-------------- 3 files changed, 64 insertions(+), 76 deletions(-) diff --git a/rust/arrow/Cargo.toml b/rust/arrow/Cargo.toml index cdff0f3fea6..4fc6bcd6fe8 100644 --- a/rust/arrow/Cargo.toml +++ b/rust/arrow/Cargo.toml @@ -51,7 +51,7 @@ flatbuffers = "^0.8" hex = "0.4" prettytable-rs = { version = "0.8.0", optional = true } lexical-core = "^0.7" -partial_sort = "0.1.1" +partial_sort = "0.1.2" [features] default = [] diff --git a/rust/arrow/benches/sort_kernel.rs b/rust/arrow/benches/sort_kernel.rs index 038833fe732..74dc0ceae18 100644 --- a/rust/arrow/benches/sort_kernel.rs +++ b/rust/arrow/benches/sort_kernel.rs @@ -93,6 +93,12 @@ fn add_benchmark(c: &mut Criterion) { b.iter(|| bench_sort(&arr_a, &arr_b, Some(1000))) }); + let arr_a = create_array(2u64.pow(12) as usize, false); + let arr_b = create_array(2u64.pow(12) as usize, false); + c.bench_function("sort 2^12 limit 2^12", |b| { + b.iter(|| bench_sort(&arr_a, &arr_b, Some(2u64.pow(12) as usize))) + }); + let arr_a = create_array(2u64.pow(12) as usize, true); let arr_b = create_array(2u64.pow(12) as usize, true); @@ -105,6 +111,9 @@ fn add_benchmark(c: &mut Criterion) { c.bench_function("sort nulls 2^12 limit 1000", |b| { b.iter(|| bench_sort(&arr_a, &arr_b, Some(1000))) }); + c.bench_function("sort nulls 2^12 limit 2^12", |b| { + b.iter(|| bench_sort(&arr_a, &arr_b, Some(2u64.pow(12) as usize))) + }); } } diff --git a/rust/arrow/src/compute/kernels/sort.rs b/rust/arrow/src/compute/kernels/sort.rs index b134813bc02..e7fd5fa0268 100644 --- a/rust/arrow/src/compute/kernels/sort.rs +++ b/rust/arrow/src/compute/kernels/sort.rs @@ -52,6 +52,18 @@ pub fn partial_sort( take(values.as_ref(), &indices, None) } +#[inline] +fn sort_by(array: &mut [T], limit: usize, cmp: F) +where + F: FnMut(&T, &T) -> Ordering, +{ + if array.len() == limit { + array.sort_by(cmp); + } else { + array.partial_sort(limit, cmp); + } +} + // implements comparison using IEEE 754 total ordering for f32 // Original implementation from https://doc.rust-lang.org/std/primitive.f64.html#method.total_cmp // TODO to change to use std when it becomes stable @@ -348,26 +360,15 @@ fn sort_boolean( let nulls_len = nulls.len(); let mut len = values.len(); - match limit { - Some(limit) => { - len = limit.min(len); - if !descending { - valids.partial_sort(len, |a, b| cmp(a.1, b.1)); - } else { - valids.partial_sort(len, |a, b| cmp(a.1, b.1).reverse()); - // reverse to keep a stable ordering - nulls.reverse(); - } - } - _ => { - if !descending { - valids.sort_by(|a, b| cmp(a.1, b.1)); - } else { - valids.sort_by(|a, b| cmp(a.1, b.1).reverse()); - // reverse to keep a stable ordering - nulls.reverse(); - } - } + if let Some(limit) = limit { + len = limit.min(len); + } + if !descending { + sort_by(&mut valids, len - nulls_len, |a, b| cmp(a.1, b.1)); + } else { + sort_by(&mut valids, len - nulls_len, |a, b| cmp(a.1, b.1).reverse()); + // reverse to keep a stable ordering + nulls.reverse(); } // collect results directly into a buffer instead of a vec to avoid another aligned allocation @@ -436,27 +437,15 @@ where let nulls_len = nulls.len(); let mut len = values.len(); - match limit { - Some(limit) => { - len = limit.min(len); - - if !descending { - valids.partial_sort(len, |a, b| cmp(a.1, b.1)); - } else { - valids.partial_sort(len, |a, b| cmp(a.1, b.1).reverse()); - // reverse to keep a stable ordering - nulls.reverse(); - } - } - _ => { - if !descending { - valids.sort_by(|a, b| cmp(a.1, b.1)); - } else { - valids.sort_by(|a, b| cmp(a.1, b.1).reverse()); - // reverse to keep a stable ordering - nulls.reverse(); - } - } + if let Some(limit) = limit { + len = limit.min(len); + } + if !descending { + sort_by(&mut valids, len - nulls_len, |a, b| cmp(a.1, b.1)); + } else { + sort_by(&mut valids, len - nulls_len, |a, b| cmp(a.1, b.1).reverse()); + // reverse to keep a stable ordering + nulls.reverse(); } // collect results directly into a buffer instead of a vec to avoid another aligned allocation @@ -579,28 +568,18 @@ where let mut nulls = null_indices; let descending = options.descending; let mut len = values.len(); - match limit { - Some(limit) => { - len = limit.min(len); - if !descending { - valids.partial_sort(len, |a, b| cmp(a.1, b.1)); - } else { - valids.partial_sort(len, |a, b| cmp(a.1, b.1).reverse()); - // reverse to keep a stable ordering - nulls.reverse(); - } - } - _ => { - if !descending { - valids.sort_by(|a, b| cmp(a.1, b.1)); - } else { - valids.sort_by(|a, b| cmp(a.1, b.1).reverse()); - // reverse to keep a stable ordering - nulls.reverse(); - } - } - } + let nulls_len = nulls.len(); + if let Some(limit) = limit { + len = limit.min(len); + } + if !descending { + sort_by(&mut valids, len - nulls_len, |a, b| cmp(a.1, b.1)); + } else { + sort_by(&mut valids, len - nulls_len, |a, b| cmp(a.1, b.1).reverse()); + // reverse to keep a stable ordering + nulls.reverse(); + } // collect the order of valid tuplies let mut valid_indices: Vec = valids.iter().map(|tuple| tuple.0).collect(); @@ -651,23 +630,25 @@ where ); let mut len = values.len(); + let nulls_len = null_indices.len(); let descending = options.descending; - if let Some(size) = limit { - len = size.min(len); + if let Some(limit) = limit { + len = limit.min(len); } - - // we are not using partial_sort here, because array is ArrayRef. Something is not working good in that. if !descending { - valids.sort_by(|a, b| cmp_array(a.1.as_ref(), b.1.as_ref())); + sort_by(&mut valids, len - nulls_len, |a, b| { + cmp_array(a.1.as_ref(), b.1.as_ref()) + }); } else { - valids.sort_by(|a, b| cmp_array(a.1.as_ref(), b.1.as_ref()).reverse()); + sort_by(&mut valids, len - nulls_len, |a, b| { + cmp_array(a.1.as_ref(), b.1.as_ref()).reverse() + }); // reverse to keep a stable ordering null_indices.reverse(); } let mut valid_indices: Vec = valids.iter().map(|tuple| tuple.0).collect(); - if options.nulls_first { null_indices.append(&mut valid_indices); null_indices.truncate(len); @@ -835,13 +816,11 @@ pub fn lexsort_to_indices( let mut value_indices = (0..row_count).collect::>(); let mut len = value_indices.len(); - match limit { - Some(limit) => { - len = len.min(limit); - value_indices.partial_sort(len, lex_comparator); - } - None => value_indices.sort_by(lex_comparator), + + if let Some(limit) = limit { + len = limit.min(len); } + sort_by(&mut value_indices, len, lex_comparator); Ok(UInt32Array::from( (&value_indices)[0..len] From bb6a77ee383af70d5add7e8ad4125668cb8012e7 Mon Sep 17 00:00:00 2001 From: sundy-li <543950155@qq.com> Date: Sun, 7 Mar 2021 11:23:15 +0800 Subject: [PATCH 07/11] Update tests --- rust/arrow/src/compute/kernels/sort.rs | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/rust/arrow/src/compute/kernels/sort.rs b/rust/arrow/src/compute/kernels/sort.rs index e7fd5fa0268..8a1c7c882c0 100644 --- a/rust/arrow/src/compute/kernels/sort.rs +++ b/rust/arrow/src/compute/kernels/sort.rs @@ -1471,8 +1471,19 @@ mod tests { descending: false, nulls_first: true, }), - Some(4), - vec![Some(1.0), Some(f64::NAN), Some(f64::NAN), Some(f64::NAN)], + Some(2), + vec![Some(1.0), Some(f64::NAN)], + ); + + // limit with actual value + test_sort_primitive_arrays::( + vec![Some(2.0), Some(4.0), Some(3.0), Some(1.0)], + Some(SortOptions { + descending: false, + nulls_first: true, + }), + Some(3), + vec![Some(1.0), Some(2.0), Some(3.0)], ); } From 04616686c6e71a54fab254209a6586fdb290d71e Mon Sep 17 00:00:00 2001 From: sundy-li <543950155@qq.com> Date: Mon, 8 Mar 2021 13:07:12 +0800 Subject: [PATCH 08/11] Include partial_sort to arrow --- rust/arrow/Cargo.toml | 1 - rust/arrow/src/compute/kernels/sort.rs | 188 +++++++++++++++++++++++-- 2 files changed, 180 insertions(+), 9 deletions(-) diff --git a/rust/arrow/Cargo.toml b/rust/arrow/Cargo.toml index 4fc6bcd6fe8..5ab1f8cc02b 100644 --- a/rust/arrow/Cargo.toml +++ b/rust/arrow/Cargo.toml @@ -51,7 +51,6 @@ flatbuffers = "^0.8" hex = "0.4" prettytable-rs = { version = "0.8.0", optional = true } lexical-core = "^0.7" -partial_sort = "0.1.2" [features] default = [] diff --git a/rust/arrow/src/compute/kernels/sort.rs b/rust/arrow/src/compute/kernels/sort.rs index 8a1c7c882c0..8270d8e73c3 100644 --- a/rust/arrow/src/compute/kernels/sort.rs +++ b/rust/arrow/src/compute/kernels/sort.rs @@ -17,6 +17,7 @@ //! Defines sort kernel for `ArrayRef` +use core::{mem, ptr}; use std::cmp::Ordering; use std::sync::Arc; @@ -26,7 +27,6 @@ use crate::compute::take; use crate::datatypes::*; use crate::error::{ArrowError, Result}; -use partial_sort::PartialSort; use TimeUnit::*; /// Sort the `ArrayRef` using `SortOptions`. @@ -43,7 +43,7 @@ pub fn sort(values: &ArrayRef, options: Option) -> Result /// Sort the `ArrayRef` partially. /// Return an sorted `ArrayRef`, discarding the data after limit. -pub fn partial_sort( +pub fn sort_limit( values: &ArrayRef, options: Option, limit: Option, @@ -111,6 +111,7 @@ fn partition_validity(array: &ArrayRef) -> (Vec, Vec) { /// Sort elements from `ArrayRef` into an unsigned integer (`UInt32Array`) of indices. /// For floating point arrays any NaN values are considered to be greater than any other non-null value +/// limit is an option for partial_sort pub fn sort_to_indices( values: &ArrayRef, options: Option, @@ -830,12 +831,156 @@ pub fn lexsort_to_indices( )) } +/// partial_sort is Rust version of [std::partial_sort](https://en.cppreference.com/w/cpp/algorithm/partial_sort) +/// +/// # Example +// ``` +/// let mut vec = vec![4, 4, 3, 3, 1, 1, 2, 2]; +/// vec.partial_sort(4, |a, b| a.cmp(b)); +/// println!("{:?}", vec); +/// ``` + +pub trait PartialSort { + type Item; + + fn partial_sort(&mut self, _: usize, _: F) + where + F: FnMut(&Self::Item, &Self::Item) -> Ordering; +} + +impl PartialSort for [T] { + type Item = T; + + fn partial_sort(&mut self, limit: usize, mut cmp: F) + where + F: FnMut(&Self::Item, &Self::Item) -> Ordering, + { + partial_sort(self, limit, |a, b| cmp(a, b) == Ordering::Less); + } +} + +pub fn partial_sort(v: &mut [T], limit: usize, mut is_less: F) +where + F: FnMut(&T, &T) -> bool, +{ + debug_assert!(limit <= v.len()); + make_heap(v, limit, &mut is_less); + + // unsafe because we use `get_unchecked` + unsafe { + for i in limit..v.len() { + if is_less(v.get_unchecked(i), v.get_unchecked(0)) { + v.swap(0, i); + adjust_heap(v, 0, limit, &mut is_less); + } + } + sort_heap(v, limit, &mut is_less); + } +} + +/// make a heap with last limit size elements +#[inline] +fn make_heap(v: &mut [T], limit: usize, is_less: &mut F) +where + F: FnMut(&T, &T) -> bool, +{ + if limit < 2 { + return; + } + + let len = limit; + let mut parent = (len - 2) / 2; + + loop { + adjust_heap(v, parent, len, is_less); + if parent == 0 { + return; + } + parent -= 1; + } +} + +/// adjust_heap is a sift down adjust operation for the heap +#[inline] +fn adjust_heap(v: &mut [T], hole_index: usize, len: usize, is_less: &mut F) +where + F: FnMut(&T, &T) -> bool, +{ + let mut left_child = hole_index * 2 + 1; + + // Panic safety: + // + // If `is_less` panics at any point during the process, `hole` will get dropped and + // fill the hole in `v` with `tmp`, thus ensuring that `v` still holds every object it + // initially held exactly once. + let mut tmp = mem::ManuallyDrop::new(unsafe { ptr::read(&v[hole_index]) }); + let mut hole = InsertionHole { + src: &mut *tmp, + dest: &mut v[hole_index], + }; + + unsafe { + while left_child < len { + if left_child + 1 < len + && is_less(v.get_unchecked(left_child), v.get_unchecked(left_child + 1)) + { + left_child += 1; + } + + if is_less(&*tmp, v.get_unchecked(left_child)) { + ptr::copy_nonoverlapping(&v[left_child], hole.dest, 1); + hole.dest = &mut v[left_child]; + } else { + break; + } + + left_child = left_child * 2 + 1; + } + // `hole` gets dropped and thus copies `tmp` into the remaining hole in `v`. + } + + // This code is copy from std library `src/slice.rs` + // When dropped, copies from `src` into `dest`. + struct InsertionHole { + src: *mut T, + dest: *mut T, + } + + impl Drop for InsertionHole { + fn drop(&mut self) { + // SAFETY: + // we ensure src/dest point to a properly initialized value of type T + // src is valid for reads of `count * size_of::()` bytes. + // dest is valid for reads of `count * size_of::()` bytes. + // Both `src` and `dst` are properly aligned. + unsafe { + ptr::copy_nonoverlapping(self.src, self.dest, 1); + } + } + } +} + +#[inline] +fn sort_heap(v: &mut [T], last: usize, is_less: &mut F) +where + F: FnMut(&T, &T) -> bool, +{ + let mut last = last; + while last > 1 { + v.swap(0, last - 1); + adjust_heap(v, 0, last - 1, is_less); + last -= 1; + } +} + #[cfg(test)] mod tests { use super::*; use crate::compute::util::tests::{ build_fixed_size_list_nullable, build_generic_list_nullable, }; + use rand::rngs::StdRng; + use rand::{Rng, RngCore, SeedableRng}; use std::convert::TryFrom; use std::iter::FromIterator; use std::sync::Arc; @@ -882,7 +1027,7 @@ mod tests { let expected = Arc::new(PrimitiveArray::::from(expected_data)) as ArrayRef; let output = match limit { Some(_) => { - partial_sort(&(Arc::new(output) as ArrayRef), options, limit).unwrap() + sort_limit(&(Arc::new(output) as ArrayRef), options, limit).unwrap() } _ => sort(&(Arc::new(output) as ArrayRef), options).unwrap(), }; @@ -912,7 +1057,7 @@ mod tests { let expected = Arc::new(StringArray::from(expected_data)) as ArrayRef; let output = match limit { Some(_) => { - partial_sort(&(Arc::new(output) as ArrayRef), options, limit).unwrap() + sort_limit(&(Arc::new(output) as ArrayRef), options, limit).unwrap() } _ => sort(&(Arc::new(output) as ArrayRef), options).unwrap(), }; @@ -934,7 +1079,7 @@ mod tests { let sorted = match limit { Some(_) => { - partial_sort(&(Arc::new(array) as ArrayRef), options, limit).unwrap() + sort_limit(&(Arc::new(array) as ArrayRef), options, limit).unwrap() } _ => sort(&(Arc::new(array) as ArrayRef), options).unwrap(), }; @@ -983,7 +1128,7 @@ mod tests { if let Some(length) = fixed_length { let input = Arc::new(build_fixed_size_list_nullable(data.clone(), length)); let sorted = match limit { - Some(_) => partial_sort(&(input as ArrayRef), options, limit).unwrap(), + Some(_) => sort_limit(&(input as ArrayRef), options, limit).unwrap(), _ => sort(&(input as ArrayRef), options).unwrap(), }; let expected = Arc::new(build_fixed_size_list_nullable( @@ -997,7 +1142,7 @@ mod tests { // for List let input = Arc::new(build_generic_list_nullable::(data.clone())); let sorted = match limit { - Some(_) => partial_sort(&(input as ArrayRef), options, limit).unwrap(), + Some(_) => sort_limit(&(input as ArrayRef), options, limit).unwrap(), _ => sort(&(input as ArrayRef), options).unwrap(), }; let expected = @@ -1009,7 +1154,7 @@ mod tests { // for LargeList let input = Arc::new(build_generic_list_nullable::(data)); let sorted = match limit { - Some(_) => partial_sort(&(input as ArrayRef), options, limit).unwrap(), + Some(_) => sort_limit(&(input as ArrayRef), options, limit).unwrap(), _ => sort(&(input as ArrayRef), options).unwrap(), }; let expected = @@ -2160,4 +2305,31 @@ mod tests { ]; test_lex_sort_arrays(input, expected, None); } + + #[test] + fn test_partial_sort() { + let mut before: Vec<&str> = vec![ + "a", "cat", "mat", "on", "sat", "the", "xxx", "xxxx", "fdadfdsf", + ]; + let mut d = before.clone(); + d.sort(); + + for last in 0..before.len() { + before.partial_sort(last, |a, b| a.cmp(b)); + assert_eq!(&d[0..last], &before.as_slice()[0..last]); + } + } + + #[test] + fn test_partial_rand_sort() { + let size = 1000u32; + let mut rng = StdRng::seed_from_u64(42); + let mut before: Vec = (0..size).map(|_| rng.gen::()).collect(); + let mut d = before.clone(); + let last = (rng.next_u32() % size) as usize; + d.sort(); + + before.partial_sort(last, |a, b| a.cmp(b)); + assert_eq!(&d[0..last], &before[0..last]); + } } From 86672cc549eec4d2a8899a3362e6e9336fe49190 Mon Sep 17 00:00:00 2001 From: sundy-li <543950155@qq.com> Date: Mon, 8 Mar 2021 23:23:00 +0800 Subject: [PATCH 09/11] Make clippy test happy --- rust/arrow/src/compute/kernels/sort.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/rust/arrow/src/compute/kernels/sort.rs b/rust/arrow/src/compute/kernels/sort.rs index 8270d8e73c3..050ca67fa14 100644 --- a/rust/arrow/src/compute/kernels/sort.rs +++ b/rust/arrow/src/compute/kernels/sort.rs @@ -2312,7 +2312,7 @@ mod tests { "a", "cat", "mat", "on", "sat", "the", "xxx", "xxxx", "fdadfdsf", ]; let mut d = before.clone(); - d.sort(); + d.sort_unstable(); for last in 0..before.len() { before.partial_sort(last, |a, b| a.cmp(b)); @@ -2327,7 +2327,7 @@ mod tests { let mut before: Vec = (0..size).map(|_| rng.gen::()).collect(); let mut d = before.clone(); let last = (rng.next_u32() % size) as usize; - d.sort(); + d.sort_unstable(); before.partial_sort(last, |a, b| a.cmp(b)); assert_eq!(&d[0..last], &before[0..last]); From 2dd16a797b972657bae42a40a4c7e415b7bf7721 Mon Sep 17 00:00:00 2001 From: sundy-li <543950155@qq.com> Date: Fri, 12 Mar 2021 15:25:24 +0800 Subject: [PATCH 10/11] partial_sort from select_nth_unstable_by + sort --- rust/arrow/src/compute/kernels/sort.rs | 147 +------------------------ 1 file changed, 6 insertions(+), 141 deletions(-) diff --git a/rust/arrow/src/compute/kernels/sort.rs b/rust/arrow/src/compute/kernels/sort.rs index 050ca67fa14..5a666844cb3 100644 --- a/rust/arrow/src/compute/kernels/sort.rs +++ b/rust/arrow/src/compute/kernels/sort.rs @@ -17,7 +17,6 @@ //! Defines sort kernel for `ArrayRef` -use core::{mem, ptr}; use std::cmp::Ordering; use std::sync::Arc; @@ -60,7 +59,7 @@ where if array.len() == limit { array.sort_by(cmp); } else { - array.partial_sort(limit, cmp); + partial_sort(array, limit, cmp); } } @@ -831,146 +830,12 @@ pub fn lexsort_to_indices( )) } -/// partial_sort is Rust version of [std::partial_sort](https://en.cppreference.com/w/cpp/algorithm/partial_sort) -/// -/// # Example -// ``` -/// let mut vec = vec![4, 4, 3, 3, 1, 1, 2, 2]; -/// vec.partial_sort(4, |a, b| a.cmp(b)); -/// println!("{:?}", vec); -/// ``` - -pub trait PartialSort { - type Item; - - fn partial_sort(&mut self, _: usize, _: F) - where - F: FnMut(&Self::Item, &Self::Item) -> Ordering; -} - -impl PartialSort for [T] { - type Item = T; - - fn partial_sort(&mut self, limit: usize, mut cmp: F) - where - F: FnMut(&Self::Item, &Self::Item) -> Ordering, - { - partial_sort(self, limit, |a, b| cmp(a, b) == Ordering::Less); - } -} - pub fn partial_sort(v: &mut [T], limit: usize, mut is_less: F) where - F: FnMut(&T, &T) -> bool, -{ - debug_assert!(limit <= v.len()); - make_heap(v, limit, &mut is_less); - - // unsafe because we use `get_unchecked` - unsafe { - for i in limit..v.len() { - if is_less(v.get_unchecked(i), v.get_unchecked(0)) { - v.swap(0, i); - adjust_heap(v, 0, limit, &mut is_less); - } - } - sort_heap(v, limit, &mut is_less); - } -} - -/// make a heap with last limit size elements -#[inline] -fn make_heap(v: &mut [T], limit: usize, is_less: &mut F) -where - F: FnMut(&T, &T) -> bool, -{ - if limit < 2 { - return; - } - - let len = limit; - let mut parent = (len - 2) / 2; - - loop { - adjust_heap(v, parent, len, is_less); - if parent == 0 { - return; - } - parent -= 1; - } -} - -/// adjust_heap is a sift down adjust operation for the heap -#[inline] -fn adjust_heap(v: &mut [T], hole_index: usize, len: usize, is_less: &mut F) -where - F: FnMut(&T, &T) -> bool, -{ - let mut left_child = hole_index * 2 + 1; - - // Panic safety: - // - // If `is_less` panics at any point during the process, `hole` will get dropped and - // fill the hole in `v` with `tmp`, thus ensuring that `v` still holds every object it - // initially held exactly once. - let mut tmp = mem::ManuallyDrop::new(unsafe { ptr::read(&v[hole_index]) }); - let mut hole = InsertionHole { - src: &mut *tmp, - dest: &mut v[hole_index], - }; - - unsafe { - while left_child < len { - if left_child + 1 < len - && is_less(v.get_unchecked(left_child), v.get_unchecked(left_child + 1)) - { - left_child += 1; - } - - if is_less(&*tmp, v.get_unchecked(left_child)) { - ptr::copy_nonoverlapping(&v[left_child], hole.dest, 1); - hole.dest = &mut v[left_child]; - } else { - break; - } - - left_child = left_child * 2 + 1; - } - // `hole` gets dropped and thus copies `tmp` into the remaining hole in `v`. - } - - // This code is copy from std library `src/slice.rs` - // When dropped, copies from `src` into `dest`. - struct InsertionHole { - src: *mut T, - dest: *mut T, - } - - impl Drop for InsertionHole { - fn drop(&mut self) { - // SAFETY: - // we ensure src/dest point to a properly initialized value of type T - // src is valid for reads of `count * size_of::()` bytes. - // dest is valid for reads of `count * size_of::()` bytes. - // Both `src` and `dst` are properly aligned. - unsafe { - ptr::copy_nonoverlapping(self.src, self.dest, 1); - } - } - } -} - -#[inline] -fn sort_heap(v: &mut [T], last: usize, is_less: &mut F) -where - F: FnMut(&T, &T) -> bool, + F: FnMut(&T, &T) -> Ordering, { - let mut last = last; - while last > 1 { - v.swap(0, last - 1); - adjust_heap(v, 0, last - 1, is_less); - last -= 1; - } + let (before, _mid, _after) = v.select_nth_unstable_by(limit, &mut is_less); + before.sort_unstable_by(is_less); } #[cfg(test)] @@ -2315,7 +2180,7 @@ mod tests { d.sort_unstable(); for last in 0..before.len() { - before.partial_sort(last, |a, b| a.cmp(b)); + partial_sort(&mut before, last, |a, b| a.cmp(b)); assert_eq!(&d[0..last], &before.as_slice()[0..last]); } } @@ -2329,7 +2194,7 @@ mod tests { let last = (rng.next_u32() % size) as usize; d.sort_unstable(); - before.partial_sort(last, |a, b| a.cmp(b)); + partial_sort(&mut before, last, |a, b| a.cmp(b)); assert_eq!(&d[0..last], &before[0..last]); } } From fd6a2342bea440922cab47ec0ca6bdba7fb8f708 Mon Sep 17 00:00:00 2001 From: sundy-li <543950155@qq.com> Date: Sun, 14 Mar 2021 22:30:58 +0800 Subject: [PATCH 11/11] Add docs about unstable_sort --- rust/arrow/src/compute/kernels/sort.rs | 2 ++ 1 file changed, 2 insertions(+) diff --git a/rust/arrow/src/compute/kernels/sort.rs b/rust/arrow/src/compute/kernels/sort.rs index 5a666844cb3..f5472738e7c 100644 --- a/rust/arrow/src/compute/kernels/sort.rs +++ b/rust/arrow/src/compute/kernels/sort.rs @@ -41,6 +41,7 @@ pub fn sort(values: &ArrayRef, options: Option) -> Result } /// Sort the `ArrayRef` partially. +/// It's unstable_sort, may not preserve the order of equal elements /// Return an sorted `ArrayRef`, discarding the data after limit. pub fn sort_limit( values: &ArrayRef, @@ -830,6 +831,7 @@ pub fn lexsort_to_indices( )) } +/// It's unstable_sort, may not preserve the order of equal elements pub fn partial_sort(v: &mut [T], limit: usize, mut is_less: F) where F: FnMut(&T, &T) -> Ordering,