Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
215 changes: 43 additions & 172 deletions rust/arrow/src/array/array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2115,28 +2115,32 @@ impl From<(Vec<(Field, ArrayRef)>, Buffer, usize)> for StructArray {
/// Example **with nullable** data:
///
/// ```
/// use arrow::array::DictionaryArray;
/// use arrow::array::{DictionaryArray, Int8Array};
/// use arrow::datatypes::Int8Type;
/// let test = vec!["a", "a", "b", "c"];
/// let array : DictionaryArray<Int8Type> = test.iter().map(|&x| if x == "b" {None} else {Some(x)}).collect();
/// assert_eq!(array.keys().collect::<Vec<Option<i8>>>(), vec![Some(0), Some(0), None, Some(1)]);
/// assert_eq!(array.keys(), &Int8Array::from(vec![Some(0), Some(0), None, Some(1)]));
/// ```
///
/// Example **without nullable** data:
///
/// ```
/// use arrow::array::DictionaryArray;
/// use arrow::array::{DictionaryArray, Int8Array};
/// use arrow::datatypes::Int8Type;
/// let test = vec!["a", "a", "b", "c"];
/// let array : DictionaryArray<Int8Type> = test.into_iter().collect();
/// assert_eq!(array.keys().collect::<Vec<Option<i8>>>(), vec![Some(0), Some(0), Some(1), Some(2)]);
/// assert_eq!(array.keys(), &Int8Array::from(vec![0, 0, 1, 2]));
/// ```
pub struct DictionaryArray<K: ArrowPrimitiveType> {
/// Array of keys, stored as a PrimitiveArray<K>.
/// Data of this dictionary. Note that this is _not_ compatible with the C Data interface,
/// as, in the current implementation, `values` below are the first child of this struct.
data: ArrayDataRef,

/// Pointer to the key values.
raw_values: RawPtrBox<K::Native>,
/// The keys of this dictionary. These are constructed from the buffer and null bitmap
/// of `data`.
/// Also, note that these do not correspond to the true values of this array. Rather, they map
/// to the real values.
keys: PrimitiveArray<K>,

/// Array of dictionary values (can by any DataType).
values: ArrayRef,
Expand All @@ -2145,112 +2149,10 @@ pub struct DictionaryArray<K: ArrowPrimitiveType> {
is_ordered: bool,
}

#[derive(Debug)]
enum Draining {
Ready,
Iterating,
Finished,
}

#[derive(Debug)]
pub struct NullableIter<'a, T> {
data: &'a ArrayDataRef, // TODO: Use a pointer to the null bitmap.
ptr: *const T,
i: usize,
len: usize,
draining: Draining,
}

impl<'a, T> std::iter::Iterator for NullableIter<'a, T>
where
T: Clone,
{
type Item = Option<T>;

fn next(&mut self) -> Option<Self::Item> {
let i = self.i;
if i >= self.len {
None
} else if self.data.is_null(i) {
self.i += 1;
Some(None)
} else {
self.i += 1;
unsafe { Some(Some((&*self.ptr.add(i)).clone())) }
}
}

fn size_hint(&self) -> (usize, Option<usize>) {
(self.len, Some(self.len))
}

fn nth(&mut self, n: usize) -> Option<Self::Item> {
let i = self.i;
if i + n >= self.len {
self.i = self.len;
None
} else if self.data.is_null(i + n) {
self.i += n + 1;
Some(None)
} else {
self.i += n + 1;
unsafe { Some(Some((&*self.ptr.add(i + n)).clone())) }
}
}
}

impl<'a, T> std::iter::DoubleEndedIterator for NullableIter<'a, T>
where
T: Clone,
{
fn next_back(&mut self) -> Option<Self::Item> {
match self.draining {
Draining::Ready => {
self.draining = Draining::Iterating;
self.i = self.len - 1;
self.next_back()
}
Draining::Iterating => {
let i = self.i;
if i >= self.len {
None
} else if self.data.is_null(i) {
self.i = self.i.checked_sub(1).unwrap_or_else(|| {
self.draining = Draining::Finished;
0_usize
});
Some(None)
} else {
match i.checked_sub(1) {
Some(idx) => {
self.i = idx;
unsafe { Some(Some((&*self.ptr.add(i)).clone())) }
}
_ => {
self.draining = Draining::Finished;
unsafe { Some(Some((&*self.ptr).clone())) }
}
}
}
}
Draining::Finished => {
self.draining = Draining::Ready;
None
}
}
}
}

impl<'a, K: ArrowPrimitiveType> DictionaryArray<K> {
/// Return an iterator to the keys of this dictionary.
pub fn keys(&self) -> NullableIter<'_, K::Native> {
NullableIter::<'_, K::Native> {
data: &self.data,
ptr: unsafe { self.raw_values.get().add(self.data.offset()) },
i: 0,
len: self.data.len(),
draining: Draining::Ready,
}
pub fn keys(&self) -> &PrimitiveArray<K> {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks Jorge, I like this option more

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

me too. that's way better.

&self.keys
}

