-
Notifications
You must be signed in to change notification settings - Fork 1.1k
[Variant] Implement DataType::Float16 => Variant::Float
#8073
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -18,10 +18,11 @@ | |
| use crate::{VariantArray, VariantArrayBuilder}; | ||
| use arrow::array::{Array, AsArray}; | ||
| use arrow::datatypes::{ | ||
| Float32Type, Float64Type, Int16Type, Int32Type, Int64Type, Int8Type, UInt16Type, UInt32Type, | ||
| UInt64Type, UInt8Type, | ||
| Float16Type, Float32Type, Float64Type, Int16Type, Int32Type, Int64Type, Int8Type, UInt16Type, | ||
| UInt32Type, UInt64Type, UInt8Type, | ||
| }; | ||
| use arrow_schema::{ArrowError, DataType}; | ||
| use half::f16; | ||
| use parquet_variant::Variant; | ||
|
|
||
| /// Convert the input array of a specific primitive type to a `VariantArray` | ||
|
|
@@ -39,6 +40,22 @@ macro_rules! primitive_conversion { | |
| }}; | ||
| } | ||
|
|
||
| /// Convert the input array to a `VariantArray` row by row, | ||
| /// transforming each element with `cast_fn` | ||
| macro_rules! cast_conversion { | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This macro applies We could also add the
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm working on a couple of these issues in parallel and made some additional tweaks to the macro here: https://github.com/apache/arrow-rs/pull/8074/files. |
||
| ($t:ty, $cast_fn:expr, $input:expr, $builder:expr) => {{ | ||
| let array = $input.as_primitive::<$t>(); | ||
| for i in 0..array.len() { | ||
| if array.is_null(i) { | ||
| $builder.append_null(); | ||
| continue; | ||
| } | ||
| let cast_value = $cast_fn(array.value(i)); | ||
| $builder.append_variant(Variant::from(cast_value)); | ||
| } | ||
| }}; | ||
| } | ||
|
|
||
| /// Casts a typed arrow [`Array`] to a [`VariantArray`]. This is useful when you | ||
| /// need to convert a specific data type | ||
| /// | ||
|
|
@@ -92,6 +109,9 @@ pub fn cast_to_variant(input: &dyn Array) -> Result<VariantArray, ArrowError> { | |
| DataType::UInt64 => { | ||
| primitive_conversion!(UInt64Type, input, builder); | ||
| } | ||
| DataType::Float16 => { | ||
| cast_conversion!(Float16Type, |v: f16| -> f32 { v.into() }, input, builder); | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Casted f16 to f32 so that the value can be wrapped by Variant::Float.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Makes sense to me. In general, getting a macro that knows how to convert various Arrow types to Variant I think is an important building block |
||
| } | ||
| DataType::Float32 => { | ||
| primitive_conversion!(Float32Type, input, builder); | ||
| } | ||
|
|
@@ -115,8 +135,8 @@ pub fn cast_to_variant(input: &dyn Array) -> Result<VariantArray, ArrowError> { | |
| mod tests { | ||
| use super::*; | ||
| use arrow::array::{ | ||
| ArrayRef, Float32Array, Float64Array, Int16Array, Int32Array, Int64Array, Int8Array, | ||
| UInt16Array, UInt32Array, UInt64Array, UInt8Array, | ||
| ArrayRef, Float16Array, Float32Array, Float64Array, Int16Array, Int32Array, Int64Array, | ||
| Int8Array, UInt16Array, UInt32Array, UInt64Array, UInt8Array, | ||
| }; | ||
| use parquet_variant::{Variant, VariantDecimal16}; | ||
| use std::sync::Arc; | ||
|
|
@@ -284,6 +304,28 @@ mod tests { | |
| ) | ||
| } | ||
|
|
||
| #[test] | ||
| fn test_cast_to_variant_float16() { | ||
| run_test( | ||
| Arc::new(Float16Array::from(vec![ | ||
| Some(f16::MIN), | ||
| None, | ||
| Some(f16::from_f32(-1.5)), | ||
| Some(f16::from_f32(0.0)), | ||
| Some(f16::from_f32(1.5)), | ||
| Some(f16::MAX), | ||
| ])), | ||
| vec![ | ||
| Some(Variant::Float(f16::MIN.into())), | ||
| None, | ||
| Some(Variant::Float(-1.5)), | ||
| Some(Variant::Float(0.0)), | ||
| Some(Variant::Float(1.5)), | ||
| Some(Variant::Float(f16::MAX.into())), | ||
| ], | ||
| ) | ||
| } | ||
|
|
||
| #[test] | ||
| fn test_cast_to_variant_float32() { | ||
| run_test( | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Needed to reference
f16in the code and in the tests.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah -- we should probably directly export he
f16type from the arrow crate (pub use) to avoid having users explicitly have tousehalf . Maybe as a follow on PR