Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 18 additions & 2 deletions rust/arrow/src/compute/kernels/cast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -327,11 +327,27 @@ pub fn cast(array: &ArrayRef, to_type: &DataType) -> Result<ArrayRef> {

// temporal casts
(Int32, Date32(_)) => cast_array_data::<Date32Type>(array, to_type.clone()),
(Int32, Time32(_)) => cast_array_data::<Date32Type>(array, to_type.clone()),
(Int32, Time32(unit)) => match unit {
TimeUnit::Second => {
cast_array_data::<Time32SecondType>(array, to_type.clone())
}
TimeUnit::Millisecond => {
cast_array_data::<Time32MillisecondType>(array, to_type.clone())
}
_ => unreachable!(),
},
(Date32(_), Int32) => cast_array_data::<Int32Type>(array, to_type.clone()),
(Time32(_), Int32) => cast_array_data::<Int32Type>(array, to_type.clone()),
(Int64, Date64(_)) => cast_array_data::<Date64Type>(array, to_type.clone()),
(Int64, Time64(_)) => cast_array_data::<Date64Type>(array, to_type.clone()),
(Int64, Time64(unit)) => match unit {
TimeUnit::Microsecond => {
cast_array_data::<Time64MicrosecondType>(array, to_type.clone())
}
TimeUnit::Nanosecond => {
cast_array_data::<Time64NanosecondType>(array, to_type.clone())
}
_ => unreachable!(),
},
(Date64(_), Int64) => cast_array_data::<Int64Type>(array, to_type.clone()),
(Time64(_), Int64) => cast_array_data::<Int64Type>(array, to_type.clone()),
(Date32(DateUnit::Day), Date64(DateUnit::Millisecond)) => {
Expand Down
4 changes: 3 additions & 1 deletion rust/arrow/src/ipc/convert.rs
Original file line number Diff line number Diff line change
Expand Up @@ -334,7 +334,9 @@ pub(crate) fn build_field<'a: 'b, 'b>(

let mut field_builder = ipc::FieldBuilder::new(fbb);
field_builder.add_name(fb_field_name);
fb_dictionary.map(|dictionary| field_builder.add_dictionary(dictionary));
if let Some(dictionary) = fb_dictionary {
field_builder.add_dictionary(dictionary)
}
field_builder.add_type_type(field_type.type_type);
field_builder.add_nullable(field.is_nullable());
match field_type.children {
Expand Down
102 changes: 75 additions & 27 deletions rust/parquet/src/arrow/array_reader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,10 @@ use crate::arrow::converter::{
BinaryArrayConverter, BinaryConverter, BoolConverter, BooleanArrayConverter,
Converter, Date32Converter, FixedLenBinaryConverter, FixedSizeArrayConverter,
Float32Converter, Float64Converter, Int16Converter, Int32Converter, Int64Converter,
Int8Converter, Int96ArrayConverter, Int96Converter, TimestampMicrosecondConverter,
TimestampMillisecondConverter, UInt16Converter, UInt32Converter, UInt64Converter,
UInt8Converter, Utf8ArrayConverter, Utf8Converter,
Int8Converter, Int96ArrayConverter, Int96Converter, Time32MillisecondConverter,
Time32SecondConverter, Time64MicrosecondConverter, Time64NanosecondConverter,
TimestampMicrosecondConverter, TimestampMillisecondConverter, UInt16Converter,
UInt32Converter, UInt64Converter, UInt8Converter, Utf8ArrayConverter, Utf8Converter,
};
use crate::arrow::record_reader::RecordReader;
use crate::arrow::schema::parquet_to_arrow_field;
Expand Down Expand Up @@ -196,11 +197,27 @@ impl<T: DataType> ArrayReader for PrimitiveArrayReader<T> {
.convert(self.record_reader.cast::<Int32Type>()),
_ => Err(general_err!("No conversion from parquet type to arrow type for date with unit {:?}", unit)),
}
(ArrowType::Time32(_), PhysicalType::INT32) => {
UInt32Converter::new().convert(self.record_reader.cast::<Int32Type>())
(ArrowType::Time32(unit), PhysicalType::INT32) => {
match unit {
TimeUnit::Second => {
Time32SecondConverter::new().convert(self.record_reader.cast::<Int32Type>())
}
TimeUnit::Millisecond => {
Time32MillisecondConverter::new().convert(self.record_reader.cast::<Int32Type>())
}
_ => Err(general_err!("Invalid or unsupported arrow array with datatype {:?}", self.get_data_type()))
}
}
(ArrowType::Time64(_), PhysicalType::INT64) => {
UInt64Converter::new().convert(self.record_reader.cast::<Int64Type>())
(ArrowType::Time64(unit), PhysicalType::INT64) => {
match unit {
TimeUnit::Microsecond => {
Time64MicrosecondConverter::new().convert(self.record_reader.cast::<Int64Type>())
}
TimeUnit::Nanosecond => {
Time64NanosecondConverter::new().convert(self.record_reader.cast::<Int64Type>())
}
_ => Err(general_err!("Invalid or unsupported arrow array with datatype {:?}", self.get_data_type()))
}
}
(ArrowType::Interval(IntervalUnit::YearMonth), PhysicalType::INT32) => {
UInt32Converter::new().convert(self.record_reader.cast::<Int32Type>())
Expand Down Expand Up @@ -941,10 +958,12 @@ mod tests {
use crate::util::test_common::{get_test_file, make_pages};
use arrow::array::{Array, ArrayRef, PrimitiveArray, StringArray, StructArray};
use arrow::datatypes::{
DataType as ArrowType, Date32Type as ArrowDate32, Field, Int32Type as ArrowInt32,
ArrowPrimitiveType, DataType as ArrowType, Date32Type as ArrowDate32, Field,
Int32Type as ArrowInt32, Int64Type as ArrowInt64,
Time32MillisecondType as ArrowTime32MillisecondArray,
Time64MicrosecondType as ArrowTime64MicrosecondArray,
TimestampMicrosecondType as ArrowTimestampMicrosecondType,
TimestampMillisecondType as ArrowTimestampMillisecondType,
UInt32Type as ArrowUInt32, UInt64Type as ArrowUInt64,
};
use rand::distributions::uniform::SampleUniform;
use rand::{thread_rng, Rng};
Expand Down Expand Up @@ -1101,7 +1120,7 @@ mod tests {
}

macro_rules! test_primitive_array_reader_one_type {
($arrow_parquet_type:ty, $physical_type:expr, $logical_type_str:expr, $result_arrow_type:ty, $result_primitive_type:ty) => {{
($arrow_parquet_type:ty, $physical_type:expr, $logical_type_str:expr, $result_arrow_type:ty, $result_arrow_cast_type:ty, $result_primitive_type:ty) => {{
let message_type = format!(
"
message test_schema {{
Expand All @@ -1112,7 +1131,7 @@ mod tests {
);
let schema = parse_message_type(&message_type)
.map(|t| Rc::new(SchemaDescriptor::new(Rc::new(t))))
.unwrap();
.expect("Unable to parse message type into a schema descriptor");

let column_desc = schema.column(0);

Expand Down Expand Up @@ -1142,24 +1161,48 @@ mod tests {
Box::new(page_iterator),
column_desc.clone(),
)
.unwrap();
.expect("Unable to get array reader");

let array = array_reader.next_batch(50).unwrap();
let array = array_reader
.next_batch(50)
.expect("Unable to get batch from reader");

let result_data_type = <$result_arrow_type>::get_data_type();
let array = array
.as_any()
.downcast_ref::<PrimitiveArray<$result_arrow_type>>()
.unwrap();

assert_eq!(
&PrimitiveArray::<$result_arrow_type>::from(
data[0..50]
.iter()
.map(|x| *x as $result_primitive_type)
.collect::<Vec<$result_primitive_type>>()
),
array
.expect(
format!(
"Unable to downcast {:?} to {:?}",
array.data_type(),
result_data_type
)
.as_str(),
);

// create expected array as primitive, and cast to result type
let expected = PrimitiveArray::<$result_arrow_cast_type>::from(
data[0..50]
.iter()
.map(|x| *x as $result_primitive_type)
.collect::<Vec<$result_primitive_type>>(),
);
let expected = Arc::new(expected) as ArrayRef;
let expected = arrow::compute::cast(&expected, &result_data_type)
.expect("Unable to cast expected array");
assert_eq!(expected.data_type(), &result_data_type);
let expected = expected
.as_any()
.downcast_ref::<PrimitiveArray<$result_arrow_type>>()
.expect(
format!(
"Unable to downcast expected {:?} to {:?}",
expected.data_type(),
result_data_type
)
.as_str(),
);
assert_eq!(expected, array);
}
}};
}
Expand All @@ -1171,34 +1214,39 @@ mod tests {
PhysicalType::INT32,
"DATE",
ArrowDate32,
ArrowInt32,
i32
);
test_primitive_array_reader_one_type!(
Int32Type,
PhysicalType::INT32,
"TIME_MILLIS",
ArrowUInt32,
u32
ArrowTime32MillisecondArray,
ArrowInt32,
i32
);
test_primitive_array_reader_one_type!(
Int64Type,
PhysicalType::INT64,
"TIME_MICROS",
ArrowUInt64,
u64
ArrowTime64MicrosecondArray,
ArrowInt64,
i64
);
test_primitive_array_reader_one_type!(
Int64Type,
PhysicalType::INT64,
"TIMESTAMP_MILLIS",
ArrowTimestampMillisecondType,
ArrowInt64,
i64
);
test_primitive_array_reader_one_type!(
Int64Type,
PhysicalType::INT64,
"TIMESTAMP_MICROS",
ArrowTimestampMicrosecondType,
ArrowInt64,
i64
);
}
Expand Down
35 changes: 27 additions & 8 deletions rust/parquet/src/arrow/arrow_reader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,9 @@

use crate::arrow::array_reader::{build_array_reader, ArrayReader, StructArrayReader};
use crate::arrow::schema::parquet_to_arrow_schema;
use crate::arrow::schema::parquet_to_arrow_schema_by_columns;
use crate::arrow::schema::{
parquet_to_arrow_schema_by_columns, parquet_to_arrow_schema_by_root_columns,
};
use crate::errors::{ParquetError, Result};
use crate::file::reader::FileReader;
use arrow::datatypes::{DataType as ArrowType, Schema, SchemaRef};
Expand All @@ -40,7 +42,12 @@ pub trait ArrowReader {

/// Read parquet schema and convert it into arrow schema.
/// This schema only includes columns identified by `column_indices`.
fn get_schema_by_columns<T>(&mut self, column_indices: T) -> Result<Schema>
/// To select leaf columns (i.e. `a.b.c` instead of `a`), set `leaf_columns = true`
fn get_schema_by_columns<T>(
&mut self,
column_indices: T,
leaf_columns: bool,
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added this extra option

) -> Result<Schema>
where
T: IntoIterator<Item = usize>;

Expand Down Expand Up @@ -84,16 +91,28 @@ impl ArrowReader for ParquetFileArrowReader {
)
}

fn get_schema_by_columns<T>(&mut self, column_indices: T) -> Result<Schema>
fn get_schema_by_columns<T>(
&mut self,
column_indices: T,
leaf_columns: bool,
) -> Result<Schema>
where
T: IntoIterator<Item = usize>,
{
let file_metadata = self.file_reader.metadata().file_metadata();
parquet_to_arrow_schema_by_columns(
file_metadata.schema_descr(),
column_indices,
file_metadata.key_value_metadata(),
)
if leaf_columns {
parquet_to_arrow_schema_by_columns(
file_metadata.schema_descr(),
column_indices,
file_metadata.key_value_metadata(),
)
} else {
parquet_to_arrow_schema_by_root_columns(
file_metadata.schema_descr(),
column_indices,
file_metadata.key_value_metadata(),
)
}
}

fn get_record_reader(
Expand Down
Loading