From 1664f95ffc32889d4734e0fe12700321b914870b Mon Sep 17 00:00:00 2001 From: Alfonso Subiotto Marques Date: Fri, 16 May 2025 11:39:21 +0200 Subject: [PATCH] arrow-select: add support for merging primitive dictionary values Previously, should_merge_dictionaries would always return false in the ptr_eq closure creation match arm for types that were not {Large}{Utf8,Binary}. This could lead to excessive memory usage. --- arrow-select/src/concat.rs | 43 +++++++++++++++++++++++++ arrow-select/src/dictionary.rs | 58 ++++++++++++++++++++++++++-------- 2 files changed, 88 insertions(+), 13 deletions(-) diff --git a/arrow-select/src/concat.rs b/arrow-select/src/concat.rs index 486afbd14467..1c99caef5ad9 100644 --- a/arrow-select/src/concat.rs +++ b/arrow-select/src/concat.rs @@ -1081,6 +1081,49 @@ mod tests { assert!((30..40).contains(&values_len), "{values_len}") } + #[test] + fn test_primitive_dictionary_merge() { + // Same value repeated 5 times. + let keys = vec![1; 5]; + let values = (10..20).collect::>(); + let dict = DictionaryArray::new( + Int8Array::from(keys.clone()), + Arc::new(Int32Array::from(values.clone())), + ); + let other = DictionaryArray::new( + Int8Array::from(keys.clone()), + Arc::new(Int32Array::from(values.clone())), + ); + + let result_same_dictionary = concat(&[&dict, &dict]).unwrap(); + // Verify pointer equality check succeeds, and therefore the + // dictionaries are not merged. A single values buffer should be reused + // in this case. + assert!(dict.values().to_data().ptr_eq( + &result_same_dictionary + .as_dictionary::() + .values() + .to_data() + )); + assert_eq!( + result_same_dictionary + .as_dictionary::() + .values() + .len(), + values.len(), + ); + + let result_cloned_dictionary = concat(&[&dict, &other]).unwrap(); + // Should have only 1 underlying value since all keys reference it. + assert_eq!( + result_cloned_dictionary + .as_dictionary::() + .values() + .len(), + 1 + ); + } + #[test] fn test_concat_string_sizes() { let a: LargeStringArray = ((0..150).map(|_| Some("foo"))).collect(); diff --git a/arrow-select/src/dictionary.rs b/arrow-select/src/dictionary.rs index 57aed644fe0c..c5773b16a486 100644 --- a/arrow-select/src/dictionary.rs +++ b/arrow-select/src/dictionary.rs @@ -18,12 +18,13 @@ use crate::interleave::interleave; use ahash::RandomState; use arrow_array::builder::BooleanBufferBuilder; -use arrow_array::cast::AsArray; use arrow_array::types::{ - ArrowDictionaryKeyType, BinaryType, ByteArrayType, LargeBinaryType, LargeUtf8Type, Utf8Type, + ArrowDictionaryKeyType, ArrowPrimitiveType, BinaryType, ByteArrayType, LargeBinaryType, + LargeUtf8Type, Utf8Type, }; -use arrow_array::{Array, ArrayRef, DictionaryArray, GenericByteArray}; -use arrow_buffer::{ArrowNativeType, BooleanBuffer, ScalarBuffer}; +use arrow_array::{cast::AsArray, downcast_primitive}; +use arrow_array::{Array, ArrayRef, DictionaryArray, GenericByteArray, PrimitiveArray}; +use arrow_buffer::{ArrowNativeType, BooleanBuffer, ScalarBuffer, ToByteSlice}; use arrow_schema::{ArrowError, DataType}; /// A best effort interner that maintains a fixed number of buckets @@ -102,7 +103,7 @@ fn bytes_ptr_eq(a: &dyn Array, b: &dyn Array) -> bool { } /// A type-erased function that compares two array for pointer equality -type PtrEq = dyn Fn(&dyn Array, &dyn Array) -> bool; +type PtrEq = fn(&dyn Array, &dyn Array) -> bool; /// A weak heuristic of whether to merge dictionary values that aims to only /// perform the expensive merge computation when it is likely to yield at least @@ -115,12 +116,17 @@ pub fn should_merge_dictionary_values( ) -> bool { use DataType::*; let first_values = dictionaries[0].values().as_ref(); - let ptr_eq: Box = match first_values.data_type() { - Utf8 => Box::new(bytes_ptr_eq::), - LargeUtf8 => Box::new(bytes_ptr_eq::), - Binary => Box::new(bytes_ptr_eq::), - LargeBinary => Box::new(bytes_ptr_eq::), - _ => return false, + let ptr_eq: PtrEq = match first_values.data_type() { + Utf8 => bytes_ptr_eq::, + LargeUtf8 => bytes_ptr_eq::, + Binary => bytes_ptr_eq::, + LargeBinary => bytes_ptr_eq::, + dt => { + if !dt.is_primitive() { + return false; + } + |a, b| a.to_data().ptr_eq(&b.to_data()) + } }; let mut single_dictionary = true; @@ -233,17 +239,43 @@ fn compute_values_mask( builder.finish() } +/// Process primitive array values to bytes +fn masked_primitives_to_bytes<'a, T: ArrowPrimitiveType>( + array: &'a PrimitiveArray, + mask: &BooleanBuffer, +) -> Vec<(usize, Option<&'a [u8]>)> +where + T::Native: ToByteSlice, +{ + let mut out = Vec::with_capacity(mask.count_set_bits()); + let values = array.values(); + for idx in mask.set_indices() { + out.push(( + idx, + array.is_valid(idx).then_some(values[idx].to_byte_slice()), + )) + } + out +} + +macro_rules! masked_primitive_to_bytes_helper { + ($t:ty, $array:expr, $mask:expr) => { + masked_primitives_to_bytes::<$t>($array.as_primitive(), $mask) + }; +} + /// Return a Vec containing for each set index in `mask`, the index and byte value of that index fn get_masked_values<'a>( array: &'a dyn Array, mask: &BooleanBuffer, ) -> Vec<(usize, Option<&'a [u8]>)> { - match array.data_type() { + downcast_primitive! { + array.data_type() => (masked_primitive_to_bytes_helper, array, mask), DataType::Utf8 => masked_bytes(array.as_string::(), mask), DataType::LargeUtf8 => masked_bytes(array.as_string::(), mask), DataType::Binary => masked_bytes(array.as_binary::(), mask), DataType::LargeBinary => masked_bytes(array.as_binary::(), mask), - _ => unimplemented!(), + _ => unimplemented!("Dictionary merging for type {} is not implemented", array.data_type()), } }