diff --git a/rust/arrow/src/array/array_binary.rs b/rust/arrow/src/array/array_binary.rs index c59e540551a..0af194ec37d 100644 --- a/rust/arrow/src/array/array_binary.rs +++ b/rust/arrow/src/array/array_binary.rs @@ -776,7 +776,9 @@ mod tests { .build(); let binary_array1 = BinaryArray::from(array_data1); - let array_data2 = ArrayData::builder(DataType::Binary) + let data_type = + DataType::List(Box::new(Field::new("item", DataType::UInt8, false))); + let array_data2 = ArrayData::builder(data_type) .len(3) .add_buffer(Buffer::from_slice_ref(&offsets)) .add_child_data(values_data) @@ -818,7 +820,9 @@ mod tests { .build(); let binary_array1 = LargeBinaryArray::from(array_data1); - let array_data2 = ArrayData::builder(DataType::Binary) + let data_type = + DataType::LargeList(Box::new(Field::new("item", DataType::UInt8, false))); + let array_data2 = ArrayData::builder(data_type) .len(3) .add_buffer(Buffer::from_slice_ref(&offsets)) .add_child_data(values_data) @@ -869,41 +873,21 @@ mod tests { #[test] #[should_panic( - expected = "BinaryArray can only be created from List arrays, mismatched \ - data types." - )] - fn test_binary_array_from_incorrect_list_array_type() { - let values: [u32; 12] = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]; - let values_data = ArrayData::builder(DataType::UInt32) - .len(12) - .add_buffer(Buffer::from_slice_ref(&values)) - .build(); - let offsets: [i32; 4] = [0, 5, 5, 12]; - - let array_data = ArrayData::builder(DataType::Utf8) - .len(3) - .add_buffer(Buffer::from_slice_ref(&offsets)) - .add_child_data(values_data) - .build(); - let list_array = ListArray::from(array_data); - BinaryArray::from(list_array); - } - - #[test] - #[should_panic( - expected = "BinaryArray can only be created from list array of u8 values \ - (i.e. List>)." + expected = "assertion failed: `(left == right)`\n left: `UInt32`,\n \ + right: `UInt8`: BinaryArray can only be created from List arrays, \ + mismatched data types." )] fn test_binary_array_from_incorrect_list_array() { let values: [u32; 12] = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]; let values_data = ArrayData::builder(DataType::UInt32) .len(12) .add_buffer(Buffer::from_slice_ref(&values)) - .add_child_data(ArrayData::builder(DataType::Boolean).build()) .build(); let offsets: [i32; 4] = [0, 5, 5, 12]; - let array_data = ArrayData::builder(DataType::Utf8) + let data_type = + DataType::List(Box::new(Field::new("item", DataType::UInt32, false))); + let array_data = ArrayData::builder(data_type) .len(3) .add_buffer(Buffer::from_slice_ref(&offsets)) .add_child_data(values_data) diff --git a/rust/arrow/src/array/array_list.rs b/rust/arrow/src/array/array_list.rs index f2076b3e86d..fc67d95a5a6 100644 --- a/rust/arrow/src/array/array_list.rs +++ b/rust/arrow/src/array/array_list.rs @@ -31,12 +31,19 @@ use crate::datatypes::{ArrowNativeType, ArrowPrimitiveType, DataType, Field}; /// trait declaring an offset size, relevant for i32 vs i64 array types. pub trait OffsetSizeTrait: ArrowNativeType + Num + Ord + std::ops::AddAssign { + fn is_large() -> bool; + fn prefix() -> &'static str; fn to_isize(&self) -> isize; } impl OffsetSizeTrait for i32 { + #[inline] + fn is_large() -> bool { + false + } + fn prefix() -> &'static str { "" } @@ -47,6 +54,11 @@ impl OffsetSizeTrait for i32 { } impl OffsetSizeTrait for i64 { + #[inline] + fn is_large() -> bool { + true + } + fn prefix() -> &'static str { "Large" } @@ -117,6 +129,21 @@ impl GenericListArray { GenericListArrayIter::<'a, OffsetSize>::new(&self) } + #[inline] + fn get_type(data_type: &DataType) -> Option<&DataType> { + if OffsetSize::is_large() { + if let DataType::LargeList(child) = data_type { + Some(child.data_type()) + } else { + None + } + } else if let DataType::List(child) = data_type { + Some(child.data_type()) + } else { + None + } + } + /// Creates a [`GenericListArray`] from an iterator of primitive values /// # Example /// ``` @@ -193,7 +220,19 @@ impl From for GenericListArray::new(value_offsets) }; diff --git a/rust/arrow/src/array/builder.rs b/rust/arrow/src/array/builder.rs index 6979a9887ca..32eea9c9c7a 100644 --- a/rust/arrow/src/array/builder.rs +++ b/rust/arrow/src/array/builder.rs @@ -764,7 +764,7 @@ where values_data.data_type().clone(), true, // TODO: find a consistent way of getting this )); - let data_type = if OffsetSize::prefix() == "Large" { + let data_type = if OffsetSize::is_large() { DataType::LargeList(field) } else { DataType::List(field) diff --git a/rust/arrow/src/compute/kernels/cast.rs b/rust/arrow/src/compute/kernels/cast.rs index 25592c657ae..9a547bdefaf 100644 --- a/rust/arrow/src/compute/kernels/cast.rs +++ b/rust/arrow/src/compute/kernels/cast.rs @@ -264,8 +264,10 @@ pub fn cast(array: &ArrayRef, to_type: &DataType) -> Result { (_, Struct(_)) => Err(ArrowError::ComputeError( "Cannot cast to struct from other types".to_string(), )), - (List(_), List(ref to)) => cast_list_inner::(&**array, to), - (LargeList(_), LargeList(ref to)) => cast_list_inner::(&**array, to), + (List(_), List(ref to)) => cast_list_inner::(&**array, to, to_type), + (LargeList(_), LargeList(ref to)) => { + cast_list_inner::(&**array, to, to_type) + } (List(list_from), LargeList(list_to)) => { if list_to.data_type() != list_from.data_type() { Err(ArrowError::ComputeError( @@ -287,8 +289,8 @@ pub fn cast(array: &ArrayRef, to_type: &DataType) -> Result { (List(_), _) => Err(ArrowError::ComputeError( "Cannot cast list to non-list data types".to_string(), )), - (_, List(ref to)) => cast_primitive_to_list::(array, to), - (_, LargeList(ref to)) => cast_primitive_to_list::(array, to), + (_, List(ref to)) => cast_primitive_to_list::(array, to, to_type), + (_, LargeList(ref to)) => cast_primitive_to_list::(array, to, to_type), (Dictionary(index_type, _), _) => match **index_type { DataType::Int8 => dictionary_cast::(array, to_type), DataType::Int16 => dictionary_cast::(array, to_type), @@ -1243,6 +1245,7 @@ where fn cast_primitive_to_list( array: &ArrayRef, to: &Field, + to_type: &DataType, ) -> Result { // cast primitive to list's primitive let cast_array = cast(array, to.data_type())?; @@ -1257,7 +1260,7 @@ fn cast_primitive_to_list( }; let list_data = ArrayData::new( - to.data_type().clone(), + to_type.clone(), array.len(), Some(cast_array.null_count()), cast_array @@ -1279,12 +1282,13 @@ fn cast_primitive_to_list( fn cast_list_inner( array: &dyn Array, to: &Field, + to_type: &DataType, ) -> Result { let data = array.data_ref(); let underlying_array = make_array(data.child_data()[0].clone()); let cast_array = cast(&underlying_array, to.data_type())?; let array_data = ArrayData::new( - to.data_type().clone(), + to_type.clone(), array.len(), Some(cast_array.null_count()), cast_array diff --git a/rust/arrow/src/ipc/writer.rs b/rust/arrow/src/ipc/writer.rs index c6b28944d24..fd104e66416 100644 --- a/rust/arrow/src/ipc/writer.rs +++ b/rust/arrow/src/ipc/writer.rs @@ -984,7 +984,6 @@ mod tests { "generated_dictionary", // "generated_duplicate_fieldnames", "generated_interval", - "generated_large_batch", "generated_nested", // "generated_nested_large_offsets", "generated_null_trivial", @@ -1048,7 +1047,6 @@ mod tests { "generated_dictionary", // "generated_duplicate_fieldnames", "generated_interval", - "generated_large_batch", "generated_nested", // "generated_nested_large_offsets", "generated_null_trivial",