Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
103 changes: 62 additions & 41 deletions rust/arrow/benches/cast_kernels.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,100 +29,121 @@ use arrow::array::*;
use arrow::compute::cast;
use arrow::datatypes::*;

// cast array from specified primitive array type to desired data type
fn cast_array<FROM>(size: usize, to_type: DataType)
fn build_array<FROM>(size: usize) -> ArrayRef
where
FROM: ArrowNumericType,
Standard: Distribution<FROM::Native>,
PrimitiveArray<FROM>: std::convert::From<Vec<FROM::Native>>,
PrimitiveArray<FROM>: std::convert::From<Vec<Option<FROM::Native>>>,
{
let array = Arc::new(PrimitiveArray::<FROM>::from(vec![
random::<FROM::Native>();
size
])) as ArrayRef;
criterion::black_box(cast(&array, &to_type).unwrap());
let values = (0..size)
.map(|_| {
// 10% nulls, i.e. dense.
if random::<f64>() < 0.1 {
None
} else {
Some(random::<FROM::Native>())
}
})
.collect();

Arc::new(PrimitiveArray::<FROM>::from(values))
}

// cast timestamp array from specified primitive array type to desired data type
fn cast_timestamp_array<FROM>(size: usize, to_type: DataType)
fn build_timestamp_array<FROM>(size: usize) -> ArrayRef
where
FROM: ArrowTimestampType,
Standard: Distribution<i64>,
Standard: Distribution<FROM::Native>,
{
let array = Arc::new(PrimitiveArray::<FROM>::from_vec(
vec![random::<i64>(); size],
None,
)) as ArrayRef;
criterion::black_box(cast(&array, &to_type).unwrap());
let values = (0..size)
.map(|_| {
if random::<f64>() < 0.5 {
None
} else {
Some(random::<i64>())
}
})
.collect::<Vec<Option<i64>>>();

Arc::new(PrimitiveArray::<FROM>::from_opt_vec(values, None))
}

// cast array from specified primitive array type to desired data type
fn cast_array(array: &ArrayRef, to_type: DataType) {
criterion::black_box(cast(array, &to_type).unwrap());
}

