Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
320 changes: 316 additions & 4 deletions arrow-ord/src/ord.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@ use arrow_array::cast::AsArray;
use arrow_array::types::*;
use arrow_array::*;
use arrow_buffer::{ArrowNativeType, NullBuffer};
use arrow_schema::{ArrowError, SortOptions};
use std::cmp::Ordering;
use arrow_schema::{ArrowError, DataType, SortOptions};
use std::{cmp::Ordering, collections::HashMap};

/// Compare the values at two arbitrary indices in two arrays.
pub type DynComparator = Box<dyn Fn(usize, usize) -> Ordering + Send + Sync>;
Expand Down Expand Up @@ -296,6 +296,78 @@ fn compare_struct(
Ok(f)
}

fn compare_union(
left: &dyn Array,
right: &dyn Array,
opts: SortOptions,
) -> Result<DynComparator, ArrowError> {
let left = left.as_union();
let right = right.as_union();

let (left_fields, left_mode) = match left.data_type() {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is weird to have to re-check the DataTypes.

What would you think about adding UnionArray::fields() and UnionArray::mode() methods to make the code easier to work with?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be super quick to review: #8884

Somewhat related but it feels a bit weird that the following works without any notice to the user:

#[test]
fn test_union_fields() {
    let ids = vec![0, 1, 2];
    let field = Field::new("a", DataType::Binary, true);

    // different length of ids and fields (we zip so we truncate the longer vec)
    let _out = UnionFields::new(ids.clone(), vec![field.clone()]);

    // duplicate fields associated with different type ids!
    let _out = UnionFields::new(ids, vec![field.clone(), field]);
}

I feel like we could benefit from a bit more validation? We could leave UnionFields::new but also have a UnionFields::try_new that checks the above 🤔

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I think that sounds like a good idea to me

We can even deprecate UnionFields::new to help people migrate over

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here it is: #8891

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here is another minor convenience improvement: #8895

DataType::Union(fields, mode) => (fields, mode),
_ => unreachable!(),
};
let (right_fields, right_mode) = match right.data_type() {
DataType::Union(fields, mode) => (fields, mode),
_ => unreachable!(),
};

if left_fields != right_fields {
return Err(ArrowError::InvalidArgumentError(format!(
"Cannot compare UnionArrays with different fields: left={:?}, right={:?}",
left_fields, right_fields
)));
}

if left_mode != right_mode {
return Err(ArrowError::InvalidArgumentError(format!(
"Cannot compare UnionArrays with different modes: left={:?}, right={:?}",
left_mode, right_mode
)));
}

let c_opts = child_opts(opts);

let mut field_comparators = HashMap::with_capacity(left_fields.len());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

rather than a hash map you could potentially just use a 128 valued Vec<> indexed by the typeids -- since typeid is i8 you know there can be at most 128 values that might be faster to lookup than hashing/hash table

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hm so this was my first thought/approach as well, but I decided to use a hashmap because it avoids superfluous memory usage for sparse sets

Plus, I don't think this is a very hot path, so any perf differences wouldn't be super meaningful


for (type_id, _field) in left_fields.iter() {
let left_child = left.child(type_id);
let right_child = right.child(type_id);
let cmp = make_comparator(left_child.as_ref(), right_child.as_ref(), c_opts)?;

field_comparators.insert(type_id, cmp);
}

let left_type_ids = left.type_ids().clone();
let right_type_ids = right.type_ids().clone();

let left_offsets = left.offsets().cloned();
let right_offsets = right.offsets().cloned();

let f = compare(left, right, opts, move |i, j| {
let left_type_id = left_type_ids[i];
let right_type_id = right_type_ids[j];

// first, compare by type_id
match left_type_id.cmp(&right_type_id) {
Ordering::Equal => {
// second, compare by values
let left_offset = left_offsets.as_ref().map(|o| o[i] as usize).unwrap_or(i);
let right_offset = right_offsets.as_ref().map(|o| o[j] as usize).unwrap_or(j);

let cmp = field_comparators
.get(&left_type_id)
.expect("type id not found in field_comparators");

cmp(left_offset, right_offset)
}
other => other,
}
});
Ok(f)
}

