diff --git a/arrow/src/util/data_gen.rs b/arrow/src/util/data_gen.rs index 89bbe4b1fbcb..023436e0a7f7 100644 --- a/arrow/src/util/data_gen.rs +++ b/arrow/src/util/data_gen.rs @@ -66,110 +66,72 @@ pub fn create_random_batch( pub fn create_random_array( field: &Field, size: usize, - null_density: f32, + mut null_density: f32, true_density: f32, ) -> Result { - // Override null density with 0.0 if the array is non-nullable - // and a primitive type in case a nested field is nullable - let primitive_null_density = match field.is_nullable() { - true => null_density, - false => 0.0, - }; + // Override nullability in case of not nested and not dictionary + // For nested we don't want to override as we want to keep the nullability for the children + // For dictionary it handle the nullability internally + if !field.data_type().is_nested() && !matches!(field.data_type(), Dictionary(_, _)) { + // Override null density with 0.0 if the array is non-nullable + null_density = match field.is_nullable() { + true => null_density, + false => 0.0, + }; + } + use DataType::*; - Ok(match field.data_type() { + let array = match field.data_type() { Null => Arc::new(NullArray::new(size)) as ArrayRef, - Boolean => Arc::new(create_boolean_array( - size, - primitive_null_density, - true_density, - )), - Int8 => Arc::new(create_primitive_array::( - size, - primitive_null_density, - )), - Int16 => Arc::new(create_primitive_array::( - size, - primitive_null_density, - )), - Int32 => Arc::new(create_primitive_array::( - size, - primitive_null_density, - )), - Int64 => Arc::new(create_primitive_array::( - size, - primitive_null_density, - )), - UInt8 => Arc::new(create_primitive_array::( - size, - primitive_null_density, - )), - UInt16 => Arc::new(create_primitive_array::( - size, - primitive_null_density, - )), - UInt32 => Arc::new(create_primitive_array::( - size, - primitive_null_density, - )), - UInt64 => Arc::new(create_primitive_array::( - size, - primitive_null_density, - )), + Boolean => Arc::new(create_boolean_array(size, null_density, true_density)), + Int8 => Arc::new(create_primitive_array::(size, null_density)), + Int16 => Arc::new(create_primitive_array::(size, null_density)), + Int32 => Arc::new(create_primitive_array::(size, null_density)), + Int64 => Arc::new(create_primitive_array::(size, null_density)), + UInt8 => Arc::new(create_primitive_array::(size, null_density)), + UInt16 => Arc::new(create_primitive_array::(size, null_density)), + UInt32 => Arc::new(create_primitive_array::(size, null_density)), + UInt64 => Arc::new(create_primitive_array::(size, null_density)), Float16 => { return Err(ArrowError::NotYetImplemented( "Float16 is not implemented".to_string(), )); } - Float32 => Arc::new(create_primitive_array::( - size, - primitive_null_density, - )), - Float64 => Arc::new(create_primitive_array::( - size, - primitive_null_density, - )), + Float32 => Arc::new(create_primitive_array::(size, null_density)), + Float64 => Arc::new(create_primitive_array::(size, null_density)), Timestamp(unit, tz) => match unit { TimeUnit::Second => Arc::new( - create_random_temporal_array::(size, primitive_null_density) + create_random_temporal_array::(size, null_density) .with_timezone_opt(tz.clone()), - ), + ) as ArrayRef, TimeUnit::Millisecond => Arc::new( - create_random_temporal_array::( - size, - primitive_null_density, - ) - .with_timezone_opt(tz.clone()), + create_random_temporal_array::(size, null_density) + .with_timezone_opt(tz.clone()), ), TimeUnit::Microsecond => Arc::new( - create_random_temporal_array::( - size, - primitive_null_density, - ) - .with_timezone_opt(tz.clone()), + create_random_temporal_array::(size, null_density) + .with_timezone_opt(tz.clone()), ), TimeUnit::Nanosecond => Arc::new( - create_random_temporal_array::( - size, - primitive_null_density, - ) - .with_timezone_opt(tz.clone()), + create_random_temporal_array::(size, null_density) + .with_timezone_opt(tz.clone()), ), }, Date32 => Arc::new(create_random_temporal_array::( size, - primitive_null_density, + null_density, )), Date64 => Arc::new(create_random_temporal_array::( size, - primitive_null_density, + null_density, )), Time32(unit) => match unit { TimeUnit::Second => Arc::new(create_random_temporal_array::( size, - primitive_null_density, + null_density, )) as ArrayRef, TimeUnit::Millisecond => Arc::new( - create_random_temporal_array::(size, primitive_null_density), + create_random_temporal_array::(size, null_density), ), _ => { return Err(ArrowError::InvalidArgumentError(format!( @@ -179,11 +141,11 @@ pub fn create_random_array( }, Time64(unit) => match unit { TimeUnit::Microsecond => Arc::new( - create_random_temporal_array::(size, primitive_null_density), + create_random_temporal_array::(size, null_density), ) as ArrayRef, TimeUnit::Nanosecond => Arc::new(create_random_temporal_array::( size, - primitive_null_density, + null_density, )), _ => { return Err(ArrowError::InvalidArgumentError(format!( @@ -191,24 +153,19 @@ pub fn create_random_array( ))); } }, - Utf8 => Arc::new(create_string_array::(size, primitive_null_density)), - LargeUtf8 => Arc::new(create_string_array::(size, primitive_null_density)), + Utf8 => Arc::new(create_string_array::(size, null_density)), + LargeUtf8 => Arc::new(create_string_array::(size, null_density)), Utf8View => Arc::new(create_string_view_array_with_len( size, - primitive_null_density, + null_density, 4, false, )), - Binary => Arc::new(create_binary_array::(size, primitive_null_density)), - LargeBinary => Arc::new(create_binary_array::(size, primitive_null_density)), - FixedSizeBinary(len) => Arc::new(create_fsb_array( - size, - primitive_null_density, - *len as usize, - )), + Binary => Arc::new(create_binary_array::(size, null_density)), + LargeBinary => Arc::new(create_binary_array::(size, null_density)), + FixedSizeBinary(len) => Arc::new(create_fsb_array(size, null_density, *len as usize)), BinaryView => Arc::new( - create_string_view_array_with_len(size, primitive_null_density, 4, false) - .to_binary_view(), + create_string_view_array_with_len(size, null_density, 4, false).to_binary_view(), ), List(_) => create_random_list_array(field, size, null_density, true_density)?, LargeList(_) => create_random_list_array(field, size, null_density, true_density)?, @@ -230,7 +187,13 @@ pub fn create_random_array( "Generating random arrays not yet implemented for {other:?}" ))); } - }) + }; + + if !field.is_nullable() { + assert_eq!(array.null_count(), 0); + } + + Ok(array) } #[inline] @@ -812,4 +775,23 @@ mod tests { assert_eq!(array.len(), size); } } + + #[test] + fn create_non_nullable_decimal_array_with_null_density() { + let size = 10; + let fields = vec![ + Field::new("a", DataType::Decimal128(10, -2), false), + Field::new("b", DataType::Decimal256(10, -2), false), + ]; + let schema = Schema::new(fields); + let schema_ref = Arc::new(schema); + let batch = create_random_batch(schema_ref.clone(), size, 0.35, 0.7).unwrap(); + + assert_eq!(batch.schema(), schema_ref); + assert_eq!(batch.num_columns(), schema_ref.fields().len()); + for array in batch.columns() { + assert_eq!(array.len(), size); + assert_eq!(array.null_count(), 0); + } + } }