-
Notifications
You must be signed in to change notification settings - Fork 1.1k
Add comparison support for Union arrays #8838
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -21,8 +21,8 @@ use arrow_array::cast::AsArray; | |
| use arrow_array::types::*; | ||
| use arrow_array::*; | ||
| use arrow_buffer::{ArrowNativeType, NullBuffer}; | ||
| use arrow_schema::{ArrowError, SortOptions}; | ||
| use std::cmp::Ordering; | ||
| use arrow_schema::{ArrowError, DataType, SortOptions}; | ||
| use std::{cmp::Ordering, collections::HashMap}; | ||
|
|
||
| /// Compare the values at two arbitrary indices in two arrays. | ||
| pub type DynComparator = Box<dyn Fn(usize, usize) -> Ordering + Send + Sync>; | ||
|
|
@@ -296,6 +296,78 @@ fn compare_struct( | |
| Ok(f) | ||
| } | ||
|
|
||
| fn compare_union( | ||
| left: &dyn Array, | ||
| right: &dyn Array, | ||
| opts: SortOptions, | ||
| ) -> Result<DynComparator, ArrowError> { | ||
| let left = left.as_union(); | ||
| let right = right.as_union(); | ||
|
|
||
| let (left_fields, left_mode) = match left.data_type() { | ||
| DataType::Union(fields, mode) => (fields, mode), | ||
| _ => unreachable!(), | ||
| }; | ||
| let (right_fields, right_mode) = match right.data_type() { | ||
| DataType::Union(fields, mode) => (fields, mode), | ||
| _ => unreachable!(), | ||
| }; | ||
|
|
||
| if left_fields != right_fields { | ||
| return Err(ArrowError::InvalidArgumentError(format!( | ||
| "Cannot compare UnionArrays with different fields: left={:?}, right={:?}", | ||
| left_fields, right_fields | ||
| ))); | ||
| } | ||
|
|
||
| if left_mode != right_mode { | ||
| return Err(ArrowError::InvalidArgumentError(format!( | ||
| "Cannot compare UnionArrays with different modes: left={:?}, right={:?}", | ||
| left_mode, right_mode | ||
| ))); | ||
| } | ||
|
|
||
| let c_opts = child_opts(opts); | ||
|
|
||
| let mut field_comparators = HashMap::with_capacity(left_fields.len()); | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. rather than a hash map you could potentially just use a 128 valued
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hm so this was my first thought/approach as well, but I decided to use a hashmap because it avoids superfluous memory usage for sparse sets Plus, I don't think this is a very hot path, so any perf differences wouldn't be super meaningful |
||
|
|
||
| for (type_id, _field) in left_fields.iter() { | ||
| let left_child = left.child(type_id); | ||
| let right_child = right.child(type_id); | ||
| let cmp = make_comparator(left_child.as_ref(), right_child.as_ref(), c_opts)?; | ||
|
|
||
| field_comparators.insert(type_id, cmp); | ||
| } | ||
|
|
||
| let left_type_ids = left.type_ids().clone(); | ||
| let right_type_ids = right.type_ids().clone(); | ||
|
|
||
| let left_offsets = left.offsets().cloned(); | ||
| let right_offsets = right.offsets().cloned(); | ||
|
|
||
| let f = compare(left, right, opts, move |i, j| { | ||
| let left_type_id = left_type_ids[i]; | ||
| let right_type_id = right_type_ids[j]; | ||
|
|
||
| // first, compare by type_id | ||
| match left_type_id.cmp(&right_type_id) { | ||
| Ordering::Equal => { | ||
| // second, compare by values | ||
| let left_offset = left_offsets.as_ref().map(|o| o[i] as usize).unwrap_or(i); | ||
| let right_offset = right_offsets.as_ref().map(|o| o[j] as usize).unwrap_or(j); | ||
|
|
||
| let cmp = field_comparators | ||
| .get(&left_type_id) | ||
| .expect("type id not found in field_comparators"); | ||
|
|
||
| cmp(left_offset, right_offset) | ||
| } | ||
| other => other, | ||
| } | ||
| }); | ||
| Ok(f) | ||
| } | ||
|
|
||
| /// Returns a comparison function that compares two values at two different positions | ||
| /// between the two arrays. | ||
| /// | ||
|
|
@@ -412,6 +484,7 @@ pub fn make_comparator( | |
| } | ||
| }, | ||
| (Map(_, _), Map(_, _)) => compare_map(left, right, opts), | ||
| (Union(_, _), Union(_, _)) => compare_union(left, right, opts), | ||
| (lhs, rhs) => Err(ArrowError::InvalidArgumentError(match lhs == rhs { | ||
| true => format!("The data type type {lhs:?} has no natural order"), | ||
| false => "Can't compare arrays of different types".to_string(), | ||
|
|
@@ -423,8 +496,8 @@ pub fn make_comparator( | |
| mod tests { | ||
| use super::*; | ||
| use arrow_array::builder::{Int32Builder, ListBuilder, MapBuilder, StringBuilder}; | ||
| use arrow_buffer::{IntervalDayTime, OffsetBuffer, i256}; | ||
| use arrow_schema::{DataType, Field, Fields}; | ||
| use arrow_buffer::{IntervalDayTime, OffsetBuffer, ScalarBuffer, i256}; | ||
| use arrow_schema::{DataType, Field, Fields, UnionFields}; | ||
| use half::f16; | ||
| use std::sync::Arc; | ||
|
|
||
|
|
@@ -1189,4 +1262,243 @@ mod tests { | |
| } | ||
| } | ||
| } | ||
|
|
||
| #[test] | ||
| fn test_dense_union() { | ||
| // create a dense union array with Int32 (type_id = 0) and Utf8 (type_id=1) | ||
| // the values are: [1, "b", 2, "a", 3] | ||
| // type_ids are: [0, 1, 0, 1, 0] | ||
| // offsets are: [0, 0, 1, 1, 2] from [1, 2, 3] and ["b", "a"] | ||
| let int_array = Int32Array::from(vec![1, 2, 3]); | ||
| let str_array = StringArray::from(vec!["b", "a"]); | ||
|
|
||
| let type_ids = [0, 1, 0, 1, 0].into_iter().collect::<ScalarBuffer<i8>>(); | ||
| let offsets = [0, 0, 1, 1, 2].into_iter().collect::<ScalarBuffer<i32>>(); | ||
|
|
||
| let union_fields = [ | ||
| (0, Arc::new(Field::new("A", DataType::Int32, false))), | ||
| (1, Arc::new(Field::new("B", DataType::Utf8, false))), | ||
| ] | ||
| .into_iter() | ||
| .collect::<UnionFields>(); | ||
|
|
||
| let children = vec![Arc::new(int_array) as ArrayRef, Arc::new(str_array)]; | ||
|
|
||
| let array1 = | ||
| UnionArray::try_new(union_fields.clone(), type_ids, Some(offsets), children).unwrap(); | ||
|
|
||
| // create a second array: [2, "a", 1, "c"] | ||
| // type ids are: [0, 1, 0, 1] | ||
| // offsets are: [0, 0, 1, 1] from [2, 1] and ["a", "c"] | ||
| let int_array2 = Int32Array::from(vec![2, 1]); | ||
| let str_array2 = StringArray::from(vec!["a", "c"]); | ||
| let type_ids2 = [0, 1, 0, 1].into_iter().collect::<ScalarBuffer<i8>>(); | ||
| let offsets2 = [0, 0, 1, 1].into_iter().collect::<ScalarBuffer<i32>>(); | ||
|
|
||
| let children2 = vec![Arc::new(int_array2) as ArrayRef, Arc::new(str_array2)]; | ||
|
|
||
| let array2 = | ||
| UnionArray::try_new(union_fields, type_ids2, Some(offsets2), children2).unwrap(); | ||
|
|
||
| let opts = SortOptions { | ||
| descending: false, | ||
| nulls_first: true, | ||
| }; | ||
|
|
||
| // comparing | ||
| // [1, "b", 2, "a", 3] | ||
| // [2, "a", 1, "c"] | ||
| let cmp = make_comparator(&array1, &array2, opts).unwrap(); | ||
|
|
||
| // array1[0] = (type_id=0, value=1) | ||
| // array2[0] = (type_id=0, value=2) | ||
| assert_eq!(cmp(0, 0), Ordering::Less); // 1 < 2 | ||
|
|
||
| // array1[0] = (type_id=0, value=1) | ||
| // array2[1] = (type_id=1, value="a") | ||
| assert_eq!(cmp(0, 1), Ordering::Less); // type_id 0 < 1 | ||
|
|
||
| // array1[1] = (type_id=1, value="b") | ||
| // array2[1] = (type_id=1, value="a") | ||
| assert_eq!(cmp(1, 1), Ordering::Greater); // "b" > "a" | ||
|
|
||
| // array1[2] = (type_id=0, value=2) | ||
| // array2[0] = (type_id=0, value=2) | ||
| assert_eq!(cmp(2, 0), Ordering::Equal); // 2 == 2 | ||
|
|
||
| // array1[3] = (type_id=1, value="a") | ||
| // array2[1] = (type_id=1, value="a") | ||
| assert_eq!(cmp(3, 1), Ordering::Equal); // "a" == "a" | ||
|
|
||
| // array1[1] = (type_id=1, value="b") | ||
| // array2[3] = (type_id=1, value="c") | ||
| assert_eq!(cmp(1, 3), Ordering::Less); // "b" < "c" | ||
|
|
||
| let opts_desc = SortOptions { | ||
| descending: true, | ||
| nulls_first: true, | ||
| }; | ||
| let cmp_desc = make_comparator(&array1, &array2, opts_desc).unwrap(); | ||
|
|
||
| assert_eq!(cmp_desc(0, 0), Ordering::Greater); // 1 > 2 (reversed) | ||
| assert_eq!(cmp_desc(0, 1), Ordering::Greater); // type_id 0 < 1, reversed to Greater | ||
| assert_eq!(cmp_desc(1, 1), Ordering::Less); // "b" < "a" (reversed) | ||
| } | ||
|
|
||
| #[test] | ||
| fn test_sparse_union() { | ||
| // create a sparse union array with Int32 (type_id=0) and Utf8 (type_id=1) | ||
| // values: [1, "b", 3] | ||
| // note, in sparse unions, child arrays have the same length as the union | ||
| let int_array = Int32Array::from(vec![Some(1), None, Some(3)]); | ||
| let str_array = StringArray::from(vec![None, Some("b"), None]); | ||
| let type_ids = [0, 1, 0].into_iter().collect::<ScalarBuffer<i8>>(); | ||
|
|
||
| let union_fields = [ | ||
| (0, Arc::new(Field::new("a", DataType::Int32, false))), | ||
| (1, Arc::new(Field::new("b", DataType::Utf8, false))), | ||
| ] | ||
| .into_iter() | ||
| .collect::<UnionFields>(); | ||
|
|
||
| let children = vec![Arc::new(int_array) as ArrayRef, Arc::new(str_array)]; | ||
|
|
||
| let array = UnionArray::try_new(union_fields, type_ids, None, children).unwrap(); | ||
|
|
||
| let opts = SortOptions::default(); | ||
| let cmp = make_comparator(&array, &array, opts).unwrap(); | ||
|
|
||
| // array[0] = (type_id=0, value=1), array[2] = (type_id=0, value=3) | ||
| assert_eq!(cmp(0, 2), Ordering::Less); // 1 < 3 | ||
| // array[0] = (type_id=0, value=1), array[1] = (type_id=1, value="b") | ||
| assert_eq!(cmp(0, 1), Ordering::Less); // type_id 0 < 1 | ||
| } | ||
|
|
||
| #[test] | ||
| #[should_panic(expected = "index out of bounds")] | ||
| fn test_union_out_of_bounds() { | ||
| // create a dense union array with 3 elements | ||
| let int_array = Int32Array::from(vec![1, 2]); | ||
| let str_array = StringArray::from(vec!["a"]); | ||
|
|
||
| let type_ids = [0, 1, 0].into_iter().collect::<ScalarBuffer<i8>>(); | ||
| let offsets = [0, 0, 1].into_iter().collect::<ScalarBuffer<i32>>(); | ||
|
|
||
| let union_fields = [ | ||
| (0, Arc::new(Field::new("A", DataType::Int32, false))), | ||
| (1, Arc::new(Field::new("B", DataType::Utf8, false))), | ||
| ] | ||
| .into_iter() | ||
| .collect::<UnionFields>(); | ||
|
|
||
| let children = vec![Arc::new(int_array) as ArrayRef, Arc::new(str_array)]; | ||
|
|
||
| let array = UnionArray::try_new(union_fields, type_ids, Some(offsets), children).unwrap(); | ||
|
|
||
| let opts = SortOptions::default(); | ||
| let cmp = make_comparator(&array, &array, opts).unwrap(); | ||
|
|
||
| // oob | ||
| cmp(0, 3); | ||
| } | ||
|
|
||
| #[test] | ||
| fn test_union_incompatible_fields() { | ||
| // create first union with Int32 and Utf8 | ||
| let int_array1 = Int32Array::from(vec![1, 2]); | ||
| let str_array1 = StringArray::from(vec!["a", "b"]); | ||
|
|
||
| let type_ids1 = [0, 1].into_iter().collect::<ScalarBuffer<i8>>(); | ||
| let offsets1 = [0, 0].into_iter().collect::<ScalarBuffer<i32>>(); | ||
|
|
||
| let union_fields1 = [ | ||
| (0, Arc::new(Field::new("A", DataType::Int32, false))), | ||
| (1, Arc::new(Field::new("B", DataType::Utf8, false))), | ||
| ] | ||
| .into_iter() | ||
| .collect::<UnionFields>(); | ||
|
|
||
| let children1 = vec![Arc::new(int_array1) as ArrayRef, Arc::new(str_array1)]; | ||
|
|
||
| let array1 = | ||
| UnionArray::try_new(union_fields1, type_ids1, Some(offsets1), children1).unwrap(); | ||
|
|
||
| // create second union with Int32 and Float64 (incompatible with first) | ||
| let int_array2 = Int32Array::from(vec![3, 4]); | ||
| let float_array2 = Float64Array::from(vec![1.0, 2.0]); | ||
|
|
||
| let type_ids2 = [0, 1].into_iter().collect::<ScalarBuffer<i8>>(); | ||
| let offsets2 = [0, 0].into_iter().collect::<ScalarBuffer<i32>>(); | ||
|
|
||
| let union_fields2 = [ | ||
| (0, Arc::new(Field::new("A", DataType::Int32, false))), | ||
| (1, Arc::new(Field::new("C", DataType::Float64, false))), | ||
| ] | ||
| .into_iter() | ||
| .collect::<UnionFields>(); | ||
|
|
||
| let children2 = vec![Arc::new(int_array2) as ArrayRef, Arc::new(float_array2)]; | ||
|
|
||
| let array2 = | ||
| UnionArray::try_new(union_fields2, type_ids2, Some(offsets2), children2).unwrap(); | ||
|
|
||
| let opts = SortOptions::default(); | ||
|
|
||
| let Result::Err(ArrowError::InvalidArgumentError(out)) = | ||
| make_comparator(&array1, &array2, opts) | ||
| else { | ||
| panic!("expected error when making comparator of incompatible union arrays"); | ||
| }; | ||
|
|
||
| assert_eq!( | ||
| &out, | ||
| "Cannot compare UnionArrays with different fields: left=[(0, Field { name: \"A\", data_type: Int32 }), (1, Field { name: \"B\", data_type: Utf8 })], right=[(0, Field { name: \"A\", data_type: Int32 }), (1, Field { name: \"C\", data_type: Float64 })]" | ||
| ); | ||
| } | ||
|
|
||
| #[test] | ||
| fn test_union_incompatible_modes() { | ||
| // create first union as Dense with Int32 and Utf8 | ||
| let int_array1 = Int32Array::from(vec![1, 2]); | ||
| let str_array1 = StringArray::from(vec!["a", "b"]); | ||
|
|
||
| let type_ids1 = [0, 1].into_iter().collect::<ScalarBuffer<i8>>(); | ||
| let offsets1 = [0, 0].into_iter().collect::<ScalarBuffer<i32>>(); | ||
|
|
||
| let union_fields1 = [ | ||
| (0, Arc::new(Field::new("A", DataType::Int32, false))), | ||
| (1, Arc::new(Field::new("B", DataType::Utf8, false))), | ||
| ] | ||
| .into_iter() | ||
| .collect::<UnionFields>(); | ||
|
|
||
| let children1 = vec![Arc::new(int_array1) as ArrayRef, Arc::new(str_array1)]; | ||
|
|
||
| let array1 = | ||
| UnionArray::try_new(union_fields1.clone(), type_ids1, Some(offsets1), children1) | ||
| .unwrap(); | ||
|
|
||
| // create second union as Sparse with same fields (Int32 and Utf8) | ||
| let int_array2 = Int32Array::from(vec![Some(3), None]); | ||
| let str_array2 = StringArray::from(vec![None, Some("c")]); | ||
|
|
||
| let type_ids2 = [0, 1].into_iter().collect::<ScalarBuffer<i8>>(); | ||
|
|
||
| let children2 = vec![Arc::new(int_array2) as ArrayRef, Arc::new(str_array2)]; | ||
|
|
||
| let array2 = UnionArray::try_new(union_fields1, type_ids2, None, children2).unwrap(); | ||
|
|
||
| let opts = SortOptions::default(); | ||
|
|
||
| let Result::Err(ArrowError::InvalidArgumentError(out)) = | ||
| make_comparator(&array1, &array2, opts) | ||
| else { | ||
| panic!("expected error when making comparator of union arrays with different modes"); | ||
| }; | ||
|
|
||
| assert_eq!( | ||
| &out, | ||
| "Cannot compare UnionArrays with different modes: left=Dense, right=Sparse" | ||
| ); | ||
| } | ||
| } | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is weird to have to re-check the DataTypes.
What would you think about adding
UnionArray::fields()andUnionArray::mode()methods to make the code easier to work with?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should be super quick to review: #8884
Somewhat related but it feels a bit weird that the following works without any notice to the user:
I feel like we could benefit from a bit more validation? We could leave
UnionFields::newbut also have aUnionFields::try_newthat checks the above 🤔There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, I think that sounds like a good idea to me
We can even deprecate
UnionFields::newto help people migrate overThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here it is: #8891
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here is another minor convenience improvement: #8895