diff --git a/arrow/src/array/data.rs b/arrow/src/array/data.rs index cbbc56a2456a..22d850dc118f 100644 --- a/arrow/src/array/data.rs +++ b/arrow/src/array/data.rs @@ -194,7 +194,16 @@ pub(crate) fn new_buffers(data_type: &DataType, capacity: usize) -> [MutableBuff MutableBuffer::new(capacity * mem::size_of::()), empty_buffer, ], - DataType::Union(_, _) => unimplemented!(), + DataType::Union(_, mode) => { + let type_ids = MutableBuffer::new(capacity * mem::size_of::()); + match mode { + UnionMode::Sparse => [type_ids, empty_buffer], + UnionMode::Dense => { + let offsets = MutableBuffer::new(capacity * mem::size_of::()); + [type_ids, offsets] + } + } + } } } @@ -210,7 +219,8 @@ pub(crate) fn into_buffers( DataType::Utf8 | DataType::Binary | DataType::LargeUtf8 - | DataType::LargeBinary => vec![buffer1.into(), buffer2.into()], + | DataType::LargeBinary + | DataType::Union(_, _) => vec![buffer1.into(), buffer2.into()], _ => vec![buffer1.into()], } } @@ -559,7 +569,10 @@ impl ArrayData { DataType::Map(field, _) => { vec![Self::new_empty(field.data_type())] } - DataType::Union(_, _) => unimplemented!(), + DataType::Union(fields, _) => fields + .iter() + .map(|field| Self::new_empty(field.data_type())) + .collect(), DataType::Dictionary(_, data_type) => { vec![Self::new_empty(data_type)] } diff --git a/arrow/src/array/transform/mod.rs b/arrow/src/array/transform/mod.rs index fe16aeeb95a5..b34a469f7ba1 100644 --- a/arrow/src/array/transform/mod.rs +++ b/arrow/src/array/transform/mod.rs @@ -35,6 +35,7 @@ mod list; mod null; mod primitive; mod structure; +mod union; mod utils; mod variable_size; @@ -272,9 +273,12 @@ fn build_extend(array: &ArrayData) -> Extend { DataType::Struct(_) => structure::build_extend(array), DataType::FixedSizeBinary(_) => fixed_binary::build_extend(array), DataType::Float16 => primitive::build_extend::(array), + DataType::Union(_, mode) => match mode { + UnionMode::Sparse => union::build_extend_sparse(array), + UnionMode::Dense => union::build_extend_dense(array), + }, /* DataType::FixedSizeList(_, _) => {} - DataType::Union(_) => {} */ ty => todo!( "Take and filter operations still not supported for this datatype: `{:?}`", @@ -326,9 +330,12 @@ fn build_extend_nulls(data_type: &DataType) -> ExtendNulls { DataType::Struct(_) => structure::extend_nulls, DataType::FixedSizeBinary(_) => fixed_binary::extend_nulls, DataType::Float16 => primitive::extend_nulls::, + DataType::Union(_, mode) => match mode { + UnionMode::Sparse => union::extend_nulls_sparse, + UnionMode::Dense => union::extend_nulls_dense, + }, /* DataType::FixedSizeList(_, _) => {} - DataType::Union(_) => {} */ ty => todo!( "Take and filter operations still not supported for this datatype: `{:?}`", @@ -522,6 +529,15 @@ impl<'a> MutableArrayData<'a> { }) .collect::>(), }, + DataType::Union(fields, _) => (0..fields.len()) + .map(|i| { + let child_arrays = arrays + .iter() + .map(|array| &array.child_data()[i]) + .collect::>(); + MutableArrayData::new(child_arrays, use_nulls, array_capacity) + }) + .collect::>(), ty => { todo!("Take and filter operations still not supported for this datatype: `{:?}`", ty) } diff --git a/arrow/src/array/transform/union.rs b/arrow/src/array/transform/union.rs new file mode 100644 index 000000000000..ec672daf4d99 --- /dev/null +++ b/arrow/src/array/transform/union.rs @@ -0,0 +1,161 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::array::ArrayData; + +use super::{Extend, _MutableArrayData}; + +pub(super) fn build_extend_sparse(array: &ArrayData) -> Extend { + let type_ids = array.buffer::(0); + + if array.null_count() == 0 { + Box::new( + move |mutable: &mut _MutableArrayData, + index: usize, + start: usize, + len: usize| { + // extends type_ids + mutable + .buffer1 + .extend_from_slice(&type_ids[start..start + len]); + + mutable + .child_data + .iter_mut() + .for_each(|child| child.extend(index, start, start + len)) + }, + ) + } else { + Box::new( + move |mutable: &mut _MutableArrayData, + index: usize, + start: usize, + len: usize| { + // extends type_ids + mutable + .buffer1 + .extend_from_slice(&type_ids[start..start + len]); + + (start..start + len).for_each(|i| { + if array.is_valid(i) { + mutable + .child_data + .iter_mut() + .for_each(|child| child.extend(index, i, i + 1)) + } else { + mutable + .child_data + .iter_mut() + .for_each(|child| child.extend_nulls(1)) + } + }) + }, + ) + } +} + +pub(super) fn build_extend_dense(array: &ArrayData) -> Extend { + let type_ids = array.buffer::(0); + let offsets = array.buffer::(1); + + if array.null_count() == 0 { + Box::new( + move |mutable: &mut _MutableArrayData, + index: usize, + start: usize, + len: usize| { + // extends type_ids + mutable + .buffer1 + .extend_from_slice(&type_ids[start..start + len]); + // extends offsets + mutable + .buffer2 + .extend_from_slice(&offsets[start..start + len]); + + (start..start + len).for_each(|i| { + let type_id = type_ids[i] as usize; + let offset_start = offsets[start] as usize; + + mutable.child_data[type_id].extend( + index, + offset_start, + offset_start + 1, + ) + }) + }, + ) + } else { + Box::new( + move |mutable: &mut _MutableArrayData, + index: usize, + start: usize, + len: usize| { + // extends type_ids + mutable + .buffer1 + .extend_from_slice(&type_ids[start..start + len]); + // extends offsets + mutable + .buffer2 + .extend_from_slice(&offsets[start..start + len]); + + (start..start + len).for_each(|i| { + let type_id = type_ids[i] as usize; + let offset_start = offsets[start] as usize; + + if array.is_valid(i) { + mutable.child_data[type_id].extend( + index, + offset_start, + offset_start + 1, + ) + } else { + mutable.child_data[type_id].extend_nulls(1) + } + }) + }, + ) + } +} + +pub(super) fn extend_nulls_dense(mutable: &mut _MutableArrayData, len: usize) { + let mut count: usize = 0; + let num = len / mutable.child_data.len(); + mutable + .child_data + .iter_mut() + .enumerate() + .for_each(|(idx, child)| { + let n = if count + num > len { len - count } else { num }; + count += n; + mutable + .buffer1 + .extend_from_slice(vec![idx as i8; n].as_slice()); + mutable + .buffer2 + .extend_from_slice(vec![child.len() as i32; n].as_slice()); + child.extend_nulls(n) + }) +} + +pub(super) fn extend_nulls_sparse(mutable: &mut _MutableArrayData, len: usize) { + mutable + .child_data + .iter_mut() + .for_each(|child| child.extend_nulls(len)) +} diff --git a/arrow/src/compute/kernels/filter.rs b/arrow/src/compute/kernels/filter.rs index d0bfc91a98e1..85511b990b5f 100644 --- a/arrow/src/compute/kernels/filter.rs +++ b/arrow/src/compute/kernels/filter.rs @@ -1521,4 +1521,143 @@ mod tests { assert_eq!(&expected, &got); } + + fn test_filter_union_array(array: UnionArray) { + let filter_array = BooleanArray::from(vec![true, false, false]); + let c = filter(&array, &filter_array).unwrap(); + let filtered = c.as_any().downcast_ref::().unwrap(); + + let mut builder = UnionBuilder::new_dense(1); + builder.append::("A", 1).unwrap(); + let expected_array = builder.build().unwrap(); + + compare_union_arrays(filtered, &expected_array); + + let filter_array = BooleanArray::from(vec![true, false, true]); + let c = filter(&array, &filter_array).unwrap(); + let filtered = c.as_any().downcast_ref::().unwrap(); + + let mut builder = UnionBuilder::new_dense(2); + builder.append::("A", 1).unwrap(); + builder.append::("A", 34).unwrap(); + let expected_array = builder.build().unwrap(); + + compare_union_arrays(filtered, &expected_array); + + let filter_array = BooleanArray::from(vec![true, true, false]); + let c = filter(&array, &filter_array).unwrap(); + let filtered = c.as_any().downcast_ref::().unwrap(); + + let mut builder = UnionBuilder::new_dense(2); + builder.append::("A", 1).unwrap(); + builder.append::("B", 3.2).unwrap(); + let expected_array = builder.build().unwrap(); + + compare_union_arrays(filtered, &expected_array); + } + + #[test] + fn test_filter_union_array_dense() { + let mut builder = UnionBuilder::new_dense(3); + builder.append::("A", 1).unwrap(); + builder.append::("B", 3.2).unwrap(); + builder.append::("A", 34).unwrap(); + let array = builder.build().unwrap(); + + test_filter_union_array(array); + } + + #[test] + fn test_filter_union_array_dense_with_nulls() { + let mut builder = UnionBuilder::new_dense(4); + builder.append::("A", 1).unwrap(); + builder.append::("B", 3.2).unwrap(); + builder.append_null().unwrap(); + builder.append::("A", 34).unwrap(); + let array = builder.build().unwrap(); + + let filter_array = BooleanArray::from(vec![true, false, true, false]); + let c = filter(&array, &filter_array).unwrap(); + let filtered = c.as_any().downcast_ref::().unwrap(); + + let mut builder = UnionBuilder::new_dense(1); + builder.append::("A", 1).unwrap(); + builder.append_null().unwrap(); + let expected_array = builder.build().unwrap(); + + compare_union_arrays(filtered, &expected_array); + } + + #[test] + fn test_filter_union_array_sparse() { + let mut builder = UnionBuilder::new_sparse(3); + builder.append::("A", 1).unwrap(); + builder.append::("B", 3.2).unwrap(); + builder.append::("A", 34).unwrap(); + let array = builder.build().unwrap(); + + test_filter_union_array(array); + } + + #[test] + fn test_filter_union_array_sparse_with_nulls() { + let mut builder = UnionBuilder::new_sparse(4); + builder.append::("A", 1).unwrap(); + builder.append::("B", 3.2).unwrap(); + builder.append_null().unwrap(); + builder.append::("A", 34).unwrap(); + let array = builder.build().unwrap(); + + let filter_array = BooleanArray::from(vec![true, false, true, false]); + let c = filter(&array, &filter_array).unwrap(); + let filtered = c.as_any().downcast_ref::().unwrap(); + + let mut builder = UnionBuilder::new_dense(1); + builder.append::("A", 1).unwrap(); + builder.append_null().unwrap(); + let expected_array = builder.build().unwrap(); + + compare_union_arrays(filtered, &expected_array); + } + + fn compare_union_arrays(union1: &UnionArray, union2: &UnionArray) { + assert_eq!(union1.len(), union2.len()); + + for i in 0..union1.len() { + let type_id = union1.type_id(i); + + let slot1 = union1.value(i); + let slot2 = union2.value(i); + + assert_eq!(union1.is_null(i), union2.is_null(i)); + + if !union1.is_null(i) && !union2.is_null(i) { + match type_id { + 0 => { + let slot1 = slot1.as_any().downcast_ref::().unwrap(); + assert_eq!(slot1.len(), 1); + let value1 = slot1.value(0); + + let slot2 = slot2.as_any().downcast_ref::().unwrap(); + assert_eq!(slot2.len(), 1); + let value2 = slot2.value(0); + assert_eq!(value1, value2); + } + 1 => { + let slot1 = + slot1.as_any().downcast_ref::().unwrap(); + assert_eq!(slot1.len(), 1); + let value1 = slot1.value(0); + + let slot2 = + slot2.as_any().downcast_ref::().unwrap(); + assert_eq!(slot2.len(), 1); + let value2 = slot2.value(0); + assert_eq!(value1, value2); + } + _ => unreachable!(), + } + } + } + } }