diff --git a/rust/arrow/src/array.rs b/rust/arrow/src/array.rs index dc0a5090ee8..09a8dd39fa5 100644 --- a/rust/arrow/src/array.rs +++ b/rust/arrow/src/array.rs @@ -609,8 +609,8 @@ impl From for BinaryArray { impl<'a> From> for BinaryArray { fn from(v: Vec<&'a str>) -> Self { - let mut offsets = vec![]; - let mut values = vec![]; + let mut offsets = Vec::with_capacity(v.len() + 1); + let mut values = Vec::new(); let mut length_so_far = 0; offsets.push(length_so_far); for s in &v { @@ -627,6 +627,26 @@ impl<'a> From> for BinaryArray { } } +impl<'a> From> for BinaryArray { + fn from(v: Vec<&[u8]>) -> Self { + let mut offsets = Vec::with_capacity(v.len() + 1); + let mut values = Vec::new(); + let mut length_so_far = 0; + offsets.push(length_so_far); + for s in &v { + length_so_far += s.len() as i32; + offsets.push(length_so_far as i32); + values.extend_from_slice(s); + } + let array_data = ArrayData::builder(DataType::Utf8) + .len(v.len()) + .add_buffer(Buffer::from(offsets.to_byte_slice())) + .add_buffer(Buffer::from(&values[..])) + .build(); + BinaryArray::from(array_data) + } +} + /// Creates a `BinaryArray` from `List` array impl From for BinaryArray { fn from(v: ListArray) -> Self { @@ -1155,6 +1175,36 @@ mod tests { } } + #[test] + fn test_binary_array_from_u8_slice() { + let values: Vec<&[u8]> = vec![ + &[b'h', b'e', b'l', b'l', b'o'], + &[], + &[b'p', b'a', b'r', b'q', b'u', b'e', b't'], + ]; + + // Array data: ["hello", "", "parquet"] + let binary_array = BinaryArray::from(values); + + assert_eq!(3, binary_array.len()); + assert_eq!(0, binary_array.null_count()); + assert_eq!([b'h', b'e', b'l', b'l', b'o'], binary_array.value(0)); + assert_eq!("hello", binary_array.get_string(0)); + assert_eq!([] as [u8; 0], binary_array.value(1)); + assert_eq!("", binary_array.get_string(1)); + assert_eq!( + [b'p', b'a', b'r', b'q', b'u', b'e', b't'], + binary_array.value(2) + ); + assert_eq!("parquet", binary_array.get_string(2)); + assert_eq!(5, binary_array.value_offset(2)); + assert_eq!(7, binary_array.value_length(2)); + for i in 0..3 { + assert!(binary_array.is_valid(i)); + assert!(!binary_array.is_null(i)); + } + } + #[test] #[should_panic( expected = "BinaryArray can only be created from List arrays, mismatched \ diff --git a/rust/arrow/src/compute/array_ops.rs b/rust/arrow/src/compute/array_ops.rs index 0d6ccbe1678..89e1667d87b 100644 --- a/rust/arrow/src/compute/array_ops.rs +++ b/rust/arrow/src/compute/array_ops.rs @@ -17,10 +17,16 @@ //! Defines primitive computations on arrays, e.g. addition, equality, boolean logic. +use std::cmp; use std::ops::Add; +use std::sync::Arc; -use crate::array::{Array, BooleanArray, PrimitiveArray}; -use crate::datatypes::ArrowNumericType; +use crate::array::{ + Array, ArrayRef, BinaryArray, BooleanArray, Float32Array, Float64Array, Int16Array, + Int32Array, Int64Array, Int8Array, PrimitiveArray, UInt16Array, UInt32Array, + UInt64Array, UInt8Array, +}; +use crate::datatypes::{ArrowNumericType, DataType}; use crate::error::{ArrowError, Result}; /// Returns the minimum value in the array, according to the natural order. @@ -204,6 +210,101 @@ where Ok(b.finish()) } +macro_rules! filter_array { + ($array:expr, $filter:expr, $array_type:ident) => {{ + let b = $array.as_any().downcast_ref::<$array_type>().unwrap(); + let mut builder = $array_type::builder(b.len()); + for i in 0..b.len() { + if $filter.value(i) { + if b.is_null(i) { + builder.append_null()?; + } else { + builder.append_value(b.value(i))?; + } + } + } + Ok(Arc::new(builder.finish())) + }}; +} + +pub fn filter(array: &Array, filter: &BooleanArray) -> Result { + match array.data_type() { + DataType::UInt8 => filter_array!(array, filter, UInt8Array), + DataType::UInt16 => filter_array!(array, filter, UInt16Array), + DataType::UInt32 => filter_array!(array, filter, UInt32Array), + DataType::UInt64 => filter_array!(array, filter, UInt64Array), + DataType::Int8 => filter_array!(array, filter, Int8Array), + DataType::Int16 => filter_array!(array, filter, Int16Array), + DataType::Int32 => filter_array!(array, filter, Int32Array), + DataType::Int64 => filter_array!(array, filter, Int64Array), + DataType::Float32 => filter_array!(array, filter, Float32Array), + DataType::Float64 => filter_array!(array, filter, Float64Array), + DataType::Boolean => filter_array!(array, filter, BooleanArray), + DataType::Utf8 => { + let b = array.as_any().downcast_ref::().unwrap(); + let mut values: Vec<&[u8]> = Vec::with_capacity(b.len()); + for i in 0..b.len() { + if filter.value(i) { + values.push(b.value(i)); + } + } + Ok(Arc::new(BinaryArray::from(values))) + } + other => Err(ArrowError::ComputeError(format!( + "filter not supported for {:?}", + other + ))), + } +} + +macro_rules! limit_array { + ($array:expr, $num_elements:expr, $array_type:ident) => {{ + let b = $array.as_any().downcast_ref::<$array_type>().unwrap(); + let mut builder = $array_type::builder($num_elements); + for i in 0..$num_elements { + if b.is_null(i) { + builder.append_null()?; + } else { + builder.append_value(b.value(i))?; + } + } + Ok(Arc::new(builder.finish())) + }}; +} + +/// Returns the array, taking only the number of elements specified +/// +/// Returns the whole array if the number of elements specified is larger than the length of the array +pub fn limit(array: &Array, num_elements: usize) -> Result { + let num_elements_safe: usize = cmp::min(array.len(), num_elements); + + match array.data_type() { + DataType::UInt8 => limit_array!(array, num_elements_safe, UInt8Array), + DataType::UInt16 => limit_array!(array, num_elements_safe, UInt16Array), + DataType::UInt32 => limit_array!(array, num_elements_safe, UInt32Array), + DataType::UInt64 => limit_array!(array, num_elements_safe, UInt64Array), + DataType::Int8 => limit_array!(array, num_elements_safe, Int8Array), + DataType::Int16 => limit_array!(array, num_elements_safe, Int16Array), + DataType::Int32 => limit_array!(array, num_elements_safe, Int32Array), + DataType::Int64 => limit_array!(array, num_elements_safe, Int64Array), + DataType::Float32 => limit_array!(array, num_elements_safe, Float32Array), + DataType::Float64 => limit_array!(array, num_elements_safe, Float64Array), + DataType::Boolean => limit_array!(array, num_elements_safe, BooleanArray), + DataType::Utf8 => { + let b = array.as_any().downcast_ref::().unwrap(); + let mut values: Vec<&[u8]> = Vec::with_capacity(num_elements_safe); + for i in 0..num_elements_safe { + values.push(b.value(i)); + } + Ok(Arc::new(BinaryArray::from(values))) + } + other => Err(ArrowError::ComputeError(format!( + "limit not supported for {:?}", + other + ))), + } +} + #[cfg(test)] mod tests { use super::*; @@ -358,4 +459,80 @@ mod tests { assert_eq!(5, min(&a).unwrap()); assert_eq!(9, max(&a).unwrap()); } + + #[test] + fn test_filter_array() { + let a = Int32Array::from(vec![5, 6, 7, 8, 9]); + let b = BooleanArray::from(vec![true, false, false, true, false]); + let c = filter(&a, &b).unwrap(); + let d = c.as_ref().as_any().downcast_ref::().unwrap(); + assert_eq!(2, d.len()); + assert_eq!(5, d.value(0)); + assert_eq!(8, d.value(1)); + } + + #[test] + fn test_filter_binary_array() { + let a = BinaryArray::from(vec!["hello", " ", "world", "!"]); + let b = BooleanArray::from(vec![true, false, true, false]); + let c = filter(&a, &b).unwrap(); + let d = c.as_ref().as_any().downcast_ref::().unwrap(); + assert_eq!(2, d.len()); + assert_eq!("hello", d.get_string(0)); + assert_eq!("world", d.get_string(1)); + } + + #[test] + fn test_filter_array_with_null() { + let a = Int32Array::from(vec![Some(5), None]); + let b = BooleanArray::from(vec![false, true]); + let c = filter(&a, &b).unwrap(); + let d = c.as_ref().as_any().downcast_ref::().unwrap(); + assert_eq!(1, d.len()); + assert_eq!(true, d.is_null(0)); + } + + #[test] + fn test_limit_array() { + let a = Int32Array::from(vec![5, 6, 7, 8, 9]); + let b = limit(&a, 3).unwrap(); + let c = b.as_ref().as_any().downcast_ref::().unwrap(); + assert_eq!(3, c.len()); + assert_eq!(5, c.value(0)); + assert_eq!(6, c.value(1)); + assert_eq!(7, c.value(2)); + } + + #[test] + fn test_limit_binary_array() { + let a = BinaryArray::from(vec!["hello", " ", "world", "!"]); + let b = limit(&a, 2).unwrap(); + let c = b.as_ref().as_any().downcast_ref::().unwrap(); + assert_eq!(2, c.len()); + assert_eq!("hello", c.get_string(0)); + assert_eq!(" ", c.get_string(1)); + } + + #[test] + fn test_limit_array_with_null() { + let a = Int32Array::from(vec![None, Some(5)]); + let b = limit(&a, 1).unwrap(); + let c = b.as_ref().as_any().downcast_ref::().unwrap(); + assert_eq!(1, c.len()); + assert_eq!(true, c.is_null(0)); + } + + #[test] + fn test_limit_array_with_limit_too_large() { + let a = Int32Array::from(vec![5, 6, 7, 8, 9]); + let b = limit(&a, 6).unwrap(); + let c = b.as_ref().as_any().downcast_ref::().unwrap(); + + assert_eq!(5, c.len()); + assert_eq!(a.value(0), c.value(0)); + assert_eq!(a.value(1), c.value(1)); + assert_eq!(a.value(2), c.value(2)); + assert_eq!(a.value(3), c.value(3)); + assert_eq!(a.value(4), c.value(4)); + } } diff --git a/rust/datafusion/src/execution/filter.rs b/rust/datafusion/src/execution/filter.rs index 32c1628fecf..8e706705b2e 100644 --- a/rust/datafusion/src/execution/filter.rs +++ b/rust/datafusion/src/execution/filter.rs @@ -22,7 +22,8 @@ use std::rc::Rc; use std::sync::Arc; use arrow::array::*; -use arrow::datatypes::{DataType, Schema}; +use arrow::compute::array_ops::filter; +use arrow::datatypes::Schema; use arrow::record_batch::RecordBatch; use super::error::{ExecutionError, Result}; @@ -61,7 +62,12 @@ impl Relation for FilterRelation { Some(filter_bools) => { let filtered_columns: Result> = (0..batch .num_columns()) - .map(|i| filter(batch.column(i), &filter_bools)) + .map(|i| { + match filter(batch.column(i).as_ref(), &filter_bools) { + Ok(result) => Ok(result), + Err(error) => Err(ExecutionError::from(error)), + } + }) .collect(); let filtered_batch: RecordBatch = RecordBatch::new( @@ -84,129 +90,3 @@ impl Relation for FilterRelation { &self.schema } } - -//TODO: move into Arrow array_ops -fn filter(array: &Arc, filter: &BooleanArray) -> Result { - let a = array.as_ref(); - - //TODO use macros - match a.data_type() { - DataType::UInt8 => { - let b = a.as_any().downcast_ref::().unwrap(); - let mut builder = UInt8Array::builder(b.len()); - for i in 0..b.len() { - if filter.value(i) { - builder.append_value(b.value(i))?; - } - } - Ok(Arc::new(builder.finish())) - } - DataType::UInt16 => { - let b = a.as_any().downcast_ref::().unwrap(); - let mut builder = UInt16Array::builder(b.len()); - for i in 0..b.len() { - if filter.value(i) { - builder.append_value(b.value(i))?; - } - } - Ok(Arc::new(builder.finish())) - } - DataType::UInt32 => { - let b = a.as_any().downcast_ref::().unwrap(); - let mut builder = UInt32Array::builder(b.len()); - for i in 0..b.len() { - if filter.value(i) { - builder.append_value(b.value(i))?; - } - } - Ok(Arc::new(builder.finish())) - } - DataType::UInt64 => { - let b = a.as_any().downcast_ref::().unwrap(); - let mut builder = UInt64Array::builder(b.len()); - for i in 0..b.len() { - if filter.value(i) { - builder.append_value(b.value(i))?; - } - } - Ok(Arc::new(builder.finish())) - } - DataType::Int8 => { - let b = a.as_any().downcast_ref::().unwrap(); - let mut builder = Int8Array::builder(b.len()); - for i in 0..b.len() { - if filter.value(i) { - builder.append_value(b.value(i))?; - } - } - Ok(Arc::new(builder.finish())) - } - DataType::Int16 => { - let b = a.as_any().downcast_ref::().unwrap(); - let mut builder = Int16Array::builder(b.len()); - for i in 0..b.len() { - if filter.value(i) { - builder.append_value(b.value(i))?; - } - } - Ok(Arc::new(builder.finish())) - } - DataType::Int32 => { - let b = a.as_any().downcast_ref::().unwrap(); - let mut builder = Int32Array::builder(b.len()); - for i in 0..b.len() { - if filter.value(i) { - builder.append_value(b.value(i))?; - } - } - Ok(Arc::new(builder.finish())) - } - DataType::Int64 => { - let b = a.as_any().downcast_ref::().unwrap(); - let mut builder = Int64Array::builder(b.len()); - for i in 0..b.len() { - if filter.value(i) { - builder.append_value(b.value(i))?; - } - } - Ok(Arc::new(builder.finish())) - } - DataType::Float32 => { - let b = a.as_any().downcast_ref::().unwrap(); - let mut builder = Float32Array::builder(b.len()); - for i in 0..b.len() { - if filter.value(i) { - builder.append_value(b.value(i))?; - } - } - Ok(Arc::new(builder.finish())) - } - DataType::Float64 => { - let b = a.as_any().downcast_ref::().unwrap(); - let mut builder = Float64Array::builder(b.len()); - for i in 0..b.len() { - if filter.value(i) { - builder.append_value(b.value(i))?; - } - } - Ok(Arc::new(builder.finish())) - } - DataType::Utf8 => { - //TODO: this is inefficient and we should improve the Arrow impl to help make - // this more concise - let b = a.as_any().downcast_ref::().unwrap(); - let mut values: Vec = Vec::with_capacity(b.len()); - for i in 0..b.len() { - if filter.value(i) { - values.push(b.get_string(i)); - } - } - let tmp: Vec<&str> = values.iter().map(|s| s.as_str()).collect(); - Ok(Arc::new(BinaryArray::from(tmp))) - } - other => Err(ExecutionError::ExecutionError(format!( - "filter not supported for {:?}", - other - ))), - } -} diff --git a/rust/datafusion/src/execution/limit.rs b/rust/datafusion/src/execution/limit.rs index d6258d63db9..888fac5045e 100644 --- a/rust/datafusion/src/execution/limit.rs +++ b/rust/datafusion/src/execution/limit.rs @@ -22,7 +22,8 @@ use std::rc::Rc; use std::sync::Arc; use arrow::array::*; -use arrow::datatypes::{DataType, Schema}; +use arrow::compute::array_ops::limit; +use arrow::datatypes::Schema; use arrow::record_batch::RecordBatch; use super::error::{ExecutionError, Result}; @@ -58,7 +59,10 @@ impl Relation for LimitRelation { if batch.num_rows() >= capacity { let limited_columns: Result> = (0..batch.num_columns()) - .map(|i| limit(batch.column(i).as_ref(), capacity)) + .map(|i| match limit(batch.column(i).as_ref(), capacity) { + Ok(result) => Ok(result), + Err(error) => Err(ExecutionError::from(error)), + }) .collect(); let limited_batch: RecordBatch = @@ -79,104 +83,3 @@ impl Relation for LimitRelation { &self.schema } } - -//TODO: move into Arrow array_ops -fn limit(a: &Array, num_rows_to_read: usize) -> Result { - //TODO use macros - match a.data_type() { - DataType::UInt8 => { - let b = a.as_any().downcast_ref::().unwrap(); - let mut builder = UInt8Array::builder(num_rows_to_read as usize); - for i in 0..num_rows_to_read { - builder.append_value(b.value(i as usize))?; - } - Ok(Arc::new(builder.finish())) - } - DataType::UInt16 => { - let b = a.as_any().downcast_ref::().unwrap(); - let mut builder = UInt16Array::builder(num_rows_to_read as usize); - for i in 0..num_rows_to_read { - builder.append_value(b.value(i as usize))?; - } - Ok(Arc::new(builder.finish())) - } - DataType::UInt32 => { - let b = a.as_any().downcast_ref::().unwrap(); - let mut builder = UInt32Array::builder(num_rows_to_read as usize); - for i in 0..num_rows_to_read { - builder.append_value(b.value(i as usize))?; - } - Ok(Arc::new(builder.finish())) - } - DataType::UInt64 => { - let b = a.as_any().downcast_ref::().unwrap(); - let mut builder = UInt64Array::builder(num_rows_to_read as usize); - for i in 0..num_rows_to_read { - builder.append_value(b.value(i as usize))?; - } - Ok(Arc::new(builder.finish())) - } - DataType::Int8 => { - let b = a.as_any().downcast_ref::().unwrap(); - let mut builder = Int8Array::builder(num_rows_to_read as usize); - for i in 0..num_rows_to_read { - builder.append_value(b.value(i as usize))?; - } - Ok(Arc::new(builder.finish())) - } - DataType::Int16 => { - let b = a.as_any().downcast_ref::().unwrap(); - let mut builder = Int16Array::builder(num_rows_to_read as usize); - for i in 0..num_rows_to_read { - builder.append_value(b.value(i as usize))?; - } - Ok(Arc::new(builder.finish())) - } - DataType::Int32 => { - let b = a.as_any().downcast_ref::().unwrap(); - let mut builder = Int32Array::builder(num_rows_to_read as usize); - for i in 0..num_rows_to_read { - builder.append_value(b.value(i as usize))?; - } - Ok(Arc::new(builder.finish())) - } - DataType::Int64 => { - let b = a.as_any().downcast_ref::().unwrap(); - let mut builder = Int64Array::builder(num_rows_to_read as usize); - for i in 0..num_rows_to_read { - builder.append_value(b.value(i as usize))?; - } - Ok(Arc::new(builder.finish())) - } - DataType::Float32 => { - let b = a.as_any().downcast_ref::().unwrap(); - let mut builder = Float32Array::builder(num_rows_to_read as usize); - for i in 0..num_rows_to_read { - builder.append_value(b.value(i as usize))?; - } - Ok(Arc::new(builder.finish())) - } - DataType::Float64 => { - let b = a.as_any().downcast_ref::().unwrap(); - let mut builder = Float64Array::builder(num_rows_to_read as usize); - for i in 0..num_rows_to_read { - builder.append_value(b.value(i as usize))?; - } - Ok(Arc::new(builder.finish())) - } - DataType::Utf8 => { - //TODO: this is inefficient and we should improve the Arrow impl to help make this more concise - let b = a.as_any().downcast_ref::().unwrap(); - let mut values: Vec = Vec::with_capacity(num_rows_to_read as usize); - for i in 0..num_rows_to_read { - values.push(b.get_string(i as usize)); - } - let tmp: Vec<&str> = values.iter().map(|s| s.as_str()).collect(); - Ok(Arc::new(BinaryArray::from(tmp))) - } - other => Err(ExecutionError::ExecutionError(format!( - "filter not supported for {:?}", - other - ))), - } -}