Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
160 changes: 71 additions & 89 deletions arrow/src/util/data_gen.rs
Original file line number Diff line number Diff line change
Expand Up @@ -66,110 +66,72 @@ pub fn create_random_batch(
pub fn create_random_array(
field: &Field,
size: usize,
null_density: f32,
mut null_density: f32,
true_density: f32,
) -> Result<ArrayRef> {
// Override null density with 0.0 if the array is non-nullable
// and a primitive type in case a nested field is nullable
let primitive_null_density = match field.is_nullable() {
true => null_density,
false => 0.0,
};
// Override nullability in case of not nested and not dictionary
// For nested we don't want to override as we want to keep the nullability for the children
// For dictionary it handle the nullability internally
if !field.data_type().is_nested() && !matches!(field.data_type(), Dictionary(_, _)) {
// Override null density with 0.0 if the array is non-nullable
null_density = match field.is_nullable() {
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

overriding this instead of making the decimal use primitive_null_density to avoid future problems like this

true => null_density,
false => 0.0,
};
}

use DataType::*;
Ok(match field.data_type() {
let array = match field.data_type() {
Null => Arc::new(NullArray::new(size)) as ArrayRef,
Boolean => Arc::new(create_boolean_array(
size,
primitive_null_density,
true_density,
)),
Int8 => Arc::new(create_primitive_array::<Int8Type>(
size,
primitive_null_density,
)),
Int16 => Arc::new(create_primitive_array::<Int16Type>(
size,
primitive_null_density,
)),
Int32 => Arc::new(create_primitive_array::<Int32Type>(
size,
primitive_null_density,
)),
Int64 => Arc::new(create_primitive_array::<Int64Type>(
size,
primitive_null_density,
)),
UInt8 => Arc::new(create_primitive_array::<UInt8Type>(
size,
primitive_null_density,
)),
UInt16 => Arc::new(create_primitive_array::<UInt16Type>(
size,
primitive_null_density,
)),
UInt32 => Arc::new(create_primitive_array::<UInt32Type>(
size,
primitive_null_density,
)),
UInt64 => Arc::new(create_primitive_array::<UInt64Type>(
size,
primitive_null_density,
)),
Boolean => Arc::new(create_boolean_array(size, null_density, true_density)),
Int8 => Arc::new(create_primitive_array::<Int8Type>(size, null_density)),
Int16 => Arc::new(create_primitive_array::<Int16Type>(size, null_density)),
Int32 => Arc::new(create_primitive_array::<Int32Type>(size, null_density)),
Int64 => Arc::new(create_primitive_array::<Int64Type>(size, null_density)),
UInt8 => Arc::new(create_primitive_array::<UInt8Type>(size, null_density)),
UInt16 => Arc::new(create_primitive_array::<UInt16Type>(size, null_density)),
UInt32 => Arc::new(create_primitive_array::<UInt32Type>(size, null_density)),
UInt64 => Arc::new(create_primitive_array::<UInt64Type>(size, null_density)),
Float16 => {
return Err(ArrowError::NotYetImplemented(
"Float16 is not implemented".to_string(),
));
}
Float32 => Arc::new(create_primitive_array::<Float32Type>(
size,
primitive_null_density,
)),
Float64 => Arc::new(create_primitive_array::<Float64Type>(
size,
primitive_null_density,
)),
Float32 => Arc::new(create_primitive_array::<Float32Type>(size, null_density)),
Float64 => Arc::new(create_primitive_array::<Float64Type>(size, null_density)),
Timestamp(unit, tz) => match unit {
TimeUnit::Second => Arc::new(
create_random_temporal_array::<TimestampSecondType>(size, primitive_null_density)
create_random_temporal_array::<TimestampSecondType>(size, null_density)
.with_timezone_opt(tz.clone()),
),
) as ArrayRef,
TimeUnit::Millisecond => Arc::new(
create_random_temporal_array::<TimestampMillisecondType>(
size,
primitive_null_density,
)
.with_timezone_opt(tz.clone()),
create_random_temporal_array::<TimestampMillisecondType>(size, null_density)
.with_timezone_opt(tz.clone()),
),
TimeUnit::Microsecond => Arc::new(
create_random_temporal_array::<TimestampMicrosecondType>(
size,
primitive_null_density,
)
.with_timezone_opt(tz.clone()),
create_random_temporal_array::<TimestampMicrosecondType>(size, null_density)
.with_timezone_opt(tz.clone()),
),
TimeUnit::Nanosecond => Arc::new(
create_random_temporal_array::<TimestampNanosecondType>(
size,
primitive_null_density,
)
.with_timezone_opt(tz.clone()),
create_random_temporal_array::<TimestampNanosecondType>(size, null_density)
.with_timezone_opt(tz.clone()),
),
},
Date32 => Arc::new(create_random_temporal_array::<Date32Type>(
size,
primitive_null_density,
null_density,
)),
Date64 => Arc::new(create_random_temporal_array::<Date64Type>(
size,
primitive_null_density,
null_density,
)),
Time32(unit) => match unit {
TimeUnit::Second => Arc::new(create_random_temporal_array::<Time32SecondType>(
size,
primitive_null_density,
null_density,
)) as ArrayRef,
TimeUnit::Millisecond => Arc::new(
create_random_temporal_array::<Time32MillisecondType>(size, primitive_null_density),
create_random_temporal_array::<Time32MillisecondType>(size, null_density),
),
_ => {
return Err(ArrowError::InvalidArgumentError(format!(
Expand All @@ -179,36 +141,31 @@ pub fn create_random_array(
},
Time64(unit) => match unit {
TimeUnit::Microsecond => Arc::new(
create_random_temporal_array::<Time64MicrosecondType>(size, primitive_null_density),
create_random_temporal_array::<Time64MicrosecondType>(size, null_density),
) as ArrayRef,
TimeUnit::Nanosecond => Arc::new(create_random_temporal_array::<Time64NanosecondType>(
size,
primitive_null_density,
null_density,
)),
_ => {
return Err(ArrowError::InvalidArgumentError(format!(
"Unsupported unit {unit:?} for Time64"
)));
}
},
Utf8 => Arc::new(create_string_array::<i32>(size, primitive_null_density)),
LargeUtf8 => Arc::new(create_string_array::<i64>(size, primitive_null_density)),
Utf8 => Arc::new(create_string_array::<i32>(size, null_density)),
LargeUtf8 => Arc::new(create_string_array::<i64>(size, null_density)),
Utf8View => Arc::new(create_string_view_array_with_len(
size,
primitive_null_density,
null_density,
4,
false,
)),
Binary => Arc::new(create_binary_array::<i32>(size, primitive_null_density)),
LargeBinary => Arc::new(create_binary_array::<i64>(size, primitive_null_density)),
FixedSizeBinary(len) => Arc::new(create_fsb_array(
size,
primitive_null_density,
*len as usize,
)),
Binary => Arc::new(create_binary_array::<i32>(size, null_density)),
LargeBinary => Arc::new(create_binary_array::<i64>(size, null_density)),
FixedSizeBinary(len) => Arc::new(create_fsb_array(size, null_density, *len as usize)),
BinaryView => Arc::new(
create_string_view_array_with_len(size, primitive_null_density, 4, false)
.to_binary_view(),
create_string_view_array_with_len(size, null_density, 4, false).to_binary_view(),
),
List(_) => create_random_list_array(field, size, null_density, true_density)?,
LargeList(_) => create_random_list_array(field, size, null_density, true_density)?,
Expand All @@ -230,7 +187,13 @@ pub fn create_random_array(
"Generating random arrays not yet implemented for {other:?}"
)));
}
})
};

if !field.is_nullable() {
assert_eq!(array.null_count(), 0);
}

Ok(array)
}

#[inline]
Expand Down Expand Up @@ -812,4 +775,23 @@ mod tests {
assert_eq!(array.len(), size);
}
}

#[test]
fn create_non_nullable_decimal_array_with_null_density() {
let size = 10;
let fields = vec![
Field::new("a", DataType::Decimal128(10, -2), false),
Field::new("b", DataType::Decimal256(10, -2), false),
];
let schema = Schema::new(fields);
let schema_ref = Arc::new(schema);
let batch = create_random_batch(schema_ref.clone(), size, 0.35, 0.7).unwrap();

assert_eq!(batch.schema(), schema_ref);
assert_eq!(batch.num_columns(), schema_ref.fields().len());
for array in batch.columns() {
assert_eq!(array.len(), size);
assert_eq!(array.null_count(), 0);
}
}
}
Loading