fn add_benchmark(c: &mut Criterion) {
let i32_array = build_array::<Int32Type>(512);
let i64_array = build_array::<Int64Type>(512);
let f32_array = build_array::<Float32Type>(512);
let f64_array = build_array::<Float64Type>(512);
let date64_array = build_array::<Date64Type>(512);
let date32_array = build_array::<Date32Type>(512);
let time32s_array = build_array::<Time32SecondType>(512);
let time64ns_array = build_array::<Time64NanosecondType>(512);
let time_ns_array = build_timestamp_array::<TimestampNanosecondType>(512);
let time_ms_array = build_timestamp_array::<TimestampMillisecondType>(512);

c.bench_function("cast int32 to int32 512", |b| {
b.iter(|| cast_array::<Int32Type>(512, DataType::Int32))
b.iter(|| cast_array(&i32_array, DataType::Int32))
});
c.bench_function("cast int32 to uint32 512", |b| {
b.iter(|| cast_array::<Int32Type>(512, DataType::UInt32))
b.iter(|| cast_array(&i32_array, DataType::UInt32))
});
c.bench_function("cast int32 to float32 512", |b| {
b.iter(|| cast_array::<Int32Type>(512, DataType::Float32))
b.iter(|| cast_array(&i32_array, DataType::Float32))
});
c.bench_function("cast int32 to float64 512", |b| {
b.iter(|| cast_array::<Int32Type>(512, DataType::Float64))
b.iter(|| cast_array(&i32_array, DataType::Float64))
});
c.bench_function("cast int32 to int64 512", |b| {
b.iter(|| cast_array::<Int32Type>(512, DataType::Int64))
b.iter(|| cast_array(&i32_array, DataType::Int64))
});
c.bench_function("cast float32 to int32 512", |b| {
b.iter(|| cast_array::<Float32Type>(512, DataType::Int32))
b.iter(|| cast_array(&f32_array, DataType::Int32))
});
c.bench_function("cast float64 to float32 512", |b| {
b.iter(|| cast_array::<Float64Type>(512, DataType::Float32))
b.iter(|| cast_array(&f64_array, DataType::Float32))
});
c.bench_function("cast float64 to uint64 512", |b| {
b.iter(|| cast_array::<Float64Type>(512, DataType::UInt64))
b.iter(|| cast_array(&f64_array, DataType::UInt64))
});
c.bench_function("cast int64 to int32 512", |b| {
b.iter(|| cast_array::<Int64Type>(512, DataType::Int32))
b.iter(|| cast_array(&i64_array, DataType::Int32))
});
c.bench_function("cast date64 to date32 512", |b| {
b.iter(|| cast_array::<Date64Type>(512, DataType::Date32(DateUnit::Day)))
b.iter(|| cast_array(&date64_array, DataType::Date32(DateUnit::Day)))
});
c.bench_function("cast date32 to date64 512", |b| {
b.iter(|| cast_array::<Date32Type>(512, DataType::Date64(DateUnit::Millisecond)))
b.iter(|| cast_array(&date32_array, DataType::Date64(DateUnit::Millisecond)))
});
c.bench_function("cast time32s to time32ms 512", |b| {
b.iter(|| {
cast_array::<Time32SecondType>(512, DataType::Time32(TimeUnit::Millisecond))
})
b.iter(|| cast_array(&time32s_array, DataType::Time32(TimeUnit::Millisecond)))
});
c.bench_function("cast time32s to time64us 512", |b| {
b.iter(|| {
cast_array::<Time32SecondType>(512, DataType::Time64(TimeUnit::Microsecond))
})
b.iter(|| cast_array(&time32s_array, DataType::Time64(TimeUnit::Microsecond)))
});
c.bench_function("cast time64ns to time32s 512", |b| {
b.iter(|| {
cast_array::<Time64NanosecondType>(512, DataType::Time32(TimeUnit::Second))
})
b.iter(|| cast_array(&time64ns_array, DataType::Time32(TimeUnit::Second)))
});
c.bench_function("cast timestamp_ns to timestamp_s 512", |b| {
b.iter(|| {
cast_timestamp_array::<TimestampNanosecondType>(
512,
cast_array(
&time_ns_array,
DataType::Timestamp(TimeUnit::Nanosecond, None),
)
})
});
c.bench_function("cast timestamp_ms to timestamp_ns 512", |b| {
b.iter(|| {
cast_timestamp_array::<TimestampMillisecondType>(
512,
cast_array(
&time_ms_array,
DataType::Timestamp(TimeUnit::Nanosecond, None),
)
})
});
c.bench_function("cast timestamp_ms to i64 512", |b| {
b.iter(|| cast_timestamp_array::<TimestampMillisecondType>(512, DataType::Int64))
b.iter(|| cast_array(&time_ms_array, DataType::Int64))
});
}

Expand Down
85 changes: 57 additions & 28 deletions rust/arrow/src/array/array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
// under the License.

use std::any::Any;
use std::borrow::Borrow;
use std::convert::{From, TryFrom};
use std::fmt;
use std::io::Write;
Expand Down Expand Up @@ -693,6 +694,61 @@ impl fmt::Debug for PrimitiveArray<BooleanType> {
}
}

impl<'a, T: ArrowPrimitiveType> IntoIterator for &'a PrimitiveArray<T> {
type Item = Option<<T as ArrowPrimitiveType>::Native>;
type IntoIter = PrimitiveIter<'a, T>;

fn into_iter(self) -> Self::IntoIter {
PrimitiveIter::<'a, T>::new(self)
}
}

impl<'a, T: ArrowPrimitiveType> PrimitiveArray<T> {
/// constructs a new iterator
pub fn iter(&'a self) -> PrimitiveIter<'a, T> {
PrimitiveIter::<'a, T>::new(&self)
}
}

impl<T: ArrowPrimitiveType, Ptr: Borrow<Option<<T as ArrowPrimitiveType>::Native>>>
FromIterator<Ptr> for PrimitiveArray<T>
{
fn from_iter<I: IntoIterator<Item = Ptr>>(iter: I) -> Self {
let iter = iter.into_iter();
let (_, data_len) = iter.size_hint();
let data_len = data_len.expect("Iterator must be sized"); // panic if no upper bound.

let num_bytes = bit_util::ceil(data_len, 8);
let mut null_buf = MutableBuffer::new(num_bytes).with_bitset(num_bytes, false);
let mut val_buf = MutableBuffer::new(
data_len * mem::size_of::<<T as ArrowPrimitiveType>::Native>(),
);

let null = vec![0; mem::size_of::<<T as ArrowPrimitiveType>::Native>()];

let null_slice = null_buf.data_mut();
iter.enumerate().for_each(|(i, item)| {
if let Some(a) = item.borrow() {
bit_util::set_bit(null_slice, i);
val_buf.write_all(a.to_byte_slice()).unwrap();
} else {
val_buf.write_all(&null).unwrap();
}
});

let data = ArrayData::new(
T::get_data_type(),
data_len,
None,
Some(null_buf.freeze()),
0,
vec![val_buf.freeze()],
vec![],
);
PrimitiveArray::from(Arc::new(data))
}
}

// TODO: the macro is needed here because we'd get "conflicting implementations" error
// otherwise with both `From<Vec<T::Native>>` and `From<Vec<Option<T::Native>>>`.
// We should revisit this in future.
Expand All @@ -713,34 +769,7 @@ macro_rules! def_numeric_from_vec {
for PrimitiveArray<$ty>
{
fn from(data: Vec<Option<<$ty as ArrowPrimitiveType>::Native>>) -> Self {
let data_len = data.len();
let mut null_buf = make_null_buffer(data_len);
let mut val_buf = MutableBuffer::new(
data_len * mem::size_of::<<$ty as ArrowPrimitiveType>::Native>(),
);

{
let null =
vec![0; mem::size_of::<<$ty as ArrowPrimitiveType>::Native>()];
let null_slice = null_buf.data_mut();
for (i, v) in data.iter().enumerate() {
if let Some(n) = v {
bit_util::set_bit(null_slice, i);
// unwrap() in the following should be safe here since we've
// made sure enough space is allocated for the values.
val_buf.write_all(&n.to_byte_slice()).unwrap();
} else {
val_buf.write_all(&null).unwrap();
}
}
}

let array_data = ArrayData::builder($ty::get_data_type())
.len(data_len)
.add_buffer(val_buf.freeze())
.null_bit_buffer(null_buf.freeze())
.build();
PrimitiveArray::from(array_data)
PrimitiveArray::from_iter(data.iter())
}
}
};
Expand Down
85 changes: 85 additions & 0 deletions rust/arrow/src/array/iterator.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

use crate::datatypes::ArrowPrimitiveType;

use super::{Array, PrimitiveArray, PrimitiveArrayOps};

/// an iterator that returns Some(T) or None, that can be used on any non-boolean PrimitiveArray
#[derive(Debug)]
pub struct PrimitiveIter<'a, T: ArrowPrimitiveType> {
array: &'a PrimitiveArray<T>,
i: usize,
len: usize,
}