/// Returns a comparison function that compares two values at two different positions
/// between the two arrays.
///
Expand Down Expand Up @@ -412,6 +484,7 @@ pub fn make_comparator(
}
},
(Map(_, _), Map(_, _)) => compare_map(left, right, opts),
(Union(_, _), Union(_, _)) => compare_union(left, right, opts),
(lhs, rhs) => Err(ArrowError::InvalidArgumentError(match lhs == rhs {
true => format!("The data type type {lhs:?} has no natural order"),
false => "Can't compare arrays of different types".to_string(),
Expand All @@ -423,8 +496,8 @@ pub fn make_comparator(
mod tests {
use super::*;
use arrow_array::builder::{Int32Builder, ListBuilder, MapBuilder, StringBuilder};
use arrow_buffer::{IntervalDayTime, OffsetBuffer, i256};
use arrow_schema::{DataType, Field, Fields};
use arrow_buffer::{IntervalDayTime, OffsetBuffer, ScalarBuffer, i256};
use arrow_schema::{DataType, Field, Fields, UnionFields};
use half::f16;
use std::sync::Arc;

Expand Down Expand Up @@ -1189,4 +1262,243 @@ mod tests {
}
}
}

#[test]
fn test_dense_union() {
// create a dense union array with Int32 (type_id = 0) and Utf8 (type_id=1)
// the values are: [1, "b", 2, "a", 3]
// type_ids are: [0, 1, 0, 1, 0]
// offsets are: [0, 0, 1, 1, 2] from [1, 2, 3] and ["b", "a"]
let int_array = Int32Array::from(vec![1, 2, 3]);
let str_array = StringArray::from(vec!["b", "a"]);

let type_ids = [0, 1, 0, 1, 0].into_iter().collect::<ScalarBuffer<i8>>();
let offsets = [0, 0, 1, 1, 2].into_iter().collect::<ScalarBuffer<i32>>();

let union_fields = [
(0, Arc::new(Field::new("A", DataType::Int32, false))),
(1, Arc::new(Field::new("B", DataType::Utf8, false))),
]
.into_iter()
.collect::<UnionFields>();

let children = vec![Arc::new(int_array) as ArrayRef, Arc::new(str_array)];

let array1 =
UnionArray::try_new(union_fields.clone(), type_ids, Some(offsets), children).unwrap();

// create a second array: [2, "a", 1, "c"]
// type ids are: [0, 1, 0, 1]
// offsets are: [0, 0, 1, 1] from [2, 1] and ["a", "c"]
let int_array2 = Int32Array::from(vec![2, 1]);
let str_array2 = StringArray::from(vec!["a", "c"]);
let type_ids2 = [0, 1, 0, 1].into_iter().collect::<ScalarBuffer<i8>>();
let offsets2 = [0, 0, 1, 1].into_iter().collect::<ScalarBuffer<i32>>();

let children2 = vec![Arc::new(int_array2) as ArrayRef, Arc::new(str_array2)];

let array2 =
UnionArray::try_new(union_fields, type_ids2, Some(offsets2), children2).unwrap();

let opts = SortOptions {
descending: false,
nulls_first: true,
};

// comparing
// [1, "b", 2, "a", 3]
// [2, "a", 1, "c"]
let cmp = make_comparator(&array1, &array2, opts).unwrap();

// array1[0] = (type_id=0, value=1)
// array2[0] = (type_id=0, value=2)
assert_eq!(cmp(0, 0), Ordering::Less); // 1 < 2

// array1[0] = (type_id=0, value=1)
// array2[1] = (type_id=1, value="a")
assert_eq!(cmp(0, 1), Ordering::Less); // type_id 0 < 1

// array1[1] = (type_id=1, value="b")
// array2[1] = (type_id=1, value="a")
assert_eq!(cmp(1, 1), Ordering::Greater); // "b" > "a"

// array1[2] = (type_id=0, value=2)
// array2[0] = (type_id=0, value=2)
assert_eq!(cmp(2, 0), Ordering::Equal); // 2 == 2

// array1[3] = (type_id=1, value="a")
// array2[1] = (type_id=1, value="a")
assert_eq!(cmp(3, 1), Ordering::Equal); // "a" == "a"

// array1[1] = (type_id=1, value="b")
// array2[3] = (type_id=1, value="c")
assert_eq!(cmp(1, 3), Ordering::Less); // "b" < "c"

let opts_desc = SortOptions {
descending: true,
nulls_first: true,
};
let cmp_desc = make_comparator(&array1, &array2, opts_desc).unwrap();

assert_eq!(cmp_desc(0, 0), Ordering::Greater); // 1 > 2 (reversed)
assert_eq!(cmp_desc(0, 1), Ordering::Greater); // type_id 0 < 1, reversed to Greater
assert_eq!(cmp_desc(1, 1), Ordering::Less); // "b" < "a" (reversed)
}

#[test]
fn test_sparse_union() {
// create a sparse union array with Int32 (type_id=0) and Utf8 (type_id=1)
// values: [1, "b", 3]
// note, in sparse unions, child arrays have the same length as the union
let int_array = Int32Array::from(vec![Some(1), None, Some(3)]);
let str_array = StringArray::from(vec![None, Some("b"), None]);
let type_ids = [0, 1, 0].into_iter().collect::<ScalarBuffer<i8>>();

let union_fields = [
(0, Arc::new(Field::new("a", DataType::Int32, false))),
(1, Arc::new(Field::new("b", DataType::Utf8, false))),
]
.into_iter()
.collect::<UnionFields>();

let children = vec![Arc::new(int_array) as ArrayRef, Arc::new(str_array)];

let array = UnionArray::try_new(union_fields, type_ids, None, children).unwrap();

let opts = SortOptions::default();
let cmp = make_comparator(&array, &array, opts).unwrap();

// array[0] = (type_id=0, value=1), array[2] = (type_id=0, value=3)
assert_eq!(cmp(0, 2), Ordering::Less); // 1 < 3
// array[0] = (type_id=0, value=1), array[1] = (type_id=1, value="b")
assert_eq!(cmp(0, 1), Ordering::Less); // type_id 0 < 1
}

#[test]
#[should_panic(expected = "index out of bounds")]
fn test_union_out_of_bounds() {
// create a dense union array with 3 elements
let int_array = Int32Array::from(vec![1, 2]);
let str_array = StringArray::from(vec!["a"]);

let type_ids = [0, 1, 0].into_iter().collect::<ScalarBuffer<i8>>();
let offsets = [0, 0, 1].into_iter().collect::<ScalarBuffer<i32>>();

let union_fields = [
(0, Arc::new(Field::new("A", DataType::Int32, false))),
(1, Arc::new(Field::new("B", DataType::Utf8, false))),
]
.into_iter()
.collect::<UnionFields>();

let children = vec![Arc::new(int_array) as ArrayRef, Arc::new(str_array)];

let array = UnionArray::try_new(union_fields, type_ids, Some(offsets), children).unwrap();

let opts = SortOptions::default();
let cmp = make_comparator(&array, &array, opts).unwrap();

// oob
cmp(0, 3);
}

#[test]
fn test_union_incompatible_fields() {
// create first union with Int32 and Utf8
let int_array1 = Int32Array::from(vec![1, 2]);
let str_array1 = StringArray::from(vec!["a", "b"]);

let type_ids1 = [0, 1].into_iter().collect::<ScalarBuffer<i8>>();
let offsets1 = [0, 0].into_iter().collect::<ScalarBuffer<i32>>();

let union_fields1 = [
(0, Arc::new(Field::new("A", DataType::Int32, false))),
(1, Arc::new(Field::new("B", DataType::Utf8, false))),
]
.into_iter()
.collect::<UnionFields>();

let children1 = vec![Arc::new(int_array1) as ArrayRef, Arc::new(str_array1)];

let array1 =
UnionArray::try_new(union_fields1, type_ids1, Some(offsets1), children1).unwrap();

// create second union with Int32 and Float64 (incompatible with first)
let int_array2 = Int32Array::from(vec![3, 4]);
let float_array2 = Float64Array::from(vec![1.0, 2.0]);

let type_ids2 = [0, 1].into_iter().collect::<ScalarBuffer<i8>>();
let offsets2 = [0, 0].into_iter().collect::<ScalarBuffer<i32>>();

let union_fields2 = [
(0, Arc::new(Field::new("A", DataType::Int32, false))),
(1, Arc::new(Field::new("C", DataType::Float64, false))),
]
.into_iter()
.collect::<UnionFields>();

let children2 = vec![Arc::new(int_array2) as ArrayRef, Arc::new(float_array2)];

let array2 =
UnionArray::try_new(union_fields2, type_ids2, Some(offsets2), children2).unwrap();

let opts = SortOptions::default();

let Result::Err(ArrowError::InvalidArgumentError(out)) =
make_comparator(&array1, &array2, opts)
else {
panic!("expected error when making comparator of incompatible union arrays");
};

assert_eq!(
&out,
"Cannot compare UnionArrays with different fields: left=[(0, Field { name: \"A\", data_type: Int32 }), (1, Field { name: \"B\", data_type: Utf8 })], right=[(0, Field { name: \"A\", data_type: Int32 }), (1, Field { name: \"C\", data_type: Float64 })]"
);
}

#[test]
fn test_union_incompatible_modes() {
// create first union as Dense with Int32 and Utf8
let int_array1 = Int32Array::from(vec![1, 2]);
let str_array1 = StringArray::from(vec!["a", "b"]);

let type_ids1 = [0, 1].into_iter().collect::<ScalarBuffer<i8>>();
let offsets1 = [0, 0].into_iter().collect::<ScalarBuffer<i32>>();

let union_fields1 = [
(0, Arc::new(Field::new("A", DataType::Int32, false))),
(1, Arc::new(Field::new("B", DataType::Utf8, false))),
]
.into_iter()
.collect::<UnionFields>();

let children1 = vec![Arc::new(int_array1) as ArrayRef, Arc::new(str_array1)];

let array1 =
UnionArray::try_new(union_fields1.clone(), type_ids1, Some(offsets1), children1)
.unwrap();

// create second union as Sparse with same fields (Int32 and Utf8)
let int_array2 = Int32Array::from(vec![Some(3), None]);
let str_array2 = StringArray::from(vec![None, Some("c")]);

let type_ids2 = [0, 1].into_iter().collect::<ScalarBuffer<i8>>();

let children2 = vec![Arc::new(int_array2) as ArrayRef, Arc::new(str_array2)];

let array2 = UnionArray::try_new(union_fields1, type_ids2, None, children2).unwrap();

let opts = SortOptions::default();

let Result::Err(ArrowError::InvalidArgumentError(out)) =
make_comparator(&array1, &array2, opts)
else {
panic!("expected error when making comparator of union arrays with different modes");
};

assert_eq!(
&out,
"Cannot compare UnionArrays with different modes: left=Dense, right=Sparse"
);
}
}
Loading