Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
381 changes: 377 additions & 4 deletions arrow-ord/src/sort.rs
Original file line number Diff line number Diff line change
Expand Up @@ -345,12 +345,88 @@ fn sort_bytes<T: ByteArrayType>(
options: SortOptions,
limit: Option<usize>,
) -> UInt32Array {
let mut valids = value_indices
// Note: Why do we use 4‑byte prefix?
// Compute the 4‑byte prefix in BE order, or left‑pad if shorter.
// Most byte‐sequences differ in their first few bytes, so by
// comparing up to 4 bytes as a single u32 we avoid the overhead
// of a full lexicographical compare for the vast majority of cases.

// 1. Build a vector of (index, prefix, length) tuples
let mut valids: Vec<(u32, u32, u64)> = value_indices
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if we have to worry about cuttingutf8 codepoints when lopping off the first few bytes 🤔

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point @alamb , i think it's safe for us to do the cut here because:

In our sort_bytes implementation we always pull the first up to 4 bytes as raw &[u8] (via values.value(...).as_ref()), pack them into a u32, and compare at the byte level. That means we never try to decode those 4 bytes back into a &str, so there’s no risk of a UTF-8 panic in the hot path.

And we don't need to use it, just to compare.

I will also try to add this in fuzz testing, thanks!

.into_iter()
.map(|index| (index, values.value(index as usize).as_ref()))
.collect::<Vec<(u32, &[u8])>>();
.map(|idx| unsafe {
let slice: &[u8] = values.value_unchecked(idx as usize).as_ref();
let len = slice.len() as u64;
// Compute the 4‑byte prefix in BE order, or left‑pad if shorter
let prefix = if slice.len() >= 4 {
let raw = std::ptr::read_unaligned(slice.as_ptr() as *const u32);
u32::from_be(raw)
} else if slice.is_empty() {
// Handle empty slice case to avoid shift overflow
0u32
} else {
let mut v = 0u32;
for &b in slice {
v = (v << 8) | (b as u32);
}
// Safe shift: slice.len() is in range [1, 3], so shift is in range [8, 24]
v << (8 * (4 - slice.len()))
};
(idx, prefix, len)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think len it can store len < bool as well ("is small") as the actual length is not used? This will also avoid doing this la < 4 check in the sort, so might be slightly faster.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Alternatively or additionaly we could store the &[u8] instead of the index so it doesn't have to retrieve it via values.value again in the sort.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you @Dandandan for this idea, i tried now, but it show 30% performance decrease:

diff --git a/arrow-ord/src/sort.rs b/arrow-ord/src/sort.rs
index 093c52d867..29800663a0 100644
--- a/arrow-ord/src/sort.rs
+++ b/arrow-ord/src/sort.rs
@@ -301,14 +301,15 @@ fn sort_bytes<T: ByteArrayType>(
     // comparing up to 4 bytes as a single u32 we avoid the overhead
     // of a full lexicographical compare for the vast majority of cases.
 
-    // 1. Build a vector of (index, prefix, length) tuples
-    let mut valids: Vec<(u32, u32, u64)> = value_indices
+    // 1. Build a vector of (idx, prefix, is_small, slice) tuples
+    let mut valids: Vec<(u32, u32, bool, &[u8])> = value_indices
         .into_iter()
         .map(|idx| {
             let slice: &[u8] = values.value(idx as usize).as_ref();
-            let len = slice.len() as u64;
-            // Compute the 4‑byte prefix in BE order, or left‑pad if shorter
-            let prefix = if slice.len() >= 4 {
+            // store the bool for whether the slice is smaller than 4 bytes
+            let is_small = slice.len() < 4;
+            // prefix: if the slice is smaller than 4 bytes, left-pad it with zeros,
+            let prefix = if !is_small {
                 let raw = unsafe { std::ptr::read_unaligned(slice.as_ptr() as *const u32) };
                 u32::from_be(raw)
             } else {
@@ -318,7 +319,7 @@ fn sort_bytes<T: ByteArrayType>(
                 }
                 v << (8 * (4 - slice.len()))
             };
-            (idx, prefix, len)
+            (idx, prefix, is_small, slice)
         })
         .collect();
 
@@ -328,27 +329,24 @@ fn sort_bytes<T: ByteArrayType>(
         _ => valids.len(),
     };
 
-    // 3. Comparator: compare prefix, then (when both slices shorter than 4) length, otherwise full slice
-    let cmp_bytes = |a: &(u32, u32, u64), b: &(u32, u32, u64)| {
-        let (ia, pa, la) = *a;
-        let (ib, pb, lb) = *b;
-        // 3.1 prefix (first 4 bytes)
-        let ord = pa.cmp(&pb);
-        if ord != Ordering::Equal {
-            return ord;
+    // 3. Comparator: compare prefix first, then for both “small” slices compare length, and finally full lexicographical compare
+    let cmp_bytes = |a: &(u32, u32, bool, &[u8]), b: &(u32, u32, bool, &[u8])| {
+        let (_ia, pa, sa_small, sa) = a;
+        let (_ib, pb, sb_small, sb) = b;
+        // 3.1 Compare the 4‑byte prefix
+        match pa.cmp(&pb) {
+            Ordering::Equal => (),
+            non_eq => return non_eq,
         }
-        // 3.2 only if both slices had length < 4 (so prefix was padded)
-        // length compare only when prefix was padded (i.e. original length < 4)
-        if la < 4 || lb < 4 {
-            let ord = la.cmp(&lb);
-            if ord != Ordering::Equal {
-                return ord;
+        // 3.2 If both slices were shorter than 4 bytes, compare their actual lengths
+        if *sa_small && *sb_small {
+            match sa.len().cmp(&sb.len()) {
+                Ordering::Equal => (),
+                non_eq => return non_eq,
             }
         }
-        // 3.3 full lexicographical compare
-        let a_bytes: &[u8] = values.value(ia as usize).as_ref();
-        let b_bytes: &[u8] = values.value(ib as usize).as_ref();
-        a_bytes.cmp(b_bytes)
+        // 3.3 Otherwise, do a full byte‑wise lexicographical comparison
+        sa.cmp(sb)
     };
 
     // 4. Partially sort according to ascending/descending
@@ -366,9 +364,9 @@ fn sort_bytes<T: ByteArrayType>(
     if options.nulls_first {
         out.extend_from_slice(&nulls[..nulls.len().min(out_limit)]);
         let rem = out_limit - out.len();
-        out.extend(valids.iter().map(|&(i, _, _)| i).take(rem));
+        out.extend(valids.iter().map(|&(i, _, _, _)| i).take(rem));
     } else {
-        out.extend(valids.iter().map(|&(i, _, _)| i).take(out_limit));
+        out.extend(valids.iter().map(|&(i, _, _, _)| i).take(out_limit));
         let rem = out_limit - out.len();
         out.extend_from_slice(&nulls[..rem]);
     }

})
.collect();

sort_impl(options, &mut valids, &nulls, limit, Ord::cmp).into()
// 2. compute the number of non-null entries to partially sort
let vlimit = match (limit, options.nulls_first) {
(Some(l), true) => l.saturating_sub(nulls.len()).min(valids.len()),
_ => valids.len(),
};

// 3. Comparator: compare prefix, then (when both slices shorter than 4) length, otherwise full slice
let cmp_bytes = |a: &(u32, u32, u64), b: &(u32, u32, u64)| unsafe {
let (ia, pa, la) = *a;
let (ib, pb, lb) = *b;
// 3.1 prefix (first 4 bytes)
let ord = pa.cmp(&pb);
if ord != Ordering::Equal {
return ord;
}
// 3.2 only if both slices had length < 4 (so prefix was padded)
if la < 4 || lb < 4 {
let ord = la.cmp(&lb);
if ord != Ordering::Equal {
return ord;
}
}
// 3.3 full lexicographical compare
let a_bytes: &[u8] = values.value_unchecked(ia as usize).as_ref();
let b_bytes: &[u8] = values.value_unchecked(ib as usize).as_ref();
a_bytes.cmp(b_bytes)
};

// 4. Partially sort according to ascending/descending
if !options.descending {
sort_unstable_by(&mut valids, vlimit, cmp_bytes);
} else {
sort_unstable_by(&mut valids, vlimit, |x, y| cmp_bytes(x, y).reverse());
}

// 5. Assemble nulls and sorted indices into final output
let total = valids.len() + nulls.len();
let out_limit = limit.unwrap_or(total).min(total);
let mut out = Vec::with_capacity(out_limit);

if options.nulls_first {
out.extend_from_slice(&nulls[..nulls.len().min(out_limit)]);
let rem = out_limit - out.len();
out.extend(valids.iter().map(|&(i, _, _)| i).take(rem));
} else {
out.extend(valids.iter().map(|&(i, _, _)| i).take(out_limit));
let rem = out_limit - out.len();
out.extend_from_slice(&nulls[..rem]);
}

out.into()
}

fn sort_byte_view<T: ByteViewType>(
Expand Down Expand Up @@ -4841,4 +4917,301 @@ mod tests {
assert_eq!(valid, vec![0, 2]);
assert_eq!(nulls, vec![1, 3]);
}

// Test specific edge case strings that exercise the 4-byte prefix logic
#[test]
fn test_specific_edge_cases() {
let test_cases = vec![
// Key test cases for lengths 1-4 that test prefix padding
"a", "ab", "ba", "baa", "abba", "abbc", "abc", "cda",
// Test cases where first 4 bytes are same but subsequent bytes differ
"abcd", "abcde", "abcdf", "abcdaaa", "abcdbbb",
// Test cases with length < 4 that require padding
"z", "za", "zaa", "zaaa", "zaaab", // Empty string
"", // Test various length combinations with same prefix
"test", "test1", "test12", "test123", "test1234",
];

// Use standard library sort as reference
let mut expected = test_cases.clone();
expected.sort();

// Use our sorting algorithm
let string_array = StringArray::from(test_cases.clone());
let indices: Vec<u32> = (0..test_cases.len() as u32).collect();
let result = sort_bytes(
&string_array,
indices,
vec![], // no nulls
SortOptions::default(),
None,
);

// Verify results
let sorted_strings: Vec<&str> = result
.values()
.iter()
.map(|&idx| test_cases[idx as usize])
.collect();

assert_eq!(sorted_strings, expected);
}

// Test sorting correctness for different length combinations
#[test]
fn test_length_combinations() {
let test_cases = vec![
// Focus on testing strings of length 1-4, as these affect padding logic
("", 0),
("a", 1),
("ab", 2),
("abc", 3),
("abcd", 4),
("abcde", 5),
("b", 1),
("ba", 2),
("bab", 3),
("babc", 4),
("babcd", 5),
// Test same prefix with different lengths
("test", 4),
("test1", 5),
("test12", 6),
("test123", 7),
];

let strings: Vec<&str> = test_cases.iter().map(|(s, _)| *s).collect();
let mut expected = strings.clone();
expected.sort();

let string_array = StringArray::from(strings.clone());
let indices: Vec<u32> = (0..strings.len() as u32).collect();
let result = sort_bytes(&string_array, indices, vec![], SortOptions::default(), None);

let sorted_strings: Vec<&str> = result
.values()
.iter()
.map(|&idx| strings[idx as usize])
.collect();

assert_eq!(sorted_strings, expected);
}

// Test UTF-8 string handling
#[test]
fn test_utf8_strings() {
let test_cases = vec![
"a",
"你", // 3-byte UTF-8 character
"你好", // 6 bytes
"你好世界", // 12 bytes
"🎉", // 4-byte emoji
"🎉🎊", // 8 bytes
"café", // Contains accent character
"naïve",
"Москва", // Cyrillic script
"東京", // Japanese kanji
"한국", // Korean
];

let mut expected = test_cases.clone();
expected.sort();

let string_array = StringArray::from(test_cases.clone());
let indices: Vec<u32> = (0..test_cases.len() as u32).collect();
let result = sort_bytes(&string_array, indices, vec![], SortOptions::default(), None);

let sorted_strings: Vec<&str> = result
.values()
.iter()
.map(|&idx| test_cases[idx as usize])
.collect();

assert_eq!(sorted_strings, expected);
}

// Fuzz testing: generate random UTF-8 strings and verify sort correctness
#[test]
fn test_fuzz_random_strings() {
let mut rng = StdRng::seed_from_u64(42); // Fixed seed for reproducibility

for _ in 0..100 {
// Run 100 rounds of fuzz testing
let mut test_strings = Vec::new();

// Generate 20-50 random strings
let num_strings = rng.random_range(20..=50);

for _ in 0..num_strings {
let string = generate_random_string(&mut rng);
test_strings.push(string);
}

// Use standard library sort as reference
let mut expected = test_strings.clone();
expected.sort();

// Use our sorting algorithm
let string_array = StringArray::from(test_strings.clone());
let indices: Vec<u32> = (0..test_strings.len() as u32).collect();
let result = sort_bytes(&string_array, indices, vec![], SortOptions::default(), None);

let sorted_strings: Vec<String> = result
.values()
.iter()
.map(|&idx| test_strings[idx as usize].clone())
.collect();

assert_eq!(
sorted_strings, expected,
"Fuzz test failed with input: {test_strings:?}"
);
}
}

// Helper function to generate random UTF-8 strings
fn generate_random_string(rng: &mut StdRng) -> String {
// Bias towards generating short strings, especially length 1-4
let length = if rng.random_bool(0.6) {
rng.random_range(0..=4) // 60% probability for 0-4 length strings
} else {
rng.random_range(5..=20) // 40% probability for longer strings
};

if length == 0 {
return String::new();
}

let mut result = String::new();
let mut current_len = 0;

while current_len < length {
let c = generate_random_char(rng);
let char_len = c.len_utf8();

// Ensure we don't exceed target length
if current_len + char_len <= length {
result.push(c);
current_len += char_len;
} else {
// If adding this character would exceed length, fill with ASCII
let remaining = length - current_len;
for _ in 0..remaining {
result.push(rng.random_range('a'..='z'));
current_len += 1;
}
break;
}
}

result
}

// Generate random characters (including various UTF-8 characters)
fn generate_random_char(rng: &mut StdRng) -> char {
match rng.random_range(0..10) {
0..=5 => rng.random_range('a'..='z'), // 60% ASCII lowercase
6 => rng.random_range('A'..='Z'), // 10% ASCII uppercase
7 => rng.random_range('0'..='9'), // 10% digits
8 => {
// 10% Chinese characters
let chinese_chars = ['你', '好', '世', '界', '测', '试', '中', '文'];
chinese_chars[rng.random_range(0..chinese_chars.len())]
}
9 => {
// 10% other Unicode characters (single `char`s)
let special_chars = ['é', 'ï', '🎉', '🎊', 'α', 'β', 'γ'];
special_chars[rng.random_range(0..special_chars.len())]
}
_ => unreachable!(),
}
}

// Test descending sort order
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice

#[test]
fn test_descending_sort() {
let test_cases = vec!["a", "ab", "ba", "baa", "abba", "abbc", "abc", "cda"];

let mut expected = test_cases.clone();
expected.sort();
expected.reverse(); // Descending order

let string_array = StringArray::from(test_cases.clone());
let indices: Vec<u32> = (0..test_cases.len() as u32).collect();
let result = sort_bytes(
&string_array,
indices,
vec![],
SortOptions {
descending: true,
nulls_first: false,
},
None,
);

let sorted_strings: Vec<&str> = result
.values()
.iter()
.map(|&idx| test_cases[idx as usize])
.collect();

assert_eq!(sorted_strings, expected);
}

// Stress test: large number of strings with same prefix
#[test]
fn test_same_prefix_stress() {
let mut test_cases = Vec::new();
let prefix = "same";

// Generate many strings with the same prefix
for i in 0..1000 {
test_cases.push(format!("{prefix}{i:04}"));
}

let mut expected = test_cases.clone();
expected.sort();

let string_array = StringArray::from(test_cases.clone());
let indices: Vec<u32> = (0..test_cases.len() as u32).collect();
let result = sort_bytes(&string_array, indices, vec![], SortOptions::default(), None);

let sorted_strings: Vec<String> = result
.values()
.iter()
.map(|&idx| test_cases[idx as usize].clone())
.collect();

assert_eq!(sorted_strings, expected);
}

// Test limit parameter
#[test]
fn test_with_limit() {
let test_cases = vec!["z", "y", "x", "w", "v", "u", "t", "s"];
let limit = 3;

let mut expected = test_cases.clone();
expected.sort();
expected.truncate(limit);

let string_array = StringArray::from(test_cases.clone());
let indices: Vec<u32> = (0..test_cases.len() as u32).collect();
let result = sort_bytes(
&string_array,
indices,
vec![],
SortOptions::default(),
Some(limit),
);

let sorted_strings: Vec<&str> = result
.values()
.iter()
.map(|&idx| test_cases[idx as usize])
.collect();

assert_eq!(sorted_strings, expected);
assert_eq!(sorted_strings.len(), limit);
}
}
Loading