diff --git a/rust/arrow/src/array/array.rs b/rust/arrow/src/array/array.rs index 3283dff6217..2b53f30ba22 100644 --- a/rust/arrow/src/array/array.rs +++ b/rust/arrow/src/array/array.rs @@ -1429,14 +1429,28 @@ impl From for LargeBinaryArray { } } +/// Like OffsetSizeTrait, but specialized for Strings +// This allow us to expose a constant datatype for the GenericStringArray +pub trait StringOffsetSizeTrait: OffsetSizeTrait { + const DATA_TYPE: DataType; +} + +impl StringOffsetSizeTrait for i32 { + const DATA_TYPE: DataType = DataType::Utf8; +} + +impl StringOffsetSizeTrait for i64 { + const DATA_TYPE: DataType = DataType::LargeUtf8; +} + /// Generic struct for \[Large\]StringArray -pub struct GenericStringArray { +pub struct GenericStringArray { data: ArrayDataRef, value_offsets: RawPtrBox, value_data: RawPtrBox, } -impl GenericStringArray { +impl GenericStringArray { /// Returns the offset for the element at index `i`. /// /// Note this doesn't do any bound checking, for performance reason. @@ -1559,7 +1573,7 @@ impl GenericStringArray { } } -impl fmt::Debug for GenericStringArray { +impl fmt::Debug for GenericStringArray { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { write!(f, "{}StringArray\n[\n", OffsetSize::prefix())?; print_long_array(self, f, |array, index, f| { @@ -1569,7 +1583,7 @@ impl fmt::Debug for GenericStringArray } } -impl Array for GenericStringArray { +impl Array for GenericStringArray { fn as_any(&self) -> &Any { self } @@ -1593,8 +1607,15 @@ impl Array for GenericStringArray { } } -impl From for GenericStringArray { +impl From + for GenericStringArray +{ fn from(data: ArrayDataRef) -> Self { + assert_eq!( + data.data_type(), + &::DATA_TYPE, + "[Large]StringArray expects Datatype::[Large]Utf8" + ); assert_eq!( data.buffers().len(), 2, @@ -1612,7 +1633,7 @@ impl From for GenericStringArray ListArrayOps +impl ListArrayOps for GenericStringArray { fn value_offset_at(&self, i: usize) -> OffsetSize { @@ -3608,6 +3629,13 @@ mod tests { } } + #[test] + #[should_panic(expected = "[Large]StringArray expects Datatype::[Large]Utf8")] + fn test_string_array_from_int() { + let array = LargeStringArray::from(vec!["a", "b"]); + StringArray::from(array.data()); + } + #[test] fn test_large_string_array_from_u8_slice() { let values: Vec<&str> = vec!["hello", "", "parquet"]; diff --git a/rust/arrow/src/array/equal.rs b/rust/arrow/src/array/equal.rs index dd7dec89064..df480fe9045 100644 --- a/rust/arrow/src/array/equal.rs +++ b/rust/arrow/src/array/equal.rs @@ -20,7 +20,7 @@ use crate::datatypes::*; use crate::util::bit_util; use array::{ Array, GenericBinaryArray, GenericListArray, GenericStringArray, ListArrayOps, - OffsetSizeTrait, + OffsetSizeTrait, StringOffsetSizeTrait, }; use hex::FromHex; use serde_json::value::Value::{Null as JNull, Object, String as JString}; @@ -141,7 +141,7 @@ impl PartialEq for BooleanArray { } } -impl PartialEq for GenericStringArray { +impl PartialEq for GenericStringArray { fn eq(&self, other: &Self) -> bool { self.equals(other) } @@ -444,7 +444,7 @@ impl ArrayEqual for GenericBinaryArray } } -impl ArrayEqual for GenericStringArray { +impl ArrayEqual for GenericStringArray { fn equals(&self, other: &dyn Array) -> bool { if !base_equal(&self.data(), &other.data()) { return false; @@ -1063,7 +1063,7 @@ impl PartialEq> for } } -impl JsonEqual for GenericStringArray { +impl JsonEqual for GenericStringArray { fn equals_json(&self, json: &[&Value]) -> bool { if self.len() != json.len() { return false; @@ -1077,7 +1077,9 @@ impl JsonEqual for GenericStringArray { } } -impl PartialEq for GenericStringArray { +impl PartialEq + for GenericStringArray +{ fn eq(&self, json: &Value) -> bool { match json { Value::Array(json_array) => self.equals_json_values(&json_array), @@ -1086,7 +1088,9 @@ impl PartialEq for GenericStringArray PartialEq> for Value { +impl PartialEq> + for Value +{ fn eq(&self, arrow: &GenericStringArray) -> bool { match self { Value::Array(json_array) => arrow.equals_json_values(&json_array), @@ -1412,7 +1416,7 @@ mod tests { // assert!(b_slice.equals(&*a_slice)); } - fn test_generic_string_equal(datatype: DataType) { + fn test_generic_string_equal(datatype: DataType) { let a = GenericStringArray::::from_vec( vec!["hello", "world"], datatype.clone(), diff --git a/rust/arrow/src/array/mod.rs b/rust/arrow/src/array/mod.rs index 9debbb6b0ad..b9bf0fed9d9 100644 --- a/rust/arrow/src/array/mod.rs +++ b/rust/arrow/src/array/mod.rs @@ -160,6 +160,7 @@ pub use self::array::GenericListArray; pub use self::array::GenericStringArray; pub use self::array::OffsetSizeTrait; pub use self::array::PrimitiveArrayOps; +pub use self::array::StringOffsetSizeTrait; // --------------------- Array Builder --------------------- diff --git a/rust/arrow/src/compute/kernels/aggregate.rs b/rust/arrow/src/compute/kernels/aggregate.rs index 78c8f2a7ace..996e1667c14 100644 --- a/rust/arrow/src/compute/kernels/aggregate.rs +++ b/rust/arrow/src/compute/kernels/aggregate.rs @@ -19,11 +19,11 @@ use std::ops::Add; -use crate::array::{Array, GenericStringArray, OffsetSizeTrait, PrimitiveArray}; +use crate::array::{Array, GenericStringArray, PrimitiveArray, StringOffsetSizeTrait}; use crate::datatypes::ArrowNumericType; /// Helper macro to perform min/max of strings -fn min_max_string bool>( +fn min_max_string bool>( array: &GenericStringArray, cmp: F, ) -> Option<&str> { @@ -73,12 +73,16 @@ where } /// Returns the maximum value in the string array, according to the natural order. -pub fn max_string(array: &GenericStringArray) -> Option<&str> { +pub fn max_string( + array: &GenericStringArray, +) -> Option<&str> { min_max_string(array, |a, b| a < b) } /// Returns the minimum value in the string array, according to the natural order. -pub fn min_string(array: &GenericStringArray) -> Option<&str> { +pub fn min_string( + array: &GenericStringArray, +) -> Option<&str> { min_max_string(array, |a, b| a > b) } diff --git a/rust/arrow/src/compute/kernels/comparison.rs b/rust/arrow/src/compute/kernels/comparison.rs index 8f81c3bcf49..72d0f6bff94 100644 --- a/rust/arrow/src/compute/kernels/comparison.rs +++ b/rust/arrow/src/compute/kernels/comparison.rs @@ -617,7 +617,7 @@ pub fn contains_utf8( right: &ListArray, ) -> Result where - OffsetSize: OffsetSizeTrait, + OffsetSize: StringOffsetSizeTrait, { let left_len = left.len(); if left_len != right.len() { diff --git a/rust/arrow/src/compute/kernels/substring.rs b/rust/arrow/src/compute/kernels/substring.rs index 68f117115e5..38c8040ef00 100644 --- a/rust/arrow/src/compute/kernels/substring.rs +++ b/rust/arrow/src/compute/kernels/substring.rs @@ -24,11 +24,10 @@ use crate::{ }; use std::sync::Arc; -fn substring1( +fn generic_substring( array: &GenericStringArray, start: OffsetSize, length: &Option, - datatype: DataType, ) -> Result { // compute current offsets let offsets = array.data_ref().clone().buffers()[0].clone(); @@ -76,7 +75,7 @@ fn substring1( }); let data = ArrayData::new( - datatype, + ::DATA_TYPE, array.len(), None, null_bit_buffer, @@ -95,23 +94,21 @@ fn substring1( /// this function errors when the passed array is not a \[Large\]String array. pub fn substring(array: &Array, start: i64, length: &Option) -> Result { match array.data_type() { - DataType::LargeUtf8 => substring1( + DataType::LargeUtf8 => generic_substring( array .as_any() .downcast_ref::() .expect("A large string is expected"), start, &length.map(|e| e as i64), - DataType::LargeUtf8, ), - DataType::Utf8 => substring1( + DataType::Utf8 => generic_substring( array .as_any() .downcast_ref::() .expect("A string is expected"), start as i32, &length.map(|e| e as i32), - DataType::Utf8, ), _ => Err(ArrowError::ComputeError(format!( "substring does not support type {:?}", diff --git a/rust/arrow/src/compute/kernels/take.rs b/rust/arrow/src/compute/kernels/take.rs index f1525375d39..a76885cc30b 100644 --- a/rust/arrow/src/compute/kernels/take.rs +++ b/rust/arrow/src/compute/kernels/take.rs @@ -125,8 +125,8 @@ pub fn take( DataType::Duration(TimeUnit::Nanosecond) => { take_primitive::(values, indices) } - DataType::Utf8 => take_string::(values, indices, DataType::Utf8), - DataType::LargeUtf8 => take_string::(values, indices, DataType::LargeUtf8), + DataType::Utf8 => take_string::(values, indices), + DataType::LargeUtf8 => take_string::(values, indices), DataType::List(_) => take_list(values, indices), DataType::Struct(fields) => { let struct_: &StructArray = @@ -262,13 +262,9 @@ fn take_boolean(values: &ArrayRef, indices: &UInt32Array) -> Result { } /// `take` implementation for string arrays -fn take_string( - values: &ArrayRef, - indices: &UInt32Array, - data_type: DataType, -) -> Result +fn take_string(values: &ArrayRef, indices: &UInt32Array) -> Result where - OffsetSize: Zero + AddAssign + OffsetSizeTrait, + OffsetSize: Zero + AddAssign + StringOffsetSizeTrait, { let data_len = indices.len(); @@ -306,7 +302,7 @@ where None => null_buf.freeze(), }; - let data = ArrayData::builder(data_type) + let data = ArrayData::builder(::DATA_TYPE) .len(data_len) .null_bit_buffer(nulls) .add_buffer(Buffer::from(offsets.to_byte_slice()))