From 7abc92b219b2f1b76052f201c5da0349dc3ffa92 Mon Sep 17 00:00:00 2001 From: klion26 Date: Fri, 12 Dec 2025 09:16:13 +0800 Subject: [PATCH] [Variant] Unify the CastOptions usage in parquet-variant-compute --- .../src/arrow_to_variant.rs | 79 ++++++++++++++----- .../src/cast_to_variant.rs | 20 +++-- parquet-variant-compute/src/lib.rs | 1 - .../src/type_conversion.rs | 13 --- 4 files changed, 74 insertions(+), 39 deletions(-) diff --git a/parquet-variant-compute/src/arrow_to_variant.rs b/parquet-variant-compute/src/arrow_to_variant.rs index 5e01aba3c1a1..3f6c4d444945 100644 --- a/parquet-variant-compute/src/arrow_to_variant.rs +++ b/parquet-variant-compute/src/arrow_to_variant.rs @@ -15,12 +15,11 @@ // specific language governing permissions and limitations // under the License. -use crate::type_conversion::CastOptions; use arrow::array::{ Array, AsArray, FixedSizeListArray, GenericBinaryArray, GenericListArray, GenericListViewArray, GenericStringArray, OffsetSizeTrait, PrimitiveArray, }; -use arrow::compute::kernels::cast; +use arrow::compute::{CastOptions, kernels::cast}; use arrow::datatypes::{ self as datatypes, ArrowNativeType, ArrowPrimitiveType, ArrowTemporalType, ArrowTimestampType, DecimalType, RunEndIndexType, @@ -367,7 +366,7 @@ macro_rules! define_row_builder { $( // NOTE: The `?` macro expansion fails without the type annotation. let Some(value): Option<$option_ty> = value else { - if self.options.strict { + if !self.options.safe { return Err(ArrowError::ComputeError(format!( "Failed to convert value at index {index}: conversion failed", ))); @@ -404,7 +403,7 @@ define_row_builder!( where V: VariantDecimalType, { - options: &'a CastOptions, + options: &'a CastOptions<'a>, scale: i8, }, |array| -> PrimitiveArray { array.as_primitive() }, @@ -414,7 +413,7 @@ define_row_builder!( // Decimal256 needs a two-stage conversion via i128 define_row_builder!( struct Decimal256ArrowToVariantBuilder<'a> { - options: &'a CastOptions, + options: &'a CastOptions<'a>, scale: i8, }, |array| -> arrow::array::Decimal256Array { array.as_primitive() }, @@ -426,7 +425,7 @@ define_row_builder!( define_row_builder!( struct TimestampArrowToVariantBuilder<'a, T: ArrowTimestampType> { - options: &'a CastOptions, + options: &'a CastOptions<'a>, has_time_zone: bool, }, |array| -> PrimitiveArray { array.as_primitive() }, @@ -450,7 +449,7 @@ define_row_builder!( where i64: From, { - options: &'a CastOptions, + options: &'a CastOptions<'a>, }, |array| -> PrimitiveArray { array.as_primitive() }, |value| -> Option<_> { @@ -464,7 +463,7 @@ define_row_builder!( where i64: From, { - options: &'a CastOptions, + options: &'a CastOptions<'a>, }, |array| -> PrimitiveArray { array.as_primitive() }, |value| -> Option<_> { @@ -899,7 +898,13 @@ mod tests { /// Builds a VariantArray from an Arrow array using the row builder. fn execute_row_builder_test(array: &dyn Array) -> VariantArray { - execute_row_builder_test_with_options(array, CastOptions::default()) + execute_row_builder_test_with_options( + array, + CastOptions { + safe: false, + ..Default::default() + }, + ) } /// Variant of `execute_row_builder_test` that allows specifying options @@ -925,7 +930,14 @@ mod tests { /// Generic helper function to test row builders with basic assertion patterns. /// Uses execute_row_builder_test and adds simple value comparison assertions. fn test_row_builder_basic(array: &dyn Array, expected_values: Vec>) { - test_row_builder_basic_with_options(array, expected_values, CastOptions::default()); + test_row_builder_basic_with_options( + array, + expected_values, + CastOptions { + safe: false, + ..Default::default() + }, + ); } /// Variant of `test_row_builder_basic` that allows specifying options @@ -1058,7 +1070,10 @@ mod tests { let run_ends = Int32Array::from(vec![2, 5, 6]); let run_array = RunArray::::try_new(&run_ends, &values).unwrap(); - let options = CastOptions::default(); + let options = CastOptions { + safe: false, + ..Default::default() + }; let mut row_builder = make_arrow_to_variant_row_builder(run_array.data_type(), &run_array, &options).unwrap(); @@ -1084,7 +1099,10 @@ mod tests { let run_ends = Int32Array::from(vec![2, 4, 5]); let run_array = RunArray::::try_new(&run_ends, &values).unwrap(); - let options = CastOptions::default(); + let options = CastOptions { + safe: false, + ..Default::default() + }; let mut row_builder = make_arrow_to_variant_row_builder(run_array.data_type(), &run_array, &options).unwrap(); let mut array_builder = VariantArrayBuilder::new(5); @@ -1135,7 +1153,10 @@ mod tests { let keys = Int32Array::from(vec![Some(0), None, Some(1), None, Some(2)]); let dict_array = DictionaryArray::::try_new(keys, Arc::new(values)).unwrap(); - let options = CastOptions::default(); + let options = CastOptions { + safe: false, + ..Default::default() + }; let mut row_builder = make_arrow_to_variant_row_builder(dict_array.data_type(), &dict_array, &options) .unwrap(); @@ -1167,7 +1188,10 @@ mod tests { let keys = Int32Array::from(vec![0, 1, 2, 0, 1, 2]); let dict_array = DictionaryArray::::try_new(keys, Arc::new(values)).unwrap(); - let options = CastOptions::default(); + let options = CastOptions { + safe: false, + ..Default::default() + }; let mut row_builder = make_arrow_to_variant_row_builder(dict_array.data_type(), &dict_array, &options) .unwrap(); @@ -1207,7 +1231,10 @@ mod tests { let dict_array = DictionaryArray::::try_new(keys, Arc::new(struct_array)).unwrap(); - let options = CastOptions::default(); + let options = CastOptions { + safe: false, + ..Default::default() + }; let mut row_builder = make_arrow_to_variant_row_builder(dict_array.data_type(), &dict_array, &options) .unwrap(); @@ -1302,7 +1329,10 @@ mod tests { // Slice to get just the middle element: [[3, 4, 5]] let sliced_array = list_array.slice(1, 1); - let options = CastOptions::default(); + let options = CastOptions { + safe: false, + ..Default::default() + }; let mut row_builder = make_arrow_to_variant_row_builder(sliced_array.data_type(), &sliced_array, &options) .unwrap(); @@ -1346,7 +1376,10 @@ mod tests { Some(arrow::buffer::NullBuffer::from(vec![true, false])), ); - let options = CastOptions::default(); + let options = CastOptions { + safe: false, + ..Default::default() + }; let mut row_builder = make_arrow_to_variant_row_builder(outer_list.data_type(), &outer_list, &options) .unwrap(); @@ -1539,7 +1572,10 @@ mod tests { .unwrap(); // Test the row builder - let options = CastOptions::default(); + let options = CastOptions { + safe: false, + ..Default::default() + }; let mut row_builder = make_arrow_to_variant_row_builder(union_array.data_type(), &union_array, &options) .unwrap(); @@ -1590,7 +1626,10 @@ mod tests { .unwrap(); // Test the row builder - let options = CastOptions::default(); + let options = CastOptions { + safe: false, + ..Default::default() + }; let mut row_builder = make_arrow_to_variant_row_builder(union_array.data_type(), &union_array, &options) .unwrap(); @@ -1668,7 +1707,7 @@ mod tests { Some(Variant::Null), // Overflow value becomes Variant::Null Some(Variant::from(VariantDecimal16::try_new(123, 3).unwrap())), ], - CastOptions { strict: false }, + CastOptions::default(), ); } diff --git a/parquet-variant-compute/src/cast_to_variant.rs b/parquet-variant-compute/src/cast_to_variant.rs index 4f400a5f7bd5..9f290d774451 100644 --- a/parquet-variant-compute/src/cast_to_variant.rs +++ b/parquet-variant-compute/src/cast_to_variant.rs @@ -16,8 +16,9 @@ // under the License. use crate::arrow_to_variant::make_arrow_to_variant_row_builder; -use crate::{CastOptions, VariantArray, VariantArrayBuilder}; +use crate::{VariantArray, VariantArrayBuilder}; use arrow::array::Array; +use arrow::compute::CastOptions; use arrow_schema::ArrowError; /// Casts a typed arrow [`Array`] to a [`VariantArray`]. This is useful when you @@ -75,9 +76,15 @@ pub fn cast_to_variant_with_options( /// failures). /// /// This function provides backward compatibility. For non-strict behavior, -/// use [`cast_to_variant_with_options`] with `CastOptions { strict: false }`. +/// use [`cast_to_variant_with_options`] with `CastOptions { safe: true, ..Default::default() }`. pub fn cast_to_variant(input: &dyn Array) -> Result { - cast_to_variant_with_options(input, &CastOptions::default()) + cast_to_variant_with_options( + input, + &CastOptions { + safe: false, + ..Default::default() + }, + ) } #[cfg(test)] @@ -2261,14 +2268,17 @@ mod tests { } fn run_test(values: ArrayRef, expected: Vec>) { - run_test_with_options(values, expected, CastOptions { strict: false }); + run_test_with_options(values, expected, CastOptions::default()); } fn run_test_in_strict_mode( values: ArrayRef, expected: Result>, ArrowError>, ) { - let options = CastOptions { strict: true }; + let options = CastOptions { + safe: false, + ..Default::default() + }; match expected { Ok(expected) => run_test_with_options(values, expected, options), Err(_) => { diff --git a/parquet-variant-compute/src/lib.rs b/parquet-variant-compute/src/lib.rs index 9b8008f58422..b05d0e023653 100644 --- a/parquet-variant-compute/src/lib.rs +++ b/parquet-variant-compute/src/lib.rs @@ -58,6 +58,5 @@ pub use cast_to_variant::{cast_to_variant, cast_to_variant_with_options}; pub use from_json::json_to_variant; pub use shred_variant::{IntoShreddingField, ShreddedSchemaBuilder, shred_variant}; pub use to_json::variant_to_json; -pub use type_conversion::CastOptions; pub use unshred_variant::unshred_variant; pub use variant_get::{GetOptions, variant_get}; diff --git a/parquet-variant-compute/src/type_conversion.rs b/parquet-variant-compute/src/type_conversion.rs index 01065175653f..6a0a743c9029 100644 --- a/parquet-variant-compute/src/type_conversion.rs +++ b/parquet-variant-compute/src/type_conversion.rs @@ -25,19 +25,6 @@ use arrow::datatypes::{ use chrono::Timelike; use parquet_variant::{Variant, VariantDecimal4, VariantDecimal8, VariantDecimal16}; -/// Options for controlling the behavior of `cast_to_variant_with_options`. -#[derive(Debug, Clone, PartialEq, Eq)] -pub struct CastOptions { - /// If true, return error on conversion failure. If false, insert null for failed conversions. - pub strict: bool, -} - -impl Default for CastOptions { - fn default() -> Self { - Self { strict: true } - } -} - /// Extension trait for Arrow primitive types that can extract their native value from a Variant pub(crate) trait PrimitiveFromVariant: ArrowPrimitiveType { fn from_variant(variant: &Variant<'_, '_>) -> Option;