From f201a4c05d4df410f0776b744ad3be335cc10a63 Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Tue, 16 Mar 2021 06:39:06 -0400 Subject: [PATCH] ARROW-11979: Combine limit into SortOptions --- rust/arrow/src/compute/kernels/sort.rs | 185 ++++++++++++++----- rust/datafusion/src/physical_plan/planner.rs | 1 + rust/datafusion/src/physical_plan/sort.rs | 2 + 3 files changed, 146 insertions(+), 42 deletions(-) diff --git a/rust/arrow/src/compute/kernels/sort.rs b/rust/arrow/src/compute/kernels/sort.rs index f5472738e7c..3e6685b8a8e 100644 --- a/rust/arrow/src/compute/kernels/sort.rs +++ b/rust/arrow/src/compute/kernels/sort.rs @@ -35,19 +35,44 @@ use TimeUnit::*; /// /// Returns an `ArrowError::ComputeError(String)` if the array type is either unsupported by `sort_to_indices` or `take`. /// +/// # Example +/// ```rust +/// # use std::sync::Arc; +/// # use arrow::array::{Int32Array, ArrayRef}; +/// # use arrow::error::Result; +/// # use arrow::compute::kernels::sort::{sort, SortOptions}; +/// # fn main() -> Result<()> { +/// let array: ArrayRef = Arc::new(Int32Array::from(vec![5, 4, 3, 2, 1])); +/// +/// // Sort the array +/// let sorted_array = sort(&array, None).unwrap(); +/// let sorted_array = sorted_array.as_any().downcast_ref::().unwrap(); +/// assert_eq!(sorted_array, &Int32Array::from(vec![1, 2, 3, 4, 5])); +/// +/// // Find the the top 2 items +/// let options = Some(SortOptions { +/// limit: Some(2), +/// ..Default::default() +/// }); +/// let sorted_array = sort(&array, options).unwrap(); +/// let sorted_array = sorted_array.as_any().downcast_ref::().unwrap(); +/// assert_eq!(sorted_array, &Int32Array::from(vec![1, 2])); +/// +/// // Find the bottom top 2 items +/// let options = Some(SortOptions { +/// descending: true, +/// limit: Some(2), +/// ..Default::default() +/// }); +/// let sorted_array = sort(&array, options).unwrap(); +/// let sorted_array = sorted_array.as_any().downcast_ref::().unwrap(); +/// assert_eq!(sorted_array, &Int32Array::from(vec![5, 4])); + +/// # Ok(()) +/// # } +/// ``` pub fn sort(values: &ArrayRef, options: Option) -> Result { - let indices = sort_to_indices(values, options, None)?; - take(values.as_ref(), &indices, None) -} - -/// Sort the `ArrayRef` partially. -/// It's unstable_sort, may not preserve the order of equal elements -/// Return an sorted `ArrayRef`, discarding the data after limit. -pub fn sort_limit( - values: &ArrayRef, - options: Option, - limit: Option, -) -> Result { + let limit = options.as_ref().map(|s| s.limit).unwrap_or(None); let indices = sort_to_indices(values, options, limit)?; take(values.as_ref(), &indices, None) } @@ -322,6 +347,10 @@ pub struct SortOptions { pub descending: bool, /// Whether to sort nulls first pub nulls_first: bool, + /// If Some(limit), only first `limit` elements in the sort order + /// in the output. Any data data after the limit will be + /// discarded. + pub limit: Option, } impl Default for SortOptions { @@ -330,6 +359,8 @@ impl Default for SortOptions { descending: false, // default to nulls first to match spark's behavior nulls_first: true, + // Keep all rows + limit: None, } } } @@ -720,6 +751,7 @@ pub struct SortColumn { /// options: Some(SortOptions { /// descending: true, /// nulls_first: false, +/// ..Default::default() /// }), /// }, /// ], None).unwrap(); @@ -881,6 +913,33 @@ mod tests { assert_eq!(output, expected) } + // TODO remove this function + // Combine the options and limit for testing purposes + fn combine_options( + options: Option, + limit: Option, + ) -> Option { + match limit { + None => options, + Some(limit) => match options { + Some(mut options) => { + assert!( + options.limit.is_none() || options.limit == Some(limit), + "conflicting limit specified: {:?} {:?}", + options, + limit + ); + options.limit = Some(limit); + Some(options) + } + None => Some(SortOptions { + limit: Some(limit), + ..Default::default() + }), + }, + } + } + fn test_sort_primitive_arrays( data: Vec>, options: Option, @@ -892,12 +951,8 @@ mod tests { { let output = PrimitiveArray::::from(data); let expected = Arc::new(PrimitiveArray::::from(expected_data)) as ArrayRef; - let output = match limit { - Some(_) => { - sort_limit(&(Arc::new(output) as ArrayRef), options, limit).unwrap() - } - _ => sort(&(Arc::new(output) as ArrayRef), options).unwrap(), - }; + let options = combine_options(options, limit); + let output = sort(&(Arc::new(output) as ArrayRef), options).unwrap(); assert_eq!(&output, &expected) } @@ -922,12 +977,8 @@ mod tests { ) { let output = StringArray::from(data); let expected = Arc::new(StringArray::from(expected_data)) as ArrayRef; - let output = match limit { - Some(_) => { - sort_limit(&(Arc::new(output) as ArrayRef), options, limit).unwrap() - } - _ => sort(&(Arc::new(output) as ArrayRef), options).unwrap(), - }; + let options = combine_options(options, limit); + let output = sort(&(Arc::new(output) as ArrayRef), options).unwrap(); assert_eq!(&output, &expected) } @@ -944,12 +995,8 @@ mod tests { .downcast_ref::() .expect("Unable to get dictionary values"); - let sorted = match limit { - Some(_) => { - sort_limit(&(Arc::new(array) as ArrayRef), options, limit).unwrap() - } - _ => sort(&(Arc::new(array) as ArrayRef), options).unwrap(), - }; + let options = combine_options(options, limit); + let sorted = sort(&(Arc::new(array) as ArrayRef), options).unwrap(); let sorted = sorted .as_any() .downcast_ref::>() @@ -994,10 +1041,8 @@ mod tests { // for FixedSizedList if let Some(length) = fixed_length { let input = Arc::new(build_fixed_size_list_nullable(data.clone(), length)); - let sorted = match limit { - Some(_) => sort_limit(&(input as ArrayRef), options, limit).unwrap(), - _ => sort(&(input as ArrayRef), options).unwrap(), - }; + let options = combine_options(options, limit); + let sorted = sort(&(input as ArrayRef), options).unwrap(); let expected = Arc::new(build_fixed_size_list_nullable( expected_data.clone(), length, @@ -1008,10 +1053,8 @@ mod tests { // for List let input = Arc::new(build_generic_list_nullable::(data.clone())); - let sorted = match limit { - Some(_) => sort_limit(&(input as ArrayRef), options, limit).unwrap(), - _ => sort(&(input as ArrayRef), options).unwrap(), - }; + let options = combine_options(options, limit); + let sorted = sort(&(input as ArrayRef), options).unwrap(); let expected = Arc::new(build_generic_list_nullable::(expected_data.clone())) as ArrayRef; @@ -1020,10 +1063,8 @@ mod tests { // for LargeList let input = Arc::new(build_generic_list_nullable::(data)); - let sorted = match limit { - Some(_) => sort_limit(&(input as ArrayRef), options, limit).unwrap(), - _ => sort(&(input as ArrayRef), options).unwrap(), - }; + let options = combine_options(options, limit); + let sorted = sort(&(input as ArrayRef), options).unwrap(); let expected = Arc::new(build_generic_list_nullable::(expected_data)) as ArrayRef; @@ -1101,6 +1142,7 @@ mod tests { Some(SortOptions { descending: true, nulls_first: false, + ..Default::default() }), None, vec![2, 1, 4, 3, 5, 0], // [2, 4, 1, 3, 5, 0] @@ -1111,6 +1153,7 @@ mod tests { Some(SortOptions { descending: true, nulls_first: false, + ..Default::default() }), None, vec![2, 1, 4, 3, 5, 0], @@ -1121,6 +1164,7 @@ mod tests { Some(SortOptions { descending: true, nulls_first: false, + ..Default::default() }), None, vec![2, 1, 4, 3, 5, 0], @@ -1131,6 +1175,7 @@ mod tests { Some(SortOptions { descending: true, nulls_first: false, + ..Default::default() }), None, vec![2, 1, 4, 3, 5, 0], @@ -1148,6 +1193,7 @@ mod tests { Some(SortOptions { descending: true, nulls_first: false, + ..Default::default() }), None, vec![2, 1, 4, 3, 5, 0], @@ -1158,6 +1204,7 @@ mod tests { Some(SortOptions { descending: true, nulls_first: false, + ..Default::default() }), None, vec![2, 1, 4, 3, 5, 0], @@ -1169,6 +1216,7 @@ mod tests { Some(SortOptions { descending: true, nulls_first: true, + ..Default::default() }), None, vec![5, 0, 2, 1, 4, 3], // [5, 0, 2, 4, 1, 3] @@ -1179,6 +1227,7 @@ mod tests { Some(SortOptions { descending: true, nulls_first: true, + ..Default::default() }), None, vec![5, 0, 2, 1, 4, 3], // [5, 0, 2, 4, 1, 3] @@ -1189,6 +1238,7 @@ mod tests { Some(SortOptions { descending: true, nulls_first: true, + ..Default::default() }), None, vec![5, 0, 2, 1, 4, 3], @@ -1199,6 +1249,7 @@ mod tests { Some(SortOptions { descending: true, nulls_first: true, + ..Default::default() }), None, vec![5, 0, 2, 1, 4, 3], @@ -1209,6 +1260,7 @@ mod tests { Some(SortOptions { descending: true, nulls_first: true, + ..Default::default() }), None, vec![5, 0, 2, 1, 4, 3], @@ -1219,6 +1271,7 @@ mod tests { Some(SortOptions { descending: true, nulls_first: true, + ..Default::default() }), None, vec![5, 0, 2, 1, 4, 3], @@ -1241,6 +1294,7 @@ mod tests { Some(SortOptions { descending: true, nulls_first: false, + ..Default::default() }), None, vec![2, 3, 1, 4, 5, 0], @@ -1252,6 +1306,7 @@ mod tests { Some(SortOptions { descending: true, nulls_first: true, + ..Default::default() }), None, vec![5, 0, 2, 3, 1, 4], @@ -1263,6 +1318,7 @@ mod tests { Some(SortOptions { descending: true, nulls_first: true, + ..Default::default() }), Some(3), vec![5, 0, 2], @@ -1303,6 +1359,7 @@ mod tests { Some(SortOptions { descending: true, nulls_first: false, + ..Default::default() }), None, vec![Some(2), Some(0), Some(0), Some(-1), None, None], @@ -1312,6 +1369,7 @@ mod tests { Some(SortOptions { descending: true, nulls_first: false, + ..Default::default() }), None, vec![Some(2), Some(0), Some(0), Some(-1), None, None], @@ -1321,6 +1379,7 @@ mod tests { Some(SortOptions { descending: true, nulls_first: false, + ..Default::default() }), None, vec![Some(2), Some(0), Some(0), Some(-1), None, None], @@ -1330,6 +1389,7 @@ mod tests { Some(SortOptions { descending: true, nulls_first: false, + ..Default::default() }), None, vec![Some(2), Some(0), Some(0), Some(-1), None, None], @@ -1341,6 +1401,7 @@ mod tests { Some(SortOptions { descending: true, nulls_first: true, + ..Default::default() }), None, vec![None, None, Some(2), Some(0), Some(0), Some(-1)], @@ -1350,6 +1411,7 @@ mod tests { Some(SortOptions { descending: true, nulls_first: true, + ..Default::default() }), None, vec![None, None, Some(2), Some(0), Some(0), Some(-1)], @@ -1359,6 +1421,7 @@ mod tests { Some(SortOptions { descending: true, nulls_first: true, + ..Default::default() }), None, vec![None, None, Some(2), Some(0), Some(0), Some(-1)], @@ -1368,6 +1431,7 @@ mod tests { Some(SortOptions { descending: true, nulls_first: true, + ..Default::default() }), None, vec![None, None, Some(2), Some(0), Some(0), Some(-1)], @@ -1378,6 +1442,7 @@ mod tests { Some(SortOptions { descending: true, nulls_first: true, + ..Default::default() }), Some(3), vec![None, None, Some(2)], @@ -1388,6 +1453,7 @@ mod tests { Some(SortOptions { descending: true, nulls_first: true, + ..Default::default() }), None, vec![None, None, Some(2.0), Some(0.0), Some(0.0), Some(-1.0)], @@ -1397,6 +1463,7 @@ mod tests { Some(SortOptions { descending: true, nulls_first: true, + ..Default::default() }), None, vec![None, None, Some(f64::NAN), Some(2.0), Some(0.0), Some(-1.0)], @@ -1406,6 +1473,7 @@ mod tests { Some(SortOptions { descending: true, nulls_first: true, + ..Default::default() }), None, vec![Some(f64::NAN), Some(f64::NAN), Some(f64::NAN), Some(1.0)], @@ -1417,6 +1485,7 @@ mod tests { Some(SortOptions { descending: false, nulls_first: true, + ..Default::default() }), None, vec![None, None, Some(-1), Some(0), Some(0), Some(2)], @@ -1426,6 +1495,7 @@ mod tests { Some(SortOptions { descending: false, nulls_first: true, + ..Default::default() }), None, vec![None, None, Some(-1), Some(0), Some(0), Some(2)], @@ -1435,6 +1505,7 @@ mod tests { Some(SortOptions { descending: false, nulls_first: true, + ..Default::default() }), None, vec![None, None, Some(-1), Some(0), Some(0), Some(2)], @@ -1444,6 +1515,7 @@ mod tests { Some(SortOptions { descending: false, nulls_first: true, + ..Default::default() }), None, vec![None, None, Some(-1), Some(0), Some(0), Some(2)], @@ -1453,6 +1525,7 @@ mod tests { Some(SortOptions { descending: false, nulls_first: true, + ..Default::default() }), None, vec![None, None, Some(-1.0), Some(0.0), Some(0.0), Some(2.0)], @@ -1462,6 +1535,7 @@ mod tests { Some(SortOptions { descending: false, nulls_first: true, + ..Default::default() }), None, vec![None, None, Some(-1.0), Some(0.0), Some(2.0), Some(f64::NAN)], @@ -1471,6 +1545,7 @@ mod tests { Some(SortOptions { descending: false, nulls_first: true, + ..Default::default() }), None, vec![Some(1.0), Some(f64::NAN), Some(f64::NAN), Some(f64::NAN)], @@ -1482,6 +1557,7 @@ mod tests { Some(SortOptions { descending: false, nulls_first: true, + ..Default::default() }), Some(2), vec![Some(1.0), Some(f64::NAN)], @@ -1493,6 +1569,7 @@ mod tests { Some(SortOptions { descending: false, nulls_first: true, + ..Default::default() }), Some(3), vec![Some(1.0), Some(2.0), Some(3.0)], @@ -1527,6 +1604,7 @@ mod tests { Some(SortOptions { descending: true, nulls_first: false, + ..Default::default() }), None, vec![2, 4, 1, 5, 3, 0], @@ -1544,6 +1622,7 @@ mod tests { Some(SortOptions { descending: false, nulls_first: true, + ..Default::default() }), None, vec![0, 3, 5, 1, 4, 2], @@ -1561,6 +1640,7 @@ mod tests { Some(SortOptions { descending: true, nulls_first: true, + ..Default::default() }), None, vec![3, 0, 2, 4, 1, 5], @@ -1578,6 +1658,7 @@ mod tests { Some(SortOptions { descending: true, nulls_first: true, + ..Default::default() }), Some(3), vec![3, 0, 2], @@ -1619,6 +1700,7 @@ mod tests { Some(SortOptions { descending: true, nulls_first: false, + ..Default::default() }), None, vec![ @@ -1643,6 +1725,7 @@ mod tests { Some(SortOptions { descending: false, nulls_first: true, + ..Default::default() }), None, vec![ @@ -1667,6 +1750,7 @@ mod tests { Some(SortOptions { descending: true, nulls_first: true, + ..Default::default() }), None, vec![ @@ -1691,6 +1775,7 @@ mod tests { Some(SortOptions { descending: true, nulls_first: true, + ..Default::default() }), Some(3), vec![None, None, Some("sad")], @@ -1732,6 +1817,7 @@ mod tests { Some(SortOptions { descending: true, nulls_first: false, + ..Default::default() }), None, vec![ @@ -1756,6 +1842,7 @@ mod tests { Some(SortOptions { descending: false, nulls_first: true, + ..Default::default() }), None, vec![ @@ -1780,6 +1867,7 @@ mod tests { Some(SortOptions { descending: true, nulls_first: true, + ..Default::default() }), None, vec![ @@ -1804,6 +1892,7 @@ mod tests { Some(SortOptions { descending: true, nulls_first: true, + ..Default::default() }), Some(3), vec![None, None, Some("sad")], @@ -1822,6 +1911,7 @@ mod tests { Some(SortOptions { descending: false, nulls_first: false, + ..Default::default() }), None, vec![ @@ -1844,6 +1934,7 @@ mod tests { Some(SortOptions { descending: false, nulls_first: false, + ..Default::default() }), None, vec![ @@ -1867,6 +1958,7 @@ mod tests { Some(SortOptions { descending: false, nulls_first: false, + ..Default::default() }), None, vec![ @@ -1890,6 +1982,7 @@ mod tests { Some(SortOptions { descending: false, nulls_first: false, + ..Default::default() }), Some(2), vec![Some(vec![Some(1), Some(0)]), Some(vec![Some(1), Some(1)])], @@ -2008,6 +2101,7 @@ mod tests { options: Some(SortOptions { descending: true, nulls_first: true, + ..Default::default() }), }, SortColumn { @@ -2020,6 +2114,7 @@ mod tests { options: Some(SortOptions { descending: true, nulls_first: true, + ..Default::default() }), }, ]; @@ -2051,6 +2146,7 @@ mod tests { options: Some(SortOptions { descending: true, nulls_first: true, + ..Default::default() }), }, SortColumn { @@ -2063,6 +2159,7 @@ mod tests { options: Some(SortOptions { descending: true, nulls_first: true, + ..Default::default() }), }, ]; @@ -2094,6 +2191,7 @@ mod tests { options: Some(SortOptions { descending: true, nulls_first: false, + ..Default::default() }), }, SortColumn { @@ -2106,6 +2204,7 @@ mod tests { options: Some(SortOptions { descending: true, nulls_first: false, + ..Default::default() }), }, ]; @@ -2138,6 +2237,7 @@ mod tests { options: Some(SortOptions { descending: false, nulls_first: false, + ..Default::default() }), }, SortColumn { @@ -2151,6 +2251,7 @@ mod tests { options: Some(SortOptions { descending: true, nulls_first: true, + ..Default::default() }), }, ]; diff --git a/rust/datafusion/src/physical_plan/planner.rs b/rust/datafusion/src/physical_plan/planner.rs index ef1cb1160a8..b59dbc2014b 100644 --- a/rust/datafusion/src/physical_plan/planner.rs +++ b/rust/datafusion/src/physical_plan/planner.rs @@ -310,6 +310,7 @@ impl DefaultPhysicalPlanner { SortOptions { descending: !*asc, nulls_first: *nulls_first, + ..Default::default() }, ctx_state, ), diff --git a/rust/datafusion/src/physical_plan/sort.rs b/rust/datafusion/src/physical_plan/sort.rs index 994168c2efb..4798dc89b20 100644 --- a/rust/datafusion/src/physical_plan/sort.rs +++ b/rust/datafusion/src/physical_plan/sort.rs @@ -363,6 +363,7 @@ mod tests { options: SortOptions { descending: true, nulls_first: true, + ..Default::default() }, }, PhysicalSortExpr { @@ -370,6 +371,7 @@ mod tests { options: SortOptions { descending: false, nulls_first: false, + ..Default::default() }, }, ],