diff --git a/parquet-variant-compute/Cargo.toml b/parquet-variant-compute/Cargo.toml index cc13810a2971..0aa926ee7fa4 100644 --- a/parquet-variant-compute/Cargo.toml +++ b/parquet-variant-compute/Cargo.toml @@ -33,6 +33,7 @@ rust-version = { workspace = true } [dependencies] arrow = { workspace = true } arrow-schema = { workspace = true } +half = { version = "2.1", default-features = false } parquet-variant = { workspace = true } parquet-variant-json = { workspace = true } @@ -49,4 +50,3 @@ arrow = { workspace = true, features = ["test_utils"] } [[bench]] name = "variant_kernels" harness = false - diff --git a/parquet-variant-compute/src/cast_to_variant.rs b/parquet-variant-compute/src/cast_to_variant.rs index 49bdd30cea6b..cbd16c589c61 100644 --- a/parquet-variant-compute/src/cast_to_variant.rs +++ b/parquet-variant-compute/src/cast_to_variant.rs @@ -18,10 +18,11 @@ use crate::{VariantArray, VariantArrayBuilder}; use arrow::array::{Array, AsArray}; use arrow::datatypes::{ - Float32Type, Float64Type, Int16Type, Int32Type, Int64Type, Int8Type, UInt16Type, UInt32Type, - UInt64Type, UInt8Type, + Float16Type, Float32Type, Float64Type, Int16Type, Int32Type, Int64Type, Int8Type, UInt16Type, + UInt32Type, UInt64Type, UInt8Type, }; use arrow_schema::{ArrowError, DataType}; +use half::f16; use parquet_variant::Variant; /// Convert the input array of a specific primitive type to a `VariantArray` @@ -39,6 +40,22 @@ macro_rules! primitive_conversion { }}; } +/// Convert the input array to a `VariantArray` row by row, +/// transforming each element with `cast_fn` +macro_rules! cast_conversion { + ($t:ty, $cast_fn:expr, $input:expr, $builder:expr) => {{ + let array = $input.as_primitive::<$t>(); + for i in 0..array.len() { + if array.is_null(i) { + $builder.append_null(); + continue; + } + let cast_value = $cast_fn(array.value(i)); + $builder.append_variant(Variant::from(cast_value)); + } + }}; +} + /// Casts a typed arrow [`Array`] to a [`VariantArray`]. This is useful when you /// need to convert a specific data type /// @@ -92,6 +109,9 @@ pub fn cast_to_variant(input: &dyn Array) -> Result { DataType::UInt64 => { primitive_conversion!(UInt64Type, input, builder); } + DataType::Float16 => { + cast_conversion!(Float16Type, |v: f16| -> f32 { v.into() }, input, builder); + } DataType::Float32 => { primitive_conversion!(Float32Type, input, builder); } @@ -115,8 +135,8 @@ pub fn cast_to_variant(input: &dyn Array) -> Result { mod tests { use super::*; use arrow::array::{ - ArrayRef, Float32Array, Float64Array, Int16Array, Int32Array, Int64Array, Int8Array, - UInt16Array, UInt32Array, UInt64Array, UInt8Array, + ArrayRef, Float16Array, Float32Array, Float64Array, Int16Array, Int32Array, Int64Array, + Int8Array, UInt16Array, UInt32Array, UInt64Array, UInt8Array, }; use parquet_variant::{Variant, VariantDecimal16}; use std::sync::Arc; @@ -284,6 +304,28 @@ mod tests { ) } + #[test] + fn test_cast_to_variant_float16() { + run_test( + Arc::new(Float16Array::from(vec![ + Some(f16::MIN), + None, + Some(f16::from_f32(-1.5)), + Some(f16::from_f32(0.0)), + Some(f16::from_f32(1.5)), + Some(f16::MAX), + ])), + vec![ + Some(Variant::Float(f16::MIN.into())), + None, + Some(Variant::Float(-1.5)), + Some(Variant::Float(0.0)), + Some(Variant::Float(1.5)), + Some(Variant::Float(f16::MAX.into())), + ], + ) + } + #[test] fn test_cast_to_variant_float32() { run_test(