Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 18 additions & 2 deletions rust/arrow/src/compute/kernels/cast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -327,11 +327,27 @@ pub fn cast(array: &ArrayRef, to_type: &DataType) -> Result<ArrayRef> {

// temporal casts
(Int32, Date32(_)) => cast_array_data::<Date32Type>(array, to_type.clone()),
(Int32, Time32(_)) => cast_array_data::<Date32Type>(array, to_type.clone()),
(Int32, Time32(unit)) => match unit {
TimeUnit::Second => {
cast_array_data::<Time32SecondType>(array, to_type.clone())
}
TimeUnit::Millisecond => {
cast_array_data::<Time32MillisecondType>(array, to_type.clone())
}
_ => unreachable!(),
},
(Date32(_), Int32) => cast_array_data::<Int32Type>(array, to_type.clone()),
(Time32(_), Int32) => cast_array_data::<Int32Type>(array, to_type.clone()),
(Int64, Date64(_)) => cast_array_data::<Date64Type>(array, to_type.clone()),
(Int64, Time64(_)) => cast_array_data::<Date64Type>(array, to_type.clone()),
(Int64, Time64(unit)) => match unit {
TimeUnit::Microsecond => {
cast_array_data::<Time64MicrosecondType>(array, to_type.clone())
}
TimeUnit::Nanosecond => {
cast_array_data::<Time64NanosecondType>(array, to_type.clone())
}
_ => unreachable!(),
},
(Date64(_), Int64) => cast_array_data::<Int64Type>(array, to_type.clone()),
(Time64(_), Int64) => cast_array_data::<Int64Type>(array, to_type.clone()),
(Date32(DateUnit::Day), Date64(DateUnit::Millisecond)) => {
Expand Down
102 changes: 75 additions & 27 deletions rust/parquet/src/arrow/array_reader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,10 @@ use crate::arrow::converter::{
BinaryArrayConverter, BinaryConverter, BoolConverter, BooleanArrayConverter,
Converter, Date32Converter, FixedLenBinaryConverter, FixedSizeArrayConverter,
Float32Converter, Float64Converter, Int16Converter, Int32Converter, Int64Converter,
Int8Converter, Int96ArrayConverter, Int96Converter, TimestampMicrosecondConverter,
TimestampMillisecondConverter, UInt16Converter, UInt32Converter, UInt64Converter,
UInt8Converter, Utf8ArrayConverter, Utf8Converter,
Int8Converter, Int96ArrayConverter, Int96Converter, Time32MillisecondConverter,
Time32SecondConverter, Time64MicrosecondConverter, Time64NanosecondConverter,
TimestampMicrosecondConverter, TimestampMillisecondConverter, UInt16Converter,
UInt32Converter, UInt64Converter, UInt8Converter, Utf8ArrayConverter, Utf8Converter,
};
use crate::arrow::record_reader::RecordReader;
use crate::arrow::schema::parquet_to_arrow_field;
Expand Down Expand Up @@ -196,11 +197,27 @@ impl<T: DataType> ArrayReader for PrimitiveArrayReader<T> {
.convert(self.record_reader.cast::<Int32Type>()),
_ => Err(general_err!("No conversion from parquet type to arrow type for date with unit {:?}", unit)),
}
(ArrowType::Time32(_), PhysicalType::INT32) => {
UInt32Converter::new().convert(self.record_reader.cast::<Int32Type>())
(ArrowType::Time32(unit), PhysicalType::INT32) => {
match unit {
TimeUnit::Second => {
Time32SecondConverter::new().convert(self.record_reader.cast::<Int32Type>())
}
TimeUnit::Millisecond => {
Time32MillisecondConverter::new().convert(self.record_reader.cast::<Int32Type>())
}
_ => Err(general_err!("Invalid or unsupported arrow array with datatype {:?}", self.get_data_type()))
}
}
(ArrowType::Time64(_), PhysicalType::INT64) => {
UInt64Converter::new().convert(self.record_reader.cast::<Int64Type>())
(ArrowType::Time64(unit), PhysicalType::INT64) => {
match unit {
TimeUnit::Microsecond => {
Time64MicrosecondConverter::new().convert(self.record_reader.cast::<Int64Type>())
}
TimeUnit::Nanosecond => {
Time64NanosecondConverter::new().convert(self.record_reader.cast::<Int64Type>())
}
_ => Err(general_err!("Invalid or unsupported arrow array with datatype {:?}", self.get_data_type()))
}
}
(ArrowType::Interval(IntervalUnit::YearMonth), PhysicalType::INT32) => {
UInt32Converter::new().convert(self.record_reader.cast::<Int32Type>())
Expand Down Expand Up @@ -941,10 +958,12 @@ mod tests {
use crate::util::test_common::{get_test_file, make_pages};
use arrow::array::{Array, ArrayRef, PrimitiveArray, StringArray, StructArray};
use arrow::datatypes::{
DataType as ArrowType, Date32Type as ArrowDate32, Field, Int32Type as ArrowInt32,
ArrowPrimitiveType, DataType as ArrowType, Date32Type as ArrowDate32, Field,
Int32Type as ArrowInt32, Int64Type as ArrowInt64,
Time32MillisecondType as ArrowTime32MillisecondArray,
Time64MicrosecondType as ArrowTime64MicrosecondArray,
TimestampMicrosecondType as ArrowTimestampMicrosecondType,
TimestampMillisecondType as ArrowTimestampMillisecondType,
UInt32Type as ArrowUInt32, UInt64Type as ArrowUInt64,
};
use rand::distributions::uniform::SampleUniform;
use rand::{thread_rng, Rng};
Expand Down Expand Up @@ -1101,7 +1120,7 @@ mod tests {
}

macro_rules! test_primitive_array_reader_one_type {
($arrow_parquet_type:ty, $physical_type:expr, $logical_type_str:expr, $result_arrow_type:ty, $result_primitive_type:ty) => {{
($arrow_parquet_type:ty, $physical_type:expr, $logical_type_str:expr, $result_arrow_type:ty, $result_arrow_cast_type:ty, $result_primitive_type:ty) => {{
let message_type = format!(
"
message test_schema {{
Expand All @@ -1112,7 +1131,7 @@ mod tests {
);
let schema = parse_message_type(&message_type)
.map(|t| Rc::new(SchemaDescriptor::new(Rc::new(t))))
.unwrap();
.expect("Unable to parse message type into a schema descriptor");

let column_desc = schema.column(0);

Expand Down Expand Up @@ -1142,24 +1161,48 @@ mod tests {
Box::new(page_iterator),
column_desc.clone(),
)
.unwrap();
.expect("Unable to get array reader");

let array = array_reader.next_batch(50).unwrap();
let array = array_reader
.next_batch(50)
.expect("Unable to get batch from reader");

let result_data_type = <$result_arrow_type>::get_data_type();
let array = array
.as_any()
.downcast_ref::<PrimitiveArray<$result_arrow_type>>()
.unwrap();

assert_eq!(
&PrimitiveArray::<$result_arrow_type>::from(
data[0..50]
.iter()
.map(|x| *x as $result_primitive_type)
.collect::<Vec<$result_primitive_type>>()
),
array
.expect(
format!(
"Unable to downcast {:?} to {:?}",
array.data_type(),
result_data_type
)
.as_str(),
);

// create expected array as primitive, and cast to result type
let expected = PrimitiveArray::<$result_arrow_cast_type>::from(
data[0..50]
.iter()
.map(|x| *x as $result_primitive_type)
.collect::<Vec<$result_primitive_type>>(),
);
let expected = Arc::new(expected) as ArrayRef;
let expected = arrow::compute::cast(&expected, &result_data_type)
.expect("Unable to cast expected array");
assert_eq!(expected.data_type(), &result_data_type);
let expected = expected
.as_any()
.downcast_ref::<PrimitiveArray<$result_arrow_type>>()
.expect(
format!(
"Unable to downcast expected {:?} to {:?}",
expected.data_type(),
result_data_type
)
.as_str(),
);
assert_eq!(expected, array);
}
}};
}
Expand All @@ -1171,34 +1214,39 @@ mod tests {
PhysicalType::INT32,
"DATE",
ArrowDate32,
ArrowInt32,
i32
);
test_primitive_array_reader_one_type!(
Int32Type,
PhysicalType::INT32,
"TIME_MILLIS",
ArrowUInt32,
u32
ArrowTime32MillisecondArray,
ArrowInt32,
i32
);
test_primitive_array_reader_one_type!(
Int64Type,
PhysicalType::INT64,
"TIME_MICROS",
ArrowUInt64,
u64
ArrowTime64MicrosecondArray,
ArrowInt64,
i64
);
test_primitive_array_reader_one_type!(
Int64Type,
PhysicalType::INT64,
"TIMESTAMP_MILLIS",
ArrowTimestampMillisecondType,
ArrowInt64,
i64
);
test_primitive_array_reader_one_type!(
Int64Type,
PhysicalType::INT64,
"TIMESTAMP_MICROS",
ArrowTimestampMicrosecondType,
ArrowInt64,
i64
);
}
Expand Down
Loading