diff --git a/arrow-ord/src/sort.rs b/arrow-ord/src/sort.rs index a405aa7a3735..ba026af637d7 100644 --- a/arrow-ord/src/sort.rs +++ b/arrow-ord/src/sort.rs @@ -345,12 +345,88 @@ fn sort_bytes( options: SortOptions, limit: Option, ) -> UInt32Array { - let mut valids = value_indices + // Note: Why do we use 4‑byte prefix? + // Compute the 4‑byte prefix in BE order, or left‑pad if shorter. + // Most byte‐sequences differ in their first few bytes, so by + // comparing up to 4 bytes as a single u32 we avoid the overhead + // of a full lexicographical compare for the vast majority of cases. + + // 1. Build a vector of (index, prefix, length) tuples + let mut valids: Vec<(u32, u32, u64)> = value_indices .into_iter() - .map(|index| (index, values.value(index as usize).as_ref())) - .collect::>(); + .map(|idx| unsafe { + let slice: &[u8] = values.value_unchecked(idx as usize).as_ref(); + let len = slice.len() as u64; + // Compute the 4‑byte prefix in BE order, or left‑pad if shorter + let prefix = if slice.len() >= 4 { + let raw = std::ptr::read_unaligned(slice.as_ptr() as *const u32); + u32::from_be(raw) + } else if slice.is_empty() { + // Handle empty slice case to avoid shift overflow + 0u32 + } else { + let mut v = 0u32; + for &b in slice { + v = (v << 8) | (b as u32); + } + // Safe shift: slice.len() is in range [1, 3], so shift is in range [8, 24] + v << (8 * (4 - slice.len())) + }; + (idx, prefix, len) + }) + .collect(); - sort_impl(options, &mut valids, &nulls, limit, Ord::cmp).into() + // 2. compute the number of non-null entries to partially sort + let vlimit = match (limit, options.nulls_first) { + (Some(l), true) => l.saturating_sub(nulls.len()).min(valids.len()), + _ => valids.len(), + }; + + // 3. Comparator: compare prefix, then (when both slices shorter than 4) length, otherwise full slice + let cmp_bytes = |a: &(u32, u32, u64), b: &(u32, u32, u64)| unsafe { + let (ia, pa, la) = *a; + let (ib, pb, lb) = *b; + // 3.1 prefix (first 4 bytes) + let ord = pa.cmp(&pb); + if ord != Ordering::Equal { + return ord; + } + // 3.2 only if both slices had length < 4 (so prefix was padded) + if la < 4 || lb < 4 { + let ord = la.cmp(&lb); + if ord != Ordering::Equal { + return ord; + } + } + // 3.3 full lexicographical compare + let a_bytes: &[u8] = values.value_unchecked(ia as usize).as_ref(); + let b_bytes: &[u8] = values.value_unchecked(ib as usize).as_ref(); + a_bytes.cmp(b_bytes) + }; + + // 4. Partially sort according to ascending/descending + if !options.descending { + sort_unstable_by(&mut valids, vlimit, cmp_bytes); + } else { + sort_unstable_by(&mut valids, vlimit, |x, y| cmp_bytes(x, y).reverse()); + } + + // 5. Assemble nulls and sorted indices into final output + let total = valids.len() + nulls.len(); + let out_limit = limit.unwrap_or(total).min(total); + let mut out = Vec::with_capacity(out_limit); + + if options.nulls_first { + out.extend_from_slice(&nulls[..nulls.len().min(out_limit)]); + let rem = out_limit - out.len(); + out.extend(valids.iter().map(|&(i, _, _)| i).take(rem)); + } else { + out.extend(valids.iter().map(|&(i, _, _)| i).take(out_limit)); + let rem = out_limit - out.len(); + out.extend_from_slice(&nulls[..rem]); + } + + out.into() } fn sort_byte_view( @@ -4841,4 +4917,301 @@ mod tests { assert_eq!(valid, vec![0, 2]); assert_eq!(nulls, vec![1, 3]); } + + // Test specific edge case strings that exercise the 4-byte prefix logic + #[test] + fn test_specific_edge_cases() { + let test_cases = vec![ + // Key test cases for lengths 1-4 that test prefix padding + "a", "ab", "ba", "baa", "abba", "abbc", "abc", "cda", + // Test cases where first 4 bytes are same but subsequent bytes differ + "abcd", "abcde", "abcdf", "abcdaaa", "abcdbbb", + // Test cases with length < 4 that require padding + "z", "za", "zaa", "zaaa", "zaaab", // Empty string + "", // Test various length combinations with same prefix + "test", "test1", "test12", "test123", "test1234", + ]; + + // Use standard library sort as reference + let mut expected = test_cases.clone(); + expected.sort(); + + // Use our sorting algorithm + let string_array = StringArray::from(test_cases.clone()); + let indices: Vec = (0..test_cases.len() as u32).collect(); + let result = sort_bytes( + &string_array, + indices, + vec![], // no nulls + SortOptions::default(), + None, + ); + + // Verify results + let sorted_strings: Vec<&str> = result + .values() + .iter() + .map(|&idx| test_cases[idx as usize]) + .collect(); + + assert_eq!(sorted_strings, expected); + } + + // Test sorting correctness for different length combinations + #[test] + fn test_length_combinations() { + let test_cases = vec![ + // Focus on testing strings of length 1-4, as these affect padding logic + ("", 0), + ("a", 1), + ("ab", 2), + ("abc", 3), + ("abcd", 4), + ("abcde", 5), + ("b", 1), + ("ba", 2), + ("bab", 3), + ("babc", 4), + ("babcd", 5), + // Test same prefix with different lengths + ("test", 4), + ("test1", 5), + ("test12", 6), + ("test123", 7), + ]; + + let strings: Vec<&str> = test_cases.iter().map(|(s, _)| *s).collect(); + let mut expected = strings.clone(); + expected.sort(); + + let string_array = StringArray::from(strings.clone()); + let indices: Vec = (0..strings.len() as u32).collect(); + let result = sort_bytes(&string_array, indices, vec![], SortOptions::default(), None); + + let sorted_strings: Vec<&str> = result + .values() + .iter() + .map(|&idx| strings[idx as usize]) + .collect(); + + assert_eq!(sorted_strings, expected); + } + + // Test UTF-8 string handling + #[test] + fn test_utf8_strings() { + let test_cases = vec![ + "a", + "你", // 3-byte UTF-8 character + "你好", // 6 bytes + "你好世界", // 12 bytes + "🎉", // 4-byte emoji + "🎉🎊", // 8 bytes + "café", // Contains accent character + "naïve", + "Москва", // Cyrillic script + "東京", // Japanese kanji + "한국", // Korean + ]; + + let mut expected = test_cases.clone(); + expected.sort(); + + let string_array = StringArray::from(test_cases.clone()); + let indices: Vec = (0..test_cases.len() as u32).collect(); + let result = sort_bytes(&string_array, indices, vec![], SortOptions::default(), None); + + let sorted_strings: Vec<&str> = result + .values() + .iter() + .map(|&idx| test_cases[idx as usize]) + .collect(); + + assert_eq!(sorted_strings, expected); + } + + // Fuzz testing: generate random UTF-8 strings and verify sort correctness + #[test] + fn test_fuzz_random_strings() { + let mut rng = StdRng::seed_from_u64(42); // Fixed seed for reproducibility + + for _ in 0..100 { + // Run 100 rounds of fuzz testing + let mut test_strings = Vec::new(); + + // Generate 20-50 random strings + let num_strings = rng.random_range(20..=50); + + for _ in 0..num_strings { + let string = generate_random_string(&mut rng); + test_strings.push(string); + } + + // Use standard library sort as reference + let mut expected = test_strings.clone(); + expected.sort(); + + // Use our sorting algorithm + let string_array = StringArray::from(test_strings.clone()); + let indices: Vec = (0..test_strings.len() as u32).collect(); + let result = sort_bytes(&string_array, indices, vec![], SortOptions::default(), None); + + let sorted_strings: Vec = result + .values() + .iter() + .map(|&idx| test_strings[idx as usize].clone()) + .collect(); + + assert_eq!( + sorted_strings, expected, + "Fuzz test failed with input: {test_strings:?}" + ); + } + } + + // Helper function to generate random UTF-8 strings + fn generate_random_string(rng: &mut StdRng) -> String { + // Bias towards generating short strings, especially length 1-4 + let length = if rng.random_bool(0.6) { + rng.random_range(0..=4) // 60% probability for 0-4 length strings + } else { + rng.random_range(5..=20) // 40% probability for longer strings + }; + + if length == 0 { + return String::new(); + } + + let mut result = String::new(); + let mut current_len = 0; + + while current_len < length { + let c = generate_random_char(rng); + let char_len = c.len_utf8(); + + // Ensure we don't exceed target length + if current_len + char_len <= length { + result.push(c); + current_len += char_len; + } else { + // If adding this character would exceed length, fill with ASCII + let remaining = length - current_len; + for _ in 0..remaining { + result.push(rng.random_range('a'..='z')); + current_len += 1; + } + break; + } + } + + result + } + + // Generate random characters (including various UTF-8 characters) + fn generate_random_char(rng: &mut StdRng) -> char { + match rng.random_range(0..10) { + 0..=5 => rng.random_range('a'..='z'), // 60% ASCII lowercase + 6 => rng.random_range('A'..='Z'), // 10% ASCII uppercase + 7 => rng.random_range('0'..='9'), // 10% digits + 8 => { + // 10% Chinese characters + let chinese_chars = ['你', '好', '世', '界', '测', '试', '中', '文']; + chinese_chars[rng.random_range(0..chinese_chars.len())] + } + 9 => { + // 10% other Unicode characters (single `char`s) + let special_chars = ['é', 'ï', '🎉', '🎊', 'α', 'β', 'γ']; + special_chars[rng.random_range(0..special_chars.len())] + } + _ => unreachable!(), + } + } + + // Test descending sort order + #[test] + fn test_descending_sort() { + let test_cases = vec!["a", "ab", "ba", "baa", "abba", "abbc", "abc", "cda"]; + + let mut expected = test_cases.clone(); + expected.sort(); + expected.reverse(); // Descending order + + let string_array = StringArray::from(test_cases.clone()); + let indices: Vec = (0..test_cases.len() as u32).collect(); + let result = sort_bytes( + &string_array, + indices, + vec![], + SortOptions { + descending: true, + nulls_first: false, + }, + None, + ); + + let sorted_strings: Vec<&str> = result + .values() + .iter() + .map(|&idx| test_cases[idx as usize]) + .collect(); + + assert_eq!(sorted_strings, expected); + } + + // Stress test: large number of strings with same prefix + #[test] + fn test_same_prefix_stress() { + let mut test_cases = Vec::new(); + let prefix = "same"; + + // Generate many strings with the same prefix + for i in 0..1000 { + test_cases.push(format!("{prefix}{i:04}")); + } + + let mut expected = test_cases.clone(); + expected.sort(); + + let string_array = StringArray::from(test_cases.clone()); + let indices: Vec = (0..test_cases.len() as u32).collect(); + let result = sort_bytes(&string_array, indices, vec![], SortOptions::default(), None); + + let sorted_strings: Vec = result + .values() + .iter() + .map(|&idx| test_cases[idx as usize].clone()) + .collect(); + + assert_eq!(sorted_strings, expected); + } + + // Test limit parameter + #[test] + fn test_with_limit() { + let test_cases = vec!["z", "y", "x", "w", "v", "u", "t", "s"]; + let limit = 3; + + let mut expected = test_cases.clone(); + expected.sort(); + expected.truncate(limit); + + let string_array = StringArray::from(test_cases.clone()); + let indices: Vec = (0..test_cases.len() as u32).collect(); + let result = sort_bytes( + &string_array, + indices, + vec![], + SortOptions::default(), + Some(limit), + ); + + let sorted_strings: Vec<&str> = result + .values() + .iter() + .map(|&idx| test_cases[idx as usize]) + .collect(); + + assert_eq!(sorted_strings, expected); + assert_eq!(sorted_strings.len(), limit); + } }