impl<'a, T: ArrowPrimitiveType> PrimitiveIter<'a, T> {
/// create a new iterator
pub fn new(array: &'a PrimitiveArray<T>) -> Self {
PrimitiveIter::<T> {
array,
i: 0,
len: array.len(),
}
}
}

impl<'a, T: ArrowPrimitiveType> std::iter::Iterator for PrimitiveIter<'a, T> {
type Item = Option<T::Native>;

fn next(&mut self) -> Option<Self::Item> {
let i = self.i;
if i >= self.len {
None
} else if self.array.is_null(i) {
self.i += 1;
Some(None)
} else {
self.i += 1;
Some(Some(self.array.value(i)))
}
}

fn size_hint(&self) -> (usize, Option<usize>) {
(self.len, Some(self.len))
}
}

/// all arrays have known size.
impl<'a, T: ArrowPrimitiveType> std::iter::ExactSizeIterator for PrimitiveIter<'a, T> {}

#[cfg(test)]
mod tests {
use std::sync::Arc;

use crate::array::{ArrayRef, Int32Array};

#[test]
fn test_primitive_array_iter_round_trip() {
let array = Int32Array::from(vec![Some(0), None, Some(2), None, Some(4)]);
let array = Arc::new(array) as ArrayRef;

let array = array.as_any().downcast_ref::<Int32Array>().unwrap();

// to and from iter, with a +1
let result: Int32Array =
array.iter().map(|e| e.and_then(|e| Some(e + 1))).collect();

let expected = Int32Array::from(vec![Some(1), None, Some(3), None, Some(5)]);
assert_eq!(result, expected);
}
}
5 changes: 5 additions & 0 deletions rust/arrow/src/array/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ mod builder;
mod cast;
mod data;
mod equal;
mod iterator;
mod null;
mod ord;
mod union;
Expand Down Expand Up @@ -239,6 +240,10 @@ pub type DurationMillisecondBuilder = PrimitiveBuilder<DurationMillisecondType>;
pub type DurationMicrosecondBuilder = PrimitiveBuilder<DurationMicrosecondType>;
pub type DurationNanosecondBuilder = PrimitiveBuilder<DurationNanosecondType>;

// --------------------- Array Iterator ---------------------

pub use self::iterator::*;

// --------------------- Array Equality ---------------------

pub use self::equal::ArrayEqual;
Expand Down
Loading