diff --git a/arrow-ord/src/rank.rs b/arrow-ord/src/rank.rs index e61cebef38ec..1b0d2a7e6349 100644 --- a/arrow-ord/src/rank.rs +++ b/arrow-ord/src/rank.rs @@ -19,7 +19,9 @@ use arrow_array::cast::AsArray; use arrow_array::types::*; -use arrow_array::{downcast_primitive_array, Array, ArrowNativeTypeOp, GenericByteArray}; +use arrow_array::{ + downcast_primitive_array, Array, ArrowNativeTypeOp, BooleanArray, GenericByteArray, +}; use arrow_buffer::NullBuffer; use arrow_schema::{ArrowError, DataType, SortOptions}; use std::cmp::Ordering; @@ -29,7 +31,11 @@ pub(crate) fn can_rank(data_type: &DataType) -> bool { data_type.is_primitive() || matches!( data_type, - DataType::Utf8 | DataType::LargeUtf8 | DataType::Binary | DataType::LargeBinary + DataType::Boolean + | DataType::Utf8 + | DataType::LargeUtf8 + | DataType::Binary + | DataType::LargeBinary ) } @@ -49,6 +55,7 @@ pub fn rank(array: &dyn Array, options: Option) -> Result, let options = options.unwrap_or_default(); let ranks = downcast_primitive_array! { array => primitive_rank(array.values(), array.nulls(), options), + DataType::Boolean => boolean_rank(array.as_boolean(), options), DataType::Utf8 => bytes_rank(array.as_bytes::(), options), DataType::LargeUtf8 => bytes_rank(array.as_bytes::(), options), DataType::Binary => bytes_rank(array.as_bytes::(), options), @@ -135,6 +142,84 @@ where out } +/// Return the index for the rank when ranking boolean array +/// +/// The index is calculated as follows: +/// if is_null is true, the index is 2 +/// if is_null is false and the value is true, the index is 1 +/// otherwise, the index is 0 +/// +/// false is 0 and true is 1 because these are the value when cast to number +#[inline] +fn get_boolean_rank_index(value: bool, is_null: bool) -> usize { + let is_null_num = is_null as usize; + (is_null_num << 1) | (value as usize & !is_null_num) +} + +#[inline(never)] +fn boolean_rank(array: &BooleanArray, options: SortOptions) -> Vec { + let null_count = array.null_count() as u32; + let true_count = array.true_count() as u32; + let false_count = array.len() as u32 - null_count - true_count; + + // Rank values for [false, true, null] in that order + // + // The value for a rank is last value rank + own value count + // this means that if we have the following order: `false`, `true` and then `null` + // the ranks will be: + // - false: false_count + // - true: false_count + true_count + // - null: false_count + true_count + null_count + // + // If we have the following order: `null`, `false` and then `true` + // the ranks will be: + // - false: null_count + false_count + // - true: null_count + false_count + true_count + // - null: null_count + // + // You will notice that the last rank is always the total length of the array but we don't use it for readability on how the rank is calculated + let ranks_index: [u32; 3] = match (options.descending, options.nulls_first) { + // The order is null, true, false + (true, true) => [ + null_count + true_count + false_count, + null_count + true_count, + null_count, + ], + // The order is true, false, null + (true, false) => [ + true_count + false_count, + true_count, + true_count + false_count + null_count, + ], + // The order is null, false, true + (false, true) => [ + null_count + false_count, + null_count + false_count + true_count, + null_count, + ], + // The order is false, true, null + (false, false) => [ + false_count, + false_count + true_count, + false_count + true_count + null_count, + ], + }; + + match array.nulls().filter(|n| n.null_count() > 0) { + Some(n) => array + .values() + .iter() + .zip(n.iter()) + .map(|(value, is_valid)| ranks_index[get_boolean_rank_index(value, !is_valid)]) + .collect::>(), + None => array + .values() + .iter() + .map(|value| ranks_index[value as usize]) + .collect::>(), + } +} + #[cfg(test)] mod tests { use super::*; @@ -177,6 +262,82 @@ mod tests { assert_eq!(res, &[4, 6, 3, 6, 3, 3]); } + #[test] + fn test_get_boolean_rank_index() { + assert_eq!(get_boolean_rank_index(true, true), 2); + assert_eq!(get_boolean_rank_index(false, true), 2); + assert_eq!(get_boolean_rank_index(true, false), 1); + assert_eq!(get_boolean_rank_index(false, false), 0); + } + + #[test] + fn test_nullable_booleans() { + let descending = SortOptions { + descending: true, + nulls_first: true, + }; + + let nulls_last = SortOptions { + descending: false, + nulls_first: false, + }; + + let nulls_last_descending = SortOptions { + descending: true, + nulls_first: false, + }; + + let a = BooleanArray::from(vec![Some(true), Some(true), None, Some(false), Some(false)]); + let res = rank(&a, None).unwrap(); + assert_eq!(res, &[5, 5, 1, 3, 3]); + + let res = rank(&a, Some(descending)).unwrap(); + assert_eq!(res, &[3, 3, 1, 5, 5]); + + let res = rank(&a, Some(nulls_last)).unwrap(); + assert_eq!(res, &[4, 4, 5, 2, 2]); + + let res = rank(&a, Some(nulls_last_descending)).unwrap(); + assert_eq!(res, &[2, 2, 5, 4, 4]); + + // Test with non-zero null values + let nulls = NullBuffer::from(vec![true, true, false, true, true]); + let a = BooleanArray::new(vec![true, true, true, false, false].into(), Some(nulls)); + let res = rank(&a, None).unwrap(); + assert_eq!(res, &[5, 5, 1, 3, 3]); + } + + #[test] + fn test_booleans() { + let descending = SortOptions { + descending: true, + nulls_first: true, + }; + + let nulls_last = SortOptions { + descending: false, + nulls_first: false, + }; + + let nulls_last_descending = SortOptions { + descending: true, + nulls_first: false, + }; + + let a = BooleanArray::from(vec![true, false, false, false, true]); + let res = rank(&a, None).unwrap(); + assert_eq!(res, &[5, 3, 3, 3, 5]); + + let res = rank(&a, Some(descending)).unwrap(); + assert_eq!(res, &[2, 5, 5, 5, 2]); + + let res = rank(&a, Some(nulls_last)).unwrap(); + assert_eq!(res, &[5, 3, 3, 3, 5]); + + let res = rank(&a, Some(nulls_last_descending)).unwrap(); + assert_eq!(res, &[2, 5, 5, 5, 2]); + } + #[test] fn test_bytes() { let v = vec!["foo", "fo", "bar", "bar"]; diff --git a/arrow-ord/src/sort.rs b/arrow-ord/src/sort.rs index 51a6659e631b..fa5e2b8b2f7e 100644 --- a/arrow-ord/src/sort.rs +++ b/arrow-ord/src/sort.rs @@ -785,12 +785,14 @@ impl LexicographicalComparator { mod tests { use super::*; use arrow_array::builder::{ - FixedSizeListBuilder, Int64Builder, ListBuilder, PrimitiveRunBuilder, + BooleanBuilder, FixedSizeListBuilder, GenericListBuilder, Int64Builder, ListBuilder, + PrimitiveRunBuilder, }; use arrow_buffer::{i256, NullBuffer}; use arrow_schema::Field; use half::f16; use rand::rngs::StdRng; + use rand::seq::SliceRandom; use rand::{Rng, RngCore, SeedableRng}; fn create_decimal128_array(data: Vec>) -> Decimal128Array { @@ -1541,6 +1543,384 @@ mod tests { ); } + /// Test sort boolean on each permutation of with/without limit and GenericListArray/FixedSizeListArray + /// + /// The input data must have the same length for all list items so that we can test FixedSizeListArray + /// + fn test_every_config_sort_boolean_list_arrays( + data: Vec>>>, + options: Option, + expected_data: Vec>>>, + ) { + let first_length = data + .iter() + .find_map(|x| x.as_ref().map(|x| x.len())) + .unwrap_or(0); + let first_non_match_length = data + .iter() + .map(|x| x.as_ref().map(|x| x.len()).unwrap_or(first_length)) + .position(|x| x != first_length); + + assert_eq!( + first_non_match_length, None, + "All list items should have the same length {first_length}, input data is invalid" + ); + + let first_non_match_length = expected_data + .iter() + .map(|x| x.as_ref().map(|x| x.len()).unwrap_or(first_length)) + .position(|x| x != first_length); + + assert_eq!( + first_non_match_length, None, + "All list items should have the same length {first_length}, expected data is invalid" + ); + + let limit = expected_data.len().saturating_div(2); + + for &with_limit in &[false, true] { + let (limit, expected_data) = if with_limit { + ( + Some(limit), + expected_data.iter().take(limit).cloned().collect(), + ) + } else { + (None, expected_data.clone()) + }; + + for &fixed_length in &[None, Some(first_length as i32)] { + test_sort_boolean_list_arrays( + data.clone(), + options, + limit, + expected_data.clone(), + fixed_length, + ); + } + } + } + + fn test_sort_boolean_list_arrays( + data: Vec>>>, + options: Option, + limit: Option, + expected_data: Vec>>>, + fixed_length: Option, + ) { + fn build_fixed_boolean_list_array( + data: Vec>>>, + fixed_length: i32, + ) -> ArrayRef { + let mut builder = FixedSizeListBuilder::new( + BooleanBuilder::with_capacity(fixed_length as usize), + fixed_length, + ); + for sublist in data { + match sublist { + Some(sublist) => { + builder.values().extend(sublist); + builder.append(true); + } + None => { + builder + .values() + .extend(std::iter::repeat(None).take(fixed_length as usize)); + builder.append(false); + } + } + } + Arc::new(builder.finish()) as ArrayRef + } + + fn build_generic_boolean_list_array( + data: Vec>>>, + ) -> ArrayRef { + let mut builder = GenericListBuilder::::new(BooleanBuilder::new()); + builder.extend(data); + Arc::new(builder.finish()) as ArrayRef + } + + // for FixedSizedList + if let Some(length) = fixed_length { + let input = build_fixed_boolean_list_array(data.clone(), length); + let sorted = match limit { + Some(_) => sort_limit(&(input as ArrayRef), options, limit).unwrap(), + _ => sort(&(input as ArrayRef), options).unwrap(), + }; + let expected = build_fixed_boolean_list_array(expected_data.clone(), length); + + assert_eq!(&sorted, &expected); + } + + // for List + let input = build_generic_boolean_list_array::(data.clone()); + let sorted = match limit { + Some(_) => sort_limit(&(input as ArrayRef), options, limit).unwrap(), + _ => sort(&(input as ArrayRef), options).unwrap(), + }; + let expected = build_generic_boolean_list_array::(expected_data.clone()); + + assert_eq!(&sorted, &expected); + + // for LargeList + let input = build_generic_boolean_list_array::(data.clone()); + let sorted = match limit { + Some(_) => sort_limit(&(input as ArrayRef), options, limit).unwrap(), + _ => sort(&(input as ArrayRef), options).unwrap(), + }; + let expected = build_generic_boolean_list_array::(expected_data.clone()); + + assert_eq!(&sorted, &expected); + } + + #[test] + fn test_sort_list_of_booleans() { + // These are all the possible combinations of boolean values + // There are 3^3 + 1 = 28 possible combinations (3 values to permutate - [true, false, null] and 1 None value) + #[rustfmt::skip] + let mut cases = vec![ + Some(vec![Some(true), Some(true), Some(true)]), + Some(vec![Some(true), Some(true), Some(false)]), + Some(vec![Some(true), Some(true), None]), + + Some(vec![Some(true), Some(false), Some(true)]), + Some(vec![Some(true), Some(false), Some(false)]), + Some(vec![Some(true), Some(false), None]), + + Some(vec![Some(true), None, Some(true)]), + Some(vec![Some(true), None, Some(false)]), + Some(vec![Some(true), None, None]), + + Some(vec![Some(false), Some(true), Some(true)]), + Some(vec![Some(false), Some(true), Some(false)]), + Some(vec![Some(false), Some(true), None]), + + Some(vec![Some(false), Some(false), Some(true)]), + Some(vec![Some(false), Some(false), Some(false)]), + Some(vec![Some(false), Some(false), None]), + + Some(vec![Some(false), None, Some(true)]), + Some(vec![Some(false), None, Some(false)]), + Some(vec![Some(false), None, None]), + + Some(vec![None, Some(true), Some(true)]), + Some(vec![None, Some(true), Some(false)]), + Some(vec![None, Some(true), None]), + + Some(vec![None, Some(false), Some(true)]), + Some(vec![None, Some(false), Some(false)]), + Some(vec![None, Some(false), None]), + + Some(vec![None, None, Some(true)]), + Some(vec![None, None, Some(false)]), + Some(vec![None, None, None]), + None, + ]; + + cases.shuffle(&mut StdRng::seed_from_u64(42)); + + // The order is false, true, null + #[rustfmt::skip] + let expected_descending_false_nulls_first_false = vec![ + Some(vec![Some(false), Some(false), Some(false)]), + Some(vec![Some(false), Some(false), Some(true)]), + Some(vec![Some(false), Some(false), None]), + + Some(vec![Some(false), Some(true), Some(false)]), + Some(vec![Some(false), Some(true), Some(true)]), + Some(vec![Some(false), Some(true), None]), + + Some(vec![Some(false), None, Some(false)]), + Some(vec![Some(false), None, Some(true)]), + Some(vec![Some(false), None, None]), + + Some(vec![Some(true), Some(false), Some(false)]), + Some(vec![Some(true), Some(false), Some(true)]), + Some(vec![Some(true), Some(false), None]), + + Some(vec![Some(true), Some(true), Some(false)]), + Some(vec![Some(true), Some(true), Some(true)]), + Some(vec![Some(true), Some(true), None]), + + Some(vec![Some(true), None, Some(false)]), + Some(vec![Some(true), None, Some(true)]), + Some(vec![Some(true), None, None]), + + Some(vec![None, Some(false), Some(false)]), + Some(vec![None, Some(false), Some(true)]), + Some(vec![None, Some(false), None]), + + Some(vec![None, Some(true), Some(false)]), + Some(vec![None, Some(true), Some(true)]), + Some(vec![None, Some(true), None]), + + Some(vec![None, None, Some(false)]), + Some(vec![None, None, Some(true)]), + Some(vec![None, None, None]), + None, + ]; + test_every_config_sort_boolean_list_arrays( + cases.clone(), + Some(SortOptions { + descending: false, + nulls_first: false, + }), + expected_descending_false_nulls_first_false, + ); + + // The order is null, false, true + #[rustfmt::skip] + let expected_descending_false_nulls_first_true = vec![ + None, + + Some(vec![None, None, None]), + Some(vec![None, None, Some(false)]), + Some(vec![None, None, Some(true)]), + + Some(vec![None, Some(false), None]), + Some(vec![None, Some(false), Some(false)]), + Some(vec![None, Some(false), Some(true)]), + + Some(vec![None, Some(true), None]), + Some(vec![None, Some(true), Some(false)]), + Some(vec![None, Some(true), Some(true)]), + + Some(vec![Some(false), None, None]), + Some(vec![Some(false), None, Some(false)]), + Some(vec![Some(false), None, Some(true)]), + + Some(vec![Some(false), Some(false), None]), + Some(vec![Some(false), Some(false), Some(false)]), + Some(vec![Some(false), Some(false), Some(true)]), + + Some(vec![Some(false), Some(true), None]), + Some(vec![Some(false), Some(true), Some(false)]), + Some(vec![Some(false), Some(true), Some(true)]), + + Some(vec![Some(true), None, None]), + Some(vec![Some(true), None, Some(false)]), + Some(vec![Some(true), None, Some(true)]), + + Some(vec![Some(true), Some(false), None]), + Some(vec![Some(true), Some(false), Some(false)]), + Some(vec![Some(true), Some(false), Some(true)]), + + Some(vec![Some(true), Some(true), None]), + Some(vec![Some(true), Some(true), Some(false)]), + Some(vec![Some(true), Some(true), Some(true)]), + ]; + + test_every_config_sort_boolean_list_arrays( + cases.clone(), + Some(SortOptions { + descending: false, + nulls_first: true, + }), + expected_descending_false_nulls_first_true, + ); + + // The order is true, false, null + #[rustfmt::skip] + let expected_descending_true_nulls_first_false = vec![ + Some(vec![Some(true), Some(true), Some(true)]), + Some(vec![Some(true), Some(true), Some(false)]), + Some(vec![Some(true), Some(true), None]), + + Some(vec![Some(true), Some(false), Some(true)]), + Some(vec![Some(true), Some(false), Some(false)]), + Some(vec![Some(true), Some(false), None]), + + Some(vec![Some(true), None, Some(true)]), + Some(vec![Some(true), None, Some(false)]), + Some(vec![Some(true), None, None]), + + Some(vec![Some(false), Some(true), Some(true)]), + Some(vec![Some(false), Some(true), Some(false)]), + Some(vec![Some(false), Some(true), None]), + + Some(vec![Some(false), Some(false), Some(true)]), + Some(vec![Some(false), Some(false), Some(false)]), + Some(vec![Some(false), Some(false), None]), + + Some(vec![Some(false), None, Some(true)]), + Some(vec![Some(false), None, Some(false)]), + Some(vec![Some(false), None, None]), + + Some(vec![None, Some(true), Some(true)]), + Some(vec![None, Some(true), Some(false)]), + Some(vec![None, Some(true), None]), + + Some(vec![None, Some(false), Some(true)]), + Some(vec![None, Some(false), Some(false)]), + Some(vec![None, Some(false), None]), + + Some(vec![None, None, Some(true)]), + Some(vec![None, None, Some(false)]), + Some(vec![None, None, None]), + + None, + ]; + test_every_config_sort_boolean_list_arrays( + cases.clone(), + Some(SortOptions { + descending: true, + nulls_first: false, + }), + expected_descending_true_nulls_first_false, + ); + + // The order is null, true, false + #[rustfmt::skip] + let expected_descending_true_nulls_first_true = vec![ + None, + + Some(vec![None, None, None]), + Some(vec![None, None, Some(true)]), + Some(vec![None, None, Some(false)]), + + Some(vec![None, Some(true), None]), + Some(vec![None, Some(true), Some(true)]), + Some(vec![None, Some(true), Some(false)]), + + Some(vec![None, Some(false), None]), + Some(vec![None, Some(false), Some(true)]), + Some(vec![None, Some(false), Some(false)]), + + Some(vec![Some(true), None, None]), + Some(vec![Some(true), None, Some(true)]), + Some(vec![Some(true), None, Some(false)]), + + Some(vec![Some(true), Some(true), None]), + Some(vec![Some(true), Some(true), Some(true)]), + Some(vec![Some(true), Some(true), Some(false)]), + + Some(vec![Some(true), Some(false), None]), + Some(vec![Some(true), Some(false), Some(true)]), + Some(vec![Some(true), Some(false), Some(false)]), + + Some(vec![Some(false), None, None]), + Some(vec![Some(false), None, Some(true)]), + Some(vec![Some(false), None, Some(false)]), + + Some(vec![Some(false), Some(true), None]), + Some(vec![Some(false), Some(true), Some(true)]), + Some(vec![Some(false), Some(true), Some(false)]), + + Some(vec![Some(false), Some(false), None]), + Some(vec![Some(false), Some(false), Some(true)]), + Some(vec![Some(false), Some(false), Some(false)]), + ]; + // Testing with limit false and fixed_length None + test_every_config_sort_boolean_list_arrays( + cases.clone(), + Some(SortOptions { + descending: true, + nulls_first: true, + }), + expected_descending_true_nulls_first_true, + ); + } + #[test] fn test_sort_indices_decimal128() { // decimal default