-
Notifications
You must be signed in to change notification settings - Fork 1.1k
Improve StringArray(Utf8) sort performance (~2-4x faster) #7860
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
0a8db05
5019b69
0c665fb
3f30e41
0bbb117
0f7e353
e77f208
8ce60b1
360ca41
481a2a3
85d9218
84646e5
b0c7448
e5fefef
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -345,12 +345,88 @@ fn sort_bytes<T: ByteArrayType>( | |
| options: SortOptions, | ||
| limit: Option<usize>, | ||
| ) -> UInt32Array { | ||
| let mut valids = value_indices | ||
| // Note: Why do we use 4‑byte prefix? | ||
| // Compute the 4‑byte prefix in BE order, or left‑pad if shorter. | ||
| // Most byte‐sequences differ in their first few bytes, so by | ||
| // comparing up to 4 bytes as a single u32 we avoid the overhead | ||
| // of a full lexicographical compare for the vast majority of cases. | ||
|
|
||
| // 1. Build a vector of (index, prefix, length) tuples | ||
| let mut valids: Vec<(u32, u32, u64)> = value_indices | ||
| .into_iter() | ||
| .map(|index| (index, values.value(index as usize).as_ref())) | ||
| .collect::<Vec<(u32, &[u8])>>(); | ||
| .map(|idx| unsafe { | ||
| let slice: &[u8] = values.value_unchecked(idx as usize).as_ref(); | ||
| let len = slice.len() as u64; | ||
| // Compute the 4‑byte prefix in BE order, or left‑pad if shorter | ||
| let prefix = if slice.len() >= 4 { | ||
| let raw = std::ptr::read_unaligned(slice.as_ptr() as *const u32); | ||
| u32::from_be(raw) | ||
| } else if slice.is_empty() { | ||
| // Handle empty slice case to avoid shift overflow | ||
| 0u32 | ||
| } else { | ||
| let mut v = 0u32; | ||
| for &b in slice { | ||
| v = (v << 8) | (b as u32); | ||
| } | ||
| // Safe shift: slice.len() is in range [1, 3], so shift is in range [8, 24] | ||
| v << (8 * (4 - slice.len())) | ||
| }; | ||
| (idx, prefix, len) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Alternatively or additionaly we could store the
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thank you @Dandandan for this idea, i tried now, but it show 30% performance decrease: diff --git a/arrow-ord/src/sort.rs b/arrow-ord/src/sort.rs
index 093c52d867..29800663a0 100644
--- a/arrow-ord/src/sort.rs
+++ b/arrow-ord/src/sort.rs
@@ -301,14 +301,15 @@ fn sort_bytes<T: ByteArrayType>(
// comparing up to 4 bytes as a single u32 we avoid the overhead
// of a full lexicographical compare for the vast majority of cases.
- // 1. Build a vector of (index, prefix, length) tuples
- let mut valids: Vec<(u32, u32, u64)> = value_indices
+ // 1. Build a vector of (idx, prefix, is_small, slice) tuples
+ let mut valids: Vec<(u32, u32, bool, &[u8])> = value_indices
.into_iter()
.map(|idx| {
let slice: &[u8] = values.value(idx as usize).as_ref();
- let len = slice.len() as u64;
- // Compute the 4‑byte prefix in BE order, or left‑pad if shorter
- let prefix = if slice.len() >= 4 {
+ // store the bool for whether the slice is smaller than 4 bytes
+ let is_small = slice.len() < 4;
+ // prefix: if the slice is smaller than 4 bytes, left-pad it with zeros,
+ let prefix = if !is_small {
let raw = unsafe { std::ptr::read_unaligned(slice.as_ptr() as *const u32) };
u32::from_be(raw)
} else {
@@ -318,7 +319,7 @@ fn sort_bytes<T: ByteArrayType>(
}
v << (8 * (4 - slice.len()))
};
- (idx, prefix, len)
+ (idx, prefix, is_small, slice)
})
.collect();
@@ -328,27 +329,24 @@ fn sort_bytes<T: ByteArrayType>(
_ => valids.len(),
};
- // 3. Comparator: compare prefix, then (when both slices shorter than 4) length, otherwise full slice
- let cmp_bytes = |a: &(u32, u32, u64), b: &(u32, u32, u64)| {
- let (ia, pa, la) = *a;
- let (ib, pb, lb) = *b;
- // 3.1 prefix (first 4 bytes)
- let ord = pa.cmp(&pb);
- if ord != Ordering::Equal {
- return ord;
+ // 3. Comparator: compare prefix first, then for both “small” slices compare length, and finally full lexicographical compare
+ let cmp_bytes = |a: &(u32, u32, bool, &[u8]), b: &(u32, u32, bool, &[u8])| {
+ let (_ia, pa, sa_small, sa) = a;
+ let (_ib, pb, sb_small, sb) = b;
+ // 3.1 Compare the 4‑byte prefix
+ match pa.cmp(&pb) {
+ Ordering::Equal => (),
+ non_eq => return non_eq,
}
- // 3.2 only if both slices had length < 4 (so prefix was padded)
- // length compare only when prefix was padded (i.e. original length < 4)
- if la < 4 || lb < 4 {
- let ord = la.cmp(&lb);
- if ord != Ordering::Equal {
- return ord;
+ // 3.2 If both slices were shorter than 4 bytes, compare their actual lengths
+ if *sa_small && *sb_small {
+ match sa.len().cmp(&sb.len()) {
+ Ordering::Equal => (),
+ non_eq => return non_eq,
}
}
- // 3.3 full lexicographical compare
- let a_bytes: &[u8] = values.value(ia as usize).as_ref();
- let b_bytes: &[u8] = values.value(ib as usize).as_ref();
- a_bytes.cmp(b_bytes)
+ // 3.3 Otherwise, do a full byte‑wise lexicographical comparison
+ sa.cmp(sb)
};
// 4. Partially sort according to ascending/descending
@@ -366,9 +364,9 @@ fn sort_bytes<T: ByteArrayType>(
if options.nulls_first {
out.extend_from_slice(&nulls[..nulls.len().min(out_limit)]);
let rem = out_limit - out.len();
- out.extend(valids.iter().map(|&(i, _, _)| i).take(rem));
+ out.extend(valids.iter().map(|&(i, _, _, _)| i).take(rem));
} else {
- out.extend(valids.iter().map(|&(i, _, _)| i).take(out_limit));
+ out.extend(valids.iter().map(|&(i, _, _, _)| i).take(out_limit));
let rem = out_limit - out.len();
out.extend_from_slice(&nulls[..rem]);
} |
||
| }) | ||
| .collect(); | ||
|
|
||
| sort_impl(options, &mut valids, &nulls, limit, Ord::cmp).into() | ||
| // 2. compute the number of non-null entries to partially sort | ||
| let vlimit = match (limit, options.nulls_first) { | ||
| (Some(l), true) => l.saturating_sub(nulls.len()).min(valids.len()), | ||
| _ => valids.len(), | ||
| }; | ||
|
|
||
| // 3. Comparator: compare prefix, then (when both slices shorter than 4) length, otherwise full slice | ||
| let cmp_bytes = |a: &(u32, u32, u64), b: &(u32, u32, u64)| unsafe { | ||
| let (ia, pa, la) = *a; | ||
| let (ib, pb, lb) = *b; | ||
| // 3.1 prefix (first 4 bytes) | ||
| let ord = pa.cmp(&pb); | ||
| if ord != Ordering::Equal { | ||
| return ord; | ||
| } | ||
| // 3.2 only if both slices had length < 4 (so prefix was padded) | ||
| if la < 4 || lb < 4 { | ||
| let ord = la.cmp(&lb); | ||
| if ord != Ordering::Equal { | ||
| return ord; | ||
| } | ||
| } | ||
| // 3.3 full lexicographical compare | ||
| let a_bytes: &[u8] = values.value_unchecked(ia as usize).as_ref(); | ||
| let b_bytes: &[u8] = values.value_unchecked(ib as usize).as_ref(); | ||
| a_bytes.cmp(b_bytes) | ||
| }; | ||
|
|
||
| // 4. Partially sort according to ascending/descending | ||
| if !options.descending { | ||
| sort_unstable_by(&mut valids, vlimit, cmp_bytes); | ||
| } else { | ||
| sort_unstable_by(&mut valids, vlimit, |x, y| cmp_bytes(x, y).reverse()); | ||
| } | ||
|
|
||
| // 5. Assemble nulls and sorted indices into final output | ||
| let total = valids.len() + nulls.len(); | ||
| let out_limit = limit.unwrap_or(total).min(total); | ||
| let mut out = Vec::with_capacity(out_limit); | ||
|
|
||
| if options.nulls_first { | ||
| out.extend_from_slice(&nulls[..nulls.len().min(out_limit)]); | ||
| let rem = out_limit - out.len(); | ||
| out.extend(valids.iter().map(|&(i, _, _)| i).take(rem)); | ||
| } else { | ||
| out.extend(valids.iter().map(|&(i, _, _)| i).take(out_limit)); | ||
| let rem = out_limit - out.len(); | ||
| out.extend_from_slice(&nulls[..rem]); | ||
| } | ||
|
|
||
| out.into() | ||
| } | ||
|
|
||
| fn sort_byte_view<T: ByteViewType>( | ||
|
|
@@ -4841,4 +4917,301 @@ mod tests { | |
| assert_eq!(valid, vec![0, 2]); | ||
| assert_eq!(nulls, vec![1, 3]); | ||
| } | ||
|
|
||
| // Test specific edge case strings that exercise the 4-byte prefix logic | ||
| #[test] | ||
| fn test_specific_edge_cases() { | ||
| let test_cases = vec![ | ||
| // Key test cases for lengths 1-4 that test prefix padding | ||
| "a", "ab", "ba", "baa", "abba", "abbc", "abc", "cda", | ||
| // Test cases where first 4 bytes are same but subsequent bytes differ | ||
| "abcd", "abcde", "abcdf", "abcdaaa", "abcdbbb", | ||
| // Test cases with length < 4 that require padding | ||
| "z", "za", "zaa", "zaaa", "zaaab", // Empty string | ||
| "", // Test various length combinations with same prefix | ||
| "test", "test1", "test12", "test123", "test1234", | ||
| ]; | ||
|
|
||
| // Use standard library sort as reference | ||
| let mut expected = test_cases.clone(); | ||
| expected.sort(); | ||
|
|
||
| // Use our sorting algorithm | ||
| let string_array = StringArray::from(test_cases.clone()); | ||
| let indices: Vec<u32> = (0..test_cases.len() as u32).collect(); | ||
| let result = sort_bytes( | ||
| &string_array, | ||
| indices, | ||
| vec![], // no nulls | ||
| SortOptions::default(), | ||
| None, | ||
| ); | ||
|
|
||
| // Verify results | ||
| let sorted_strings: Vec<&str> = result | ||
| .values() | ||
| .iter() | ||
| .map(|&idx| test_cases[idx as usize]) | ||
| .collect(); | ||
|
|
||
| assert_eq!(sorted_strings, expected); | ||
| } | ||
|
|
||
| // Test sorting correctness for different length combinations | ||
| #[test] | ||
| fn test_length_combinations() { | ||
| let test_cases = vec![ | ||
| // Focus on testing strings of length 1-4, as these affect padding logic | ||
| ("", 0), | ||
| ("a", 1), | ||
| ("ab", 2), | ||
| ("abc", 3), | ||
| ("abcd", 4), | ||
| ("abcde", 5), | ||
| ("b", 1), | ||
| ("ba", 2), | ||
| ("bab", 3), | ||
| ("babc", 4), | ||
| ("babcd", 5), | ||
| // Test same prefix with different lengths | ||
| ("test", 4), | ||
| ("test1", 5), | ||
| ("test12", 6), | ||
| ("test123", 7), | ||
| ]; | ||
|
|
||
| let strings: Vec<&str> = test_cases.iter().map(|(s, _)| *s).collect(); | ||
| let mut expected = strings.clone(); | ||
| expected.sort(); | ||
|
|
||
| let string_array = StringArray::from(strings.clone()); | ||
| let indices: Vec<u32> = (0..strings.len() as u32).collect(); | ||
| let result = sort_bytes(&string_array, indices, vec![], SortOptions::default(), None); | ||
|
|
||
| let sorted_strings: Vec<&str> = result | ||
| .values() | ||
| .iter() | ||
| .map(|&idx| strings[idx as usize]) | ||
| .collect(); | ||
|
|
||
| assert_eq!(sorted_strings, expected); | ||
| } | ||
|
|
||
| // Test UTF-8 string handling | ||
| #[test] | ||
| fn test_utf8_strings() { | ||
| let test_cases = vec![ | ||
| "a", | ||
| "你", // 3-byte UTF-8 character | ||
| "你好", // 6 bytes | ||
| "你好世界", // 12 bytes | ||
| "🎉", // 4-byte emoji | ||
| "🎉🎊", // 8 bytes | ||
| "café", // Contains accent character | ||
| "naïve", | ||
| "Москва", // Cyrillic script | ||
| "東京", // Japanese kanji | ||
| "한국", // Korean | ||
| ]; | ||
|
|
||
| let mut expected = test_cases.clone(); | ||
| expected.sort(); | ||
|
|
||
| let string_array = StringArray::from(test_cases.clone()); | ||
| let indices: Vec<u32> = (0..test_cases.len() as u32).collect(); | ||
| let result = sort_bytes(&string_array, indices, vec![], SortOptions::default(), None); | ||
|
|
||
| let sorted_strings: Vec<&str> = result | ||
| .values() | ||
| .iter() | ||
| .map(|&idx| test_cases[idx as usize]) | ||
| .collect(); | ||
|
|
||
| assert_eq!(sorted_strings, expected); | ||
| } | ||
|
|
||
| // Fuzz testing: generate random UTF-8 strings and verify sort correctness | ||
| #[test] | ||
| fn test_fuzz_random_strings() { | ||
| let mut rng = StdRng::seed_from_u64(42); // Fixed seed for reproducibility | ||
|
|
||
| for _ in 0..100 { | ||
| // Run 100 rounds of fuzz testing | ||
| let mut test_strings = Vec::new(); | ||
|
|
||
| // Generate 20-50 random strings | ||
| let num_strings = rng.random_range(20..=50); | ||
|
|
||
| for _ in 0..num_strings { | ||
| let string = generate_random_string(&mut rng); | ||
| test_strings.push(string); | ||
| } | ||
|
|
||
| // Use standard library sort as reference | ||
| let mut expected = test_strings.clone(); | ||
| expected.sort(); | ||
|
|
||
| // Use our sorting algorithm | ||
| let string_array = StringArray::from(test_strings.clone()); | ||
| let indices: Vec<u32> = (0..test_strings.len() as u32).collect(); | ||
| let result = sort_bytes(&string_array, indices, vec![], SortOptions::default(), None); | ||
|
|
||
| let sorted_strings: Vec<String> = result | ||
| .values() | ||
| .iter() | ||
| .map(|&idx| test_strings[idx as usize].clone()) | ||
| .collect(); | ||
|
|
||
| assert_eq!( | ||
| sorted_strings, expected, | ||
| "Fuzz test failed with input: {test_strings:?}" | ||
| ); | ||
| } | ||
| } | ||
|
|
||
| // Helper function to generate random UTF-8 strings | ||
| fn generate_random_string(rng: &mut StdRng) -> String { | ||
| // Bias towards generating short strings, especially length 1-4 | ||
| let length = if rng.random_bool(0.6) { | ||
| rng.random_range(0..=4) // 60% probability for 0-4 length strings | ||
| } else { | ||
| rng.random_range(5..=20) // 40% probability for longer strings | ||
| }; | ||
|
|
||
| if length == 0 { | ||
| return String::new(); | ||
| } | ||
|
|
||
| let mut result = String::new(); | ||
| let mut current_len = 0; | ||
|
|
||
| while current_len < length { | ||
| let c = generate_random_char(rng); | ||
| let char_len = c.len_utf8(); | ||
|
|
||
| // Ensure we don't exceed target length | ||
| if current_len + char_len <= length { | ||
| result.push(c); | ||
| current_len += char_len; | ||
| } else { | ||
| // If adding this character would exceed length, fill with ASCII | ||
| let remaining = length - current_len; | ||
| for _ in 0..remaining { | ||
| result.push(rng.random_range('a'..='z')); | ||
| current_len += 1; | ||
| } | ||
| break; | ||
| } | ||
| } | ||
|
|
||
| result | ||
| } | ||
|
|
||
| // Generate random characters (including various UTF-8 characters) | ||
| fn generate_random_char(rng: &mut StdRng) -> char { | ||
| match rng.random_range(0..10) { | ||
| 0..=5 => rng.random_range('a'..='z'), // 60% ASCII lowercase | ||
| 6 => rng.random_range('A'..='Z'), // 10% ASCII uppercase | ||
| 7 => rng.random_range('0'..='9'), // 10% digits | ||
| 8 => { | ||
| // 10% Chinese characters | ||
| let chinese_chars = ['你', '好', '世', '界', '测', '试', '中', '文']; | ||
| chinese_chars[rng.random_range(0..chinese_chars.len())] | ||
| } | ||
| 9 => { | ||
| // 10% other Unicode characters (single `char`s) | ||
| let special_chars = ['é', 'ï', '🎉', '🎊', 'α', 'β', 'γ']; | ||
| special_chars[rng.random_range(0..special_chars.len())] | ||
| } | ||
| _ => unreachable!(), | ||
| } | ||
| } | ||
|
|
||
| // Test descending sort order | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nice |
||
| #[test] | ||
| fn test_descending_sort() { | ||
| let test_cases = vec!["a", "ab", "ba", "baa", "abba", "abbc", "abc", "cda"]; | ||
|
|
||
| let mut expected = test_cases.clone(); | ||
| expected.sort(); | ||
| expected.reverse(); // Descending order | ||
|
|
||
| let string_array = StringArray::from(test_cases.clone()); | ||
| let indices: Vec<u32> = (0..test_cases.len() as u32).collect(); | ||
| let result = sort_bytes( | ||
| &string_array, | ||
| indices, | ||
| vec![], | ||
| SortOptions { | ||
| descending: true, | ||
| nulls_first: false, | ||
| }, | ||
| None, | ||
| ); | ||
|
|
||
| let sorted_strings: Vec<&str> = result | ||
| .values() | ||
| .iter() | ||
| .map(|&idx| test_cases[idx as usize]) | ||
| .collect(); | ||
|
|
||
| assert_eq!(sorted_strings, expected); | ||
| } | ||
|
|
||
| // Stress test: large number of strings with same prefix | ||
| #[test] | ||
| fn test_same_prefix_stress() { | ||
| let mut test_cases = Vec::new(); | ||
| let prefix = "same"; | ||
|
|
||
| // Generate many strings with the same prefix | ||
| for i in 0..1000 { | ||
| test_cases.push(format!("{prefix}{i:04}")); | ||
| } | ||
|
|
||
| let mut expected = test_cases.clone(); | ||
| expected.sort(); | ||
|
|
||
| let string_array = StringArray::from(test_cases.clone()); | ||
| let indices: Vec<u32> = (0..test_cases.len() as u32).collect(); | ||
| let result = sort_bytes(&string_array, indices, vec![], SortOptions::default(), None); | ||
|
|
||
| let sorted_strings: Vec<String> = result | ||
| .values() | ||
| .iter() | ||
| .map(|&idx| test_cases[idx as usize].clone()) | ||
| .collect(); | ||
|
|
||
| assert_eq!(sorted_strings, expected); | ||
| } | ||
|
|
||
| // Test limit parameter | ||
| #[test] | ||
| fn test_with_limit() { | ||
| let test_cases = vec!["z", "y", "x", "w", "v", "u", "t", "s"]; | ||
| let limit = 3; | ||
|
|
||
| let mut expected = test_cases.clone(); | ||
| expected.sort(); | ||
| expected.truncate(limit); | ||
|
|
||
| let string_array = StringArray::from(test_cases.clone()); | ||
| let indices: Vec<u32> = (0..test_cases.len() as u32).collect(); | ||
| let result = sort_bytes( | ||
| &string_array, | ||
| indices, | ||
| vec![], | ||
| SortOptions::default(), | ||
| Some(limit), | ||
| ); | ||
|
|
||
| let sorted_strings: Vec<&str> = result | ||
| .values() | ||
| .iter() | ||
| .map(|&idx| test_cases[idx as usize]) | ||
| .collect(); | ||
|
|
||
| assert_eq!(sorted_strings, expected); | ||
| assert_eq!(sorted_strings.len(), limit); | ||
| } | ||
| } | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wonder if we have to worry about cutting
utf8codepoints when lopping off the first few bytes 🤔There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good point @alamb , i think it's safe for us to do the cut here because:
In our sort_bytes implementation we always pull the first up to 4 bytes as raw &[u8] (via values.value(...).as_ref()), pack them into a u32, and compare at the byte level. That means we never try to decode those 4 bytes back into a &str, so there’s no risk of a UTF-8 panic in the hot path.
And we don't need to use it, just to compare.
I will also try to add this in fuzz testing, thanks!