diff --git a/arrow-ord/src/ord.rs b/arrow-ord/src/ord.rs index 6e3025576c69..b12a06732d42 100644 --- a/arrow-ord/src/ord.rs +++ b/arrow-ord/src/ord.rs @@ -21,8 +21,8 @@ use arrow_array::cast::AsArray; use arrow_array::types::*; use arrow_array::*; use arrow_buffer::{ArrowNativeType, NullBuffer}; -use arrow_schema::{ArrowError, SortOptions}; -use std::cmp::Ordering; +use arrow_schema::{ArrowError, DataType, SortOptions}; +use std::{cmp::Ordering, collections::HashMap}; /// Compare the values at two arbitrary indices in two arrays. pub type DynComparator = Box Ordering + Send + Sync>; @@ -296,6 +296,78 @@ fn compare_struct( Ok(f) } +fn compare_union( + left: &dyn Array, + right: &dyn Array, + opts: SortOptions, +) -> Result { + let left = left.as_union(); + let right = right.as_union(); + + let (left_fields, left_mode) = match left.data_type() { + DataType::Union(fields, mode) => (fields, mode), + _ => unreachable!(), + }; + let (right_fields, right_mode) = match right.data_type() { + DataType::Union(fields, mode) => (fields, mode), + _ => unreachable!(), + }; + + if left_fields != right_fields { + return Err(ArrowError::InvalidArgumentError(format!( + "Cannot compare UnionArrays with different fields: left={:?}, right={:?}", + left_fields, right_fields + ))); + } + + if left_mode != right_mode { + return Err(ArrowError::InvalidArgumentError(format!( + "Cannot compare UnionArrays with different modes: left={:?}, right={:?}", + left_mode, right_mode + ))); + } + + let c_opts = child_opts(opts); + + let mut field_comparators = HashMap::with_capacity(left_fields.len()); + + for (type_id, _field) in left_fields.iter() { + let left_child = left.child(type_id); + let right_child = right.child(type_id); + let cmp = make_comparator(left_child.as_ref(), right_child.as_ref(), c_opts)?; + + field_comparators.insert(type_id, cmp); + } + + let left_type_ids = left.type_ids().clone(); + let right_type_ids = right.type_ids().clone(); + + let left_offsets = left.offsets().cloned(); + let right_offsets = right.offsets().cloned(); + + let f = compare(left, right, opts, move |i, j| { + let left_type_id = left_type_ids[i]; + let right_type_id = right_type_ids[j]; + + // first, compare by type_id + match left_type_id.cmp(&right_type_id) { + Ordering::Equal => { + // second, compare by values + let left_offset = left_offsets.as_ref().map(|o| o[i] as usize).unwrap_or(i); + let right_offset = right_offsets.as_ref().map(|o| o[j] as usize).unwrap_or(j); + + let cmp = field_comparators + .get(&left_type_id) + .expect("type id not found in field_comparators"); + + cmp(left_offset, right_offset) + } + other => other, + } + }); + Ok(f) +} + /// Returns a comparison function that compares two values at two different positions /// between the two arrays. /// @@ -412,6 +484,7 @@ pub fn make_comparator( } }, (Map(_, _), Map(_, _)) => compare_map(left, right, opts), + (Union(_, _), Union(_, _)) => compare_union(left, right, opts), (lhs, rhs) => Err(ArrowError::InvalidArgumentError(match lhs == rhs { true => format!("The data type type {lhs:?} has no natural order"), false => "Can't compare arrays of different types".to_string(), @@ -423,8 +496,8 @@ pub fn make_comparator( mod tests { use super::*; use arrow_array::builder::{Int32Builder, ListBuilder, MapBuilder, StringBuilder}; - use arrow_buffer::{IntervalDayTime, OffsetBuffer, i256}; - use arrow_schema::{DataType, Field, Fields}; + use arrow_buffer::{IntervalDayTime, OffsetBuffer, ScalarBuffer, i256}; + use arrow_schema::{DataType, Field, Fields, UnionFields}; use half::f16; use std::sync::Arc; @@ -1189,4 +1262,243 @@ mod tests { } } } + + #[test] + fn test_dense_union() { + // create a dense union array with Int32 (type_id = 0) and Utf8 (type_id=1) + // the values are: [1, "b", 2, "a", 3] + // type_ids are: [0, 1, 0, 1, 0] + // offsets are: [0, 0, 1, 1, 2] from [1, 2, 3] and ["b", "a"] + let int_array = Int32Array::from(vec![1, 2, 3]); + let str_array = StringArray::from(vec!["b", "a"]); + + let type_ids = [0, 1, 0, 1, 0].into_iter().collect::>(); + let offsets = [0, 0, 1, 1, 2].into_iter().collect::>(); + + let union_fields = [ + (0, Arc::new(Field::new("A", DataType::Int32, false))), + (1, Arc::new(Field::new("B", DataType::Utf8, false))), + ] + .into_iter() + .collect::(); + + let children = vec![Arc::new(int_array) as ArrayRef, Arc::new(str_array)]; + + let array1 = + UnionArray::try_new(union_fields.clone(), type_ids, Some(offsets), children).unwrap(); + + // create a second array: [2, "a", 1, "c"] + // type ids are: [0, 1, 0, 1] + // offsets are: [0, 0, 1, 1] from [2, 1] and ["a", "c"] + let int_array2 = Int32Array::from(vec![2, 1]); + let str_array2 = StringArray::from(vec!["a", "c"]); + let type_ids2 = [0, 1, 0, 1].into_iter().collect::>(); + let offsets2 = [0, 0, 1, 1].into_iter().collect::>(); + + let children2 = vec![Arc::new(int_array2) as ArrayRef, Arc::new(str_array2)]; + + let array2 = + UnionArray::try_new(union_fields, type_ids2, Some(offsets2), children2).unwrap(); + + let opts = SortOptions { + descending: false, + nulls_first: true, + }; + + // comparing + // [1, "b", 2, "a", 3] + // [2, "a", 1, "c"] + let cmp = make_comparator(&array1, &array2, opts).unwrap(); + + // array1[0] = (type_id=0, value=1) + // array2[0] = (type_id=0, value=2) + assert_eq!(cmp(0, 0), Ordering::Less); // 1 < 2 + + // array1[0] = (type_id=0, value=1) + // array2[1] = (type_id=1, value="a") + assert_eq!(cmp(0, 1), Ordering::Less); // type_id 0 < 1 + + // array1[1] = (type_id=1, value="b") + // array2[1] = (type_id=1, value="a") + assert_eq!(cmp(1, 1), Ordering::Greater); // "b" > "a" + + // array1[2] = (type_id=0, value=2) + // array2[0] = (type_id=0, value=2) + assert_eq!(cmp(2, 0), Ordering::Equal); // 2 == 2 + + // array1[3] = (type_id=1, value="a") + // array2[1] = (type_id=1, value="a") + assert_eq!(cmp(3, 1), Ordering::Equal); // "a" == "a" + + // array1[1] = (type_id=1, value="b") + // array2[3] = (type_id=1, value="c") + assert_eq!(cmp(1, 3), Ordering::Less); // "b" < "c" + + let opts_desc = SortOptions { + descending: true, + nulls_first: true, + }; + let cmp_desc = make_comparator(&array1, &array2, opts_desc).unwrap(); + + assert_eq!(cmp_desc(0, 0), Ordering::Greater); // 1 > 2 (reversed) + assert_eq!(cmp_desc(0, 1), Ordering::Greater); // type_id 0 < 1, reversed to Greater + assert_eq!(cmp_desc(1, 1), Ordering::Less); // "b" < "a" (reversed) + } + + #[test] + fn test_sparse_union() { + // create a sparse union array with Int32 (type_id=0) and Utf8 (type_id=1) + // values: [1, "b", 3] + // note, in sparse unions, child arrays have the same length as the union + let int_array = Int32Array::from(vec![Some(1), None, Some(3)]); + let str_array = StringArray::from(vec![None, Some("b"), None]); + let type_ids = [0, 1, 0].into_iter().collect::>(); + + let union_fields = [ + (0, Arc::new(Field::new("a", DataType::Int32, false))), + (1, Arc::new(Field::new("b", DataType::Utf8, false))), + ] + .into_iter() + .collect::(); + + let children = vec![Arc::new(int_array) as ArrayRef, Arc::new(str_array)]; + + let array = UnionArray::try_new(union_fields, type_ids, None, children).unwrap(); + + let opts = SortOptions::default(); + let cmp = make_comparator(&array, &array, opts).unwrap(); + + // array[0] = (type_id=0, value=1), array[2] = (type_id=0, value=3) + assert_eq!(cmp(0, 2), Ordering::Less); // 1 < 3 + // array[0] = (type_id=0, value=1), array[1] = (type_id=1, value="b") + assert_eq!(cmp(0, 1), Ordering::Less); // type_id 0 < 1 + } + + #[test] + #[should_panic(expected = "index out of bounds")] + fn test_union_out_of_bounds() { + // create a dense union array with 3 elements + let int_array = Int32Array::from(vec![1, 2]); + let str_array = StringArray::from(vec!["a"]); + + let type_ids = [0, 1, 0].into_iter().collect::>(); + let offsets = [0, 0, 1].into_iter().collect::>(); + + let union_fields = [ + (0, Arc::new(Field::new("A", DataType::Int32, false))), + (1, Arc::new(Field::new("B", DataType::Utf8, false))), + ] + .into_iter() + .collect::(); + + let children = vec![Arc::new(int_array) as ArrayRef, Arc::new(str_array)]; + + let array = UnionArray::try_new(union_fields, type_ids, Some(offsets), children).unwrap(); + + let opts = SortOptions::default(); + let cmp = make_comparator(&array, &array, opts).unwrap(); + + // oob + cmp(0, 3); + } + + #[test] + fn test_union_incompatible_fields() { + // create first union with Int32 and Utf8 + let int_array1 = Int32Array::from(vec![1, 2]); + let str_array1 = StringArray::from(vec!["a", "b"]); + + let type_ids1 = [0, 1].into_iter().collect::>(); + let offsets1 = [0, 0].into_iter().collect::>(); + + let union_fields1 = [ + (0, Arc::new(Field::new("A", DataType::Int32, false))), + (1, Arc::new(Field::new("B", DataType::Utf8, false))), + ] + .into_iter() + .collect::(); + + let children1 = vec![Arc::new(int_array1) as ArrayRef, Arc::new(str_array1)]; + + let array1 = + UnionArray::try_new(union_fields1, type_ids1, Some(offsets1), children1).unwrap(); + + // create second union with Int32 and Float64 (incompatible with first) + let int_array2 = Int32Array::from(vec![3, 4]); + let float_array2 = Float64Array::from(vec![1.0, 2.0]); + + let type_ids2 = [0, 1].into_iter().collect::>(); + let offsets2 = [0, 0].into_iter().collect::>(); + + let union_fields2 = [ + (0, Arc::new(Field::new("A", DataType::Int32, false))), + (1, Arc::new(Field::new("C", DataType::Float64, false))), + ] + .into_iter() + .collect::(); + + let children2 = vec![Arc::new(int_array2) as ArrayRef, Arc::new(float_array2)]; + + let array2 = + UnionArray::try_new(union_fields2, type_ids2, Some(offsets2), children2).unwrap(); + + let opts = SortOptions::default(); + + let Result::Err(ArrowError::InvalidArgumentError(out)) = + make_comparator(&array1, &array2, opts) + else { + panic!("expected error when making comparator of incompatible union arrays"); + }; + + assert_eq!( + &out, + "Cannot compare UnionArrays with different fields: left=[(0, Field { name: \"A\", data_type: Int32 }), (1, Field { name: \"B\", data_type: Utf8 })], right=[(0, Field { name: \"A\", data_type: Int32 }), (1, Field { name: \"C\", data_type: Float64 })]" + ); + } + + #[test] + fn test_union_incompatible_modes() { + // create first union as Dense with Int32 and Utf8 + let int_array1 = Int32Array::from(vec![1, 2]); + let str_array1 = StringArray::from(vec!["a", "b"]); + + let type_ids1 = [0, 1].into_iter().collect::>(); + let offsets1 = [0, 0].into_iter().collect::>(); + + let union_fields1 = [ + (0, Arc::new(Field::new("A", DataType::Int32, false))), + (1, Arc::new(Field::new("B", DataType::Utf8, false))), + ] + .into_iter() + .collect::(); + + let children1 = vec![Arc::new(int_array1) as ArrayRef, Arc::new(str_array1)]; + + let array1 = + UnionArray::try_new(union_fields1.clone(), type_ids1, Some(offsets1), children1) + .unwrap(); + + // create second union as Sparse with same fields (Int32 and Utf8) + let int_array2 = Int32Array::from(vec![Some(3), None]); + let str_array2 = StringArray::from(vec![None, Some("c")]); + + let type_ids2 = [0, 1].into_iter().collect::>(); + + let children2 = vec![Arc::new(int_array2) as ArrayRef, Arc::new(str_array2)]; + + let array2 = UnionArray::try_new(union_fields1, type_ids2, None, children2).unwrap(); + + let opts = SortOptions::default(); + + let Result::Err(ArrowError::InvalidArgumentError(out)) = + make_comparator(&array1, &array2, opts) + else { + panic!("expected error when making comparator of union arrays with different modes"); + }; + + assert_eq!( + &out, + "Cannot compare UnionArrays with different modes: left=Dense, right=Sparse" + ); + } }