diff --git a/.gitignore b/.gitignore new file mode 100644 index 00000000000..96ef6c0b944 --- /dev/null +++ b/.gitignore @@ -0,0 +1,2 @@ +/target +Cargo.lock diff --git a/Cargo.toml b/Cargo.toml new file mode 100644 index 00000000000..2b843408db5 --- /dev/null +++ b/Cargo.toml @@ -0,0 +1,32 @@ +[package] +name = "arrow2" +version = "0.1.0" +authors = ["Jorge C. Leitao "] +edition = "2018" + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[dependencies] +serde = { version = "1.0", features = ["rc"] } +serde_derive = "1.0" +serde_json = { version = "1.0", features = ["preserve_order"] } +rand = "0.7" +num = "0.3" +chrono = "0.4" + +csv = { version = "1.1", optional = true } +regex = { version = "1.3", optional = true } +lazy_static = { version = "1.4", optional = true } +lexical-core = { version = "^0.7", optional = true } + +[dev-dependencies] +criterion = "0.3" +tempfile = "3" + +[features] +default = ["io_csv"] +io_csv = ["csv", "lazy_static", "lexical-core", "regex"] + +[[bench]] +name = "take_kernels" +harness = false diff --git a/README.md b/README.md index 48302afe14d..7001c0b0775 100644 --- a/README.md +++ b/README.md @@ -1,4 +1,219 @@ # Proposal: Safety by design in the Arrow crate -This document proposes a major redesign of the arrow crate to correctly handle +This document and repository proposes a major redesign of the arrow crate to correctly handle memory safety, offsets and type safety. + +TL;DR: this repo reproduces the main parts of the arrow crate with the proposed design in this repo. What it demonstrates: + +1. allocations along cache lines, buffers and memory manangement +2. import and export using the FFI / C data interface +3. implementation of nested types (Dict, List, Struct) +4. `dyn Array` and dynamic typing +5. array equality +6. one kernel (`take`) for primitives (1.3x faster than current master). + +Not demonstrated (but deemed feasible with the proposed design): + +1. SIMD +2. IO (CSV / JSON) +3. `transform/` module (that would need to be migrated) + +## Background + +The arrow crate uses `Buffer`, a generic struct to store contiguous memory regions (of bytes). This construct is used to store data from all arrays in the Rust implementation. The simplest example is a buffer containing `1i32`, that is represented as `&[0u8, 0u8, 0u8, 1u8]` or `&[1u8, 0u8, 0u8, 0u8]` depending on endianness. + +When a user wishes to read from a buffer, e.g. to perform a mathematical operation with its values, it needs to interpret the buffer in the target type. Because `Buffer` is a contiguous regions of bytes with no information about its underlying type, users must transmute its data into the respective type. + +Arrow currently transmutes buffers on almost all operations, and very often does not verify that there is type alignment nor correct length when we transmute it to a slice of type `&[T]`. + +Just as an example, the following code compiles, does not panic, and is unsound and results in UBs: + +```rust +let buffer = Buffer::from(&[0i32, 2i32]) +let data = ArrayData::new(DataType::Int64, 10, 0, None, 0, vec![buffer], vec![]); +let array = Float64Array::from(Arc::new(data)); + +println!("{:?}", array.value(1)); +``` + +Note how this initializes a buffer with bytes from `i32`, initializes an `ArrayData` with dynamic type +`Int64`, and then a `Float64Array` from `Arc`. `Float64Array`'s internals will essentially consume the pointer from the buffer, re-interpret it as `f64`, and offset it by `1`. + +Still within this example, if we were to use `ArrayData`'s datatype, `Int64`, to transmute the buffer, we would be creating `&[i64]` out of a buffer created out of `i32`. + +Any Rust developer acknowledges that this behavior goes very much against Rust's core premise that a function's behavior must not be undefined depending on whether the arguments are correct. The obvious observation is that transmute is one of the most `unsafe` Rust operations and not allowing the compiler to verify the necessary invariants is a large burden for users and developers to take. + +This simple example indicates a broader problem with the current design, that we now explore in detail. + +### Root cause analysis + +At its core, Arrow's current design is centered around two main `structs`: + +1. untyped `Buffer` +2. untyped `ArrayData` + +#### 1. untyped `Buffer` + +The crate's `Buffer` is untyped, which implies that once created, the type +information used to create it is lost. Consequently, the compiler has no way of verifying that a certain read can be performed. Thus, any read from it requires an alignment and size check at runtime. This is not only detrimental to performance, but also cumbersome. + +Over the past 4 months, I have identified and fixed more than 10 instances of unsound code derived from the misuse, within the crate itself, of `Buffer`. This hints that there may be a design problem. + +#### 2. untyped `ArrayData` + +`ArrayData` is a `struct` containing buffers and child data that does not differentiate which type of array it represents at compile time. + +Consequently, all buffer reads from `ArrayData`'s buffers are effectively `unsafe`, as they require certain invariants to hold. These invariants are strictly related to `ArrayData::datatype`: this `enum` differentiates how to transmute the `ArrayData::buffers`. For example, an `ArrayData::datatype` equal to `DataType::UInt32` implies that the buffer should be interpreted / transmuted as `u32`. + +The challenge with the above struct is that it is not possible to prove that `ArrayData`'s creation and reads +are sound at compile time. As the sample above shows, there is nothing wrong, during compilation, with passing a buffer with `i32` to an `PrimitiveArray` expecting `i64` (via `ArrayData`). We could of course check it at runtime, and we should, but we are defeating the whole purpose of using a typed system as powerful as Rust offers. + +The main consequence of this observation is that the current code has a significant maintenance cost, as we have to rigorously check the types of the buffers we are working with. The example above shows +that, even with that rigour, we fail to identify obvious problems at runtime. + +Overall, there are many instances of our code where we expose public APIs marked as `safe` that are `unsafe` and lead to undefined behavior if used incorrectly. This goes against the core goals of the Rust language, and significantly weakens Arrow Rust's implementation core premise that the compiler and borrow checker proves many of the memory safety concerns that we may have. + +Equally important, the inability of the compiler to prove certain invariants is detrimental to performance. As an example, the implementation of the `take` kernel in this repo is semantically equivalent to the current master, but 1.3x faster. + +## Proposal + +The proposal is to redesign the Arrow crate to address the design limitation described above. +This has a major impact into the whole ecosystem that relies on `Buffer`, `MutableBuffer`, `bytes`, +and has limited impact on high-end `Array` APIs that rely on iterators and other higher abstractions. + +Broadly speaking, this proposes the following changes: + +1. Replace `Buffer` by `Buffer` +2. Replace `MutableBuffer` by `MutableBuffer` +3. Replace `Bytes` by `Bytes` +4. Remove `RawPointer` +5. Remove `ArrayData` and place its contents directly on the corresponding arrays +6. Make childs be `Arc` +7. Remove `Array::data` and `Array::data_ref` +8. Redesign `bitmap` to hold offsets +9. Replace `Array::slice` by concrete implementations +10. Make `PrimitiveArray` instead of `PrimitiveType` + +### 1-4. Replace `Buffer` by `Buffer` + +This is one of the core changes and is a major design change: `Buffer`s must be typed. There will be +an `unsafe` trait, `NativeType`, implemented for `u8, u16, u32, u64, i8, i16, i32, i64, f32, f64` corresponding to the only types that can be represented in a buffer. + +Create a generic `Buffer`, `Bytes`, `MutableBuffer`, that corresponds to a byte-aligned, cache line-aligned contiguous memory regions. + +This allow us to only have to deal with `transmute` at FFI boundaries. Effectively, it allow us to not +have to rely on the highly `unsafe` `RawPointer` on array implementations, as well as `as_typed` function that transmutes buffers. + +[Here](src/buffer/immutable.rs) you can find the concrete implementation proposed in this repo. + +### 5. Remove `ArrayData` and place its contents directly on the corresponding arrays + +For example, for primitive types, such as `Float64` and `Date32`, declare a `PrimitiveArray` as follows: + +```rust +#[derive(Debug, Clone)] +pub struct PrimitiveArray { + data_type: DataType, + values: Buffer, + validity: Option, + offset: usize, +} +``` + +Note how `T` denotes the _physical_ representation, while `data_type` corresponds to the _logical_ representation. This is so that `Timestamp` with timezones becomes a first-class citizen (it currently isn't). + +### 6. Child data is stored as `Arc` + +For example, the struct holding a `ListArray` is [defined](src/array/list.rs) as + +```rust +#[derive(Debug, Clone)] +pub struct ListArray { + data_type: DataType, + offsets: Buffer, + values: Arc, + validity: Option, + offset: usize, +} +``` + +This greatly simplifies creating nested structures, as there is no longer any `ArrayData`. + +Accessing individual (nested) values of this array, e.g. for iterations, works as before: + +```rust +impl ListArray { + pub fn value(&self, i: usize) -> Box { + let offsets = self.offsets.as_slice(); + let offset = offsets[i]; + let offset_1 = offsets[i + 1]; + let length = (offset_1 - offset).to_usize().unwrap(); + + self.values.slice(offset.to_usize().unwrap(), length) + } +} +``` + +Note the usage of `Array::slice`, an abstract method that each specific implementation must know how to perform. This method has been problematic in the past because its implementation is type-specific, but +the current implementation is type-agnostic (i.e. a bug). + +In the case of a list array: + +```rust +impl ListArray { + pub fn slice(&self, offset: usize, length: usize) -> Self { + let validity = self.validity.as_ref().map(|x| x.slice(offset, length)); + Self { + data_type: self.data_type.clone(), + offsets: self.offsets.slice(offset, length), + values: self.values.clone(), + validity, + offset, + } + } +} + +impl Array for ListArray { + fn slice(&self, offset: usize, length: usize) -> Box { + Box::new(self.slice(offset, length)) + } +} +``` + +Note how the `offsets` were sliced, but the `values` were not. In the current master, both get sliced, which +is semantically incorrect. + +Also note that the choice of `Arc` over `Box` is solely for the purposes of enabling a cheap `Clone`. + +### 7. Remove `Array::data` and `Array::data_ref` + +Without `ArrayData`, these methods are no longer required. Required traits to enable FFI are instead +provided. This repo supports FFI (import and export), which demonstrates that `ArrayData` is not needed. + +### 8. Redesign bitmap + +This implementation redesigns `Bitmap` to allow it to hold `Bytes` and an offset in `bits`. +`Bitmap` is the only struct that holds bitmaps, and has methods to efficiently `get` bits. +Because it has an offset in bits, it contains all information required to correctly offset itself. + +This way, users no longer have to use `MutableBuffer` to handle `bitmaps`, use `unsafe` `get_bit_raw`, +offsetting in bits vs bytes, etc. + +### 9. Replace `Array::slice` by concrete implementations + +Slice is an operation whose implementation depends on the particular logical type being implemented. +This proposes that we move `slice` to be a type-specific implementation. + +### 10. Make `PrimitiveArray` instead of `PrimitiveType` + +Currently, `PrimitiveArray` depends on a `ArrowPrimitiveType`, which has an associated `DataType`. +This makes it difficult to distinguish the physical representation from its logical one. I.e. `Int64Type` is both +a physical (`i64`) and logical type (`DataType::Int64`). There are logical types whose physical representation +is the same (e.g. `Timestamp(_, _)`). Hard-coding the logical representation in the type takes away this fundamental +separation. + +This proposal separates the two aspects: the generic argument, `T`, is used to declare the physical layout, which, within Rust, is used for type-safety. +The `DataType` is used for a logical representation which, in the context of Rust, is used for dynamic typing, i.e. it enables the trait Object `Array` to implement `as_any()` and use `Array::data_type()` to decide to which concrete +implementation `&dyn Array` should be `downcast_ref`ed to. + +With this design, an incorrect `DataType` only causes `downcast_ref` to fail and cannot cause undefined behavior. The only possible undefined behavior in this new design is at FFI boundaries: a byte buffer that is incorrect for a `DataType` causes the library to interpret bytes of type `x` as type `y`, which is undefined behavior. diff --git a/benches/take_kernels.rs b/benches/take_kernels.rs new file mode 100644 index 00000000000..2855fb9d91c --- /dev/null +++ b/benches/take_kernels.rs @@ -0,0 +1,88 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#[macro_use] +extern crate criterion; +use criterion::Criterion; + +use rand::distributions::{Distribution, Standard}; +use rand::{rngs::StdRng, Rng, SeedableRng}; + +use arrow2::{array::*, datatypes::DataType, datatypes::PrimitiveType}; +use arrow2::{compute::take, datatypes::Int32Type}; + +/// Returns fixed seedable RNG +pub fn seedable_rng() -> StdRng { + StdRng::seed_from_u64(42) +} + +// cast array from specified primitive array type to desired data type +fn create_primitive(size: usize) -> PrimitiveArray +where + T: PrimitiveType, + Standard: Distribution, +{ + seedable_rng() + .sample_iter(&Standard) + .take(size) + .map(Some) + .collect::>() + .to(T::DATA_TYPE) +} + +fn create_random_index(size: usize, null_density: f32) -> PrimitiveArray { + let mut rng = seedable_rng(); + (0..size) + .map(|_| { + if rng.gen::() > null_density { + let value = rng.gen_range::(0i32, size as i32); + Some(value) + } else { + None + } + }) + .collect::>() + .to(DataType::Int32) +} + +fn bench_take(values: &dyn Array, indices: &PrimitiveArray) { + criterion::black_box(take::take(values, &indices, None).unwrap()); +} + +fn add_benchmark(c: &mut Criterion) { + let values = create_primitive::(512); + let indices = create_random_index(512, 0.0); + c.bench_function("take i32 512", |b| b.iter(|| bench_take(&values, &indices))); + let values = create_primitive::(1024); + let indices = create_random_index(1024, 0.0); + c.bench_function("take i32 1024", |b| { + b.iter(|| bench_take(&values, &indices)) + }); + + let indices = create_random_index(512, 0.5); + c.bench_function("take i32 nulls 512", |b| { + b.iter(|| bench_take(&values, &indices)) + }); + let values = create_primitive::(1024); + let indices = create_random_index(1024, 0.5); + c.bench_function("take i32 nulls 1024", |b| { + b.iter(|| bench_take(&values, &indices)) + }); +} + +criterion_group!(benches, add_benchmark); +criterion_main!(benches); diff --git a/src/array/binary/iterator.rs b/src/array/binary/iterator.rs new file mode 100644 index 00000000000..cc2aa29b99b --- /dev/null +++ b/src/array/binary/iterator.rs @@ -0,0 +1,66 @@ +use crate::array::Array; +use crate::array::Offset; + +use super::BinaryArray; + +impl<'a, O: Offset> IntoIterator for &'a BinaryArray { + type Item = Option<&'a [u8]>; + type IntoIter = BinaryIter<'a, O>; + + fn into_iter(self) -> Self::IntoIter { + BinaryIter::new(self) + } +} + +impl<'a, O: Offset> BinaryArray { + /// constructs a new iterator + pub fn iter(&'a self) -> BinaryIter<'a, O> { + BinaryIter::new(&self) + } +} + +/// an iterator that returns `Some(&[u8])` or `None`, for binary arrays +#[derive(Debug)] +pub struct BinaryIter<'a, O> +where + O: Offset, +{ + array: &'a BinaryArray, + i: usize, + len: usize, +} + +impl<'a, O: Offset> BinaryIter<'a, O> { + /// create a new iterator + pub fn new(array: &'a BinaryArray) -> Self { + BinaryIter:: { + array, + i: 0, + len: array.len(), + } + } +} + +impl<'a, O: Offset> std::iter::Iterator for BinaryIter<'a, O> { + type Item = Option<&'a [u8]>; + + fn next(&mut self) -> Option { + let i = self.i; + if i >= self.len { + None + } else if self.array.is_null(i) { + self.i += 1; + Some(None) + } else { + self.i += 1; + Some(Some(unsafe { self.array.value_unchecked(i) })) + } + } + + fn size_hint(&self) -> (usize, Option) { + (self.len - self.i, Some(self.len - self.i)) + } +} + +/// all arrays have known size. +impl<'a, O: Offset> std::iter::ExactSizeIterator for BinaryIter<'a, O> {} diff --git a/src/array/binary/mod.rs b/src/array/binary/mod.rs new file mode 100644 index 00000000000..0a61773410a --- /dev/null +++ b/src/array/binary/mod.rs @@ -0,0 +1,152 @@ +use crate::{ + buffer::{Bitmap, Buffer}, + datatypes::DataType, + ffi::ArrowArray, +}; + +use super::{ffi::ToFFI, specification::check_offsets, specification::Offset, Array, FromFFI}; + +use crate::error::Result; + +#[derive(Debug, Clone)] +pub struct BinaryArray { + data_type: DataType, + offsets: Buffer, + values: Buffer, + validity: Option, + offset: usize, +} + +impl BinaryArray { + pub fn new_empty() -> Self { + Self::from_data(Buffer::from(&[O::zero()]), Buffer::new(), None) + } + + pub fn from_data(offsets: Buffer, values: Buffer, validity: Option) -> Self { + check_offsets(&offsets, values.len()); + + Self { + data_type: if O::is_large() { + DataType::LargeBinary + } else { + DataType::Binary + }, + offsets, + values, + validity, + offset: 0, + } + } + + pub fn slice(&self, offset: usize, length: usize) -> Self { + let validity = self.validity.clone().map(|x| x.slice(offset, length)); + Self { + data_type: self.data_type.clone(), + offsets: self.offsets.clone().slice(offset, length), + values: self.values.clone(), + validity, + offset: self.offset + offset, + } + } + + /// Returns the element at index `i` as &str + /// # Safety + /// Assumes that the `i < self.len`. + pub unsafe fn value_unchecked(&self, i: usize) -> &[u8] { + let offset = *self.offsets.as_ptr().add(i); + let offset_1 = *self.offsets.as_ptr().add(i + 1); + let length = (offset_1 - offset).to_usize().unwrap(); + let offset = offset.to_usize().unwrap(); + + std::slice::from_raw_parts(self.values.as_ptr().add(offset), length) + } + + #[inline] + pub fn offsets(&self) -> &[O] { + self.offsets.as_slice() + } + + #[inline] + pub fn values(&self) -> &[u8] { + self.values.as_slice() + } +} + +impl Array for BinaryArray { + #[inline] + fn as_any(&self) -> &dyn std::any::Any { + self + } + + #[inline] + fn len(&self) -> usize { + self.offsets.len() - 1 + } + + #[inline] + fn data_type(&self) -> &DataType { + &self.data_type + } + + fn nulls(&self) -> &Option { + &self.validity + } + + fn slice(&self, offset: usize, length: usize) -> Box { + Box::new(self.slice(offset, length)) + } +} + +unsafe impl ToFFI for BinaryArray { + fn buffers(&self) -> [Option>; 3] { + unsafe { + [ + self.validity.as_ref().map(|x| x.as_ptr()), + Some(std::ptr::NonNull::new_unchecked( + self.offsets.as_ptr() as *mut u8 + )), + Some(std::ptr::NonNull::new_unchecked( + self.values.as_ptr() as *mut u8 + )), + ] + } + } + + #[inline] + fn offset(&self) -> usize { + self.offset + } +} + +unsafe impl FromFFI for BinaryArray { + fn try_from_ffi(data_type: DataType, array: ArrowArray) -> Result { + let expected = if O::is_large() { + DataType::LargeBinary + } else { + DataType::Binary + }; + assert_eq!(data_type, expected); + + let length = array.len(); + let offset = array.offset(); + let mut validity = array.null_bit_buffer(); + let mut offsets = unsafe { array.buffer::(0)? }; + let values = unsafe { array.buffer::(1)? }; + + if offset > 0 { + offsets = offsets.slice(offset, length); + validity = validity.map(|x| x.slice(offset, length)) + } + + Ok(Self { + data_type, + offsets, + values, + validity, + offset: 0, + }) + } +} + +mod iterator; +pub use iterator::*; diff --git a/src/array/boolean/from.rs b/src/array/boolean/from.rs new file mode 100644 index 00000000000..edd8c73208d --- /dev/null +++ b/src/array/boolean/from.rs @@ -0,0 +1,130 @@ +use crate::buffer::{Bitmap, MutableBitmap}; + +use super::BooleanArray; + +impl BooleanArray { + pub fn from_slice>(slice: P) -> Self { + unsafe { Self::from_trusted_len_iter(slice.as_ref().iter().map(Some)) } + } +} + +impl BooleanArray { + /// Creates a [`BooleanArray`] from an iterator of trusted length. + /// # Safety + /// The iterator must be [`TrustedLen`](https://doc.rust-lang.org/std/iter/trait.TrustedLen.html). + /// I.e. that `size_hint().1` correctly reports its length. + #[inline] + pub unsafe fn from_trusted_len_iter(iter: I) -> Self + where + P: std::borrow::Borrow, + I: IntoIterator>, + { + let iterator = iter.into_iter(); + + let (validity, values) = trusted_len_unzip(iterator); + + Self::from_data(values, validity) + } + + /// Creates a [`PrimitiveArray`] from an falible iterator of trusted length. + /// # Safety + /// The iterator must be [`TrustedLen`](https://doc.rust-lang.org/std/iter/trait.TrustedLen.html). + /// I.e. that `size_hint().1` correctly reports its length. + #[inline] + pub unsafe fn try_from_trusted_len_iter(iter: I) -> Result + where + P: std::borrow::Borrow, + I: IntoIterator, E>>, + { + let iterator = iter.into_iter(); + + let (validity, values) = try_trusted_len_unzip(iterator)?; + + Ok(Self::from_data(values, validity)) + } +} + +/// Creates a Bitmap and an optional [`Bitmap`] from an iterator of `Option`. +/// The first buffer corresponds to a bitmap buffer, the second one +/// corresponds to a values buffer. +/// # Safety +/// The caller must ensure that `iterator` is `TrustedLen`. +#[inline] +pub(crate) unsafe fn trusted_len_unzip(iterator: I) -> (Option, Bitmap) +where + P: std::borrow::Borrow, + I: Iterator>, +{ + let (_, upper) = iterator.size_hint(); + let len = upper.expect("trusted_len_unzip requires an upper limit"); + + let mut null = MutableBitmap::with_capacity(len); + let mut values = MutableBitmap::with_capacity(len); + + for item in iterator { + let item = if let Some(item) = item { + null.push_unchecked(true); + *item.borrow() + } else { + null.push_unchecked(false); + false + }; + values.push_unchecked(item); + } + assert_eq!( + values.len(), + len, + "Trusted iterator length was not accurately reported" + ); + values.set_len(len); + null.set_len(len); + + let bitmap = if null.null_count() > 0 { + Some(null.into()) + } else { + None + }; + (bitmap, values.into()) +} + +/// # Safety +/// The caller must ensure that `iterator` is `TrustedLen`. +#[inline] +pub(crate) unsafe fn try_trusted_len_unzip( + iterator: I, +) -> Result<(Option, Bitmap), E> +where + P: std::borrow::Borrow, + I: Iterator, E>>, +{ + let (_, upper) = iterator.size_hint(); + let len = upper.expect("trusted_len_unzip requires an upper limit"); + + let mut null = MutableBitmap::with_capacity(len); + let mut values = MutableBitmap::with_capacity(len); + + for item in iterator { + let item = if let Some(item) = item? { + null.push_unchecked(true); + *item.borrow() + } else { + null.push_unchecked(false); + false + }; + values.push_unchecked(item); + } + assert_eq!( + values.len(), + len, + "Trusted iterator length was not accurately reported" + ); + values.set_len(len); + null.set_len(len); + + let bitmap = if null.null_count() > 0 { + Some(null.into()) + } else { + None + }; + Ok((bitmap, values.into())) +} diff --git a/src/array/boolean/iterator.rs b/src/array/boolean/iterator.rs new file mode 100644 index 00000000000..05a15e17cc3 --- /dev/null +++ b/src/array/boolean/iterator.rs @@ -0,0 +1,81 @@ +use crate::array::Array; + +use super::BooleanArray; + +impl<'a> IntoIterator for &'a BooleanArray { + type Item = Option; + type IntoIter = BooleanIter<'a>; + + fn into_iter(self) -> Self::IntoIter { + BooleanIter::<'a>::new(self) + } +} + +impl<'a> BooleanArray { + /// constructs a new iterator + pub fn iter(&'a self) -> BooleanIter<'a> { + BooleanIter::<'a>::new(&self) + } +} + +/// an iterator that returns Some(bool) or None. +// Note: This implementation is based on std's [Vec]s' [IntoIter]. +#[derive(Debug)] +pub struct BooleanIter<'a> { + array: &'a BooleanArray, + current: usize, + current_end: usize, +} + +impl<'a> BooleanIter<'a> { + /// create a new iterator + pub fn new(array: &'a BooleanArray) -> Self { + BooleanIter { + array, + current: 0, + current_end: array.len(), + } + } +} + +impl<'a> std::iter::Iterator for BooleanIter<'a> { + type Item = Option; + + fn next(&mut self) -> Option { + if self.current == self.current_end { + None + } else if self.array.is_null(self.current) { + self.current += 1; + Some(None) + } else { + let old = self.current; + self.current += 1; + Some(Some(self.array.value(old))) + } + } + + fn size_hint(&self) -> (usize, Option) { + ( + self.array.len() - self.current, + Some(self.array.len() - self.current), + ) + } +} + +impl<'a> std::iter::DoubleEndedIterator for BooleanIter<'a> { + fn next_back(&mut self) -> Option { + if self.current_end == self.current { + None + } else { + self.current_end -= 1; + Some(if self.array.is_null(self.current_end) { + None + } else { + Some(self.array.value(self.current_end)) + }) + } + } +} + +/// all arrays have known size. +impl<'a> std::iter::ExactSizeIterator for BooleanIter<'a> {} diff --git a/src/array/boolean/mod.rs b/src/array/boolean/mod.rs new file mode 100644 index 00000000000..cb2ae305492 --- /dev/null +++ b/src/array/boolean/mod.rs @@ -0,0 +1,117 @@ +use crate::{buffer::Bitmap, datatypes::DataType, ffi::ArrowArray}; + +use super::{ffi::ToFFI, Array, FromFFI}; + +use crate::error::Result; + +#[derive(Debug, Clone)] +pub struct BooleanArray { + data_type: DataType, + values: Bitmap, + validity: Option, + offset: usize, +} + +impl BooleanArray { + pub fn new_empty() -> Self { + Self::from_data(Bitmap::new(), None) + } + + pub fn from_data(values: Bitmap, validity: Option) -> Self { + Self { + data_type: DataType::Boolean, + values, + validity, + offset: 0, + } + } + + pub fn slice(&self, offset: usize, length: usize) -> Self { + let validity = self.validity.clone().map(|x| x.slice(offset, length)); + Self { + data_type: self.data_type.clone(), + values: self.values.clone().slice(offset, length), + validity, + offset: self.offset + offset, + } + } + + /// Returns the element at index `i` as &str + pub fn value(&self, i: usize) -> bool { + self.values.get_bit(i) + } + + pub fn values(&self) -> &Bitmap { + &self.values + } +} + +impl Array for BooleanArray { + #[inline] + fn as_any(&self) -> &dyn std::any::Any { + self + } + + #[inline] + fn len(&self) -> usize { + self.values.len() + } + + #[inline] + fn data_type(&self) -> &DataType { + &self.data_type + } + + fn nulls(&self) -> &Option { + &self.validity + } + + fn slice(&self, offset: usize, length: usize) -> Box { + Box::new(self.slice(offset, length)) + } +} + +unsafe impl ToFFI for BooleanArray { + fn buffers(&self) -> [Option>; 3] { + [ + self.validity.as_ref().map(|x| x.as_ptr()), + Some(self.values.as_ptr()), + None, + ] + } + + fn offset(&self) -> usize { + self.offset + } +} + +unsafe impl FromFFI for BooleanArray { + fn try_from_ffi(data_type: DataType, array: ArrowArray) -> Result { + let length = array.len(); + let offset = array.offset(); + let mut validity = array.null_bit_buffer(); + let mut values = unsafe { array.bitmap(0)? }; + + if offset > 0 { + values = values.slice(offset, length); + validity = validity.map(|x| x.slice(offset, length)) + } + Ok(Self { + data_type, + values, + validity, + offset: 0, + }) + } +} + +impl]>> From

for BooleanArray { + fn from(slice: P) -> Self { + unsafe { Self::from_trusted_len_iter(slice.as_ref().iter().map(|x| x.as_ref())) } + } +} + +mod iterator; +pub use iterator::*; + +mod from; diff --git a/src/array/dictionary/from.rs b/src/array/dictionary/from.rs new file mode 100644 index 00000000000..4656a5f352b --- /dev/null +++ b/src/array/dictionary/from.rs @@ -0,0 +1,66 @@ +use std::collections::HashMap; +use std::hash::Hash; + +use crate::{ + array::{Builder, Primitive, ToArray}, + datatypes::DataType, + error::{ArrowError, Result}, +}; + +use super::{DictionaryArray, DictionaryKey}; + +#[derive(Debug)] +pub struct DictionaryPrimitive { + keys: Primitive, + values: A, +} + +impl DictionaryPrimitive { + pub fn to(self, data_type: DataType) -> DictionaryArray { + let (keys, values) = DictionaryArray::::get_child(&data_type); + let values = self.values.to_arc(values); + DictionaryArray::from_data(self.keys.to(keys.clone()), values) + } +} + +pub fn dict_from_iter(iter: I) -> Result> +where + K: DictionaryKey, + B: Builder, + T: Eq + Hash + Clone, + P: std::borrow::Borrow>, + I: IntoIterator, +{ + let mut map = HashMap::::new(); + + let iterator = iter.into_iter(); + + // [10, 20, 10, 30] + // keys = [0, 1, 0, 2] + // values = [10, 20, 30] + + // if value not in set { + // key.push(set.len()); + // set.insert(value) + // } else { + // key.push(set.len()) + // } + let mut values = B::with_capacity(0); + let keys: Primitive = iterator + .map(|item| match item.borrow() { + Some(v) => match map.get(v) { + Some(key) => Ok(Some(*key)), + None => { + let key = + K::from_usize(map.len()).ok_or(ArrowError::DictionaryKeyOverflowError)?; + values.push(Some(v)); + map.insert(v.clone(), key); + Ok(Some(key)) + } + }, + None => Ok(None), + }) + .collect::>()?; + + Ok(DictionaryPrimitive:: { keys, values }) +} diff --git a/src/array/dictionary/mod.rs b/src/array/dictionary/mod.rs new file mode 100644 index 00000000000..0e9ba4278e8 --- /dev/null +++ b/src/array/dictionary/mod.rs @@ -0,0 +1,134 @@ +use std::sync::Arc; + +use crate::{ + buffer::{types::NativeType, Bitmap}, + datatypes::DataType, +}; + +use super::{ffi::ToFFI, new_empty_array, primitive::PrimitiveArray, Array}; + +pub unsafe trait DictionaryKey: NativeType + num::NumCast + num::FromPrimitive { + const DATA_TYPE: DataType; +} + +unsafe impl DictionaryKey for i8 { + const DATA_TYPE: DataType = DataType::Int8; +} +unsafe impl DictionaryKey for i16 { + const DATA_TYPE: DataType = DataType::Int16; +} +unsafe impl DictionaryKey for i32 { + const DATA_TYPE: DataType = DataType::Int32; +} +unsafe impl DictionaryKey for i64 { + const DATA_TYPE: DataType = DataType::Int64; +} +unsafe impl DictionaryKey for u8 { + const DATA_TYPE: DataType = DataType::UInt8; +} +unsafe impl DictionaryKey for u16 { + const DATA_TYPE: DataType = DataType::UInt16; +} +unsafe impl DictionaryKey for u32 { + const DATA_TYPE: DataType = DataType::UInt32; +} +unsafe impl DictionaryKey for u64 { + const DATA_TYPE: DataType = DataType::UInt64; +} + +mod from; +pub use from::*; + +#[derive(Debug, Clone)] +pub struct DictionaryArray { + data_type: DataType, + keys: PrimitiveArray, + values: Arc, + offset: usize, +} + +impl DictionaryArray { + pub fn new_empty(datatype: DataType) -> Self { + let values = new_empty_array(datatype).into(); + Self::from_data(PrimitiveArray::::new_empty(K::DATA_TYPE), values) + } + + pub fn from_data(keys: PrimitiveArray, values: Arc) -> Self { + let data_type = DataType::Dictionary( + Box::new(keys.data_type().clone()), + Box::new(values.data_type().clone()), + ); + + Self { + data_type, + keys, + values, + offset: 0, + } + } + + pub fn slice(&self, offset: usize, length: usize) -> Self { + Self { + data_type: self.data_type.clone(), + keys: self.keys.clone().slice(offset, length), + values: self.values.clone(), + offset: self.offset + offset, + } + } + + #[inline] + pub fn keys(&self) -> &PrimitiveArray { + &self.keys + } + + #[inline] + pub fn values(&self) -> &Arc { + &self.values + } +} + +impl DictionaryArray { + pub(crate) fn get_child(data_type: &DataType) -> (&DataType, &DataType) { + if let DataType::Dictionary(keys, values) = data_type { + (keys.as_ref(), values.as_ref()) + } else { + panic!("Wrong DataType") + } + } +} + +impl Array for DictionaryArray { + #[inline] + fn as_any(&self) -> &dyn std::any::Any { + self + } + + #[inline] + fn len(&self) -> usize { + self.values.len() + } + + #[inline] + fn data_type(&self) -> &DataType { + &self.data_type + } + + fn nulls(&self) -> &Option { + self.keys.nulls() + } + + fn slice(&self, offset: usize, length: usize) -> Box { + Box::new(self.slice(offset, length)) + } +} + +unsafe impl ToFFI for DictionaryArray { + fn buffers(&self) -> [Option>; 3] { + [self.keys.nulls().as_ref().map(|x| x.as_ptr()), None, None] + } + + #[inline] + fn offset(&self) -> usize { + self.offset + } +} diff --git a/src/array/equal/boolean.rs b/src/array/equal/boolean.rs new file mode 100644 index 00000000000..19ab69b956a --- /dev/null +++ b/src/array/equal/boolean.rs @@ -0,0 +1,56 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::{array::BooleanArray, buffer::Bitmap}; + +use super::utils::{count_nulls, equal_bits}; + +pub(super) fn equal( + lhs: &BooleanArray, + rhs: &BooleanArray, + lhs_nulls: &Option, + rhs_nulls: &Option, + lhs_start: usize, + rhs_start: usize, + len: usize, +) -> bool { + let lhs_values = lhs.values(); + let rhs_values = rhs.values(); + + let lhs_null_count = count_nulls(lhs_nulls, lhs_start, len); + let rhs_null_count = count_nulls(rhs_nulls, rhs_start, len); + + if lhs_null_count == 0 && rhs_null_count == 0 { + equal_bits(lhs_values, rhs_values, lhs_start, rhs_start, len) + } else { + // get a ref of the null buffer bytes, to use in testing for nullness + let lhs_null_bytes = lhs_nulls.as_ref().unwrap(); + let rhs_null_bytes = rhs_nulls.as_ref().unwrap(); + + (0..len).all(|i| { + let lhs_pos = lhs_start + i; + let rhs_pos = rhs_start + i; + + let lhs_is_null = !lhs_null_bytes.get_bit(lhs_pos); + let rhs_is_null = !rhs_null_bytes.get_bit(rhs_pos); + + lhs_is_null + || (lhs_is_null == rhs_is_null) + && equal_bits(lhs_values, rhs_values, lhs_pos, rhs_pos, 1) + }) + } +} diff --git a/src/array/equal/list.rs b/src/array/equal/list.rs new file mode 100644 index 00000000000..a3de9d48161 --- /dev/null +++ b/src/array/equal/list.rs @@ -0,0 +1,225 @@ +use crate::{ + array::{Array, ListArray, Offset}, + buffer::Bitmap, +}; + +use super::{ + equal_range, + utils::{child_logical_null_buffer, count_nulls}, +}; + +fn lengths_equal(lhs: &[O], rhs: &[O]) -> bool { + // invariant from `base_equal` + debug_assert_eq!(lhs.len(), rhs.len()); + + if lhs.is_empty() { + return true; + } + + if lhs[0] == O::zero() && rhs[0] == O::zero() { + return lhs == rhs; + }; + + // The expensive case, e.g. + // [0, 2, 4, 6, 9] == [4, 6, 8, 10, 13] + lhs.windows(2) + .zip(rhs.windows(2)) + .all(|(lhs_offsets, rhs_offsets)| { + // length of left == length of right + (lhs_offsets[1] - lhs_offsets[0]) == (rhs_offsets[1] - rhs_offsets[0]) + }) +} + +#[allow(clippy::too_many_arguments)] +#[inline] +fn offset_value_equal( + lhs_values: &dyn Array, + rhs_values: &dyn Array, + lhs_nulls: &Option, + rhs_nulls: &Option, + lhs_offsets: &[O], + rhs_offsets: &[O], + lhs_pos: usize, + rhs_pos: usize, + len: usize, +) -> bool { + let lhs_start = lhs_offsets[lhs_pos].to_usize().unwrap(); + let rhs_start = rhs_offsets[rhs_pos].to_usize().unwrap(); + let lhs_len = lhs_offsets[lhs_pos + len] - lhs_offsets[lhs_pos]; + let rhs_len = rhs_offsets[rhs_pos + len] - rhs_offsets[rhs_pos]; + + lhs_len == rhs_len + && equal_range( + lhs_values, + rhs_values, + lhs_nulls, + rhs_nulls, + lhs_start, + rhs_start, + lhs_len.to_usize().unwrap(), + ) +} + +pub(super) fn equal( + lhs: &ListArray, + rhs: &ListArray, + lhs_nulls: &Option, + rhs_nulls: &Option, + lhs_start: usize, + rhs_start: usize, + len: usize, +) -> bool { + let lhs_offsets = lhs.offsets().as_slice(); + let rhs_offsets = rhs.offsets().as_slice(); + + // There is an edge-case where a n-length list that has 0 children, results in panics. + // For example; an array with offsets [0, 0, 0, 0, 0] has 4 slots, but will have + // no valid children. + // Under logical equality, the child null bitmap will be an empty buffer, as there are + // no child values. This causes panics when trying to count set bits. + // + // We caught this by chance from an accidental test-case, but due to the nature of this + // crash only occuring on list equality checks, we are adding a check here, instead of + // on the buffer/bitmap utilities, as a length check would incur a penalty for almost all + // other use-cases. + // + // The solution is to check the number of child values from offsets, and return `true` if + // they = 0. Empty arrays are equal, so this is correct. + // + // It's unlikely that one would create a n-length list array with no values, where n > 0, + // however, one is more likely to slice into a list array and get a region that has 0 + // child values. + // The test that triggered this behaviour had [4, 4] as a slice of 1 value slot. + let lhs_child_length = lhs_offsets.get(len).unwrap().to_usize().unwrap() + - lhs_offsets.first().unwrap().to_usize().unwrap(); + let rhs_child_length = rhs_offsets.get(len).unwrap().to_usize().unwrap() + - rhs_offsets.first().unwrap().to_usize().unwrap(); + + if lhs_child_length == 0 && lhs_child_length == rhs_child_length { + return true; + } + + let lhs_values = lhs.values().as_ref(); + let rhs_values = rhs.values().as_ref(); + + let lhs_null_count = count_nulls(lhs_nulls, lhs_start, len); + let rhs_null_count = count_nulls(rhs_nulls, rhs_start, len); + + // compute the child logical bitmap + let child_lhs_nulls = child_logical_null_buffer(lhs, lhs_nulls, lhs_values); + let child_rhs_nulls = child_logical_null_buffer(rhs, rhs_nulls, rhs_values); + + if lhs_null_count == 0 && rhs_null_count == 0 { + lengths_equal( + &lhs_offsets[lhs_start..lhs_start + len], + &rhs_offsets[rhs_start..rhs_start + len], + ) && equal_range( + lhs_values, + rhs_values, + &child_lhs_nulls, + &child_rhs_nulls, + lhs_offsets[lhs_start].to_usize().unwrap(), + rhs_offsets[rhs_start].to_usize().unwrap(), + (lhs_offsets[len] - lhs_offsets[lhs_start]) + .to_usize() + .unwrap(), + ) + } else { + // get a ref of the parent null buffer bytes, to use in testing for nullness + let lhs_null_bytes = lhs_nulls.as_ref().unwrap(); + let rhs_null_bytes = rhs_nulls.as_ref().unwrap(); + // with nulls, we need to compare item by item whenever it is not null + (0..len).all(|i| { + let lhs_pos = lhs_start + i; + let rhs_pos = rhs_start + i; + + let lhs_is_null = !lhs_null_bytes.get_bit(lhs_pos); + let rhs_is_null = !rhs_null_bytes.get_bit(rhs_pos); + + lhs_is_null + || (lhs_is_null == rhs_is_null) + && offset_value_equal::( + lhs_values, + rhs_values, + &child_lhs_nulls, + &child_rhs_nulls, + lhs_offsets, + rhs_offsets, + lhs_pos, + rhs_pos, + 1, + ) + }) + } +} + +#[cfg(test)] +mod tests { + use crate::array::{ListPrimitive, Primitive}; + use crate::{array::equal::tests::test_equal, datatypes::DataType}; + + use super::*; + + fn create_list_array, T: AsRef<[Option]>>(data: T) -> ListArray { + let data_type = ListArray::::default_datatype(DataType::Int32); + let list = data + .as_ref() + .into_iter() + .map(|x| { + x.as_ref() + .map(|x| x.as_ref().iter().map(|x| Some(*x)).collect::>()) + }) + .collect::>>(); + list.to(data_type) + } + + #[test] + fn test_list_equal() { + let a = create_list_array(&[Some(&[1, 2, 3]), Some(&[4, 5, 6])]); + let b = create_list_array(&[Some(&[1, 2, 3]), Some(&[4, 5, 6])]); + test_equal(&a, &b, true); + + let b = create_list_array(&[Some(&[1, 2, 3]), Some(&[4, 5, 7])]); + test_equal(&a, &b, false); + } + + // Test the case where null_count > 0 + #[test] + fn test_list_null() { + let a = create_list_array(&[Some(&[1, 2]), None, None, Some(&[3, 4]), None, None]); + let b = create_list_array(&[Some(&[1, 2]), None, None, Some(&[3, 4]), None, None]); + test_equal(&a, &b, true); + + let b = create_list_array(&[ + Some(&[1, 2]), + None, + Some(&[5, 6]), + Some(&[3, 4]), + None, + None, + ]); + test_equal(&a, &b, false); + + let b = create_list_array(&[Some(&[1, 2]), None, None, Some(&[3, 5]), None, None]); + test_equal(&a, &b, false); + } + + // Test the case where offset != 0 + #[test] + fn test_list_offsets() { + let a = create_list_array(&[Some(&[1, 2]), None, None, Some(&[3, 4]), None, None]); + let b = create_list_array(&[Some(&[1, 2]), None, None, Some(&[3, 5]), None, None]); + + let a_slice = a.slice(0, 3); + let b_slice = b.slice(0, 3); + test_equal(&a_slice, &b_slice, true); + + let a_slice = a.slice(0, 5); + let b_slice = b.slice(0, 5); + test_equal(&a_slice, &b_slice, false); + + let a_slice = a.slice(4, 1); + let b_slice = b.slice(4, 1); + test_equal(&a_slice, &b_slice, true); + } +} diff --git a/src/array/equal/mod.rs b/src/array/equal/mod.rs new file mode 100644 index 00000000000..0b57148f7d2 --- /dev/null +++ b/src/array/equal/mod.rs @@ -0,0 +1,451 @@ +use std::unimplemented; + +use crate::{ + buffer::{Bitmap, NativeType}, + datatypes::{DataType, IntervalUnit}, +}; + +use super::{ + primitive::PrimitiveArray, Array, BinaryArray, BooleanArray, ListArray, NullArray, Utf8Array, +}; + +mod boolean; +mod list; +mod null; +mod primitive; +mod utils; +mod variable_size; + +impl PartialEq for &dyn Array { + fn eq(&self, other: &Self) -> bool { + equal(*self, *other) + } +} + +impl PartialEq<&dyn Array> for PrimitiveArray { + fn eq(&self, other: &&dyn Array) -> bool { + equal(self, *other) + } +} + +impl PartialEq> for PrimitiveArray { + fn eq(&self, other: &Self) -> bool { + equal(self, other) + } +} + +fn equal_range( + lhs: &dyn Array, + rhs: &dyn Array, + lhs_nulls: &Option, + rhs_nulls: &Option, + lhs_start: usize, + rhs_start: usize, + len: usize, +) -> bool { + utils::base_equal(lhs, rhs) + && utils::equal_nulls(lhs_nulls, rhs_nulls, lhs_start, rhs_start, len) + && equal_values(lhs, rhs, lhs_nulls, rhs_nulls, lhs_start, rhs_start, len) +} + +/// Compares the values of two [`Array`] starting at `lhs_start` and `rhs_start` respectively +/// for `len` slots. The null buffers `lhs_nulls` and `rhs_nulls` are inherit parent nullability. +/// +/// If an array is a child of a struct or list, the array's nulls have to be merged with the parent. +/// This then affects the null count of the array, thus the merged nulls are passed separately +/// as `lhs_nulls` and `rhs_nulls` variables to functions. +/// The nulls are merged with a bitwise AND, and null counts are recomputed where necessary. +#[inline] +fn equal_values( + lhs: &dyn Array, + rhs: &dyn Array, + lhs_nulls: &Option, + rhs_nulls: &Option, + lhs_start: usize, + rhs_start: usize, + len: usize, +) -> bool { + match lhs.data_type() { + DataType::Null => { + let lhs = lhs.as_any().downcast_ref::().unwrap(); + let rhs = rhs.as_any().downcast_ref::().unwrap(); + null::equal(lhs, rhs) + } + DataType::Boolean => { + let lhs = lhs.as_any().downcast_ref::().unwrap(); + let rhs = rhs.as_any().downcast_ref::().unwrap(); + boolean::equal(lhs, rhs, lhs_nulls, rhs_nulls, lhs_start, rhs_start, len) + } + DataType::UInt8 => { + let lhs = lhs.as_any().downcast_ref::>().unwrap(); + let rhs = rhs.as_any().downcast_ref::>().unwrap(); + primitive::equal(lhs, rhs, lhs_nulls, rhs_nulls, lhs_start, rhs_start, len) + } + DataType::UInt16 => { + let lhs = lhs.as_any().downcast_ref::>().unwrap(); + let rhs = rhs.as_any().downcast_ref::>().unwrap(); + primitive::equal(lhs, rhs, lhs_nulls, rhs_nulls, lhs_start, rhs_start, len) + } + DataType::UInt32 => { + let lhs = lhs.as_any().downcast_ref::>().unwrap(); + let rhs = rhs.as_any().downcast_ref::>().unwrap(); + primitive::equal(lhs, rhs, lhs_nulls, rhs_nulls, lhs_start, rhs_start, len) + } + DataType::UInt64 => { + let lhs = lhs.as_any().downcast_ref::>().unwrap(); + let rhs = rhs.as_any().downcast_ref::>().unwrap(); + primitive::equal(lhs, rhs, lhs_nulls, rhs_nulls, lhs_start, rhs_start, len) + } + DataType::Int8 => { + let lhs = lhs.as_any().downcast_ref::>().unwrap(); + let rhs = rhs.as_any().downcast_ref::>().unwrap(); + primitive::equal(lhs, rhs, lhs_nulls, rhs_nulls, lhs_start, rhs_start, len) + } + DataType::Int16 => { + let lhs = lhs.as_any().downcast_ref::>().unwrap(); + let rhs = rhs.as_any().downcast_ref::>().unwrap(); + primitive::equal(lhs, rhs, lhs_nulls, rhs_nulls, lhs_start, rhs_start, len) + } + DataType::Int32 + | DataType::Date32 + | DataType::Time32(_) + | DataType::Interval(IntervalUnit::YearMonth) => { + let lhs = lhs.as_any().downcast_ref::>().unwrap(); + let rhs = rhs.as_any().downcast_ref::>().unwrap(); + primitive::equal(lhs, rhs, lhs_nulls, rhs_nulls, lhs_start, rhs_start, len) + } + DataType::Int64 + | DataType::Date64 + | DataType::Interval(IntervalUnit::DayTime) + | DataType::Time64(_) + | DataType::Timestamp(_, _) + | DataType::Duration(_) => { + let lhs = lhs.as_any().downcast_ref::>().unwrap(); + let rhs = rhs.as_any().downcast_ref::>().unwrap(); + primitive::equal(lhs, rhs, lhs_nulls, rhs_nulls, lhs_start, rhs_start, len) + } + DataType::Float16 => unreachable!(), + DataType::Float32 => { + let lhs = lhs.as_any().downcast_ref::>().unwrap(); + let rhs = rhs.as_any().downcast_ref::>().unwrap(); + primitive::equal(lhs, rhs, lhs_nulls, rhs_nulls, lhs_start, rhs_start, len) + } + DataType::Float64 => { + let lhs = lhs.as_any().downcast_ref::>().unwrap(); + let rhs = rhs.as_any().downcast_ref::>().unwrap(); + primitive::equal(lhs, rhs, lhs_nulls, rhs_nulls, lhs_start, rhs_start, len) + } + DataType::Decimal(_, _) => { + let lhs = lhs.as_any().downcast_ref::>().unwrap(); + let rhs = rhs.as_any().downcast_ref::>().unwrap(); + primitive::equal(lhs, rhs, lhs_nulls, rhs_nulls, lhs_start, rhs_start, len) + } + DataType::Utf8 => { + let lhs = lhs.as_any().downcast_ref::>().unwrap(); + let rhs = rhs.as_any().downcast_ref::>().unwrap(); + variable_size::equal( + lhs.offsets(), + rhs.offsets(), + lhs.values(), + rhs.values(), + lhs_nulls, + rhs_nulls, + lhs_start, + rhs_start, + len, + ) + } + DataType::LargeUtf8 => { + let lhs = lhs.as_any().downcast_ref::>().unwrap(); + let rhs = rhs.as_any().downcast_ref::>().unwrap(); + variable_size::equal( + lhs.offsets(), + rhs.offsets(), + lhs.values(), + rhs.values(), + lhs_nulls, + rhs_nulls, + lhs_start, + rhs_start, + len, + ) + } + DataType::Binary => { + let lhs = lhs.as_any().downcast_ref::>().unwrap(); + let rhs = rhs.as_any().downcast_ref::>().unwrap(); + variable_size::equal( + lhs.offsets(), + rhs.offsets(), + lhs.values(), + rhs.values(), + lhs_nulls, + rhs_nulls, + lhs_start, + rhs_start, + len, + ) + } + DataType::LargeBinary => { + let lhs = lhs.as_any().downcast_ref::>().unwrap(); + let rhs = rhs.as_any().downcast_ref::>().unwrap(); + variable_size::equal( + lhs.offsets(), + rhs.offsets(), + lhs.values(), + rhs.values(), + lhs_nulls, + rhs_nulls, + lhs_start, + rhs_start, + len, + ) + } + DataType::List(_) => { + let lhs = lhs.as_any().downcast_ref::>().unwrap(); + let rhs = rhs.as_any().downcast_ref::>().unwrap(); + list::equal(lhs, rhs, lhs_nulls, rhs_nulls, lhs_start, rhs_start, len) + } + DataType::LargeList(_) => { + let lhs = lhs.as_any().downcast_ref::>().unwrap(); + let rhs = rhs.as_any().downcast_ref::>().unwrap(); + list::equal(lhs, rhs, lhs_nulls, rhs_nulls, lhs_start, rhs_start, len) + } + _ => unimplemented!(), + /* + DataType::FixedSizeBinary(_) => {} + DataType::FixedSizeList(_, _) => {} + DataType::Struct(_) => {} + DataType::Union(_) => {} + DataType::Dictionary(_, _) => {} + */ + } +} + +/// Logically compares two [ArrayData]. +/// Two arrays are logically equal if and only if: +/// * their data types are equal +/// * their lengths are equal +/// * their null counts are equal +/// * their null bitmaps are equal +/// * each of their items are equal +/// two items are equal when their in-memory representation is physically equal (i.e. same bit content). +/// The physical comparison depend on the data type. +/// # Panics +/// This function may panic whenever any of the [ArrayData] does not follow the Arrow specification. +/// (e.g. wrong number of buffers, buffer `len` does not correspond to the declared `len`) +pub fn equal(lhs: &dyn Array, rhs: &dyn Array) -> bool { + let lhs_nulls = lhs.nulls(); + let rhs_nulls = rhs.nulls(); + utils::base_equal(lhs, rhs) + && lhs.null_count() == rhs.null_count() + && utils::equal_nulls(lhs_nulls, rhs_nulls, 0, 0, lhs.len()) + && equal_values(lhs, rhs, lhs_nulls, rhs_nulls, 0, 0, lhs.len()) +} + +#[cfg(test)] +mod tests { + use crate::array::{BooleanArray, Offset, Primitive}; + + use super::*; + + #[test] + fn test_primitive() { + let cases = vec![ + ( + vec![Some(1), Some(2), Some(3)], + vec![Some(1), Some(2), Some(3)], + true, + ), + ( + vec![Some(1), Some(2), Some(3)], + vec![Some(1), Some(2), Some(4)], + false, + ), + ( + vec![Some(1), Some(2), None], + vec![Some(1), Some(2), None], + true, + ), + ( + vec![Some(1), None, Some(3)], + vec![Some(1), Some(2), None], + false, + ), + ( + vec![Some(1), None, None], + vec![Some(1), Some(2), None], + false, + ), + ]; + + for (lhs, rhs, expected) in cases { + let lhs = Primitive::::from(&lhs).to(DataType::Int32); + let rhs = Primitive::::from(&rhs).to(DataType::Int32); + test_equal(&lhs, &rhs, expected); + } + } + + #[test] + fn test_primitive_slice() { + let cases = vec![ + ( + vec![Some(1), Some(2), Some(3)], + (0, 1), + vec![Some(1), Some(2), Some(3)], + (0, 1), + true, + ), + ( + vec![Some(1), Some(2), Some(3)], + (1, 1), + vec![Some(1), Some(2), Some(3)], + (2, 1), + false, + ), + ( + vec![Some(1), Some(2), None], + (1, 1), + vec![Some(1), None, Some(2)], + (2, 1), + true, + ), + ( + vec![None, Some(2), None], + (1, 1), + vec![None, None, Some(2)], + (2, 1), + true, + ), + ( + vec![Some(1), None, Some(2), None, Some(3)], + (2, 2), + vec![None, Some(2), None, Some(3)], + (1, 2), + true, + ), + ]; + + for (lhs, slice_lhs, rhs, slice_rhs, expected) in cases { + let lhs = Primitive::::from(&lhs).to(DataType::Int32); + let lhs = lhs.slice(slice_lhs.0, slice_lhs.1); + let rhs = Primitive::::from(&rhs).to(DataType::Int32); + let rhs = rhs.slice(slice_rhs.0, slice_rhs.1); + + test_equal(&lhs, &rhs, expected); + } + } + + pub(super) fn test_equal(lhs: &dyn Array, rhs: &dyn Array, expected: bool) { + // equality is symmetric + assert_eq!(equal(lhs, lhs), true, "\n{:?}\n{:?}", lhs, lhs); + assert_eq!(equal(rhs, rhs), true, "\n{:?}\n{:?}", rhs, rhs); + + assert_eq!(equal(lhs, rhs), expected, "\n{:?}\n{:?}", lhs, rhs); + assert_eq!(equal(rhs, lhs), expected, "\n{:?}\n{:?}", rhs, lhs); + } + + #[test] + fn test_boolean_equal() { + let a = BooleanArray::from_slice([false, false, true]); + let b = BooleanArray::from_slice([false, false, true]); + test_equal(&a, &b, true); + + let b = BooleanArray::from_slice([false, false, false]); + test_equal(&a, &b, false); + } + + #[test] + fn test_boolean_equal_null() { + let a = BooleanArray::from(vec![Some(false), None, None, Some(true)]); + let b = BooleanArray::from(vec![Some(false), None, None, Some(true)]); + test_equal(&a, &b, true); + + let b = BooleanArray::from(vec![None, None, None, Some(true)]); + test_equal(&a, &b, false); + + let b = BooleanArray::from(vec![Some(true), None, None, Some(true)]); + test_equal(&a, &b, false); + } + + #[test] + fn test_boolean_equal_offset() { + let a = BooleanArray::from_slice(vec![false, true, false, true, false, false, true]); + let b = BooleanArray::from_slice(vec![true, false, false, false, true, false, true, true]); + test_equal(&a, &b, false); + + let a_slice = a.slice(2, 3); + let b_slice = b.slice(3, 3); + test_equal(&a_slice, &b_slice, true); + + let a_slice = a.slice(3, 4); + let b_slice = b.slice(4, 4); + test_equal(&a_slice, &b_slice, false); + + // Elements fill in `u8`'s exactly. + let mut vector = vec![false, false, true, true, true, true, true, true]; + let a = BooleanArray::from_slice(vector.clone()); + let b = BooleanArray::from_slice(vector.clone()); + test_equal(&a, &b, true); + + // Elements fill in `u8`s + suffix bits. + vector.push(true); + let a = BooleanArray::from_slice(vector.clone()); + let b = BooleanArray::from_slice(vector); + test_equal(&a, &b, true); + } + + fn binary_cases() -> Vec<(Vec>, Vec>, bool)> { + let base = vec![ + Some("hello".to_owned()), + None, + None, + Some("world".to_owned()), + None, + None, + ]; + let not_base = vec![ + Some("hello".to_owned()), + Some("foo".to_owned()), + None, + Some("world".to_owned()), + None, + None, + ]; + vec![ + ( + vec![Some("hello".to_owned()), Some("world".to_owned())], + vec![Some("hello".to_owned()), Some("world".to_owned())], + true, + ), + ( + vec![Some("hello".to_owned()), Some("world".to_owned())], + vec![Some("hello".to_owned()), Some("arrow".to_owned())], + false, + ), + (base.clone(), base.clone(), true), + (base, not_base, false), + ] + } + + fn test_generic_string_equal() { + let cases = binary_cases(); + + for (lhs, rhs, expected) in cases { + let lhs = lhs.iter().map(|x| x.as_deref()).collect::>(); + let rhs = rhs.iter().map(|x| x.as_deref()).collect::>(); + let lhs = Utf8Array::::from(&lhs); + let rhs = Utf8Array::::from(&rhs); + test_equal(&lhs, &rhs, expected); + } + } + + #[test] + fn test_string_equal() { + test_generic_string_equal::() + } + + #[test] + fn test_large_string_equal() { + test_generic_string_equal::() + } +} diff --git a/src/array/equal/null.rs b/src/array/equal/null.rs new file mode 100644 index 00000000000..bf10b1aae58 --- /dev/null +++ b/src/array/equal/null.rs @@ -0,0 +1,8 @@ +use crate::array::NullArray; + +#[inline] +pub(super) fn equal(_lhs: &NullArray, _rhs: &NullArray) -> bool { + // a null buffer's range is always true, as every element is by definition equal (to null). + // We only need to compare data_types + true +} diff --git a/src/array/equal/primitive.rs b/src/array/equal/primitive.rs new file mode 100644 index 00000000000..eea12f7d916 --- /dev/null +++ b/src/array/equal/primitive.rs @@ -0,0 +1,42 @@ +use crate::{ + array::primitive::PrimitiveArray, + buffer::{types::NativeType, Bitmap}, +}; + +use super::utils::{count_nulls, equal_len}; + +pub(super) fn equal( + lhs: &PrimitiveArray, + rhs: &PrimitiveArray, + lhs_nulls: &Option, + rhs_nulls: &Option, + lhs_start: usize, + rhs_start: usize, + len: usize, +) -> bool { + let lhs_values = &lhs.values().as_slice(); + let rhs_values = &rhs.values().as_slice(); + + let lhs_null_count = count_nulls(lhs_nulls, lhs_start, len); + let rhs_null_count = count_nulls(rhs_nulls, rhs_start, len); + + if lhs_null_count == 0 && rhs_null_count == 0 { + // without nulls, we just need to compare slices + equal_len(lhs_values, rhs_values, lhs_start, rhs_start, len) + } else { + // get a ref of the null buffer bytes, to use in testing for nullness + let lhs_bitmap = lhs_nulls.as_ref().unwrap(); + let rhs_bitmap = rhs_nulls.as_ref().unwrap(); + // with nulls, we need to compare item by item whenever it is not null + (0..len).all(|i| { + let lhs_pos = lhs_start + i; + let rhs_pos = rhs_start + i; + let lhs_is_null = !lhs_bitmap.get_bit(lhs_pos); + let rhs_is_null = !rhs_bitmap.get_bit(rhs_pos); + + lhs_is_null + || (lhs_is_null == rhs_is_null) + && equal_len(lhs_values, rhs_values, lhs_pos, rhs_pos, 1) + }) + } +} diff --git a/src/array/equal/utils.rs b/src/array/equal/utils.rs new file mode 100644 index 00000000000..701d9caeb2b --- /dev/null +++ b/src/array/equal/utils.rs @@ -0,0 +1,173 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::{ + array::{Array, ListArray, Offset}, + buffer::{Bitmap, MutableBitmap}, + datatypes::DataType, +}; + +// whether bits along the positions are equal +// `lhs_start`, `rhs_start` and `len` are _measured in bits_. +#[inline] +pub(super) fn equal_bits( + lhs_values: &Bitmap, + rhs_values: &Bitmap, + lhs_start: usize, + rhs_start: usize, + len: usize, +) -> bool { + // todo: safely iterate over both in one go + (0..len).all(|i| lhs_values.get_bit(lhs_start + i) == rhs_values.get_bit(rhs_start + i)) +} + +#[inline] +pub(super) fn equal_nulls( + lhs_nulls: &Option, + rhs_nulls: &Option, + lhs_start: usize, + rhs_start: usize, + len: usize, +) -> bool { + let lhs_null_count = lhs_nulls + .as_ref() + .map(|x| x.null_count_range(lhs_start, len)) + .unwrap_or(0); + let rhs_null_count = rhs_nulls + .as_ref() + .map(|x| x.null_count_range(rhs_start, len)) + .unwrap_or(0); + + if lhs_null_count > 0 || rhs_null_count > 0 { + let lhs_values = lhs_nulls.as_ref().unwrap(); + let rhs_values = rhs_nulls.as_ref().unwrap(); + equal_bits(lhs_values, rhs_values, lhs_start, rhs_start, len) + } else { + true + } +} + +#[inline] +pub(super) fn base_equal(lhs: &dyn Array, rhs: &dyn Array) -> bool { + lhs.data_type() == rhs.data_type() && lhs.len() == rhs.len() +} + +// whether the two memory regions are equal +#[inline] +pub(super) fn equal_len( + lhs_values: &[T], + rhs_values: &[T], + lhs_start: usize, + rhs_start: usize, + len: usize, +) -> bool { + lhs_values[lhs_start..(lhs_start + len)] == rhs_values[rhs_start..(rhs_start + len)] +} + +#[inline] +pub(super) fn count_nulls(nulls: &Option, offset: usize, length: usize) -> usize { + nulls + .as_ref() + .map(|x| x.null_count_range(offset, length)) + .unwrap_or(0) +} + +// Calculate a list child's logical bitmap/buffer +// `[[1, None, 3], None, [None]]` +// offsets = [0, 3, 3, 4] +// parent validity = [1, 0] +// child validity = [1, 0, 1, 0] +// logical_list_bitmap = [1, 0, 1, 0] +#[inline] +fn logical_list_bitmap( + offsets: &[O], + parent_bitmap: &Option, + child_bitmap: &Option, +) -> Option { + let first_offset = offsets.first().unwrap().to_usize().unwrap(); + let last_offset = offsets.get(offsets.len() - 1).unwrap().to_usize().unwrap(); + let length = last_offset - first_offset; + + match (parent_bitmap, child_bitmap) { + (Some(parent_bitmap), Some(child_bitmap)) => { + let mut buffer = MutableBitmap::with_capacity(length); + offsets + .windows(2) + .enumerate() + .take(length) + .for_each(|(index, window)| { + let start = window[0].to_usize().unwrap(); + let end = window[1].to_usize().unwrap(); + let mask = parent_bitmap.get_bit(index); + (start..end).for_each(|child_index| { + let is_set = mask && child_bitmap.get_bit(child_index); + buffer.push(is_set); + }); + }); + Some(buffer.into()) + } + (None, Some(child_bitmap)) => { + let mut buffer = MutableBitmap::with_capacity(length); + offsets.windows(2).take(length).for_each(|window| { + let start = window[0].to_usize().unwrap(); + let end = window[1].to_usize().unwrap(); + (start..end).for_each(|child_index| { + buffer.push(child_bitmap.get_bit(child_index)); + }); + }); + Some(buffer.into()) + } + (Some(parent_bitmap), None) => { + let mut buffer = MutableBitmap::with_capacity(length); + offsets.windows(2).take(length).for_each(|window| { + let start = window[0].to_usize().unwrap(); + let end = window[1].to_usize().unwrap(); + (start..end).for_each(|child_index| { + buffer.push(parent_bitmap.get_bit(child_index)); + }); + }); + Some(buffer.into()) + } + (None, None) => None, + } +} + +/// Computes the logical validity bitmap of the array data using the +/// parent's array data. The parent should be a list or struct, else +/// the logical bitmap of the array is returned unaltered. +/// +/// Parent data is passed along with the parent's logical bitmap, as +/// nested arrays could have a logical bitmap different to the physical one. +pub(super) fn child_logical_null_buffer( + parent: &dyn Array, + logical_null_buffer: &Option, + child: &dyn Array, +) -> Option { + let parent_bitmap = logical_null_buffer; + let self_null_bitmap = child.nulls(); + match parent.data_type() { + DataType::List(_) => { + let parent = parent.as_any().downcast_ref::>().unwrap(); + logical_list_bitmap(parent.offsets().as_slice(), parent_bitmap, self_null_bitmap) + } + DataType::LargeList(_) => { + let parent = parent.as_any().downcast_ref::>().unwrap(); + logical_list_bitmap(parent.offsets().as_slice(), parent_bitmap, self_null_bitmap) + } + data_type => panic!("Data type {:?} is not a supported nested type", data_type), + } +} diff --git a/src/array/equal/variable_size.rs b/src/array/equal/variable_size.rs new file mode 100644 index 00000000000..657925f04ef --- /dev/null +++ b/src/array/equal/variable_size.rs @@ -0,0 +1,83 @@ +use crate::{array::Offset, buffer::Bitmap}; + +use super::utils::{count_nulls, equal_len}; + +fn offset_value_equal( + lhs_values: &[u8], + rhs_values: &[u8], + lhs_offsets: &[O], + rhs_offsets: &[O], + lhs_pos: usize, + rhs_pos: usize, + len: usize, +) -> bool { + let lhs_start = lhs_offsets[lhs_pos].to_usize().unwrap(); + let rhs_start = rhs_offsets[rhs_pos].to_usize().unwrap(); + let lhs_len = lhs_offsets[lhs_pos + len] - lhs_offsets[lhs_pos]; + let rhs_len = rhs_offsets[rhs_pos + len] - rhs_offsets[rhs_pos]; + + lhs_len == rhs_len + && equal_len( + lhs_values, + rhs_values, + lhs_start, + rhs_start, + lhs_len.to_usize().unwrap(), + ) +} + +pub(super) fn equal( + lhs_offsets: &[O], + rhs_offsets: &[O], + lhs_values: &[u8], + rhs_values: &[u8], + lhs_nulls: &Option, + rhs_nulls: &Option, + lhs_start: usize, + rhs_start: usize, + len: usize, +) -> bool { + // the offsets of the `ArrayData` are ignored as they are only applied to the offset buffer. + + let lhs_null_count = count_nulls(lhs_nulls, lhs_start, len); + let rhs_null_count = count_nulls(rhs_nulls, rhs_start, len); + + if lhs_null_count == 0 + && rhs_null_count == 0 + && !lhs_values.is_empty() + && !rhs_values.is_empty() + { + offset_value_equal( + lhs_values, + rhs_values, + lhs_offsets, + rhs_offsets, + lhs_start, + rhs_start, + len, + ) + } else { + let lhs_bitmap = lhs_nulls.as_ref().unwrap(); + let rhs_bitmap = rhs_nulls.as_ref().unwrap(); + + (0..len).all(|i| { + let lhs_pos = lhs_start + i; + let rhs_pos = rhs_start + i; + + let lhs_is_null = !lhs_bitmap.get_bit(lhs_pos); + let rhs_is_null = !rhs_bitmap.get_bit(rhs_pos); + + lhs_is_null + || (lhs_is_null == rhs_is_null) + && offset_value_equal( + lhs_values, + rhs_values, + lhs_offsets, + rhs_offsets, + lhs_pos, + rhs_pos, + 1, + ) + }) + } +} diff --git a/src/array/ffi.rs b/src/array/ffi.rs new file mode 100644 index 00000000000..39db10bc094 --- /dev/null +++ b/src/array/ffi.rs @@ -0,0 +1,13 @@ +use crate::{datatypes::DataType, ffi::ArrowArray}; + +use crate::error::Result; + +pub unsafe trait ToFFI { + fn buffers(&self) -> [Option>; 3]; + + fn offset(&self) -> usize; +} + +pub unsafe trait FromFFI: Sized { + fn try_from_ffi(data_type: DataType, array: ArrowArray) -> Result; +} diff --git a/src/array/fixed_binary.rs b/src/array/fixed_binary.rs new file mode 100644 index 00000000000..42decc9a2c0 --- /dev/null +++ b/src/array/fixed_binary.rs @@ -0,0 +1,101 @@ +use crate::{ + buffer::{Bitmap, Buffer}, + datatypes::DataType, +}; + +use super::{ffi::ToFFI, Array}; + +#[derive(Debug, Clone)] +pub struct FixedSizeBinaryArray { + size: i32, // this is redundant with `data_type`, but useful to not have to deconstruct the data_type. + data_type: DataType, + values: Buffer, + validity: Option, + offset: usize, +} + +impl FixedSizeBinaryArray { + pub fn new_empty(data_type: DataType) -> Self { + Self::from_data(data_type, Buffer::new(), None) + } + + pub fn from_data(data_type: DataType, values: Buffer, validity: Option) -> Self { + let size = *Self::get_size(&data_type); + + assert_eq!(values.len() % (size as usize), 0); + + Self { + size, + data_type: DataType::FixedSizeBinary(size), + values, + validity, + offset: 0, + } + } + + pub fn slice(&self, offset: usize, length: usize) -> Self { + let validity = self.validity.clone().map(|x| x.slice(offset, length)); + let offset = offset * self.size as usize; + let length = offset * self.size as usize; + Self { + data_type: self.data_type.clone(), + size: self.size, + values: self.values.clone().slice(offset, length), + validity, + offset: 0, + } + } +} + +impl FixedSizeBinaryArray { + pub(crate) fn get_size(data_type: &DataType) -> &i32 { + if let DataType::FixedSizeBinary(size) = data_type { + size + } else { + panic!("Wrong DataType") + } + } +} + +impl Array for FixedSizeBinaryArray { + #[inline] + fn as_any(&self) -> &dyn std::any::Any { + self + } + + #[inline] + fn len(&self) -> usize { + self.values.len() / self.size as usize + } + + #[inline] + fn data_type(&self) -> &DataType { + &self.data_type + } + + fn nulls(&self) -> &Option { + &self.validity + } + + fn slice(&self, offset: usize, length: usize) -> Box { + Box::new(self.slice(offset, length)) + } +} + +unsafe impl ToFFI for FixedSizeBinaryArray { + fn buffers(&self) -> [Option>; 3] { + unsafe { + [ + self.validity.as_ref().map(|x| x.as_ptr()), + Some(std::ptr::NonNull::new_unchecked( + self.values.as_ptr() as *mut u8 + )), + None, + ] + } + } + + fn offset(&self) -> usize { + self.offset + } +} diff --git a/src/array/fixed_list.rs b/src/array/fixed_list.rs new file mode 100644 index 00000000000..ca7fb562421 --- /dev/null +++ b/src/array/fixed_list.rs @@ -0,0 +1,97 @@ +use std::sync::Arc; + +use crate::{buffer::Bitmap, datatypes::DataType}; + +use super::{ffi::ToFFI, new_empty_array, Array}; + +#[derive(Debug, Clone)] +pub struct FixedSizeListArray { + size: i32, // this is redundant with `data_type`, but useful to not have to deconstruct the data_type. + data_type: DataType, + values: Arc, + validity: Option, + offset: usize, +} + +impl FixedSizeListArray { + pub fn new_empty(data_type: DataType) -> Self { + let values = new_empty_array(Self::get_child_and_size(&data_type).0.clone()).into(); + Self::from_data(data_type, values, None) + } + + pub fn from_data( + data_type: DataType, + values: Arc, + validity: Option, + ) -> Self { + let (_, size) = Self::get_child_and_size(&data_type); + + assert_eq!(values.len() % (*size as usize), 0); + + Self { + size: *size, + data_type, + values, + validity, + offset: 0, + } + } + + pub fn slice(&self, offset: usize, length: usize) -> Self { + let validity = self.validity.clone().map(|x| x.slice(offset, length)); + let offset = offset * self.size as usize; + let length = offset * self.size as usize; + Self { + data_type: self.data_type.clone(), + size: self.size, + values: self.values.clone().slice(offset, length).into(), + validity, + offset: 0, + } + } +} + +impl FixedSizeListArray { + pub(crate) fn get_child_and_size(data_type: &DataType) -> (&DataType, &i32) { + if let DataType::FixedSizeList(field, size) = data_type { + (field.data_type(), size) + } else { + panic!("Wrong DataType") + } + } +} + +impl Array for FixedSizeListArray { + #[inline] + fn as_any(&self) -> &dyn std::any::Any { + self + } + + #[inline] + fn len(&self) -> usize { + self.values.len() / self.size as usize + } + + #[inline] + fn data_type(&self) -> &DataType { + &self.data_type + } + + fn nulls(&self) -> &Option { + &self.validity + } + + fn slice(&self, offset: usize, length: usize) -> Box { + Box::new(self.slice(offset, length)) + } +} + +unsafe impl ToFFI for FixedSizeListArray { + fn buffers(&self) -> [Option>; 3] { + [self.validity.as_ref().map(|x| x.as_ptr()), None, None] + } + + fn offset(&self) -> usize { + self.offset + } +} diff --git a/src/array/list/from.rs b/src/array/list/from.rs new file mode 100644 index 00000000000..00b5981b599 --- /dev/null +++ b/src/array/list/from.rs @@ -0,0 +1,112 @@ +use std::{iter::FromIterator, sync::Arc}; + +use crate::{ + array::{Array, Offset, Primitive, ToArray}, + buffer::{Bitmap, Buffer, MutableBitmap, MutableBuffer, NativeType}, + datatypes::DataType, +}; + +use super::ListArray; + +#[derive(Debug)] +pub struct ListPrimitive { + offsets: Buffer, + values: A, + validity: Option, +} + +impl ListPrimitive { + pub fn to(self, data_type: DataType) -> ListArray { + let values = self.values.to_arc(ListArray::::get_child(&data_type)); + ListArray::from_data(data_type, self.offsets, values, self.validity) + } +} + +pub fn list_from_iter_primitive(iter: I) -> (Buffer, Primitive, Option) +where + O: Offset, + T: NativeType, + P: AsRef<[Option]> + IntoIterator>, + I: IntoIterator>, +{ + let iterator = iter.into_iter(); + let (lower, _) = iterator.size_hint(); + + let mut offsets = MutableBuffer::::with_capacity(lower + 1); + let mut length_so_far = O::zero(); + offsets.push(length_so_far); + + let mut nulls = MutableBitmap::with_capacity(lower); + + let values: Primitive = iterator + .filter_map(|maybe_slice| { + // regardless of whether the item is Some, the offsets and null buffers must be updated. + match &maybe_slice { + Some(x) => { + length_so_far += O::from_usize(x.as_ref().len()).unwrap(); + nulls.push(true); + } + None => nulls.push(false), + }; + offsets.push(length_so_far); + maybe_slice + }) + .flatten() + .collect(); + + let bitmap = if nulls.null_count() > 0 { + Some(nulls.into()) + } else { + None + }; + + (offsets.into(), values, bitmap) +} + +impl FromIterator> for ListPrimitive> +where + O: Offset, + T: NativeType, + P: AsRef<[Option]> + IntoIterator>, +{ + fn from_iter>>(iter: I) -> Self { + let (offsets, values, validity) = list_from_iter_primitive(iter); + Self { + offsets, + values, + validity, + } + } +} + +impl ToArray for ListPrimitive> { + fn to_arc(self, data_type: &DataType) -> Arc { + let list = self.to(data_type.clone()); + Arc::new(list) + } +} + +#[cfg(test)] +mod tests { + use crate::array::{Primitive, PrimitiveArray}; + + use super::*; + + #[test] + fn primitive() { + let data = vec![ + Some(vec![Some(1i32), Some(2), Some(3)]), + None, + Some(vec![Some(4), None, Some(6)]), + ]; + + let a: ListPrimitive> = data.into_iter().collect(); + let a = a.to(ListArray::::default_datatype(DataType::Int32)); + let a = a.value(0); + let a = a.as_any().downcast_ref::>().unwrap(); + + let expected = + Primitive::::from(vec![Some(1i32), Some(2), Some(3)]).to(DataType::Int32); + assert_eq!(a, &expected) + } +} diff --git a/src/array/list/mod.rs b/src/array/list/mod.rs new file mode 100644 index 00000000000..d241d85f88e --- /dev/null +++ b/src/array/list/mod.rs @@ -0,0 +1,187 @@ +use std::sync::Arc; + +use crate::{ + buffer::{Bitmap, Buffer}, + datatypes::{DataType, Field}, +}; + +use super::{ + ffi::ToFFI, + new_empty_array, + specification::{check_offsets, Offset}, + Array, +}; + +#[derive(Debug, Clone)] +pub struct ListArray { + data_type: DataType, + offsets: Buffer, + values: Arc, + validity: Option, + offset: usize, +} + +impl ListArray { + pub fn new_empty(data_type: DataType) -> Self { + let values = new_empty_array(Self::get_child(&data_type).clone()).into(); + Self::from_data(data_type, Buffer::from(&[O::zero()]), values, None) + } + + pub fn from_data( + data_type: DataType, + offsets: Buffer, + values: Arc, + validity: Option, + ) -> Self { + check_offsets(&offsets, values.len()); + + // validate data_type + let _ = Self::get_child(&data_type); + + Self { + data_type, + offsets, + values, + validity, + offset: 0, + } + } + + /// Returns the element at index `i` as &str + pub fn value(&self, i: usize) -> Box { + let offsets = self.offsets.as_slice(); + let offset = offsets[i]; + let offset_1 = offsets[i + 1]; + let length = (offset_1 - offset).to_usize().unwrap(); + + self.values.slice(offset.to_usize().unwrap(), length) + } + + /// Returns the element at index `i` as &str + /// # Safety + /// Assumes that the `i < self.len`. + pub unsafe fn value_unchecked(&self, i: usize) -> Box { + let offset = *self.offsets.as_ptr().add(i); + let offset_1 = *self.offsets.as_ptr().add(i + 1); + let length = (offset_1 - offset).to_usize().unwrap(); + + self.values.slice(offset.to_usize().unwrap(), length) + } + + pub fn slice(&self, offset: usize, length: usize) -> Self { + let validity = self.validity.clone().map(|x| x.slice(offset, length)); + Self { + data_type: self.data_type.clone(), + offsets: self.offsets.clone().slice(offset, length), + values: self.values.clone(), + validity, + offset: self.offset + offset, + } + } + + #[inline] + pub fn offsets(&self) -> &Buffer { + &self.offsets + } + + #[inline] + pub fn values(&self) -> &Arc { + &self.values + } +} + +impl ListArray { + pub fn default_datatype(data_type: DataType) -> DataType { + let field = Box::new(Field::new("item", data_type, true)); + if O::is_large() { + DataType::LargeList(field) + } else { + DataType::List(field) + } + } + + pub(crate) fn get_child(data_type: &DataType) -> &DataType { + if O::is_large() { + if let DataType::LargeList(child) = data_type { + child.data_type() + } else { + panic!("Wrong DataType") + } + } else { + if let DataType::List(child) = data_type { + child.data_type() + } else { + panic!("Wrong DataType") + } + } + } +} + +impl Array for ListArray { + #[inline] + fn as_any(&self) -> &dyn std::any::Any { + self + } + + #[inline] + fn len(&self) -> usize { + self.offsets.len() - 1 + } + + #[inline] + fn data_type(&self) -> &DataType { + &self.data_type + } + + #[inline] + fn nulls(&self) -> &Option { + &self.validity + } + + fn slice(&self, offset: usize, length: usize) -> Box { + Box::new(self.slice(offset, length)) + } +} + +unsafe impl ToFFI for ListArray { + fn buffers(&self) -> [Option>; 3] { + unsafe { + [ + self.validity.as_ref().map(|x| x.as_ptr()), + Some(std::ptr::NonNull::new_unchecked( + self.offsets.as_ptr() as *mut u8 + )), + None, + ] + } + } + + fn offset(&self) -> usize { + self.offset + } +} + +mod from; + +pub use from::ListPrimitive; + +#[cfg(test)] +mod tests { + use crate::array::primitive::PrimitiveArray; + + use super::*; + + #[test] + fn test_create() { + let values = Buffer::from([1, 2, 3, 4, 5]); + let values = PrimitiveArray::::from_data(DataType::Int32, values, None); + + let data_type = ListArray::::default_datatype(DataType::Int32); + ListArray::::from_data( + data_type, + Buffer::from([0, 2, 2, 3, 5]), + Arc::new(values), + None, + ); + } +} diff --git a/src/array/mod.rs b/src/array/mod.rs new file mode 100644 index 00000000000..2cee7d93b15 --- /dev/null +++ b/src/array/mod.rs @@ -0,0 +1,176 @@ +use std::any::Any; + +use crate::{buffer::Bitmap, datatypes::DataType}; + +pub trait Array: std::fmt::Debug + Send + Sync + ToFFI { + fn as_any(&self) -> &dyn Any; + + fn len(&self) -> usize; + + fn data_type(&self) -> &DataType; + + fn nulls(&self) -> &Option; + + #[inline] + fn null_count(&self) -> usize { + self.nulls().as_ref().map(|x| x.null_count()).unwrap_or(0) + } + + #[inline] + fn is_null(&self, i: usize) -> bool { + self.nulls() + .as_ref() + .map(|x| !x.get_bit(i)) + .unwrap_or(false) + } + + fn slice(&self, offset: usize, length: usize) -> Box; +} + +/// Creates a new empty dynamic array +pub fn new_empty_array(data_type: DataType) -> Box { + match data_type { + DataType::Null => Box::new(NullArray::new_empty()), + DataType::Boolean => Box::new(BooleanArray::new_empty()), + DataType::Int8 => Box::new(PrimitiveArray::::new_empty(data_type)), + DataType::Int16 => Box::new(PrimitiveArray::::new_empty(data_type)), + DataType::Int32 | DataType::Date32 | DataType::Time32(_) => { + Box::new(PrimitiveArray::::new_empty(data_type)) + } + DataType::Int64 + | DataType::Date64 + | DataType::Time64(_) + | DataType::Timestamp(_, _) + | DataType::Duration(_) + | DataType::Interval(_) => Box::new(PrimitiveArray::::new_empty(data_type)), + DataType::UInt8 => Box::new(PrimitiveArray::::new_empty(data_type)), + DataType::UInt16 => Box::new(PrimitiveArray::::new_empty(data_type)), + DataType::UInt32 => Box::new(PrimitiveArray::::new_empty(data_type)), + DataType::UInt64 => Box::new(PrimitiveArray::::new_empty(data_type)), + DataType::Float16 => unreachable!(), + DataType::Float32 => Box::new(PrimitiveArray::::new_empty(data_type)), + DataType::Float64 => Box::new(PrimitiveArray::::new_empty(data_type)), + DataType::Binary => Box::new(BinaryArray::::new_empty()), + DataType::LargeBinary => Box::new(BinaryArray::::new_empty()), + DataType::FixedSizeBinary(_) => Box::new(FixedSizeBinaryArray::new_empty(data_type)), + DataType::Utf8 => Box::new(Utf8Array::::new_empty()), + DataType::LargeUtf8 => Box::new(Utf8Array::::new_empty()), + DataType::List(_) => Box::new(ListArray::::new_empty(data_type)), + DataType::LargeList(_) => Box::new(ListArray::::new_empty(data_type)), + DataType::FixedSizeList(_, _) => Box::new(FixedSizeListArray::new_empty(data_type)), + DataType::Struct(fields) => Box::new(StructArray::new_empty(&fields)), + DataType::Union(_) => unimplemented!(), + DataType::Dictionary(key_type, value_type) => match key_type.as_ref() { + DataType::Int8 => Box::new(DictionaryArray::::new_empty(*value_type)), + DataType::Int16 => Box::new(DictionaryArray::::new_empty(*value_type)), + DataType::Int32 => Box::new(DictionaryArray::::new_empty(*value_type)), + DataType::Int64 => Box::new(DictionaryArray::::new_empty(*value_type)), + DataType::UInt8 => Box::new(DictionaryArray::::new_empty(*value_type)), + DataType::UInt16 => Box::new(DictionaryArray::::new_empty(*value_type)), + DataType::UInt32 => Box::new(DictionaryArray::::new_empty(*value_type)), + DataType::UInt64 => Box::new(DictionaryArray::::new_empty(*value_type)), + _ => unreachable!(), + }, + DataType::Decimal(_, _) => Box::new(PrimitiveArray::::new_empty(data_type)), + } +} + +macro_rules! clone_dyn { + ($array:expr, $ty:ty) => {{ + let array = $array.as_any().downcast_ref::<$ty>().unwrap(); + Box::new(array.clone()) + }}; +} + +/// Clones `array`. +pub fn clone(array: &dyn Array) -> Box { + match array.data_type() { + DataType::Null => clone_dyn!(array, NullArray), + DataType::Boolean => clone_dyn!(array, BooleanArray), + DataType::Int8 => clone_dyn!(array, PrimitiveArray), + DataType::Int16 => clone_dyn!(array, PrimitiveArray), + DataType::Int32 | DataType::Date32 | DataType::Time32(_) => { + clone_dyn!(array, PrimitiveArray) + } + DataType::Int64 + | DataType::Date64 + | DataType::Time64(_) + | DataType::Timestamp(_, _) + | DataType::Duration(_) + | DataType::Interval(_) => clone_dyn!(array, PrimitiveArray), + DataType::UInt8 => clone_dyn!(array, PrimitiveArray), + DataType::UInt16 => clone_dyn!(array, PrimitiveArray), + DataType::UInt32 => clone_dyn!(array, PrimitiveArray), + DataType::UInt64 => clone_dyn!(array, PrimitiveArray), + DataType::Float16 => unreachable!(), + DataType::Float32 => clone_dyn!(array, PrimitiveArray), + DataType::Float64 => clone_dyn!(array, PrimitiveArray), + DataType::Binary => clone_dyn!(array, BinaryArray), + DataType::LargeBinary => clone_dyn!(array, BinaryArray), + DataType::FixedSizeBinary(_) => clone_dyn!(array, FixedSizeBinaryArray), + DataType::Utf8 => clone_dyn!(array, Utf8Array::), + DataType::LargeUtf8 => clone_dyn!(array, Utf8Array::), + DataType::List(_) => clone_dyn!(array, ListArray::), + DataType::LargeList(_) => clone_dyn!(array, ListArray::), + DataType::FixedSizeList(_, _) => clone_dyn!(array, FixedSizeListArray), + DataType::Struct(_) => clone_dyn!(array, StructArray), + DataType::Union(_) => unimplemented!(), + DataType::Dictionary(key_type, _) => match key_type.as_ref() { + DataType::Int8 => clone_dyn!(array, DictionaryArray::), + DataType::Int16 => clone_dyn!(array, DictionaryArray::), + DataType::Int32 => clone_dyn!(array, DictionaryArray::), + DataType::Int64 => clone_dyn!(array, DictionaryArray::), + DataType::UInt8 => clone_dyn!(array, DictionaryArray::), + DataType::UInt16 => clone_dyn!(array, DictionaryArray::), + DataType::UInt32 => clone_dyn!(array, DictionaryArray::), + DataType::UInt64 => clone_dyn!(array, DictionaryArray::), + _ => unreachable!(), + }, + DataType::Decimal(_, _) => clone_dyn!(array, PrimitiveArray::), + } +} + +mod binary; +mod boolean; +mod dictionary; +mod fixed_binary; +mod fixed_list; +mod list; +mod null; +mod primitive; +mod specification; +mod string; +mod struct_; + +mod equal; +mod ffi; + +pub use binary::BinaryArray; +pub use boolean::BooleanArray; +pub use dictionary::{dict_from_iter, DictionaryArray, DictionaryKey, DictionaryPrimitive}; +pub use fixed_binary::FixedSizeBinaryArray; +pub use fixed_list::FixedSizeListArray; +pub use list::{ListArray, ListPrimitive}; +pub use null::NullArray; +pub use primitive::{Primitive, PrimitiveArray}; +pub use specification::Offset; +pub use string::{Utf8Array, Utf8Primitive}; +pub use struct_::StructArray; + +pub use self::ffi::FromFFI; +use self::ffi::ToFFI; + +pub type Float32Array = PrimitiveArray; +pub type Float64Array = PrimitiveArray; +pub type StringArray = Utf8Array; +pub type UInt32Array = PrimitiveArray; + +pub trait ToArray: std::fmt::Debug { + fn to_arc(self, data_type: &DataType) -> std::sync::Arc; +} + +pub trait Builder: ToArray { + fn with_capacity(capacity: usize) -> Self; + + fn push(&mut self, item: Option<&T>); +} diff --git a/src/array/null.rs b/src/array/null.rs new file mode 100644 index 00000000000..18c861f47dd --- /dev/null +++ b/src/array/null.rs @@ -0,0 +1,68 @@ +use crate::{buffer::Bitmap, datatypes::DataType}; + +use super::{ffi::ToFFI, Array}; + +#[derive(Debug, Clone)] +pub struct NullArray { + data_type: DataType, + length: usize, + offset: usize, +} + +impl NullArray { + pub fn new_empty() -> Self { + Self::from_data(0) + } + + pub fn from_data(length: usize) -> Self { + Self { + data_type: DataType::Null, + length, + offset: 0, + } + } + + pub fn slice(&self, offset: usize, length: usize) -> Self { + Self { + data_type: self.data_type.clone(), + length, + offset: self.offset + offset, + } + } +} + +impl Array for NullArray { + #[inline] + fn as_any(&self) -> &dyn std::any::Any { + self + } + + #[inline] + fn len(&self) -> usize { + self.length + } + + #[inline] + fn data_type(&self) -> &DataType { + &DataType::Null + } + + fn nulls(&self) -> &Option { + &None + } + + fn slice(&self, offset: usize, length: usize) -> Box { + Box::new(self.slice(offset, length)) + } +} + +unsafe impl ToFFI for NullArray { + fn buffers(&self) -> [Option>; 3] { + [None, None, None] + } + + #[inline] + fn offset(&self) -> usize { + self.offset + } +} diff --git a/src/array/primitive/from.rs b/src/array/primitive/from.rs new file mode 100644 index 00000000000..a42f546c618 --- /dev/null +++ b/src/array/primitive/from.rs @@ -0,0 +1,219 @@ +use std::{iter::FromIterator, sync::Arc}; + +use crate::{ + array::{Array, Builder, ToArray}, + buffer::{types::NativeType, MutableBitmap, MutableBuffer}, + datatypes::DataType, +}; + +use super::PrimitiveArray; + +impl Primitive { + pub fn from_slice>(slice: P) -> Self { + unsafe { Self::from_trusted_len_iter(slice.as_ref().iter().map(Some)) } + } +} + +impl]>> From

for Primitive { + fn from(slice: P) -> Self { + unsafe { Self::from_trusted_len_iter(slice.as_ref().iter().map(|x| x.as_ref())) } + } +} + +impl Primitive { + pub fn from_values>(slice: P) -> Self { + Self::from_iter(slice.as_ref().iter().map(|x| Some(*x))) + } +} + +impl Primitive { + /// Creates a [`PrimitiveArray`] from an iterator of trusted length. + /// # Safety + /// The iterator must be [`TrustedLen`](https://doc.rust-lang.org/std/iter/trait.TrustedLen.html). + /// I.e. that `size_hint().1` correctly reports its length. + #[inline] + pub unsafe fn from_trusted_len_iter(iter: I) -> Self + where + P: std::borrow::Borrow, + I: IntoIterator>, + { + let iterator = iter.into_iter(); + + let (validity, values) = trusted_len_unzip(iterator); + + Self { values, validity } + } + + /// Creates a [`PrimitiveArray`] from an falible iterator of trusted length. + /// # Safety + /// The iterator must be [`TrustedLen`](https://doc.rust-lang.org/std/iter/trait.TrustedLen.html). + /// I.e. that `size_hint().1` correctly reports its length. + #[inline] + pub unsafe fn try_from_trusted_len_iter(iter: I) -> Result + where + P: std::borrow::Borrow, + I: IntoIterator, E>>, + { + let iterator = iter.into_iter(); + + let (validity, values) = try_trusted_len_unzip(iterator)?; + + Ok(Self { values, validity }) + } +} + +/// Creates a Bitmap and a [`Buffer`] from an iterator of `Option`. +/// The first buffer corresponds to a bitmap buffer, the second one +/// corresponds to a values buffer. +/// # Safety +/// The caller must ensure that `iterator` is `TrustedLen`. +#[inline] +pub(crate) unsafe fn trusted_len_unzip(iterator: I) -> (MutableBitmap, MutableBuffer) +where + T: NativeType, + P: std::borrow::Borrow, + I: Iterator>, +{ + let (_, upper) = iterator.size_hint(); + let len = upper.expect("trusted_len_unzip requires an upper limit"); + + let mut null = MutableBitmap::with_capacity(len); + let mut buffer = MutableBuffer::::with_capacity(len); + + let mut dst = buffer.as_mut_ptr(); + for item in iterator { + let item = if let Some(item) = item { + null.push_unchecked(true); + *item.borrow() + } else { + null.push_unchecked(false); + T::default() + }; + std::ptr::write(dst, item); + dst = dst.add(1); + } + assert_eq!( + dst.offset_from(buffer.as_ptr()) as usize, + len, + "Trusted iterator length was not accurately reported" + ); + buffer.set_len(len); + null.set_len(len); + + (null, buffer) +} + +/// # Safety +/// The caller must ensure that `iterator` is `TrustedLen`. +#[inline] +pub(crate) unsafe fn try_trusted_len_unzip( + iterator: I, +) -> Result<(MutableBitmap, MutableBuffer), E> +where + T: NativeType, + P: std::borrow::Borrow, + I: Iterator, E>>, +{ + let (_, upper) = iterator.size_hint(); + let len = upper.expect("trusted_len_unzip requires an upper limit"); + + let mut null = MutableBitmap::with_capacity(len); + let mut buffer = MutableBuffer::::with_capacity(len); + + let mut dst = buffer.as_mut_ptr(); + for item in iterator { + let item = if let Some(item) = item? { + null.push_unchecked(true); + *item.borrow() + } else { + null.push_unchecked(false); + T::default() + }; + std::ptr::write(dst, item); + dst = dst.add(1); + } + assert_eq!( + dst.offset_from(buffer.as_ptr()) as usize, + len, + "Trusted iterator length was not accurately reported" + ); + buffer.set_len(len); + null.set_len(len); + + Ok((null, buffer)) +} + +/// auxiliary struct used to create a [`PrimitiveArray`] out of an iterator +#[derive(Debug)] +pub struct Primitive { + values: MutableBuffer, + validity: MutableBitmap, +} + +impl Builder for Primitive { + #[inline] + fn with_capacity(capacity: usize) -> Self { + Self { + values: MutableBuffer::::with_capacity(capacity), + validity: MutableBitmap::with_capacity(capacity), + } + } + + #[inline] + fn push(&mut self, value: Option<&T>) { + match value { + Some(v) => { + self.values.push(*v); + self.validity.push(true); + } + None => { + self.values.push(T::default()); + self.validity.push(false); + } + } + } +} + +impl Primitive { + pub fn to(self, data_type: DataType) -> PrimitiveArray { + let validity = if self.validity.null_count() > 0 { + Some(self.validity.into()) + } else { + None + }; + + PrimitiveArray::::from_data(data_type, self.values.into(), validity) + } +} + +impl>> FromIterator for Primitive { + fn from_iter>(iter: I) -> Self { + let iter = iter.into_iter(); + let (lower, _) = iter.size_hint(); + + let mut validity = MutableBitmap::with_capacity(lower); + + let values: MutableBuffer = iter + .map(|item| { + if let Some(a) = item.borrow() { + validity.push(true); + *a + } else { + validity.push(false); + T::default() + } + }) + .collect(); + + Self { + values: values, + validity, + } + } +} + +impl ToArray for Primitive { + fn to_arc(self, data_type: &DataType) -> Arc { + Arc::new(self.to(data_type.clone())) + } +} diff --git a/src/array/primitive/iterator.rs b/src/array/primitive/iterator.rs new file mode 100644 index 00000000000..b8940e51dac --- /dev/null +++ b/src/array/primitive/iterator.rs @@ -0,0 +1,73 @@ +use crate::{ + array::Array, + buffer::{Bitmap, NativeType}, +}; + +use super::PrimitiveArray; + +/// an iterator that returns Some(T) or None, that can be used on any PrimitiveArray +// Note: This implementation is based on std's [Vec]s' [IntoIter]. +#[derive(Debug)] +pub struct PrimitiveIter<'a, T: NativeType> { + values: &'a [T], + validity: &'a Option, + current: usize, + current_end: usize, +} + +impl<'a, T: NativeType> PrimitiveIter<'a, T> { + /// create a new iterator + pub fn new(array: &'a PrimitiveArray) -> Self { + PrimitiveIter:: { + values: array.values().as_slice(), + validity: &array.validity, + current: 0, + current_end: array.len(), + } + } +} + +impl<'a, T: NativeType> std::iter::Iterator for PrimitiveIter<'a, T> { + type Item = Option; + + fn next(&mut self) -> Option { + if self.current == self.current_end { + None + } else if !self + .validity + .as_ref() + .map(|x| x.get_bit(self.current)) + .unwrap_or(true) + { + self.current += 1; + Some(None) + } else { + let old = self.current; + self.current += 1; + Some(Some(self.values[old])) + } + } + + fn size_hint(&self) -> (usize, Option) { + ( + self.values.len() - self.current, + Some(self.values.len() - self.current), + ) + } +} + +impl<'a, T: NativeType> IntoIterator for &'a PrimitiveArray { + type Item = Option; + type IntoIter = PrimitiveIter<'a, T>; + + fn into_iter(self) -> Self::IntoIter { + PrimitiveIter::<'a, T>::new(self) + } +} + +impl<'a, T: NativeType> PrimitiveArray { + /// constructs a new iterator + pub fn iter(&'a self) -> PrimitiveIter<'a, T> { + PrimitiveIter::<'a, T>::new(&self) + } +} diff --git a/src/array/primitive/mod.rs b/src/array/primitive/mod.rs new file mode 100644 index 00000000000..5d0721eab18 --- /dev/null +++ b/src/array/primitive/mod.rs @@ -0,0 +1,125 @@ +use crate::{ + buffer::{types::NativeType, Bitmap, Buffer}, + datatypes::DataType, + ffi::ArrowArray, +}; + +use crate::error::Result; + +use super::{ + ffi::{FromFFI, ToFFI}, + Array, +}; + +#[derive(Debug, Clone)] +pub struct PrimitiveArray { + data_type: DataType, + values: Buffer, + validity: Option, + offset: usize, +} + +impl PrimitiveArray { + pub fn new_empty(data_type: DataType) -> Self { + Self::from_data(data_type, Buffer::new(), None) + } + + pub fn from_data(data_type: DataType, values: Buffer, validity: Option) -> Self { + assert!(T::is_valid(&data_type)); + Self { + data_type, + values, + validity, + offset: 0, + } + } + + pub fn slice(&self, offset: usize, length: usize) -> Self { + let validity = self.validity.clone().map(|x| x.slice(offset, length)); + Self { + data_type: self.data_type.clone(), + values: self.values.clone().slice(offset, length), + validity, + offset: self.offset + offset, + } + } + + #[inline] + pub fn values(&self) -> &Buffer { + &self.values + } + + #[inline] + pub fn value(&self, i: usize) -> T { + self.values().as_slice()[i] + } +} + +impl Array for PrimitiveArray { + #[inline] + fn as_any(&self) -> &dyn std::any::Any { + self + } + + #[inline] + fn len(&self) -> usize { + self.values.len() + } + + #[inline] + fn data_type(&self) -> &DataType { + &self.data_type + } + + fn nulls(&self) -> &Option { + &self.validity + } + + fn slice(&self, offset: usize, length: usize) -> Box { + Box::new(self.slice(offset, length)) + } +} + +unsafe impl ToFFI for PrimitiveArray { + fn buffers(&self) -> [Option>; 3] { + unsafe { + [ + self.validity.as_ref().map(|x| x.as_ptr()), + Some(std::ptr::NonNull::new_unchecked( + self.values.as_ptr() as *mut u8 + )), + None, + ] + } + } + + #[inline] + fn offset(&self) -> usize { + self.offset + } +} + +unsafe impl FromFFI for PrimitiveArray { + fn try_from_ffi(data_type: DataType, array: ArrowArray) -> Result { + let length = array.len(); + let offset = array.offset(); + let mut validity = array.null_bit_buffer(); + let mut values = unsafe { array.buffer::(0)? }; + + if offset > 0 { + values = values.slice(offset, length); + validity = validity.map(|x| x.slice(offset, length)) + } + Ok(Self { + data_type, + values, + validity, + offset: 0, + }) + } +} + +mod from; +pub use from::Primitive; +mod iterator; +pub use iterator::*; diff --git a/src/array/specification.rs b/src/array/specification.rs new file mode 100644 index 00000000000..3258b3909cd --- /dev/null +++ b/src/array/specification.rs @@ -0,0 +1,86 @@ +use std::convert::TryFrom; + +use num::Num; + +use crate::buffer::{Buffer, NativeType}; + +pub unsafe trait Offset: NativeType + Num + Ord + std::ops::AddAssign { + fn is_large() -> bool; + + fn to_usize(&self) -> Option; + + fn from_usize(value: usize) -> Option; +} + +unsafe impl Offset for i32 { + #[inline] + fn is_large() -> bool { + false + } + + #[inline] + fn to_usize(&self) -> Option { + Some(*self as usize) + } + + #[inline] + fn from_usize(value: usize) -> Option { + Self::try_from(value).ok() + } +} + +unsafe impl Offset for i64 { + #[inline] + fn is_large() -> bool { + true + } + + #[inline] + fn to_usize(&self) -> Option { + usize::try_from(*self).ok() + } + + #[inline] + fn from_usize(value: usize) -> Option { + Some(value as i64) + } +} + +#[inline] +pub fn check_offsets(offsets: &Buffer, values_len: usize) -> usize { + assert!( + offsets.len() >= 1, + "The length of the offset buffer must be larger than 1" + ); + let len = offsets.len() - 1; + + let offsets = offsets.as_slice(); + + let last_offset = offsets[len]; + let last_offset = last_offset + .to_usize() + .expect("The last offset of the array is larger than usize::MAX"); + + assert_eq!( + values_len, last_offset, + "The length of the values must be equal to the last offset value" + ); + len +} + +#[inline] +pub fn check_offsets_and_utf8(offsets: &Buffer, values: &Buffer) -> usize { + let len = check_offsets(offsets, values.len()); + offsets.as_slice().windows(2).for_each(|window| { + let start = window[0] + .to_usize() + .expect("The last offset of the array is larger than usize::MAX"); + let end = window[1] + .to_usize() + .expect("The last offset of the array is larger than usize::MAX"); + assert!(end < values.len()); + let slice = unsafe { std::slice::from_raw_parts(values.as_ptr().add(start), end - start) }; + std::str::from_utf8(slice).expect("A non-utf8 string was passed."); + }); + len +} diff --git a/src/array/string/from.rs b/src/array/string/from.rs new file mode 100644 index 00000000000..1a61831f3cc --- /dev/null +++ b/src/array/string/from.rs @@ -0,0 +1,223 @@ +use std::sync::Arc; + +use crate::{ + array::{Array, Builder, Offset, ToArray}, + buffer::{Bitmap, Buffer, MutableBitmap, MutableBuffer}, + datatypes::DataType, +}; + +use super::Utf8Array; + +impl Utf8Array { + pub fn from_slice, P: AsRef<[T]>>(slice: P) -> Self { + unsafe { Self::from_trusted_len_iter(slice.as_ref().iter().map(Some)) } + } +} + +impl> From<&Vec>> for Utf8Array { + fn from(slice: &Vec>) -> Self { + unsafe { Self::from_trusted_len_iter(slice.iter().map(|x| x.as_ref())) } + } +} + +impl Utf8Array { + /// Creates a [`PrimitiveArray`] from an iterator of trusted length. + /// # Safety + /// The iterator must be [`TrustedLen`](https://doc.rust-lang.org/std/iter/trait.TrustedLen.html). + /// I.e. that `size_hint().1` correctly reports its length. + #[inline] + pub unsafe fn from_trusted_len_iter(iter: I) -> Self + where + P: AsRef, + I: IntoIterator>, + { + let iterator = iter.into_iter(); + + let (validity, offsets, values) = trusted_len_unzip(iterator); + + // soundness: P is `str` + Self::from_data_unchecked(offsets, values, validity) + } + + /// Creates a [`PrimitiveArray`] from an falible iterator of trusted length. + /// # Safety + /// The iterator must be [`TrustedLen`](https://doc.rust-lang.org/std/iter/trait.TrustedLen.html). + /// I.e. that `size_hint().1` correctly reports its length. + #[inline] + pub unsafe fn try_from_trusted_len_iter(iter: I) -> Result + where + P: AsRef, + I: IntoIterator, E>>, + { + let iterator = iter.into_iter(); + + let (validity, offsets, values) = try_trusted_len_unzip(iterator)?; + + // soundness: P is `str` + Ok(Self::from_data_unchecked(offsets, values, validity)) + } +} + +/// Creates a Bitmap and a [`Buffer`] from an iterator of `Option`. +/// The first buffer corresponds to a bitmap buffer, the second one +/// corresponds to a values buffer. +/// # Safety +/// The caller must ensure that `iterator` is `TrustedLen`. +#[inline] +pub(crate) unsafe fn trusted_len_unzip( + iterator: I, +) -> (Option, Buffer, Buffer) +where + O: Offset, + P: AsRef, + I: Iterator>, +{ + let (_, upper) = iterator.size_hint(); + let len = upper.expect("trusted_len_unzip requires an upper limit"); + + let mut null = MutableBitmap::with_capacity(len); + let mut offsets = MutableBuffer::::with_capacity(len + 1); + let mut values = MutableBuffer::::new(); + + let mut length = O::default(); + let mut dst = offsets.as_mut_ptr(); + std::ptr::write(dst, length); + dst = dst.add(1); + for item in iterator { + if let Some(item) = item { + null.push_unchecked(true); + let s = item.as_ref(); + length += O::from_usize(s.len()).unwrap(); + values.extend_from_slice(s.as_bytes()); + } else { + null.push_unchecked(false); + }; + + std::ptr::write(dst, length); + dst = dst.add(1); + } + assert_eq!( + dst.offset_from(offsets.as_ptr()) as usize, + len + 1, + "Trusted iterator length was not accurately reported" + ); + offsets.set_len(len + 1); + null.set_len(len); + + let bitmap = if null.null_count() > 0 { + Some(null.into()) + } else { + None + }; + (bitmap, offsets.into(), values.into()) +} + +/// # Safety +/// The caller must ensure that `iterator` is `TrustedLen`. +#[inline] +pub(crate) unsafe fn try_trusted_len_unzip( + iterator: I, +) -> Result<(Option, Buffer, Buffer), E> +where + O: Offset, + P: AsRef, + I: Iterator, E>>, +{ + let (_, upper) = iterator.size_hint(); + let len = upper.expect("trusted_len_unzip requires an upper limit"); + + let mut null = MutableBitmap::with_capacity(len); + let mut offsets = MutableBuffer::::with_capacity(len + 1); + offsets.push(O::default()); + let mut values = MutableBuffer::::new(); + + let mut length = O::default(); + let mut dst = offsets.as_mut_ptr(); + for item in iterator { + if let Some(item) = item? { + null.push_unchecked(true); + let s = item.as_ref(); + length += O::from_usize(s.len()).unwrap(); + values.extend_from_slice(s.as_bytes()); + } else { + null.push_unchecked(false); + }; + std::ptr::write(dst, length); + dst = dst.add(1); + } + assert_eq!( + dst.offset_from(offsets.as_ptr()) as usize, + len, + "Trusted iterator length was not accurately reported" + ); + offsets.set_len(len); + null.set_len(len); + + let bitmap = if null.null_count() > 0 { + Some(null.into()) + } else { + None + }; + Ok((bitmap, offsets.into(), values.into())) +} + +/// auxiliary struct used to create a [`PrimitiveArray`] out of an iterator +#[derive(Debug)] +pub struct Utf8Primitive { + offsets: MutableBuffer, + values: MutableBuffer, + validity: MutableBitmap, + // invariant: always equal to the last offset + length: O, +} + +impl Builder<&str> for Utf8Primitive { + #[inline] + fn with_capacity(capacity: usize) -> Self { + let mut offsets = MutableBuffer::::with_capacity(capacity + 1); + offsets.push(O::default()); + + Self { + offsets, + values: MutableBuffer::::new(), + validity: MutableBitmap::with_capacity(capacity), + length: O::default(), + } + } + + #[inline] + fn push(&mut self, value: Option<&&str>) { + match value { + Some(v) => { + self.offsets.push(O::from_usize(v.len()).unwrap()); + self.values.extend_from_slice(&v.as_bytes()); + self.validity.push(true); + } + None => { + self.offsets.push(self.length); + self.validity.push(false); + } + } + } +} + +impl Utf8Primitive { + pub fn to(self) -> Utf8Array { + let validity = if self.validity.null_count() > 0 { + Some(self.validity.into()) + } else { + None + }; + + // Soundness: all methods from `Utf8Primitive` receive &str + unsafe { + Utf8Array::::from_data_unchecked(self.offsets.into(), self.values.into(), validity) + } + } +} + +impl ToArray for Utf8Primitive { + fn to_arc(self, _: &DataType) -> Arc { + Arc::new(self.to()) + } +} diff --git a/src/array/string/iterator.rs b/src/array/string/iterator.rs new file mode 100644 index 00000000000..0c4129be08c --- /dev/null +++ b/src/array/string/iterator.rs @@ -0,0 +1,65 @@ +use crate::array::{Array, Offset}; + +use super::Utf8Array; + +impl<'a, O: Offset> IntoIterator for &'a Utf8Array { + type Item = Option<&'a str>; + type IntoIter = Utf8Iter<'a, O>; + + fn into_iter(self) -> Self::IntoIter { + Utf8Iter::new(self) + } +} + +impl<'a, O: Offset> Utf8Array { + /// constructs a new iterator + pub fn iter(&'a self) -> Utf8Iter<'a, O> { + Utf8Iter::new(&self) + } +} + +/// an iterator that returns `Some(&str)` or `None`, for string arrays +#[derive(Debug)] +pub struct Utf8Iter<'a, T> +where + T: Offset, +{ + array: &'a Utf8Array, + i: usize, + len: usize, +} + +impl<'a, T: Offset> Utf8Iter<'a, T> { + /// create a new iterator + pub fn new(array: &'a Utf8Array) -> Self { + Utf8Iter:: { + array, + i: 0, + len: array.len(), + } + } +} + +impl<'a, T: Offset> std::iter::Iterator for Utf8Iter<'a, T> { + type Item = Option<&'a str>; + + fn next(&mut self) -> Option { + let i = self.i; + if i >= self.len { + None + } else if self.array.is_null(i) { + self.i += 1; + Some(None) + } else { + self.i += 1; + Some(Some(unsafe { self.array.value_unchecked(i) })) + } + } + + fn size_hint(&self) -> (usize, Option) { + (self.len - self.i, Some(self.len - self.i)) + } +} + +/// all arrays have known size. +impl<'a, T: Offset> std::iter::ExactSizeIterator for Utf8Iter<'a, T> {} diff --git a/src/array/string/mod.rs b/src/array/string/mod.rs new file mode 100644 index 00000000000..72268de8611 --- /dev/null +++ b/src/array/string/mod.rs @@ -0,0 +1,161 @@ +use crate::{ + buffer::{Bitmap, Buffer}, + datatypes::DataType, +}; + +use super::{ + ffi::ToFFI, + specification::{check_offsets, check_offsets_and_utf8}, + Array, Offset, +}; + +#[derive(Debug, Clone)] +pub struct Utf8Array { + data_type: DataType, + offsets: Buffer, + values: Buffer, + validity: Option, + offset: usize, +} + +impl Utf8Array { + pub fn new_empty() -> Self { + unsafe { Self::from_data_unchecked(Buffer::from(&[O::zero()]), Buffer::new(), None) } + } + + pub fn from_data(offsets: Buffer, values: Buffer, validity: Option) -> Self { + check_offsets_and_utf8(&offsets, &values); + + Self { + data_type: if O::is_large() { + DataType::LargeUtf8 + } else { + DataType::Utf8 + }, + offsets, + values, + validity, + offset: 0, + } + } + + /// # Safety + /// `values` buffer must contain valid utf8 between every `offset` + pub unsafe fn from_data_unchecked( + offsets: Buffer, + values: Buffer, + validity: Option, + ) -> Self { + check_offsets(&offsets, values.len()); + + Self { + data_type: if O::is_large() { + DataType::LargeUtf8 + } else { + DataType::Utf8 + }, + offsets, + values, + validity, + offset: 0, + } + } + + /// Returns the element at index `i` as &str + /// # Safety + /// Assumes that the `i < self.len`. + pub unsafe fn value_unchecked(&self, i: usize) -> &str { + let offset = *self.offsets.as_ptr().add(i); + let offset_1 = *self.offsets.as_ptr().add(i + 1); + let length = (offset_1 - offset).to_usize().unwrap(); + let offset = offset.to_usize().unwrap(); + + let slice = std::slice::from_raw_parts(self.values.as_ptr().add(offset), length); + // todo: validate utf8 so that we can use the unsafe version + std::str::from_utf8(slice).unwrap() + } + + pub fn slice(&self, offset: usize, length: usize) -> Self { + let validity = self.validity.clone().map(|x| x.slice(offset, length)); + Self { + data_type: self.data_type.clone(), + offsets: self.offsets.clone().slice(offset, length), + values: self.values.clone(), + validity, + offset: self.offset + offset, + } + } + + /// Returns the element at index `i` as &str + pub fn value(&self, i: usize) -> &str { + let offsets = self.offsets.as_slice(); + let offset = offsets[i]; + let offset_1 = offsets[i + 1]; + let length = (offset_1 - offset).to_usize().unwrap(); + let offset = offset.to_usize().unwrap(); + + let slice = &self.values.as_slice()[offset..offset + length]; + // todo: validate utf8 so that we can use the unsafe version + std::str::from_utf8(slice).unwrap() + } + + #[inline] + pub fn offsets(&self) -> &[O] { + self.offsets.as_slice() + } + + #[inline] + pub fn values(&self) -> &[u8] { + self.values.as_slice() + } +} + +impl Array for Utf8Array { + #[inline] + fn as_any(&self) -> &dyn std::any::Any { + self + } + + #[inline] + fn len(&self) -> usize { + self.offsets.len() - 1 + } + + #[inline] + fn data_type(&self) -> &DataType { + &self.data_type + } + + fn nulls(&self) -> &Option { + &self.validity + } + + fn slice(&self, offset: usize, length: usize) -> Box { + Box::new(self.slice(offset, length)) + } +} + +unsafe impl ToFFI for Utf8Array { + fn buffers(&self) -> [Option>; 3] { + unsafe { + [ + self.validity.as_ref().map(|x| x.as_ptr()), + Some(std::ptr::NonNull::new_unchecked( + self.offsets.as_ptr() as *mut u8 + )), + Some(std::ptr::NonNull::new_unchecked( + self.values.as_ptr() as *mut u8 + )), + ] + } + } + + fn offset(&self) -> usize { + self.offset + } +} + +mod from; +pub use from::*; +mod iterator; +pub use iterator::*; diff --git a/src/array/struct_.rs b/src/array/struct_.rs new file mode 100644 index 00000000000..5c09fc95f9d --- /dev/null +++ b/src/array/struct_.rs @@ -0,0 +1,92 @@ +use std::sync::Arc; + +use crate::{ + buffer::Bitmap, + datatypes::{DataType, Field}, +}; + +use super::{ffi::ToFFI, new_empty_array, Array}; + +#[derive(Debug, Clone)] +pub struct StructArray { + data_type: DataType, + values: Vec>, + validity: Option, +} + +impl StructArray { + pub fn new_empty(fields: &[Field]) -> Self { + let values = fields + .iter() + .map(|field| new_empty_array(field.data_type().clone()).into()) + .collect(); + Self::from_data(fields.to_vec(), values, None) + } + + pub fn from_data( + fields: Vec, + values: Vec>, + validity: Option, + ) -> Self { + Self { + data_type: DataType::Struct(fields), + values, + validity, + } + } + + pub fn slice(&self, offset: usize, length: usize) -> Self { + let validity = self.validity.clone().map(|x| x.slice(offset, length)); + Self { + data_type: self.data_type.clone(), + values: self + .values + .iter() + .map(|x| x.slice(offset, length).into()) + .collect(), + validity, + } + } + + #[inline] + pub fn values(&self) -> &[Arc] { + &self.values + } +} + +impl Array for StructArray { + #[inline] + fn as_any(&self) -> &dyn std::any::Any { + self + } + + #[inline] + fn len(&self) -> usize { + self.values[0].len() + } + + #[inline] + fn data_type(&self) -> &DataType { + &self.data_type + } + + #[inline] + fn nulls(&self) -> &Option { + &self.validity + } + + fn slice(&self, offset: usize, length: usize) -> Box { + Box::new(self.slice(offset, length)) + } +} + +unsafe impl ToFFI for StructArray { + fn buffers(&self) -> [Option>; 3] { + [self.validity.as_ref().map(|x| x.as_ptr()), None, None] + } + + fn offset(&self) -> usize { + // we do not support offsets in structs. Instead, if an FFI we slice the incoming arrays + 0 + } +} diff --git a/src/bits/chunk_iterator.rs b/src/bits/chunk_iterator.rs new file mode 100644 index 00000000000..085ef500c27 --- /dev/null +++ b/src/bits/chunk_iterator.rs @@ -0,0 +1,231 @@ +use super::bytes_for; + +#[derive(Debug)] +pub struct BitChunks<'a> { + buffer: &'a [u8], + /// offset inside a byte, guaranteed to be between 0 and 7 (inclusive) + bit_offset: usize, + /// number of complete u64 chunks + chunk_len: usize, + /// number of remaining bits, guaranteed to be between 0 and 63 (inclusive) + remainder_len: usize, +} + +impl<'a> BitChunks<'a> { + pub fn new(buffer: &'a [u8], offset: usize, len: usize) -> Self { + assert!(offset + len <= buffer.len() * 8); + + let byte_offset = offset / 8; + let bit_offset = offset % 8; + + let chunk_bits = 8 * std::mem::size_of::(); + + let chunk_len = len / chunk_bits; + let remainder_len = len & (chunk_bits - 1); + + BitChunks::<'a> { + buffer: &buffer[byte_offset..], + bit_offset, + chunk_len, + remainder_len, + } + } +} + +#[derive(Debug)] +pub struct BitChunkIterator<'a> { + buffer: &'a [u8], + bit_offset: usize, + chunk_len: usize, + index: usize, +} + +impl<'a> BitChunks<'a> { + /// Returns the number of remaining bits, guaranteed to be between 0 and 63 (inclusive) + #[inline] + pub const fn remainder_len(&self) -> usize { + self.remainder_len + } + + /// Returns the number of chunks + #[inline] + pub const fn chunk_len(&self) -> usize { + self.chunk_len + } + + /// Returns the bitmask of remaining bits + #[inline] + pub fn remainder_bits(&self) -> u64 { + let bit_len = self.remainder_len; + if bit_len == 0 { + 0 + } else { + let bit_offset = self.bit_offset; + // number of bytes to read + // might be one more than sizeof(u64) if the offset is in the middle of a byte + let byte_len = bytes_for(bit_len + bit_offset); + // pointer to remainder bytes after all complete chunks + let base = unsafe { + self.buffer + .as_ptr() + .add(self.chunk_len * std::mem::size_of::()) + }; + + let mut bits = unsafe { std::ptr::read(base) } as u64 >> bit_offset; + for i in 1..byte_len { + let byte = unsafe { std::ptr::read(base.add(i)) }; + bits |= (byte as u64) << (i * 8 - bit_offset); + } + + bits & ((1 << bit_len) - 1) + } + } + + /// Returns an iterator over chunks of 64 bits represented as an u64 + #[inline] + pub const fn iter(&self) -> BitChunkIterator<'a> { + BitChunkIterator::<'a> { + buffer: self.buffer, + bit_offset: self.bit_offset, + chunk_len: self.chunk_len, + index: 0, + } + } +} + +impl<'a> IntoIterator for BitChunks<'a> { + type Item = u64; + type IntoIter = BitChunkIterator<'a>; + + fn into_iter(self) -> Self::IntoIter { + self.iter() + } +} + +impl Iterator for BitChunkIterator<'_> { + type Item = u64; + + #[inline] + fn next(&mut self) -> Option { + let index = self.index; + if index >= self.chunk_len { + return None; + } + + // cast to *const u64 should be fine since we are using read_unaligned below + #[allow(clippy::cast_ptr_alignment)] + let raw_data = self.buffer.as_ptr() as *const u64; + + // bit-packed buffers are stored starting with the least-significant byte first + // so when reading as u64 on a big-endian machine, the bytes need to be swapped + let current = unsafe { std::ptr::read_unaligned(raw_data.add(index)).to_le() }; + + let combined = if self.bit_offset == 0 { + current + } else { + let next = unsafe { std::ptr::read_unaligned(raw_data.add(index + 1)).to_le() }; + + current >> self.bit_offset + | (next & ((1 << self.bit_offset) - 1)) << (64 - self.bit_offset) + }; + + self.index = index + 1; + + Some(combined) + } + + #[inline] + fn size_hint(&self) -> (usize, Option) { + ( + self.chunk_len - self.index, + Some(self.chunk_len - self.index), + ) + } +} + +impl ExactSizeIterator for BitChunkIterator<'_> { + #[inline] + fn len(&self) -> usize { + self.chunk_len - self.index + } +} + +#[cfg(test)] +mod tests { + use super::BitChunks; + + #[test] + fn test_iter_aligned() { + let input: &[u8] = &[0, 1, 2, 3, 4, 5, 6, 7]; + let bitchunks = BitChunks::new(&input, 0, 64); + + let result = bitchunks.into_iter().collect::>(); + + assert_eq!(vec![0x0706050403020100], result); + } + + #[test] + fn test_iter_unaligned() { + let input: &[u8] = &[ + 0b00000000, 0b00000001, 0b00000010, 0b00000100, 0b00001000, 0b00010000, 0b00100000, + 0b01000000, 0b11111111, + ]; + let bitchunks = BitChunks::new(&input, 4, 64); + + assert_eq!(0, bitchunks.remainder_len()); + assert_eq!(0, bitchunks.remainder_bits()); + + let result = bitchunks.into_iter().collect::>(); + + assert_eq!( + vec![0b1111010000000010000000010000000010000000010000000010000000010000], + result + ); + } + + #[test] + fn test_iter_unaligned_remainder_1_byte() { + let input: &[u8] = &[ + 0b00000000, 0b00000001, 0b00000010, 0b00000100, 0b00001000, 0b00010000, 0b00100000, + 0b01000000, 0b11111111, + ]; + let bitchunks = BitChunks::new(&input, 4, 66); + + assert_eq!(2, bitchunks.remainder_len()); + assert_eq!(0b00000011, bitchunks.remainder_bits()); + + let result = bitchunks.into_iter().collect::>(); + + assert_eq!( + vec![0b1111010000000010000000010000000010000000010000000010000000010000], + result + ); + } + + #[test] + fn test_iter_unaligned_remainder_bits_across_bytes() { + let input: &[u8] = &[0b00111111, 0b11111100]; + + // remainder contains bits from both bytes + // result should be the highest 2 bits from first byte followed by lowest 5 bits of second bytes + let bitchunks = BitChunks::new(&input, 6, 7); + + assert_eq!(7, bitchunks.remainder_len()); + assert_eq!(0b1110000, bitchunks.remainder_bits()); + } + + #[test] + fn test_iter_unaligned_remainder_bits_large() { + let input: &[u8] = &[ + 0b11111111, 0b00000000, 0b11111111, 0b00000000, 0b11111111, 0b00000000, 0b11111111, + 0b00000000, 0b11111111, + ]; + let bitchunks = BitChunks::new(&input, 2, 63); + + assert_eq!(63, bitchunks.remainder_len()); + assert_eq!( + 0b100_0000_0011_1111_1100_0000_0011_1111_1100_0000_0011_1111_1100_0000_0011_1111, + bitchunks.remainder_bits() + ); + } +} diff --git a/src/bits/mod.rs b/src/bits/mod.rs new file mode 100644 index 00000000000..1ef00ac85be --- /dev/null +++ b/src/bits/mod.rs @@ -0,0 +1,69 @@ +mod chunk_iterator; + +const BIT_MASK: [u8; 8] = [1, 2, 4, 8, 16, 32, 64, 128]; +const UNSET_BIT_MASK: [u8; 8] = [ + 255 - 1, + 255 - 2, + 255 - 4, + 255 - 8, + 255 - 16, + 255 - 32, + 255 - 64, + 255 - 128, +]; + +/// Sets bit at position `i` for `data` +#[inline] +pub fn set_bit(data: &mut [u8], i: usize) { + data[i >> 3] |= BIT_MASK[i & 7]; +} + +/// Sets bit at position `i` for `data` to 0 +#[inline] +pub fn unset_bit(data: &mut [u8], i: usize) { + data[i >> 3] &= UNSET_BIT_MASK[i & 7]; +} + +/// Returns whether bit at position `i` in `data` is set or not +#[inline] +pub fn get_bit(data: &[u8], i: usize) -> bool { + (data[i >> 3] & BIT_MASK[i & 7]) != 0 +} + +/// Returns the number of bytes required to hold `bits` bits. +#[inline] +pub fn bytes_for(bits: usize) -> usize { + bits.saturating_add(7) / 8 +} + +#[inline] +pub(crate) fn null_count(slice: &[u8], offset: usize, len: usize) -> usize { + let chunks = chunk_iterator::BitChunks::new(slice, offset, len); + + let mut count: usize = chunks.iter().map(|c| c.count_ones() as usize).sum(); + count += chunks.remainder_bits().count_ones() as usize; + + len - count +} + +/// Sets bit at position `i` for `data` +/// +/// # Safety +/// +/// Note this doesn't do any bound checking, for performance reason. The caller is +/// responsible to guarantee that `i` is within bounds. +#[inline] +pub unsafe fn set_bit_raw(data: *mut u8, i: usize) { + *data.add(i >> 3) |= BIT_MASK[i & 7]; +} + +/// Sets bit at position `i` for `data` to 0 +/// +/// # Safety +/// +/// Note this doesn't do any bound checking, for performance reason. The caller is +/// responsible to guarantee that `i` is within bounds. +#[inline] +pub unsafe fn unset_bit_raw(data: *mut u8, i: usize) { + *data.add(i >> 3) &= UNSET_BIT_MASK[i & 7]; +} diff --git a/src/buffer/alignment.rs b/src/buffer/alignment.rs new file mode 100644 index 00000000000..a8f3e3b1afe --- /dev/null +++ b/src/buffer/alignment.rs @@ -0,0 +1,102 @@ +// NOTE: Below code is written for spatial/temporal prefetcher optimizations. Memory allocation +// should align well with usage pattern of cache access and block sizes on layers of storage levels from +// registers to non-volatile memory. These alignments are all cache aware alignments incorporated +// from [cuneiform](https://crates.io/crates/cuneiform) crate. This approach mimicks Intel TBB's +// cache_aligned_allocator which exploits cache locality and minimizes prefetch signals +// resulting in less round trip time between the layers of storage. +// For further info: https://software.intel.com/en-us/node/506094 + +// 32-bit architecture and things other than netburst microarchitecture are using 64 bytes. +/// Cache and allocation multiple alignment size +#[cfg(target_arch = "x86")] +pub const ALIGNMENT: usize = 1 << 6; + +// Intel x86_64: +// L2D streamer from L1: +// Loads data or instructions from memory to the second-level cache. To use the streamer, +// organize the data or instructions in blocks of 128 bytes, aligned on 128 bytes. +// - https://www.intel.com/content/dam/www/public/us/en/documents/manuals/64-ia-32-architectures-optimization-manual.pdf +/// Cache and allocation multiple alignment size +#[cfg(target_arch = "x86_64")] +pub const ALIGNMENT: usize = 1 << 7; + +// 24Kc: +// Data Line Size +// - https://s3-eu-west-1.amazonaws.com/downloads-mips/documents/MD00346-2B-24K-DTS-04.00.pdf +// - https://gitlab.e.foundation/e/devices/samsung/n7100/stable_android_kernel_samsung_smdk4412/commit/2dbac10263b2f3c561de68b4c369bc679352ccee +/// Cache and allocation multiple alignment size +#[cfg(target_arch = "mips")] +pub const ALIGNMENT: usize = 1 << 5; +/// Cache and allocation multiple alignment size +#[cfg(target_arch = "mips64")] +pub const ALIGNMENT: usize = 1 << 5; + +// Defaults for powerpc +/// Cache and allocation multiple alignment size +#[cfg(target_arch = "powerpc")] +pub const ALIGNMENT: usize = 1 << 5; + +// Defaults for the ppc 64 +/// Cache and allocation multiple alignment size +#[cfg(target_arch = "powerpc64")] +pub const ALIGNMENT: usize = 1 << 6; + +// e.g.: sifive +// - https://github.com/torvalds/linux/blob/master/Documentation/devicetree/bindings/riscv/sifive-l2-cache.txt#L41 +// in general all of them are the same. +/// Cache and allocation multiple alignment size +#[cfg(target_arch = "riscv")] +pub const ALIGNMENT: usize = 1 << 6; + +// This size is same across all hardware for this architecture. +// - https://docs.huihoo.com/doxygen/linux/kernel/3.7/arch_2s390_2include_2asm_2cache_8h.html +/// Cache and allocation multiple alignment size +#[cfg(target_arch = "s390x")] +pub const ALIGNMENT: usize = 1 << 8; + +// This size is same across all hardware for this architecture. +// - https://docs.huihoo.com/doxygen/linux/kernel/3.7/arch_2sparc_2include_2asm_2cache_8h.html#a9400cc2ba37e33279bdbc510a6311fb4 +/// Cache and allocation multiple alignment size +#[cfg(target_arch = "sparc")] +pub const ALIGNMENT: usize = 1 << 5; +/// Cache and allocation multiple alignment size +#[cfg(target_arch = "sparc64")] +pub const ALIGNMENT: usize = 1 << 6; + +// On ARM cache line sizes are fixed. both v6 and v7. +// Need to add board specific or platform specific things later. +/// Cache and allocation multiple alignment size +#[cfg(target_arch = "thumbv6")] +pub const ALIGNMENT: usize = 1 << 5; +/// Cache and allocation multiple alignment size +#[cfg(target_arch = "thumbv7")] +pub const ALIGNMENT: usize = 1 << 5; + +// Operating Systems cache size determines this. +// Currently no way to determine this without runtime inference. +/// Cache and allocation multiple alignment size +#[cfg(target_arch = "wasm32")] +pub const ALIGNMENT: usize = FALLBACK_ALIGNMENT; + +// Same as v6 and v7. +// List goes like that: +// Cortex A, M, R, ARM v7, v7-M, Krait and NeoverseN uses this size. +/// Cache and allocation multiple alignment size +#[cfg(target_arch = "arm")] +pub const ALIGNMENT: usize = 1 << 5; + +// Combined from 4 sectors. Volta says 128. +// Prevent chunk optimizations better to go to the default size. +// If you have smaller data with less padded functionality then use 32 with force option. +// - https://devtalk.nvidia.com/default/topic/803600/variable-cache-line-width-/ +/// Cache and allocation multiple alignment size +#[cfg(target_arch = "nvptx")] +pub const ALIGNMENT: usize = 1 << 7; +/// Cache and allocation multiple alignment size +#[cfg(target_arch = "nvptx64")] +pub const ALIGNMENT: usize = 1 << 7; + +// This size is same across all hardware for this architecture. +/// Cache and allocation multiple alignment size +#[cfg(target_arch = "aarch64")] +pub const ALIGNMENT: usize = 1 << 6; diff --git a/src/buffer/alloc.rs b/src/buffer/alloc.rs new file mode 100644 index 00000000000..4234ff16623 --- /dev/null +++ b/src/buffer/alloc.rs @@ -0,0 +1,147 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Defines memory-related functions, such as allocate/deallocate/reallocate memory +//! regions, cache and allocation alignments. + +use super::{alignment::ALIGNMENT, types::NativeType}; + +use std::mem::size_of; +use std::ptr::NonNull; +use std::{ + alloc::{handle_alloc_error, Layout}, + sync::atomic::AtomicIsize, +}; + +// If this number is not zero after all objects have been `drop`, there is a memory leak +pub static mut ALLOCATIONS: AtomicIsize = AtomicIsize::new(0); + +#[inline] +unsafe fn null_pointer() -> NonNull { + NonNull::new_unchecked(ALIGNMENT as *mut T) +} + +/// Allocates a cache-aligned memory region of `size` bytes with uninitialized values. +/// This is more performant than using [allocate_aligned_zeroed] when all bytes will have +/// an unknown or non-zero value and is semantically similar to `malloc`. +pub fn allocate_aligned(size: usize) -> NonNull { + unsafe { + if size == 0 { + null_pointer() + } else { + let size = size * size_of::(); + ALLOCATIONS.fetch_add(size as isize, std::sync::atomic::Ordering::SeqCst); + + let layout = Layout::from_size_align_unchecked(size, ALIGNMENT); + let raw_ptr = std::alloc::alloc(layout) as *mut T; + NonNull::new(raw_ptr).unwrap_or_else(|| handle_alloc_error(layout)) + } + } +} + +/// Allocates a cache-aligned memory region of `size` bytes with `0` on all of them. +/// This is more performant than using [allocate_aligned] and setting all bytes to zero +/// and is semantically similar to `calloc`. +pub fn allocate_aligned_zeroed(size: usize) -> NonNull { + unsafe { + if size == 0 { + null_pointer() + } else { + let size = size * size_of::(); + ALLOCATIONS.fetch_add(size as isize, std::sync::atomic::Ordering::SeqCst); + + let layout = Layout::from_size_align_unchecked(size, ALIGNMENT); + let raw_ptr = std::alloc::alloc_zeroed(layout) as *mut T; + NonNull::new(raw_ptr).unwrap_or_else(|| handle_alloc_error(layout)) + } + } +} + +/// # Safety +/// +/// This function is unsafe because undefined behavior can result if the caller does not ensure all +/// of the following: +/// +/// * ptr must denote a block of memory currently allocated via this allocator, +/// +/// * size must be the same size that was used to allocate that block of memory, +pub unsafe fn free_aligned(ptr: NonNull, size: usize) { + if ptr != null_pointer() { + let size = size * size_of::(); + ALLOCATIONS.fetch_sub(size as isize, std::sync::atomic::Ordering::SeqCst); + std::alloc::dealloc( + ptr.as_ptr() as *mut u8, + Layout::from_size_align_unchecked(size, ALIGNMENT), + ); + } +} + +/// # Safety +/// +/// This function is unsafe because undefined behavior can result if the caller does not ensure all +/// of the following: +/// +/// * ptr must be currently allocated via this allocator, +/// +/// * new_size must be greater than zero. +/// +/// * new_size, when rounded up to the nearest multiple of [ALIGNMENT], must not overflow (i.e., +/// the rounded value must be less than usize::MAX). +pub unsafe fn reallocate( + ptr: NonNull, + old_size: usize, + new_size: usize, +) -> NonNull { + let old_size = old_size * size_of::(); + let new_size = new_size * size_of::(); + if ptr == null_pointer() { + return allocate_aligned(new_size); + } + + if new_size == 0 { + free_aligned(ptr, old_size); + return null_pointer(); + } + + ALLOCATIONS.fetch_add( + new_size as isize - old_size as isize, + std::sync::atomic::Ordering::SeqCst, + ); + let raw_ptr = std::alloc::realloc( + ptr.as_ptr() as *mut u8, + Layout::from_size_align_unchecked(old_size, ALIGNMENT), + new_size, + ) as *mut T; + NonNull::new(raw_ptr).unwrap_or_else(|| { + handle_alloc_error(Layout::from_size_align_unchecked(new_size, ALIGNMENT)) + }) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_allocate() { + for _ in 0..10 { + let p = allocate_aligned::(1024); + // make sure this is 64-byte aligned + assert_eq!(0, (p.as_ptr() as usize) % 64); + unsafe { free_aligned(p, 1024) }; + } + } +} diff --git a/src/buffer/bitmap.rs b/src/buffer/bitmap.rs new file mode 100644 index 00000000000..bed80cee5d2 --- /dev/null +++ b/src/buffer/bitmap.rs @@ -0,0 +1,250 @@ +use std::iter::FromIterator; +use std::sync::Arc; + +use crate::{ + bits::{get_bit, null_count, set_bit_raw, unset_bit_raw}, + buffer::bytes::Bytes, + ffi, +}; + +use super::{bytes::Deallocation, MutableBuffer}; + +#[derive(Debug, Clone)] +pub struct Bitmap { + bytes: Arc>, + // both are measured in bits. They are used to bound the bitmap to a region of Bytes. + offset: usize, + length: usize, + // this is a cache: it must be computed on initialization + null_count: usize, +} + +impl Bitmap { + pub fn new() -> Self { + MutableBitmap::new().into() + } + + #[inline] + pub fn len(&self) -> usize { + self.length + } + + #[inline] + pub fn is_empty(&self) -> bool { + self.len() == 0 + } + + #[inline] + pub fn from_bytes(bytes: Bytes, length: usize) -> Self { + assert!(length <= bytes.len() * 8); + let null_count = null_count(&bytes, 0, length); + Self { + length, + offset: 0, + bytes: Arc::new(bytes), + null_count, + } + } + + #[inline] + pub fn null_count_range(&self, offset: usize, length: usize) -> usize { + null_count(&self.bytes, self.offset + offset, length) + } + + #[inline] + pub fn null_count(&self) -> usize { + self.null_count + } + + #[inline] + pub fn slice(mut self, offset: usize, length: usize) -> Self { + let offset = self.offset + offset; + self.offset += offset; + self.length = length; + self.null_count = null_count(&self.bytes, self.offset, self.length); + self + } + + #[inline] + pub fn get_bit(&self, i: usize) -> bool { + get_bit(&self.bytes, self.offset + i) + } + + /// Returns a pointer to the start of this bitmap. + pub fn as_ptr(&self) -> std::ptr::NonNull { + self.bytes.ptr() + } +} + +#[derive(Debug)] +pub struct MutableBitmap { + buffer: MutableBuffer, + length: usize, +} + +impl MutableBitmap { + #[inline] + pub fn new() -> Self { + Self { + buffer: MutableBuffer::new(), + length: 0, + } + } + + #[inline] + pub fn with_capacity(capacity: usize) -> Self { + Self { + buffer: MutableBuffer::from_len_zeroed(capacity.saturating_add(7) / 8), + length: 0, + } + } + + #[inline] + pub fn push(&mut self, value: bool) { + self.buffer + .resize((self.length + 1).saturating_add(7) / 8, 0); + if value { + unsafe { set_bit_raw(self.buffer.as_mut_ptr(), self.length) }; + } else { + unsafe { unset_bit_raw(self.buffer.as_mut_ptr(), self.length) }; + } + self.length += 1; + } + + #[inline] + pub unsafe fn push_unchecked(&mut self, value: bool) { + if value { + set_bit_raw(self.buffer.as_mut_ptr(), self.length); + } else { + unset_bit_raw(self.buffer.as_mut_ptr(), self.length); + } + self.length += 1; + self.buffer.set_len(self.length.saturating_add(7) / 8); + } + + #[inline] + pub fn null_count(&self) -> usize { + null_count(&self.buffer, 0, self.length) + } + + /// Returns the number of bytes in the buffer + pub fn len(&self) -> usize { + self.length + } + + /// Returns whether the buffer is empty. + pub fn is_empty(&self) -> bool { + self.len() == 0 + } + + /// # Safety + /// The caller must ensure that the buffer was properly initialized up to `len`. + #[inline] + pub(crate) unsafe fn set_len(&mut self, len: usize) { + self.buffer.set_len(len.saturating_add(7) / 8); + self.length = len; + } +} + +impl From for Bitmap { + #[inline] + fn from(buffer: MutableBitmap) -> Self { + Bitmap::from_bytes(buffer.buffer.into(), buffer.length) + } +} + +impl FromIterator for MutableBitmap { + fn from_iter(iter: I) -> Self + where + I: IntoIterator, + { + let mut iterator = iter.into_iter(); + let mut buffer = { + let byte_capacity: usize = iterator.size_hint().0.saturating_add(7) / 8; + MutableBuffer::with_capacity(byte_capacity) + }; + + let mut length = 0; + + loop { + let mut exhausted = false; + let mut byte_accum: u8 = 0; + let mut mask: u8 = 1; + + //collect (up to) 8 bits into a byte + while mask != 0 { + if let Some(value) = iterator.next() { + length += 1; + byte_accum |= match value { + true => mask, + false => 0, + }; + mask <<= 1; + } else { + exhausted = true; + break; + } + } + + // break if the iterator was exhausted before it provided a bool for this byte + if exhausted && mask == 1 { + break; + } + + //ensure we have capacity to write the byte + if buffer.len() == buffer.capacity() { + //no capacity for new byte, allocate 1 byte more (plus however many more the iterator advertises) + let additional_byte_capacity = 1usize.saturating_add( + iterator.size_hint().0.saturating_add(7) / 8, //convert bit count to byte count, rounding up + ); + buffer.reserve(additional_byte_capacity) + } + + // Soundness: capacity was allocated above + unsafe { buffer.push_unchecked(byte_accum) }; + if exhausted { + break; + } + } + Self { buffer, length } + } +} + +impl Bitmap { + #[inline] + pub unsafe fn from_trusted_len_iter>(iterator: I) -> Self { + // todo implement `from_trusted_len_iter` for MutableBitmap + MutableBitmap::from_iter(iterator).into() + } +} + +impl Bitmap { + /// Creates a bitmap from an existing memory region (must already be byte-aligned), this + /// `Bitmap` **does not** free this piece of memory when dropped. + /// + /// # Arguments + /// + /// * `ptr` - Pointer to raw parts + /// * `len` - Length of raw parts in **bytes** + /// * `data` - An [ffi::FFI_ArrowArray] with the data + /// + /// # Safety + /// + /// This function is unsafe as there is no guarantee that the given pointer is valid for `len` + /// bytes and that the foreign deallocator frees the region. + pub unsafe fn from_unowned( + ptr: std::ptr::NonNull, + length: usize, + data: Arc, + ) -> Self { + // todo: make all kinds of assertions + let bytes = Bytes::new(ptr, length, Deallocation::Foreign(data)); + let null_count = null_count(&bytes, 0, length); + Self { + bytes: Arc::new(bytes), + offset: 0, + length, + null_count, + } + } +} diff --git a/src/buffer/bytes.rs b/src/buffer/bytes.rs new file mode 100644 index 00000000000..496b064479a --- /dev/null +++ b/src/buffer/bytes.rs @@ -0,0 +1,134 @@ +//! This module contains an implementation of a contiguous immutable memory region that knows +//! how to de-allocate itself, [`Bytes`]. + +use core::slice; +use std::{fmt::Debug, fmt::Formatter}; +use std::{ptr::NonNull, sync::Arc}; + +use crate::ffi; + +use super::{alloc, types::NativeType}; + +/// Mode of deallocating memory regions +pub enum Deallocation { + /// Native deallocation, using Rust deallocator with Arrow-specific memory aligment + Native(usize), + // Foreign interface, via a callback + Foreign(Arc), +} + +impl Debug for Deallocation { + fn fmt(&self, f: &mut Formatter) -> std::fmt::Result { + match self { + Deallocation::Native(capacity) => { + write!(f, "Deallocation::Native {{ capacity: {} }}", capacity) + } + Deallocation::Foreign(_) => { + write!(f, "Deallocation::Foreign {{ capacity: unknown }}") + } + } + } +} + +/// A continuous, fixed-size, immutable memory region that knows how to de-allocate itself. +/// This structs' API is inspired by the `bytes::Bytes`, but it is not limited to using rust's +/// global allocator nor u8 aligmnent. +/// +/// In the most common case, this buffer is allocated using [`allocate_aligned`](memory::allocate_aligned) +/// and deallocated accordingly [`free_aligned`](memory::free_aligned). +/// When the region is allocated by an foreign allocator, [Deallocation::Foreign], this calls the +/// foreign deallocator to deallocate the region when it is no longer needed. +pub struct Bytes { + /// The raw pointer to be begining of the region + ptr: NonNull, + + /// The number of bytes visible to this region. This is always smaller than its capacity (when avaliable). + len: usize, + + /// how to deallocate this region + deallocation: Deallocation, +} + +impl Bytes { + /// Takes ownership of an allocated memory region, + /// + /// # Arguments + /// + /// * `ptr` - Pointer to raw parts + /// * `len` - Length of raw parts in **bytes** + /// * `capacity` - Total allocated memory for the pointer `ptr`, in **bytes** + /// + /// # Safety + /// + /// This function is unsafe as there is no guarantee that the given pointer is valid for `len` + /// bytes. If the `ptr` and `capacity` come from a `Buffer`, then this is guaranteed. + #[inline] + pub unsafe fn new(ptr: std::ptr::NonNull, len: usize, deallocation: Deallocation) -> Self { + Self { + ptr, + len, + deallocation, + } + } + + #[inline] + fn as_slice(&self) -> &[T] { + self + } + + #[inline] + pub fn len(&self) -> usize { + self.len + } + + #[inline] + pub fn is_empty(&self) -> bool { + self.len == 0 + } + + #[inline] + pub fn ptr(&self) -> NonNull { + self.ptr + } +} + +impl Drop for Bytes { + #[inline] + fn drop(&mut self) { + match &self.deallocation { + Deallocation::Native(capacity) => { + unsafe { alloc::free_aligned(self.ptr, *capacity) }; + } + // foreign interface knows how to deallocate itself. + Deallocation::Foreign(_) => (), + } + } +} + +impl std::ops::Deref for Bytes { + type Target = [T]; + + fn deref(&self) -> &[T] { + unsafe { slice::from_raw_parts(self.ptr.as_ptr(), self.len) } + } +} + +impl PartialEq for Bytes { + fn eq(&self, other: &Bytes) -> bool { + self.as_slice() == other.as_slice() + } +} + +impl Debug for Bytes { + fn fmt(&self, f: &mut Formatter) -> std::fmt::Result { + write!(f, "Bytes {{ ptr: {:?}, len: {}, data: ", self.ptr, self.len,)?; + + f.debug_list().entries(self.iter()).finish()?; + + write!(f, " }}") + } +} + +// This is sound because `Bytes` is an imutable container +unsafe impl Send for Bytes {} +unsafe impl Sync for Bytes {} diff --git a/src/buffer/immutable.rs b/src/buffer/immutable.rs new file mode 100644 index 00000000000..82c8a5c6db5 --- /dev/null +++ b/src/buffer/immutable.rs @@ -0,0 +1,158 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! This module contains two main structs: [Buffer] and [MutableBuffer]. A buffer represents +//! a contiguous memory region that can be shared via `offsets`. + +use crate::ffi; + +use super::{ + bytes::{Bytes, Deallocation}, + types::NativeType, +}; + +use std::fmt::Debug; +use std::sync::Arc; +use std::{convert::AsRef, usize}; + +use super::mutable::MutableBuffer; + +/// Buffer represents a contiguous memory region that can be shared with other buffers and across +/// thread boundaries. +#[derive(Clone, PartialEq, Debug)] +pub struct Buffer { + /// the internal byte buffer. + data: Arc>, + + /// The offset into the buffer. + offset: usize, + + // the length of the buffer. Given a region `data` of N bytes, [offset..offset+length] is visible + // to this buffer. + length: usize, +} + +impl Buffer { + #[inline] + pub fn new() -> Self { + MutableBuffer::new().into() + } + + /// Auxiliary method to create a new Buffer + #[inline] + pub fn from_bytes(bytes: Bytes) -> Self { + let length = bytes.len(); + Buffer { + data: Arc::new(bytes), + offset: 0, + length, + } + } + + /// Returns the number of bytes in the buffer + #[inline] + pub fn len(&self) -> usize { + self.length + } + + /// Returns whether the buffer is empty. + #[inline] + pub fn is_empty(&self) -> bool { + self.len() == 0 + } + + /// Returns the byte slice stored in this buffer + #[inline] + pub fn as_slice(&self) -> &[T] { + &self.data[self.offset..self.offset + self.length] + } + + /// Returns a new [Buffer] that is a slice of this buffer starting at `offset`. + /// Doing so allows the same memory region to be shared between buffers. + /// # Panics + /// Panics iff `offset` is larger than `len`. + #[inline] + pub fn slice(mut self, offset: usize, length: usize) -> Self { + assert!( + offset + length <= self.len(), + "the offset of the new Buffer cannot exceed the existing length" + ); + self.offset += offset; + self.length = length; + self + } + + /// Returns a pointer to the start of this buffer. + #[inline] + pub fn as_ptr(&self) -> *const T { + unsafe { self.data.ptr().as_ptr().add(self.offset) } + } +} + +impl Buffer { + #[inline] + pub unsafe fn from_trusted_len_iter>(iterator: I) -> Self { + MutableBuffer::from_trusted_len_iter(iterator).into() + } +} + +/// Creating a `Buffer` instance by copying the memory from a `AsRef<[u8]>` into a newly +/// allocated memory region. +impl> From for Buffer { + #[inline] + fn from(p: U) -> Self { + // allocate aligned memory buffer + let slice = p.as_ref(); + let len = slice.len(); + let mut buffer = MutableBuffer::with_capacity(len); + buffer.extend_from_slice(slice); + buffer.into() + } +} + +impl Buffer { + /// Creates a buffer from an existing memory region (must already be byte-aligned), this + /// `Buffer` **does not** free this piece of memory when dropped. + /// + /// # Arguments + /// + /// * `ptr` - Pointer to raw parts + /// * `len` - Length of raw parts in **bytes** + /// * `data` - An [ffi::FFI_ArrowArray] with the data + /// + /// # Safety + /// + /// This function is unsafe as there is no guarantee that the given pointer is valid for `len` + /// bytes and that the foreign deallocator frees the region. + #[inline] + pub unsafe fn from_unowned( + ptr: std::ptr::NonNull, + length: usize, + data: Arc, + ) -> Self { + // todo: make all kinds of assertions here wrt to alignment, len, etc. + let ptr = ptr.as_ptr() as *mut T; + let ptr = + std::ptr::NonNull::::new(ptr).expect("Can't cast pointer from FFI: it is null"); + let bytes = Bytes::new(ptr, length, Deallocation::Foreign(data)); + Self { + data: Arc::new(bytes), + offset: 0, + length, + } + } +} diff --git a/src/buffer/mod.rs b/src/buffer/mod.rs new file mode 100644 index 00000000000..f88769c4375 --- /dev/null +++ b/src/buffer/mod.rs @@ -0,0 +1,15 @@ +mod alignment; +mod alloc; +mod bitmap; +mod immutable; +mod mutable; + +pub(crate) mod bytes; +pub(crate) mod types; +pub(crate) mod util; + +pub use bitmap::Bitmap; +pub use bitmap::MutableBitmap; +pub use immutable::Buffer; +pub use mutable::MutableBuffer; +pub use types::NativeType; diff --git a/src/buffer/mutable.rs b/src/buffer/mutable.rs new file mode 100644 index 00000000000..177647b08a4 --- /dev/null +++ b/src/buffer/mutable.rs @@ -0,0 +1,456 @@ +use std::iter::FromIterator; +use std::ptr::NonNull; +use std::usize; +use std::{fmt::Debug, mem::size_of}; + +use super::{ + alloc, + bytes::{Bytes, Deallocation}, + types::NativeType, + util, +}; + +use super::immutable::Buffer; + +#[inline] +fn capacity_multiple_of_64(capacity: usize) -> usize { + util::round_upto_multiple_of_64(capacity * size_of::()) / size_of::() +} + +/// A [`MutableBuffer`] is Arrow's interface to build a [`Buffer`] out of items or slices of items. +/// [`Buffer`]s created from [`MutableBuffer`] (via `into`) are guaranteed to have its pointer aligned +/// along cache lines and in multiple of 64 bytes. +/// Use [MutableBuffer::push] to insert an item, [MutableBuffer::extend_from_slice] +/// to insert many items, and `into` to convert it to [`Buffer`]. +/// # Example +/// ``` +/// # use arrow::buffer::{Buffer, MutableBuffer}; +/// let mut buffer = MutableBuffer::new(0); +/// buffer.push(256u32); +/// buffer.extend_from_slice(&[1u32]); +/// let buffer: Buffer = buffer.into(); +/// assert_eq!(buffer.as_slice(), &[0u8, 1, 0, 0, 1, 0, 0, 0]) +/// ``` +#[derive(Debug)] +pub struct MutableBuffer { + // dangling iff capacity = 0 + ptr: NonNull, + // invariant: len <= capacity + len: usize, + capacity: usize, +} + +impl MutableBuffer { + #[inline] + pub fn new() -> Self { + let ptr = alloc::allocate_aligned(0); + Self { + ptr, + len: 0, + capacity: 0, + } + } + + /// Allocate a new [MutableBuffer] with initial capacity to be at least `capacity`. + #[inline] + pub fn with_capacity(capacity: usize) -> Self { + let capacity = capacity_multiple_of_64::(capacity); + let ptr = alloc::allocate_aligned(capacity); + Self { + ptr, + len: 0, + capacity, + } + } + + /// Allocates a new [MutableBuffer] with `len` and capacity to be at least `len` where + /// all bytes are guaranteed to be `0u8`. + /// # Example + /// ``` + /// # use arrow::buffer::{Buffer, MutableBuffer}; + /// let mut buffer = MutableBuffer::from_len_zeroed(127); + /// assert_eq!(buffer.len(), 127); + /// assert!(buffer.capacity() >= 127); + /// let data = buffer.as_slice_mut(); + /// assert_eq!(data[126], 0u8); + /// ``` + pub fn from_len_zeroed(len: usize) -> Self { + let new_capacity = capacity_multiple_of_64::(len); + let ptr = alloc::allocate_aligned_zeroed(new_capacity); + Self { + ptr, + len, + capacity: new_capacity, + } + } + + /// Ensures that this buffer has at least `self.len + additional` bytes. This re-allocates iff + /// `self.len + additional > capacity`. + /// # Example + /// ``` + /// # use arrow::buffer::{Buffer, MutableBuffer}; + /// let mut buffer = MutableBuffer::new(0); + /// buffer.reserve(253); // allocates for the first time + /// (0..253u8).for_each(|i| buffer.push(i)); // no reallocation + /// let buffer: Buffer = buffer.into(); + /// assert_eq!(buffer.len(), 253); + /// ``` + // For performance reasons, this must be inlined so that the `if` is executed inside the caller, and not as an extra call that just + // exits. + #[inline(always)] + pub fn reserve(&mut self, additional: usize) { + let required_cap = self.len + additional; + if required_cap > self.capacity { + // JUSTIFICATION + // Benefit + // necessity + // Soundness + // `self.data` is valid for `self.capacity`. + let (ptr, new_capacity) = unsafe { reallocate(self.ptr, self.capacity, required_cap) }; + self.ptr = ptr; + self.capacity = new_capacity; + } + } + + /// Resizes the buffer, either truncating its contents (with no change in capacity), or + /// growing it (potentially reallocating it) and writing `value` in the newly available bytes. + /// # Example + /// ``` + /// # use arrow::buffer::{Buffer, MutableBuffer}; + /// let mut buffer = MutableBuffer::new(0); + /// buffer.resize(253, 2); // allocates for the first time + /// assert_eq!(buffer.as_slice()[252], 2u8); + /// ``` + // For performance reasons, this must be inlined so that the `if` is executed inside the caller, and not as an extra call that just + // exits. + #[inline(always)] + pub fn resize(&mut self, new_len: usize, value: T) { + if new_len > self.len { + let diff = new_len - self.len; + self.reserve(diff); + unsafe { + // write the value + let mut ptr = self.ptr.as_ptr().add(self.len); + (0..diff).for_each(|_| { + std::ptr::write(ptr, value); + ptr = ptr.add(1); + }) + } + } + // this truncates the buffer when new_len < self.len + self.len = new_len; + } + + /// Returns whether this buffer is empty or not. + #[inline] + pub fn is_empty(&self) -> bool { + self.len == 0 + } + + /// Returns the length (the number of bytes written) in this buffer. + /// The invariant `buffer.len() <= buffer.capacity()` is always upheld. + #[inline] + pub fn len(&self) -> usize { + self.len + } + + /// Returns the total capacity in this buffer. + /// The invariant `buffer.len() <= buffer.capacity()` is always upheld. + #[inline] + pub fn capacity(&self) -> usize { + self.capacity + } + + /// Clear all existing data from this buffer. + pub fn clear(&mut self) { + self.len = 0 + } + + /// Returns the data stored in this buffer as a slice. + pub fn as_slice(&self) -> &[T] { + self + } + + /// Returns the data stored in this buffer as a mutable slice. + pub fn as_slice_mut(&mut self) -> &mut [T] { + self + } + + /// Returns a raw pointer to this buffer's internal memory + /// This pointer is guaranteed to be aligned along cache-lines. + #[inline] + pub fn as_ptr(&self) -> *const T { + self.ptr.as_ptr() + } + + /// Returns a mutable raw pointer to this buffer's internal memory + /// This pointer is guaranteed to be aligned along cache-lines. + #[inline] + pub fn as_mut_ptr(&mut self) -> *mut T { + self.ptr.as_ptr() + } + + /// Extends this buffer from a slice of items that can be represented in bytes, increasing its capacity if needed. + /// # Example + /// ``` + /// # use arrow::buffer::MutableBuffer; + /// let mut buffer = MutableBuffer::new(0); + /// buffer.extend_from_slice(&[2u32, 0]); + /// assert_eq!(buffer.len(), 8) // u32 has 4 bytes + /// ``` + pub fn extend_from_slice(&mut self, items: &[T]) { + let len = items.len(); + let additional = len; + self.reserve(additional); + unsafe { + let dst = self.ptr.as_ptr().add(self.len); + let src = items.as_ptr(); + std::ptr::copy_nonoverlapping(src, dst, additional) + } + self.len += additional; + } + + /// Extends the buffer with a new item, increasing its capacity if needed. + /// # Example + /// ``` + /// # use arrow::buffer::MutableBuffer; + /// let mut buffer = MutableBuffer::new(0); + /// buffer.push(256u32); + /// assert_eq!(buffer.len(), 4) // u32 has 4 bytes + /// ``` + #[inline] + pub fn push(&mut self, item: T) { + self.reserve(1); + unsafe { + let dst = self.ptr.as_ptr().add(self.len) as *mut T; + std::ptr::write(dst, item); + } + self.len += 1; + } + + /// Extends the buffer with a new item, without checking for sufficient capacity + /// Safety + /// Caller must ensure that the capacity()-len()>=size_of() + #[inline] + pub(crate) unsafe fn push_unchecked(&mut self, item: T) { + let dst = self.ptr.as_ptr().add(self.len); + std::ptr::write(dst, item); + self.len += 1; + } + + /// # Safety + /// The caller must ensure that the buffer was properly initialized up to `len`. + #[inline] + pub(crate) unsafe fn set_len(&mut self, len: usize) { + assert!(len <= self.capacity()); + self.len = len; + } +} + +/// # Safety +/// `ptr` must be allocated for `old_capacity`. +#[inline] +unsafe fn reallocate( + ptr: NonNull, + old_capacity: usize, + new_capacity: usize, +) -> (NonNull, usize) { + let new_capacity = capacity_multiple_of_64::(new_capacity); + let new_capacity = std::cmp::max(new_capacity, old_capacity * 2); + let ptr = alloc::reallocate(ptr, old_capacity, new_capacity); + (ptr, new_capacity) +} + +impl Extend for MutableBuffer { + #[inline] + fn extend>(&mut self, iter: T) { + let iterator = iter.into_iter(); + self.extend_from_iter(iterator) + } +} + +impl MutableBuffer { + #[inline] + fn extend_from_iter>(&mut self, mut iterator: I) { + let (lower, _) = iterator.size_hint(); + let additional = lower; + self.reserve(additional); + + // this is necessary because of https://github.com/rust-lang/rust/issues/32155 + let mut len = SetLenOnDrop::new(&mut self.len); + let mut dst = unsafe { self.ptr.as_ptr().add(len.local_len) as *mut T }; + let capacity = self.capacity; + + while len.local_len + 1 <= capacity { + if let Some(item) = iterator.next() { + unsafe { + std::ptr::write(dst, item); + dst = dst.add(1); + } + len.local_len += 1; + } else { + break; + } + } + drop(len); + + iterator.for_each(|item| self.push(item)); + } + + /// Creates a [`MutableBuffer`] from an [`Iterator`] with a trusted (upper) length. + /// Prefer this to `collect` whenever possible, as it is faster ~60% faster. + /// # Example + /// ``` + /// # use arrow::buffer::MutableBuffer; + /// let v = vec![1u32]; + /// let iter = v.iter().map(|x| x * 2); + /// let buffer = unsafe { MutableBuffer::from_trusted_len_iter(iter) }; + /// assert_eq!(buffer.len(), 4) // u32 has 4 bytes + /// ``` + /// # Safety + /// This method assumes that the iterator's size is correct and is undefined behavior + /// to use it on an iterator that reports an incorrect length. + // This implementation is required for two reasons: + // 1. there is no trait `TrustedLen` in stable rust and therefore + // we can't specialize `extend` for `TrustedLen` like `Vec` does. + // 2. `from_trusted_len_iter` is faster. + pub unsafe fn from_trusted_len_iter>(iterator: I) -> Self { + let (_, upper) = iterator.size_hint(); + let upper = upper.expect("from_trusted_len_iter requires an upper limit"); + let len = upper; + + let mut buffer = MutableBuffer::with_capacity(len); + + let mut dst = buffer.ptr.as_ptr(); + for item in iterator { + // note how there is no reserve here (compared with `extend_from_iter`) + std::ptr::write(dst, item); + dst = dst.add(1); + } + assert_eq!( + dst.offset_from(buffer.ptr.as_ptr()) as usize, + upper, + "Trusted iterator length was not accurately reported" + ); + buffer.len = len; + buffer + } + + /// Creates a [`MutableBuffer`] from an [`Iterator`] with a trusted (upper) length or errors + /// if any of the items of the iterator is an error. + /// Prefer this to `collect` whenever possible, as it is faster ~60% faster. + /// # Safety + /// This method assumes that the iterator's size is correct and is undefined behavior + /// to use it on an iterator that reports an incorrect length. + pub unsafe fn try_from_trusted_len_iter>>( + iterator: I, + ) -> std::result::Result { + let (_, upper) = iterator.size_hint(); + let upper = upper.expect("try_from_trusted_len_iter requires an upper limit"); + let len = upper; + + let mut buffer = MutableBuffer::with_capacity(len); + + let mut dst = buffer.ptr.as_ptr(); + for item in iterator { + std::ptr::write(dst, item?); + dst = dst.add(1); + } + assert_eq!( + dst.offset_from(buffer.ptr.as_ptr()) as usize, + upper, + "Trusted iterator length was not accurately reported" + ); + buffer.len = len; + Ok(buffer) + } +} + +impl FromIterator for MutableBuffer { + fn from_iter>(iter: I) -> Self { + let mut iterator = iter.into_iter(); + + // first iteration, which will likely reserve sufficient space for the buffer. + let mut buffer = match iterator.next() { + None => MutableBuffer::new(), + Some(element) => { + let (lower, _) = iterator.size_hint(); + let mut buffer = MutableBuffer::with_capacity(lower.saturating_add(1)); + unsafe { + std::ptr::write(buffer.as_mut_ptr(), element); + buffer.len = 1; + } + buffer + } + }; + + buffer.extend_from_iter(iterator); + buffer.into() + } +} + +impl std::ops::Deref for MutableBuffer { + type Target = [T]; + + #[inline] + fn deref(&self) -> &[T] { + unsafe { std::slice::from_raw_parts(self.as_ptr(), self.len) } + } +} + +impl std::ops::DerefMut for MutableBuffer { + #[inline] + fn deref_mut(&mut self) -> &mut [T] { + unsafe { std::slice::from_raw_parts_mut(self.as_mut_ptr(), self.len) } + } +} + +impl Drop for MutableBuffer { + fn drop(&mut self) { + unsafe { alloc::free_aligned(self.ptr, self.capacity) }; + } +} + +struct SetLenOnDrop<'a> { + len: &'a mut usize, + local_len: usize, +} + +impl<'a> SetLenOnDrop<'a> { + #[inline] + fn new(len: &'a mut usize) -> Self { + SetLenOnDrop { + local_len: *len, + len, + } + } +} + +impl Drop for SetLenOnDrop<'_> { + #[inline] + fn drop(&mut self) { + *self.len = self.local_len; + } +} + +impl From> for Buffer { + #[inline] + fn from(buffer: MutableBuffer) -> Self { + Buffer::from_bytes(buffer.into()) + } +} + +impl From> for Bytes { + #[inline] + fn from(buffer: MutableBuffer) -> Self { + let result = unsafe { + Bytes::new( + buffer.ptr, + buffer.len, + Deallocation::Native(buffer.capacity), + ) + }; + std::mem::forget(buffer); + result + } +} diff --git a/src/buffer/types.rs b/src/buffer/types.rs new file mode 100644 index 00000000000..0b79246fcd7 --- /dev/null +++ b/src/buffer/types.rs @@ -0,0 +1,85 @@ +use crate::datatypes::DataType; + +pub unsafe trait NativeType: + Sized + Copy + std::fmt::Debug + std::fmt::Display + PartialEq + Default + Sized + 'static +{ + fn is_valid(data_type: &DataType) -> bool; +} + +unsafe impl NativeType for u8 { + fn is_valid(data_type: &DataType) -> bool { + data_type == &DataType::UInt8 + } +} + +unsafe impl NativeType for u16 { + fn is_valid(data_type: &DataType) -> bool { + data_type == &DataType::UInt16 + } +} + +unsafe impl NativeType for u32 { + fn is_valid(data_type: &DataType) -> bool { + data_type == &DataType::UInt32 + } +} +unsafe impl NativeType for u64 { + fn is_valid(data_type: &DataType) -> bool { + data_type == &DataType::UInt64 + } +} + +unsafe impl NativeType for i8 { + fn is_valid(data_type: &DataType) -> bool { + data_type == &DataType::Int8 + } +} + +unsafe impl NativeType for i16 { + fn is_valid(data_type: &DataType) -> bool { + data_type == &DataType::Int16 + } +} + +unsafe impl NativeType for i32 { + fn is_valid(data_type: &DataType) -> bool { + match data_type { + DataType::Int32 | DataType::Date32 | DataType::Time32(_) => true, + _ => false, + } + } +} + +unsafe impl NativeType for i64 { + fn is_valid(data_type: &DataType) -> bool { + match data_type { + DataType::Int64 + | DataType::Date64 + | DataType::Time64(_) + | DataType::Timestamp(_, _) => true, + _ => false, + } + } +} + +unsafe impl NativeType for f32 { + fn is_valid(data_type: &DataType) -> bool { + data_type == &DataType::Float32 + } +} + +unsafe impl NativeType for f64 { + fn is_valid(data_type: &DataType) -> bool { + data_type == &DataType::Float64 + } +} + +unsafe impl NativeType for i128 { + fn is_valid(data_type: &DataType) -> bool { + if let DataType::Decimal(_, _) = data_type { + true + } else { + false + } + } +} diff --git a/src/buffer/util.rs b/src/buffer/util.rs new file mode 100644 index 00000000000..7cd1324f196 --- /dev/null +++ b/src/buffer/util.rs @@ -0,0 +1,12 @@ +/// Returns the nearest number that is `>=` than `num` and is a multiple of 64 +#[inline] +pub fn round_upto_multiple_of_64(num: usize) -> usize { + round_upto_power_of_2(num, 64) +} + +/// Returns the nearest multiple of `factor` that is `>=` than `num`. Here `factor` must +/// be a power of 2. +pub fn round_upto_power_of_2(num: usize, factor: usize) -> usize { + debug_assert!(factor > 0 && (factor & (factor - 1)) == 0); + (num + (factor - 1)) & !(factor - 1) +} diff --git a/src/compute/arity.rs b/src/compute/arity.rs new file mode 100644 index 00000000000..e61b0a1010f --- /dev/null +++ b/src/compute/arity.rs @@ -0,0 +1,32 @@ +//! Defines kernels suitable to perform operations to primitive arrays. + +use crate::buffer::Buffer; +use crate::{ + array::{Array, PrimitiveArray}, + buffer::NativeType, + datatypes::DataType, +}; + +/// Applies an unary and infalible function to a primitive array. +/// This is the fastest way to perform an operation on a primitive array when +/// the benefits of a vectorized operation outweights the cost of branching nulls and non-nulls. +/// # Implementation +/// This will apply the function for all values, including those on null slots. +/// This implies that the operation must be infalible for any value of the corresponding type +/// or this function may panic. +pub fn unary(array: &PrimitiveArray, op: F, data_type: &DataType) -> PrimitiveArray +where + I: NativeType, + O: NativeType, + F: Fn(I) -> O, +{ + let values = array.values().as_slice().iter().map(|v| op(*v)); + // JUSTIFICATION + // Benefit + // ~60% speedup + // Soundness + // `values` is an iterator with a known size because arrays are sized. + let values = unsafe { Buffer::from_trusted_len_iter(values) }; + + PrimitiveArray::::from_data(data_type.clone(), values, array.nulls().clone()) +} diff --git a/src/compute/cast/boolean_to.rs b/src/compute/cast/boolean_to.rs new file mode 100644 index 00000000000..ede849a22a7 --- /dev/null +++ b/src/compute/cast/boolean_to.rs @@ -0,0 +1,44 @@ +use crate::{ + array::{Array, BooleanArray, Primitive}, + buffer::NativeType, + datatypes::DataType, +}; +use crate::{ + array::{Offset, Utf8Array}, + error::Result, +}; + +/// Cast Boolean types to numeric +/// +/// `false` returns 0 while `true` returns 1 +pub fn cast_bool_to_numeric(array: &dyn Array, to: &DataType) -> Result> +where + T: NativeType + num::One, +{ + let array = array.as_any().downcast_ref::().unwrap(); + + let iter = array + .iter() + .map(|x| x.map(|x| if x { T::one() } else { T::default() })); + + // Soundness: + // The iterator is trustedLen + let array = unsafe { Primitive::::from_trusted_len_iter(iter) }.to(to.clone()); + + Ok(Box::new(array)) +} + +/// Cast Boolean types to numeric +/// +/// `false` returns 0 while `true` returns 1 +pub fn cast_bool_to_utf8(array: &dyn Array) -> Result> { + let array = array.as_any().downcast_ref::().unwrap(); + + let iter = array.iter().map(|x| x.map(|x| if x { "1" } else { "0" })); + + // Soundness: + // The iterator is trustedLen + let array = unsafe { Utf8Array::::from_trusted_len_iter(iter) }; + + Ok(Box::new(array)) +} diff --git a/src/compute/cast/dictionary_to.rs b/src/compute/cast/dictionary_to.rs new file mode 100644 index 00000000000..544cea145e4 --- /dev/null +++ b/src/compute/cast/dictionary_to.rs @@ -0,0 +1,85 @@ +use super::{cast, primitive_to::cast_typed_primitive}; +use crate::{ + array::{Array, DictionaryKey}, + datatypes::DataType, +}; +use crate::{ + array::{DictionaryArray, PrimitiveArray}, + compute::take::take, + error::{ArrowError, Result}, +}; + +macro_rules! key_cast { + ($keys:expr, $values:expr, $array:expr, $to_keys_type:expr, $to_type:ty) => {{ + let cast_keys = cast_typed_primitive::<_, $to_type>($keys, $to_keys_type); + + // Failure to cast keys (because they don't fit in the + // target type) results in NULL values; + if cast_keys.null_count() > $keys.null_count() { + return Err(ArrowError::ComputeError(format!( + "Could not convert {} dictionary indexes from {:?} to {:?}", + cast_keys.null_count() - $keys.null_count(), + $keys.data_type(), + $to_keys_type + ))); + } + Ok(Box::new(DictionaryArray::<$to_type>::from_data( + cast_keys, $values, + ))) + }}; +} + +/// Attempts to cast an `ArrayDictionary` with keys type `K` into +/// `to_type`, for supported types. +pub fn dictionary_cast( + array: &dyn Array, + to_type: &DataType, +) -> Result> { + let array = array.as_any().downcast_ref::>().unwrap(); + let keys = array.keys(); + let values = array.values(); + + match to_type { + DataType::Dictionary(to_keys_type, to_values_type) => { + let values = cast(values.as_ref(), to_values_type)?.into(); + + // create the appropriate array type + match to_keys_type.as_ref() { + DataType::Int8 => key_cast!(keys, values, array, to_keys_type, i8), + DataType::Int16 => key_cast!(keys, values, array, to_keys_type, i16), + DataType::Int32 => key_cast!(keys, values, array, to_keys_type, i32), + DataType::Int64 => key_cast!(keys, values, array, to_keys_type, i64), + DataType::UInt8 => key_cast!(keys, values, array, to_keys_type, u8), + DataType::UInt16 => key_cast!(keys, values, array, to_keys_type, u16), + DataType::UInt32 => key_cast!(keys, values, array, to_keys_type, u32), + DataType::UInt64 => key_cast!(keys, values, array, to_keys_type, u64), + _ => { + return Err(ArrowError::ComputeError(format!( + "Unsupported type {:?} for dictionary index", + to_keys_type + ))) + } + } + } + _ => unpack_dictionary::(keys, values.as_ref(), to_type), + } +} + +// Unpack a dictionary where the keys are of type into a flattened array of type to_type +fn unpack_dictionary( + keys: &PrimitiveArray, + values: &dyn Array, + to_type: &DataType, +) -> Result> +where + K: DictionaryKey, +{ + // attempt to cast the dict values to the target type + // use the take kernel to expand out the dictionary + let values = cast(values, to_type)?; + + // take requires first casting i32 + let indices = cast_typed_primitive::<_, i32>(keys, &DataType::UInt32); + + take(values.as_ref(), &indices, None) +} diff --git a/src/compute/cast/mod.rs b/src/compute/cast/mod.rs new file mode 100644 index 00000000000..2b76a707120 --- /dev/null +++ b/src/compute/cast/mod.rs @@ -0,0 +1,2566 @@ +use crate::{ + array::*, + buffer::Buffer, + temporal_conversions::{MILLISECONDS, MILLISECONDS_IN_DAY}, +}; +use crate::{ + datatypes::*, + temporal_conversions::{MICROSECONDS, NANOSECONDS}, +}; +use crate::{ + error::{ArrowError, Result}, + temporal_conversions::SECONDS_IN_DAY, +}; + +use self::{ + boolean_to::{cast_bool_to_numeric, cast_bool_to_utf8}, + dictionary_to::dictionary_cast, + primitive_to::{ + cast_array_data, cast_numeric_arrays, cast_numeric_to_bool, cast_numeric_to_string, + primitive_to_dictionary, + }, + string_to::{cast_string_to_numeric, string_to_dictionary, to_date32, to_date64}, +}; + +use super::arity::unary; + +mod boolean_to; +mod dictionary_to; +mod primitive_to; +mod string_to; + +/// Returns true if this type is numeric: (UInt*, Unit*, or Float*). +fn is_numeric(t: &DataType) -> bool { + use DataType::*; + matches!( + t, + UInt8 | UInt16 | UInt32 | UInt64 | Int8 | Int16 | Int32 | Int64 | Float32 | Float64 + ) +} + +/// Return true if a value of type `from_type` can be cast into a +/// value of `to_type`. Note that such as cast may be lossy. +/// +/// If this function returns true to stay consistent with the `cast` kernel below. +pub fn can_cast_types(from_type: &DataType, to_type: &DataType) -> bool { + use self::DataType::*; + if from_type == to_type { + return true; + } + + match (from_type, to_type) { + (Struct(_), _) => false, + (_, Struct(_)) => false, + (List(list_from), List(list_to)) => { + can_cast_types(list_from.data_type(), list_to.data_type()) + } + (List(_), _) => false, + (_, List(list_to)) => can_cast_types(from_type, list_to.data_type()), + (Dictionary(_, from_value_type), Dictionary(_, to_value_type)) => { + can_cast_types(from_value_type, to_value_type) + } + (Dictionary(_, value_type), _) => can_cast_types(value_type, to_type), + (_, Dictionary(_, value_type)) => can_cast_types(from_type, value_type), + + (_, Boolean) => is_numeric(from_type), + (Boolean, _) => is_numeric(to_type) || to_type == &Utf8, + + (Utf8, Date32) => true, + (Utf8, Date64) => true, + (Utf8, _) => is_numeric(to_type), + (LargeUtf8, _) => is_numeric(to_type), + (_, Utf8) => is_numeric(from_type) || from_type == &Binary, + + // start numeric casts + (UInt8, UInt16) => true, + (UInt8, UInt32) => true, + (UInt8, UInt64) => true, + (UInt8, Int8) => true, + (UInt8, Int16) => true, + (UInt8, Int32) => true, + (UInt8, Int64) => true, + (UInt8, Float32) => true, + (UInt8, Float64) => true, + + (UInt16, UInt8) => true, + (UInt16, UInt32) => true, + (UInt16, UInt64) => true, + (UInt16, Int8) => true, + (UInt16, Int16) => true, + (UInt16, Int32) => true, + (UInt16, Int64) => true, + (UInt16, Float32) => true, + (UInt16, Float64) => true, + + (UInt32, UInt8) => true, + (UInt32, UInt16) => true, + (UInt32, UInt64) => true, + (UInt32, Int8) => true, + (UInt32, Int16) => true, + (UInt32, Int32) => true, + (UInt32, Int64) => true, + (UInt32, Float32) => true, + (UInt32, Float64) => true, + + (UInt64, UInt8) => true, + (UInt64, UInt16) => true, + (UInt64, UInt32) => true, + (UInt64, Int8) => true, + (UInt64, Int16) => true, + (UInt64, Int32) => true, + (UInt64, Int64) => true, + (UInt64, Float32) => true, + (UInt64, Float64) => true, + + (Int8, UInt8) => true, + (Int8, UInt16) => true, + (Int8, UInt32) => true, + (Int8, UInt64) => true, + (Int8, Int16) => true, + (Int8, Int32) => true, + (Int8, Int64) => true, + (Int8, Float32) => true, + (Int8, Float64) => true, + + (Int16, UInt8) => true, + (Int16, UInt16) => true, + (Int16, UInt32) => true, + (Int16, UInt64) => true, + (Int16, Int8) => true, + (Int16, Int32) => true, + (Int16, Int64) => true, + (Int16, Float32) => true, + (Int16, Float64) => true, + + (Int32, UInt8) => true, + (Int32, UInt16) => true, + (Int32, UInt32) => true, + (Int32, UInt64) => true, + (Int32, Int8) => true, + (Int32, Int16) => true, + (Int32, Int64) => true, + (Int32, Float32) => true, + (Int32, Float64) => true, + + (Int64, UInt8) => true, + (Int64, UInt16) => true, + (Int64, UInt32) => true, + (Int64, UInt64) => true, + (Int64, Int8) => true, + (Int64, Int16) => true, + (Int64, Int32) => true, + (Int64, Float32) => true, + (Int64, Float64) => true, + + (Float32, UInt8) => true, + (Float32, UInt16) => true, + (Float32, UInt32) => true, + (Float32, UInt64) => true, + (Float32, Int8) => true, + (Float32, Int16) => true, + (Float32, Int32) => true, + (Float32, Int64) => true, + (Float32, Float64) => true, + + (Float64, UInt8) => true, + (Float64, UInt16) => true, + (Float64, UInt32) => true, + (Float64, UInt64) => true, + (Float64, Int8) => true, + (Float64, Int16) => true, + (Float64, Int32) => true, + (Float64, Int64) => true, + (Float64, Float32) => true, + // end numeric casts + + // temporal casts + (Int32, Date32) => true, + (Int32, Time32(_)) => true, + (Date32, Int32) => true, + (Time32(_), Int32) => true, + (Int64, Date64) => true, + (Int64, Time64(_)) => true, + (Date64, Int64) => true, + (Time64(_), Int64) => true, + (Date32, Date64) => true, + (Date64, Date32) => true, + (Time32(TimeUnit::Second), Time32(TimeUnit::Millisecond)) => true, + (Time32(TimeUnit::Millisecond), Time32(TimeUnit::Second)) => true, + (Time32(_), Time64(_)) => true, + (Time64(TimeUnit::Microsecond), Time64(TimeUnit::Nanosecond)) => true, + (Time64(TimeUnit::Nanosecond), Time64(TimeUnit::Microsecond)) => true, + (Time64(_), Time32(to_unit)) => { + matches!(to_unit, TimeUnit::Second | TimeUnit::Millisecond) + } + (Timestamp(_, _), Int64) => true, + (Int64, Timestamp(_, _)) => true, + (Timestamp(_, _), Timestamp(_, _)) => true, + (Timestamp(_, _), Date32) => true, + (Timestamp(_, _), Date64) => true, + // date64 to timestamp might not make sense, + (Int64, Duration(_)) => true, + (Null, Int32) => true, + (_, _) => false, + } +} + +/// Cast `array` to the provided data type and return a new Array with +/// type `to_type`, if possible. +/// +/// Behavior: +/// * Boolean to Utf8: `true` => '1', `false` => `0` +/// * Utf8 to numeric: strings that can't be parsed to numbers return null, float strings +/// in integer casts return null +/// * Numeric to boolean: 0 returns `false`, any other value returns `true` +/// * List to List: the underlying data type is cast +/// * Primitive to List: a list array with 1 value per slot is created +/// * Date32 and Date64: precision lost when going to higher interval +/// * Time32 and Time64: precision lost when going to higher interval +/// * Timestamp and Date{32|64}: precision lost when going to higher interval +/// * Temporal to/from backing primitive: zero-copy with data type change +/// +/// Unsupported Casts +/// * To or from `StructArray` +/// * List to primitive +/// * Utf8 to boolean +/// * Interval and duration +pub fn cast(array: &dyn Array, to_type: &DataType) -> Result> { + use DataType::*; + let from_type = array.data_type(); + + // clone array if types are the same + if from_type == to_type { + return Ok(clone(array)); + } + match (from_type, to_type) { + (Struct(_), _) => Err(ArrowError::ComputeError( + "Cannot cast from struct to other types".to_string(), + )), + (_, Struct(_)) => Err(ArrowError::ComputeError( + "Cannot cast to struct from other types".to_string(), + )), + (List(_), List(to)) => { + let array = array.as_any().downcast_ref::>().unwrap(); + + let values = array.values(); + let new_values = cast(values.as_ref(), to.data_type())?.into(); + + let list = ListArray::::from_data( + to_type.clone(), + array.offsets().clone(), + new_values, + array.nulls().clone(), + ); + Ok(Box::new(list)) + } + + (List(_), _) => Err(ArrowError::ComputeError( + "Cannot cast list to non-list data types".to_string(), + )), + (_, List(to)) => { + // cast primitive to list's primitive + let values = cast(array, to.data_type())?.into(); + // create offsets, where if array.len() = 2, we have [0,1,2] + let offsets: Buffer = + unsafe { Buffer::from_trusted_len_iter(0..=array.len() as i32) }; + + let data_type = ListArray::::default_datatype(to.data_type().clone()); + let list_array = ListArray::::from_data(data_type, offsets, values, None); + + Ok(Box::new(list_array)) + } + + (Dictionary(index_type, _), _) => match **index_type { + DataType::Int8 => dictionary_cast::(array, to_type), + DataType::Int16 => dictionary_cast::(array, to_type), + DataType::Int32 => dictionary_cast::(array, to_type), + DataType::Int64 => dictionary_cast::(array, to_type), + DataType::UInt8 => dictionary_cast::(array, to_type), + DataType::UInt16 => dictionary_cast::(array, to_type), + DataType::UInt32 => dictionary_cast::(array, to_type), + DataType::UInt64 => dictionary_cast::(array, to_type), + _ => unreachable!(), + }, + (_, Dictionary(index_type, value_type)) => match **index_type { + DataType::Int8 => cast_to_dictionary::(array, value_type), + DataType::Int16 => cast_to_dictionary::(array, value_type), + DataType::Int32 => cast_to_dictionary::(array, value_type), + DataType::Int64 => cast_to_dictionary::(array, value_type), + DataType::UInt8 => cast_to_dictionary::(array, value_type), + DataType::UInt16 => cast_to_dictionary::(array, value_type), + DataType::UInt32 => cast_to_dictionary::(array, value_type), + DataType::UInt64 => cast_to_dictionary::(array, value_type), + _ => Err(ArrowError::ComputeError(format!( + "Casting from type {:?} to dictionary type {:?} not supported", + from_type, to_type, + ))), + }, + (_, Boolean) => match from_type { + UInt8 => cast_numeric_to_bool::(array), + UInt16 => cast_numeric_to_bool::(array), + UInt32 => cast_numeric_to_bool::(array), + UInt64 => cast_numeric_to_bool::(array), + Int8 => cast_numeric_to_bool::(array), + Int16 => cast_numeric_to_bool::(array), + Int32 => cast_numeric_to_bool::(array), + Int64 => cast_numeric_to_bool::(array), + Float32 => cast_numeric_to_bool::(array), + Float64 => cast_numeric_to_bool::(array), + Utf8 => Err(ArrowError::ComputeError(format!( + "Casting from {:?} to {:?} not supported", + from_type, to_type, + ))), + _ => Err(ArrowError::ComputeError(format!( + "Casting from {:?} to {:?} not supported", + from_type, to_type, + ))), + }, + (Boolean, _) => match to_type { + UInt8 => cast_bool_to_numeric::(array, to_type), + UInt16 => cast_bool_to_numeric::(array, to_type), + UInt32 => cast_bool_to_numeric::(array, to_type), + UInt64 => cast_bool_to_numeric::(array, to_type), + Int8 => cast_bool_to_numeric::(array, to_type), + Int16 => cast_bool_to_numeric::(array, to_type), + Int32 => cast_bool_to_numeric::(array, to_type), + Int64 => cast_bool_to_numeric::(array, to_type), + Float32 => cast_bool_to_numeric::(array, to_type), + Float64 => cast_bool_to_numeric::(array, to_type), + Utf8 => cast_bool_to_utf8::(array), + LargeUtf8 => cast_bool_to_utf8::(array), + _ => Err(ArrowError::ComputeError(format!( + "Casting from {:?} to {:?} not supported", + from_type, to_type, + ))), + }, + + (Utf8, _) => match to_type { + UInt8 => cast_string_to_numeric::(array, to_type), + UInt16 => cast_string_to_numeric::(array, to_type), + UInt32 => cast_string_to_numeric::(array, to_type), + UInt64 => cast_string_to_numeric::(array, to_type), + Int8 => cast_string_to_numeric::(array, to_type), + Int16 => cast_string_to_numeric::(array, to_type), + Int32 => cast_string_to_numeric::(array, to_type), + Int64 => cast_string_to_numeric::(array, to_type), + Float32 => cast_string_to_numeric::(array, to_type), + Float64 => cast_string_to_numeric::(array, to_type), + Date32 => Ok(Box::new(to_date32::(array, to_type))), + Date64 => Ok(Box::new(to_date64::(array, to_type))), + _ => Err(ArrowError::ComputeError(format!( + "Casting from {:?} to {:?} not supported", + from_type, to_type, + ))), + }, + (LargeUtf8, _) => match to_type { + UInt8 => cast_string_to_numeric::(array, to_type), + UInt16 => cast_string_to_numeric::(array, to_type), + UInt32 => cast_string_to_numeric::(array, to_type), + UInt64 => cast_string_to_numeric::(array, to_type), + Int8 => cast_string_to_numeric::(array, to_type), + Int16 => cast_string_to_numeric::(array, to_type), + Int32 => cast_string_to_numeric::(array, to_type), + Int64 => cast_string_to_numeric::(array, to_type), + Float32 => cast_string_to_numeric::(array, to_type), + Float64 => cast_string_to_numeric::(array, to_type), + Date32 => Ok(Box::new(to_date32::(array, to_type))), + Date64 => Ok(Box::new(to_date64::(array, to_type))), + _ => Err(ArrowError::ComputeError(format!( + "Casting from {:?} to {:?} not supported", + from_type, to_type, + ))), + }, + + (_, Utf8) => match from_type { + UInt8 => cast_numeric_to_string::(array), + UInt16 => cast_numeric_to_string::(array), + UInt32 => cast_numeric_to_string::(array), + UInt64 => cast_numeric_to_string::(array), + Int8 => cast_numeric_to_string::(array), + Int16 => cast_numeric_to_string::(array), + Int32 => cast_numeric_to_string::(array), + Int64 => cast_numeric_to_string::(array), + Float32 => cast_numeric_to_string::(array), + Float64 => cast_numeric_to_string::(array), + Binary => { + let array = array.as_any().downcast_ref::>().unwrap(); + + // perf todo: the offsets are equal; we can speed-up this + let iter = array + .iter() + .map(|x| x.and_then(|x| std::str::from_utf8(x).ok())); + + let array = unsafe { Utf8Array::::from_trusted_len_iter(iter) }; + Ok(Box::new(array)) + } + _ => Err(ArrowError::ComputeError(format!( + "Casting from {:?} to {:?} not supported", + from_type, to_type, + ))), + }, + + // start numeric casts + (UInt8, UInt16) => cast_numeric_arrays::(array, to_type), + (UInt8, UInt32) => cast_numeric_arrays::(array, to_type), + (UInt8, UInt64) => cast_numeric_arrays::(array, to_type), + (UInt8, Int8) => cast_numeric_arrays::(array, to_type), + (UInt8, Int16) => cast_numeric_arrays::(array, to_type), + (UInt8, Int32) => cast_numeric_arrays::(array, to_type), + (UInt8, Int64) => cast_numeric_arrays::(array, to_type), + (UInt8, Float32) => cast_numeric_arrays::(array, to_type), + (UInt8, Float64) => cast_numeric_arrays::(array, to_type), + + (UInt16, UInt8) => cast_numeric_arrays::(array, to_type), + (UInt16, UInt32) => cast_numeric_arrays::(array, to_type), + (UInt16, UInt64) => cast_numeric_arrays::(array, to_type), + (UInt16, Int8) => cast_numeric_arrays::(array, to_type), + (UInt16, Int16) => cast_numeric_arrays::(array, to_type), + (UInt16, Int32) => cast_numeric_arrays::(array, to_type), + (UInt16, Int64) => cast_numeric_arrays::(array, to_type), + (UInt16, Float32) => cast_numeric_arrays::(array, to_type), + (UInt16, Float64) => cast_numeric_arrays::(array, to_type), + + (UInt32, UInt8) => cast_numeric_arrays::(array, to_type), + (UInt32, UInt16) => cast_numeric_arrays::(array, to_type), + (UInt32, UInt64) => cast_numeric_arrays::(array, to_type), + (UInt32, Int8) => cast_numeric_arrays::(array, to_type), + (UInt32, Int16) => cast_numeric_arrays::(array, to_type), + (UInt32, Int32) => cast_numeric_arrays::(array, to_type), + (UInt32, Int64) => cast_numeric_arrays::(array, to_type), + (UInt32, Float32) => cast_numeric_arrays::(array, to_type), + (UInt32, Float64) => cast_numeric_arrays::(array, to_type), + + (UInt64, UInt8) => cast_numeric_arrays::(array, to_type), + (UInt64, UInt16) => cast_numeric_arrays::(array, to_type), + (UInt64, UInt32) => cast_numeric_arrays::(array, to_type), + (UInt64, Int8) => cast_numeric_arrays::(array, to_type), + (UInt64, Int16) => cast_numeric_arrays::(array, to_type), + (UInt64, Int32) => cast_numeric_arrays::(array, to_type), + (UInt64, Int64) => cast_numeric_arrays::(array, to_type), + (UInt64, Float32) => cast_numeric_arrays::(array, to_type), + (UInt64, Float64) => cast_numeric_arrays::(array, to_type), + + (Int8, UInt8) => cast_numeric_arrays::(array, to_type), + (Int8, UInt16) => cast_numeric_arrays::(array, to_type), + (Int8, UInt32) => cast_numeric_arrays::(array, to_type), + (Int8, UInt64) => cast_numeric_arrays::(array, to_type), + (Int8, Int16) => cast_numeric_arrays::(array, to_type), + (Int8, Int32) => cast_numeric_arrays::(array, to_type), + (Int8, Int64) => cast_numeric_arrays::(array, to_type), + (Int8, Float32) => cast_numeric_arrays::(array, to_type), + (Int8, Float64) => cast_numeric_arrays::(array, to_type), + + (Int16, UInt8) => cast_numeric_arrays::(array, to_type), + (Int16, UInt16) => cast_numeric_arrays::(array, to_type), + (Int16, UInt32) => cast_numeric_arrays::(array, to_type), + (Int16, UInt64) => cast_numeric_arrays::(array, to_type), + (Int16, Int8) => cast_numeric_arrays::(array, to_type), + (Int16, Int32) => cast_numeric_arrays::(array, to_type), + (Int16, Int64) => cast_numeric_arrays::(array, to_type), + (Int16, Float32) => cast_numeric_arrays::(array, to_type), + (Int16, Float64) => cast_numeric_arrays::(array, to_type), + + (Int32, UInt8) => cast_numeric_arrays::(array, to_type), + (Int32, UInt16) => cast_numeric_arrays::(array, to_type), + (Int32, UInt32) => cast_numeric_arrays::(array, to_type), + (Int32, UInt64) => cast_numeric_arrays::(array, to_type), + (Int32, Int8) => cast_numeric_arrays::(array, to_type), + (Int32, Int16) => cast_numeric_arrays::(array, to_type), + (Int32, Int64) => cast_numeric_arrays::(array, to_type), + (Int32, Float32) => cast_numeric_arrays::(array, to_type), + (Int32, Float64) => cast_numeric_arrays::(array, to_type), + + (Int64, UInt8) => cast_numeric_arrays::(array, to_type), + (Int64, UInt16) => cast_numeric_arrays::(array, to_type), + (Int64, UInt32) => cast_numeric_arrays::(array, to_type), + (Int64, UInt64) => cast_numeric_arrays::(array, to_type), + (Int64, Int8) => cast_numeric_arrays::(array, to_type), + (Int64, Int16) => cast_numeric_arrays::(array, to_type), + (Int64, Int32) => cast_numeric_arrays::(array, to_type), + (Int64, Float32) => cast_numeric_arrays::(array, to_type), + (Int64, Float64) => cast_numeric_arrays::(array, to_type), + + (Float32, UInt8) => cast_numeric_arrays::(array, to_type), + (Float32, UInt16) => cast_numeric_arrays::(array, to_type), + (Float32, UInt32) => cast_numeric_arrays::(array, to_type), + (Float32, UInt64) => cast_numeric_arrays::(array, to_type), + (Float32, Int8) => cast_numeric_arrays::(array, to_type), + (Float32, Int16) => cast_numeric_arrays::(array, to_type), + (Float32, Int32) => cast_numeric_arrays::(array, to_type), + (Float32, Int64) => cast_numeric_arrays::(array, to_type), + (Float32, Float64) => cast_numeric_arrays::(array, to_type), + + (Float64, UInt8) => cast_numeric_arrays::(array, to_type), + (Float64, UInt16) => cast_numeric_arrays::(array, to_type), + (Float64, UInt32) => cast_numeric_arrays::(array, to_type), + (Float64, UInt64) => cast_numeric_arrays::(array, to_type), + (Float64, Int8) => cast_numeric_arrays::(array, to_type), + (Float64, Int16) => cast_numeric_arrays::(array, to_type), + (Float64, Int32) => cast_numeric_arrays::(array, to_type), + (Float64, Int64) => cast_numeric_arrays::(array, to_type), + (Float64, Float32) => cast_numeric_arrays::(array, to_type), + // end numeric casts + + // temporal casts + (Int32, Date32) => cast_array_data::(array, to_type), + (Int32, Time32(TimeUnit::Second)) => cast_array_data::(array, to_type), + (Int32, Time32(TimeUnit::Millisecond)) => cast_array_data::(array, to_type), + // No support for microsecond/nanosecond with i32 + (Date32, Int32) => cast_array_data::(array, to_type), + (Time32(_), Int32) => cast_array_data::(array, to_type), + (Int64, Date64) => cast_array_data::(array, to_type), + // No support for second/milliseconds with i64 + (Int64, Time64(TimeUnit::Microsecond)) => cast_array_data::(array, to_type), + (Int64, Time64(TimeUnit::Nanosecond)) => cast_array_data::(array, to_type), + + (Date64, Int64) => cast_array_data::(array, to_type), + (Time64(_), Int64) => cast_array_data::(array, to_type), + (Date32, Date64) => { + let array = array + .as_any() + .downcast_ref::>() + .unwrap(); + + let values = unary::<_, _, i64>(array, |x| x as i64 * MILLISECONDS_IN_DAY, to_type); + + Ok(Box::new(values)) + } + (Date64, Date32) => { + let array = array + .as_any() + .downcast_ref::>() + .unwrap(); + + let values = unary::<_, _, i32>(array, |x| (x / MILLISECONDS_IN_DAY) as i32, to_type); + + Ok(Box::new(values)) + } + (Time32(TimeUnit::Second), Time32(TimeUnit::Millisecond)) => { + let array = array + .as_any() + .downcast_ref::>() + .unwrap(); + + let values = unary::<_, _, i32>(array, |x| x * MILLISECONDS as i32, to_type); + + Ok(Box::new(values)) + } + (Time32(TimeUnit::Millisecond), Time32(TimeUnit::Second)) => { + let array = array + .as_any() + .downcast_ref::>() + .unwrap(); + + let values = unary::<_, _, i32>(array, |x| x / (MILLISECONDS as i32), to_type); + + Ok(Box::new(values)) + } + (Time32(from_unit), Time64(to_unit)) => { + let array = array + .as_any() + .downcast_ref::>() + .unwrap(); + + let from_size = time_unit_multiple(&from_unit); + let to_size = time_unit_multiple(&to_unit); + let divisor = to_size / from_size; + let values = unary::<_, _, i64>(&array, |x| (x as i64 * divisor), to_type); + Ok(Box::new(values)) + } + (Time64(TimeUnit::Microsecond), Time64(TimeUnit::Nanosecond)) => { + let array = array + .as_any() + .downcast_ref::>() + .unwrap(); + + let values = unary::<_, _, i64>(array, |x| x * MILLISECONDS, to_type); + Ok(Box::new(values)) + } + (Time64(TimeUnit::Nanosecond), Time64(TimeUnit::Microsecond)) => { + let array = array + .as_any() + .downcast_ref::>() + .unwrap(); + + let values = unary::<_, _, i64>(array, |x| x / MILLISECONDS, to_type); + Ok(Box::new(values)) + } + (Time64(from_unit), Time32(to_unit)) => { + let array = array + .as_any() + .downcast_ref::>() + .unwrap(); + + let from_size = time_unit_multiple(&from_unit); + let to_size = time_unit_multiple(&to_unit); + let divisor = from_size / to_size; + let values = unary::<_, _, i32>(&array, |x| (x as i64 / divisor) as i32, to_type); + Ok(Box::new(values)) + } + (Timestamp(_, _), Int64) => cast_array_data::(array, &to_type), + (Int64, Timestamp(_, _)) => cast_array_data::(array, to_type), + (Timestamp(from_unit, _), Timestamp(to_unit, _)) => { + let array = array + .as_any() + .downcast_ref::>() + .unwrap(); + + let from_size = time_unit_multiple(&from_unit); + let to_size = time_unit_multiple(&to_unit); + // we either divide or multiply, depending on size of each unit + let array = if from_size >= to_size { + unary::<_, _, i64>(&array, |x| (x / (from_size / to_size)), to_type) + } else { + unary::<_, _, i64>(&array, |x| (x * (to_size / from_size)), to_type) + }; + Ok(Box::new(array)) + } + (Timestamp(from_unit, _), Date32) => { + let array = array + .as_any() + .downcast_ref::>() + .unwrap(); + + let from_size = time_unit_multiple(&from_unit) * SECONDS_IN_DAY; + let array = unary::<_, _, i32>(&array, |x| (x / from_size) as i32, to_type); + + Ok(Box::new(array)) + } + (Timestamp(from_unit, _), Date64) => { + let array = array + .as_any() + .downcast_ref::>() + .unwrap(); + + let from_size = time_unit_multiple(&from_unit); + let to_size = MILLISECONDS; + + // Scale time_array by (to_size / from_size) using a + // single integer operation, but need to avoid integer + // math rounding down to zero + + match to_size.cmp(&from_size) { + std::cmp::Ordering::Less => { + let array = + unary::<_, _, i64>(&array, |x| (x / (from_size / to_size)), to_type); + Ok(Box::new(array)) + } + std::cmp::Ordering::Equal => cast_array_data::(array, &to_type), + std::cmp::Ordering::Greater => { + let array = + unary::<_, _, i64>(&array, |x| (x * (to_size / from_size)), to_type); + Ok(Box::new(array)) + } + } + } + + // date64 to timestamp might not make sense, + (Int64, Duration(_)) => cast_array_data::(array, to_type), + + // null to primitive/flat types + //(Null, Int32) => Ok(Box::new(Int32Array::from(vec![None; array.len()]))), + (_, _) => Err(ArrowError::ComputeError(format!( + "Casting from {:?} to {:?} not supported", + from_type, to_type, + ))), + } +} + +/// Get the time unit as a multiple of a second +const fn time_unit_multiple(unit: &TimeUnit) -> i64 { + match unit { + TimeUnit::Second => 1, + TimeUnit::Millisecond => MILLISECONDS, + TimeUnit::Microsecond => MICROSECONDS, + TimeUnit::Nanosecond => NANOSECONDS, + } +} + +/// Attempts to encode an array into an `ArrayDictionary` with index +/// type K and value (dictionary) type value_type +/// +/// K is the key type +fn cast_to_dictionary( + array: &dyn Array, + dict_value_type: &DataType, +) -> Result> { + match *dict_value_type { + DataType::Int8 => primitive_to_dictionary::(array, dict_value_type), + DataType::Int16 => primitive_to_dictionary::(array, dict_value_type), + DataType::Int32 => primitive_to_dictionary::(array, dict_value_type), + DataType::Int64 => primitive_to_dictionary::(array, dict_value_type), + DataType::UInt8 => primitive_to_dictionary::(array, dict_value_type), + DataType::UInt16 => primitive_to_dictionary::(array, dict_value_type), + DataType::UInt32 => primitive_to_dictionary::(array, dict_value_type), + DataType::UInt64 => primitive_to_dictionary::(array, dict_value_type), + DataType::Utf8 => string_to_dictionary::(array), + DataType::LargeUtf8 => string_to_dictionary::(array), + _ => Err(ArrowError::ComputeError(format!( + "Internal Error: Unsupported output type for dictionary packing: {:?}", + dict_value_type + ))), + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_cast_i32_to_f64() { + let array = Primitive::::from_values(&[5, 6, 7, 8, 9]).to(DataType::Int32); + let b = cast(&array, &DataType::Float64).unwrap(); + let c = b.as_any().downcast_ref::().unwrap(); + assert!(5.0 - c.value(0) < f64::EPSILON); + assert!(6.0 - c.value(1) < f64::EPSILON); + assert!(7.0 - c.value(2) < f64::EPSILON); + assert!(8.0 - c.value(3) < f64::EPSILON); + assert!(9.0 - c.value(4) < f64::EPSILON); + } + + #[test] + fn test_cast_i32_to_u8() { + let array = Primitive::::from_values(&[-5, 6, -7, 8, 100000000]).to(DataType::Int32); + let b = cast(&array, &DataType::UInt8).unwrap(); + let expected = + Primitive::::from(&[None, Some(6), None, Some(8), None]).to(DataType::UInt8); + let c = b.as_any().downcast_ref::>().unwrap(); + assert_eq!(c, &expected); + } + + #[test] + fn test_cast_i32_to_u8_sliced() { + let array = Primitive::::from_values(&[-5, 6, -7, 8, 100000000]).to(DataType::Int32); + let array = array.slice(2, 3); + let b = cast(&array, &DataType::UInt8).unwrap(); + let expected = Primitive::::from(&[None, Some(8), None]).to(DataType::UInt8); + let c = b.as_any().downcast_ref::>().unwrap(); + assert_eq!(c, &expected); + } + + #[test] + fn test_cast_i32_to_i32() { + let input = &[5, 6, 7, 8, 9]; + let array = Primitive::::from_values(input).to(DataType::Int32); + let b = cast(&array, &DataType::Int32).unwrap(); + let c = b.as_any().downcast_ref::>().unwrap(); + + let expected = &[5, 6, 7, 8, 9]; + let expected = Primitive::::from_values(expected).to(DataType::Int32); + assert_eq!(c, &expected); + } + + #[test] + fn test_cast_i32_to_list_i32() { + let input = &[5, 6, 7, 8, 9]; + let array = Primitive::::from_values(input).to(DataType::Int32); + let b = cast( + &array, + &DataType::List(Box::new(Field::new("item", DataType::Int32, true))), + ) + .unwrap(); + + let arr = b.as_any().downcast_ref::>().unwrap(); + assert_eq!(&[0, 1, 2, 3, 4, 5], arr.offsets().as_slice()); + let values = arr.values(); + let c = values + .as_any() + .downcast_ref::>() + .unwrap(); + + let expected = &[5, 6, 7, 8, 9]; + let expected = Primitive::::from_values(expected).to(DataType::Int32); + assert_eq!(c, &expected); + } + + #[test] + fn test_cast_i32_to_list_i32_nullable() { + let input = [Some(5), None, Some(7), Some(8), Some(9)]; + + let array = Primitive::::from(input).to(DataType::Int32); + let b = cast( + &array, + &DataType::List(Box::new(Field::new("item", DataType::Int32, true))), + ) + .unwrap(); + + let arr = b.as_any().downcast_ref::>().unwrap(); + assert_eq!(&[0, 1, 2, 3, 4, 5], arr.offsets().as_slice()); + let values = arr.values(); + let c = values + .as_any() + .downcast_ref::>() + .unwrap(); + + let expected = &[Some(5), None, Some(7), Some(8), Some(9)]; + let expected = Primitive::::from(expected).to(DataType::Int32); + assert_eq!(c, &expected); + } + + #[test] + fn test_cast_i32_to_list_f64_nullable_sliced() { + let input = [Some(5), None, Some(7), Some(8), None, Some(10)]; + + let array = Primitive::::from(input).to(DataType::Int32); + + let array = array.slice(2, 4); + let b = cast( + &array, + &DataType::List(Box::new(Field::new("item", DataType::Float64, true))), + ) + .unwrap(); + + let arr = b.as_any().downcast_ref::>().unwrap(); + assert_eq!(&[0, 1, 2, 3, 4], arr.offsets().as_slice()); + let values = arr.values(); + let c = values + .as_any() + .downcast_ref::>() + .unwrap(); + + let expected = &[Some(7.0), Some(8.0), None, Some(10.0)]; + let expected = Primitive::::from(expected).to(DataType::Float64); + assert_eq!(c, &expected); + } + + #[test] + fn test_cast_utf8_to_i32() { + let array = Utf8Array::::from_slice(&["5", "6", "seven", "8", "9.1"]); + println!("{:#?}", array); + let b = cast(&array, &DataType::Int32).unwrap(); + let c = b.as_any().downcast_ref::>().unwrap(); + + let expected = &[Some(5), Some(6), None, Some(8), None]; + let expected = Primitive::::from(expected).to(DataType::Int32); + assert_eq!(c, &expected); + } + + #[test] + fn test_cast_bool_to_i32() { + let array = BooleanArray::from(vec![Some(true), Some(false), None]); + let b = cast(&array, &DataType::Int32).unwrap(); + let c = b.as_any().downcast_ref::>().unwrap(); + + let expected = &[Some(1), Some(0), None]; + let expected = Primitive::::from(expected).to(DataType::Int32); + assert_eq!(c, &expected); + } + + #[test] + fn test_cast_bool_to_f64() { + let array = BooleanArray::from(vec![Some(true), Some(false), None]); + let b = cast(&array, &DataType::Float64).unwrap(); + let c = b.as_any().downcast_ref::().unwrap(); + + let expected = &[Some(1.0), Some(0.0), None]; + let expected = Primitive::::from(expected).to(DataType::Float64); + assert_eq!(c, &expected); + } + + #[test] + #[should_panic(expected = "Casting from Int32 to Timestamp(Microsecond, None) not supported")] + fn test_cast_int32_to_timestamp() { + let array = Primitive::::from(&[Some(2), Some(10), None]).to(DataType::Int32); + cast(&array, &DataType::Timestamp(TimeUnit::Microsecond, None)).unwrap(); + } + + /* + #[test] + fn test_cast_list_i32_to_list_u16() { + // Construct a value array + let value_data = Int32Array::from(vec![0, 0, 0, -1, -2, -1, 2, 100000000]).data(); + + let value_offsets = Buffer::from_slice_ref(&[0, 3, 6, 8]); + + // Construct a list array from the above two + let list_data_type = + DataType::List(Box::new(Field::new("item", DataType::Int32, true))); + let list_data = ArrayData::builder(list_data_type) + .len(3) + .add_buffer(value_offsets) + .add_child_data(value_data) + .build(); + let list_array = Arc::new(ListArray::from(list_data)) as ArrayRef; + + let cast_array = cast( + &list_array, + &DataType::List(Box::new(Field::new("item", DataType::UInt16, true))), + ) + .unwrap(); + // 3 negative values should get lost when casting to unsigned, + // 1 value should overflow + assert_eq!(4, cast_array.null_count()); + // offsets should be the same + assert_eq!( + list_array.data().buffers().to_vec(), + cast_array.data().buffers().to_vec() + ); + let array = cast_array + .as_ref() + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(DataType::UInt16, array.value_type()); + assert_eq!(4, array.values().null_count()); + assert_eq!(3, array.value_length(0)); + assert_eq!(3, array.value_length(1)); + assert_eq!(2, array.value_length(2)); + let values = array.values(); + let u16arr = values.as_any().downcast_ref::().unwrap(); + assert_eq!(8, u16arr.len()); + assert_eq!(4, u16arr.null_count()); + + assert_eq!(0, u16arr.value(0)); + assert_eq!(0, u16arr.value(1)); + assert_eq!(0, u16arr.value(2)); + assert_eq!(false, u16arr.is_valid(3)); + assert_eq!(false, u16arr.is_valid(4)); + assert_eq!(false, u16arr.is_valid(5)); + assert_eq!(2, u16arr.value(6)); + assert_eq!(false, u16arr.is_valid(7)); + } + + #[test] + #[should_panic( + expected = "Casting from Int32 to Timestamp(Microsecond, None) not supported" + )] + fn test_cast_list_i32_to_list_timestamp() { + // Construct a value array + let value_data = + Int32Array::from(vec![0, 0, 0, -1, -2, -1, 2, 8, 100000000]).data(); + + let value_offsets = Buffer::from_slice_ref(&[0, 3, 6, 9]); + + // Construct a list array from the above two + let list_data_type = + DataType::List(Box::new(Field::new("item", DataType::Int32, true))); + let list_data = ArrayData::builder(list_data_type) + .len(3) + .add_buffer(value_offsets) + .add_child_data(value_data) + .build(); + let list_array = Arc::new(ListArray::from(list_data)) as ArrayRef; + + cast( + &list_array, + &DataType::List(Box::new(Field::new( + "item", + DataType::Timestamp(TimeUnit::Microsecond, None), + true, + ))), + ) + .unwrap(); + } + + #[test] + fn test_cast_date32_to_date64() { + let a = Date32Array::from(vec![10000, 17890]); + let array = Arc::new(a) as ArrayRef; + let b = cast(&array, &DataType::Date64).unwrap(); + let c = b.as_any().downcast_ref::().unwrap(); + assert_eq!(864000000000, c.value(0)); + assert_eq!(1545696000000, c.value(1)); + } + + #[test] + fn test_cast_date64_to_date32() { + let a = Date64Array::from(vec![Some(864000000005), Some(1545696000001), None]); + let array = Arc::new(a) as ArrayRef; + let b = cast(&array, &DataType::Date32).unwrap(); + let c = b.as_any().downcast_ref::().unwrap(); + assert_eq!(10000, c.value(0)); + assert_eq!(17890, c.value(1)); + assert!(c.is_null(2)); + } + + #[test] + fn test_cast_date32_to_int32() { + let a = Date32Array::from(vec![10000, 17890]); + let array = Arc::new(a) as ArrayRef; + let b = cast(&array, &DataType::Int32).unwrap(); + let c = b.as_any().downcast_ref::().unwrap(); + assert_eq!(10000, c.value(0)); + assert_eq!(17890, c.value(1)); + } + + #[test] + fn test_cast_int32_to_date32() { + let a = Int32Array::from(vec![10000, 17890]); + let array = Arc::new(a) as ArrayRef; + let b = cast(&array, &DataType::Date32).unwrap(); + let c = b.as_any().downcast_ref::().unwrap(); + assert_eq!(10000, c.value(0)); + assert_eq!(17890, c.value(1)); + } + + #[test] + fn test_cast_timestamp_to_date32() { + let a = TimestampMillisecondArray::from_opt_vec( + vec![Some(864000000005), Some(1545696000001), None], + Some(String::from("UTC")), + ); + let array = Arc::new(a) as ArrayRef; + let b = cast(&array, &DataType::Date32).unwrap(); + let c = b.as_any().downcast_ref::().unwrap(); + assert_eq!(10000, c.value(0)); + assert_eq!(17890, c.value(1)); + assert!(c.is_null(2)); + } + + #[test] + fn test_cast_timestamp_to_date64() { + let a = TimestampMillisecondArray::from_opt_vec( + vec![Some(864000000005), Some(1545696000001), None], + None, + ); + let array = Arc::new(a) as ArrayRef; + let b = cast(&array, &DataType::Date64).unwrap(); + let c = b.as_any().downcast_ref::().unwrap(); + assert_eq!(864000000005, c.value(0)); + assert_eq!(1545696000001, c.value(1)); + assert!(c.is_null(2)); + } + + #[test] + fn test_cast_timestamp_to_i64() { + let a = TimestampMillisecondArray::from_opt_vec( + vec![Some(864000000005), Some(1545696000001), None], + Some("UTC".to_string()), + ); + let array = Arc::new(a) as ArrayRef; + let b = cast(&array, &DataType::Int64).unwrap(); + let c = b.as_any().downcast_ref::().unwrap(); + assert_eq!(&DataType::Int64, c.data_type()); + assert_eq!(864000000005, c.value(0)); + assert_eq!(1545696000001, c.value(1)); + assert!(c.is_null(2)); + } + + #[test] + fn test_cast_between_timestamps() { + let a = TimestampMillisecondArray::from_opt_vec( + vec![Some(864000003005), Some(1545696002001), None], + None, + ); + let array = Arc::new(a) as ArrayRef; + let b = cast(&array, &DataType::Timestamp(TimeUnit::Second, None)).unwrap(); + let c = b.as_any().downcast_ref::().unwrap(); + assert_eq!(864000003, c.value(0)); + assert_eq!(1545696002, c.value(1)); + assert!(c.is_null(2)); + } + + #[test] + fn test_cast_from_f64() { + let f64_values: Vec = vec![ + std::i64::MIN as f64, + std::i32::MIN as f64, + std::i16::MIN as f64, + std::i8::MIN as f64, + 0_f64, + std::u8::MAX as f64, + std::u16::MAX as f64, + std::u32::MAX as f64, + std::u64::MAX as f64, + ]; + let f64_array: ArrayRef = Arc::new(Float64Array::from(f64_values)); + + let f64_expected = vec![ + "-9223372036854776000.0", + "-2147483648.0", + "-32768.0", + "-128.0", + "0.0", + "255.0", + "65535.0", + "4294967295.0", + "18446744073709552000.0", + ]; + assert_eq!( + f64_expected, + get_cast_values::(&f64_array, &DataType::Float64) + ); + + let f32_expected = vec![ + "-9223372000000000000.0", + "-2147483600.0", + "-32768.0", + "-128.0", + "0.0", + "255.0", + "65535.0", + "4294967300.0", + "18446744000000000000.0", + ]; + assert_eq!( + f32_expected, + get_cast_values::(&f64_array, &DataType::Float32) + ); + + let i64_expected = vec![ + "-9223372036854775808", + "-2147483648", + "-32768", + "-128", + "0", + "255", + "65535", + "4294967295", + "null", + ]; + assert_eq!( + i64_expected, + get_cast_values::(&f64_array, &DataType::Int64) + ); + + let i32_expected = vec![ + "null", + "-2147483648", + "-32768", + "-128", + "0", + "255", + "65535", + "null", + "null", + ]; + assert_eq!( + i32_expected, + get_cast_values::(&f64_array, &DataType::Int32) + ); + + let i16_expected = vec![ + "null", "null", "-32768", "-128", "0", "255", "null", "null", "null", + ]; + assert_eq!( + i16_expected, + get_cast_values::(&f64_array, &DataType::Int16) + ); + + let i8_expected = vec![ + "null", "null", "null", "-128", "0", "null", "null", "null", "null", + ]; + assert_eq!( + i8_expected, + get_cast_values::(&f64_array, &DataType::Int8) + ); + + let u64_expected = vec![ + "null", + "null", + "null", + "null", + "0", + "255", + "65535", + "4294967295", + "null", + ]; + assert_eq!( + u64_expected, + get_cast_values::(&f64_array, &DataType::UInt64) + ); + + let u32_expected = vec![ + "null", + "null", + "null", + "null", + "0", + "255", + "65535", + "4294967295", + "null", + ]; + assert_eq!( + u32_expected, + get_cast_values::(&f64_array, &DataType::UInt32) + ); + + let u16_expected = vec![ + "null", "null", "null", "null", "0", "255", "65535", "null", "null", + ]; + assert_eq!( + u16_expected, + get_cast_values::(&f64_array, &DataType::UInt16) + ); + + let u8_expected = vec![ + "null", "null", "null", "null", "0", "255", "null", "null", "null", + ]; + assert_eq!( + u8_expected, + get_cast_values::(&f64_array, &DataType::UInt8) + ); + } + + #[test] + fn test_cast_from_f32() { + let f32_values: Vec = vec![ + std::i32::MIN as f32, + std::i32::MIN as f32, + std::i16::MIN as f32, + std::i8::MIN as f32, + 0_f32, + std::u8::MAX as f32, + std::u16::MAX as f32, + std::u32::MAX as f32, + std::u32::MAX as f32, + ]; + let f32_array: ArrayRef = Arc::new(Float32Array::from(f32_values)); + + let f64_expected = vec![ + "-2147483648.0", + "-2147483648.0", + "-32768.0", + "-128.0", + "0.0", + "255.0", + "65535.0", + "4294967296.0", + "4294967296.0", + ]; + assert_eq!( + f64_expected, + get_cast_values::(&f32_array, &DataType::Float64) + ); + + let f32_expected = vec![ + "-2147483600.0", + "-2147483600.0", + "-32768.0", + "-128.0", + "0.0", + "255.0", + "65535.0", + "4294967300.0", + "4294967300.0", + ]; + assert_eq!( + f32_expected, + get_cast_values::(&f32_array, &DataType::Float32) + ); + + let i64_expected = vec![ + "-2147483648", + "-2147483648", + "-32768", + "-128", + "0", + "255", + "65535", + "4294967296", + "4294967296", + ]; + assert_eq!( + i64_expected, + get_cast_values::(&f32_array, &DataType::Int64) + ); + + let i32_expected = vec![ + "-2147483648", + "-2147483648", + "-32768", + "-128", + "0", + "255", + "65535", + "null", + "null", + ]; + assert_eq!( + i32_expected, + get_cast_values::(&f32_array, &DataType::Int32) + ); + + let i16_expected = vec![ + "null", "null", "-32768", "-128", "0", "255", "null", "null", "null", + ]; + assert_eq!( + i16_expected, + get_cast_values::(&f32_array, &DataType::Int16) + ); + + let i8_expected = vec![ + "null", "null", "null", "-128", "0", "null", "null", "null", "null", + ]; + assert_eq!( + i8_expected, + get_cast_values::(&f32_array, &DataType::Int8) + ); + + let u64_expected = vec![ + "null", + "null", + "null", + "null", + "0", + "255", + "65535", + "4294967296", + "4294967296", + ]; + assert_eq!( + u64_expected, + get_cast_values::(&f32_array, &DataType::UInt64) + ); + + let u32_expected = vec![ + "null", "null", "null", "null", "0", "255", "65535", "null", "null", + ]; + assert_eq!( + u32_expected, + get_cast_values::(&f32_array, &DataType::UInt32) + ); + + let u16_expected = vec![ + "null", "null", "null", "null", "0", "255", "65535", "null", "null", + ]; + assert_eq!( + u16_expected, + get_cast_values::(&f32_array, &DataType::UInt16) + ); + + let u8_expected = vec![ + "null", "null", "null", "null", "0", "255", "null", "null", "null", + ]; + assert_eq!( + u8_expected, + get_cast_values::(&f32_array, &DataType::UInt8) + ); + } + + #[test] + fn test_cast_from_uint64() { + let u64_values: Vec = vec![ + 0, + std::u8::MAX as u64, + std::u16::MAX as u64, + std::u32::MAX as u64, + std::u64::MAX, + ]; + let u64_array: ArrayRef = Arc::new(UInt64Array::from(u64_values)); + + let f64_expected = vec![ + "0.0", + "255.0", + "65535.0", + "4294967295.0", + "18446744073709552000.0", + ]; + assert_eq!( + f64_expected, + get_cast_values::(&u64_array, &DataType::Float64) + ); + + let f32_expected = vec![ + "0.0", + "255.0", + "65535.0", + "4294967300.0", + "18446744000000000000.0", + ]; + assert_eq!( + f32_expected, + get_cast_values::(&u64_array, &DataType::Float32) + ); + + let i64_expected = vec!["0", "255", "65535", "4294967295", "null"]; + assert_eq!( + i64_expected, + get_cast_values::(&u64_array, &DataType::Int64) + ); + + let i32_expected = vec!["0", "255", "65535", "null", "null"]; + assert_eq!( + i32_expected, + get_cast_values::(&u64_array, &DataType::Int32) + ); + + let i16_expected = vec!["0", "255", "null", "null", "null"]; + assert_eq!( + i16_expected, + get_cast_values::(&u64_array, &DataType::Int16) + ); + + let i8_expected = vec!["0", "null", "null", "null", "null"]; + assert_eq!( + i8_expected, + get_cast_values::(&u64_array, &DataType::Int8) + ); + + let u64_expected = + vec!["0", "255", "65535", "4294967295", "18446744073709551615"]; + assert_eq!( + u64_expected, + get_cast_values::(&u64_array, &DataType::UInt64) + ); + + let u32_expected = vec!["0", "255", "65535", "4294967295", "null"]; + assert_eq!( + u32_expected, + get_cast_values::(&u64_array, &DataType::UInt32) + ); + + let u16_expected = vec!["0", "255", "65535", "null", "null"]; + assert_eq!( + u16_expected, + get_cast_values::(&u64_array, &DataType::UInt16) + ); + + let u8_expected = vec!["0", "255", "null", "null", "null"]; + assert_eq!( + u8_expected, + get_cast_values::(&u64_array, &DataType::UInt8) + ); + } + + #[test] + fn test_cast_from_uint32() { + let u32_values: Vec = vec![ + 0, + std::u8::MAX as u32, + std::u16::MAX as u32, + std::u32::MAX as u32, + ]; + let u32_array: ArrayRef = Arc::new(UInt32Array::from(u32_values)); + + let f64_expected = vec!["0.0", "255.0", "65535.0", "4294967295.0"]; + assert_eq!( + f64_expected, + get_cast_values::(&u32_array, &DataType::Float64) + ); + + let f32_expected = vec!["0.0", "255.0", "65535.0", "4294967300.0"]; + assert_eq!( + f32_expected, + get_cast_values::(&u32_array, &DataType::Float32) + ); + + let i64_expected = vec!["0", "255", "65535", "4294967295"]; + assert_eq!( + i64_expected, + get_cast_values::(&u32_array, &DataType::Int64) + ); + + let i32_expected = vec!["0", "255", "65535", "null"]; + assert_eq!( + i32_expected, + get_cast_values::(&u32_array, &DataType::Int32) + ); + + let i16_expected = vec!["0", "255", "null", "null"]; + assert_eq!( + i16_expected, + get_cast_values::(&u32_array, &DataType::Int16) + ); + + let i8_expected = vec!["0", "null", "null", "null"]; + assert_eq!( + i8_expected, + get_cast_values::(&u32_array, &DataType::Int8) + ); + + let u64_expected = vec!["0", "255", "65535", "4294967295"]; + assert_eq!( + u64_expected, + get_cast_values::(&u32_array, &DataType::UInt64) + ); + + let u32_expected = vec!["0", "255", "65535", "4294967295"]; + assert_eq!( + u32_expected, + get_cast_values::(&u32_array, &DataType::UInt32) + ); + + let u16_expected = vec!["0", "255", "65535", "null"]; + assert_eq!( + u16_expected, + get_cast_values::(&u32_array, &DataType::UInt16) + ); + + let u8_expected = vec!["0", "255", "null", "null"]; + assert_eq!( + u8_expected, + get_cast_values::(&u32_array, &DataType::UInt8) + ); + } + + #[test] + fn test_cast_from_uint16() { + let u16_values: Vec = vec![0, std::u8::MAX as u16, std::u16::MAX as u16]; + let u16_array: ArrayRef = Arc::new(UInt16Array::from(u16_values)); + + let f64_expected = vec!["0.0", "255.0", "65535.0"]; + assert_eq!( + f64_expected, + get_cast_values::(&u16_array, &DataType::Float64) + ); + + let f32_expected = vec!["0.0", "255.0", "65535.0"]; + assert_eq!( + f32_expected, + get_cast_values::(&u16_array, &DataType::Float32) + ); + + let i64_expected = vec!["0", "255", "65535"]; + assert_eq!( + i64_expected, + get_cast_values::(&u16_array, &DataType::Int64) + ); + + let i32_expected = vec!["0", "255", "65535"]; + assert_eq!( + i32_expected, + get_cast_values::(&u16_array, &DataType::Int32) + ); + + let i16_expected = vec!["0", "255", "null"]; + assert_eq!( + i16_expected, + get_cast_values::(&u16_array, &DataType::Int16) + ); + + let i8_expected = vec!["0", "null", "null"]; + assert_eq!( + i8_expected, + get_cast_values::(&u16_array, &DataType::Int8) + ); + + let u64_expected = vec!["0", "255", "65535"]; + assert_eq!( + u64_expected, + get_cast_values::(&u16_array, &DataType::UInt64) + ); + + let u32_expected = vec!["0", "255", "65535"]; + assert_eq!( + u32_expected, + get_cast_values::(&u16_array, &DataType::UInt32) + ); + + let u16_expected = vec!["0", "255", "65535"]; + assert_eq!( + u16_expected, + get_cast_values::(&u16_array, &DataType::UInt16) + ); + + let u8_expected = vec!["0", "255", "null"]; + assert_eq!( + u8_expected, + get_cast_values::(&u16_array, &DataType::UInt8) + ); + } + + #[test] + fn test_cast_from_uint8() { + let u8_values: Vec = vec![0, std::u8::MAX]; + let u8_array: ArrayRef = Arc::new(UInt8Array::from(u8_values)); + + let f64_expected = vec!["0.0", "255.0"]; + assert_eq!( + f64_expected, + get_cast_values::(&u8_array, &DataType::Float64) + ); + + let f32_expected = vec!["0.0", "255.0"]; + assert_eq!( + f32_expected, + get_cast_values::(&u8_array, &DataType::Float32) + ); + + let i64_expected = vec!["0", "255"]; + assert_eq!( + i64_expected, + get_cast_values::(&u8_array, &DataType::Int64) + ); + + let i32_expected = vec!["0", "255"]; + assert_eq!( + i32_expected, + get_cast_values::(&u8_array, &DataType::Int32) + ); + + let i16_expected = vec!["0", "255"]; + assert_eq!( + i16_expected, + get_cast_values::(&u8_array, &DataType::Int16) + ); + + let i8_expected = vec!["0", "null"]; + assert_eq!( + i8_expected, + get_cast_values::(&u8_array, &DataType::Int8) + ); + + let u64_expected = vec!["0", "255"]; + assert_eq!( + u64_expected, + get_cast_values::(&u8_array, &DataType::UInt64) + ); + + let u32_expected = vec!["0", "255"]; + assert_eq!( + u32_expected, + get_cast_values::(&u8_array, &DataType::UInt32) + ); + + let u16_expected = vec!["0", "255"]; + assert_eq!( + u16_expected, + get_cast_values::(&u8_array, &DataType::UInt16) + ); + + let u8_expected = vec!["0", "255"]; + assert_eq!( + u8_expected, + get_cast_values::(&u8_array, &DataType::UInt8) + ); + } + + #[test] + fn test_cast_from_int64() { + let i64_values: Vec = vec![ + std::i64::MIN, + std::i32::MIN as i64, + std::i16::MIN as i64, + std::i8::MIN as i64, + 0, + std::i8::MAX as i64, + std::i16::MAX as i64, + std::i32::MAX as i64, + std::i64::MAX, + ]; + let i64_array: ArrayRef = Arc::new(Int64Array::from(i64_values)); + + let f64_expected = vec![ + "-9223372036854776000.0", + "-2147483648.0", + "-32768.0", + "-128.0", + "0.0", + "127.0", + "32767.0", + "2147483647.0", + "9223372036854776000.0", + ]; + assert_eq!( + f64_expected, + get_cast_values::(&i64_array, &DataType::Float64) + ); + + let f32_expected = vec![ + "-9223372000000000000.0", + "-2147483600.0", + "-32768.0", + "-128.0", + "0.0", + "127.0", + "32767.0", + "2147483600.0", + "9223372000000000000.0", + ]; + assert_eq!( + f32_expected, + get_cast_values::(&i64_array, &DataType::Float32) + ); + + let i64_expected = vec![ + "-9223372036854775808", + "-2147483648", + "-32768", + "-128", + "0", + "127", + "32767", + "2147483647", + "9223372036854775807", + ]; + assert_eq!( + i64_expected, + get_cast_values::(&i64_array, &DataType::Int64) + ); + + let i32_expected = vec![ + "null", + "-2147483648", + "-32768", + "-128", + "0", + "127", + "32767", + "2147483647", + "null", + ]; + assert_eq!( + i32_expected, + get_cast_values::(&i64_array, &DataType::Int32) + ); + + let i16_expected = vec![ + "null", "null", "-32768", "-128", "0", "127", "32767", "null", "null", + ]; + assert_eq!( + i16_expected, + get_cast_values::(&i64_array, &DataType::Int16) + ); + + let i8_expected = vec![ + "null", "null", "null", "-128", "0", "127", "null", "null", "null", + ]; + assert_eq!( + i8_expected, + get_cast_values::(&i64_array, &DataType::Int8) + ); + + let u64_expected = vec![ + "null", + "null", + "null", + "null", + "0", + "127", + "32767", + "2147483647", + "9223372036854775807", + ]; + assert_eq!( + u64_expected, + get_cast_values::(&i64_array, &DataType::UInt64) + ); + + let u32_expected = vec![ + "null", + "null", + "null", + "null", + "0", + "127", + "32767", + "2147483647", + "null", + ]; + assert_eq!( + u32_expected, + get_cast_values::(&i64_array, &DataType::UInt32) + ); + + let u16_expected = vec![ + "null", "null", "null", "null", "0", "127", "32767", "null", "null", + ]; + assert_eq!( + u16_expected, + get_cast_values::(&i64_array, &DataType::UInt16) + ); + + let u8_expected = vec![ + "null", "null", "null", "null", "0", "127", "null", "null", "null", + ]; + assert_eq!( + u8_expected, + get_cast_values::(&i64_array, &DataType::UInt8) + ); + } + + #[test] + fn test_cast_from_int32() { + let i32_values: Vec = vec![ + std::i32::MIN as i32, + std::i16::MIN as i32, + std::i8::MIN as i32, + 0, + std::i8::MAX as i32, + std::i16::MAX as i32, + std::i32::MAX as i32, + ]; + let i32_array: ArrayRef = Arc::new(Int32Array::from(i32_values)); + + let f64_expected = vec![ + "-2147483648.0", + "-32768.0", + "-128.0", + "0.0", + "127.0", + "32767.0", + "2147483647.0", + ]; + assert_eq!( + f64_expected, + get_cast_values::(&i32_array, &DataType::Float64) + ); + + let f32_expected = vec![ + "-2147483600.0", + "-32768.0", + "-128.0", + "0.0", + "127.0", + "32767.0", + "2147483600.0", + ]; + assert_eq!( + f32_expected, + get_cast_values::(&i32_array, &DataType::Float32) + ); + + let i16_expected = vec!["null", "-32768", "-128", "0", "127", "32767", "null"]; + assert_eq!( + i16_expected, + get_cast_values::(&i32_array, &DataType::Int16) + ); + + let i8_expected = vec!["null", "null", "-128", "0", "127", "null", "null"]; + assert_eq!( + i8_expected, + get_cast_values::(&i32_array, &DataType::Int8) + ); + + let u64_expected = + vec!["null", "null", "null", "0", "127", "32767", "2147483647"]; + assert_eq!( + u64_expected, + get_cast_values::(&i32_array, &DataType::UInt64) + ); + + let u32_expected = + vec!["null", "null", "null", "0", "127", "32767", "2147483647"]; + assert_eq!( + u32_expected, + get_cast_values::(&i32_array, &DataType::UInt32) + ); + + let u16_expected = vec!["null", "null", "null", "0", "127", "32767", "null"]; + assert_eq!( + u16_expected, + get_cast_values::(&i32_array, &DataType::UInt16) + ); + + let u8_expected = vec!["null", "null", "null", "0", "127", "null", "null"]; + assert_eq!( + u8_expected, + get_cast_values::(&i32_array, &DataType::UInt8) + ); + } + + #[test] + fn test_cast_from_int16() { + let i16_values: Vec = vec![ + std::i16::MIN, + std::i8::MIN as i16, + 0, + std::i8::MAX as i16, + std::i16::MAX, + ]; + let i16_array: ArrayRef = Arc::new(Int16Array::from(i16_values)); + + let f64_expected = vec!["-32768.0", "-128.0", "0.0", "127.0", "32767.0"]; + assert_eq!( + f64_expected, + get_cast_values::(&i16_array, &DataType::Float64) + ); + + let f32_expected = vec!["-32768.0", "-128.0", "0.0", "127.0", "32767.0"]; + assert_eq!( + f32_expected, + get_cast_values::(&i16_array, &DataType::Float32) + ); + + let i64_expected = vec!["-32768", "-128", "0", "127", "32767"]; + assert_eq!( + i64_expected, + get_cast_values::(&i16_array, &DataType::Int64) + ); + + let i32_expected = vec!["-32768", "-128", "0", "127", "32767"]; + assert_eq!( + i32_expected, + get_cast_values::(&i16_array, &DataType::Int32) + ); + + let i16_expected = vec!["-32768", "-128", "0", "127", "32767"]; + assert_eq!( + i16_expected, + get_cast_values::(&i16_array, &DataType::Int16) + ); + + let i8_expected = vec!["null", "-128", "0", "127", "null"]; + assert_eq!( + i8_expected, + get_cast_values::(&i16_array, &DataType::Int8) + ); + + let u64_expected = vec!["null", "null", "0", "127", "32767"]; + assert_eq!( + u64_expected, + get_cast_values::(&i16_array, &DataType::UInt64) + ); + + let u32_expected = vec!["null", "null", "0", "127", "32767"]; + assert_eq!( + u32_expected, + get_cast_values::(&i16_array, &DataType::UInt32) + ); + + let u16_expected = vec!["null", "null", "0", "127", "32767"]; + assert_eq!( + u16_expected, + get_cast_values::(&i16_array, &DataType::UInt16) + ); + + let u8_expected = vec!["null", "null", "0", "127", "null"]; + assert_eq!( + u8_expected, + get_cast_values::(&i16_array, &DataType::UInt8) + ); + } + + #[test] + fn test_cast_from_int8() { + let i8_values: Vec = vec![std::i8::MIN, 0, std::i8::MAX]; + let i8_array: ArrayRef = Arc::new(Int8Array::from(i8_values)); + + let f64_expected = vec!["-128.0", "0.0", "127.0"]; + assert_eq!( + f64_expected, + get_cast_values::(&i8_array, &DataType::Float64) + ); + + let f32_expected = vec!["-128.0", "0.0", "127.0"]; + assert_eq!( + f32_expected, + get_cast_values::(&i8_array, &DataType::Float32) + ); + + let i64_expected = vec!["-128", "0", "127"]; + assert_eq!( + i64_expected, + get_cast_values::(&i8_array, &DataType::Int64) + ); + + let i32_expected = vec!["-128", "0", "127"]; + assert_eq!( + i32_expected, + get_cast_values::(&i8_array, &DataType::Int32) + ); + + let i16_expected = vec!["-128", "0", "127"]; + assert_eq!( + i16_expected, + get_cast_values::(&i8_array, &DataType::Int16) + ); + + let i8_expected = vec!["-128", "0", "127"]; + assert_eq!( + i8_expected, + get_cast_values::(&i8_array, &DataType::Int8) + ); + + let u64_expected = vec!["null", "0", "127"]; + assert_eq!( + u64_expected, + get_cast_values::(&i8_array, &DataType::UInt64) + ); + + let u32_expected = vec!["null", "0", "127"]; + assert_eq!( + u32_expected, + get_cast_values::(&i8_array, &DataType::UInt32) + ); + + let u16_expected = vec!["null", "0", "127"]; + assert_eq!( + u16_expected, + get_cast_values::(&i8_array, &DataType::UInt16) + ); + + let u8_expected = vec!["null", "0", "127"]; + assert_eq!( + u8_expected, + get_cast_values::(&i8_array, &DataType::UInt8) + ); + } + + /// Convert `array` into a vector of strings by casting to data type dt + fn get_cast_values(array: &ArrayRef, dt: &DataType) -> Vec + where + T: ArrowNumericType, + { + let c = cast(&array, dt).unwrap(); + let a = c.as_any().downcast_ref::>().unwrap(); + let mut v: Vec = vec![]; + for i in 0..array.len() { + if a.is_null(i) { + v.push("null".to_string()) + } else { + v.push(format!("{:?}", a.value(i))); + } + } + v + } + + #[test] + fn test_cast_utf8_dict() { + // FROM a dictionary with of Utf8 values + use DataType::*; + + let keys_builder = PrimitiveBuilder::::new(10); + let values_builder = StringBuilder::new(10); + let mut builder = StringDictionaryBuilder::new(keys_builder, values_builder); + builder.append("one").unwrap(); + builder.append_null().unwrap(); + builder.append("three").unwrap(); + let array: ArrayRef = Arc::new(builder.finish()); + + let expected = vec!["one", "null", "three"]; + + // Test casting TO StringArray + let cast_type = Utf8; + let cast_array = cast(&array, &cast_type).expect("cast to UTF-8 failed"); + assert_eq!(cast_array.data_type(), &cast_type); + assert_eq!(array_to_strings(&cast_array), expected); + + // Test casting TO Dictionary (with different index sizes) + + let cast_type = Dictionary(Box::new(Int16), Box::new(Utf8)); + let cast_array = cast(&array, &cast_type).expect("cast failed"); + assert_eq!(cast_array.data_type(), &cast_type); + assert_eq!(array_to_strings(&cast_array), expected); + + let cast_type = Dictionary(Box::new(Int32), Box::new(Utf8)); + let cast_array = cast(&array, &cast_type).expect("cast failed"); + assert_eq!(cast_array.data_type(), &cast_type); + assert_eq!(array_to_strings(&cast_array), expected); + + let cast_type = Dictionary(Box::new(Int64), Box::new(Utf8)); + let cast_array = cast(&array, &cast_type).expect("cast failed"); + assert_eq!(cast_array.data_type(), &cast_type); + assert_eq!(array_to_strings(&cast_array), expected); + + let cast_type = Dictionary(Box::new(UInt8), Box::new(Utf8)); + let cast_array = cast(&array, &cast_type).expect("cast failed"); + assert_eq!(cast_array.data_type(), &cast_type); + assert_eq!(array_to_strings(&cast_array), expected); + + let cast_type = Dictionary(Box::new(UInt16), Box::new(Utf8)); + let cast_array = cast(&array, &cast_type).expect("cast failed"); + assert_eq!(cast_array.data_type(), &cast_type); + assert_eq!(array_to_strings(&cast_array), expected); + + let cast_type = Dictionary(Box::new(UInt32), Box::new(Utf8)); + let cast_array = cast(&array, &cast_type).expect("cast failed"); + assert_eq!(cast_array.data_type(), &cast_type); + assert_eq!(array_to_strings(&cast_array), expected); + + let cast_type = Dictionary(Box::new(UInt64), Box::new(Utf8)); + let cast_array = cast(&array, &cast_type).expect("cast failed"); + assert_eq!(cast_array.data_type(), &cast_type); + assert_eq!(array_to_strings(&cast_array), expected); + } + + #[test] + fn test_cast_dict_to_dict_bad_index_value_primitive() { + use DataType::*; + // test converting from an array that has indexes of a type + // that are out of bounds for a particular other kind of + // index. + + let keys_builder = PrimitiveBuilder::::new(10); + let values_builder = PrimitiveBuilder::::new(10); + let mut builder = PrimitiveDictionaryBuilder::new(keys_builder, values_builder); + + // add 200 distinct values (which can be stored by a + // dictionary indexed by int32, but not a dictionary indexed + // with int8) + for i in 0..200 { + builder.append(i).unwrap(); + } + let array: ArrayRef = Arc::new(builder.finish()); + + let cast_type = Dictionary(Box::new(Int8), Box::new(Utf8)); + let res = cast(&array, &cast_type); + assert!(res.is_err()); + let actual_error = format!("{:?}", res); + let expected_error = "Could not convert 72 dictionary indexes from Int32 to Int8"; + assert!( + actual_error.contains(expected_error), + "did not find expected error '{}' in actual error '{}'", + actual_error, + expected_error + ); + } + + #[test] + fn test_cast_dict_to_dict_bad_index_value_utf8() { + use DataType::*; + // Same test as test_cast_dict_to_dict_bad_index_value but use + // string values (and encode the expected behavior here); + + let keys_builder = PrimitiveBuilder::::new(10); + let values_builder = StringBuilder::new(10); + let mut builder = StringDictionaryBuilder::new(keys_builder, values_builder); + + // add 200 distinct values (which can be stored by a + // dictionary indexed by int32, but not a dictionary indexed + // with int8) + for i in 0..200 { + let val = format!("val{}", i); + builder.append(&val).unwrap(); + } + let array: ArrayRef = Arc::new(builder.finish()); + + let cast_type = Dictionary(Box::new(Int8), Box::new(Utf8)); + let res = cast(&array, &cast_type); + assert!(res.is_err()); + let actual_error = format!("{:?}", res); + let expected_error = "Could not convert 72 dictionary indexes from Int32 to Int8"; + assert!( + actual_error.contains(expected_error), + "did not find expected error '{}' in actual error '{}'", + actual_error, + expected_error + ); + } + + #[test] + fn test_cast_primitive_dict() { + // FROM a dictionary with of INT32 values + use DataType::*; + + let keys_builder = PrimitiveBuilder::::new(10); + let values_builder = PrimitiveBuilder::::new(10); + let mut builder = PrimitiveDictionaryBuilder::new(keys_builder, values_builder); + builder.append(1).unwrap(); + builder.append_null().unwrap(); + builder.append(3).unwrap(); + let array: ArrayRef = Arc::new(builder.finish()); + + let expected = vec!["1", "null", "3"]; + + // Test casting TO PrimitiveArray, different dictionary type + let cast_array = cast(&array, &Utf8).expect("cast to UTF-8 failed"); + assert_eq!(array_to_strings(&cast_array), expected); + assert_eq!(cast_array.data_type(), &Utf8); + + let cast_array = cast(&array, &Int64).expect("cast to int64 failed"); + assert_eq!(array_to_strings(&cast_array), expected); + assert_eq!(cast_array.data_type(), &Int64); + } + + #[test] + fn test_cast_primitive_array_to_dict() { + use DataType::*; + + let mut builder = PrimitiveBuilder::::new(10); + builder.append_value(1).unwrap(); + builder.append_null().unwrap(); + builder.append_value(3).unwrap(); + let array: ArrayRef = Arc::new(builder.finish()); + + let expected = vec!["1", "null", "3"]; + + // Cast to a dictionary (same value type, Int32) + let cast_type = Dictionary(Box::new(UInt8), Box::new(Int32)); + let cast_array = cast(&array, &cast_type).expect("cast failed"); + assert_eq!(cast_array.data_type(), &cast_type); + assert_eq!(array_to_strings(&cast_array), expected); + + // Cast to a dictionary (different value type, Int8) + let cast_type = Dictionary(Box::new(UInt8), Box::new(Int8)); + let cast_array = cast(&array, &cast_type).expect("cast failed"); + assert_eq!(cast_array.data_type(), &cast_type); + assert_eq!(array_to_strings(&cast_array), expected); + } + + #[test] + fn test_cast_string_array_to_dict() { + use DataType::*; + + let array = Arc::new(StringArray::from(vec![Some("one"), None, Some("three")])) + as ArrayRef; + + let expected = vec!["one", "null", "three"]; + + // Cast to a dictionary (same value type, Utf8) + let cast_type = Dictionary(Box::new(UInt8), Box::new(Utf8)); + let cast_array = cast(&array, &cast_type).expect("cast failed"); + assert_eq!(cast_array.data_type(), &cast_type); + assert_eq!(array_to_strings(&cast_array), expected); + } + + #[test] + fn test_cast_null_array_to_int32() { + let array = Arc::new(NullArray::new(6)) as ArrayRef; + + let expected = Int32Array::from(vec![None; 6]); + + // Cast to a dictionary (same value type, Utf8) + let cast_type = DataType::Int32; + let cast_array = cast(&array, &cast_type).expect("cast failed"); + let cast_array = as_primitive_array::(&cast_array); + assert_eq!(cast_array.data_type(), &cast_type); + assert_eq!(cast_array, &expected); + } + + /// Print the `DictionaryArray` `array` as a vector of strings + fn array_to_strings(array: &ArrayRef) -> Vec { + (0..array.len()) + .map(|i| { + if array.is_null(i) { + "null".to_string() + } else { + array_value_to_string(array, i).expect("Convert array to String") + } + }) + .collect() + } + + #[test] + fn test_cast_utf8_to_date32() { + use chrono::NaiveDate; + let from_ymd = chrono::NaiveDate::from_ymd; + let since = chrono::NaiveDate::signed_duration_since; + + let a = StringArray::from(vec![ + "2000-01-01", // valid date with leading 0s + "2000-2-2", // valid date without leading 0s + "2000-00-00", // invalid month and day + "2000-01-01T12:00:00", // date + time is invalid + "2000", // just a year is invalid + ]); + let array = Arc::new(a) as ArrayRef; + let b = cast(&array, &DataType::Date32).unwrap(); + let c = b.as_any().downcast_ref::().unwrap(); + + // test valid inputs + let date_value = since(NaiveDate::from_ymd(2000, 1, 1), from_ymd(1970, 1, 1)) + .num_days() as i32; + assert_eq!(true, c.is_valid(0)); // "2000-01-01" + assert_eq!(date_value, c.value(0)); + + let date_value = since(NaiveDate::from_ymd(2000, 2, 2), from_ymd(1970, 1, 1)) + .num_days() as i32; + assert_eq!(true, c.is_valid(1)); // "2000-2-2" + assert_eq!(date_value, c.value(1)); + + // test invalid inputs + assert_eq!(false, c.is_valid(2)); // "2000-00-00" + assert_eq!(false, c.is_valid(3)); // "2000-01-01T12:00:00" + assert_eq!(false, c.is_valid(4)); // "2000" + } + + #[test] + fn test_cast_utf8_to_date64() { + let a = StringArray::from(vec![ + "2000-01-01T12:00:00", // date + time valid + "2020-12-15T12:34:56", // date + time valid + "2020-2-2T12:34:56", // valid date time without leading 0s + "2000-00-00T12:00:00", // invalid month and day + "2000-01-01 12:00:00", // missing the 'T' + "2000-01-01", // just a date is invalid + ]); + let array = Arc::new(a) as ArrayRef; + let b = cast(&array, &DataType::Date64).unwrap(); + let c = b.as_any().downcast_ref::().unwrap(); + + // test valid inputs + assert_eq!(true, c.is_valid(0)); // "2000-01-01T12:00:00" + assert_eq!(946728000000, c.value(0)); + assert_eq!(true, c.is_valid(1)); // "2020-12-15T12:34:56" + assert_eq!(1608035696000, c.value(1)); + assert_eq!(true, c.is_valid(2)); // "2020-2-2T12:34:56" + assert_eq!(1580646896000, c.value(2)); + + // test invalid inputs + assert_eq!(false, c.is_valid(3)); // "2000-00-00T12:00:00" + assert_eq!(false, c.is_valid(4)); // "2000-01-01 12:00:00" + assert_eq!(false, c.is_valid(5)); // "2000-01-01" + } + + #[test] + fn test_can_cast_types() { + // this function attempts to ensure that can_cast_types stays + // in sync with cast. It simply tries all combinations of + // types and makes sure that if `can_cast_types` returns + // true, so does `cast` + + let all_types = get_all_types(); + + for array in get_arrays_of_all_types() { + for to_type in &all_types { + println!("Test casting {:?} --> {:?}", array.data_type(), to_type); + let cast_result = cast(&array, &to_type); + let reported_cast_ability = can_cast_types(array.data_type(), to_type); + + // check for mismatch + match (cast_result, reported_cast_ability) { + (Ok(_), false) => { + panic!("Was able to cast array from {:?} to {:?} but can_cast_types reported false", + array.data_type(), to_type) + } + (Err(e), true) => { + panic!("Was not able to cast array from {:?} to {:?} but can_cast_types reported true. \ + Error was {:?}", + array.data_type(), to_type, e) + } + // otherwise it was a match + _ => {} + }; + } + } + } + + /// Create instances of arrays with varying types for cast tests + fn get_arrays_of_all_types() -> Vec { + let tz_name = String::from("America/New_York"); + let binary_data: Vec<&[u8]> = vec![b"foo", b"bar"]; + vec![ + Arc::new(BinaryArray::from(binary_data.clone())), + Arc::new(LargeBinaryArray::from(binary_data.clone())), + make_dictionary_primitive::(), + make_dictionary_primitive::(), + make_dictionary_primitive::(), + make_dictionary_primitive::(), + make_dictionary_primitive::(), + make_dictionary_primitive::(), + make_dictionary_primitive::(), + make_dictionary_primitive::(), + make_dictionary_utf8::(), + make_dictionary_utf8::(), + make_dictionary_utf8::(), + make_dictionary_utf8::(), + make_dictionary_utf8::(), + make_dictionary_utf8::(), + make_dictionary_utf8::(), + make_dictionary_utf8::(), + Arc::new(make_list_array()), + Arc::new(make_large_list_array()), + Arc::new(make_fixed_size_list_array()), + Arc::new(make_fixed_size_binary_array()), + Arc::new(StructArray::from(vec![ + ( + Field::new("a", DataType::Boolean, false), + Arc::new(BooleanArray::from(vec![false, false, true, true])) + as Arc, + ), + ( + Field::new("b", DataType::Int32, false), + Arc::new(Int32Array::from(vec![42, 28, 19, 31])), + ), + ])), + //Arc::new(make_union_array()), + Arc::new(NullArray::new(10)), + Arc::new(StringArray::from(vec!["foo", "bar"])), + Arc::new(LargeStringArray::from(vec!["foo", "bar"])), + Arc::new(BooleanArray::from(vec![true, false])), + Arc::new(Int8Array::from(vec![1, 2])), + Arc::new(Int16Array::from(vec![1, 2])), + Arc::new(Int32Array::from(vec![1, 2])), + Arc::new(Int64Array::from(vec![1, 2])), + Arc::new(UInt8Array::from(vec![1, 2])), + Arc::new(UInt16Array::from(vec![1, 2])), + Arc::new(UInt32Array::from(vec![1, 2])), + Arc::new(UInt64Array::from(vec![1, 2])), + Arc::new(Float32Array::from(vec![1.0, 2.0])), + Arc::new(Float64Array::from(vec![1.0, 2.0])), + Arc::new(TimestampSecondArray::from_vec(vec![1000, 2000], None)), + Arc::new(TimestampMillisecondArray::from_vec(vec![1000, 2000], None)), + Arc::new(TimestampMicrosecondArray::from_vec(vec![1000, 2000], None)), + Arc::new(TimestampNanosecondArray::from_vec(vec![1000, 2000], None)), + Arc::new(TimestampSecondArray::from_vec( + vec![1000, 2000], + Some(tz_name.clone()), + )), + Arc::new(TimestampMillisecondArray::from_vec( + vec![1000, 2000], + Some(tz_name.clone()), + )), + Arc::new(TimestampMicrosecondArray::from_vec( + vec![1000, 2000], + Some(tz_name.clone()), + )), + Arc::new(TimestampNanosecondArray::from_vec( + vec![1000, 2000], + Some(tz_name), + )), + Arc::new(Date32Array::from(vec![1000, 2000])), + Arc::new(Date64Array::from(vec![1000, 2000])), + Arc::new(Time32SecondArray::from(vec![1000, 2000])), + Arc::new(Time32MillisecondArray::from(vec![1000, 2000])), + Arc::new(Time64MicrosecondArray::from(vec![1000, 2000])), + Arc::new(Time64NanosecondArray::from(vec![1000, 2000])), + Arc::new(IntervalYearMonthArray::from(vec![1000, 2000])), + Arc::new(IntervalDayTimeArray::from(vec![1000, 2000])), + Arc::new(DurationSecondArray::from(vec![1000, 2000])), + Arc::new(DurationMillisecondArray::from(vec![1000, 2000])), + Arc::new(DurationMicrosecondArray::from(vec![1000, 2000])), + Arc::new(DurationNanosecondArray::from(vec![1000, 2000])), + ] + } + + fn make_list_array() -> ListArray { + // Construct a value array + let value_data = ArrayData::builder(DataType::Int32) + .len(8) + .add_buffer(Buffer::from_slice_ref(&[0, 1, 2, 3, 4, 5, 6, 7])) + .build(); + + // Construct a buffer for value offsets, for the nested array: + // [[0, 1, 2], [3, 4, 5], [6, 7]] + let value_offsets = Buffer::from_slice_ref(&[0, 3, 6, 8]); + + // Construct a list array from the above two + let list_data_type = + DataType::List(Box::new(Field::new("item", DataType::Int32, true))); + let list_data = ArrayData::builder(list_data_type) + .len(3) + .add_buffer(value_offsets) + .add_child_data(value_data) + .build(); + ListArray::from(list_data) + } + + fn make_large_list_array() -> LargeListArray { + // Construct a value array + let value_data = ArrayData::builder(DataType::Int32) + .len(8) + .add_buffer(Buffer::from_slice_ref(&[0, 1, 2, 3, 4, 5, 6, 7])) + .build(); + + // Construct a buffer for value offsets, for the nested array: + // [[0, 1, 2], [3, 4, 5], [6, 7]] + let value_offsets = Buffer::from_slice_ref(&[0i64, 3, 6, 8]); + + // Construct a list array from the above two + let list_data_type = + DataType::LargeList(Box::new(Field::new("item", DataType::Int32, true))); + let list_data = ArrayData::builder(list_data_type) + .len(3) + .add_buffer(value_offsets) + .add_child_data(value_data) + .build(); + LargeListArray::from(list_data) + } + + fn make_fixed_size_list_array() -> FixedSizeListArray { + // Construct a value array + let value_data = ArrayData::builder(DataType::Int32) + .len(10) + .add_buffer(Buffer::from_slice_ref(&[0, 1, 2, 3, 4, 5, 6, 7, 8, 9])) + .build(); + + // Construct a fixed size list array from the above two + let list_data_type = DataType::FixedSizeList( + Box::new(Field::new("item", DataType::Int32, true)), + 2, + ); + let list_data = ArrayData::builder(list_data_type) + .len(5) + .add_child_data(value_data) + .build(); + FixedSizeListArray::from(list_data) + } + + fn make_fixed_size_binary_array() -> FixedSizeBinaryArray { + let values: [u8; 15] = *b"hellotherearrow"; + + let array_data = ArrayData::builder(DataType::FixedSizeBinary(5)) + .len(3) + .add_buffer(Buffer::from(&values[..])) + .build(); + FixedSizeBinaryArray::from(array_data) + } + + fn make_union_array() -> UnionArray { + let mut builder = UnionBuilder::new_dense(7); + builder.append::("a", 1).unwrap(); + builder.append::("b", 2).unwrap(); + builder.build().unwrap() + } + + /// Creates a dictionary with primitive dictionary values, and keys of type K + fn make_dictionary_primitive() -> ArrayRef { + let keys_builder = PrimitiveBuilder::::new(2); + // Pick Int32 arbitrarily for dictionary values + let values_builder = PrimitiveBuilder::::new(2); + let mut b = PrimitiveDictionaryBuilder::new(keys_builder, values_builder); + b.append(1).unwrap(); + b.append(2).unwrap(); + Arc::new(b.finish()) + } + + /// Creates a dictionary with utf8 values, and keys of type K + fn make_dictionary_utf8() -> ArrayRef { + let keys_builder = PrimitiveBuilder::::new(2); + // Pick Int32 arbitrarily for dictionary values + let values_builder = StringBuilder::new(2); + let mut b = StringDictionaryBuilder::new(keys_builder, values_builder); + b.append("foo").unwrap(); + b.append("bar").unwrap(); + Arc::new(b.finish()) + } + + // Get a selection of datatypes to try and cast to + fn get_all_types() -> Vec { + use DataType::*; + let tz_name = String::from("America/New_York"); + + vec![ + Null, + Boolean, + Int8, + Int16, + Int32, + UInt64, + UInt8, + UInt16, + UInt32, + UInt64, + Float16, + Float32, + Float64, + Timestamp(TimeUnit::Second, None), + Timestamp(TimeUnit::Millisecond, None), + Timestamp(TimeUnit::Microsecond, None), + Timestamp(TimeUnit::Nanosecond, None), + Timestamp(TimeUnit::Second, Some(tz_name.clone())), + Timestamp(TimeUnit::Millisecond, Some(tz_name.clone())), + Timestamp(TimeUnit::Microsecond, Some(tz_name.clone())), + Timestamp(TimeUnit::Nanosecond, Some(tz_name)), + Date32, + Date64, + Time32(TimeUnit::Second), + Time32(TimeUnit::Millisecond), + Time64(TimeUnit::Microsecond), + Time64(TimeUnit::Nanosecond), + Duration(TimeUnit::Second), + Duration(TimeUnit::Millisecond), + Duration(TimeUnit::Microsecond), + Duration(TimeUnit::Nanosecond), + Interval(IntervalUnit::YearMonth), + Interval(IntervalUnit::DayTime), + Binary, + FixedSizeBinary(10), + LargeBinary, + Utf8, + LargeUtf8, + List(Box::new(Field::new("item", DataType::Int8, true))), + List(Box::new(Field::new("item", DataType::Utf8, true))), + FixedSizeList(Box::new(Field::new("item", DataType::Int8, true)), 10), + FixedSizeList(Box::new(Field::new("item", DataType::Utf8, false)), 10), + LargeList(Box::new(Field::new("item", DataType::Int8, true))), + LargeList(Box::new(Field::new("item", DataType::Utf8, false))), + Struct(vec![ + Field::new("f1", DataType::Int32, false), + Field::new("f2", DataType::Utf8, true), + ]), + Union(vec![ + Field::new("f1", DataType::Int32, false), + Field::new("f2", DataType::Utf8, true), + ]), + Dictionary(Box::new(DataType::Int8), Box::new(DataType::Int32)), + Dictionary(Box::new(DataType::Int16), Box::new(DataType::Utf8)), + Dictionary(Box::new(DataType::UInt32), Box::new(DataType::Utf8)), + ] + } + */ +} diff --git a/src/compute/cast/primitive_to.rs b/src/compute/cast/primitive_to.rs new file mode 100644 index 00000000000..4012bbbf06f --- /dev/null +++ b/src/compute/cast/primitive_to.rs @@ -0,0 +1,100 @@ +use std::hash::Hash; + +use crate::{ + array::{ + dict_from_iter, Array, BooleanArray, DictionaryKey, DictionaryPrimitive, Offset, Primitive, + PrimitiveArray, Utf8Array, + }, + buffer::{Bitmap, NativeType}, + datatypes::DataType, +}; +use crate::{error::Result, util::lexical_to_string}; + +use super::cast; + +/// Cast numeric types to Utf8 +pub fn cast_numeric_to_string(array: &dyn Array) -> Result> +where + O: Offset, + T: NativeType + lexical_core::ToLexical, +{ + let array = array.as_any().downcast_ref::>().unwrap(); + + let iter = array.iter().map(|x| x.map(lexical_to_string)); + + let array = unsafe { Utf8Array::::from_trusted_len_iter(iter) }; + + Ok(Box::new(array)) +} + +/// Convert Array into a PrimitiveArray of type, and apply numeric cast +pub fn cast_numeric_arrays(from: &dyn Array, to_type: &DataType) -> Result> +where + I: NativeType + num::NumCast, + O: NativeType + num::NumCast, +{ + let from = from.as_any().downcast_ref::>().unwrap(); + Ok(Box::new(cast_typed_primitive::(from, to_type))) +} + +/// Cast PrimitiveArray to PrimitiveArray +pub fn cast_typed_primitive(from: &PrimitiveArray, to_type: &DataType) -> PrimitiveArray +where + I: NativeType + num::NumCast, + O: NativeType + num::NumCast, +{ + let from = from.as_any().downcast_ref::>().unwrap(); + + let iter = from.iter().map(|v| v.and_then(num::cast::cast::)); + // Soundness: + // The iterator is trustedLen because it comes from an `PrimitiveArray`. + unsafe { Primitive::::from_trusted_len_iter(iter) }.to(to_type.clone()) +} + +/// Cast an array by changing its data type to the desired type +pub fn cast_array_data(from: &dyn Array, to_type: &DataType) -> Result> +where + T: NativeType, +{ + let from = from.as_any().downcast_ref::>().unwrap(); + + Ok(Box::new(PrimitiveArray::::from_data( + to_type.clone(), + from.values().clone(), + from.nulls().clone(), + ))) +} + +/// Cast numeric types to Boolean +/// +/// Any zero value returns `false` while non-zero returns `true` +pub fn cast_numeric_to_bool(array: &dyn Array) -> Result> +where + T: NativeType, +{ + let array = array.as_any().downcast_ref::>().unwrap(); + + let iter = array.values().as_slice().iter().map(|v| *v != T::default()); + let values = unsafe { Bitmap::from_trusted_len_iter(iter) }; + + let array = BooleanArray::from_data(values, array.nulls().clone()); + + Ok(Box::new(array)) +} + +pub fn primitive_to_dictionary( + array: &dyn Array, + to: &DataType, +) -> Result> { + let values = cast(array, to)?; + let values = values.as_any().downcast_ref::>().unwrap(); + + let primitive: DictionaryPrimitive> = dict_from_iter(values.iter())?; + + let array = primitive.to(DataType::Dictionary( + Box::new(DataType::Utf8), + Box::new(DataType::Utf8), + )); + + Ok(Box::new(array)) +} diff --git a/src/compute/cast/string_to.rs b/src/compute/cast/string_to.rs new file mode 100644 index 00000000000..5008a121890 --- /dev/null +++ b/src/compute/cast/string_to.rs @@ -0,0 +1,87 @@ +use chrono::Datelike; + +use crate::{ + array::{ + dict_from_iter, Array, DictionaryKey, DictionaryPrimitive, Offset, Primitive, + PrimitiveArray, Utf8Array, Utf8Primitive, + }, + buffer::NativeType, + datatypes::DataType, +}; +use crate::{error::Result, temporal_conversions::EPOCH_DAYS_FROM_CE}; + +use super::cast; + +/// Cast numeric types to Utf8 +pub fn cast_string_to_numeric( + from: &dyn Array, + to: &DataType, +) -> Result> +where + T: NativeType + lexical_core::FromLexical, +{ + let from = from.as_any().downcast_ref::>().unwrap(); + + let iter = from + .iter() + .map(|x| x.and_then::(|x| lexical_core::parse(x.as_bytes()).ok())); + + // Benefit: + // 20% performance improvement + // Soundness: + // The iterator is trustedLen because it comes from an `StringArray`. + let array = unsafe { Primitive::::from_trusted_len_iter(iter) }.to(to.clone()); + + Ok(Box::new(array)) +} + +pub fn to_date32(array: &dyn Array, to_type: &DataType) -> PrimitiveArray { + let array = array.as_any().downcast_ref::>().unwrap(); + + let iter = array.iter().map(|x| { + x.and_then(|x| { + x.parse::() + .ok() + .map(|x| x.num_days_from_ce() - EPOCH_DAYS_FROM_CE) + }) + }); + // Soundness: + // The iterator is trustedLen because it comes from a `Utf8Array`. + unsafe { Primitive::::from_trusted_len_iter(iter) }.to(to_type.clone()) +} + +pub fn to_date64(array: &dyn Array, to_type: &DataType) -> PrimitiveArray { + let array = array.as_any().downcast_ref::>().unwrap(); + + let iter = array.iter().map(|x| { + x.and_then(|x| { + x.parse::() + .ok() + .map(|x| x.timestamp_millis()) + }) + }); + // Soundness: + // The iterator is trustedLen because it comes from a `Utf8Array`. + unsafe { Primitive::::from_trusted_len_iter(iter) }.to(to_type.clone()) +} + +// Packs the data as a StringDictionaryArray, if possible, with the +// key types of K +pub fn string_to_dictionary( + array: &dyn Array, +) -> Result> { + let to = if O::is_large() { + DataType::LargeUtf8 + } else { + DataType::Utf8 + }; + + let values = cast(array, &to)?; + let values = values.as_any().downcast_ref::>().unwrap(); + + let primitive: DictionaryPrimitive> = dict_from_iter(values.iter())?; + + let array = primitive.to(DataType::Dictionary(Box::new(K::DATA_TYPE), Box::new(to))); + + Ok(Box::new(array)) +} diff --git a/src/compute/mod.rs b/src/compute/mod.rs new file mode 100644 index 00000000000..2c29ed3e22f --- /dev/null +++ b/src/compute/mod.rs @@ -0,0 +1,3 @@ +pub mod arity; +pub mod cast; +pub mod take; diff --git a/src/compute/take.rs b/src/compute/take.rs new file mode 100644 index 00000000000..d72cc84cf8a --- /dev/null +++ b/src/compute/take.rs @@ -0,0 +1,296 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Defines take kernel for [Array] + +use crate::{ + buffer::MutableBitmap, + error::{ArrowError, Result}, +}; + +use crate::{ + array::{Array, Offset, PrimitiveArray}, + buffer::{types::NativeType, Bitmap, Buffer, MutableBuffer}, + datatypes::DataType, +}; + +macro_rules! downcast_take { + ($type: ty, $values: expr, $indices: expr) => {{ + let values = $values + .as_any() + .downcast_ref::>() + .expect("Unable to downcast to a primitive array"); + Ok(Box::new(take_primitive::<$type, _>(&values, $indices)?)) + }}; +} + +pub fn take( + values: &dyn Array, + indices: &PrimitiveArray, + options: Option, +) -> Result> { + take_impl(values, indices, options) +} + +fn take_impl( + values: &dyn Array, + indices: &PrimitiveArray, + options: Option, +) -> Result> { + match values.data_type() { + DataType::Int8 => downcast_take!(i8, values, indices), + DataType::Int16 => downcast_take!(i16, values, indices), + DataType::Int32 => downcast_take!(i32, values, indices), + DataType::Int64 => downcast_take!(i64, values, indices), + DataType::UInt8 => downcast_take!(u8, values, indices), + DataType::UInt16 => downcast_take!(u16, values, indices), + DataType::UInt32 => downcast_take!(u32, values, indices), + DataType::UInt64 => downcast_take!(u64, values, indices), + DataType::Float32 => downcast_take!(f32, values, indices), + DataType::Float64 => downcast_take!(f64, values, indices), + t => unimplemented!("Take not supported for data type {:?}", t), + } +} + +/// Options that define how `take` should behave +#[derive(Clone, Debug)] +pub struct TakeOptions { + /// Perform bounds check before taking indices from values. + /// If enabled, an `ArrowError` is returned if the indices are out of bounds. + /// If not enabled, and indices exceed bounds, the kernel will panic. + pub check_bounds: bool, +} + +impl Default for TakeOptions { + fn default() -> Self { + Self { + check_bounds: false, + } + } +} + +#[inline(always)] +fn maybe_usize(index: I) -> Result { + index + .to_usize() + .ok_or_else(|| ArrowError::ComputeError("Cast to usize failed".to_string())) +} + +// take implementation when neither values nor indices contain nulls +fn take_no_nulls( + values: &[T], + indices: &[I], +) -> Result<(Buffer, Option)> { + let values = indices + .iter() + .map(|index| Result::Ok(values[maybe_usize::(*index)?])); + // Soundness: `slice.map` is `TrustedLen`. + let buffer = unsafe { MutableBuffer::try_from_trusted_len_iter(values)? }; + + Ok((buffer.into(), None)) +} + +// take implementation when only values contain nulls +fn take_values_nulls( + values: &PrimitiveArray, + indices: &[I], +) -> Result<(Buffer, Option)> { + let mut null = MutableBitmap::with_capacity(indices.len()); + + let null_values = values.nulls().as_ref().unwrap(); + + let values_values = values.values().as_slice(); + + let values = indices.iter().map(|index| { + let index = maybe_usize::(*index)?; + if null_values.get_bit(index) { + null.push(true); + } else { + null.push(false); + } + Result::Ok(values_values[index]) + }); + // Soundness: `slice.map` is `TrustedLen`. + let buffer = unsafe { MutableBuffer::try_from_trusted_len_iter(values)? }; + + let bitmap = if null.null_count() > 0 { + Some(null.into()) + } else { + None + }; + + Ok((buffer.into(), bitmap)) +} + +// take implementation when only indices contain nulls +fn take_indices_nulls( + values: &[T], + indices: &PrimitiveArray, +) -> Result<(Buffer, Option)> { + let null_indices = indices.nulls().as_ref().unwrap(); + + let values = indices.values().as_slice().iter().map(|index| { + let index = maybe_usize::(*index)?; + Result::Ok(match values.get(index) { + Some(value) => *value, + None => { + if null_indices.get_bit(index) { + panic!("Out-of-bounds index {}", index) + } else { + T::default() + } + } + }) + }); + + // Soundness: `slice.map` is `TrustedLen`. + let buffer = unsafe { MutableBuffer::try_from_trusted_len_iter(values)? }; + + Ok((buffer.into(), indices.nulls().clone())) +} + +// take implementation when both values and indices contain nulls +fn take_values_indices_nulls( + values: &PrimitiveArray, + indices: &PrimitiveArray, +) -> Result<(Buffer, Option)> { + let mut bitmap = MutableBitmap::with_capacity(indices.len()); + + let null_values = values.nulls().as_ref().unwrap(); + + let values_values = values.values().as_slice(); + let values = indices.iter().map(|index| match index { + Some(index) => { + let index = maybe_usize::(index)?; + bitmap.push(null_values.get_bit(index)); + Result::Ok(values_values[index]) + } + None => { + bitmap.push(false); + Ok(T::default()) + } + }); + // Soundness: `slice.map` is `TrustedLen`. + let buffer = unsafe { MutableBuffer::try_from_trusted_len_iter(values)? }; + + let bitmap = if bitmap.null_count() > 0 { + Some(bitmap.into()) + } else { + None + }; + + Ok((buffer.into(), bitmap)) +} + +/// `take` implementation for all primitive arrays +/// +/// This checks if an `indices` slot is populated, and gets the value from `values` +/// as the populated index. +/// If the `indices` slot is null, a null value is returned. +/// For example, given: +/// values: [1, 2, 3, null, 5] +/// indices: [0, null, 4, 3] +/// The result is: [1 (slot 0), null (null slot), 5 (slot 4), null (slot 3)] +fn take_primitive( + values: &PrimitiveArray, + indices: &PrimitiveArray, +) -> Result> { + let indices_has_nulls = indices.null_count() > 0; + let values_has_nulls = values.null_count() > 0; + // note: this function should only panic when "an index is not null and out of bounds". + // if the index is null, its value is undefined and therefore we should not read from it. + + let (buffer, nulls) = match (values_has_nulls, indices_has_nulls) { + (false, false) => { + // * no nulls + // * all `indices.values()` are valid + take_no_nulls::(values.values().as_slice(), indices.values().as_slice())? + } + (true, false) => { + // * nulls come from `values` alone + // * all `indices.values()` are valid + take_values_nulls::(values, indices.values().as_slice())? + } + (false, true) => { + // in this branch it is unsound to read and use `index.values()`, + // as doing so is UB when they come from a null slot. + take_indices_nulls::(values.values().as_slice(), indices)? + } + (true, true) => { + // in this branch it is unsound to read and use `index.values()`, + // as doing so is UB when they come from a null slot. + take_values_indices_nulls::(values, indices)? + } + }; + + Ok(PrimitiveArray::::from_data( + values.data_type().clone(), + buffer, + nulls, + )) +} + +#[cfg(test)] +mod tests { + use crate::{ + array::Primitive, + datatypes::{Int8Type, PrimitiveType}, + }; + + use super::*; + + fn test_take_primitive_arrays( + data: &[Option], + index: &PrimitiveArray, + options: Option, + expected_data: &[Option], + ) -> Result<()> + where + T: PrimitiveType, + { + let output = Primitive::::from(data).to(T::DATA_TYPE); + let expected = Primitive::::from(expected_data).to(T::DATA_TYPE); + let output = take(&output, index, options)?; + assert_eq!(output.as_ref(), &expected); + Ok(()) + } + + #[test] + fn test_take_primitive_non_null_indices() { + let index = Primitive::::from_slice(&[0, 5, 3, 1, 4, 2]).to(DataType::Int32); + test_take_primitive_arrays::( + &[None, Some(3), Some(5), Some(2), Some(3), None], + &index, + None, + &[None, None, Some(2), Some(3), Some(3), Some(5)], + ) + .unwrap(); + } + + #[test] + fn test_take_primitive_non_null_values() { + let index = + Primitive::::from(&[Some(3), None, Some(1), Some(3), Some(2)]).to(DataType::Int32); + test_take_primitive_arrays::( + &[Some(0), Some(1), Some(2), Some(3), Some(4)], + &index, + None, + &[Some(3), None, Some(1), Some(3), Some(2)], + ) + .unwrap(); + } +} diff --git a/src/datatypes/field.rs b/src/datatypes/field.rs new file mode 100644 index 00000000000..192c3408b95 --- /dev/null +++ b/src/datatypes/field.rs @@ -0,0 +1,245 @@ +use std::collections::BTreeMap; + +use crate::error::{ArrowError, Result}; + +use super::DataType; + +/// Contains the meta-data for a single relative type. +/// +/// The `Schema` object is an ordered collection of `Field` objects. +#[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)] +pub struct Field { + pub(super) name: String, + pub(super) data_type: DataType, + pub(super) nullable: bool, + pub(super) dict_id: i64, + pub(super) dict_is_ordered: bool, + /// A map of key-value pairs containing additional custom meta data. + pub(super) metadata: Option>, +} + +impl Field { + /// Creates a new field + pub fn new(name: &str, data_type: DataType, nullable: bool) -> Self { + Field { + name: name.to_string(), + data_type, + nullable, + dict_id: 0, + dict_is_ordered: false, + metadata: None, + } + } + + /// Creates a new field + pub fn new_dict( + name: &str, + data_type: DataType, + nullable: bool, + dict_id: i64, + dict_is_ordered: bool, + ) -> Self { + Field { + name: name.to_string(), + data_type, + nullable, + dict_id, + dict_is_ordered, + metadata: None, + } + } + + /// Sets the `Field`'s optional custom metadata. + /// The metadata is set as `None` for empty map. + #[inline] + pub fn set_metadata(&mut self, metadata: Option>) { + // To make serde happy, convert Some(empty_map) to None. + self.metadata = None; + if let Some(v) = metadata { + if !v.is_empty() { + self.metadata = Some(v); + } + } + } + + /// Returns the immutable reference to the `Field`'s optional custom metadata. + #[inline] + pub const fn metadata(&self) -> &Option> { + &self.metadata + } + + /// Returns an immutable reference to the `Field`'s name. + #[inline] + pub const fn name(&self) -> &String { + &self.name + } + + /// Returns an immutable reference to the `Field`'s data-type. + #[inline] + pub const fn data_type(&self) -> &DataType { + &self.data_type + } + + /// Indicates whether this `Field` supports null values. + #[inline] + pub const fn is_nullable(&self) -> bool { + self.nullable + } + + /// Returns the dictionary ID, if this is a dictionary type. + #[inline] + pub const fn dict_id(&self) -> Option { + match self.data_type { + DataType::Dictionary(_, _) => Some(self.dict_id), + _ => None, + } + } + + /// Returns whether this `Field`'s dictionary is ordered, if this is a dictionary type. + #[inline] + pub const fn dict_is_ordered(&self) -> Option { + match self.data_type { + DataType::Dictionary(_, _) => Some(self.dict_is_ordered), + _ => None, + } + } + + /// Merge field into self if it is compatible. Struct will be merged recursively. + /// NOTE: `self` may be updated to unexpected state in case of merge failure. + /// + /// Example: + /// + /// ``` + /// use arrow::datatypes::*; + /// + /// let mut field = Field::new("c1", DataType::Int64, false); + /// assert!(field.try_merge(&Field::new("c1", DataType::Int64, true)).is_ok()); + /// assert!(field.is_nullable()); + /// ``` + pub fn try_merge(&mut self, from: &Field) -> Result<()> { + // merge metadata + match (self.metadata(), from.metadata()) { + (Some(self_metadata), Some(from_metadata)) => { + let mut merged = self_metadata.clone(); + for (key, from_value) in from_metadata { + if let Some(self_value) = self_metadata.get(key) { + if self_value != from_value { + return Err(ArrowError::SchemaError(format!( + "Fail to merge field due to conflicting metadata data value for key {}", key), + )); + } + } else { + merged.insert(key.clone(), from_value.clone()); + } + } + self.set_metadata(Some(merged)); + } + (None, Some(from_metadata)) => { + self.set_metadata(Some(from_metadata.clone())); + } + _ => {} + } + if from.dict_id != self.dict_id { + return Err(ArrowError::SchemaError( + "Fail to merge schema Field due to conflicting dict_id".to_string(), + )); + } + if from.dict_is_ordered != self.dict_is_ordered { + return Err(ArrowError::SchemaError( + "Fail to merge schema Field due to conflicting dict_is_ordered".to_string(), + )); + } + match &mut self.data_type { + DataType::Struct(nested_fields) => match &from.data_type { + DataType::Struct(from_nested_fields) => { + for from_field in from_nested_fields { + let mut is_new_field = true; + for self_field in nested_fields.iter_mut() { + if self_field.name != from_field.name { + continue; + } + is_new_field = false; + self_field.try_merge(&from_field)?; + } + if is_new_field { + nested_fields.push(from_field.clone()); + } + } + } + _ => { + return Err(ArrowError::SchemaError( + "Fail to merge schema Field due to conflicting datatype".to_string(), + )); + } + }, + DataType::Union(nested_fields) => match &from.data_type { + DataType::Union(from_nested_fields) => { + for from_field in from_nested_fields { + let mut is_new_field = true; + for self_field in nested_fields.iter_mut() { + if from_field == self_field { + is_new_field = false; + break; + } + } + if is_new_field { + nested_fields.push(from_field.clone()); + } + } + } + _ => { + return Err(ArrowError::SchemaError( + "Fail to merge schema Field due to conflicting datatype".to_string(), + )); + } + }, + DataType::Null + | DataType::Boolean + | DataType::Int8 + | DataType::Int16 + | DataType::Int32 + | DataType::Int64 + | DataType::UInt8 + | DataType::UInt16 + | DataType::UInt32 + | DataType::UInt64 + | DataType::Float16 + | DataType::Float32 + | DataType::Float64 + | DataType::Timestamp(_, _) + | DataType::Date32 + | DataType::Date64 + | DataType::Time32(_) + | DataType::Time64(_) + | DataType::Duration(_) + | DataType::Binary + | DataType::LargeBinary + | DataType::Interval(_) + | DataType::LargeList(_) + | DataType::List(_) + | DataType::Dictionary(_, _) + | DataType::FixedSizeList(_, _) + | DataType::FixedSizeBinary(_) + | DataType::Utf8 + | DataType::LargeUtf8 + | DataType::Decimal(_, _) => { + if self.data_type != from.data_type { + return Err(ArrowError::SchemaError( + "Fail to merge schema Field due to conflicting datatype".to_string(), + )); + } + } + } + if from.nullable { + self.nullable = from.nullable; + } + + Ok(()) + } +} + +impl std::fmt::Display for Field { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + write!(f, "{:?}", self) + } +} diff --git a/src/datatypes/json.rs b/src/datatypes/json.rs new file mode 100644 index 00000000000..5f14f154de5 --- /dev/null +++ b/src/datatypes/json.rs @@ -0,0 +1,591 @@ +use std::{ + collections::{BTreeMap, HashMap}, + convert::TryFrom, +}; + +use serde_derive::Deserialize; +use serde_json::{json, Value}; + +use crate::error::ArrowError; + +use super::{DataType, Field, IntervalUnit, Schema, TimeUnit}; + +pub trait ToJson { + /// Generate a JSON representation + fn to_json(&self) -> Value; +} + +impl ToJson for DataType { + fn to_json(&self) -> Value { + match self { + DataType::Null => json!({"name": "null"}), + DataType::Boolean => json!({"name": "bool"}), + DataType::Int8 => json!({"name": "int", "bitWidth": 8, "isSigned": true}), + DataType::Int16 => json!({"name": "int", "bitWidth": 16, "isSigned": true}), + DataType::Int32 => json!({"name": "int", "bitWidth": 32, "isSigned": true}), + DataType::Int64 => json!({"name": "int", "bitWidth": 64, "isSigned": true}), + DataType::UInt8 => json!({"name": "int", "bitWidth": 8, "isSigned": false}), + DataType::UInt16 => json!({"name": "int", "bitWidth": 16, "isSigned": false}), + DataType::UInt32 => json!({"name": "int", "bitWidth": 32, "isSigned": false}), + DataType::UInt64 => json!({"name": "int", "bitWidth": 64, "isSigned": false}), + DataType::Float16 => json!({"name": "floatingpoint", "precision": "HALF"}), + DataType::Float32 => json!({"name": "floatingpoint", "precision": "SINGLE"}), + DataType::Float64 => json!({"name": "floatingpoint", "precision": "DOUBLE"}), + DataType::Utf8 => json!({"name": "utf8"}), + DataType::LargeUtf8 => json!({"name": "largeutf8"}), + DataType::Binary => json!({"name": "binary"}), + DataType::LargeBinary => json!({"name": "largebinary"}), + DataType::FixedSizeBinary(byte_width) => { + json!({"name": "fixedsizebinary", "byteWidth": byte_width}) + } + DataType::Struct(_) => json!({"name": "struct"}), + DataType::Union(_) => json!({"name": "union"}), + DataType::List(_) => json!({ "name": "list"}), + DataType::LargeList(_) => json!({ "name": "largelist"}), + DataType::FixedSizeList(_, length) => { + json!({"name":"fixedsizelist", "listSize": length}) + } + DataType::Time32(unit) => { + json!({"name": "time", "bitWidth": 32, "unit": match unit { + TimeUnit::Second => "SECOND", + TimeUnit::Millisecond => "MILLISECOND", + TimeUnit::Microsecond => "MICROSECOND", + TimeUnit::Nanosecond => "NANOSECOND", + }}) + } + DataType::Time64(unit) => { + json!({"name": "time", "bitWidth": 64, "unit": match unit { + TimeUnit::Second => "SECOND", + TimeUnit::Millisecond => "MILLISECOND", + TimeUnit::Microsecond => "MICROSECOND", + TimeUnit::Nanosecond => "NANOSECOND", + }}) + } + DataType::Date32 => { + json!({"name": "date", "unit": "DAY"}) + } + DataType::Date64 => { + json!({"name": "date", "unit": "MILLISECOND"}) + } + DataType::Timestamp(unit, None) => { + json!({"name": "timestamp", "unit": match unit { + TimeUnit::Second => "SECOND", + TimeUnit::Millisecond => "MILLISECOND", + TimeUnit::Microsecond => "MICROSECOND", + TimeUnit::Nanosecond => "NANOSECOND", + }}) + } + DataType::Timestamp(unit, Some(tz)) => { + json!({"name": "timestamp", "unit": match unit { + TimeUnit::Second => "SECOND", + TimeUnit::Millisecond => "MILLISECOND", + TimeUnit::Microsecond => "MICROSECOND", + TimeUnit::Nanosecond => "NANOSECOND", + }, "timezone": tz}) + } + DataType::Interval(unit) => json!({"name": "interval", "unit": match unit { + IntervalUnit::YearMonth => "YEAR_MONTH", + IntervalUnit::DayTime => "DAY_TIME", + }}), + DataType::Duration(unit) => json!({"name": "duration", "unit": match unit { + TimeUnit::Second => "SECOND", + TimeUnit::Millisecond => "MILLISECOND", + TimeUnit::Microsecond => "MICROSECOND", + TimeUnit::Nanosecond => "NANOSECOND", + }}), + DataType::Dictionary(_, _) => json!({ "name": "dictionary"}), + DataType::Decimal(precision, scale) => { + json!({"name": "decimal", "precision": precision, "scale": scale}) + } + } + } +} + +impl ToJson for Field { + fn to_json(&self) -> Value { + let children: Vec = match self.data_type() { + DataType::Struct(fields) => fields.iter().map(|f| f.to_json()).collect(), + DataType::List(field) => vec![field.to_json()], + DataType::LargeList(field) => vec![field.to_json()], + DataType::FixedSizeList(field, _) => vec![field.to_json()], + _ => vec![], + }; + match self.data_type() { + DataType::Dictionary(ref index_type, ref value_type) => json!({ + "name": self.name(), + "nullable": self.is_nullable(), + "type": value_type.to_json(), + "children": children, + "dictionary": { + "id": self.dict_id(), + "indexType": index_type.to_json(), + "isOrdered": self.dict_is_ordered() + } + }), + _ => json!({ + "name": self.name(), + "nullable": self.is_nullable(), + "type": self.data_type().to_json(), + "children": children + }), + } + } +} + +impl TryFrom<&Value> for DataType { + type Error = ArrowError; + + fn try_from(value: &Value) -> Result { + let default_field = Field::new("", DataType::Boolean, true); + match *value { + Value::Object(ref map) => match map.get("name") { + Some(s) if s == "null" => Ok(DataType::Null), + Some(s) if s == "bool" => Ok(DataType::Boolean), + Some(s) if s == "binary" => Ok(DataType::Binary), + Some(s) if s == "largebinary" => Ok(DataType::LargeBinary), + Some(s) if s == "utf8" => Ok(DataType::Utf8), + Some(s) if s == "largeutf8" => Ok(DataType::LargeUtf8), + Some(s) if s == "fixedsizebinary" => { + // return a list with any type as its child isn't defined in the map + if let Some(Value::Number(size)) = map.get("byteWidth") { + Ok(DataType::FixedSizeBinary(size.as_i64().unwrap() as i32)) + } else { + Err(ArrowError::ParseError( + "Expecting a byteWidth for fixedsizebinary".to_string(), + )) + } + } + Some(s) if s == "decimal" => { + // return a list with any type as its child isn't defined in the map + let precision = match map.get("precision") { + Some(p) => Ok(p.as_u64().unwrap() as usize), + None => Err(ArrowError::ParseError( + "Expecting a precision for decimal".to_string(), + )), + }; + let scale = match map.get("scale") { + Some(s) => Ok(s.as_u64().unwrap() as usize), + _ => Err(ArrowError::ParseError( + "Expecting a scale for decimal".to_string(), + )), + }; + + Ok(DataType::Decimal(precision?, scale?)) + } + Some(s) if s == "floatingpoint" => match map.get("precision") { + Some(p) if p == "HALF" => Ok(DataType::Float16), + Some(p) if p == "SINGLE" => Ok(DataType::Float32), + Some(p) if p == "DOUBLE" => Ok(DataType::Float64), + _ => Err(ArrowError::ParseError( + "floatingpoint precision missing or invalid".to_string(), + )), + }, + Some(s) if s == "timestamp" => { + let unit = match map.get("unit") { + Some(p) if p == "SECOND" => Ok(TimeUnit::Second), + Some(p) if p == "MILLISECOND" => Ok(TimeUnit::Millisecond), + Some(p) if p == "MICROSECOND" => Ok(TimeUnit::Microsecond), + Some(p) if p == "NANOSECOND" => Ok(TimeUnit::Nanosecond), + _ => Err(ArrowError::ParseError( + "timestamp unit missing or invalid".to_string(), + )), + }; + let tz = match map.get("timezone") { + None => Ok(None), + Some(Value::String(tz)) => Ok(Some(tz.clone())), + _ => Err(ArrowError::ParseError( + "timezone must be a string".to_string(), + )), + }; + Ok(DataType::Timestamp(unit?, tz?)) + } + Some(s) if s == "date" => match map.get("unit") { + Some(p) if p == "DAY" => Ok(DataType::Date32), + Some(p) if p == "MILLISECOND" => Ok(DataType::Date64), + _ => Err(ArrowError::ParseError( + "date unit missing or invalid".to_string(), + )), + }, + Some(s) if s == "time" => { + let unit = match map.get("unit") { + Some(p) if p == "SECOND" => Ok(TimeUnit::Second), + Some(p) if p == "MILLISECOND" => Ok(TimeUnit::Millisecond), + Some(p) if p == "MICROSECOND" => Ok(TimeUnit::Microsecond), + Some(p) if p == "NANOSECOND" => Ok(TimeUnit::Nanosecond), + _ => Err(ArrowError::ParseError( + "time unit missing or invalid".to_string(), + )), + }; + match map.get("bitWidth") { + Some(p) if p == 32 => Ok(DataType::Time32(unit?)), + Some(p) if p == 64 => Ok(DataType::Time64(unit?)), + _ => Err(ArrowError::ParseError( + "time bitWidth missing or invalid".to_string(), + )), + } + } + Some(s) if s == "duration" => match map.get("unit") { + Some(p) if p == "SECOND" => Ok(DataType::Duration(TimeUnit::Second)), + Some(p) if p == "MILLISECOND" => Ok(DataType::Duration(TimeUnit::Millisecond)), + Some(p) if p == "MICROSECOND" => Ok(DataType::Duration(TimeUnit::Microsecond)), + Some(p) if p == "NANOSECOND" => Ok(DataType::Duration(TimeUnit::Nanosecond)), + _ => Err(ArrowError::ParseError( + "time unit missing or invalid".to_string(), + )), + }, + Some(s) if s == "interval" => match map.get("unit") { + Some(p) if p == "DAY_TIME" => Ok(DataType::Interval(IntervalUnit::DayTime)), + Some(p) if p == "YEAR_MONTH" => Ok(DataType::Interval(IntervalUnit::YearMonth)), + _ => Err(ArrowError::ParseError( + "interval unit missing or invalid".to_string(), + )), + }, + Some(s) if s == "int" => match map.get("isSigned") { + Some(&Value::Bool(true)) => match map.get("bitWidth") { + Some(&Value::Number(ref n)) => match n.as_u64() { + Some(8) => Ok(DataType::Int8), + Some(16) => Ok(DataType::Int16), + Some(32) => Ok(DataType::Int32), + Some(64) => Ok(DataType::Int64), + _ => Err(ArrowError::ParseError( + "int bitWidth missing or invalid".to_string(), + )), + }, + _ => Err(ArrowError::ParseError( + "int bitWidth missing or invalid".to_string(), + )), + }, + Some(&Value::Bool(false)) => match map.get("bitWidth") { + Some(&Value::Number(ref n)) => match n.as_u64() { + Some(8) => Ok(DataType::UInt8), + Some(16) => Ok(DataType::UInt16), + Some(32) => Ok(DataType::UInt32), + Some(64) => Ok(DataType::UInt64), + _ => Err(ArrowError::ParseError( + "int bitWidth missing or invalid".to_string(), + )), + }, + _ => Err(ArrowError::ParseError( + "int bitWidth missing or invalid".to_string(), + )), + }, + _ => Err(ArrowError::ParseError( + "int signed missing or invalid".to_string(), + )), + }, + Some(s) if s == "list" => { + // return a list with any type as its child isn't defined in the map + Ok(DataType::List(Box::new(default_field))) + } + Some(s) if s == "largelist" => { + // return a largelist with any type as its child isn't defined in the map + Ok(DataType::LargeList(Box::new(default_field))) + } + Some(s) if s == "fixedsizelist" => { + // return a list with any type as its child isn't defined in the map + if let Some(Value::Number(size)) = map.get("listSize") { + Ok(DataType::FixedSizeList( + Box::new(default_field), + size.as_i64().unwrap() as i32, + )) + } else { + Err(ArrowError::ParseError( + "Expecting a listSize for fixedsizelist".to_string(), + )) + } + } + Some(s) if s == "struct" => { + // return an empty `struct` type as its children aren't defined in the map + Ok(DataType::Struct(vec![])) + } + Some(other) => Err(ArrowError::ParseError(format!( + "invalid or unsupported type name: {} in {:?}", + other, value + ))), + None => Err(ArrowError::ParseError("type name missing".to_string())), + }, + _ => Err(ArrowError::ParseError( + "invalid json value type".to_string(), + )), + } + } +} + +impl TryFrom<&Value> for Field { + type Error = ArrowError; + + fn try_from(value: &Value) -> Result { + match *value { + Value::Object(ref map) => { + let name = match map.get("name") { + Some(&Value::String(ref name)) => name.to_string(), + _ => { + return Err(ArrowError::ParseError( + "Field missing 'name' attribute".to_string(), + )); + } + }; + let nullable = match map.get("nullable") { + Some(&Value::Bool(b)) => b, + _ => { + return Err(ArrowError::ParseError( + "Field missing 'nullable' attribute".to_string(), + )); + } + }; + let data_type = match map.get("type") { + Some(t) => DataType::try_from(t)?, + _ => { + return Err(ArrowError::ParseError( + "Field missing 'type' attribute".to_string(), + )); + } + }; + + // Referenced example file: testing/data/arrow-ipc-stream/integration/1.0.0-littleendian/generated_custom_metadata.json.gz + let metadata = match map.get("metadata") { + Some(&Value::Array(ref values)) => { + let mut res: BTreeMap = BTreeMap::new(); + for value in values { + match value.as_object() { + Some(map) => { + if map.len() != 2 { + return Err(ArrowError::ParseError( + "Field 'metadata' must have exact two entries for each key-value map".to_string(), + )); + } + if let (Some(k), Some(v)) = (map.get("key"), map.get("value")) { + if let (Some(k_str), Some(v_str)) = (k.as_str(), v.as_str()) + { + res.insert( + k_str.to_string().clone(), + v_str.to_string().clone(), + ); + } else { + return Err(ArrowError::ParseError("Field 'metadata' must have map value of string type".to_string())); + } + } else { + return Err(ArrowError::ParseError("Field 'metadata' lacks map keys named \"key\" or \"value\"".to_string())); + } + } + _ => { + return Err(ArrowError::ParseError( + "Field 'metadata' contains non-object key-value pair" + .to_string(), + )); + } + } + } + Some(res) + } + // We also support map format, because Schema's metadata supports this. + // See https://github.com/apache/arrow/pull/5907 + Some(&Value::Object(ref values)) => { + let mut res: BTreeMap = BTreeMap::new(); + for (k, v) in values { + if let Some(str_value) = v.as_str() { + res.insert(k.clone(), str_value.to_string().clone()); + } else { + return Err(ArrowError::ParseError(format!( + "Field 'metadata' contains non-string value for key {}", + k + ))); + } + } + Some(res) + } + Some(_) => { + return Err(ArrowError::ParseError( + "Field `metadata` is not json array".to_string(), + )); + } + _ => None, + }; + + // if data_type is a struct or list, get its children + let data_type = match data_type { + DataType::List(_) | DataType::LargeList(_) | DataType::FixedSizeList(_, _) => { + match map.get("children") { + Some(Value::Array(values)) => { + if values.len() != 1 { + return Err(ArrowError::ParseError( + "Field 'children' must have one element for a list data type".to_string(), + )); + } + match data_type { + DataType::List(_) => { + DataType::List(Box::new(Self::try_from(&values[0])?)) + } + DataType::LargeList(_) => { + DataType::LargeList(Box::new(Self::try_from(&values[0])?)) + } + DataType::FixedSizeList(_, int) => DataType::FixedSizeList( + Box::new(Self::try_from(&values[0])?), + int, + ), + _ => unreachable!( + "Data type should be a list, largelist or fixedsizelist" + ), + } + } + Some(_) => { + return Err(ArrowError::ParseError( + "Field 'children' must be an array".to_string(), + )) + } + None => { + return Err(ArrowError::ParseError( + "Field missing 'children' attribute".to_string(), + )); + } + } + } + DataType::Struct(mut fields) => match map.get("children") { + Some(Value::Array(values)) => { + let struct_fields: Result, _> = + values.iter().map(|v| Field::try_from(v)).collect(); + fields.append(&mut struct_fields?); + DataType::Struct(fields) + } + Some(_) => { + return Err(ArrowError::ParseError( + "Field 'children' must be an array".to_string(), + )) + } + None => { + return Err(ArrowError::ParseError( + "Field missing 'children' attribute".to_string(), + )); + } + }, + _ => data_type, + }; + + let mut dict_id = 0; + let mut dict_is_ordered = false; + + let data_type = match map.get("dictionary") { + Some(dictionary) => { + let index_type = match dictionary.get("indexType") { + Some(t) => DataType::try_from(t)?, + _ => { + return Err(ArrowError::ParseError( + "Field missing 'indexType' attribute".to_string(), + )); + } + }; + dict_id = match dictionary.get("id") { + Some(Value::Number(n)) => n.as_i64().unwrap(), + _ => { + return Err(ArrowError::ParseError( + "Field missing 'id' attribute".to_string(), + )); + } + }; + dict_is_ordered = match dictionary.get("isOrdered") { + Some(&Value::Bool(n)) => n, + _ => { + return Err(ArrowError::ParseError( + "Field missing 'isOrdered' attribute".to_string(), + )); + } + }; + DataType::Dictionary(Box::new(index_type), Box::new(data_type)) + } + _ => data_type, + }; + Ok(Field { + name, + nullable, + data_type, + dict_id, + dict_is_ordered, + metadata, + }) + } + _ => Err(ArrowError::ParseError( + "Invalid json value type for field".to_string(), + )), + } + } +} + +impl ToJson for Schema { + fn to_json(&self) -> Value { + json!({ + "fields": self.fields.iter().map(|field| field.to_json()).collect::>(), + "metadata": serde_json::to_value(&self.metadata).unwrap() + }) + } +} + +#[derive(Deserialize)] +struct MetadataKeyValue { + key: String, + value: String, +} + +/// Parse a `metadata` definition from a JSON representation. +/// The JSON can either be an Object or an Array of Objects. +fn from_metadata(json: &Value) -> Result, ArrowError> { + match json { + Value::Array(_) => { + let mut hashmap = HashMap::new(); + let values: Vec = + serde_json::from_value(json.clone()).map_err(|_| { + ArrowError::JsonError("Unable to parse object into key-value pair".to_string()) + })?; + for meta in values { + hashmap.insert(meta.key.clone(), meta.value); + } + Ok(hashmap) + } + Value::Object(md) => md + .iter() + .map(|(k, v)| { + if let Value::String(v) = v { + Ok((k.to_string(), v.to_string())) + } else { + Err(ArrowError::ParseError( + "metadata `value` field must be a string".to_string(), + )) + } + }) + .collect::>(), + _ => Err(ArrowError::ParseError( + "`metadata` field must be an object".to_string(), + )), + } +} + +impl TryFrom<&Value> for Schema { + type Error = ArrowError; + + fn try_from(json: &Value) -> Result { + match *json { + Value::Object(ref schema) => { + let fields = if let Some(Value::Array(fields)) = schema.get("fields") { + fields + .iter() + .map(|f| Field::try_from(f)) + .collect::>()? + } else { + return Err(ArrowError::ParseError( + "Schema fields should be an array".to_string(), + )); + }; + + let metadata = if let Some(value) = schema.get("metadata") { + from_metadata(value)? + } else { + HashMap::default() + }; + + Ok(Self { fields, metadata }) + } + _ => Err(ArrowError::ParseError( + "Invalid json value type for schema".to_string(), + )), + } + } +} diff --git a/src/datatypes/mod.rs b/src/datatypes/mod.rs new file mode 100644 index 00000000000..f9c745e2dad --- /dev/null +++ b/src/datatypes/mod.rs @@ -0,0 +1,171 @@ +mod field; +mod json; +mod primitive; +mod schema; + +pub use primitive::*; + +pub use field::Field; +pub use schema::Schema; + +/// The set of datatypes that are supported by this implementation of Apache Arrow. +/// +/// The Arrow specification on data types includes some more types. +/// See also [`Schema.fbs`](https://github.com/apache/arrow/blob/master/format/Schema.fbs) +/// for Arrow's specification. +/// +/// The variants of this enum include primitive fixed size types as well as parametric or +/// nested types. +/// Currently the Rust implementation supports the following nested types: +/// - `List` +/// - `Struct` +/// +/// Nested types can themselves be nested within other arrays. +/// For more information on these types please see +/// [the physical memory layout of Apache Arrow](https://arrow.apache.org/docs/format/Columnar.html#physical-memory-layout). +#[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)] +pub enum DataType { + /// Null type + Null, + /// A boolean datatype representing the values `true` and `false`. + Boolean, + /// A signed 8-bit integer. + Int8, + /// A signed 16-bit integer. + Int16, + /// A signed 32-bit integer. + Int32, + /// A signed 64-bit integer. + Int64, + /// An unsigned 8-bit integer. + UInt8, + /// An unsigned 16-bit integer. + UInt16, + /// An unsigned 32-bit integer. + UInt32, + /// An unsigned 64-bit integer. + UInt64, + /// A 16-bit floating point number. + Float16, + /// A 32-bit floating point number. + Float32, + /// A 64-bit floating point number. + Float64, + /// A timestamp with an optional timezone. + /// + /// Time is measured as a Unix epoch, counting the seconds from + /// 00:00:00.000 on 1 January 1970, excluding leap seconds, + /// as a 64-bit integer. + /// + /// The time zone is a string indicating the name of a time zone, one of: + /// + /// * As used in the Olson time zone database (the "tz database" or + /// "tzdata"), such as "America/New_York" + /// * An absolute time zone offset of the form +XX:XX or -XX:XX, such as +07:30 + Timestamp(TimeUnit, Option), + /// A 32-bit date representing the elapsed time since UNIX epoch (1970-01-01) + /// in days (32 bits). + Date32, + /// A 64-bit date representing the elapsed time since UNIX epoch (1970-01-01) + /// in milliseconds (64 bits). Values are evenly divisible by 86400000. + Date64, + /// A 32-bit time representing the elapsed time since midnight in the unit of `TimeUnit`. + Time32(TimeUnit), + /// A 64-bit time representing the elapsed time since midnight in the unit of `TimeUnit`. + Time64(TimeUnit), + /// Measure of elapsed time in either seconds, milliseconds, microseconds or nanoseconds. + Duration(TimeUnit), + /// A "calendar" interval which models types that don't necessarily + /// have a precise duration without the context of a base timestamp (e.g. + /// days can differ in length during day light savings time transitions). + Interval(IntervalUnit), + /// Opaque binary data of variable length. + Binary, + /// Opaque binary data of fixed size. + /// Enum parameter specifies the number of bytes per value. + FixedSizeBinary(i32), + /// Opaque binary data of variable length and 64-bit offsets. + LargeBinary, + /// A variable-length string in Unicode with UTF-8 encoding. + Utf8, + /// A variable-length string in Unicode with UFT-8 encoding and 64-bit offsets. + LargeUtf8, + /// A list of some logical data type with variable length. + List(Box), + /// A list of some logical data type with fixed length. + FixedSizeList(Box, i32), + /// A list of some logical data type with variable length and 64-bit offsets. + LargeList(Box), + /// A nested datatype that contains a number of sub-fields. + Struct(Vec), + /// A nested datatype that can represent slots of differing types. + Union(Vec), + /// A dictionary encoded array (`key_type`, `value_type`), where + /// each array element is an index of `key_type` into an + /// associated dictionary of `value_type`. + /// + /// Dictionary arrays are used to store columns of `value_type` + /// that contain many repeated values using less memory, but with + /// a higher CPU overhead for some operations. + /// + /// This type mostly used to represent low cardinality string + /// arrays or a limited set of primitive types as integers. + Dictionary(Box, Box), + /// Decimal value with precision and scale + Decimal(usize, usize), +} + +impl std::fmt::Display for DataType { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + write!(f, "{:?}", self) + } +} + +/// An absolute length of time in seconds, milliseconds, microseconds or nanoseconds. +#[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)] +pub enum TimeUnit { + /// Time in seconds. + Second, + /// Time in milliseconds. + Millisecond, + /// Time in microseconds. + Microsecond, + /// Time in nanoseconds. + Nanosecond, +} + +/// YEAR_MONTH or DAY_TIME interval in SQL style. +#[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)] +pub enum IntervalUnit { + /// Indicates the number of elapsed whole months, stored as 4-byte integers. + YearMonth, + /// Indicates the number of elapsed days and milliseconds, + /// stored as 2 contiguous 32-bit integers (8-bytes in total). + DayTime, +} + +impl DataType { + /// Compares the datatype with another, ignoring nested field names + /// and metadata. + pub(crate) fn equals_datatype(&self, other: &DataType) -> bool { + match (&self, other) { + (DataType::List(a), DataType::List(b)) + | (DataType::LargeList(a), DataType::LargeList(b)) => { + a.is_nullable() == b.is_nullable() && a.data_type().equals_datatype(b.data_type()) + } + (DataType::FixedSizeList(a, a_size), DataType::FixedSizeList(b, b_size)) => { + a_size == b_size + && a.is_nullable() == b.is_nullable() + && a.data_type().equals_datatype(b.data_type()) + } + (DataType::Struct(a), DataType::Struct(b)) => { + a.len() == b.len() + && a.iter().zip(b).all(|(a, b)| { + a.is_nullable() == b.is_nullable() + && a.data_type().equals_datatype(b.data_type()) + }) + } + _ => self == other, + } + } +} diff --git a/src/datatypes/primitive.rs b/src/datatypes/primitive.rs new file mode 100644 index 00000000000..5c737c3fb6a --- /dev/null +++ b/src/datatypes/primitive.rs @@ -0,0 +1,103 @@ +use crate::{ + buffer::NativeType, + datatypes::{DataType, IntervalUnit, TimeUnit}, +}; + +pub trait PrimitiveType: 'static { + /// Corresponding Rust native type for the primitive type. + type Native: NativeType; + + /// the corresponding Arrow data type of this primitive type. + const DATA_TYPE: DataType; +} + +macro_rules! make_type { + ($name:ident, $native_ty:ty, $data_ty:expr) => { + #[derive(Debug)] + pub struct $name {} + + impl PrimitiveType for $name { + type Native = $native_ty; + const DATA_TYPE: DataType = $data_ty; + } + }; +} + +make_type!(Int8Type, i8, DataType::Int8); +make_type!(Int16Type, i16, DataType::Int16); +make_type!(Int32Type, i32, DataType::Int32); +make_type!(Int64Type, i64, DataType::Int64); +make_type!(UInt8Type, u8, DataType::UInt8); +make_type!(UInt16Type, u16, DataType::UInt16); +make_type!(UInt32Type, u32, DataType::UInt32); +make_type!(UInt64Type, u64, DataType::UInt64); +make_type!(Float32Type, f32, DataType::Float32); +make_type!(Float64Type, f64, DataType::Float64); +make_type!( + TimestampSecondType, + i64, + DataType::Timestamp(TimeUnit::Second, None) +); +make_type!( + TimestampMillisecondType, + i64, + DataType::Timestamp(TimeUnit::Millisecond, None) +); +make_type!( + TimestampMicrosecondType, + i64, + DataType::Timestamp(TimeUnit::Microsecond, None) +); +make_type!( + TimestampNanosecondType, + i64, + DataType::Timestamp(TimeUnit::Nanosecond, None) +); +make_type!(Date32Type, i32, DataType::Date32); +make_type!(Date64Type, i64, DataType::Date64); +make_type!(Time32SecondType, i32, DataType::Time32(TimeUnit::Second)); +make_type!( + Time32MillisecondType, + i32, + DataType::Time32(TimeUnit::Millisecond) +); +make_type!( + Time64MicrosecondType, + i64, + DataType::Time64(TimeUnit::Microsecond) +); +make_type!( + Time64NanosecondType, + i64, + DataType::Time64(TimeUnit::Nanosecond) +); +make_type!( + IntervalYearMonthType, + i32, + DataType::Interval(IntervalUnit::YearMonth) +); +make_type!( + IntervalDayTimeType, + i64, + DataType::Interval(IntervalUnit::DayTime) +); +make_type!( + DurationSecondType, + i64, + DataType::Duration(TimeUnit::Second) +); +make_type!( + DurationMillisecondType, + i64, + DataType::Duration(TimeUnit::Millisecond) +); +make_type!( + DurationMicrosecondType, + i64, + DataType::Duration(TimeUnit::Microsecond) +); +make_type!( + DurationNanosecondType, + i64, + DataType::Duration(TimeUnit::Nanosecond) +); diff --git a/src/datatypes/schema.rs b/src/datatypes/schema.rs new file mode 100644 index 00000000000..8390546ef1d --- /dev/null +++ b/src/datatypes/schema.rs @@ -0,0 +1,196 @@ +use std::collections::HashMap; + +use crate::error::{ArrowError, Result}; + +use super::Field; + +/// Describes the meta-data of an ordered sequence of relative types. +/// +/// Note that this information is only part of the meta-data and not part of the physical +/// memory layout. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct Schema { + pub(crate) fields: Vec, + /// A map of key-value pairs containing additional meta data. + //#[serde(skip_serializing_if = "HashMap::is_empty")] + pub(crate) metadata: HashMap, +} + +impl Schema { + /// Creates an empty `Schema` + pub fn empty() -> Self { + Self { + fields: vec![], + metadata: HashMap::new(), + } + } + + /// Creates a new `Schema` from a sequence of `Field` values. + /// + /// # Example + /// + /// ``` + /// # extern crate arrow; + /// # use arrow::datatypes::{Field, DataType, Schema}; + /// let field_a = Field::new("a", DataType::Int64, false); + /// let field_b = Field::new("b", DataType::Boolean, false); + /// + /// let schema = Schema::new(vec![field_a, field_b]); + /// ``` + pub fn new(fields: Vec) -> Self { + Self::new_with_metadata(fields, HashMap::new()) + } + + /// Creates a new `Schema` from a sequence of `Field` values + /// and adds additional metadata in form of key value pairs. + /// + /// # Example + /// + /// ``` + /// # extern crate arrow; + /// # use arrow::datatypes::{Field, DataType, Schema}; + /// # use std::collections::HashMap; + /// let field_a = Field::new("a", DataType::Int64, false); + /// let field_b = Field::new("b", DataType::Boolean, false); + /// + /// let mut metadata: HashMap = HashMap::new(); + /// metadata.insert("row_count".to_string(), "100".to_string()); + /// + /// let schema = Schema::new_with_metadata(vec![field_a, field_b], metadata); + /// ``` + #[inline] + pub const fn new_with_metadata(fields: Vec, metadata: HashMap) -> Self { + Self { fields, metadata } + } + + /// Merge schema into self if it is compatible. Struct fields will be merged recursively. + /// + /// Example: + /// + /// ``` + /// use arrow::datatypes::*; + /// + /// let merged = Schema::try_merge(vec![ + /// Schema::new(vec![ + /// Field::new("c1", DataType::Int64, false), + /// Field::new("c2", DataType::Utf8, false), + /// ]), + /// Schema::new(vec![ + /// Field::new("c1", DataType::Int64, true), + /// Field::new("c2", DataType::Utf8, false), + /// Field::new("c3", DataType::Utf8, false), + /// ]), + /// ]).unwrap(); + /// + /// assert_eq!( + /// merged, + /// Schema::new(vec![ + /// Field::new("c1", DataType::Int64, true), + /// Field::new("c2", DataType::Utf8, false), + /// Field::new("c3", DataType::Utf8, false), + /// ]), + /// ); + /// ``` + pub fn try_merge(schemas: impl IntoIterator) -> Result { + schemas + .into_iter() + .try_fold(Self::empty(), |mut merged, schema| { + let Schema { metadata, fields } = schema; + for (key, value) in metadata.into_iter() { + // merge metadata + if let Some(old_val) = merged.metadata.get(&key) { + if old_val != &value { + return Err(ArrowError::SchemaError( + "Fail to merge schema due to conflicting metadata.".to_string(), + )); + } + } + merged.metadata.insert(key, value); + } + // merge fields + for field in fields.into_iter() { + let mut new_field = true; + for merged_field in &mut merged.fields { + if field.name() != merged_field.name() { + continue; + } + new_field = false; + merged_field.try_merge(&field)? + } + // found a new field, add to field list + if new_field { + merged.fields.push(field); + } + } + Ok(merged) + }) + } + + /// Returns an immutable reference of the vector of `Field` instances. + #[inline] + pub const fn fields(&self) -> &Vec { + &self.fields + } + + /// Returns an immutable reference of a specific `Field` instance selected using an + /// offset within the internal `fields` vector. + pub fn field(&self, i: usize) -> &Field { + &self.fields[i] + } + + /// Returns an immutable reference of a specific `Field` instance selected by name. + pub fn field_with_name(&self, name: &str) -> Result<&Field> { + Ok(&self.fields[self.index_of(name)?]) + } + + /// Returns a vector of immutable references to all `Field` instances selected by + /// the dictionary ID they use. + pub fn fields_with_dict_id(&self, dict_id: i64) -> Vec<&Field> { + self.fields + .iter() + .filter(|f| f.dict_id() == Some(dict_id)) + .collect() + } + + /// Find the index of the column with the given name. + pub fn index_of(&self, name: &str) -> Result { + for i in 0..self.fields.len() { + if self.fields[i].name() == name { + return Ok(i); + } + } + let valid_fields: Vec = self.fields.iter().map(|f| f.name().clone()).collect(); + Err(ArrowError::InvalidArgumentError(format!( + "Unable to get field named \"{}\". Valid fields: {:?}", + name, valid_fields + ))) + } + + /// Returns an immutable reference to the Map of custom metadata key-value pairs. + #[inline] + pub const fn metadata(&self) -> &HashMap { + &self.metadata + } + + /// Look up a column by name and return a immutable reference to the column along with + /// its index. + pub fn column_with_name(&self, name: &str) -> Option<(usize, &Field)> { + self.fields + .iter() + .enumerate() + .find(|&(_, c)| c.name() == name) + } +} + +impl std::fmt::Display for Schema { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + f.write_str( + &self + .fields + .iter() + .map(|c| c.to_string()) + .collect::>() + .join(", "), + ) + } +} diff --git a/src/error.rs b/src/error.rs new file mode 100644 index 00000000000..db2234e17d5 --- /dev/null +++ b/src/error.rs @@ -0,0 +1,102 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Defines `ArrowError` for representing failures in various Arrow operations. +use std::fmt::{Debug, Display, Formatter}; + +use std::error::Error; + +/// Many different operations in the `arrow` crate return this error type. +#[derive(Debug)] +pub enum ArrowError { + /// Returned when functionality is not yet available. + NotYetImplemented(String), + ExternalError(Box), + MemoryError(String), + ParseError(String), + SchemaError(String), + ComputeError(String), + DivideByZero, + CsvError(String), + JsonError(String), + IoError(String), + InvalidArgumentError(String), + ParquetError(String), + /// Error during import or export to/from the C Data Interface + CDataInterface(String), + DictionaryKeyOverflowError, +} + +impl ArrowError { + /// Wraps an external error in an `ArrowError`. + pub fn from_external_error(error: Box) -> Self { + Self::ExternalError(error) + } +} + +impl From<::std::io::Error> for ArrowError { + fn from(error: std::io::Error) -> Self { + ArrowError::IoError(error.to_string()) + } +} + +impl From<::std::string::FromUtf8Error> for ArrowError { + fn from(error: std::string::FromUtf8Error) -> Self { + ArrowError::ParseError(error.to_string()) + } +} + +impl From for ArrowError { + fn from(error: serde_json::Error) -> Self { + ArrowError::JsonError(error.to_string()) + } +} + +impl Display for ArrowError { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + match self { + ArrowError::NotYetImplemented(source) => { + write!(f, "Not yet implemented: {}", &source) + } + ArrowError::ExternalError(source) => write!(f, "External error: {}", &source), + ArrowError::MemoryError(desc) => write!(f, "Memory error: {}", desc), + ArrowError::ParseError(desc) => write!(f, "Parser error: {}", desc), + ArrowError::SchemaError(desc) => write!(f, "Schema error: {}", desc), + ArrowError::ComputeError(desc) => write!(f, "Compute error: {}", desc), + ArrowError::DivideByZero => write!(f, "Divide by zero error"), + ArrowError::CsvError(desc) => write!(f, "Csv error: {}", desc), + ArrowError::JsonError(desc) => write!(f, "Json error: {}", desc), + ArrowError::IoError(desc) => write!(f, "Io error: {}", desc), + ArrowError::InvalidArgumentError(desc) => { + write!(f, "Invalid argument error: {}", desc) + } + ArrowError::ParquetError(desc) => { + write!(f, "Parquet argument error: {}", desc) + } + ArrowError::CDataInterface(desc) => { + write!(f, "C Data interface error: {}", desc) + } + ArrowError::DictionaryKeyOverflowError => { + write!(f, "Dictionary key bigger than the key type") + } + } + } +} + +impl Error for ArrowError {} + +pub type Result = std::result::Result; diff --git a/src/ffi/array.rs b/src/ffi/array.rs new file mode 100644 index 00000000000..84b14b3e9da --- /dev/null +++ b/src/ffi/array.rs @@ -0,0 +1,102 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Contains functionality to load an ArrayData from the C Data Interface + +use std::convert::TryFrom; + +use super::ffi::ArrowArray; +use crate::array::{BooleanArray, FromFFI}; +use crate::error::{ArrowError, Result}; +use crate::{ + array::{Array, PrimitiveArray}, + datatypes::DataType, +}; + +impl TryFrom for Box { + type Error = ArrowError; + + fn try_from(array: ArrowArray) -> Result { + let data_type = array.data_type()?; + + let array: Box = match data_type { + DataType::Boolean => Box::new(BooleanArray::try_from_ffi(data_type, array)?), + DataType::UInt32 => Box::new(PrimitiveArray::::try_from_ffi(data_type, array)?), + DataType::Date32 | DataType::Int32 => { + Box::new(PrimitiveArray::::try_from_ffi(data_type, array)?) + } + DataType::Date64 | DataType::Int64 => { + Box::new(PrimitiveArray::::try_from_ffi(data_type, array)?) + } + DataType::UInt64 => Box::new(PrimitiveArray::::try_from_ffi(data_type, array)?), + _ => unimplemented!(), + }; + + Ok(array) + } +} + +impl TryFrom> for ArrowArray { + type Error = ArrowError; + + fn try_from(array: Box) -> Result { + ArrowArray::try_new(array) + } +} + +#[cfg(test)] +mod tests { + use crate::array::{Array, Primitive}; + use crate::{datatypes::DataType, error::Result, ffi::ArrowArray}; + use std::convert::TryFrom; + + fn test_round_trip(expected: impl Array + Clone + 'static) -> Result<()> { + // create a `ArrowArray` from the data. + let b: Box = Box::new(expected.clone()); + let d1 = ArrowArray::try_from(b)?; + + // here we export the array as 2 pointers. We would have no control over ownership if it was not for + // the release mechanism. + let (array, schema) = ArrowArray::into_raw(d1); + + // simulate an external consumer by being the consumer + let d1 = unsafe { ArrowArray::try_from_raw(array, schema) }?; + + let result = Box::::try_from(d1)?; + + assert_eq!(result.as_ref(), &expected); + Ok(()) + } + + #[test] + fn test_u32() -> Result<()> { + let data = Primitive::::from(vec![Some(2), None, Some(1), None]).to(DataType::Int32); + test_round_trip(data) + } + + #[test] + fn test_u64() -> Result<()> { + let data = Primitive::::from(vec![Some(2), None, Some(1), None]).to(DataType::UInt64); + test_round_trip(data) + } + + #[test] + fn test_i64() -> Result<()> { + let data = Primitive::::from(vec![Some(2), None, Some(1), None]).to(DataType::Int64); + test_round_trip(data) + } +} diff --git a/src/ffi/ffi.rs b/src/ffi/ffi.rs new file mode 100644 index 00000000000..773d2d128e5 --- /dev/null +++ b/src/ffi/ffi.rs @@ -0,0 +1,557 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Contains declarations to bind to the [C Data Interface](https://arrow.apache.org/docs/format/CDataInterface.html). +//! +//! Generally, this module is divided in two main interfaces: +//! One interface maps C ABI to native Rust types, i.e. convert c-pointers, c_char, to native rust. +//! This is handled by [FFI_ArrowSchema] and [FFI_ArrowArray]. +//! +//! The second interface maps native Rust types to the Rust-specific implementation of Arrow such as `format` to `Datatype`, +//! `Buffer`, etc. This is handled by `ArrowArray`. +//! +//! ```rust +//! # use std::sync::Arc; +//! # use arrow::array::{Int32Array, Array, ArrayData, make_array_from_raw}; +//! # use arrow::error::{Result, ArrowError}; +//! # use arrow::compute::kernels::arithmetic; +//! # use std::convert::TryFrom; +//! # fn main() -> Result<()> { +//! // create an array natively +//! let array = Int32Array::from(vec![Some(1), None, Some(3)]); +//! +//! // export it +//! let (array_ptr, schema_ptr) = array.to_raw()?; +//! +//! // consumed and used by something else... +//! +//! // import it +//! let array = unsafe { make_array_from_raw(array_ptr, schema_ptr)? }; +//! +//! // perform some operation +//! let array = array.as_any().downcast_ref::().ok_or( +//! ArrowError::ParseError("Expects an int32".to_string()), +//! )?; +//! let array = arithmetic::add(&array, &array)?; +//! +//! // verify +//! assert_eq!(array, Int32Array::from(vec![Some(2), None, Some(6)])); +//! +//! // (drop/release) +//! Ok(()) +//! } +//! ``` + +/* +# Design: + +Main assumptions: +* A memory region is deallocated according it its own release mechanism. +* Rust shares memory regions between arrays. +* A memory region should be deallocated when no-one is using it. + +The design of this module is as follows: + +`ArrowArray` contains two `Arc`s, one per ABI-compatible `struct`, each containing data +according to the C Data Interface. These Arcs are used for ref counting of the structs +within Rust and lifetime management. + +Each ABI-compatible `struct` knowns how to `drop` itself, calling `release`. + +To import an array, unsafely create an `ArrowArray` from two pointers using [ArrowArray::try_from_raw]. +To export an array, create an `ArrowArray` using [ArrowArray::try_new]. +*/ + +use std::{ + ffi::CStr, + ffi::CString, + ptr::{self, NonNull}, + sync::Arc, +}; + +use crate::error::{ArrowError, Result}; +use crate::{ + array::Array, + buffer::{Bitmap, Buffer, NativeType}, +}; +use crate::{ + bits::bytes_for, + datatypes::{DataType, TimeUnit}, +}; + +/// ABI-compatible struct for `ArrowSchema` from C Data Interface +/// See +/// This was created by bindgen +#[repr(C)] +#[derive(Debug)] +pub struct FFI_ArrowSchema { + format: *const ::std::os::raw::c_char, + name: *const ::std::os::raw::c_char, + metadata: *const ::std::os::raw::c_char, + flags: i64, + n_children: i64, + children: *mut *mut FFI_ArrowSchema, + dictionary: *mut FFI_ArrowSchema, + release: ::std::option::Option, + private_data: *mut ::std::os::raw::c_void, +} + +// callback used to drop [FFI_ArrowSchema] when it is exported. +unsafe extern "C" fn release_schema(schema: *mut FFI_ArrowSchema) { + let schema = &mut *schema; + + // take ownership back to release it. + CString::from_raw(schema.format as *mut std::os::raw::c_char); + + schema.release = None; +} + +impl FFI_ArrowSchema { + /// create a new [FFI_ArrowSchema] from a format. + fn new(format: &str) -> FFI_ArrowSchema { + // + FFI_ArrowSchema { + format: CString::new(format).unwrap().into_raw(), + name: std::ptr::null_mut(), + metadata: std::ptr::null_mut(), + flags: 0, + n_children: 0, + children: ptr::null_mut(), + dictionary: std::ptr::null_mut(), + release: Some(release_schema), + private_data: std::ptr::null_mut(), + } + } + + /// create an empty [FFI_ArrowSchema] + fn empty() -> Self { + Self { + format: std::ptr::null_mut(), + name: std::ptr::null_mut(), + metadata: std::ptr::null_mut(), + flags: 0, + n_children: 0, + children: ptr::null_mut(), + dictionary: std::ptr::null_mut(), + release: None, + private_data: std::ptr::null_mut(), + } + } + + /// returns the format of this schema. + pub fn format(&self) -> &str { + unsafe { CStr::from_ptr(self.format) } + .to_str() + .expect("The external API has a non-utf8 as format") + } +} + +impl Drop for FFI_ArrowSchema { + fn drop(&mut self) { + match self.release { + None => (), + Some(release) => unsafe { release(self) }, + }; + } +} + +/// maps a DataType `format` to a [DataType](arrow::datatypes::DataType). +/// See https://arrow.apache.org/docs/format/CDataInterface.html#data-type-description-format-strings +fn to_datatype(format: &str) -> Result { + Ok(match format { + "n" => DataType::Null, + "b" => DataType::Boolean, + "c" => DataType::Int8, + "C" => DataType::UInt8, + "s" => DataType::Int16, + "S" => DataType::UInt16, + "i" => DataType::Int32, + "I" => DataType::UInt32, + "l" => DataType::Int64, + "L" => DataType::UInt64, + "e" => DataType::Float16, + "f" => DataType::Float32, + "g" => DataType::Float64, + "z" => DataType::Binary, + "Z" => DataType::LargeBinary, + "u" => DataType::Utf8, + "U" => DataType::LargeUtf8, + "tdD" => DataType::Date32, + "tdm" => DataType::Date64, + "tts" => DataType::Time32(TimeUnit::Second), + "ttm" => DataType::Time32(TimeUnit::Millisecond), + "ttu" => DataType::Time64(TimeUnit::Microsecond), + "ttn" => DataType::Time64(TimeUnit::Nanosecond), + _ => { + return Err(ArrowError::CDataInterface( + "The datatype \"{}\" is still not supported in Rust implementation".to_string(), + )) + } + }) +} + +/// the inverse of [to_datatype] +fn from_datatype(datatype: &DataType) -> Result { + Ok(match datatype { + DataType::Null => "n", + DataType::Boolean => "b", + DataType::Int8 => "c", + DataType::UInt8 => "C", + DataType::Int16 => "s", + DataType::UInt16 => "S", + DataType::Int32 => "i", + DataType::UInt32 => "I", + DataType::Int64 => "l", + DataType::UInt64 => "L", + DataType::Float16 => "e", + DataType::Float32 => "f", + DataType::Float64 => "g", + DataType::Binary => "z", + DataType::LargeBinary => "Z", + DataType::Utf8 => "u", + DataType::LargeUtf8 => "U", + DataType::Date32 => "tdD", + DataType::Date64 => "tdm", + DataType::Time32(TimeUnit::Second) => "tts", + DataType::Time32(TimeUnit::Millisecond) => "ttm", + DataType::Time64(TimeUnit::Microsecond) => "ttu", + DataType::Time64(TimeUnit::Nanosecond) => "ttn", + z => { + return Err(ArrowError::CDataInterface(format!( + "The datatype \"{:?}\" is still not supported in Rust implementation", + z + ))) + } + } + .to_string()) +} + +/// ABI-compatible struct for ArrowArray from C Data Interface +/// See +/// This was created by bindgen +#[repr(C)] +#[derive(Debug)] +pub struct FFI_ArrowArray { + pub(crate) length: i64, + pub(crate) null_count: i64, + pub(crate) offset: i64, + pub(crate) n_buffers: i64, + pub(crate) n_children: i64, + pub(crate) buffers: *mut *const ::std::os::raw::c_void, + children: *mut *mut FFI_ArrowArray, + dictionary: *mut FFI_ArrowArray, + release: ::std::option::Option, + // When exported, this MUST contain everything that is owned by this array. + // for example, any buffer pointed to in `buffers` must be here, as well as the `buffers` pointer + // itself. + // In other words, everything in [FFI_ArrowArray] must be owned by `private_data` and can assume + // that they do not outlive `private_data`. + private_data: *mut ::std::os::raw::c_void, +} + +// callback used to drop [FFI_ArrowArray] when it is exported +unsafe extern "C" fn release_array(array: *mut FFI_ArrowArray) { + if array.is_null() { + return; + } + let array = &mut *array; + // take ownership of `private_data`, therefore dropping it + Box::from_raw(array.private_data as *mut PrivateData); + + array.release = None; +} + +struct PrivateData { + array: Box, + buffers_ptr: Box<[*const std::os::raw::c_void]>, +} + +impl FFI_ArrowArray { + /// creates a new `FFI_ArrowArray` from existing data. + /// # Safety + /// This method releases `buffers`. Consumers of this struct *must* call `release` before + /// releasing this struct, or contents in `buffers` leak. + fn new(array: Box) -> Self { + let buffers_ptr = array + .buffers() + .iter() + .map(|maybe_buffer| match maybe_buffer { + // note that `raw_data` takes into account the buffer's offset + Some(b) => b.as_ptr() as *const std::os::raw::c_void, + None => std::ptr::null(), + }) + .collect::>(); + let pointer = buffers_ptr.as_ptr() as *mut *const std::ffi::c_void; + + Self { + length: array.len() as i64, + null_count: array.null_count() as i64, + offset: 0i64, + n_buffers: 3, // todo: fix me + n_children: 0, + buffers: pointer, + children: std::ptr::null_mut(), + dictionary: std::ptr::null_mut(), + release: Some(release_array), + private_data: Box::into_raw(Box::new(PrivateData { array, buffers_ptr })) + as *mut ::std::os::raw::c_void, + } + } + + // create an empty `FFI_ArrowArray`, which can be used to import data into + fn empty() -> Self { + Self { + length: 0, + null_count: 0, + offset: 0, + n_buffers: 0, + n_children: 0, + buffers: std::ptr::null_mut(), + children: std::ptr::null_mut(), + dictionary: std::ptr::null_mut(), + release: None, + private_data: std::ptr::null_mut(), + } + } +} + +/// returns a new buffer corresponding to the index `i` of the FFI array. It may not exist (null pointer). +/// `bits` is the number of bits that the native type of this buffer has. +/// # Panic +/// This function panics if `i` is larger or equal to `n_buffers`. +/// # Safety +/// This function assumes that `ceil(self.length * bits, 8)` is the size of the buffer +unsafe fn create_buffer( + array: Arc, + index: usize, + len: usize, +) -> Option> { + if array.buffers.is_null() { + return None; + } + let buffers = array.buffers as *mut *const u8; + + assert!(index < array.n_buffers as usize); + let ptr = *buffers.add(index); + + NonNull::new(ptr as *mut u8).map(|ptr| Buffer::from_unowned(ptr, len, array)) +} + +/// returns a new buffer corresponding to the index `i` of the FFI array. It may not exist (null pointer). +/// `bits` is the number of bits that the native type of this buffer has. +/// The size of the buffer will be `ceil(self.length * bits, 8)`. +/// # Panic +/// This function panics if `i` is larger or equal to `n_buffers`. +/// # Safety +/// This function assumes that `ceil(self.length * bits, 8)` is the size of the buffer +unsafe fn create_bitmap(array: Arc, index: usize, len: usize) -> Option { + if array.buffers.is_null() { + return None; + } + let buffers = array.buffers as *mut *const u8; + + assert!(index < array.n_buffers as usize); + let ptr = *buffers.add(index); + + NonNull::new(ptr as *mut u8).map(|ptr| Bitmap::from_unowned(ptr, len, array)) +} + +impl Drop for FFI_ArrowArray { + fn drop(&mut self) { + match self.release { + None => (), + Some(release) => unsafe { release(self) }, + }; + } +} + +/// Struct used to move an Array from and to the C Data Interface. +/// Its main responsibility is to expose functionality that requires +/// both [FFI_ArrowArray] and [FFI_ArrowSchema]. +/// +/// This struct has two main paths: +/// +/// ## Import from the C Data Interface +/// * [ArrowArray::empty] to allocate memory to be filled by an external call +/// * [ArrowArray::try_from_raw] to consume two non-null allocated pointers +/// ## Export to the C Data Interface +/// * [ArrowArray::try_new] to create a new [ArrowArray] from Rust-specific information +/// * [ArrowArray::into_raw] to expose two pointers for [FFI_ArrowArray] and [FFI_ArrowSchema]. +/// +/// # Safety +/// Whoever creates this struct is responsible for releasing their resources. Specifically, +/// consumers *must* call [ArrowArray::into_raw] and take ownership of the individual pointers, +/// calling [FFI_ArrowArray::release] and [FFI_ArrowSchema::release] accordingly. +/// +/// Furthermore, this struct assumes that the incoming data agrees with the C data interface. +#[derive(Debug)] +pub struct ArrowArray { + // these are ref-counted because they can be shared by multiple buffers. + array: Arc, + schema: Arc, +} + +impl ArrowArray { + /// creates a new `ArrowArray`. This is used to export to the C Data Interface. + /// # Safety + /// See safety of [ArrowArray] + pub fn try_new(array: Box) -> Result { + let format = from_datatype(array.data_type())?; + + let schema = Arc::new(FFI_ArrowSchema::new(&format)); + let array = Arc::new(FFI_ArrowArray::new(array)); + + Ok(ArrowArray { schema, array }) + } + + /// creates a new [ArrowArray] from two pointers. Used to import from the C Data Interface. + /// # Safety + /// See safety of [ArrowArray] + /// # Error + /// Errors if any of the pointers is null + pub unsafe fn try_from_raw( + array: *const FFI_ArrowArray, + schema: *const FFI_ArrowSchema, + ) -> Result { + if array.is_null() || schema.is_null() { + return Err(ArrowError::MemoryError( + "At least one of the pointers passed to `try_from_raw` is null".to_string(), + )); + }; + Ok(Self { + array: Arc::from_raw(array as *mut FFI_ArrowArray), + schema: Arc::from_raw(schema as *mut FFI_ArrowSchema), + }) + } + + /// creates a new empty [ArrowArray]. Used to import from the C Data Interface. + /// # Safety + /// See safety of [ArrowArray] + pub unsafe fn empty() -> Self { + let schema = Arc::new(FFI_ArrowSchema::empty()); + let array = Arc::new(FFI_ArrowArray::empty()); + ArrowArray { schema, array } + } + + /// exports [ArrowArray] to the C Data Interface + pub fn into_raw(this: ArrowArray) -> (*const FFI_ArrowArray, *const FFI_ArrowSchema) { + (Arc::into_raw(this.array), Arc::into_raw(this.schema)) + } + + /// returns the null bit buffer. + /// Rust implementation uses a buffer that is not part of the array of buffers. + /// The C Data interface's null buffer is part of the array of buffers. + pub fn null_bit_buffer(&self) -> Option { + let len = self.array.length; + unsafe { create_bitmap(self.array.clone(), 0, len as usize) } + } + + /// Returns the length, in slots, of the buffer `i` (indexed according to the C data interface) + // Rust implementation uses fixed-sized buffers, which require knowledge of their `len`. + // for variable-sized buffers, such as the second buffer of a stringArray, we need + // to fetch offset buffer's len to build the second buffer. + fn buffer_len(&self, i: usize) -> Result { + let data_type = &self.data_type()?; + + Ok(match (data_type, i) { + (DataType::Utf8, 1) + | (DataType::LargeUtf8, 1) + | (DataType::Binary, 1) + | (DataType::LargeBinary, 1) => { + // the len of the offset buffer (buffer 1) equals length + 1 + self.array.length as usize + 1 + } + (DataType::Utf8, 2) | (DataType::Binary, 2) => { + // the len of the data buffer (buffer 2) equals the last value of the offset buffer (buffer 1) + let len = self.buffer_len(1)?; + // first buffer is the null buffer => add(1) + let offset_buffer = unsafe { *(self.array.buffers as *mut *const u8).add(1) }; + // interpret as i32 + let offset_buffer = offset_buffer as *const i32; + // get last offset + (unsafe { *offset_buffer.add(len - 1) }) as usize + } + (DataType::LargeUtf8, 2) | (DataType::LargeBinary, 2) => { + // the len of the data buffer (buffer 2) equals the last value of the offset buffer (buffer 1) + let len = self.buffer_len(1)?; + // first buffer is the null buffer => add(1) + let offset_buffer = unsafe { *(self.array.buffers as *mut *const u8).add(1) }; + // interpret as i64 + let offset_buffer = offset_buffer as *const i64; + // get last offset + (unsafe { *offset_buffer.add(len - 1) }) as usize + } + // buffer len of primitive types + _ => self.array.length as usize, + }) + } + + /// returns all buffers, as organized by Rust (i.e. null buffer is skipped) + pub unsafe fn buffer(&self, index: usize) -> Result> { + // + 1: skip null buffer + let index = (index + 1) as usize; + + let len = self.buffer_len(index)?; + + create_buffer(self.array.clone(), index, len).ok_or_else(|| { + ArrowError::CDataInterface(format!( + "The external buffer at position {} is null.", + index - 1 + )) + }) + } + + /// returns all buffers, as organized by Rust (i.e. null buffer is skipped) + pub unsafe fn bitmap(&self, index: usize) -> Result { + // + 1: skip null buffer + let index = (index + 1) as usize; + + let len = bytes_for(self.array.length as usize); + + create_bitmap(self.array.clone(), index, len).ok_or_else(|| { + ArrowError::CDataInterface(format!( + "The external buffer at position {} is null.", + index - 1 + )) + }) + } + + /// the length of the array + pub fn len(&self) -> usize { + self.array.length as usize + } + + /// whether the array is empty + pub fn is_empty(&self) -> bool { + self.array.length == 0 + } + + /// the offset of the array + pub fn offset(&self) -> usize { + self.array.offset as usize + } + + /// the null count of the array + pub fn null_count(&self) -> usize { + self.array.null_count as usize + } + + /// the data_type as declared in the schema + pub fn data_type(&self) -> Result { + to_datatype(self.schema.format()) + } +} diff --git a/src/ffi/mod.rs b/src/ffi/mod.rs new file mode 100644 index 00000000000..9663f3500f5 --- /dev/null +++ b/src/ffi/mod.rs @@ -0,0 +1,69 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Contains declarations to bind to the [C Data Interface](https://arrow.apache.org/docs/format/CDataInterface.html). +//! +//! Generally, this module is divided in two main interfaces: +//! One interface maps C ABI to native Rust types, i.e. convert c-pointers, c_char, to native rust. +//! This is handled by [FFI_ArrowSchema] and [FFI_ArrowArray]. +//! +//! The second interface maps native Rust types to the Rust-specific implementation of Arrow such as `format` to `Datatype`, +//! `Buffer`, etc. This is handled by `ArrowArray`. +//! +//! ```rust +//! # use std::sync::Arc; +//! # use arrow::array::{Int32Array, Array, ArrayData, make_array_from_raw}; +//! # use arrow::error::{Result, ArrowError}; +//! # use arrow::compute::kernels::arithmetic; +//! # use std::convert::TryFrom; +//! # fn main() -> Result<()> { +//! // create an array natively +//! let array = Int32Array::from(vec![Some(1), None, Some(3)]); +//! +//! // export it +//! let (array_ptr, schema_ptr) = array.to_raw()?; +//! +//! // consumed and used by something else... +//! +//! // import it +//! let array = unsafe { make_array_from_raw(array_ptr, schema_ptr)? }; +//! +//! // perform some operation +//! let array = array.as_any().downcast_ref::().ok_or( +//! ArrowError::ParseError("Expects an int32".to_string()), +//! )?; +//! let array = arithmetic::add(&array, &array)?; +//! +//! // verify +//! assert_eq!(array, Int32Array::from(vec![Some(2), None, Some(6)])); +//! +//! // (drop/release) +//! Ok(()) +//! } +//! ``` + +mod array; +mod ffi; + +trait ToFFI { + // necessary for ffi. first must be the bitmap + fn buffers(&self) -> [Option>; 3]; + + fn offset(&self) -> usize; +} + +pub(crate) use ffi::{ArrowArray, FFI_ArrowArray}; diff --git a/src/io/csv/infer_schema.rs b/src/io/csv/infer_schema.rs new file mode 100644 index 00000000000..c4410104de9 --- /dev/null +++ b/src/io/csv/infer_schema.rs @@ -0,0 +1,216 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::{ + collections::HashSet, + fs::File, + io::{Read, Seek, SeekFrom}, +}; + +use csv::StringRecord; + +use crate::datatypes::DataType; +use crate::datatypes::{Field, Schema}; +use crate::error::Result; + +/// Infer the schema of a CSV file by reading through the first n records of the file, +/// with `max_read_records` controlling the maximum number of records to read. +/// +/// If `max_read_records` is not set, the whole file is read to infer its schema. +/// +/// Return infered schema and number of records used for inference. +pub fn infer_file_schema DataType>( + reader: &mut R, + delimiter: u8, + max_read_records: Option, + has_header: bool, + infer: &F, +) -> Result<(Schema, usize)> { + let mut csv_reader = csv::ReaderBuilder::new() + .delimiter(delimiter) + .from_reader(reader); + + // get or create header names + // when has_header is false, creates default column names with column_ prefix + let headers: Vec = if has_header { + let headers = &csv_reader.headers()?.clone(); + headers.iter().map(|s| s.to_string()).collect() + } else { + let first_record_count = &csv_reader.headers()?.len(); + (0..*first_record_count) + .map(|i| format!("column_{}", i + 1)) + .collect() + }; + + // save the csv reader position after reading headers + let position = csv_reader.position().clone(); + + let header_length = headers.len(); + // keep track of inferred field types + let mut column_types: Vec> = vec![HashSet::new(); header_length]; + // keep track of columns with nulls + let mut nulls: Vec = vec![false; header_length]; + + // return csv reader position to after headers + csv_reader.seek(position)?; + + let mut records_count = 0; + let mut fields = vec![]; + + let mut record = StringRecord::new(); + let max_records = max_read_records.unwrap_or(usize::MAX); + while records_count < max_records { + if !csv_reader.read_record(&mut record)? { + break; + } + records_count += 1; + + for i in 0..header_length { + if let Some(string) = record.get(i) { + if string.is_empty() { + nulls[i] = true; + } else { + column_types[i].insert(infer(string)); + } + } + } + } + + // build schema from inference results + for i in 0..header_length { + let possibilities = &column_types[i]; + let has_nulls = nulls[i]; + let field_name = &headers[i]; + + // determine data type based on possible types + // if there are incompatible types, use DataType::Utf8 + match possibilities.len() { + 1 => { + for dtype in possibilities.iter() { + fields.push(Field::new(&field_name, dtype.clone(), has_nulls)); + } + } + 2 => { + if possibilities.contains(&DataType::Int64) + && possibilities.contains(&DataType::Float64) + { + // we have an integer and double, fall down to double + fields.push(Field::new(&field_name, DataType::Float64, has_nulls)); + } else { + // default to Utf8 for conflicting datatypes (e.g bool and int) + fields.push(Field::new(&field_name, DataType::Utf8, has_nulls)); + } + } + _ => fields.push(Field::new(&field_name, DataType::Utf8, has_nulls)), + } + } + + // return the reader seek back to the start + csv_reader.into_inner().seek(SeekFrom::Start(0))?; + + Ok((Schema::new(fields), records_count)) +} + +/// Infer schema from a list of CSV files by reading through first n records +/// with `max_read_records` controlling the maximum number of records to read. +/// +/// Files will be read in the given order untill n records have been reached. +/// +/// If `max_read_records` is not set, all files will be read fully to infer the schema. +pub fn infer_schema_from_files DataType>( + files: &[String], + delimiter: u8, + max_read_records: Option, + has_header: bool, + infer: F, +) -> Result { + let mut schemas = vec![]; + let mut records_to_read = max_read_records.unwrap_or(std::usize::MAX); + + for fname in files.iter() { + let (schema, records_read) = infer_file_schema( + &mut File::open(fname)?, + delimiter, + Some(records_to_read), + has_header, + &infer, + )?; + if records_read == 0 { + continue; + } + schemas.push(schema.clone()); + records_to_read -= records_read; + if records_to_read == 0 { + break; + } + } + + Schema::try_merge(schemas) +} + +#[cfg(test)] +mod tests { + use super::*; + + use std::io::Write; + use tempfile::NamedTempFile; + + use crate::io::csv::reader::infer; + + #[test] + fn test_infer_schema_from_multiple_files() -> Result<()> { + let mut csv1 = NamedTempFile::new()?; + let mut csv2 = NamedTempFile::new()?; + let csv3 = NamedTempFile::new()?; // empty csv file should be skipped + let mut csv4 = NamedTempFile::new()?; + writeln!(csv1, "c1,c2,c3")?; + writeln!(csv1, "1,\"foo\",0.5")?; + writeln!(csv1, "3,\"bar\",1")?; + // reading csv2 will set c2 to optional + writeln!(csv2, "c1,c2,c3,c4")?; + writeln!(csv2, "10,,3.14,true")?; + // reading csv4 will set c3 to optional + writeln!(csv4, "c1,c2,c3")?; + writeln!(csv4, "10,\"foo\",")?; + + let schema = infer_schema_from_files( + &[ + csv3.path().to_str().unwrap().to_string(), + csv1.path().to_str().unwrap().to_string(), + csv2.path().to_str().unwrap().to_string(), + csv4.path().to_str().unwrap().to_string(), + ], + b',', + Some(3), // only csv1 and csv2 should be read + true, + infer, + )?; + + assert_eq!(schema.fields().len(), 4); + assert_eq!(false, schema.field(0).is_nullable()); + assert_eq!(true, schema.field(1).is_nullable()); + assert_eq!(false, schema.field(2).is_nullable()); + assert_eq!(false, schema.field(3).is_nullable()); + + assert_eq!(&DataType::Int64, schema.field(0).data_type()); + assert_eq!(&DataType::Utf8, schema.field(1).data_type()); + assert_eq!(&DataType::Float64, schema.field(2).data_type()); + assert_eq!(&DataType::Boolean, schema.field(3).data_type()); + + Ok(()) + } +} diff --git a/src/io/csv/mod.rs b/src/io/csv/mod.rs new file mode 100644 index 00000000000..033508d8e7d --- /dev/null +++ b/src/io/csv/mod.rs @@ -0,0 +1,28 @@ +//! Transfer data between the Arrow memory format and CSV (comma-separated values). + +use crate::error::ArrowError; +use chrono; +use csv; + +impl From for ArrowError { + fn from(error: csv::Error) -> Self { + ArrowError::ExternalError(Box::new(error)) + } +} + +impl From for ArrowError { + fn from(error: chrono::ParseError) -> Self { + ArrowError::ExternalError(Box::new(error)) + } +} + +mod parser; +pub mod reader; +pub mod writer; + +mod infer_schema; +mod read_boolean; +mod read_primitive; +pub use infer_schema::{infer_file_schema, infer_schema_from_files}; +pub use read_boolean::{new_boolean_array, BooleanParser}; +pub use read_primitive::{new_primitive_array, PrimitiveParser}; diff --git a/src/io/csv/parser.rs b/src/io/csv/parser.rs new file mode 100644 index 00000000000..a00aff310e0 --- /dev/null +++ b/src/io/csv/parser.rs @@ -0,0 +1,216 @@ +use chrono::Datelike; + +use crate::temporal_conversions::EPOCH_DAYS_FROM_CE; +use crate::{datatypes::*, error::ArrowError}; + +use super::{BooleanParser, PrimitiveParser}; + +pub trait GenericParser: + PrimitiveParser + + PrimitiveParser + + PrimitiveParser + + PrimitiveParser + + PrimitiveParser + + PrimitiveParser + + PrimitiveParser + + PrimitiveParser + + PrimitiveParser + + PrimitiveParser + + BooleanParser +{ +} + +#[derive(Debug, Clone, Copy)] +pub struct DefaultParser {} + +impl Default for DefaultParser { + fn default() -> Self { + Self {} + } +} + +impl + From> PrimitiveParser for DefaultParser {} + +impl + From> PrimitiveParser for DefaultParser {} + +impl + From> PrimitiveParser for DefaultParser {} + +impl + From> PrimitiveParser for DefaultParser {} + +impl + From> PrimitiveParser for DefaultParser {} + +impl + From> PrimitiveParser for DefaultParser {} + +impl + From> PrimitiveParser for DefaultParser { + #[inline] + fn parse( + &self, + string: &str, + data_type: &DataType, + row_number: usize, + ) -> Result, E> { + // default behavior: error if not able to parse, else `None` + match data_type { + DataType::Int32 => { + PrimitiveParser::::parse(self, string, data_type, row_number) + } + DataType::Date32 => { + let date = string.parse::()?; + Ok(Some(date.num_days_from_ce() - EPOCH_DAYS_FROM_CE)) + } + DataType::Time32(_) => Err(ArrowError::NotYetImplemented( + "Reading Time32 from CSV is not yet implemented".to_string(), + ) + .into()), + _ => unreachable!(), + } + } +} + +impl + From> PrimitiveParser for DefaultParser { + #[inline] + fn parse( + &self, + string: &str, + data_type: &DataType, + row_number: usize, + ) -> Result, E> { + match data_type { + DataType::Int64 => { + PrimitiveParser::::parse(self, string, data_type, row_number) + } + DataType::Date64 => { + let date_time = string.parse::()?; + Ok(Some(date_time.timestamp_millis())) + } + DataType::Time64(_) => Err(ArrowError::NotYetImplemented( + "Reading Time32 from CSV is not yet implemented".to_string(), + ) + .into()), + DataType::Timestamp(TimeUnit::Nanosecond, None) => { + let date_time = string.parse::()?; + Ok(Some(date_time.timestamp_nanos())) + } + DataType::Timestamp(TimeUnit::Microsecond, None) => { + let date_time = string.parse::()?; + Ok(Some(date_time.timestamp_nanos() / 1000)) + } + DataType::Timestamp(TimeUnit::Millisecond, None) => { + let date_time = string.parse::()?; + Ok(Some(date_time.timestamp_nanos() / 1_000_000)) + } + DataType::Timestamp(TimeUnit::Second, None) => { + let date_time = string.parse::()?; + Ok(Some(date_time.timestamp_nanos() / 1_000_000_000)) + } + DataType::Timestamp(_, _) => Err(ArrowError::NotYetImplemented( + "Reading time-zone aware timestamp from CSV is not yet implemented".to_string(), + ) + .into()), + _ => unreachable!(), + } + } +} + +impl + From> PrimitiveParser for DefaultParser {} + +impl + From> PrimitiveParser for DefaultParser {} + +impl BooleanParser for DefaultParser {} + +impl + From> GenericParser for DefaultParser {} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn parse_date32() { + let parser = DefaultParser::default(); + + let cases = vec![ + ("1970-01-01", Some(0)), + ("2020-03-15", Some(18336)), + ("1945-05-08", Some(-9004)), + ]; + + for (input, output) in cases { + let r: Result<_, ArrowError> = + PrimitiveParser::parse(&parser, input, &DataType::Date32, 0); + assert_eq!(r.unwrap(), output); + } + } + + #[test] + fn parse_date64() { + let parser = DefaultParser::default(); + + let cases = vec![ + ("1970-01-01T00:00:00", Some(0i64)), + ("2018-11-13T17:11:10", Some(1542129070000)), + ("2018-11-13T17:11:10.011", Some(1542129070011)), + ("1900-02-28T12:34:56", Some(-2203932304000)), + ]; + + for (input, output) in cases { + let r: Result<_, ArrowError> = + PrimitiveParser::parse(&parser, input, &DataType::Date64, 0); + assert_eq!(r.unwrap(), output); + } + } + + #[test] + fn test_parsing_bool() { + let parser = DefaultParser::default(); + + let cases = vec![ + ("true", Some(true)), + ("tRUe", Some(true)), + ("True", Some(true)), + ("TRUE", Some(true)), + ("t", None), + ("T", None), + ("", None), + ("false", Some(false)), + ("fALse", Some(false)), + ("False", Some(false)), + ("FALSE", Some(false)), + ("f", None), + ("F", None), + ("", None), + ]; + + for (input, output) in cases { + let r: Result<_, ArrowError> = BooleanParser::parse(&parser, input, 0); + assert_eq!(r.unwrap(), output); + } + } + + #[test] + fn test_parsing_float() { + let parser = DefaultParser::default(); + + let cases = vec![ + ("12.34", Some(12.34f64)), + ("12.0", Some(12.0)), + ("0.0", Some(0.0)), + ("inf", Some(f64::INFINITY)), + ("-inf", Some(f64::NEG_INFINITY)), + ("dd", None), + ("", None), + ]; + + for (input, output) in cases { + let r: Result<_, ArrowError> = + PrimitiveParser::parse(&parser, input, &DataType::Float64, 0); + assert_eq!(r.unwrap(), output); + } + + let r: Result, ArrowError> = + PrimitiveParser::parse(&parser, "nan", &DataType::Float64, 0); + assert!(r.unwrap().unwrap().is_nan()); + let r: Result, ArrowError> = + PrimitiveParser::parse(&parser, "NaN", &DataType::Float64, 0); + assert!(r.unwrap().unwrap().is_nan()); + } +} diff --git a/src/io/csv/read_boolean.rs b/src/io/csv/read_boolean.rs new file mode 100644 index 00000000000..4d0562d5414 --- /dev/null +++ b/src/io/csv/read_boolean.rs @@ -0,0 +1,38 @@ +use csv::StringRecord; + +use crate::array::BooleanArray; + +/// default behavior is infalible: `None` if unable to parse +pub trait BooleanParser { + fn parse(&self, string: &str, _: usize) -> Result, E> { + Ok(if string.eq_ignore_ascii_case("false") { + Some(false) + } else if string.eq_ignore_ascii_case("true") { + Some(true) + } else { + None + }) + } +} + +// parses a specific column (col_idx) into an Arrow Array. +pub fn new_boolean_array>( + line_number: usize, + rows: &[StringRecord], + col_idx: usize, + parser: &P, +) -> Result { + let iter = rows + .iter() + .enumerate() + .map(|(row_index, row)| match row.get(col_idx) { + Some(s) => { + if s.is_empty() { + return Ok(None); + } + parser.parse(s, row_index + line_number) + } + None => Ok(None), + }); + unsafe { BooleanArray::try_from_trusted_len_iter(iter) } +} diff --git a/src/io/csv/read_primitive.rs b/src/io/csv/read_primitive.rs new file mode 100644 index 00000000000..ab49f4b6200 --- /dev/null +++ b/src/io/csv/read_primitive.rs @@ -0,0 +1,39 @@ +use csv::StringRecord; +use lexical_core; + +use crate::{array::Primitive, datatypes::*}; +use crate::{array::PrimitiveArray, buffer::NativeType}; + +pub trait PrimitiveParser { + fn parse(&self, string: &str, _: &DataType, _: usize) -> Result, E> { + // default behavior is infalible: `None` if unable to parse + Ok(lexical_core::parse(string.as_bytes()).ok()) + } +} + +pub fn new_primitive_array< + T: NativeType + lexical_core::FromLexical, + E, + P: PrimitiveParser, +>( + line_number: usize, + rows: &[StringRecord], + col_idx: usize, + data_type: &DataType, + parser: &P, +) -> Result, E> { + let iter = rows + .iter() + .enumerate() + .map(|(row_index, row)| match row.get(col_idx) { + Some(s) => { + if s.is_empty() { + return Ok(None); + } + parser.parse(s, &data_type, line_number + row_index) + } + None => Ok(None), + }); + // Soundness: slice is trusted len. + Ok(unsafe { Primitive::::try_from_trusted_len_iter(iter) }?.to(data_type.clone())) +} diff --git a/src/io/csv/reader.rs b/src/io/csv/reader.rs new file mode 100644 index 00000000000..b56ba462875 --- /dev/null +++ b/src/io/csv/reader.rs @@ -0,0 +1,436 @@ +use std::io::Read; +use std::sync::Arc; +use std::{fmt, io::Seek}; + +use csv; +use csv::{ByteRecord, StringRecord}; +use lazy_static::lazy_static; +use regex::{Regex, RegexBuilder}; + +use crate::record_batch::RecordBatch; +use crate::{ + array::Array, + datatypes::*, + error::{ArrowError, Result}, +}; + +use super::{ + infer_file_schema, new_boolean_array, new_primitive_array, + parser::{DefaultParser, GenericParser}, +}; + +type SchemaRef = Arc; + +// optional bounds of the reader, of the form (min line, max line). +type Bounds = Option<(usize, usize)>; + +/// CSV file reader +pub struct Reader> { + /// Explicit schema for the CSV file + schema: SchemaRef, + /// Optional projection for which columns to load (zero-based column indices) + projection: Option>, + /// File reader + reader: csv::Reader, + /// Current line number + line_number: usize, + /// Maximum number of rows to read + end: usize, + /// Number of records per batch + batch_size: usize, + /// Vector that can hold the `StringRecord`s of the batches + batch_records: Vec, + + /// the parser to use + parser: P, +} + +impl fmt::Debug for Reader +where + R: Read, + P: GenericParser, +{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("Reader") + .field("schema", &self.schema) + .field("projection", &self.projection) + .field("line_number", &self.line_number) + .finish() + } +} + +impl> Reader { + /// Create a new CsvReader from any value that implements the `Read` trait. + /// + /// If reading a `File` or an input that supports `std::io::Read` and `std::io::Seek`; + /// you can customise the Reader, such as to enable schema inference, use + /// `ReaderBuilder`. + pub fn new( + reader: R, + schema: SchemaRef, + has_header: bool, + delimiter: Option, + batch_size: usize, + bounds: Bounds, + projection: Option>, + parser: P, + ) -> Self { + Self::from_reader( + reader, schema, has_header, delimiter, batch_size, bounds, projection, parser, + ) + } + + /// Returns the schema of the reader, useful for getting the schema without reading + /// record batches + pub fn schema(&self) -> SchemaRef { + match &self.projection { + Some(projection) => { + let fields = self.schema.fields(); + let projected_fields: Vec = + projection.iter().map(|i| fields[*i].clone()).collect(); + + Arc::new(Schema::new(projected_fields)) + } + None => self.schema.clone(), + } + } + + /// Create a new CsvReader from a Reader + /// + /// This constructor allows you more flexibility in what records are processed by the + /// csv reader. + pub fn from_reader( + reader: R, + schema: SchemaRef, + has_header: bool, + delimiter: Option, + batch_size: usize, + bounds: Bounds, + projection: Option>, + parser: P, + ) -> Self { + let mut reader_builder = csv::ReaderBuilder::new(); + reader_builder.has_headers(has_header); + + if let Some(c) = delimiter { + reader_builder.delimiter(c); + } + + let mut csv_reader = reader_builder.from_reader(reader); + + let (start, end) = match bounds { + None => (0, usize::MAX), + Some((start, end)) => (start, end), + }; + + // First we will skip `start` rows + // note that this skips by iteration. This is because in general it is not possible + // to seek in CSV. However, skiping still saves the burden of creating arrow arrays, + // which is a slow operation that scales with the number of columns + + let mut record = ByteRecord::new(); + // Skip first start items + for _ in 0..start { + let res = csv_reader.read_byte_record(&mut record); + if !res.unwrap_or(false) { + break; + } + } + + // Initialize batch_records with StringRecords so they + // can be reused accross batches + let mut batch_records = Vec::with_capacity(batch_size); + batch_records.resize_with(batch_size, Default::default); + + Self { + schema, + projection, + reader: csv_reader, + line_number: if has_header { start + 1 } else { start }, + batch_size, + end, + batch_records, + parser, + } + } +} + +impl> Iterator for Reader { + type Item = Result; + + fn next(&mut self) -> Option { + let remaining = self.end - self.line_number; + + let mut read_records = 0; + for i in 0..std::cmp::min(self.batch_size, remaining) { + match self.reader.read_record(&mut self.batch_records[i]) { + Ok(true) => { + read_records += 1; + } + Ok(false) => break, + Err(e) => { + return Some(Err(ArrowError::ParseError(format!( + "Error parsing line {}: {:?}", + self.line_number + i, + e + )))) + } + } + } + + // return early if no data was loaded + if read_records == 0 { + return None; + } + + // parse the batches into a RecordBatch + let result = parse( + &self.batch_records[..read_records], + &self.schema.fields(), + &self.projection, + self.line_number, + &self.parser, + ); + + self.line_number += read_records; + + Some(result) + } +} + +macro_rules! primitive { + ($type:ty, $line_number:expr, $rows:expr, $i:expr, $data_type:expr, $parser:expr) => { + new_primitive_array::<$type, ArrowError, _>($line_number, $rows, $i, $data_type, $parser) + .map(|x| Arc::new(x) as Arc) + }; +} + +/// parses a slice of [csv_crate::StringRecord] into a [array::record_batch::RecordBatch]. +fn parse>( + rows: &[StringRecord], + fields: &[Field], + projection: &Option>, + line_number: usize, + parser: &P, +) -> Result { + let projection: Vec = match projection { + Some(ref v) => v.clone(), + None => fields.iter().enumerate().map(|(i, _)| i).collect(), + }; + + let arrays: Result>> = projection + .iter() + .map(|i| { + let i = *i; + let field = &fields[i]; + let data_type = field.data_type(); + match data_type { + DataType::Boolean => new_boolean_array(line_number, rows, i, parser) + .map(|x| Arc::new(x) as Arc), + DataType::Int8 => { + primitive!(i8, line_number, rows, i, data_type, parser) + } + DataType::Int16 => primitive!(i16, line_number, rows, i, data_type, parser), + DataType::Int32 | DataType::Date32 | DataType::Time32(_) => { + primitive!(i32, line_number, rows, i, data_type, parser) + } + DataType::Int64 + | DataType::Date64 + | DataType::Time64(_) + | DataType::Timestamp(_, None) => { + primitive!(i64, line_number, rows, i, data_type, parser) + } + DataType::UInt8 => primitive!(u8, line_number, rows, i, data_type, parser), + DataType::UInt16 => primitive!(u16, line_number, rows, i, data_type, parser), + DataType::UInt32 => primitive!(u32, line_number, rows, i, data_type, parser), + DataType::UInt64 => primitive!(u64, line_number, rows, i, data_type, parser), + DataType::Float32 => primitive!(f32, line_number, rows, i, data_type, parser), + DataType::Float64 => primitive!(f64, line_number, rows, i, data_type, parser), + other => Err(ArrowError::ParseError(format!( + "Unsupported data type {:?}", + other + ))), + } + }) + .collect(); + + let projected_fields: Vec = projection.iter().map(|i| fields[*i].clone()).collect(); + + let projected_schema = Arc::new(Schema::new(projected_fields)); + + arrays.and_then(|arr| RecordBatch::try_new(projected_schema, arr)) +} + +lazy_static! { + static ref DECIMAL_RE: Regex = Regex::new(r"^-?(\d+\.\d+)$").unwrap(); + static ref INTEGER_RE: Regex = Regex::new(r"^-?(\d+)$").unwrap(); + static ref BOOLEAN_RE: Regex = RegexBuilder::new(r"^(true)$|^(false)$") + .case_insensitive(true) + .build() + .unwrap(); + static ref DATE_RE: Regex = Regex::new(r"^\d{4}-\d\d-\d\d$").unwrap(); + static ref DATETIME_RE: Regex = Regex::new(r"^\d{4}-\d\d-\d\dT\d\d:\d\d:\d\d$").unwrap(); +} + +/// Infer the data type of a record +pub fn infer(string: &str) -> DataType { + // when quoting is enabled in the reader, these quotes aren't escaped, we default to + // Utf8 for them + if string.starts_with('"') { + return DataType::Utf8; + } + // match regex in a particular order + if BOOLEAN_RE.is_match(string) { + DataType::Boolean + } else if DECIMAL_RE.is_match(string) { + DataType::Float64 + } else if INTEGER_RE.is_match(string) { + DataType::Int64 + } else if DATETIME_RE.is_match(string) { + DataType::Date64 + } else if DATE_RE.is_match(string) { + DataType::Date32 + } else { + DataType::Utf8 + } +} + +/// CSV file reader builder +#[derive(Debug)] +pub struct ReaderBuilder { + /// Optional schema for the CSV file + /// + /// If the schema is not supplied, the reader will try to infer the schema + /// based on the CSV structure. + schema: Option, + /// Whether the file has headers or not + /// + /// If schema inference is run on a file with no headers, default column names + /// are created. + has_header: bool, + /// An optional column delimiter. Defaults to `b','` + delimiter: Option, + /// Optional maximum number of records to read during schema inference + /// + /// If a number is not provided, all the records are read. + max_records: Option, + /// Batch size (number of records to load each time) + /// + /// The default batch size when using the `ReaderBuilder` is 1024 records + batch_size: usize, + /// The bounds over which to scan the reader. `None` starts from 0 and runs until EOF. + bounds: Bounds, + /// Optional projection for which columns to load (zero-based column indices) + projection: Option>, +} + +impl Default for ReaderBuilder { + fn default() -> Self { + Self { + schema: None, + has_header: false, + delimiter: None, + max_records: None, + batch_size: 1024, + bounds: None, + projection: None, + } + } +} + +impl ReaderBuilder { + /// Create a new builder for configuring CSV parsing options. + /// + /// To convert a builder into a reader, call `ReaderBuilder::build` + /// + /// # Example + /// + /// ``` + /// extern crate arrow; + /// + /// use arrow::csv; + /// use std::fs::File; + /// + /// fn example() -> csv::Reader { + /// let file = File::open("test/data/uk_cities_with_headers.csv").unwrap(); + /// + /// // create a builder, inferring the schema with the first 100 records + /// let builder = csv::ReaderBuilder::new().infer_schema(Some(100)); + /// + /// let reader = builder.build(file).unwrap(); + /// + /// reader + /// } + /// ``` + pub fn new() -> ReaderBuilder { + ReaderBuilder::default() + } + + /// Set the CSV file's schema + pub fn with_schema(mut self, schema: SchemaRef) -> Self { + self.schema = Some(schema); + self + } + + /// Set whether the CSV file has headers + pub fn has_header(mut self, has_header: bool) -> Self { + self.has_header = has_header; + self + } + + /// Set the CSV file's column delimiter as a byte character + pub fn with_delimiter(mut self, delimiter: u8) -> Self { + self.delimiter = Some(delimiter); + self + } + + /// Set the CSV reader to infer the schema of the file + pub fn infer_schema(mut self, max_records: Option) -> Self { + // remove any schema that is set + self.schema = None; + self.max_records = max_records; + self + } + + /// Set the batch size (number of records to load at one time) + pub fn with_batch_size(mut self, batch_size: usize) -> Self { + self.batch_size = batch_size; + self + } + + /// Set the reader's column projection + pub fn with_projection(mut self, projection: Vec) -> Self { + self.projection = Some(projection); + self + } + + /// Create a new `Reader` from the `ReaderBuilder` + pub fn build(self, mut reader: R) -> Result> { + // check if schema should be inferred + let delimiter = self.delimiter.unwrap_or(b','); + let schema = match self.schema { + Some(schema) => schema, + None => { + let (inferred_schema, _) = infer_file_schema( + &mut reader, + delimiter, + self.max_records, + self.has_header, + &infer, + )?; + + Arc::new(inferred_schema) + } + }; + Ok(Reader::from_reader( + reader, + schema, + self.has_header, + self.delimiter, + self.batch_size, + None, + self.projection.clone(), + DefaultParser::default(), + )) + } +} diff --git a/src/io/csv/writer.rs b/src/io/csv/writer.rs new file mode 100644 index 00000000000..05d52174531 --- /dev/null +++ b/src/io/csv/writer.rs @@ -0,0 +1,548 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! CSV Writer +//! +//! This CSV writer allows Arrow data (in record batches) to be written as CSV files. +//! The writer does not support writing `ListArray` and `StructArray`. +//! +//! Example: +//! +//! ``` +//! use arrow::array::*; +//! use arrow::csv; +//! use arrow::datatypes::*; +//! use arrow::record_batch::RecordBatch; +//! use arrow::util::test_util::get_temp_file; +//! use std::fs::File; +//! use std::sync::Arc; +//! +//! let schema = Schema::new(vec![ +//! Field::new("c1", DataType::Utf8, false), +//! Field::new("c2", DataType::Float64, true), +//! Field::new("c3", DataType::UInt32, false), +//! Field::new("c3", DataType::Boolean, true), +//! ]); +//! let c1 = StringArray::from(vec![ +//! "Lorem ipsum dolor sit amet", +//! "consectetur adipiscing elit", +//! "sed do eiusmod tempor", +//! ]); +//! let c2 = PrimitiveArray::::from(vec![ +//! Some(123.564532), +//! None, +//! Some(-556132.25), +//! ]); +//! let c3 = PrimitiveArray::::from(vec![3, 2, 1]); +//! let c4 = BooleanArray::from(vec![Some(true), Some(false), None]); +//! +//! let batch = RecordBatch::try_new( +//! Arc::new(schema), +//! vec![Arc::new(c1), Arc::new(c2), Arc::new(c3), Arc::new(c4)], +//! ) +//! .unwrap(); +//! +//! let file = get_temp_file("out.csv", &[]); +//! +//! let mut writer = csv::Writer::new(file); +//! let batches = vec![&batch, &batch]; +//! for batch in batches { +//! writer.write(batch).unwrap(); +//! } +//! ``` + +use csv as csv_crate; + +use std::io::Write; + +use crate::array::*; +use crate::error::{ArrowError, Result}; +use crate::record_batch::RecordBatch; +use crate::{buffer::NativeType, datatypes::*, temporal_conversions, util::lexical_to_string}; +const DEFAULT_DATE_FORMAT: &str = "%F"; +const DEFAULT_TIME_FORMAT: &str = "%T"; +const DEFAULT_TIMESTAMP_FORMAT: &str = "%FT%H:%M:%S.%9f"; + +fn write_primitive_value(array: &dyn Array, i: usize) -> String +where + T: NativeType + lexical_core::ToLexical, +{ + let c = array.as_any().downcast_ref::>().unwrap(); + lexical_to_string(c.value(i)) +} + +/// A CSV writer +#[derive(Debug)] +pub struct Writer { + /// The object to write to + writer: csv_crate::Writer, + /// Column delimiter. Defaults to `b','` + delimiter: u8, + /// Whether file should be written with headers. Defaults to `true` + has_headers: bool, + /// The date format for date arrays + date_format: String, + /// The timestamp format for timestamp arrays + timestamp_format: String, + /// The time format for time arrays + time_format: String, + /// Is the beginning-of-writer + beginning: bool, +} + +impl Writer { + /// Create a new CsvWriter from a writable object, with default options + pub fn new(writer: W) -> Self { + let delimiter = b','; + let mut builder = csv_crate::WriterBuilder::new(); + let writer = builder.delimiter(delimiter).from_writer(writer); + Writer { + writer, + delimiter, + has_headers: true, + date_format: DEFAULT_DATE_FORMAT.to_string(), + time_format: DEFAULT_TIME_FORMAT.to_string(), + timestamp_format: DEFAULT_TIMESTAMP_FORMAT.to_string(), + beginning: true, + } + } + + /// Convert a record to a string vector + fn convert(&self, batch: &RecordBatch, row_index: usize, buffer: &mut [String]) -> Result<()> { + // TODO: it'd be more efficient if we could create `record: Vec<&[u8]> + for (col_index, item) in buffer.iter_mut().enumerate() { + let col = batch.column(col_index).as_ref(); + if col.is_null(row_index) { + // write an empty value + *item = "".to_string(); + continue; + } + *item = match col.data_type() { + DataType::Float64 => write_primitive_value::(col, row_index), + DataType::Float32 => write_primitive_value::(col, row_index), + DataType::Int8 => write_primitive_value::(col, row_index), + DataType::Int16 => write_primitive_value::(col, row_index), + DataType::Int32 => write_primitive_value::(col, row_index), + DataType::Int64 => write_primitive_value::(col, row_index), + DataType::UInt8 => write_primitive_value::(col, row_index), + DataType::UInt16 => write_primitive_value::(col, row_index), + DataType::UInt32 => write_primitive_value::(col, row_index), + DataType::UInt64 => write_primitive_value::(col, row_index), + DataType::Boolean => { + let c = col.as_any().downcast_ref::().unwrap(); + c.value(row_index).to_string() + } + DataType::Utf8 => { + let c = col.as_any().downcast_ref::>().unwrap(); + c.value(row_index).to_owned() + } + DataType::LargeUtf8 => { + let c = col.as_any().downcast_ref::>().unwrap(); + c.value(row_index).to_owned() + } + DataType::Date32 => { + let c = col.as_any().downcast_ref::>().unwrap(); + temporal_conversions::date32_to_datetime(c.value(row_index)) + .format(&self.date_format) + .to_string() + } + DataType::Date64 => { + let c = col.as_any().downcast_ref::>().unwrap(); + temporal_conversions::date64_to_datetime(c.value(row_index)) + .format(&self.date_format) + .to_string() + } + DataType::Time32(TimeUnit::Second) => { + let c = col.as_any().downcast_ref::>().unwrap(); + temporal_conversions::time32s_to_time(c.value(row_index)) + .format(&self.time_format) + .to_string() + } + DataType::Time32(TimeUnit::Millisecond) => { + let c = col.as_any().downcast_ref::>().unwrap(); + temporal_conversions::time32ms_to_time(c.value(row_index)) + .format(&self.time_format) + .to_string() + } + DataType::Time64(TimeUnit::Microsecond) => { + let c = col.as_any().downcast_ref::>().unwrap(); + temporal_conversions::time64us_to_time(c.value(row_index)) + .format(&self.time_format) + .to_string() + } + DataType::Time64(TimeUnit::Nanosecond) => { + let c = col.as_any().downcast_ref::>().unwrap(); + temporal_conversions::time64ns_to_time(c.value(row_index)) + .format(&self.time_format) + .to_string() + } + DataType::Timestamp(time_unit, _) => { + let c = col.as_any().downcast_ref::>().unwrap(); + + let datetime = match time_unit { + TimeUnit::Second => { + temporal_conversions::timestamp_s_to_datetime(c.value(row_index)) + } + TimeUnit::Millisecond => { + temporal_conversions::timestamp_ms_to_datetime(c.value(row_index)) + } + TimeUnit::Microsecond => { + temporal_conversions::timestamp_us_to_datetime(c.value(row_index)) + } + TimeUnit::Nanosecond => { + temporal_conversions::timestamp_ns_to_datetime(c.value(row_index)) + } + }; + format!("{}", datetime.format(&self.timestamp_format)) + } + t => { + // List and Struct arrays not supported by the writer, any + // other type needs to be implemented + return Err(ArrowError::CsvError(format!( + "CSV Writer does not support {:?} data type", + t + ))); + } + }; + } + Ok(()) + } + + /// Write a vector of record batches to a writable object + pub fn write(&mut self, batch: &RecordBatch) -> Result<()> { + let num_columns = batch.num_columns(); + if self.beginning { + if self.has_headers { + let mut headers: Vec = Vec::with_capacity(num_columns); + batch + .schema() + .fields() + .iter() + .for_each(|field| headers.push(field.name().to_string())); + self.writer.write_record(&headers[..])?; + } + self.beginning = false; + } + + let mut buffer = vec!["".to_string(); batch.num_columns()]; + + for row_index in 0..batch.num_rows() { + self.convert(batch, row_index, &mut buffer)?; + self.writer.write_record(&buffer)?; + } + self.writer.flush()?; + + Ok(()) + } +} + +/// A CSV writer builder +#[derive(Debug)] +pub struct WriterBuilder { + /// Optional column delimiter. Defaults to `b','` + delimiter: Option, + /// Whether to write column names as file headers. Defaults to `true` + has_headers: bool, + /// Optional date format for date arrays + date_format: Option, + /// Optional timestamp format for timestamp arrays + timestamp_format: Option, + /// Optional time format for time arrays + time_format: Option, +} + +impl Default for WriterBuilder { + fn default() -> Self { + Self { + has_headers: true, + delimiter: None, + date_format: Some(DEFAULT_DATE_FORMAT.to_string()), + time_format: Some(DEFAULT_TIME_FORMAT.to_string()), + timestamp_format: Some(DEFAULT_TIMESTAMP_FORMAT.to_string()), + } + } +} + +impl WriterBuilder { + /// Create a new builder for configuring CSV writing options. + /// + /// To convert a builder into a writer, call `WriterBuilder::build` + /// + /// # Example + /// + /// ``` + /// extern crate arrow; + /// + /// use arrow::csv; + /// use std::fs::File; + /// + /// fn example() -> csv::Writer { + /// let file = File::create("target/out.csv").unwrap(); + /// + /// // create a builder that doesn't write headers + /// let builder = csv::WriterBuilder::new().has_headers(false); + /// let writer = builder.build(file); + /// + /// writer + /// } + /// ``` + pub fn new() -> Self { + Self::default() + } + + /// Set whether to write headers + pub fn has_headers(mut self, has_headers: bool) -> Self { + self.has_headers = has_headers; + self + } + + /// Set the CSV file's column delimiter as a byte character + pub fn with_delimiter(mut self, delimiter: u8) -> Self { + self.delimiter = Some(delimiter); + self + } + + /// Set the CSV file's date format + pub fn with_date_format(mut self, format: String) -> Self { + self.date_format = Some(format); + self + } + + /// Set the CSV file's time format + pub fn with_time_format(mut self, format: String) -> Self { + self.time_format = Some(format); + self + } + + /// Set the CSV file's timestamp format + pub fn with_timestamp_format(mut self, format: String) -> Self { + self.timestamp_format = Some(format); + self + } + + /// Create a new `Writer` + pub fn build(self, writer: W) -> Writer { + let delimiter = self.delimiter.unwrap_or(b','); + let mut builder = csv_crate::WriterBuilder::new(); + let writer = builder.delimiter(delimiter).from_writer(writer); + Writer { + writer, + delimiter, + has_headers: self.has_headers, + date_format: self + .date_format + .unwrap_or_else(|| DEFAULT_DATE_FORMAT.to_string()), + time_format: self + .time_format + .unwrap_or_else(|| DEFAULT_TIME_FORMAT.to_string()), + timestamp_format: self + .timestamp_format + .unwrap_or_else(|| DEFAULT_TIMESTAMP_FORMAT.to_string()), + beginning: true, + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + use crate::datatypes::{Field, Schema}; + use crate::util::string_writer::StringWriter; + use crate::util::test_util::get_temp_file; + use std::fs::File; + use std::io::Read; + use std::sync::Arc; + + #[test] + fn test_write_csv() { + let schema = Schema::new(vec![ + Field::new("c1", DataType::Utf8, false), + Field::new("c2", DataType::Float64, true), + Field::new("c3", DataType::UInt32, false), + Field::new("c4", DataType::Boolean, true), + Field::new("c5", DataType::Timestamp(TimeUnit::Millisecond, None), true), + Field::new("c6", DataType::Time32(TimeUnit::Second), false), + ]); + + let c1 = Utf8Array::::from_slice([ + "Lorem ipsum dolor sit amet", + "consectetur adipiscing elit", + "sed do eiusmod tempor", + ]); + let c2 = Primitive::::from([Some(123.564532), None, Some(-556132.25)]) + .to(DataType::Float64); + let c3 = Primitive::::from_slice(vec![3, 2, 1]).to(DataType::UInt32); + let c4 = BooleanArray::from(vec![Some(true), Some(false), None]); + let c5 = Primitive::::from([None, Some(1555584887378), Some(1555555555555)]) + .to(DataType::Timestamp(TimeUnit::Millisecond, None)); + let c6 = Primitive::::from_slice(vec![1234, 24680, 85563]) + .to(DataType::Time32(TimeUnit::Second)); + + let batch = RecordBatch::try_new( + Arc::new(schema), + vec![ + Arc::new(c1), + Arc::new(c2), + Arc::new(c3), + Arc::new(c4), + Arc::new(c5), + Arc::new(c6), + ], + ) + .unwrap(); + + let file = get_temp_file("columns.csv", &[]); + + let mut writer = Writer::new(file); + let batches = vec![&batch, &batch]; + for batch in batches { + writer.write(batch).unwrap(); + } + // check that file was written successfully + let mut file = File::open("target/debug/testdata/columns.csv").unwrap(); + let mut buffer: Vec = vec![]; + file.read_to_end(&mut buffer).unwrap(); + + assert_eq!( + r#"c1,c2,c3,c4,c5,c6 +Lorem ipsum dolor sit amet,123.564532,3,true,,00:20:34 +consectetur adipiscing elit,,2,false,2019-04-18T10:54:47.378000000,06:51:20 +sed do eiusmod tempor,-556132.25,1,,2019-04-18T02:45:55.555000000,23:46:03 +Lorem ipsum dolor sit amet,123.564532,3,true,,00:20:34 +consectetur adipiscing elit,,2,false,2019-04-18T10:54:47.378000000,06:51:20 +sed do eiusmod tempor,-556132.25,1,,2019-04-18T02:45:55.555000000,23:46:03 +"# + .to_string(), + String::from_utf8(buffer).unwrap() + ); + } + + #[test] + fn test_write_csv_custom_options() { + let schema = Schema::new(vec![ + Field::new("c1", DataType::Utf8, false), + Field::new("c2", DataType::Float64, true), + Field::new("c3", DataType::UInt32, false), + Field::new("c4", DataType::Boolean, true), + Field::new("c6", DataType::Time32(TimeUnit::Second), false), + ]); + + let c1 = Utf8Array::::from_slice([ + "Lorem ipsum dolor sit amet", + "consectetur adipiscing elit", + "sed do eiusmod tempor", + ]); + let c2 = Primitive::::from([Some(123.564532), None, Some(-556132.25)]) + .to(DataType::Float64); + let c3 = Primitive::::from_slice([3, 2, 1]).to(DataType::UInt32); + let c4 = BooleanArray::from(vec![Some(true), Some(false), None]); + let c6 = Primitive::::from_slice([1234, 24680, 85563]) + .to(DataType::Time32(TimeUnit::Second)); + + let batch = RecordBatch::try_new( + Arc::new(schema), + vec![ + Arc::new(c1), + Arc::new(c2), + Arc::new(c3), + Arc::new(c4), + Arc::new(c6), + ], + ) + .unwrap(); + + let file = get_temp_file("custom_options.csv", &[]); + + let builder = WriterBuilder::new() + .has_headers(false) + .with_delimiter(b'|') + .with_time_format("%r".to_string()); + let mut writer = builder.build(file); + let batches = vec![&batch]; + for batch in batches { + writer.write(batch).unwrap(); + } + + // check that file was written successfully + let mut file = File::open("target/debug/testdata/custom_options.csv").unwrap(); + let mut buffer: Vec = vec![]; + file.read_to_end(&mut buffer).unwrap(); + + assert_eq!( + "Lorem ipsum dolor sit amet|123.564532|3|true|12:20:34 AM\nconsectetur adipiscing elit||2|false|06:51:20 AM\nsed do eiusmod tempor|-556132.25|1||11:46:03 PM\n" + .to_string(), + String::from_utf8(buffer).unwrap() + ); + } + + #[test] + fn test_export_csv_string() { + let schema = Schema::new(vec![ + Field::new("c1", DataType::Utf8, false), + Field::new("c2", DataType::Float64, true), + Field::new("c3", DataType::UInt32, false), + Field::new("c4", DataType::Boolean, true), + Field::new("c5", DataType::Timestamp(TimeUnit::Millisecond, None), true), + Field::new("c6", DataType::Time32(TimeUnit::Second), false), + ]); + + let c1 = Utf8Array::::from_slice([ + "Lorem ipsum dolor sit amet", + "consectetur adipiscing elit", + "sed do eiusmod tempor", + ]); + let c2 = Primitive::::from(vec![Some(123.564532), None, Some(-556132.25)]) + .to(DataType::Float64); + let c3 = Primitive::::from_slice(&[3, 2, 1]).to(DataType::UInt32); + let c4 = BooleanArray::from(vec![Some(true), Some(false), None]); + let c5 = Primitive::::from(vec![None, Some(1555584887378), Some(1555555555555)]) + .to(DataType::Timestamp(TimeUnit::Millisecond, None)); + let c6 = Primitive::::from_slice([1234, 24680, 85563]) + .to(DataType::Time32(TimeUnit::Second)); + + let batch = RecordBatch::try_new( + Arc::new(schema), + vec![ + Arc::new(c1), + Arc::new(c2), + Arc::new(c3), + Arc::new(c4), + Arc::new(c5), + Arc::new(c6), + ], + ) + .unwrap(); + + let sw = StringWriter::new(); + let mut writer = Writer::new(sw); + let batches = vec![&batch, &batch]; + for batch in batches { + writer.write(batch).unwrap(); + } + + let left = "c1,c2,c3,c4,c5,c6 +Lorem ipsum dolor sit amet,123.564532,3,true,,00:20:34 +consectetur adipiscing elit,,2,false,2019-04-18T10:54:47.378000000,06:51:20 +sed do eiusmod tempor,-556132.25,1,,2019-04-18T02:45:55.555000000,23:46:03 +Lorem ipsum dolor sit amet,123.564532,3,true,,00:20:34 +consectetur adipiscing elit,,2,false,2019-04-18T10:54:47.378000000,06:51:20 +sed do eiusmod tempor,-556132.25,1,,2019-04-18T02:45:55.555000000,23:46:03\n"; + let right = writer.writer.into_inner().map(|s| s.to_string()); + assert_eq!(Some(left.to_string()), right.ok()); + } +} diff --git a/src/io/mod.rs b/src/io/mod.rs new file mode 100644 index 00000000000..3b8c0845762 --- /dev/null +++ b/src/io/mod.rs @@ -0,0 +1,2 @@ +#[cfg(feature = "csv")] +pub mod csv; diff --git a/src/lib.rs b/src/lib.rs new file mode 100644 index 00000000000..031f769da5f --- /dev/null +++ b/src/lib.rs @@ -0,0 +1,16 @@ +pub mod buffer; + +pub(crate) mod bits; +pub mod datatypes; + +pub mod array; + +pub mod compute; +pub mod error; + +mod ffi; + +pub mod io; +pub mod record_batch; +pub mod temporal_conversions; +pub mod util; diff --git a/src/record_batch.rs b/src/record_batch.rs new file mode 100644 index 00000000000..d280822d7b3 --- /dev/null +++ b/src/record_batch.rs @@ -0,0 +1,311 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! A two-dimensional batch of column-oriented data with a defined +//! [schema](crate::datatypes::Schema). + +use std::sync::Arc; + +use crate::array::*; +use crate::datatypes::*; +use crate::error::{ArrowError, Result}; + +type SchemaRef = Arc; +type ArrayRef = Arc; + +/// A two-dimensional batch of column-oriented data with a defined +/// [schema](crate::datatypes::Schema). +/// +/// A `RecordBatch` is a two-dimensional dataset of a number of +/// contiguous arrays, each the same length. +/// A record batch has a schema which must match its arrays’ +/// datatypes. +/// +/// Record batches are a convenient unit of work for various +/// serialization and computation functions, possibly incremental. +/// See also [CSV reader](crate::csv::Reader) and +/// [JSON reader](crate::json::Reader). +#[derive(Clone, Debug)] +pub struct RecordBatch { + schema: SchemaRef, + columns: Vec, +} + +impl RecordBatch { + /// Creates a `RecordBatch` from a schema and columns. + /// + /// Expects the following: + /// * the vec of columns to not be empty + /// * the schema and column data types to have equal lengths + /// and match + /// * each array in columns to have the same length + /// + /// If the conditions are not met, an error is returned. + /// + /// # Example + /// + /// ``` + /// use std::sync::Arc; + /// use arrow::array::Int32Array; + /// use arrow::datatypes::{Schema, Field, DataType}; + /// use arrow::record_batch::RecordBatch; + /// + /// # fn main() -> arrow::error::Result<()> { + /// let id_array = Int32Array::from(vec![1, 2, 3, 4, 5]); + /// let schema = Schema::new(vec![ + /// Field::new("id", DataType::Int32, false) + /// ]); + /// + /// let batch = RecordBatch::try_new( + /// Arc::new(schema), + /// vec![Arc::new(id_array)] + /// )?; + /// # Ok(()) + /// # } + /// ``` + pub fn try_new(schema: SchemaRef, columns: Vec) -> Result { + let options = RecordBatchOptions::default(); + Self::validate_new_batch(&schema, columns.as_slice(), &options)?; + Ok(RecordBatch { schema, columns }) + } + + /// Creates a `RecordBatch` from a schema and columns, with additional options, + /// such as whether to strictly validate field names. + /// + /// See [`RecordBatch::try_new`] for the expected conditions. + pub fn try_new_with_options( + schema: SchemaRef, + columns: Vec, + options: &RecordBatchOptions, + ) -> Result { + Self::validate_new_batch(&schema, columns.as_slice(), options)?; + Ok(RecordBatch { schema, columns }) + } + + /// Creates a new empty [`RecordBatch`]. + pub fn new_empty(schema: SchemaRef) -> Self { + let columns = schema + .fields() + .iter() + .map(|field| new_empty_array(field.data_type().clone()).into()) + .collect(); + RecordBatch { schema, columns } + } + + /// Validate the schema and columns using [`RecordBatchOptions`]. Returns an error + /// if any validation check fails. + fn validate_new_batch( + schema: &SchemaRef, + columns: &[ArrayRef], + options: &RecordBatchOptions, + ) -> Result<()> { + // check that there are some columns + if columns.is_empty() { + return Err(ArrowError::InvalidArgumentError( + "at least one column must be defined to create a record batch".to_string(), + )); + } + // check that number of fields in schema match column length + if schema.fields().len() != columns.len() { + return Err(ArrowError::InvalidArgumentError(format!( + "number of columns({}) must match number of fields({}) in schema", + columns.len(), + schema.fields().len(), + ))); + } + // check that all columns have the same row count, and match the schema + let len = columns[0].len(); + + // This is a bit repetitive, but it is better to check the condition outside the loop + if options.match_field_names { + for (i, column) in columns.iter().enumerate() { + if column.len() != len { + return Err(ArrowError::InvalidArgumentError( + "all columns in a record batch must have the same length".to_string(), + )); + } + if column.data_type() != schema.field(i).data_type() { + return Err(ArrowError::InvalidArgumentError(format!( + "column types must match schema types, expected {:?} but found {:?} at column index {}", + schema.field(i).data_type(), + column.data_type(), + i))); + } + } + } else { + for (i, column) in columns.iter().enumerate() { + if column.len() != len { + return Err(ArrowError::InvalidArgumentError( + "all columns in a record batch must have the same length".to_string(), + )); + } + if !column + .data_type() + .equals_datatype(schema.field(i).data_type()) + { + return Err(ArrowError::InvalidArgumentError(format!( + "column types must match schema types, expected {:?} but found {:?} at column index {}", + schema.field(i).data_type(), + column.data_type(), + i))); + } + } + } + + Ok(()) + } + + /// Returns the [`Schema`](crate::datatypes::Schema) of the record batch. + pub fn schema(&self) -> SchemaRef { + self.schema.clone() + } + + /// Returns the number of columns in the record batch. + /// + /// # Example + /// + /// ``` + /// use std::sync::Arc; + /// use arrow::array::Int32Array; + /// use arrow::datatypes::{Schema, Field, DataType}; + /// use arrow::record_batch::RecordBatch; + /// + /// # fn main() -> arrow::error::Result<()> { + /// let id_array = Int32Array::from(vec![1, 2, 3, 4, 5]); + /// let schema = Schema::new(vec![ + /// Field::new("id", DataType::Int32, false) + /// ]); + /// + /// let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(id_array)])?; + /// + /// assert_eq!(batch.num_columns(), 1); + /// # Ok(()) + /// # } + /// ``` + pub fn num_columns(&self) -> usize { + self.columns.len() + } + + /// Returns the number of rows in each column. + /// + /// # Panics + /// + /// Panics if the `RecordBatch` contains no columns. + /// + /// # Example + /// + /// ``` + /// use std::sync::Arc; + /// use arrow::array::Int32Array; + /// use arrow::datatypes::{Schema, Field, DataType}; + /// use arrow::record_batch::RecordBatch; + /// + /// # fn main() -> arrow::error::Result<()> { + /// let id_array = Int32Array::from(vec![1, 2, 3, 4, 5]); + /// let schema = Schema::new(vec![ + /// Field::new("id", DataType::Int32, false) + /// ]); + /// + /// let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(id_array)])?; + /// + /// assert_eq!(batch.num_rows(), 5); + /// # Ok(()) + /// # } + /// ``` + pub fn num_rows(&self) -> usize { + self.columns[0].len() + } + + /// Get a reference to a column's array by index. + /// + /// # Panics + /// + /// Panics if `index` is outside of `0..num_columns`. + pub fn column(&self, index: usize) -> &ArrayRef { + &self.columns[index] + } + + /// Get a reference to all columns in the record batch. + pub fn columns(&self) -> &[ArrayRef] { + &self.columns[..] + } +} + +/// Options that control the behaviour used when creating a [`RecordBatch`]. +#[derive(Debug)] +pub struct RecordBatchOptions { + /// Match field names of structs and lists. If set to `true`, the names must match. + pub match_field_names: bool, +} + +impl Default for RecordBatchOptions { + fn default() -> Self { + Self { + match_field_names: true, + } + } +} + +impl From<&StructArray> for RecordBatch { + /// Create a record batch from struct array. + /// + /// This currently does not flatten and nested struct types + fn from(struct_array: &StructArray) -> Self { + if let DataType::Struct(fields) = struct_array.data_type() { + let schema = Schema::new(fields.clone()); + let columns = struct_array.values().to_vec(); + RecordBatch { + schema: Arc::new(schema), + columns, + } + } else { + unreachable!("unable to get datatype as struct") + } + } +} + +impl Into for RecordBatch { + fn into(self) -> StructArray { + let (fields, values) = self + .schema + .fields + .iter() + .zip(self.columns.iter()) + .map(|t| (t.0.clone(), t.1.clone())) + .unzip() + .into(); + StructArray::from_data(fields, values, None) + } +} + +/// Trait for types that can read `RecordBatch`'s. +pub trait RecordBatchReader: Iterator> { + /// Returns the schema of this `RecordBatchReader`. + /// + /// Implementation of this trait should guarantee that all `RecordBatch`'s returned by this + /// reader should have the same schema as returned from this method. + fn schema(&self) -> SchemaRef; + + /// Reads the next `RecordBatch`. + #[deprecated( + since = "2.0.0", + note = "This method is deprecated in favour of `next` from the trait Iterator." + )] + fn next_batch(&mut self) -> Result> { + self.next().transpose() + } +} diff --git a/src/temporal_conversions.rs b/src/temporal_conversions.rs new file mode 100644 index 00000000000..d8c37248026 --- /dev/null +++ b/src/temporal_conversions.rs @@ -0,0 +1,131 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Conversion methods for dates and times. + +use chrono::{NaiveDateTime, NaiveTime}; + +/// Number of seconds in a day +pub const SECONDS_IN_DAY: i64 = 86_400; +/// Number of milliseconds in a second +pub const MILLISECONDS: i64 = 1_000; +/// Number of microseconds in a second +pub const MICROSECONDS: i64 = 1_000_000; +/// Number of nanoseconds in a second +pub const NANOSECONDS: i64 = 1_000_000_000; +/// Number of milliseconds in a day +pub const MILLISECONDS_IN_DAY: i64 = SECONDS_IN_DAY * MILLISECONDS; +/// Number of days between 0001-01-01 and 1970-01-01 +pub const EPOCH_DAYS_FROM_CE: i32 = 719_163; + +/// converts a `i32` representing a `date32` to [`NaiveDateTime`] +#[inline] +pub fn date32_to_datetime(v: i32) -> NaiveDateTime { + NaiveDateTime::from_timestamp(v as i64 * SECONDS_IN_DAY, 0) +} + +/// converts a `i64` representing a `date64` to [`NaiveDateTime`] +#[inline] +pub fn date64_to_datetime(v: i64) -> NaiveDateTime { + NaiveDateTime::from_timestamp( + // extract seconds from milliseconds + v / MILLISECONDS, + // discard extracted seconds and convert milliseconds to nanoseconds + (v % MILLISECONDS * MICROSECONDS) as u32, + ) +} + +/// converts a `i32` representing a `time32(s)` to [`NaiveDateTime`] +#[inline] +pub fn time32s_to_time(v: i32) -> NaiveTime { + NaiveTime::from_num_seconds_from_midnight(v as u32, 0) +} + +/// converts a `i32` representing a `time32(ms)` to [`NaiveDateTime`] +#[inline] +pub fn time32ms_to_time(v: i32) -> NaiveTime { + let v = v as i64; + NaiveTime::from_num_seconds_from_midnight( + // extract seconds from milliseconds + (v / MILLISECONDS) as u32, + // discard extracted seconds and convert milliseconds to + // nanoseconds + (v % MILLISECONDS * MICROSECONDS) as u32, + ) +} + +/// converts a `i64` representing a `time64(us)` to [`NaiveDateTime`] +#[inline] +pub fn time64us_to_time(v: i64) -> NaiveTime { + NaiveTime::from_num_seconds_from_midnight( + // extract seconds from microseconds + (v / MICROSECONDS) as u32, + // discard extracted seconds and convert microseconds to + // nanoseconds + (v % MICROSECONDS * MILLISECONDS) as u32, + ) +} + +/// converts a `i64` representing a `time64(ns)` to [`NaiveDateTime`] +#[inline] +pub fn time64ns_to_time(v: i64) -> NaiveTime { + NaiveTime::from_num_seconds_from_midnight( + // extract seconds from nanoseconds + (v / NANOSECONDS) as u32, + // discard extracted seconds + (v % NANOSECONDS) as u32, + ) +} + +/// converts a `i64` representing a `timestamp(s)` to [`NaiveDateTime`] +#[inline] +pub fn timestamp_s_to_datetime(v: i64) -> NaiveDateTime { + NaiveDateTime::from_timestamp(v, 0) +} + +/// converts a `i64` representing a `timestamp(ms)` to [`NaiveDateTime`] +#[inline] +pub fn timestamp_ms_to_datetime(v: i64) -> NaiveDateTime { + NaiveDateTime::from_timestamp( + // extract seconds from milliseconds + v / MILLISECONDS, + // discard extracted seconds and convert milliseconds to nanoseconds + (v % MILLISECONDS * MICROSECONDS) as u32, + ) +} + +/// converts a `i64` representing a `timestamp(us)` to [`NaiveDateTime`] +#[inline] +pub fn timestamp_us_to_datetime(v: i64) -> NaiveDateTime { + NaiveDateTime::from_timestamp( + // extract seconds from microseconds + v / MICROSECONDS, + // discard extracted seconds and convert microseconds to nanoseconds + (v % MICROSECONDS * MILLISECONDS) as u32, + ) +} + +/// converts a `i64` representing a `timestamp(ns)` to [`NaiveDateTime`] +#[inline] +pub fn timestamp_ns_to_datetime(v: i64) -> NaiveDateTime { + NaiveDateTime::from_timestamp( + // extract seconds from nanoseconds + v / NANOSECONDS, + // discard extracted seconds + (v % NANOSECONDS) as u32, + ) +} diff --git a/src/util/mod.rs b/src/util/mod.rs new file mode 100644 index 00000000000..256e6b67a46 --- /dev/null +++ b/src/util/mod.rs @@ -0,0 +1,19 @@ +/// Converts numeric type to a `String` +pub fn lexical_to_string(n: N) -> String { + let mut buf = Vec::::with_capacity(N::FORMATTED_SIZE_DECIMAL); + unsafe { + // JUSTIFICATION + // Benefit + // Allows using the faster serializer lexical core and convert to string + // Soundness + // Length of buf is set as written length afterwards. lexical_core + // creates a valid string, so doesn't need to be checked. + let slice = std::slice::from_raw_parts_mut(buf.as_mut_ptr(), buf.capacity()); + let len = lexical_core::write(n, slice).len(); + buf.set_len(len); + String::from_utf8_unchecked(buf) + } +} + +pub mod string_writer; +pub mod test_util; diff --git a/src/util/string_writer.rs b/src/util/string_writer.rs new file mode 100644 index 00000000000..2a8175d1562 --- /dev/null +++ b/src/util/string_writer.rs @@ -0,0 +1,105 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! String Writer +//! This string writer encapsulates `std::string::String` and +//! implements `std::io::Write` trait, which makes String as a +//! writable object like File. +//! +//! Example: +//! +//! ``` +//! use arrow::array::*; +//! use arrow::csv; +//! use arrow::datatypes::*; +//! use arrow::record_batch::RecordBatch; +//! use arrow::util::string_writer::StringWriter; +//! use std::sync::Arc; +//! +//! let schema = Schema::new(vec![ +//! Field::new("c1", DataType::Utf8, false), +//! Field::new("c2", DataType::Float64, true), +//! Field::new("c3", DataType::UInt32, false), +//! Field::new("c3", DataType::Boolean, true), +//! ]); +//! let c1 = StringArray::from(vec![ +//! "Lorem ipsum dolor sit amet", +//! "consectetur adipiscing elit", +//! "sed do eiusmod tempor", +//! ]); +//! let c2 = PrimitiveArray::::from(vec![ +//! Some(123.564532), +//! None, +//! Some(-556132.25), +//! ]); +//! let c3 = PrimitiveArray::::from(vec![3, 2, 1]); +//! let c4 = BooleanArray::from(vec![Some(true), Some(false), None]); +//! +//! let batch = RecordBatch::try_new( +//! Arc::new(schema), +//! vec![Arc::new(c1), Arc::new(c2), Arc::new(c3), Arc::new(c4)], +//! ) +//! .unwrap(); +//! +//! let sw = StringWriter::new(); +//! let mut writer = csv::Writer::new(sw); +//! writer.write(&batch).unwrap(); +//! ``` + +use std::io::{Error, ErrorKind, Result, Write}; + +#[derive(Debug)] +pub struct StringWriter { + data: String, +} + +impl StringWriter { + pub fn new() -> Self { + StringWriter { + data: String::new(), + } + } +} + +impl Default for StringWriter { + fn default() -> Self { + Self::new() + } +} + +impl ToString for StringWriter { + fn to_string(&self) -> String { + self.data.clone() + } +} + +impl Write for StringWriter { + fn write(&mut self, buf: &[u8]) -> Result { + let string = match String::from_utf8(buf.to_vec()) { + Ok(x) => x, + Err(e) => { + return Err(Error::new(ErrorKind::InvalidData, e)); + } + }; + self.data.push_str(&string); + Ok(string.len()) + } + + fn flush(&mut self) -> Result<()> { + Ok(()) + } +} diff --git a/src/util/test_util.rs b/src/util/test_util.rs new file mode 100644 index 00000000000..6a70edda417 --- /dev/null +++ b/src/util/test_util.rs @@ -0,0 +1,62 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Utils to make testing easier + +use rand::{rngs::StdRng, Rng, SeedableRng}; +use std::{env, fs, io::Write}; + +/// Returns a vector of size `n`, filled with randomly generated bytes. +pub fn random_bytes(n: usize) -> Vec { + let mut result = vec![]; + let mut rng = seedable_rng(); + for _ in 0..n { + result.push(rng.gen_range(0, 255)); + } + result +} + +/// Returns fixed seedable RNG +pub fn seedable_rng() -> StdRng { + StdRng::seed_from_u64(42) +} + +/// Returns file handle for a temp file in 'target' directory with a provided content +/// +/// TODO: Originates from `parquet` utils, can be merged in [ARROW-4064] +pub fn get_temp_file(file_name: &str, content: &[u8]) -> fs::File { + // build tmp path to a file in "target/debug/testdata" + let mut path_buf = env::current_dir().unwrap(); + path_buf.push("target"); + path_buf.push("debug"); + path_buf.push("testdata"); + fs::create_dir_all(&path_buf).unwrap(); + path_buf.push(file_name); + + // write file content + let mut tmp_file = fs::File::create(path_buf.as_path()).unwrap(); + tmp_file.write_all(content).unwrap(); + tmp_file.sync_all().unwrap(); + + // return file handle for both read and write + let file = fs::OpenOptions::new() + .read(true) + .write(true) + .open(path_buf.as_path()); + assert!(file.is_ok()); + file.unwrap() +}