/// Returns an array view of the keys of this dictionary
Expand Down Expand Up @@ -2291,12 +2193,12 @@ impl<'a, K: ArrowPrimitiveType> DictionaryArray<K> {

/// The length of the dictionary is the length of the keys array.
pub fn len(&self) -> usize {
self.data.len()
self.keys.len()
}

/// Whether this dictionary is empty
pub fn is_empty(&self) -> bool {
self.data.is_empty()
self.keys.is_empty()
}

// Currently exists for compatibility purposes with Arrow IPC.
Expand All @@ -2319,13 +2221,24 @@ impl<T: ArrowPrimitiveType> From<ArrayDataRef> for DictionaryArray<T> {
"DictionaryArray should contain a single child array (values)."
);

let raw_values = data.buffers()[0].raw_data();
let dtype: &DataType = data.data_type();
let values = make_array(data.child_data()[0].clone());
if let DataType::Dictionary(_, _) = dtype {
if let DataType::Dictionary(key_data_type, _) = data.data_type() {
if key_data_type.as_ref() != &T::DATA_TYPE {
panic!("DictionaryArray's data type must match.")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
panic!("DictionaryArray's data type must match.")
unreachable!("DictionaryArray's data type must match.")

Since the former is good for defensive programming but doesn't convey the idea.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the suggestions.

Isn't unreachable used when the program arrives at an inconsistent state?

IMO, in this case, we are checking user input (this function is public) and ensure that we will not reach an inconsistent state (in the same way assert_eq does). assert_eq calls panic!, which is why I also used panic! here.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think that unreachable is only for the inconsistent state. But we can leave it as panic here. I was also unsure about the user-facing API's panicking behavior. Especially in array methods with forced asserts. We should prefer Result than direct asserts, but you know... That's also yet another topic.

};
// create a zero-copy of the keys' data
let keys = PrimitiveArray::<T>::from(Arc::new(ArrayData::new(
T::DATA_TYPE,
data.len(),
Some(data.null_count()),
data.null_buffer().cloned(),
data.offset(),
data.buffers().to_vec(),
vec![],
)));
let values = make_array(data.child_data()[0].clone());
Self {
data,
raw_values: RawPtrBox::new(raw_values as *const T::Native),
keys,
values,
is_ordered: false,
}
Expand Down Expand Up @@ -2396,32 +2309,25 @@ impl<T: ArrowPrimitiveType> Array for DictionaryArray<T> {
&self.data
}

/// Returns the total number of bytes of memory occupied by the buffers owned by this [DictionaryArray].
fn get_buffer_memory_size(&self) -> usize {
self.data.get_buffer_memory_size() + self.values().get_buffer_memory_size()
// Since both `keys` and `values` derive (are references from) `data`, we only need to account for `data`.
self.data.get_buffer_memory_size()
}

/// Returns the total number of bytes of memory occupied physically by this [DictionaryArray].
fn get_array_memory_size(&self) -> usize {
self.data.get_array_memory_size()
+ self.values().get_array_memory_size()
+ self.keys.get_array_memory_size()
+ self.values.get_array_memory_size()
+ mem::size_of_val(self)
}
}

impl<T: ArrowPrimitiveType> fmt::Debug for DictionaryArray<T> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
const MAX_LEN: usize = 10;
let keys: Vec<_> = self.keys().take(MAX_LEN).collect();
let elipsis = if self.keys().count() > MAX_LEN {
"..."
} else {
""
};
writeln!(
f,
"DictionaryArray {{keys: {:?}{} values: {:?}}}",
keys, elipsis, self.values
"DictionaryArray {{keys: {:?} values: {:?}}}",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice, I'm assuming that we're relying on the keys formatting now that it's a PrimitiveArray

self.keys, self.values
)
}
}
Expand Down Expand Up @@ -3084,23 +2990,7 @@ mod tests {
// Null count only makes sense in terms of the component arrays.
assert_eq!(0, dict_array.null_count());
assert_eq!(0, dict_array.values().null_count());
assert_eq!(Some(Some(3)), dict_array.keys().nth(1));
assert_eq!(Some(Some(4)), dict_array.keys().nth(2));

assert_eq!(
dict_array.keys().collect::<Vec<Option<i16>>>(),
vec![Some(2), Some(3), Some(4)]
);

assert_eq!(
dict_array.keys().rev().collect::<Vec<Option<i16>>>(),
vec![Some(4), Some(3), Some(2)]
);

assert_eq!(
dict_array.keys().rev().rev().collect::<Vec<Option<i16>>>(),
vec![Some(2), Some(3), Some(4)]
);
assert_eq!(dict_array.keys(), &Int16Array::from(vec![2_i16, 3, 4]));

// Now test with a non-zero offset
let dict_data = ArrayData::builder(dict_data_type)
Expand All @@ -3115,26 +3005,7 @@ mod tests {
assert_eq!(value_data, values.data());
assert_eq!(DataType::Int8, dict_array.value_type());
assert_eq!(2, dict_array.len());
assert_eq!(Some(Some(3)), dict_array.keys().nth(0));
assert_eq!(Some(Some(4)), dict_array.keys().nth(1));

assert_eq!(
dict_array.keys().collect::<Vec<Option<i16>>>(),
vec![Some(3), Some(4)]
);
}

#[test]
fn test_dictionary_array_key_reverse() {
let test = vec!["a", "a", "b", "c"];
let array: DictionaryArray<Int8Type> = test
.iter()
.map(|&x| if x == "b" { None } else { Some(x) })
.collect();
assert_eq!(
array.keys().rev().collect::<Vec<Option<i8>>>(),
vec![Some(1), None, Some(0), Some(0)]
);
assert_eq!(dict_array.keys(), &Int16Array::from(vec![3_i16, 4]));
}

#[test]
Expand Down Expand Up @@ -4239,7 +4110,7 @@ mod tests {
builder.append(22345678).unwrap();
let array = builder.finish();
assert_eq!(
"DictionaryArray {keys: [Some(0), None, Some(1)] values: PrimitiveArray<UInt32>\n[\n 12345678,\n 22345678,\n]}\n",
"DictionaryArray {keys: PrimitiveArray<UInt8>\n[\n 0,\n null,\n 1,\n] values: PrimitiveArray<UInt32>\n[\n 12345678,\n 22345678,\n]}\n",
format!("{:?}", array)
);

Expand All @@ -4251,7 +4122,7 @@ mod tests {
}
let array = builder.finish();
assert_eq!(
"DictionaryArray {keys: [Some(0), Some(0), Some(0), Some(0), Some(0), Some(0), Some(0), Some(0), Some(0), Some(0)]... values: PrimitiveArray<UInt32>\n[\n 1,\n]}\n",
"DictionaryArray {keys: PrimitiveArray<UInt8>\n[\n 0,\n 0,\n 0,\n 0,\n 0,\n 0,\n 0,\n 0,\n 0,\n 0,\n 0,\n 0,\n 0,\n 0,\n 0,\n 0,\n 0,\n 0,\n 0,\n 0,\n] values: PrimitiveArray<UInt32>\n[\n 1,\n]}\n",
format!("{:?}", array)
);
}
Expand All @@ -4264,13 +4135,13 @@ mod tests {
.map(|&x| if x == "b" { None } else { Some(x) })
.collect();
assert_eq!(
"DictionaryArray {keys: [Some(0), Some(0), None, Some(1)] values: StringArray\n[\n \"a\",\n \"c\",\n]}\n",
"DictionaryArray {keys: PrimitiveArray<Int8>\n[\n 0,\n 0,\n null,\n 1,\n] values: StringArray\n[\n \"a\",\n \"c\",\n]}\n",
format!("{:?}", array)
);

let array: DictionaryArray<Int8Type> = test.into_iter().collect();
assert_eq!(
"DictionaryArray {keys: [Some(0), Some(0), Some(1), Some(2)] values: StringArray\n[\n \"a\",\n \"b\",\n \"c\",\n]}\n",
"DictionaryArray {keys: PrimitiveArray<Int8>\n[\n 0,\n 0,\n 1,\n 2,\n] values: StringArray\n[\n \"a\",\n \"b\",\n \"c\",\n]}\n",
format!("{:?}", array)
);
}
Expand Down
Loading