diff --git a/rust/arrow/src/array/data.rs b/rust/arrow/src/array/data.rs index 9589f73caf8..a1426a6fb88 100644 --- a/rust/arrow/src/array/data.rs +++ b/rust/arrow/src/array/data.rs @@ -29,7 +29,7 @@ use crate::util::bit_util; /// An generic representation of Arrow array data which encapsulates common attributes and /// operations for Arrow array. Specific operations for different arrays types (e.g., /// primitive, list, struct) are implemented in `Array`. -#[derive(PartialEq, Debug, Clone)] +#[derive(Debug, Clone)] pub struct ArrayData { /// The data type for this array data data_type: DataType, @@ -209,6 +209,61 @@ impl ArrayData { } } +impl PartialEq for ArrayData { + fn eq(&self, other: &Self) -> bool { + assert_eq!( + self.data_type(), + other.data_type(), + "Data types not the same" + ); + assert_eq!(self.len(), other.len(), "Lengths not the same"); + // TODO: when adding tests for this, test that we can compare with arrays that have offsets + assert_eq!(self.offset(), other.offset(), "Offsets not the same"); + assert_eq!(self.null_count(), other.null_count()); + // compare buffers excluding padding + let self_buffers = self.buffers(); + let other_buffers = other.buffers(); + assert_eq!(self_buffers.len(), other_buffers.len()); + self_buffers.iter().zip(other_buffers).for_each(|(s, o)| { + compare_buffer_regions( + s, + self.offset(), // TODO mul by data length + o, + other.offset(), // TODO mul by data len + ); + }); + // assert_eq!(self.buffers(), other.buffers()); + + assert_eq!(self.child_data(), other.child_data()); + // null arrays can skip the null bitmap, thus only compare if there are no nulls + if self.null_count() != 0 || other.null_count() != 0 { + compare_buffer_regions( + self.null_buffer().unwrap(), + self.offset(), + other.null_buffer().unwrap(), + other.offset(), + ) + } + true + } +} + +/// A helper to compare buffer regions of 2 buffers. +/// Compares the length of the shorter buffer. +fn compare_buffer_regions( + left: &Buffer, + left_offset: usize, + right: &Buffer, + right_offset: usize, +) { + // for convenience, we assume that the buffer lengths are only unequal if one has padding, + // so we take the shorter length so we can discard the padding from the longer length + let shorter_len = left.len().min(right.len()); + let s_sliced = left.bit_slice(left_offset, shorter_len); + let o_sliced = right.bit_slice(right_offset, shorter_len); + assert_eq!(s_sliced, o_sliced); +} + /// Builder for `ArrayData` type #[derive(Debug)] pub struct ArrayDataBuilder { diff --git a/rust/arrow/src/array/null.rs b/rust/arrow/src/array/null.rs index 190d2fa78fc..08c7cf1f21e 100644 --- a/rust/arrow/src/array/null.rs +++ b/rust/arrow/src/array/null.rs @@ -113,7 +113,7 @@ impl From for NullArray { impl fmt::Debug for NullArray { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - write!(f, "NullArray") + write!(f, "NullArray({})", self.len()) } } @@ -146,4 +146,10 @@ mod tests { assert_eq!(array2.null_count(), 16); assert_eq!(array2.offset(), 8); } + + #[test] + fn test_debug_null_array() { + let array = NullArray::new(1024 * 1024); + assert_eq!(format!("{:?}", array), "NullArray(1048576)"); + } } diff --git a/rust/arrow/src/compute/kernels/cast.rs b/rust/arrow/src/compute/kernels/cast.rs index 4e1bc852d42..ab34c6a0950 100644 --- a/rust/arrow/src/compute/kernels/cast.rs +++ b/rust/arrow/src/compute/kernels/cast.rs @@ -200,8 +200,7 @@ pub fn can_cast_types(from_type: &DataType, to_type: &DataType) -> bool { (Timestamp(_, _), Date32(_)) => true, (Timestamp(_, _), Date64(_)) => true, // date64 to timestamp might not make sense, - - // end temporal casts + (Null, Int32) => true, (_, _) => false, } } @@ -729,25 +728,31 @@ pub fn cast(array: &ArrayRef, to_type: &DataType) -> Result { // single integer operation, but need to avoid integer // math rounding down to zero - if to_size > from_size { - let time_array = Date64Array::from(array.data()); - Ok(Arc::new(multiply( - &time_array, - &Date64Array::from(vec![to_size / from_size; array.len()]), - )?) as ArrayRef) - } else if to_size < from_size { - let time_array = Date64Array::from(array.data()); - Ok(Arc::new(divide( - &time_array, - &Date64Array::from(vec![from_size / to_size; array.len()]), - )?) as ArrayRef) - } else { - cast_array_data::(array, to_type.clone()) + match to_size.cmp(&from_size) { + std::cmp::Ordering::Less => { + let time_array = Date64Array::from(array.data()); + Ok(Arc::new(divide( + &time_array, + &Date64Array::from(vec![from_size / to_size; array.len()]), + )?) as ArrayRef) + } + std::cmp::Ordering::Equal => { + cast_array_data::(array, to_type.clone()) + } + std::cmp::Ordering::Greater => { + let time_array = Date64Array::from(array.data()); + Ok(Arc::new(multiply( + &time_array, + &Date64Array::from(vec![to_size / from_size; array.len()]), + )?) as ArrayRef) + } } } // date64 to timestamp might not make sense, - // end temporal casts + // null to primitive/flat types + (Null, Int32) => Ok(Arc::new(Int32Array::from(vec![None; array.len()]))), + (_, _) => Err(ArrowError::ComputeError(format!( "Casting from {:?} to {:?} not supported", from_type, to_type, @@ -2476,44 +2481,44 @@ mod tests { // Test casting TO StringArray let cast_type = Utf8; - let cast_array = cast(&array, &cast_type).expect("cast to UTF-8 succeeded"); + let cast_array = cast(&array, &cast_type).expect("cast to UTF-8 failed"); assert_eq!(cast_array.data_type(), &cast_type); assert_eq!(array_to_strings(&cast_array), expected); // Test casting TO Dictionary (with different index sizes) let cast_type = Dictionary(Box::new(Int16), Box::new(Utf8)); - let cast_array = cast(&array, &cast_type).expect("cast succeeded"); + let cast_array = cast(&array, &cast_type).expect("cast failed"); assert_eq!(cast_array.data_type(), &cast_type); assert_eq!(array_to_strings(&cast_array), expected); let cast_type = Dictionary(Box::new(Int32), Box::new(Utf8)); - let cast_array = cast(&array, &cast_type).expect("cast succeeded"); + let cast_array = cast(&array, &cast_type).expect("cast failed"); assert_eq!(cast_array.data_type(), &cast_type); assert_eq!(array_to_strings(&cast_array), expected); let cast_type = Dictionary(Box::new(Int64), Box::new(Utf8)); - let cast_array = cast(&array, &cast_type).expect("cast succeeded"); + let cast_array = cast(&array, &cast_type).expect("cast failed"); assert_eq!(cast_array.data_type(), &cast_type); assert_eq!(array_to_strings(&cast_array), expected); let cast_type = Dictionary(Box::new(UInt8), Box::new(Utf8)); - let cast_array = cast(&array, &cast_type).expect("cast succeeded"); + let cast_array = cast(&array, &cast_type).expect("cast failed"); assert_eq!(cast_array.data_type(), &cast_type); assert_eq!(array_to_strings(&cast_array), expected); let cast_type = Dictionary(Box::new(UInt16), Box::new(Utf8)); - let cast_array = cast(&array, &cast_type).expect("cast succeeded"); + let cast_array = cast(&array, &cast_type).expect("cast failed"); assert_eq!(cast_array.data_type(), &cast_type); assert_eq!(array_to_strings(&cast_array), expected); let cast_type = Dictionary(Box::new(UInt32), Box::new(Utf8)); - let cast_array = cast(&array, &cast_type).expect("cast succeeded"); + let cast_array = cast(&array, &cast_type).expect("cast failed"); assert_eq!(cast_array.data_type(), &cast_type); assert_eq!(array_to_strings(&cast_array), expected); let cast_type = Dictionary(Box::new(UInt64), Box::new(Utf8)); - let cast_array = cast(&array, &cast_type).expect("cast succeeded"); + let cast_array = cast(&array, &cast_type).expect("cast failed"); assert_eq!(cast_array.data_type(), &cast_type); assert_eq!(array_to_strings(&cast_array), expected); } @@ -2598,11 +2603,11 @@ mod tests { let expected = vec!["1", "null", "3"]; // Test casting TO PrimitiveArray, different dictionary type - let cast_array = cast(&array, &Utf8).expect("cast to UTF-8 succeeded"); + let cast_array = cast(&array, &Utf8).expect("cast to UTF-8 failed"); assert_eq!(array_to_strings(&cast_array), expected); assert_eq!(cast_array.data_type(), &Utf8); - let cast_array = cast(&array, &Int64).expect("cast to int64 succeeded"); + let cast_array = cast(&array, &Int64).expect("cast to int64 failed"); assert_eq!(array_to_strings(&cast_array), expected); assert_eq!(cast_array.data_type(), &Int64); } @@ -2621,13 +2626,13 @@ mod tests { // Cast to a dictionary (same value type, Int32) let cast_type = Dictionary(Box::new(UInt8), Box::new(Int32)); - let cast_array = cast(&array, &cast_type).expect("cast succeeded"); + let cast_array = cast(&array, &cast_type).expect("cast failed"); assert_eq!(cast_array.data_type(), &cast_type); assert_eq!(array_to_strings(&cast_array), expected); // Cast to a dictionary (different value type, Int8) let cast_type = Dictionary(Box::new(UInt8), Box::new(Int8)); - let cast_array = cast(&array, &cast_type).expect("cast succeeded"); + let cast_array = cast(&array, &cast_type).expect("cast failed"); assert_eq!(cast_array.data_type(), &cast_type); assert_eq!(array_to_strings(&cast_array), expected); } @@ -2646,11 +2651,25 @@ mod tests { // Cast to a dictionary (same value type, Utf8) let cast_type = Dictionary(Box::new(UInt8), Box::new(Utf8)); - let cast_array = cast(&array, &cast_type).expect("cast succeeded"); + let cast_array = cast(&array, &cast_type).expect("cast failed"); assert_eq!(cast_array.data_type(), &cast_type); assert_eq!(array_to_strings(&cast_array), expected); } + #[test] + fn test_cast_null_array_to_int32() { + let array = Arc::new(NullArray::new(6)) as ArrayRef; + + let expected = Int32Array::from(vec![None; 6]); + + // Cast to a dictionary (same value type, Utf8) + let cast_type = DataType::Int32; + let cast_array = cast(&array, &cast_type).expect("cast failed"); + let cast_array = as_primitive_array::(&cast_array); + assert_eq!(cast_array.data_type(), &cast_type); + assert_eq!(cast_array, &expected); + } + /// Print the `DictionaryArray` `array` as a vector of strings fn array_to_strings(array: &ArrayRef) -> Vec { (0..array.len()) @@ -2768,7 +2787,7 @@ mod tests { )), Arc::new(TimestampNanosecondArray::from_vec( vec![1000, 2000], - Some(tz_name.clone()), + Some(tz_name), )), Arc::new(Date32Array::from(vec![1000, 2000])), Arc::new(Date64Array::from(vec![1000, 2000])), diff --git a/rust/arrow/src/datatypes.rs b/rust/arrow/src/datatypes.rs index 3756542b718..6f0dc240ac9 100644 --- a/rust/arrow/src/datatypes.rs +++ b/rust/arrow/src/datatypes.rs @@ -189,8 +189,8 @@ pub struct Field { name: String, data_type: DataType, nullable: bool, - dict_id: i64, - dict_is_ordered: bool, + pub(crate) dict_id: i64, + pub(crate) dict_is_ordered: bool, } pub trait ArrowNativeType: diff --git a/rust/arrow/src/ipc/convert.rs b/rust/arrow/src/ipc/convert.rs index 7a5795de91c..63d55f043c6 100644 --- a/rust/arrow/src/ipc/convert.rs +++ b/rust/arrow/src/ipc/convert.rs @@ -34,18 +34,8 @@ pub fn schema_to_fb(schema: &Schema) -> FlatBufferBuilder { let mut fields = vec![]; for field in schema.fields() { - let fb_field_name = fbb.create_string(field.name().as_str()); - let field_type = get_fb_field_type(field.data_type(), &mut fbb); - let mut field_builder = ipc::FieldBuilder::new(&mut fbb); - field_builder.add_name(fb_field_name); - field_builder.add_type_type(field_type.type_type); - field_builder.add_nullable(field.is_nullable()); - match field_type.children { - None => {} - Some(children) => field_builder.add_children(children), - }; - field_builder.add_type_(field_type.type_); - fields.push(field_builder.finish()); + let fb_field = build_field(&mut fbb, field); + fields.push(fb_field); } let mut custom_metadata = vec![]; @@ -80,18 +70,8 @@ pub fn schema_to_fb_offset<'a: 'b, 'b>( ) -> WIPOffset> { let mut fields = vec![]; for field in schema.fields() { - let fb_field_name = fbb.create_string(field.name().as_str()); - let field_type = get_fb_field_type(field.data_type(), fbb); - let mut field_builder = ipc::FieldBuilder::new(fbb); - field_builder.add_name(fb_field_name); - field_builder.add_type_type(field_type.type_type); - field_builder.add_nullable(field.is_nullable()); - match field_type.children { - None => {} - Some(children) => field_builder.add_children(children), - }; - field_builder.add_type_(field_type.type_); - fields.push(field_builder.finish()); + let fb_field = build_field(fbb, field); + fields.push(fb_field); } let mut custom_metadata = vec![]; @@ -333,6 +313,40 @@ pub(crate) struct FBFieldType<'b> { pub(crate) children: Option>>>>, } +/// Create an IPC Field from an Arrow Field +pub(crate) fn build_field<'a: 'b, 'b>( + fbb: &mut FlatBufferBuilder<'a>, + field: &Field, +) -> WIPOffset> { + let fb_field_name = fbb.create_string(field.name().as_str()); + let field_type = get_fb_field_type(field.data_type(), fbb); + + let fb_dictionary = if let Dictionary(index_type, _) = field.data_type() { + Some(get_fb_dictionary( + index_type, + field.dict_id, + field.dict_is_ordered, + fbb, + )) + } else { + None + }; + + let mut field_builder = ipc::FieldBuilder::new(fbb); + field_builder.add_name(fb_field_name); + if let Some(dictionary) = fb_dictionary { + field_builder.add_dictionary(dictionary) + } + field_builder.add_type_type(field_type.type_type); + field_builder.add_nullable(field.is_nullable()); + match field_type.children { + None => {} + Some(children) => field_builder.add_children(children), + }; + field_builder.add_type_(field_type.type_); + field_builder.finish() +} + /// Get the IPC type of a data type pub(crate) fn get_fb_field_type<'a: 'b, 'b>( data_type: &DataType, @@ -609,10 +623,51 @@ pub(crate) fn get_fb_field_type<'a: 'b, 'b>( children: Some(fbb.create_vector(&children[..])), } } + Dictionary(_, value_type) => { + // In this library, the dictionary "type" is a logical construct. Here we + // pass through to the value type, as we've already captured the index + // type in the DictionaryEncoding metadata in the parent field + get_fb_field_type(value_type, fbb) + } t => unimplemented!("Type {:?} not supported", t), } } +/// Create an IPC dictionary encoding +pub(crate) fn get_fb_dictionary<'a: 'b, 'b>( + index_type: &DataType, + dict_id: i64, + dict_is_ordered: bool, + fbb: &mut FlatBufferBuilder<'a>, +) -> WIPOffset> { + // We assume that the dictionary index type (as an integer) has already been + // validated elsewhere, and can safely assume we are dealing with integers + let mut index_builder = ipc::IntBuilder::new(fbb); + + match *index_type { + Int8 | Int16 | Int32 | Int64 => index_builder.add_is_signed(true), + UInt8 | UInt16 | UInt32 | UInt64 => index_builder.add_is_signed(false), + _ => {} + } + + match *index_type { + Int8 | UInt8 => index_builder.add_bitWidth(8), + Int16 | UInt16 => index_builder.add_bitWidth(16), + Int32 | UInt32 => index_builder.add_bitWidth(32), + Int64 | UInt64 => index_builder.add_bitWidth(64), + _ => {} + } + + let index_builder = index_builder.finish(); + + let mut builder = ipc::DictionaryEncodingBuilder::new(fbb); + builder.add_id(dict_id); + builder.add_indexType(index_builder); + builder.add_isOrdered(dict_is_ordered); + + builder.finish() +} + #[cfg(test)] mod tests { use super::*; @@ -714,6 +769,26 @@ mod tests { false, ), Field::new("struct<>", DataType::Struct(vec![]), true), + Field::new_dict( + "dictionary", + DataType::Dictionary( + Box::new(DataType::Int32), + Box::new(DataType::Utf8), + ), + true, + 123, + true, + ), + Field::new_dict( + "dictionary", + DataType::Dictionary( + Box::new(DataType::UInt8), + Box::new(DataType::UInt32), + ), + true, + 123, + true, + ), ], md, ); diff --git a/rust/arrow/src/util/display.rs b/rust/arrow/src/util/display.rs index 102ec5d99ed..1a873f1b083 100644 --- a/rust/arrow/src/util/display.rs +++ b/rust/arrow/src/util/display.rs @@ -44,6 +44,22 @@ macro_rules! make_string { }}; } +macro_rules! make_string_from_list { + ($column: ident, $row: ident) => {{ + let list = $column + .as_any() + .downcast_ref::() + .ok_or(ArrowError::InvalidArgumentError(format!( + "Repl error: could not convert list column to list array." + )))? + .value($row); + let string_values = (0..list.len()) + .map(|i| array_value_to_string(&list.clone(), i)) + .collect::>>()?; + Ok(format!("[{}]", string_values.join(", "))) + }}; +} + /// Get the value at the given row in an array as a String. /// /// Note this function is quite inefficient and is unlikely to be @@ -89,6 +105,7 @@ pub fn array_value_to_string(column: &array::ArrayRef, row: usize) -> Result { make_string!(array::Time64NanosecondArray, column, row) } + DataType::List(_) => make_string_from_list!(column, row), DataType::Dictionary(index_type, _value_type) => match **index_type { DataType::Int8 => dict_array_value_to_string::(column, row), DataType::Int16 => dict_array_value_to_string::(column, row), diff --git a/rust/datafusion/tests/sql.rs b/rust/datafusion/tests/sql.rs index 9eceb768a3f..a35476d7f25 100644 --- a/rust/datafusion/tests/sql.rs +++ b/rust/datafusion/tests/sql.rs @@ -15,6 +15,7 @@ // specific language governing permissions and limitations // under the License. +use std::convert::TryFrom; use std::env; use std::sync::Arc; @@ -22,7 +23,7 @@ extern crate arrow; extern crate datafusion; use arrow::{array::*, datatypes::TimeUnit}; -use arrow::{datatypes::Int32Type, record_batch::RecordBatch}; +use arrow::{datatypes::Int32Type, datatypes::Int64Type, record_batch::RecordBatch}; use arrow::{ datatypes::{DataType, Field, Schema, SchemaRef}, util::display::array_value_to_string, @@ -128,6 +129,100 @@ async fn parquet_single_nan_schema() { } } +#[tokio::test] +async fn parquet_list_columns() { + let mut ctx = ExecutionContext::new(); + let testdata = env::var("PARQUET_TEST_DATA").expect("PARQUET_TEST_DATA not defined"); + ctx.register_parquet( + "list_columns", + &format!("{}/list_columns.parquet", testdata), + ) + .unwrap(); + + let schema = Arc::new(Schema::new(vec![ + Field::new( + "int64_list", + DataType::List(Box::new(DataType::Int64)), + true, + ), + Field::new("utf8_list", DataType::List(Box::new(DataType::Utf8)), true), + ])); + + let sql = "SELECT int64_list, utf8_list FROM list_columns"; + let plan = ctx.create_logical_plan(&sql).unwrap(); + let plan = ctx.optimize(&plan).unwrap(); + let plan = ctx.create_physical_plan(&plan).unwrap(); + let results = ctx.collect(plan).await.unwrap(); + + // int64_list utf8_list + // 0 [1, 2, 3] [abc, efg, hij] + // 1 [None, 1] None + // 2 [4] [efg, None, hij, xyz] + + assert_eq!(1, results.len()); + let batch = &results[0]; + assert_eq!(3, batch.num_rows()); + assert_eq!(2, batch.num_columns()); + assert_eq!(schema, batch.schema()); + + let int_list_array = batch + .column(0) + .as_any() + .downcast_ref::() + .unwrap(); + let utf8_list_array = batch + .column(1) + .as_any() + .downcast_ref::() + .unwrap(); + + assert_eq!( + int_list_array + .value(0) + .as_any() + .downcast_ref::>() + .unwrap(), + &PrimitiveArray::::from(vec![Some(1), Some(2), Some(3),]) + ); + + assert_eq!( + utf8_list_array + .value(0) + .as_any() + .downcast_ref::() + .unwrap(), + &StringArray::try_from(vec![Some("abc"), Some("efg"), Some("hij"),]).unwrap() + ); + + assert_eq!( + int_list_array + .value(1) + .as_any() + .downcast_ref::>() + .unwrap(), + &PrimitiveArray::::from(vec![None, Some(1),]) + ); + + assert!(utf8_list_array.is_null(1)); + + assert_eq!( + int_list_array + .value(2) + .as_any() + .downcast_ref::>() + .unwrap(), + &PrimitiveArray::::from(vec![Some(4),]) + ); + + let result = utf8_list_array.value(2); + let result = result.as_any().downcast_ref::().unwrap(); + + assert_eq!(result.value(0), "efg"); + assert!(result.is_null(1)); + assert_eq!(result.value(2), "hij"); + assert_eq!(result.value(3), "xyz"); +} + #[tokio::test] async fn csv_count_star() -> Result<()> { let mut ctx = ExecutionContext::new(); diff --git a/rust/parquet/Cargo.toml b/rust/parquet/Cargo.toml index c1fd1b6ca0a..122c5b6356d 100644 --- a/rust/parquet/Cargo.toml +++ b/rust/parquet/Cargo.toml @@ -40,6 +40,7 @@ zstd = { version = "0.5", optional = true } chrono = "0.4" num-bigint = "0.3" arrow = { path = "../arrow", version = "3.0.0-SNAPSHOT", optional = true } +base64 = { version = "*", optional = true } [dev-dependencies] rand = "0.7" @@ -52,4 +53,4 @@ arrow = { path = "../arrow", version = "3.0.0-SNAPSHOT" } serde_json = { version = "1.0", features = ["preserve_order"] } [features] -default = ["arrow", "snap", "brotli", "flate2", "lz4", "zstd"] +default = ["arrow", "snap", "brotli", "flate2", "lz4", "zstd", "base64"] diff --git a/rust/parquet/src/arrow/array_reader.rs b/rust/parquet/src/arrow/array_reader.rs index b9db4f8c37f..76b672bb301 100644 --- a/rust/parquet/src/arrow/array_reader.rs +++ b/rust/parquet/src/arrow/array_reader.rs @@ -25,19 +25,41 @@ use std::sync::Arc; use std::vec::Vec; use arrow::array::{ - ArrayDataBuilder, ArrayDataRef, ArrayRef, BooleanBufferBuilder, BufferBuilderTrait, - Int16BufferBuilder, StructArray, + Array, ArrayData, ArrayDataBuilder, ArrayDataRef, ArrayRef, BinaryArray, + BinaryBuilder, BooleanBufferBuilder, BufferBuilderTrait, FixedSizeBinaryArray, + FixedSizeBinaryBuilder, GenericListArray, Int16BufferBuilder, ListBuilder, + OffsetSizeTrait, PrimitiveArray, PrimitiveBuilder, StringArray, StringBuilder, + StructArray, }; use arrow::buffer::{Buffer, MutableBuffer}; -use arrow::datatypes::{DataType as ArrowType, DateUnit, Field, IntervalUnit, TimeUnit}; +use arrow::datatypes::{ + ArrowPrimitiveType, BooleanType as ArrowBooleanType, DataType as ArrowType, + Date32Type as ArrowDate32Type, Date64Type as ArrowDate64Type, + DurationMicrosecondType as ArrowDurationMicrosecondType, + DurationMillisecondType as ArrowDurationMillisecondType, + DurationNanosecondType as ArrowDurationNanosecondType, + DurationSecondType as ArrowDurationSecondType, Field, + Float32Type as ArrowFloat32Type, Float64Type as ArrowFloat64Type, + Int16Type as ArrowInt16Type, Int32Type as ArrowInt32Type, + Int64Type as ArrowInt64Type, Int8Type as ArrowInt8Type, Schema, + Time32MillisecondType as ArrowTime32MillisecondType, + Time32SecondType as ArrowTime32SecondType, + Time64MicrosecondType as ArrowTime64MicrosecondType, + Time64NanosecondType as ArrowTime64NanosecondType, TimeUnit as ArrowTimeUnit, + TimestampMicrosecondType as ArrowTimestampMicrosecondType, + TimestampMillisecondType as ArrowTimestampMillisecondType, + TimestampNanosecondType as ArrowTimestampNanosecondType, + TimestampSecondType as ArrowTimestampSecondType, ToByteSlice, + UInt16Type as ArrowUInt16Type, UInt32Type as ArrowUInt32Type, + UInt64Type as ArrowUInt64Type, UInt8Type as ArrowUInt8Type, +}; +use arrow::util::bit_util; use crate::arrow::converter::{ - BinaryArrayConverter, BinaryConverter, BoolConverter, BooleanArrayConverter, - Converter, Date32Converter, FixedLenBinaryConverter, FixedSizeArrayConverter, - Float32Converter, Float64Converter, Int16Converter, Int32Converter, Int64Converter, - Int8Converter, Int96ArrayConverter, Int96Converter, TimestampMicrosecondConverter, - TimestampMillisecondConverter, UInt16Converter, UInt32Converter, UInt64Converter, - UInt8Converter, Utf8ArrayConverter, Utf8Converter, + BinaryArrayConverter, BinaryConverter, Converter, FixedLenBinaryConverter, + FixedSizeArrayConverter, Int96ArrayConverter, Int96Converter, + LargeBinaryArrayConverter, LargeBinaryConverter, LargeUtf8ArrayConverter, + LargeUtf8Converter, Utf8ArrayConverter, Utf8Converter, }; use crate::arrow::record_reader::RecordReader; use crate::arrow::schema::parquet_to_arrow_field; @@ -77,6 +99,97 @@ pub trait ArrayReader { fn get_rep_levels(&self) -> Option<&[i16]>; } +/// A NullArrayReader reads Parquet columns stored as null int32s with an Arrow +/// NullArray type. +pub struct NullArrayReader { + data_type: ArrowType, + pages: Box, + def_levels_buffer: Option, + rep_levels_buffer: Option, + column_desc: ColumnDescPtr, + record_reader: RecordReader, + _type_marker: PhantomData, +} + +impl NullArrayReader { + /// Construct null array reader. + pub fn new( + mut pages: Box, + column_desc: ColumnDescPtr, + ) -> Result { + let mut record_reader = RecordReader::::new(column_desc.clone()); + if let Some(page_reader) = pages.next() { + record_reader.set_page_reader(page_reader?)?; + } + + Ok(Self { + data_type: ArrowType::Null, + pages, + def_levels_buffer: None, + rep_levels_buffer: None, + column_desc, + record_reader, + _type_marker: PhantomData, + }) + } +} + +/// Implementation of primitive array reader. +impl ArrayReader for NullArrayReader { + fn as_any(&self) -> &dyn Any { + self + } + + /// Returns data type of primitive array. + fn get_data_type(&self) -> &ArrowType { + &self.data_type + } + + /// Reads at most `batch_size` records into array. + fn next_batch(&mut self, batch_size: usize) -> Result { + let mut records_read = 0usize; + while records_read < batch_size { + let records_to_read = batch_size - records_read; + + // NB can be 0 if at end of page + let records_read_once = self.record_reader.read_records(records_to_read)?; + records_read += records_read_once; + + // Record reader exhausted + if records_read_once < records_to_read { + if let Some(page_reader) = self.pages.next() { + // Read from new page reader + self.record_reader.set_page_reader(page_reader?)?; + } else { + // Page reader also exhausted + break; + } + } + } + + // convert to arrays + let array = arrow::array::NullArray::new(records_read); + + // save definition and repetition buffers + self.def_levels_buffer = self.record_reader.consume_def_levels()?; + self.rep_levels_buffer = self.record_reader.consume_rep_levels()?; + self.record_reader.reset(); + Ok(Arc::new(array)) + } + + fn get_def_levels(&self) -> Option<&[i16]> { + self.def_levels_buffer + .as_ref() + .map(|buf| unsafe { buf.typed_data() }) + } + + fn get_rep_levels(&self) -> Option<&[i16]> { + self.rep_levels_buffer + .as_ref() + .map(|buf| unsafe { buf.typed_data() }) + } +} + /// Primitive array readers are leaves of array reader tree. They accept page iterator /// and read them into primitive arrays. pub struct PrimitiveArrayReader { @@ -94,10 +207,15 @@ impl PrimitiveArrayReader { pub fn new( mut pages: Box, column_desc: ColumnDescPtr, + arrow_type: Option, ) -> Result { - let data_type = parquet_to_arrow_field(column_desc.as_ref())? - .data_type() - .clone(); + // Check if Arrow type is specified, else create it from Parquet type + let data_type = match arrow_type { + Some(t) => t, + None => parquet_to_arrow_field(column_desc.as_ref())? + .data_type() + .clone(), + }; let mut record_reader = RecordReader::::new(column_desc.clone()); if let Some(page_reader) = pages.next() { @@ -149,74 +267,79 @@ impl ArrayReader for PrimitiveArrayReader { } } - // convert to arrays - let array = - match (&self.data_type, T::get_physical_type()) { - (ArrowType::Boolean, PhysicalType::BOOLEAN) => { - BoolConverter::new(BooleanArrayConverter {}) - .convert(self.record_reader.cast::()) - } - (ArrowType::Int8, PhysicalType::INT32) => { - Int8Converter::new().convert(self.record_reader.cast::()) - } - (ArrowType::Int16, PhysicalType::INT32) => { - Int16Converter::new().convert(self.record_reader.cast::()) - } - (ArrowType::Int32, PhysicalType::INT32) => { - Int32Converter::new().convert(self.record_reader.cast::()) - } - (ArrowType::UInt8, PhysicalType::INT32) => { - UInt8Converter::new().convert(self.record_reader.cast::()) - } - (ArrowType::UInt16, PhysicalType::INT32) => { - UInt16Converter::new().convert(self.record_reader.cast::()) - } - (ArrowType::UInt32, PhysicalType::INT32) => { - UInt32Converter::new().convert(self.record_reader.cast::()) - } - (ArrowType::Int64, PhysicalType::INT64) => { - Int64Converter::new().convert(self.record_reader.cast::()) - } - (ArrowType::UInt64, PhysicalType::INT64) => { - UInt64Converter::new().convert(self.record_reader.cast::()) - } - (ArrowType::Float32, PhysicalType::FLOAT) => Float32Converter::new() - .convert(self.record_reader.cast::()), - (ArrowType::Float64, PhysicalType::DOUBLE) => Float64Converter::new() - .convert(self.record_reader.cast::()), - (ArrowType::Timestamp(unit, _), PhysicalType::INT64) => match unit { - TimeUnit::Millisecond => TimestampMillisecondConverter::new() - .convert(self.record_reader.cast::()), - TimeUnit::Microsecond => TimestampMicrosecondConverter::new() - .convert(self.record_reader.cast::()), - _ => Err(general_err!("No conversion from parquet type to arrow type for timestamp with unit {:?}", unit)), - }, - (ArrowType::Date32(unit), PhysicalType::INT32) => match unit { - DateUnit::Day => Date32Converter::new() - .convert(self.record_reader.cast::()), - _ => Err(general_err!("No conversion from parquet type to arrow type for date with unit {:?}", unit)), - } - (ArrowType::Time32(_), PhysicalType::INT32) => { - UInt32Converter::new().convert(self.record_reader.cast::()) - } - (ArrowType::Time64(_), PhysicalType::INT64) => { - UInt64Converter::new().convert(self.record_reader.cast::()) - } - (ArrowType::Interval(IntervalUnit::YearMonth), PhysicalType::INT32) => { - UInt32Converter::new().convert(self.record_reader.cast::()) - } - (ArrowType::Interval(IntervalUnit::DayTime), PhysicalType::INT64) => { - UInt64Converter::new().convert(self.record_reader.cast::()) - } - (ArrowType::Duration(_), PhysicalType::INT64) => { - UInt64Converter::new().convert(self.record_reader.cast::()) - } - (arrow_type, physical_type) => Err(general_err!( - "Reading {:?} type from parquet {:?} is not supported yet.", - arrow_type, - physical_type - )), - }?; + let arrow_data_type = match T::get_physical_type() { + PhysicalType::BOOLEAN => ArrowBooleanType::DATA_TYPE, + PhysicalType::INT32 => ArrowInt32Type::DATA_TYPE, + PhysicalType::INT64 => ArrowInt64Type::DATA_TYPE, + PhysicalType::FLOAT => ArrowFloat32Type::DATA_TYPE, + PhysicalType::DOUBLE => ArrowFloat64Type::DATA_TYPE, + PhysicalType::INT96 + | PhysicalType::BYTE_ARRAY + | PhysicalType::FIXED_LEN_BYTE_ARRAY => { + unreachable!( + "PrimitiveArrayReaders don't support complex physical types" + ); + } + }; + + // Convert to arrays by using the Parquet phyisical type. + // The physical types are then cast to Arrow types if necessary + + let mut record_data = self.record_reader.consume_record_data()?; + + if T::get_physical_type() == PhysicalType::BOOLEAN { + let mut boolean_buffer = BooleanBufferBuilder::new(record_data.len()); + + for e in record_data.data() { + boolean_buffer.append(*e > 0)?; + } + record_data = boolean_buffer.finish(); + } + + let mut array_data = ArrayDataBuilder::new(arrow_data_type) + .len(self.record_reader.num_values()) + .add_buffer(record_data); + + if let Some(b) = self.record_reader.consume_bitmap_buffer()? { + array_data = array_data.null_bit_buffer(b); + } + + let array = match T::get_physical_type() { + PhysicalType::BOOLEAN => { + Arc::new(PrimitiveArray::::from(array_data.build())) + as ArrayRef + } + PhysicalType::INT32 => { + Arc::new(PrimitiveArray::::from(array_data.build())) + as ArrayRef + } + PhysicalType::INT64 => { + Arc::new(PrimitiveArray::::from(array_data.build())) + as ArrayRef + } + PhysicalType::FLOAT => { + Arc::new(PrimitiveArray::::from(array_data.build())) + as ArrayRef + } + PhysicalType::DOUBLE => { + Arc::new(PrimitiveArray::::from(array_data.build())) + as ArrayRef + } + PhysicalType::INT96 + | PhysicalType::BYTE_ARRAY + | PhysicalType::FIXED_LEN_BYTE_ARRAY => { + unreachable!( + "PrimitiveArrayReaders don't support complex physical types" + ); + } + }; + + // cast to Arrow type + // TODO: we need to check if it's fine for this to be fallible. + // My assumption is that we can't get to an illegal cast as we can only + // generate types that are supported, because we'd have gotten them from + // the metadata which was written to the Parquet sink + let array = arrow::compute::cast(&array, self.get_data_type())?; // save definition and repetition buffers self.def_levels_buffer = self.record_reader.consume_def_levels()?; @@ -369,7 +492,13 @@ where data_buffer.into_iter().map(Some).collect() }; - self.converter.convert(data) + let mut array = self.converter.convert(data)?; + + if let ArrowType::Dictionary(_, _) = self.data_type { + array = arrow::compute::cast(&array, &self.data_type)?; + } + + Ok(array) } fn get_def_levels(&self) -> Option<&[i16]> { @@ -390,10 +519,14 @@ where pages: Box, column_desc: ColumnDescPtr, converter: C, + arrow_type: Option, ) -> Result { - let data_type = parquet_to_arrow_field(column_desc.as_ref())? - .data_type() - .clone(); + let data_type = match arrow_type { + Some(t) => t, + None => parquet_to_arrow_field(column_desc.as_ref())? + .data_type() + .clone(), + }; Ok(Self { data_type, @@ -420,6 +553,400 @@ where } } +/// Implementation of list array reader. +pub struct ListArrayReader { + item_reader: Box, + data_type: ArrowType, + item_type: ArrowType, + list_def_level: i16, + list_rep_level: i16, + def_level_buffer: Option, + rep_level_buffer: Option, + _marker: PhantomData, +} + +impl ListArrayReader { + /// Construct list array reader. + pub fn new( + item_reader: Box, + data_type: ArrowType, + item_type: ArrowType, + def_level: i16, + rep_level: i16, + ) -> Self { + Self { + item_reader, + data_type, + item_type, + list_def_level: def_level, + list_rep_level: rep_level, + def_level_buffer: None, + rep_level_buffer: None, + _marker: PhantomData, + } + } +} + +macro_rules! build_empty_list_array_with_primitive_items { + ($item_type:ident) => {{ + let values_builder = PrimitiveBuilder::<$item_type>::new(0); + let mut builder = ListBuilder::new(values_builder); + let empty_list_array = builder.finish(); + Ok(Arc::new(empty_list_array)) + }}; +} + +macro_rules! build_empty_list_array_with_non_primitive_items { + ($builder:ident) => {{ + let values_builder = $builder::new(0); + let mut builder = ListBuilder::new(values_builder); + let empty_list_array = builder.finish(); + Ok(Arc::new(empty_list_array)) + }}; +} + +fn build_empty_list_array(item_type: ArrowType) -> Result { + match item_type { + ArrowType::UInt8 => build_empty_list_array_with_primitive_items!(ArrowUInt8Type), + ArrowType::UInt16 => { + build_empty_list_array_with_primitive_items!(ArrowUInt16Type) + } + ArrowType::UInt32 => { + build_empty_list_array_with_primitive_items!(ArrowUInt32Type) + } + ArrowType::UInt64 => { + build_empty_list_array_with_primitive_items!(ArrowUInt64Type) + } + ArrowType::Int8 => build_empty_list_array_with_primitive_items!(ArrowInt8Type), + ArrowType::Int16 => build_empty_list_array_with_primitive_items!(ArrowInt16Type), + ArrowType::Int32 => build_empty_list_array_with_primitive_items!(ArrowInt32Type), + ArrowType::Int64 => build_empty_list_array_with_primitive_items!(ArrowInt64Type), + ArrowType::Float32 => { + build_empty_list_array_with_primitive_items!(ArrowFloat32Type) + } + ArrowType::Float64 => { + build_empty_list_array_with_primitive_items!(ArrowFloat64Type) + } + ArrowType::Boolean => { + build_empty_list_array_with_primitive_items!(ArrowBooleanType) + } + ArrowType::Date32(_) => { + build_empty_list_array_with_primitive_items!(ArrowDate32Type) + } + ArrowType::Date64(_) => { + build_empty_list_array_with_primitive_items!(ArrowDate64Type) + } + ArrowType::Time32(ArrowTimeUnit::Second) => { + build_empty_list_array_with_primitive_items!(ArrowTime32SecondType) + } + ArrowType::Time32(ArrowTimeUnit::Millisecond) => { + build_empty_list_array_with_primitive_items!(ArrowTime32MillisecondType) + } + ArrowType::Time64(ArrowTimeUnit::Microsecond) => { + build_empty_list_array_with_primitive_items!(ArrowTime64MicrosecondType) + } + ArrowType::Time64(ArrowTimeUnit::Nanosecond) => { + build_empty_list_array_with_primitive_items!(ArrowTime64NanosecondType) + } + ArrowType::Duration(ArrowTimeUnit::Second) => { + build_empty_list_array_with_primitive_items!(ArrowDurationSecondType) + } + ArrowType::Duration(ArrowTimeUnit::Millisecond) => { + build_empty_list_array_with_primitive_items!(ArrowDurationMillisecondType) + } + ArrowType::Duration(ArrowTimeUnit::Microsecond) => { + build_empty_list_array_with_primitive_items!(ArrowDurationMicrosecondType) + } + ArrowType::Duration(ArrowTimeUnit::Nanosecond) => { + build_empty_list_array_with_primitive_items!(ArrowDurationNanosecondType) + } + ArrowType::Timestamp(ArrowTimeUnit::Second, _) => { + build_empty_list_array_with_primitive_items!(ArrowTimestampSecondType) + } + ArrowType::Timestamp(ArrowTimeUnit::Millisecond, _) => { + build_empty_list_array_with_primitive_items!(ArrowTimestampMillisecondType) + } + ArrowType::Timestamp(ArrowTimeUnit::Microsecond, _) => { + build_empty_list_array_with_primitive_items!(ArrowTimestampMicrosecondType) + } + ArrowType::Timestamp(ArrowTimeUnit::Nanosecond, _) => { + build_empty_list_array_with_primitive_items!(ArrowTimestampNanosecondType) + } + ArrowType::Utf8 => { + build_empty_list_array_with_non_primitive_items!(StringBuilder) + } + ArrowType::Binary => { + build_empty_list_array_with_non_primitive_items!(BinaryBuilder) + } + _ => Err(ParquetError::General(format!( + "ListArray of type List({:?}) is not supported by array_reader", + item_type + ))), + } +} + +macro_rules! remove_primitive_array_indices { + ($arr: expr, $item_type:ty, $indices:expr) => {{ + let array_data = match $arr.as_any().downcast_ref::>() { + Some(a) => a, + _ => return Err(ParquetError::General(format!("Error generating next batch for ListArray: {:?} cannot be downcast to PrimitiveArray", $arr))), + }; + let mut builder = PrimitiveBuilder::<$item_type>::new($arr.len()); + for i in 0..array_data.len() { + if !$indices.contains(&i) { + if array_data.is_null(i) { + builder.append_null()?; + } else { + builder.append_value(array_data.value(i))?; + } + } + } + Ok(Arc::new(builder.finish())) + }}; +} + +macro_rules! remove_array_indices_custom_builder { + ($arr: expr, $array_type:ty, $item_builder:ident, $indices:expr) => {{ + let array_data = match $arr.as_any().downcast_ref::<$array_type>() { + Some(a) => a, + _ => return Err(ParquetError::General(format!("Error generating next batch for ListArray: {:?} cannot be downcast to PrimitiveArray", $arr))), + }; + let mut builder = $item_builder::new(array_data.len()); + + for i in 0..array_data.len() { + if !$indices.contains(&i) { + if array_data.is_null(i) { + builder.append_null()?; + } else { + builder.append_value(array_data.value(i))?; + } + } + } + Ok(Arc::new(builder.finish())) + }}; +} + +macro_rules! remove_fixed_size_binary_array_indices { + ($arr: expr, $array_type:ty, $item_builder:ident, $indices:expr, $len:expr) => {{ + let array_data = match $arr.as_any().downcast_ref::<$array_type>() { + Some(a) => a, + _ => return Err(ParquetError::General(format!("Error generating next batch for ListArray: {:?} cannot be downcast to PrimitiveArray", $arr))), + }; + let mut builder = FixedSizeBinaryBuilder::new(array_data.len(), $len); + for i in 0..array_data.len() { + if !$indices.contains(&i) { + if array_data.is_null(i) { + builder.append_null()?; + } else { + builder.append_value(array_data.value(i))?; + } + } + } + Ok(Arc::new(builder.finish())) + }}; +} + +fn remove_indices( + arr: ArrayRef, + item_type: ArrowType, + indices: Vec, +) -> Result { + match item_type { + ArrowType::UInt8 => remove_primitive_array_indices!(arr, ArrowUInt8Type, indices), + ArrowType::UInt16 => { + remove_primitive_array_indices!(arr, ArrowUInt16Type, indices) + } + ArrowType::UInt32 => { + remove_primitive_array_indices!(arr, ArrowUInt32Type, indices) + } + ArrowType::UInt64 => { + remove_primitive_array_indices!(arr, ArrowUInt64Type, indices) + } + ArrowType::Int8 => remove_primitive_array_indices!(arr, ArrowInt8Type, indices), + ArrowType::Int16 => remove_primitive_array_indices!(arr, ArrowInt16Type, indices), + ArrowType::Int32 => remove_primitive_array_indices!(arr, ArrowInt32Type, indices), + ArrowType::Int64 => remove_primitive_array_indices!(arr, ArrowInt64Type, indices), + ArrowType::Float32 => { + remove_primitive_array_indices!(arr, ArrowFloat32Type, indices) + } + ArrowType::Float64 => { + remove_primitive_array_indices!(arr, ArrowFloat64Type, indices) + } + ArrowType::Boolean => { + remove_primitive_array_indices!(arr, ArrowBooleanType, indices) + } + ArrowType::Date32(_) => { + remove_primitive_array_indices!(arr, ArrowDate32Type, indices) + } + ArrowType::Date64(_) => { + remove_primitive_array_indices!(arr, ArrowDate64Type, indices) + } + ArrowType::Time32(ArrowTimeUnit::Second) => { + remove_primitive_array_indices!(arr, ArrowTime32SecondType, indices) + } + ArrowType::Time32(ArrowTimeUnit::Millisecond) => { + remove_primitive_array_indices!(arr, ArrowTime32MillisecondType, indices) + } + ArrowType::Time64(ArrowTimeUnit::Microsecond) => { + remove_primitive_array_indices!(arr, ArrowTime64MicrosecondType, indices) + } + ArrowType::Time64(ArrowTimeUnit::Nanosecond) => { + remove_primitive_array_indices!(arr, ArrowTime64NanosecondType, indices) + } + ArrowType::Duration(ArrowTimeUnit::Second) => { + remove_primitive_array_indices!(arr, ArrowDurationSecondType, indices) + } + ArrowType::Duration(ArrowTimeUnit::Millisecond) => { + remove_primitive_array_indices!(arr, ArrowDurationMillisecondType, indices) + } + ArrowType::Duration(ArrowTimeUnit::Microsecond) => { + remove_primitive_array_indices!(arr, ArrowDurationMicrosecondType, indices) + } + ArrowType::Duration(ArrowTimeUnit::Nanosecond) => { + remove_primitive_array_indices!(arr, ArrowDurationNanosecondType, indices) + } + ArrowType::Timestamp(ArrowTimeUnit::Second, _) => { + remove_primitive_array_indices!(arr, ArrowTimestampSecondType, indices) + } + ArrowType::Timestamp(ArrowTimeUnit::Millisecond, _) => { + remove_primitive_array_indices!(arr, ArrowTimestampMillisecondType, indices) + } + ArrowType::Timestamp(ArrowTimeUnit::Microsecond, _) => { + remove_primitive_array_indices!(arr, ArrowTimestampMicrosecondType, indices) + } + ArrowType::Timestamp(ArrowTimeUnit::Nanosecond, _) => { + remove_primitive_array_indices!(arr, ArrowTimestampNanosecondType, indices) + } + ArrowType::Utf8 => { + remove_array_indices_custom_builder!(arr, StringArray, StringBuilder, indices) + } + ArrowType::Binary => { + remove_array_indices_custom_builder!(arr, BinaryArray, BinaryBuilder, indices) + } + ArrowType::FixedSizeBinary(size) => remove_fixed_size_binary_array_indices!( + arr, + FixedSizeBinaryArray, + FixedSizeBinaryBuilder, + indices, + size + ), + _ => Err(ParquetError::General(format!( + "ListArray of type List({:?}) is not supported by array_reader", + item_type + ))), + } +} + +/// Implementation of ListArrayReader. Nested lists and lists of structs are not yet supported. +impl ArrayReader for ListArrayReader { + fn as_any(&self) -> &dyn Any { + self + } + + /// Returns data type. + /// This must be a List. + fn get_data_type(&self) -> &ArrowType { + &self.data_type + } + + fn next_batch(&mut self, batch_size: usize) -> Result { + let next_batch_array = self.item_reader.next_batch(batch_size)?; + let item_type = self.item_reader.get_data_type().clone(); + + if next_batch_array.len() == 0 { + return build_empty_list_array(item_type); + } + let def_levels = self + .item_reader + .get_def_levels() + .ok_or_else(|| ArrowError("item_reader def levels are None.".to_string()))?; + let rep_levels = self + .item_reader + .get_rep_levels() + .ok_or_else(|| ArrowError("item_reader rep levels are None.".to_string()))?; + + if !((def_levels.len() == rep_levels.len()) + && (rep_levels.len() == next_batch_array.len())) + { + return Err(ArrowError( + "Expected item_reader def_levels and rep_levels to be same length as batch".to_string(), + )); + } + + // Need to remove from the values array the nulls that represent null lists rather than null items + // null lists have def_level = 0 + let mut null_list_indices: Vec = Vec::new(); + for i in 0..def_levels.len() { + if def_levels[i] == 0 { + null_list_indices.push(i); + } + } + let batch_values = match null_list_indices.len() { + 0 => next_batch_array.clone(), + _ => remove_indices(next_batch_array.clone(), item_type, null_list_indices)?, + }; + + // null list has def_level = 0 + // empty list has def_level = 1 + // null item in a list has def_level = 2 + // non-null item has def_level = 3 + // first item in each list has rep_level = 0, subsequent items have rep_level = 1 + + let mut offsets: Vec = Vec::new(); + let mut cur_offset = OffsetSize::zero(); + for i in 0..rep_levels.len() { + if rep_levels[i] == 0 { + offsets.push(cur_offset) + } + if def_levels[i] > 0 { + cur_offset = cur_offset + OffsetSize::one(); + } + } + offsets.push(cur_offset); + + let num_bytes = bit_util::ceil(offsets.len(), 8); + let mut null_buf = MutableBuffer::new(num_bytes).with_bitset(num_bytes, false); + let null_slice = null_buf.data_mut(); + let mut list_index = 0; + for i in 0..rep_levels.len() { + if rep_levels[i] == 0 && def_levels[i] != 0 { + bit_util::set_bit(null_slice, list_index); + } + if rep_levels[i] == 0 { + list_index += 1; + } + } + let value_offsets = Buffer::from(&offsets.to_byte_slice()); + + // null list has def_level = 0 + let null_count = def_levels.iter().filter(|x| x == &&0).count(); + + let list_data = ArrayData::builder(self.get_data_type().clone()) + .len(offsets.len() - 1) + .add_buffer(value_offsets) + .add_child_data(batch_values.data()) + .null_bit_buffer(null_buf.freeze()) + .null_count(null_count) + .offset(next_batch_array.offset()) + .build(); + + let result_array = GenericListArray::::from(list_data); + Ok(Arc::new(result_array)) + } + + fn get_def_levels(&self) -> Option<&[i16]> { + self.def_level_buffer + .as_ref() + .map(|buf| unsafe { buf.typed_data() }) + } + + fn get_rep_levels(&self) -> Option<&[i16]> { + self.rep_level_buffer + .as_ref() + .map(|buf| unsafe { buf.typed_data() }) + } +} + /// Implementation of struct array reader. pub struct StructArrayReader { children: Vec>, @@ -595,6 +1122,7 @@ impl ArrayReader for StructArrayReader { /// Create array reader from parquet schema, column indices, and parquet file reader. pub fn build_array_reader( parquet_schema: SchemaDescPtr, + arrow_schema: Schema, column_indices: T, file_reader: Rc, ) -> Result> @@ -633,13 +1161,19 @@ where fields: filtered_root_fields, }; - ArrayReaderBuilder::new(Rc::new(proj), Rc::new(leaves), file_reader) - .build_array_reader() + ArrayReaderBuilder::new( + Rc::new(proj), + Rc::new(arrow_schema), + Rc::new(leaves), + file_reader, + ) + .build_array_reader() } /// Used to build array reader. struct ArrayReaderBuilder { root_schema: TypePtr, + arrow_schema: Rc, // Key: columns that need to be included in final array builder // Value: column index in schema columns_included: Rc>, @@ -756,16 +1290,94 @@ impl<'a> TypeVisitor>, &'a ArrayReaderBuilderContext } /// Build array reader for list type. - /// Currently this is not supported. fn visit_list_with_item( &mut self, - _list_type: Rc, - _item_type: &Type, - _context: &'a ArrayReaderBuilderContext, + list_type: Rc, + item_type: Rc, + context: &'a ArrayReaderBuilderContext, ) -> Result>> { - Err(ArrowError( - "Reading parquet list array into arrow is not supported yet!".to_string(), - )) + let list_child = &list_type + .get_fields() + .first() + .ok_or_else(|| ArrowError("List field must have a child.".to_string()))?; + let mut new_context = context.clone(); + + new_context.path.append(vec![list_type.name().to_string()]); + + match list_type.get_basic_info().repetition() { + Repetition::REPEATED => { + new_context.def_level += 1; + new_context.rep_level += 1; + } + Repetition::OPTIONAL => { + new_context.def_level += 1; + } + _ => (), + } + + match list_child.get_basic_info().repetition() { + Repetition::REPEATED => { + new_context.def_level += 1; + new_context.rep_level += 1; + } + Repetition::OPTIONAL => { + new_context.def_level += 1; + } + _ => (), + } + + let item_reader = self + .dispatch(item_type.clone(), &new_context) + .unwrap() + .unwrap(); + + let item_reader_type = item_reader.get_data_type().clone(); + + match item_reader_type { + ArrowType::List(_) + | ArrowType::FixedSizeList(_, _) + | ArrowType::Struct(_) + | ArrowType::Dictionary(_, _) => Err(ArrowError(format!( + "reading List({:?}) into arrow not supported yet", + item_type + ))), + _ => { + let arrow_type = self + .arrow_schema + .field_with_name(list_type.name()) + .ok() + .map(|f| f.data_type().to_owned()) + .unwrap_or_else(|| { + ArrowType::List(Box::new(item_reader_type.clone())) + }); + + let list_array_reader: Box = match arrow_type { + ArrowType::List(_) => Box::new(ListArrayReader::::new( + item_reader, + arrow_type, + item_reader_type, + new_context.def_level, + new_context.rep_level, + )), + ArrowType::LargeList(_) => Box::new(ListArrayReader::::new( + item_reader, + arrow_type, + item_reader_type, + new_context.def_level, + new_context.rep_level, + )), + + _ => { + return Err(ArrowError(format!( + "creating ListArrayReader with type {:?} should be unreachable", + arrow_type + ))) + } + }; + + Ok(Some(list_array_reader)) + } + } } } @@ -773,11 +1385,13 @@ impl<'a> ArrayReaderBuilder { /// Construct array reader builder. fn new( root_schema: TypePtr, + arrow_schema: Rc, columns_included: Rc>, file_reader: Rc, ) -> Self { Self { root_schema, + arrow_schema, columns_included, file_reader, } @@ -818,18 +1432,37 @@ impl<'a> ArrayReaderBuilder { self.file_reader.clone(), )?); + let arrow_type = self + .arrow_schema + .field_with_name(cur_type.name()) + .ok() + .map(|f| f.data_type()) + .cloned(); + match cur_type.get_physical_type() { PhysicalType::BOOLEAN => Ok(Box::new(PrimitiveArrayReader::::new( page_iterator, column_desc, + arrow_type, )?)), - PhysicalType::INT32 => Ok(Box::new(PrimitiveArrayReader::::new( - page_iterator, - column_desc, - )?)), + PhysicalType::INT32 => { + if let Some(ArrowType::Null) = arrow_type { + Ok(Box::new(NullArrayReader::::new( + page_iterator, + column_desc, + )?)) + } else { + Ok(Box::new(PrimitiveArrayReader::::new( + page_iterator, + column_desc, + arrow_type, + )?)) + } + } PhysicalType::INT64 => Ok(Box::new(PrimitiveArrayReader::::new( page_iterator, column_desc, + arrow_type, )?)), PhysicalType::INT96 => { let converter = Int96Converter::new(Int96ArrayConverter {}); @@ -837,24 +1470,61 @@ impl<'a> ArrayReaderBuilder { Int96Type, Int96Converter, >::new( - page_iterator, column_desc, converter + page_iterator, + column_desc, + converter, + arrow_type, )?)) } PhysicalType::FLOAT => Ok(Box::new(PrimitiveArrayReader::::new( page_iterator, column_desc, + arrow_type, )?)), - PhysicalType::DOUBLE => Ok(Box::new( - PrimitiveArrayReader::::new(page_iterator, column_desc)?, - )), + PhysicalType::DOUBLE => { + Ok(Box::new(PrimitiveArrayReader::::new( + page_iterator, + column_desc, + arrow_type, + )?)) + } PhysicalType::BYTE_ARRAY => { if cur_type.get_basic_info().logical_type() == LogicalType::UTF8 { - let converter = Utf8Converter::new(Utf8ArrayConverter {}); + if let Some(ArrowType::LargeUtf8) = arrow_type { + let converter = + LargeUtf8Converter::new(LargeUtf8ArrayConverter {}); + Ok(Box::new(ComplexObjectArrayReader::< + ByteArrayType, + LargeUtf8Converter, + >::new( + page_iterator, + column_desc, + converter, + arrow_type, + )?)) + } else { + let converter = Utf8Converter::new(Utf8ArrayConverter {}); + Ok(Box::new(ComplexObjectArrayReader::< + ByteArrayType, + Utf8Converter, + >::new( + page_iterator, + column_desc, + converter, + arrow_type, + )?)) + } + } else if let Some(ArrowType::LargeBinary) = arrow_type { + let converter = + LargeBinaryConverter::new(LargeBinaryArrayConverter {}); Ok(Box::new(ComplexObjectArrayReader::< ByteArrayType, - Utf8Converter, + LargeBinaryConverter, >::new( - page_iterator, column_desc, converter + page_iterator, + column_desc, + converter, + arrow_type, )?)) } else { let converter = BinaryConverter::new(BinaryArrayConverter {}); @@ -862,7 +1532,10 @@ impl<'a> ArrayReaderBuilder { ByteArrayType, BinaryConverter, >::new( - page_iterator, column_desc, converter + page_iterator, + column_desc, + converter, + arrow_type, )?)) } } @@ -884,7 +1557,10 @@ impl<'a> ArrayReaderBuilder { FixedLenByteArrayType, FixedLenBinaryConverter, >::new( - page_iterator, column_desc, converter + page_iterator, + column_desc, + converter, + arrow_type, )?)) } } @@ -901,11 +1577,15 @@ impl<'a> ArrayReaderBuilder { for child in cur_type.get_fields() { if let Some(child_reader) = self.dispatch(child.clone(), context)? { - fields.push(Field::new( - child.name(), - child_reader.get_data_type().clone(), - child.is_optional(), - )); + let field = match self.arrow_schema.field_with_name(child.name()) { + Ok(f) => f.to_owned(), + _ => Field::new( + child.name(), + child_reader.get_data_type().clone(), + child.is_optional(), + ), + }; + fields.push(field); children_reader.push(child_reader); } } @@ -928,6 +1608,7 @@ impl<'a> ArrayReaderBuilder { mod tests { use super::*; use crate::arrow::converter::Utf8Converter; + use crate::arrow::schema::parquet_to_arrow_schema; use crate::basic::{Encoding, Type as PhysicalType}; use crate::column::page::{Page, PageReader}; use crate::data_type::{ByteArray, DataType, Int32Type, Int64Type}; @@ -939,12 +1620,17 @@ mod tests { DataPageBuilder, DataPageBuilderImpl, InMemoryPageIterator, }; use crate::util::test_common::{get_test_file, make_pages}; - use arrow::array::{Array, ArrayRef, PrimitiveArray, StringArray, StructArray}; + use arrow::array::{ + Array, ArrayRef, LargeListArray, ListArray, PrimitiveArray, StringArray, + StructArray, + }; use arrow::datatypes::{ - DataType as ArrowType, Date32Type as ArrowDate32, Field, Int32Type as ArrowInt32, + ArrowPrimitiveType, DataType as ArrowType, Date32Type as ArrowDate32, Field, + Int32Type as ArrowInt32, Int64Type as ArrowInt64, + Time32MillisecondType as ArrowTime32MillisecondArray, + Time64MicrosecondType as ArrowTime64MicrosecondArray, TimestampMicrosecondType as ArrowTimestampMicrosecondType, TimestampMillisecondType as ArrowTimestampMillisecondType, - UInt32Type as ArrowUInt32, UInt64Type as ArrowUInt64, }; use rand::distributions::uniform::SampleUniform; use rand::{thread_rng, Rng}; @@ -1011,9 +1697,12 @@ mod tests { let column_desc = schema.column(0); let page_iterator = EmptyPageIterator::new(schema); - let mut array_reader = - PrimitiveArrayReader::::new(Box::new(page_iterator), column_desc) - .unwrap(); + let mut array_reader = PrimitiveArrayReader::::new( + Box::new(page_iterator), + column_desc, + None, + ) + .unwrap(); // expect no values to be read let array = array_reader.next_batch(50).unwrap(); @@ -1058,6 +1747,7 @@ mod tests { let mut array_reader = PrimitiveArrayReader::::new( Box::new(page_iterator), column_desc, + None, ) .unwrap(); @@ -1101,7 +1791,7 @@ mod tests { } macro_rules! test_primitive_array_reader_one_type { - ($arrow_parquet_type:ty, $physical_type:expr, $logical_type_str:expr, $result_arrow_type:ty, $result_primitive_type:ty) => {{ + ($arrow_parquet_type:ty, $physical_type:expr, $logical_type_str:expr, $result_arrow_type:ty, $result_arrow_cast_type:ty, $result_primitive_type:ty) => {{ let message_type = format!( " message test_schema {{ @@ -1112,7 +1802,7 @@ mod tests { ); let schema = parse_message_type(&message_type) .map(|t| Rc::new(SchemaDescriptor::new(Rc::new(t)))) - .unwrap(); + .expect("Unable to parse message type into a schema descriptor"); let column_desc = schema.column(0); @@ -1141,25 +1831,50 @@ mod tests { let mut array_reader = PrimitiveArrayReader::<$arrow_parquet_type>::new( Box::new(page_iterator), column_desc.clone(), + None, ) - .unwrap(); + .expect("Unable to get array reader"); - let array = array_reader.next_batch(50).unwrap(); + let array = array_reader + .next_batch(50) + .expect("Unable to get batch from reader"); + let result_data_type = <$result_arrow_type>::DATA_TYPE; let array = array .as_any() .downcast_ref::>() - .unwrap(); - - assert_eq!( - &PrimitiveArray::<$result_arrow_type>::from( - data[0..50] - .iter() - .map(|x| *x as $result_primitive_type) - .collect::>() - ), - array + .expect( + format!( + "Unable to downcast {:?} to {:?}", + array.data_type(), + result_data_type + ) + .as_str(), + ); + + // create expected array as primitive, and cast to result type + let expected = PrimitiveArray::<$result_arrow_cast_type>::from( + data[0..50] + .iter() + .map(|x| *x as $result_primitive_type) + .collect::>(), ); + let expected = Arc::new(expected) as ArrayRef; + let expected = arrow::compute::cast(&expected, &result_data_type) + .expect("Unable to cast expected array"); + assert_eq!(expected.data_type(), &result_data_type); + let expected = expected + .as_any() + .downcast_ref::>() + .expect( + format!( + "Unable to downcast expected {:?} to {:?}", + expected.data_type(), + result_data_type + ) + .as_str(), + ); + assert_eq!(expected, array); } }}; } @@ -1171,27 +1886,31 @@ mod tests { PhysicalType::INT32, "DATE", ArrowDate32, + ArrowInt32, i32 ); test_primitive_array_reader_one_type!( Int32Type, PhysicalType::INT32, "TIME_MILLIS", - ArrowUInt32, - u32 + ArrowTime32MillisecondArray, + ArrowInt32, + i32 ); test_primitive_array_reader_one_type!( Int64Type, PhysicalType::INT64, "TIME_MICROS", - ArrowUInt64, - u64 + ArrowTime64MicrosecondArray, + ArrowInt64, + i64 ); test_primitive_array_reader_one_type!( Int64Type, PhysicalType::INT64, "TIMESTAMP_MILLIS", ArrowTimestampMillisecondType, + ArrowInt64, i64 ); test_primitive_array_reader_one_type!( @@ -1199,6 +1918,7 @@ mod tests { PhysicalType::INT64, "TIMESTAMP_MICROS", ArrowTimestampMicrosecondType, + ArrowInt64, i64 ); } @@ -1245,6 +1965,7 @@ mod tests { let mut array_reader = PrimitiveArrayReader::::new( Box::new(page_iterator), column_desc, + None, ) .unwrap(); @@ -1358,6 +2079,7 @@ mod tests { Box::new(page_iterator), column_desc, converter, + None, ) .unwrap(); @@ -1543,8 +2265,16 @@ mod tests { let file = get_test_file("nulls.snappy.parquet"); let file_reader = Rc::new(SerializedFileReader::new(file).unwrap()); + let file_metadata = file_reader.metadata().file_metadata(); + let arrow_schema = parquet_to_arrow_schema( + file_metadata.schema_descr(), + file_metadata.key_value_metadata(), + ) + .unwrap(); + let array_reader = build_array_reader( file_reader.metadata().file_metadata().schema_descr_ptr(), + arrow_schema, vec![0usize].into_iter(), file_reader, ) @@ -1559,4 +2289,113 @@ mod tests { assert_eq!(array_reader.get_data_type(), &arrow_type); } + + #[test] + fn test_list_array_reader() { + // [[1, null, 2], null, [3, 4]] + let array = Arc::new(PrimitiveArray::::from(vec![ + Some(1), + None, + Some(2), + None, + Some(3), + Some(4), + ])); + let item_array_reader = InMemoryArrayReader::new( + ArrowType::Int32, + array, + Some(vec![3, 2, 3, 0, 3, 3]), + Some(vec![0, 1, 1, 0, 0, 1]), + ); + + let mut list_array_reader = ListArrayReader::::new( + Box::new(item_array_reader), + ArrowType::List(Box::new(ArrowType::Int32)), + ArrowType::Int32, + 1, + 1, + ); + + let next_batch = list_array_reader.next_batch(1024).unwrap(); + let list_array = next_batch.as_any().downcast_ref::().unwrap(); + + assert_eq!(3, list_array.len()); + // This passes as I expect + assert_eq!(1, list_array.null_count()); + + assert_eq!( + list_array + .value(0) + .as_any() + .downcast_ref::>() + .unwrap(), + &PrimitiveArray::::from(vec![Some(1), None, Some(2)]) + ); + + assert!(list_array.is_null(1)); + + assert_eq!( + list_array + .value(2) + .as_any() + .downcast_ref::>() + .unwrap(), + &PrimitiveArray::::from(vec![Some(3), Some(4)]) + ); + } + + #[test] + fn test_large_list_array_reader() { + // [[1, null, 2], null, [3, 4]] + let array = Arc::new(PrimitiveArray::::from(vec![ + Some(1), + None, + Some(2), + None, + Some(3), + Some(4), + ])); + let item_array_reader = InMemoryArrayReader::new( + ArrowType::Int32, + array, + Some(vec![3, 2, 3, 0, 3, 3]), + Some(vec![0, 1, 1, 0, 0, 1]), + ); + + let mut list_array_reader = ListArrayReader::::new( + Box::new(item_array_reader), + ArrowType::LargeList(Box::new(ArrowType::Int32)), + ArrowType::Int32, + 1, + 1, + ); + + let next_batch = list_array_reader.next_batch(1024).unwrap(); + let list_array = next_batch + .as_any() + .downcast_ref::() + .unwrap(); + + assert_eq!(3, list_array.len()); + + assert_eq!( + list_array + .value(0) + .as_any() + .downcast_ref::>() + .unwrap(), + &PrimitiveArray::::from(vec![Some(1), None, Some(2)]) + ); + + assert!(list_array.is_null(1)); + + assert_eq!( + list_array + .value(2) + .as_any() + .downcast_ref::>() + .unwrap(), + &PrimitiveArray::::from(vec![Some(3), Some(4)]) + ); + } } diff --git a/rust/parquet/src/arrow/arrow_reader.rs b/rust/parquet/src/arrow/arrow_reader.rs index b654de1ad0a..88af583a3d4 100644 --- a/rust/parquet/src/arrow/arrow_reader.rs +++ b/rust/parquet/src/arrow/arrow_reader.rs @@ -19,7 +19,9 @@ use crate::arrow::array_reader::{build_array_reader, ArrayReader, StructArrayReader}; use crate::arrow::schema::parquet_to_arrow_schema; -use crate::arrow::schema::parquet_to_arrow_schema_by_columns; +use crate::arrow::schema::{ + parquet_to_arrow_schema_by_columns, parquet_to_arrow_schema_by_root_columns, +}; use crate::errors::{ParquetError, Result}; use crate::file::reader::FileReader; use arrow::datatypes::{DataType as ArrowType, Schema, SchemaRef}; @@ -40,7 +42,12 @@ pub trait ArrowReader { /// Read parquet schema and convert it into arrow schema. /// This schema only includes columns identified by `column_indices`. - fn get_schema_by_columns(&mut self, column_indices: T) -> Result + /// To select leaf columns (i.e. `a.b.c` instead of `a`), set `leaf_columns = true` + fn get_schema_by_columns( + &mut self, + column_indices: T, + leaf_columns: bool, + ) -> Result where T: IntoIterator; @@ -84,16 +91,28 @@ impl ArrowReader for ParquetFileArrowReader { ) } - fn get_schema_by_columns(&mut self, column_indices: T) -> Result + fn get_schema_by_columns( + &mut self, + column_indices: T, + leaf_columns: bool, + ) -> Result where T: IntoIterator, { let file_metadata = self.file_reader.metadata().file_metadata(); - parquet_to_arrow_schema_by_columns( - file_metadata.schema_descr(), - column_indices, - file_metadata.key_value_metadata(), - ) + if leaf_columns { + parquet_to_arrow_schema_by_columns( + file_metadata.schema_descr(), + column_indices, + file_metadata.key_value_metadata(), + ) + } else { + parquet_to_arrow_schema_by_root_columns( + file_metadata.schema_descr(), + column_indices, + file_metadata.key_value_metadata(), + ) + } } fn get_record_reader( @@ -123,6 +142,7 @@ impl ArrowReader for ParquetFileArrowReader { .metadata() .file_metadata() .schema_descr_ptr(), + self.get_schema()?, column_indices, self.file_reader.clone(), )?; diff --git a/rust/parquet/src/arrow/arrow_writer.rs b/rust/parquet/src/arrow/arrow_writer.rs new file mode 100644 index 00000000000..68a55c67f6a --- /dev/null +++ b/rust/parquet/src/arrow/arrow_writer.rs @@ -0,0 +1,1386 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Contains writer which writes arrow data into parquet data. + +use std::rc::Rc; + +use arrow::array as arrow_array; +use arrow::datatypes::{DataType as ArrowDataType, SchemaRef}; +use arrow::record_batch::RecordBatch; +use arrow_array::Array; + +use super::schema::add_encoded_arrow_schema_to_metadata; +use crate::column::writer::{ColumnWriter, ColumnWriterImpl}; +use crate::errors::{ParquetError, Result}; +use crate::file::properties::WriterProperties; +use crate::{ + data_type::*, + file::writer::{FileWriter, ParquetWriter, RowGroupWriter, SerializedFileWriter}, +}; + +/// Arrow writer +/// +/// Writes Arrow `RecordBatch`es to a Parquet writer +pub struct ArrowWriter { + /// Underlying Parquet writer + writer: SerializedFileWriter, + /// A copy of the Arrow schema. + /// + /// The schema is used to verify that each record batch written has the correct schema + arrow_schema: SchemaRef, +} + +impl ArrowWriter { + /// Try to create a new Arrow writer + /// + /// The writer will fail if: + /// * a `SerializedFileWriter` cannot be created from the ParquetWriter + /// * the Arrow schema contains unsupported datatypes such as Unions + pub fn try_new( + writer: W, + arrow_schema: SchemaRef, + props: Option, + ) -> Result { + let schema = crate::arrow::arrow_to_parquet_schema(&arrow_schema)?; + // add serialized arrow schema + let mut props = props.unwrap_or_else(|| WriterProperties::builder().build()); + add_encoded_arrow_schema_to_metadata(&arrow_schema, &mut props); + + let file_writer = SerializedFileWriter::new( + writer.try_clone()?, + schema.root_schema_ptr(), + Rc::new(props), + )?; + + Ok(Self { + writer: file_writer, + arrow_schema, + }) + } + + /// Write a RecordBatch to writer + /// + /// *NOTE:* The writer currently does not support all Arrow data types + pub fn write(&mut self, batch: &RecordBatch) -> Result<()> { + // validate batch schema against writer's supplied schema + if self.arrow_schema != batch.schema() { + return Err(ParquetError::ArrowError( + "Record batch schema does not match writer schema".to_string(), + )); + } + // compute the definition and repetition levels of the batch + let mut levels = vec![]; + batch.columns().iter().for_each(|array| { + let mut array_levels = + get_levels(array, 0, &vec![1i16; batch.num_rows()][..], None); + levels.append(&mut array_levels); + }); + // reverse levels so we can use Vec::pop(&mut self) + levels.reverse(); + + let mut row_group_writer = self.writer.next_row_group()?; + + // write leaves + for column in batch.columns() { + write_leaves(&mut row_group_writer, column, &mut levels)?; + } + + self.writer.close_row_group(row_group_writer) + } + + /// Close and finalise the underlying Parquet writer + pub fn close(&mut self) -> Result<()> { + self.writer.close() + } +} + +/// Convenience method to get the next ColumnWriter from the RowGroupWriter +#[inline] +#[allow(clippy::borrowed_box)] +fn get_col_writer( + row_group_writer: &mut Box, +) -> Result { + let col_writer = row_group_writer + .next_column()? + .expect("Unable to get column writer"); + Ok(col_writer) +} + +#[allow(clippy::borrowed_box)] +fn write_leaves( + mut row_group_writer: &mut Box, + array: &arrow_array::ArrayRef, + mut levels: &mut Vec, +) -> Result<()> { + match array.data_type() { + ArrowDataType::Null + | ArrowDataType::Int8 + | ArrowDataType::Int16 + | ArrowDataType::Int32 + | ArrowDataType::Int64 + | ArrowDataType::UInt8 + | ArrowDataType::UInt16 + | ArrowDataType::UInt32 + | ArrowDataType::UInt64 + | ArrowDataType::Float32 + | ArrowDataType::Float64 + | ArrowDataType::Timestamp(_, _) + | ArrowDataType::Date32(_) + | ArrowDataType::Date64(_) + | ArrowDataType::Time32(_) + | ArrowDataType::Time64(_) + | ArrowDataType::Duration(_) + | ArrowDataType::Interval(_) + | ArrowDataType::LargeBinary + | ArrowDataType::Binary + | ArrowDataType::Utf8 + | ArrowDataType::LargeUtf8 => { + let mut col_writer = get_col_writer(&mut row_group_writer)?; + write_leaf( + &mut col_writer, + array, + levels.pop().expect("Levels exhausted"), + )?; + row_group_writer.close_column(col_writer)?; + Ok(()) + } + ArrowDataType::List(_) | ArrowDataType::LargeList(_) => { + // write the child list + let data = array.data(); + let child_array = arrow_array::make_array(data.child_data()[0].clone()); + write_leaves(&mut row_group_writer, &child_array, &mut levels)?; + Ok(()) + } + ArrowDataType::Struct(_) => { + let struct_array: &arrow_array::StructArray = array + .as_any() + .downcast_ref::() + .expect("Unable to get struct array"); + for field in struct_array.columns() { + write_leaves(&mut row_group_writer, field, &mut levels)?; + } + Ok(()) + } + ArrowDataType::Dictionary(key_type, value_type) => { + use arrow_array::{ + Int16DictionaryArray, Int32DictionaryArray, Int64DictionaryArray, + Int8DictionaryArray, PrimitiveArray, StringArray, UInt16DictionaryArray, + UInt32DictionaryArray, UInt64DictionaryArray, UInt8DictionaryArray, + }; + use ArrowDataType::*; + use ColumnWriter::*; + + let array = &**array; + let mut col_writer = get_col_writer(&mut row_group_writer)?; + let levels = levels.pop().expect("Levels exhausted"); + + macro_rules! dispatch_dictionary { + ($($kt: pat, $vt: pat, $w: ident => $kat: ty, $vat: ty,)*) => ( + match (&**key_type, &**value_type, &mut col_writer) { + $(($kt, $vt, $w(writer)) => write_dict::<$kat, $vat, _>(array, writer, levels),)* + (kt, vt, _) => unreachable!("Shouldn't be attempting to write dictionary of <{:?}, {:?}>", kt, vt), + } + ); + } + + match (&**key_type, &**value_type, &mut col_writer) { + (UInt8, UInt32, Int32ColumnWriter(writer)) => { + let typed_array = array + .as_any() + .downcast_ref::() + .expect("Unable to get dictionary array"); + + let keys = typed_array.keys(); + + let value_buffer = typed_array.values(); + let value_array = + arrow::compute::cast(&value_buffer, &ArrowDataType::Int32)?; + + let values = value_array + .as_any() + .downcast_ref::() + .unwrap(); + + use std::convert::TryFrom; + // This removes NULL values from the NullableIter, but + // they're encoded by the levels, so that's fine. + let materialized_values: Vec<_> = keys + .flatten() + .map(|key| { + usize::try_from(key).unwrap_or_else(|k| { + panic!("key {} does not fit in usize", k) + }) + }) + .map(|key| values.value(key)) + .collect(); + + let materialized_primitive_array = + PrimitiveArray::::from( + materialized_values, + ); + + writer.write_batch( + get_numeric_array_slice::( + &materialized_primitive_array, + ) + .as_slice(), + Some(levels.definition.as_slice()), + levels.repetition.as_deref(), + )?; + row_group_writer.close_column(col_writer)?; + + return Ok(()); + } + _ => {} + } + + dispatch_dictionary!( + Int8, Utf8, ByteArrayColumnWriter => Int8DictionaryArray, StringArray, + Int16, Utf8, ByteArrayColumnWriter => Int16DictionaryArray, StringArray, + Int32, Utf8, ByteArrayColumnWriter => Int32DictionaryArray, StringArray, + Int64, Utf8, ByteArrayColumnWriter => Int64DictionaryArray, StringArray, + UInt8, Utf8, ByteArrayColumnWriter => UInt8DictionaryArray, StringArray, + UInt16, Utf8, ByteArrayColumnWriter => UInt16DictionaryArray, StringArray, + UInt32, Utf8, ByteArrayColumnWriter => UInt32DictionaryArray, StringArray, + UInt64, Utf8, ByteArrayColumnWriter => UInt64DictionaryArray, StringArray, + )?; + + row_group_writer.close_column(col_writer)?; + + Ok(()) + } + ArrowDataType::Float16 => Err(ParquetError::ArrowError( + "Float16 arrays not supported".to_string(), + )), + ArrowDataType::FixedSizeList(_, _) + | ArrowDataType::Boolean + | ArrowDataType::FixedSizeBinary(_) + | ArrowDataType::Union(_) => Err(ParquetError::NYI( + "Attempting to write an Arrow type that is not yet implemented".to_string(), + )), + } +} + +trait Materialize { + type Output; + + // Materialize the packed dictionary. The writer will later repack it. + fn materialize(&self) -> Vec; +} + +macro_rules! materialize_string { + ($($k:ty,)*) => { + $(impl Materialize<$k, arrow_array::StringArray> for dyn Array { + type Output = ByteArray; + + fn materialize(&self) -> Vec { + use std::convert::TryFrom; + + let typed_array = self.as_any() + .downcast_ref::<$k>() + .expect("Unable to get dictionary array"); + + let keys = typed_array.keys(); + + let value_buffer = typed_array.values(); + let values = value_buffer + .as_any() + .downcast_ref::() + .unwrap(); + + // This removes NULL values from the NullableIter, but + // they're encoded by the levels, so that's fine. + keys + .flatten() + .map(|key| usize::try_from(key).unwrap_or_else(|k| panic!("key {} does not fit in usize", k))) + .map(|key| values.value(key)) + .map(ByteArray::from) + .collect() + } + })* + }; +} + +materialize_string! { + arrow_array::Int8DictionaryArray, + arrow_array::Int16DictionaryArray, + arrow_array::Int32DictionaryArray, + arrow_array::Int64DictionaryArray, + arrow_array::UInt8DictionaryArray, + arrow_array::UInt16DictionaryArray, + arrow_array::UInt32DictionaryArray, + arrow_array::UInt64DictionaryArray, +} + +fn write_dict( + array: &(dyn Array + 'static), + writer: &mut ColumnWriterImpl, + levels: Levels, +) -> Result<()> +where + T: DataType, + dyn Array: Materialize, +{ + writer.write_batch( + &array.materialize(), + Some(levels.definition.as_slice()), + levels.repetition.as_deref(), + )?; + + Ok(()) +} + +fn write_leaf( + writer: &mut ColumnWriter, + column: &arrow_array::ArrayRef, + levels: Levels, +) -> Result { + let written = match writer { + ColumnWriter::Int32ColumnWriter(ref mut typed) => { + let array = arrow::compute::cast(column, &ArrowDataType::Int32)?; + let array = array + .as_any() + .downcast_ref::() + .expect("Unable to get int32 array"); + typed.write_batch( + get_numeric_array_slice::(&array).as_slice(), + Some(levels.definition.as_slice()), + levels.repetition.as_deref(), + )? + } + ColumnWriter::BoolColumnWriter(ref mut _typed) => { + unreachable!("Currently unreachable because data type not supported") + } + ColumnWriter::Int64ColumnWriter(ref mut typed) => { + let array = arrow_array::Int64Array::from(column.data()); + typed.write_batch( + get_numeric_array_slice::(&array).as_slice(), + Some(levels.definition.as_slice()), + levels.repetition.as_deref(), + )? + } + ColumnWriter::Int96ColumnWriter(ref mut _typed) => { + unreachable!("Currently unreachable because data type not supported") + } + ColumnWriter::FloatColumnWriter(ref mut typed) => { + let array = arrow_array::Float32Array::from(column.data()); + typed.write_batch( + get_numeric_array_slice::(&array).as_slice(), + Some(levels.definition.as_slice()), + levels.repetition.as_deref(), + )? + } + ColumnWriter::DoubleColumnWriter(ref mut typed) => { + let array = arrow_array::Float64Array::from(column.data()); + typed.write_batch( + get_numeric_array_slice::(&array).as_slice(), + Some(levels.definition.as_slice()), + levels.repetition.as_deref(), + )? + } + ColumnWriter::ByteArrayColumnWriter(ref mut typed) => match column.data_type() { + ArrowDataType::Binary => { + let array = arrow_array::BinaryArray::from(column.data()); + typed.write_batch( + get_binary_array(&array).as_slice(), + Some(levels.definition.as_slice()), + levels.repetition.as_deref(), + )? + } + ArrowDataType::Utf8 => { + let array = arrow_array::StringArray::from(column.data()); + typed.write_batch( + get_string_array(&array).as_slice(), + Some(levels.definition.as_slice()), + levels.repetition.as_deref(), + )? + } + ArrowDataType::LargeBinary => { + let array = arrow_array::LargeBinaryArray::from(column.data()); + typed.write_batch( + get_large_binary_array(&array).as_slice(), + Some(levels.definition.as_slice()), + levels.repetition.as_deref(), + )? + } + ArrowDataType::LargeUtf8 => { + let array = arrow_array::LargeStringArray::from(column.data()); + typed.write_batch( + get_large_string_array(&array).as_slice(), + Some(levels.definition.as_slice()), + levels.repetition.as_deref(), + )? + } + _ => unreachable!("Currently unreachable because data type not supported"), + }, + ColumnWriter::FixedLenByteArrayColumnWriter(ref mut _typed) => { + unreachable!("Currently unreachable because data type not supported") + } + }; + Ok(written as i64) +} + +/// A struct that represents definition and repetition levels. +/// Repetition levels are only populated if the parent or current leaf is repeated +#[derive(Debug)] +struct Levels { + definition: Vec, + repetition: Option>, +} + +/// Compute nested levels of the Arrow array, recursing into lists and structs +fn get_levels( + array: &arrow_array::ArrayRef, + level: i16, + parent_def_levels: &[i16], + parent_rep_levels: Option<&[i16]>, +) -> Vec { + match array.data_type() { + ArrowDataType::Null => vec![Levels { + definition: parent_def_levels.iter().map(|v| (v - 1).max(0)).collect(), + repetition: None, + }], + ArrowDataType::Boolean + | ArrowDataType::Int8 + | ArrowDataType::Int16 + | ArrowDataType::Int32 + | ArrowDataType::Int64 + | ArrowDataType::UInt8 + | ArrowDataType::UInt16 + | ArrowDataType::UInt32 + | ArrowDataType::UInt64 + | ArrowDataType::Float16 + | ArrowDataType::Float32 + | ArrowDataType::Float64 + | ArrowDataType::Utf8 + | ArrowDataType::LargeUtf8 + | ArrowDataType::Timestamp(_, _) + | ArrowDataType::Date32(_) + | ArrowDataType::Date64(_) + | ArrowDataType::Time32(_) + | ArrowDataType::Time64(_) + | ArrowDataType::Duration(_) + | ArrowDataType::Interval(_) + | ArrowDataType::Binary + | ArrowDataType::LargeBinary => vec![Levels { + definition: get_primitive_def_levels(array, parent_def_levels), + repetition: None, + }], + ArrowDataType::FixedSizeBinary(_) => unimplemented!(), + ArrowDataType::List(_) | ArrowDataType::LargeList(_) => { + let array_data = array.data(); + let child_data = array_data.child_data().get(0).unwrap(); + // get offsets, accounting for large offsets if present + let offsets: Vec = { + if let ArrowDataType::LargeList(_) = array.data_type() { + unsafe { array_data.buffers()[0].typed_data::() }.to_vec() + } else { + let offsets = unsafe { array_data.buffers()[0].typed_data::() }; + offsets.to_vec().into_iter().map(|v| v as i64).collect() + } + }; + let child_array = arrow_array::make_array(child_data.clone()); + + let mut list_def_levels = Vec::with_capacity(child_array.len()); + let mut list_rep_levels = Vec::with_capacity(child_array.len()); + let rep_levels: Vec = parent_rep_levels + .map(|l| l.to_vec()) + .unwrap_or_else(|| vec![0i16; parent_def_levels.len()]); + parent_def_levels + .iter() + .zip(rep_levels) + .zip(offsets.windows(2)) + .for_each(|((parent_def_level, parent_rep_level), window)| { + if *parent_def_level == 0 { + // parent is null, list element must also be null + list_def_levels.push(0); + list_rep_levels.push(0); + } else { + // parent is not null, check if list is empty or null + let start = window[0]; + let end = window[1]; + let len = end - start; + if len == 0 { + list_def_levels.push(*parent_def_level - 1); + list_rep_levels.push(parent_rep_level); + } else { + list_def_levels.push(*parent_def_level); + list_rep_levels.push(parent_rep_level); + for _ in 1..len { + list_def_levels.push(*parent_def_level); + list_rep_levels.push(parent_rep_level + 1); + } + } + } + }); + + // if datatype is a primitive, we can construct levels of the child array + match child_array.data_type() { + // TODO: The behaviour of a > is untested + ArrowDataType::Null => vec![Levels { + definition: list_def_levels, + repetition: Some(list_rep_levels), + }], + ArrowDataType::Boolean => unimplemented!(), + ArrowDataType::Int8 + | ArrowDataType::Int16 + | ArrowDataType::Int32 + | ArrowDataType::Int64 + | ArrowDataType::UInt8 + | ArrowDataType::UInt16 + | ArrowDataType::UInt32 + | ArrowDataType::UInt64 + | ArrowDataType::Float16 + | ArrowDataType::Float32 + | ArrowDataType::Float64 + | ArrowDataType::Timestamp(_, _) + | ArrowDataType::Date32(_) + | ArrowDataType::Date64(_) + | ArrowDataType::Time32(_) + | ArrowDataType::Time64(_) + | ArrowDataType::Duration(_) + | ArrowDataType::Interval(_) => { + let def_levels = + get_primitive_def_levels(&child_array, &list_def_levels[..]); + vec![Levels { + definition: def_levels, + repetition: Some(list_rep_levels), + }] + } + ArrowDataType::Binary + | ArrowDataType::Utf8 + | ArrowDataType::LargeUtf8 => unimplemented!(), + ArrowDataType::FixedSizeBinary(_) => unimplemented!(), + ArrowDataType::LargeBinary => unimplemented!(), + ArrowDataType::List(_) | ArrowDataType::LargeList(_) => { + // nested list + unimplemented!() + } + ArrowDataType::FixedSizeList(_, _) => unimplemented!(), + ArrowDataType::Struct(_) => get_levels( + array, + level + 1, // indicates a nesting level of 2 (list + struct) + &list_def_levels[..], + Some(&list_rep_levels[..]), + ), + ArrowDataType::Union(_) => unimplemented!(), + ArrowDataType::Dictionary(_, _) => unimplemented!(), + } + } + ArrowDataType::FixedSizeList(_, _) => unimplemented!(), + ArrowDataType::Struct(_) => { + let struct_array: &arrow_array::StructArray = array + .as_any() + .downcast_ref::() + .expect("Unable to get struct array"); + let mut struct_def_levels = Vec::with_capacity(struct_array.len()); + for i in 0..array.len() { + struct_def_levels.push(level + struct_array.is_valid(i) as i16); + } + // trying to create levels for struct's fields + let mut struct_levels = vec![]; + struct_array.columns().into_iter().for_each(|col| { + let mut levels = + get_levels(col, level + 1, &struct_def_levels[..], parent_rep_levels); + struct_levels.append(&mut levels); + }); + struct_levels + } + ArrowDataType::Union(_) => unimplemented!(), + ArrowDataType::Dictionary(_, _) => { + // Need to check for these cases not implemented in C++: + // - "Writing DictionaryArray with nested dictionary type not yet supported" + // - "Writing DictionaryArray with null encoded in dictionary type not yet supported" + vec![Levels { + definition: get_primitive_def_levels(array, parent_def_levels), + repetition: None, + }] + } + } +} + +/// Get the definition levels of the numeric array, with level 0 being null and 1 being not null +/// In the case where the array in question is a child of either a list or struct, the levels +/// are incremented in accordance with the `level` parameter. +/// Parent levels are either 0 or 1, and are used to higher (correct terminology?) leaves as null +fn get_primitive_def_levels( + array: &arrow_array::ArrayRef, + parent_def_levels: &[i16], +) -> Vec { + let mut array_index = 0; + let max_def_level = parent_def_levels.iter().max().unwrap(); + let mut primitive_def_levels = vec![]; + parent_def_levels.iter().for_each(|def_level| { + if def_level < max_def_level { + primitive_def_levels.push(*def_level); + } else { + primitive_def_levels.push(def_level - array.is_null(array_index) as i16); + array_index += 1; + } + }); + primitive_def_levels +} + +macro_rules! def_get_binary_array_fn { + ($name:ident, $ty:ty) => { + fn $name(array: &$ty) -> Vec { + let mut values = Vec::with_capacity(array.len() - array.null_count()); + for i in 0..array.len() { + if array.is_valid(i) { + let bytes: Vec = array.value(i).into(); + let bytes = ByteArray::from(bytes); + values.push(bytes); + } + } + values + } + }; +} + +def_get_binary_array_fn!(get_binary_array, arrow_array::BinaryArray); +def_get_binary_array_fn!(get_string_array, arrow_array::StringArray); +def_get_binary_array_fn!(get_large_binary_array, arrow_array::LargeBinaryArray); +def_get_binary_array_fn!(get_large_string_array, arrow_array::LargeStringArray); + +/// Get the underlying numeric array slice, skipping any null values. +/// If there are no null values, it might be quicker to get the slice directly instead of +/// calling this function. +fn get_numeric_array_slice(array: &arrow_array::PrimitiveArray) -> Vec +where + T: DataType, + A: arrow::datatypes::ArrowNumericType, + T::T: From, +{ + let mut values = Vec::with_capacity(array.len() - array.null_count()); + for i in 0..array.len() { + if array.is_valid(i) { + values.push(array.value(i).into()) + } + } + values +} + +#[cfg(test)] +mod tests { + use super::*; + + use std::io::Seek; + use std::sync::Arc; + + use arrow::array::*; + use arrow::datatypes::ToByteSlice; + use arrow::datatypes::{DataType, Field, Schema, UInt32Type, UInt8Type}; + use arrow::record_batch::RecordBatch; + + use crate::arrow::{ArrowReader, ParquetFileArrowReader}; + use crate::file::{metadata::KeyValue, reader::SerializedFileReader}; + use crate::util::test_common::get_temp_file; + + #[test] + fn arrow_writer() { + // define schema + let schema = Schema::new(vec![ + Field::new("a", DataType::Int32, false), + Field::new("b", DataType::Int32, true), + ]); + + // create some data + let a = Int32Array::from(vec![1, 2, 3, 4, 5]); + let b = Int32Array::from(vec![Some(1), None, None, Some(4), Some(5)]); + + // build a record batch + let batch = RecordBatch::try_new( + Arc::new(schema.clone()), + vec![Arc::new(a), Arc::new(b)], + ) + .unwrap(); + + let file = get_temp_file("test_arrow_writer.parquet", &[]); + let mut writer = ArrowWriter::try_new(file, Arc::new(schema), None).unwrap(); + writer.write(&batch).unwrap(); + writer.close().unwrap(); + } + + #[test] + #[ignore = "repetitions might be incorrect, will be addressed as part of ARROW-9728"] + fn arrow_writer_list() { + // define schema + let schema = Schema::new(vec![Field::new( + "a", + DataType::List(Box::new(DataType::Int32)), + false, + )]); + + // create some data + let a_values = Int32Array::from(vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10]); + + // Construct a buffer for value offsets, for the nested array: + // [[1], [2, 3], null, [4, 5, 6], [7, 8, 9, 10]] + let a_value_offsets = + arrow::buffer::Buffer::from(&[0, 1, 3, 3, 6, 10].to_byte_slice()); + + // Construct a list array from the above two + let a_list_data = ArrayData::builder(DataType::List(Box::new(DataType::Int32))) + .len(5) + .add_buffer(a_value_offsets) + .add_child_data(a_values.data()) + .build(); + let a = ListArray::from(a_list_data); + + // build a record batch + let batch = + RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(a)]).unwrap(); + + // I think this setup is incorrect because this should pass + assert_eq!(batch.column(0).data().null_count(), 1); + + let file = get_temp_file("test_arrow_writer_list.parquet", &[]); + let mut writer = ArrowWriter::try_new(file, Arc::new(schema), None).unwrap(); + writer.write(&batch).unwrap(); + writer.close().unwrap(); + } + + #[test] + fn arrow_writer_binary() { + let string_field = Field::new("a", DataType::Utf8, false); + let binary_field = Field::new("b", DataType::Binary, false); + let schema = Schema::new(vec![string_field, binary_field]); + + let raw_string_values = vec!["foo", "bar", "baz", "quux"]; + let raw_binary_values = vec![ + b"foo".to_vec(), + b"bar".to_vec(), + b"baz".to_vec(), + b"quux".to_vec(), + ]; + let raw_binary_value_refs = raw_binary_values + .iter() + .map(|x| x.as_slice()) + .collect::>(); + + let string_values = StringArray::from(raw_string_values.clone()); + let binary_values = BinaryArray::from(raw_binary_value_refs); + let batch = RecordBatch::try_new( + Arc::new(schema.clone()), + vec![Arc::new(string_values), Arc::new(binary_values)], + ) + .unwrap(); + + let mut file = get_temp_file("test_arrow_writer_binary.parquet", &[]); + let mut writer = + ArrowWriter::try_new(file.try_clone().unwrap(), Arc::new(schema), None) + .unwrap(); + writer.write(&batch).unwrap(); + writer.close().unwrap(); + + file.seek(std::io::SeekFrom::Start(0)).unwrap(); + let file_reader = SerializedFileReader::new(file).unwrap(); + let mut arrow_reader = ParquetFileArrowReader::new(Rc::new(file_reader)); + let mut record_batch_reader = arrow_reader.get_record_reader(1024).unwrap(); + + let batch = record_batch_reader.next().unwrap().unwrap(); + let string_col = batch + .column(0) + .as_any() + .downcast_ref::() + .unwrap(); + let binary_col = batch + .column(1) + .as_any() + .downcast_ref::() + .unwrap(); + + for i in 0..batch.num_rows() { + assert_eq!(string_col.value(i), raw_string_values[i]); + assert_eq!(binary_col.value(i), raw_binary_values[i].as_slice()); + } + } + + #[test] + fn arrow_writer_complex() { + // define schema + let struct_field_d = Field::new("d", DataType::Float64, true); + let struct_field_f = Field::new("f", DataType::Float32, true); + let struct_field_g = + Field::new("g", DataType::List(Box::new(DataType::Int16)), false); + let struct_field_e = Field::new( + "e", + DataType::Struct(vec![struct_field_f.clone(), struct_field_g.clone()]), + true, + ); + let schema = Schema::new(vec![ + Field::new("a", DataType::Int32, false), + Field::new("b", DataType::Int32, true), + Field::new( + "c", + DataType::Struct(vec![struct_field_d.clone(), struct_field_e.clone()]), + false, + ), + ]); + + // create some data + let a = Int32Array::from(vec![1, 2, 3, 4, 5]); + let b = Int32Array::from(vec![Some(1), None, None, Some(4), Some(5)]); + let d = Float64Array::from(vec![None, None, None, Some(1.0), None]); + let f = Float32Array::from(vec![Some(0.0), None, Some(333.3), None, Some(5.25)]); + + let g_value = Int16Array::from(vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10]); + + // Construct a buffer for value offsets, for the nested array: + // [[1], [2, 3], null, [4, 5, 6], [7, 8, 9, 10]] + let g_value_offsets = + arrow::buffer::Buffer::from(&[0, 1, 3, 3, 6, 10].to_byte_slice()); + + // Construct a list array from the above two + let g_list_data = ArrayData::builder(struct_field_g.data_type().clone()) + .len(5) + .add_buffer(g_value_offsets) + .add_child_data(g_value.data()) + .build(); + let g = ListArray::from(g_list_data); + + let e = StructArray::from(vec![ + (struct_field_f, Arc::new(f) as ArrayRef), + (struct_field_g, Arc::new(g) as ArrayRef), + ]); + + let c = StructArray::from(vec![ + (struct_field_d, Arc::new(d) as ArrayRef), + (struct_field_e, Arc::new(e) as ArrayRef), + ]); + + // build a record batch + let batch = RecordBatch::try_new( + Arc::new(schema.clone()), + vec![Arc::new(a), Arc::new(b), Arc::new(c)], + ) + .unwrap(); + + let props = WriterProperties::builder() + .set_key_value_metadata(Some(vec![KeyValue { + key: "test_key".to_string(), + value: Some("test_value".to_string()), + }])) + .build(); + + let file = get_temp_file("test_arrow_writer_complex.parquet", &[]); + let mut writer = + ArrowWriter::try_new(file, Arc::new(schema), Some(props)).unwrap(); + writer.write(&batch).unwrap(); + writer.close().unwrap(); + } + + const SMALL_SIZE: usize = 100; + + fn roundtrip(filename: &str, expected_batch: RecordBatch) { + let file = get_temp_file(filename, &[]); + + let mut writer = ArrowWriter::try_new( + file.try_clone().unwrap(), + expected_batch.schema(), + None, + ) + .expect("Unable to write file"); + writer.write(&expected_batch).unwrap(); + writer.close().unwrap(); + + let reader = SerializedFileReader::new(file).unwrap(); + let mut arrow_reader = ParquetFileArrowReader::new(Rc::new(reader)); + let mut record_batch_reader = arrow_reader.get_record_reader(1024).unwrap(); + + let actual_batch = record_batch_reader + .next() + .expect("No batch found") + .expect("Unable to get batch"); + + assert_eq!(expected_batch.schema(), actual_batch.schema()); + assert_eq!(expected_batch.num_columns(), actual_batch.num_columns()); + assert_eq!(expected_batch.num_rows(), actual_batch.num_rows()); + for i in 0..expected_batch.num_columns() { + let expected_data = expected_batch.column(i).data(); + let actual_data = actual_batch.column(i).data(); + + assert_eq!(expected_data.data_type(), actual_data.data_type()); + assert_eq!(expected_data.len(), actual_data.len()); + assert_eq!(expected_data.null_count(), actual_data.null_count()); + assert_eq!(expected_data.offset(), actual_data.offset()); + assert_eq!(expected_data.buffers(), actual_data.buffers()); + assert_eq!(expected_data.child_data(), actual_data.child_data()); + // Null counts should be the same, not necessarily bitmaps + // A null bitmap is optional if an array has no nulls + if expected_data.null_count() != 0 { + assert_eq!(expected_data.null_bitmap(), actual_data.null_bitmap()); + } + } + } + + fn one_column_roundtrip(filename: &str, values: ArrayRef, nullable: bool) { + let schema = Schema::new(vec![Field::new( + "col", + values.data_type().clone(), + nullable, + )]); + let expected_batch = + RecordBatch::try_new(Arc::new(schema), vec![values]).unwrap(); + + roundtrip(filename, expected_batch); + } + + fn values_required(iter: I, filename: &str) + where + A: From> + Array + 'static, + I: IntoIterator, + { + let raw_values: Vec<_> = iter.into_iter().collect(); + let values = Arc::new(A::from(raw_values)); + one_column_roundtrip(filename, values, false); + } + + fn values_optional(iter: I, filename: &str) + where + A: From>> + Array + 'static, + I: IntoIterator, + { + let optional_raw_values: Vec<_> = iter + .into_iter() + .enumerate() + .map(|(i, v)| if i % 2 == 0 { None } else { Some(v) }) + .collect(); + let optional_values = Arc::new(A::from(optional_raw_values)); + one_column_roundtrip(filename, optional_values, true); + } + + fn required_and_optional(iter: I, filename: &str) + where + A: From> + From>> + Array + 'static, + I: IntoIterator + Clone, + { + values_required::(iter.clone(), filename); + values_optional::(iter, filename); + } + + #[test] + fn all_null_primitive_single_column() { + let values = Arc::new(Int32Array::from(vec![None; SMALL_SIZE])); + one_column_roundtrip("all_null_primitive_single_column", values, true); + } + #[test] + fn null_single_column() { + let values = Arc::new(NullArray::new(SMALL_SIZE)); + one_column_roundtrip("null_single_column", values, true); + // null arrays are always nullable, a test with non-nullable nulls fails + } + + #[test] + #[should_panic( + expected = "Attempting to write an Arrow type that is not yet implemented" + )] + fn bool_single_column() { + required_and_optional::( + [true, false].iter().cycle().copied().take(SMALL_SIZE), + "bool_single_column", + ); + } + + #[test] + fn i8_single_column() { + required_and_optional::(0..SMALL_SIZE as i8, "i8_single_column"); + } + + #[test] + fn i16_single_column() { + required_and_optional::(0..SMALL_SIZE as i16, "i16_single_column"); + } + + #[test] + fn i32_single_column() { + required_and_optional::(0..SMALL_SIZE as i32, "i32_single_column"); + } + + #[test] + fn i64_single_column() { + required_and_optional::(0..SMALL_SIZE as i64, "i64_single_column"); + } + + #[test] + fn u8_single_column() { + required_and_optional::(0..SMALL_SIZE as u8, "u8_single_column"); + } + + #[test] + fn u16_single_column() { + required_and_optional::( + 0..SMALL_SIZE as u16, + "u16_single_column", + ); + } + + #[test] + fn u32_single_column() { + required_and_optional::( + 0..SMALL_SIZE as u32, + "u32_single_column", + ); + } + + #[test] + fn u64_single_column() { + required_and_optional::( + 0..SMALL_SIZE as u64, + "u64_single_column", + ); + } + + #[test] + fn f32_single_column() { + required_and_optional::( + (0..SMALL_SIZE).map(|i| i as f32), + "f32_single_column", + ); + } + + #[test] + fn f64_single_column() { + required_and_optional::( + (0..SMALL_SIZE).map(|i| i as f64), + "f64_single_column", + ); + } + + // The timestamp array types don't implement From> because they need the timezone + // argument, and they also doesn't support building from a Vec>, so call + // one_column_roundtrip manually instead of calling required_and_optional for these tests. + + #[test] + #[ignore] // Timestamp support isn't correct yet + fn timestamp_second_single_column() { + let raw_values: Vec<_> = (0..SMALL_SIZE as i64).collect(); + let values = Arc::new(TimestampSecondArray::from_vec(raw_values, None)); + + one_column_roundtrip("timestamp_second_single_column", values, false); + } + + #[test] + fn timestamp_millisecond_single_column() { + let raw_values: Vec<_> = (0..SMALL_SIZE as i64).collect(); + let values = Arc::new(TimestampMillisecondArray::from_vec(raw_values, None)); + + one_column_roundtrip("timestamp_millisecond_single_column", values, false); + } + + #[test] + fn timestamp_microsecond_single_column() { + let raw_values: Vec<_> = (0..SMALL_SIZE as i64).collect(); + let values = Arc::new(TimestampMicrosecondArray::from_vec(raw_values, None)); + + one_column_roundtrip("timestamp_microsecond_single_column", values, false); + } + + #[test] + #[ignore] // Timestamp support isn't correct yet + fn timestamp_nanosecond_single_column() { + let raw_values: Vec<_> = (0..SMALL_SIZE as i64).collect(); + let values = Arc::new(TimestampNanosecondArray::from_vec(raw_values, None)); + + one_column_roundtrip("timestamp_nanosecond_single_column", values, false); + } + + #[test] + fn date32_single_column() { + required_and_optional::( + 0..SMALL_SIZE as i32, + "date32_single_column", + ); + } + + #[test] + #[ignore] // Date support isn't correct yet + fn date64_single_column() { + required_and_optional::( + 0..SMALL_SIZE as i64, + "date64_single_column", + ); + } + + #[test] + #[ignore] // DateUnit resolution mismatch + fn time32_second_single_column() { + required_and_optional::( + 0..SMALL_SIZE as i32, + "time32_second_single_column", + ); + } + + #[test] + fn time32_millisecond_single_column() { + required_and_optional::( + 0..SMALL_SIZE as i32, + "time32_millisecond_single_column", + ); + } + + #[test] + fn time64_microsecond_single_column() { + required_and_optional::( + 0..SMALL_SIZE as i64, + "time64_microsecond_single_column", + ); + } + + #[test] + #[ignore] // DateUnit resolution mismatch + fn time64_nanosecond_single_column() { + required_and_optional::( + 0..SMALL_SIZE as i64, + "time64_nanosecond_single_column", + ); + } + + #[test] + #[should_panic(expected = "Converting Duration to parquet not supported")] + fn duration_second_single_column() { + required_and_optional::( + 0..SMALL_SIZE as i64, + "duration_second_single_column", + ); + } + + #[test] + #[should_panic(expected = "Converting Duration to parquet not supported")] + fn duration_millisecond_single_column() { + required_and_optional::( + 0..SMALL_SIZE as i64, + "duration_millisecond_single_column", + ); + } + + #[test] + #[should_panic(expected = "Converting Duration to parquet not supported")] + fn duration_microsecond_single_column() { + required_and_optional::( + 0..SMALL_SIZE as i64, + "duration_microsecond_single_column", + ); + } + + #[test] + #[should_panic(expected = "Converting Duration to parquet not supported")] + fn duration_nanosecond_single_column() { + required_and_optional::( + 0..SMALL_SIZE as i64, + "duration_nanosecond_single_column", + ); + } + + #[test] + #[should_panic(expected = "Currently unreachable because data type not supported")] + fn interval_year_month_single_column() { + required_and_optional::( + 0..SMALL_SIZE as i32, + "interval_year_month_single_column", + ); + } + + #[test] + #[should_panic(expected = "Currently unreachable because data type not supported")] + fn interval_day_time_single_column() { + required_and_optional::( + 0..SMALL_SIZE as i64, + "interval_day_time_single_column", + ); + } + + #[test] + #[ignore] // Binary support isn't correct yet - buffers don't match + fn binary_single_column() { + let one_vec: Vec = (0..SMALL_SIZE as u8).collect(); + let many_vecs: Vec<_> = std::iter::repeat(one_vec).take(SMALL_SIZE).collect(); + let many_vecs_iter = many_vecs.iter().map(|v| v.as_slice()); + + // BinaryArrays can't be built from Vec>, so only call `values_required` + values_required::(many_vecs_iter, "binary_single_column"); + } + + #[test] + #[ignore] // Large binary support isn't correct yet - buffers don't match + fn large_binary_single_column() { + let one_vec: Vec = (0..SMALL_SIZE as u8).collect(); + let many_vecs: Vec<_> = std::iter::repeat(one_vec).take(SMALL_SIZE).collect(); + let many_vecs_iter = many_vecs.iter().map(|v| v.as_slice()); + + // LargeBinaryArrays can't be built from Vec>, so only call `values_required` + values_required::( + many_vecs_iter, + "large_binary_single_column", + ); + } + + #[test] + fn string_single_column() { + let raw_values: Vec<_> = (0..SMALL_SIZE).map(|i| i.to_string()).collect(); + let raw_strs = raw_values.iter().map(|s| s.as_str()); + + required_and_optional::(raw_strs, "string_single_column"); + } + + #[test] + fn large_string_single_column() { + let raw_values: Vec<_> = (0..SMALL_SIZE).map(|i| i.to_string()).collect(); + let raw_strs = raw_values.iter().map(|s| s.as_str()); + + required_and_optional::( + raw_strs, + "large_string_single_column", + ); + } + + #[test] + #[ignore = "repetitions might be incorrect, will be addressed as part of ARROW-9728"] + fn list_single_column() { + let a_values = Int32Array::from(vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10]); + let a_value_offsets = + arrow::buffer::Buffer::from(&[0, 1, 3, 3, 6, 10].to_byte_slice()); + let a_list_data = ArrayData::builder(DataType::List(Box::new(DataType::Int32))) + .len(5) + .add_buffer(a_value_offsets) + .add_child_data(a_values.data()) + .build(); + + // I think this setup is incorrect because this should pass + assert_eq!(a_list_data.null_count(), 1); + + let a = ListArray::from(a_list_data); + let values = Arc::new(a); + + one_column_roundtrip("list_single_column", values, false); + } + + #[test] + #[ignore = "repetitions might be incorrect, will be addressed as part of ARROW-9728"] + fn large_list_single_column() { + let a_values = Int32Array::from(vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10]); + let a_value_offsets = + arrow::buffer::Buffer::from(&[0i64, 1, 3, 3, 6, 10].to_byte_slice()); + let a_list_data = + ArrayData::builder(DataType::LargeList(Box::new(DataType::Int32))) + .len(5) + .add_buffer(a_value_offsets) + .add_child_data(a_values.data()) + .build(); + + // I think this setup is incorrect because this should pass + assert_eq!(a_list_data.null_count(), 1); + + let a = LargeListArray::from(a_list_data); + let values = Arc::new(a); + + one_column_roundtrip("large_list_single_column", values, false); + } + + #[test] + fn struct_single_column() { + let a_values = Int32Array::from(vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10]); + let struct_field_a = Field::new("f", DataType::Int32, false); + let s = StructArray::from(vec![(struct_field_a, Arc::new(a_values) as ArrayRef)]); + + let values = Arc::new(s); + one_column_roundtrip("struct_single_column", values, false); + } + + #[test] + fn arrow_writer_string_dictionary() { + // define schema + let schema = Arc::new(Schema::new(vec![Field::new_dict( + "dictionary", + DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8)), + true, + 42, + true, + )])); + + // create some data + let d: Int32DictionaryArray = [Some("alpha"), None, Some("beta"), Some("alpha")] + .iter() + .copied() + .collect(); + + // build a record batch + let expected_batch = RecordBatch::try_new(schema, vec![Arc::new(d)]).unwrap(); + + roundtrip( + "test_arrow_writer_string_dictionary.parquet", + expected_batch, + ); + } + + #[test] + fn arrow_writer_primitive_dictionary() { + // define schema + let schema = Arc::new(Schema::new(vec![Field::new_dict( + "dictionary", + DataType::Dictionary(Box::new(DataType::UInt8), Box::new(DataType::UInt32)), + true, + 42, + true, + )])); + + // create some data + let key_builder = PrimitiveBuilder::::new(3); + let value_builder = PrimitiveBuilder::::new(2); + let mut builder = PrimitiveDictionaryBuilder::new(key_builder, value_builder); + builder.append(12345678).unwrap(); + builder.append_null().unwrap(); + builder.append(22345678).unwrap(); + builder.append(12345678).unwrap(); + let d = builder.finish(); + + // build a record batch + let expected_batch = RecordBatch::try_new(schema, vec![Arc::new(d)]).unwrap(); + + roundtrip( + "test_arrow_writer_primitive_dictionary.parquet", + expected_batch, + ); + } + + #[test] + fn arrow_writer_string_dictionary_unsigned_index() { + // define schema + let schema = Arc::new(Schema::new(vec![Field::new_dict( + "dictionary", + DataType::Dictionary(Box::new(DataType::UInt8), Box::new(DataType::Utf8)), + true, + 42, + true, + )])); + + // create some data + let d: UInt8DictionaryArray = [Some("alpha"), None, Some("beta"), Some("alpha")] + .iter() + .copied() + .collect(); + + // build a record batch + let expected_batch = RecordBatch::try_new(schema, vec![Arc::new(d)]).unwrap(); + + roundtrip( + "test_arrow_writer_string_dictionary_unsigned_index.parquet", + expected_batch, + ); + } +} diff --git a/rust/parquet/src/arrow/converter.rs b/rust/parquet/src/arrow/converter.rs index da0cc6c984c..33b29c897e6 100644 --- a/rust/parquet/src/arrow/converter.rs +++ b/rust/parquet/src/arrow/converter.rs @@ -15,36 +15,28 @@ // specific language governing permissions and limitations // under the License. -use crate::arrow::record_reader::RecordReader; use crate::data_type::{ByteArray, DataType, Int96}; +// TODO: clean up imports (best done when there are few moving parts) use arrow::array::{ - Array, ArrayRef, BinaryBuilder, BooleanArray, BooleanBufferBuilder, - BufferBuilderTrait, FixedSizeBinaryBuilder, StringBuilder, - TimestampNanosecondBuilder, + Array, ArrayRef, BinaryBuilder, FixedSizeBinaryBuilder, LargeBinaryBuilder, + LargeStringBuilder, PrimitiveBuilder, PrimitiveDictionaryBuilder, StringBuilder, + StringDictionaryBuilder, TimestampNanosecondBuilder, }; use arrow::compute::cast; use std::convert::From; use std::sync::Arc; use crate::errors::Result; -use arrow::datatypes::{ArrowPrimitiveType, DataType as ArrowDataType}; +use arrow::datatypes::{ArrowDictionaryKeyType, ArrowPrimitiveType}; -use arrow::array::ArrayDataBuilder; use arrow::array::{ - BinaryArray, FixedSizeBinaryArray, PrimitiveArray, StringArray, - TimestampNanosecondArray, + BinaryArray, DictionaryArray, FixedSizeBinaryArray, LargeBinaryArray, + LargeStringArray, PrimitiveArray, StringArray, TimestampNanosecondArray, }; use std::marker::PhantomData; -use crate::data_type::{ - BoolType, DoubleType as ParquetDoubleType, FloatType as ParquetFloatType, - Int32Type as ParquetInt32Type, Int64Type as ParquetInt64Type, -}; -use arrow::datatypes::{ - Date32Type, Float32Type, Float64Type, Int16Type, Int32Type, Int64Type, Int8Type, - TimestampMicrosecondType, TimestampMillisecondType, UInt16Type, UInt32Type, - UInt64Type, UInt8Type, -}; +use crate::data_type::Int32Type as ParquetInt32Type; +use arrow::datatypes::Int32Type; /// A converter is used to consume record reader's content and convert it to arrow /// primitive array. @@ -55,83 +47,6 @@ pub trait Converter { fn convert(&self, source: S) -> Result; } -/// Cast converter first converts record reader's buffer to arrow's -/// `PrimitiveArray`, then casts it to `PrimitiveArray`. -pub struct CastConverter { - _parquet_marker: PhantomData, - _arrow_source_marker: PhantomData, - _arrow_target_marker: PhantomData, -} - -impl - CastConverter -where - ParquetType: DataType, - ArrowSourceType: ArrowPrimitiveType, - ArrowTargetType: ArrowPrimitiveType, -{ - pub fn new() -> Self { - Self { - _parquet_marker: PhantomData, - _arrow_source_marker: PhantomData, - _arrow_target_marker: PhantomData, - } - } -} - -impl - Converter<&mut RecordReader, ArrayRef> - for CastConverter -where - ParquetType: DataType, - ArrowSourceType: ArrowPrimitiveType, - ArrowTargetType: ArrowPrimitiveType, -{ - fn convert(&self, record_reader: &mut RecordReader) -> Result { - let record_data = record_reader.consume_record_data(); - - let mut array_data = ArrayDataBuilder::new(ArrowSourceType::DATA_TYPE) - .len(record_reader.num_values()) - .add_buffer(record_data?); - - if let Some(b) = record_reader.consume_bitmap_buffer()? { - array_data = array_data.null_bit_buffer(b); - } - - let primitive_array: ArrayRef = - Arc::new(PrimitiveArray::::from(array_data.build())); - - Ok(cast(&primitive_array, &ArrowTargetType::DATA_TYPE)?) - } -} - -pub struct BooleanArrayConverter {} - -impl Converter<&mut RecordReader, BooleanArray> for BooleanArrayConverter { - fn convert( - &self, - record_reader: &mut RecordReader, - ) -> Result { - let record_data = record_reader.consume_record_data()?; - - let mut boolean_buffer = BooleanBufferBuilder::new(record_data.len()); - - for e in record_data.data() { - boolean_buffer.append(*e > 0)?; - } - - let mut array_data = ArrayDataBuilder::new(ArrowDataType::Boolean) - .len(record_data.len()) - .add_buffer(boolean_buffer.finish()); - - if let Some(b) = record_reader.consume_bitmap_buffer()? { - array_data = array_data.null_bit_buffer(b); - } - - Ok(BooleanArray::from(array_data.build())) - } -} - pub struct FixedSizeArrayConverter { byte_width: i32, } @@ -193,6 +108,27 @@ impl Converter>, StringArray> for Utf8ArrayConverter { } } +pub struct LargeUtf8ArrayConverter {} + +impl Converter>, LargeStringArray> for LargeUtf8ArrayConverter { + fn convert(&self, source: Vec>) -> Result { + let data_size = source + .iter() + .map(|x| x.as_ref().map(|b| b.len()).unwrap_or(0)) + .sum(); + + let mut builder = LargeStringBuilder::with_capacity(source.len(), data_size); + for v in source { + match v { + Some(array) => builder.append_value(array.as_utf8()?), + None => builder.append_null(), + }? + } + + Ok(builder.finish()) + } +} + pub struct BinaryArrayConverter {} impl Converter>, BinaryArray> for BinaryArrayConverter { @@ -209,30 +145,135 @@ impl Converter>, BinaryArray> for BinaryArrayConverter { } } -pub type BoolConverter<'a> = ArrayRefConverter< - &'a mut RecordReader, - BooleanArray, - BooleanArrayConverter, ->; -pub type Int8Converter = CastConverter; -pub type UInt8Converter = CastConverter; -pub type Int16Converter = CastConverter; -pub type UInt16Converter = CastConverter; -pub type Int32Converter = CastConverter; -pub type UInt32Converter = CastConverter; -pub type Int64Converter = CastConverter; -pub type Date32Converter = CastConverter; -pub type TimestampMillisecondConverter = - CastConverter; -pub type TimestampMicrosecondConverter = - CastConverter; -pub type UInt64Converter = CastConverter; -pub type Float32Converter = CastConverter; -pub type Float64Converter = CastConverter; +pub struct LargeBinaryArrayConverter {} + +impl Converter>, LargeBinaryArray> for LargeBinaryArrayConverter { + fn convert(&self, source: Vec>) -> Result { + let mut builder = LargeBinaryBuilder::new(source.len()); + for v in source { + match v { + Some(array) => builder.append_value(array.data()), + None => builder.append_null(), + }? + } + + Ok(builder.finish()) + } +} + +pub struct StringDictionaryArrayConverter {} + +impl Converter>, DictionaryArray> + for StringDictionaryArrayConverter +{ + fn convert(&self, source: Vec>) -> Result> { + let data_size = source + .iter() + .map(|x| x.as_ref().map(|b| b.len()).unwrap_or(0)) + .sum(); + + let keys_builder = PrimitiveBuilder::::new(source.len()); + let values_builder = StringBuilder::with_capacity(source.len(), data_size); + + let mut builder = StringDictionaryBuilder::new(keys_builder, values_builder); + for v in source { + match v { + Some(array) => { + let _ = builder.append(array.as_utf8()?)?; + } + None => builder.append_null()?, + } + } + + Ok(builder.finish()) + } +} + +pub struct DictionaryArrayConverter +{ + _dict_value_source_marker: PhantomData, + _dict_value_target_marker: PhantomData, + _parquet_marker: PhantomData, +} + +impl + DictionaryArrayConverter +{ + pub fn new() -> Self { + Self { + _dict_value_source_marker: PhantomData, + _dict_value_target_marker: PhantomData, + _parquet_marker: PhantomData, + } + } +} + +impl + Converter::T>>, DictionaryArray> + for DictionaryArrayConverter +where + K: ArrowPrimitiveType, + DictValueSourceType: ArrowPrimitiveType, + DictValueTargetType: ArrowPrimitiveType, + ParquetType: DataType, + PrimitiveArray: From::T>>>, +{ + fn convert( + &self, + source: Vec::T>>, + ) -> Result> { + let keys_builder = PrimitiveBuilder::::new(source.len()); + let values_builder = PrimitiveBuilder::::new(source.len()); + + let mut builder = PrimitiveDictionaryBuilder::new(keys_builder, values_builder); + + let source_array: Arc = + Arc::new(PrimitiveArray::::from(source)); + let target_array = cast(&source_array, &DictValueTargetType::DATA_TYPE)?; + let target = target_array + .as_any() + .downcast_ref::>() + .unwrap(); + + for i in 0..target.len() { + if target.is_null(i) { + builder.append_null()?; + } else { + let _ = builder.append(target.value(i))?; + } + } + + Ok(builder.finish()) + } +} + pub type Utf8Converter = ArrayRefConverter>, StringArray, Utf8ArrayConverter>; +pub type LargeUtf8Converter = + ArrayRefConverter>, LargeStringArray, LargeUtf8ArrayConverter>; pub type BinaryConverter = ArrayRefConverter>, BinaryArray, BinaryArrayConverter>; +pub type LargeBinaryConverter = ArrayRefConverter< + Vec>, + LargeBinaryArray, + LargeBinaryArrayConverter, +>; +pub type StringDictionaryConverter = ArrayRefConverter< + Vec>, + DictionaryArray, + StringDictionaryArrayConverter, +>; +pub type DictionaryConverter = ArrayRefConverter< + Vec::T>>, + DictionaryArray, + DictionaryArrayConverter, +>; +pub type PrimitiveDictionaryConverter = ArrayRefConverter< + Vec::T>>, + DictionaryArray, + DictionaryArrayConverter, +>; + pub type Int96Converter = ArrayRefConverter>, TimestampNanosecondArray, Int96ArrayConverter>; pub type FixedLenBinaryConverter = ArrayRefConverter< @@ -298,120 +339,3 @@ where .map(|array| Arc::new(array) as ArrayRef) } } - -#[cfg(test)] -mod tests { - use super::*; - use crate::arrow::converter::Int16Converter; - use crate::arrow::record_reader::RecordReader; - use crate::basic::Encoding; - use crate::schema::parser::parse_message_type; - use crate::schema::types::SchemaDescriptor; - use crate::util::test_common::page_util::InMemoryPageReader; - use crate::util::test_common::page_util::{DataPageBuilder, DataPageBuilderImpl}; - use arrow::array::ArrayEqual; - use arrow::array::PrimitiveArray; - use arrow::datatypes::{Int16Type, Int32Type}; - use std::rc::Rc; - - macro_rules! converter_arrow_source_target { - ($raw_data:expr, $physical_type:expr, $result_arrow_type:ty, $converter:ty) => {{ - // Construct record reader - let mut record_reader = { - // Construct column schema - let message_type = &format!( - " - message test_schema {{ - OPTIONAL {} leaf; - }} - ", - $physical_type - ); - - let def_levels = [1i16, 0i16, 1i16, 1i16]; - build_record_reader( - message_type, - &[1, 2, 3], - 0i16, - None, - 1i16, - Some(&def_levels), - 10, - ) - }; - - let array = <$converter>::new().convert(&mut record_reader).unwrap(); - let array = array - .as_any() - .downcast_ref::>() - .unwrap(); - - assert!(array.equals(&PrimitiveArray::<$result_arrow_type>::from($raw_data))); - }}; - } - - #[test] - fn test_converter_arrow_source_i16_target_i32() { - let raw_data = vec![Some(1i16), None, Some(2i16), Some(3i16)]; - converter_arrow_source_target!(raw_data, "INT32", Int16Type, Int16Converter) - } - - #[test] - fn test_converter_arrow_source_i32_target_date32() { - let raw_data = vec![Some(1i32), None, Some(2i32), Some(3i32)]; - converter_arrow_source_target!(raw_data, "INT32", Date32Type, Date32Converter) - } - - #[test] - fn test_converter_arrow_source_i32_target_i32() { - let raw_data = vec![Some(1i32), None, Some(2i32), Some(3i32)]; - converter_arrow_source_target!(raw_data, "INT32", Int32Type, Int32Converter) - } - - fn build_record_reader( - message_type: &str, - values: &[T::T], - max_rep_level: i16, - rep_levels: Option<&[i16]>, - max_def_level: i16, - def_levels: Option<&[i16]>, - num_records: usize, - ) -> RecordReader { - let desc = parse_message_type(message_type) - .map(|t| SchemaDescriptor::new(Rc::new(t))) - .map(|s| s.column(0)) - .unwrap(); - - let mut record_reader = RecordReader::::new(desc.clone()); - - // Prepare record reader - let mut pb = DataPageBuilderImpl::new(desc, 4, true); - if rep_levels.is_some() { - pb.add_rep_levels( - max_rep_level, - match rep_levels { - Some(a) => a, - _ => unreachable!(), - }, - ); - } - if def_levels.is_some() { - pb.add_def_levels( - max_def_level, - match def_levels { - Some(a) => a, - _ => unreachable!(), - }, - ); - } - pb.add_values::(Encoding::PLAIN, &values); - let page = pb.consume(); - - let page_reader = Box::new(InMemoryPageReader::new(vec![page])); - record_reader.set_page_reader(page_reader).unwrap(); - - record_reader.read_records(num_records).unwrap(); - - record_reader - } -} diff --git a/rust/parquet/src/arrow/mod.rs b/rust/parquet/src/arrow/mod.rs index ef1544d65bb..979345722d2 100644 --- a/rust/parquet/src/arrow/mod.rs +++ b/rust/parquet/src/arrow/mod.rs @@ -35,7 +35,7 @@ //! //! println!("Converted arrow schema is: {}", arrow_reader.get_schema().unwrap()); //! println!("Arrow schema after projection is: {}", -//! arrow_reader.get_schema_by_columns(vec![2, 4, 6]).unwrap()); +//! arrow_reader.get_schema_by_columns(vec![2, 4, 6], true).unwrap()); //! //! let mut record_batch_reader = arrow_reader.get_record_reader(2048).unwrap(); //! @@ -51,10 +51,18 @@ pub(in crate::arrow) mod array_reader; pub mod arrow_reader; +pub mod arrow_writer; pub(in crate::arrow) mod converter; pub(in crate::arrow) mod record_reader; pub mod schema; pub use self::arrow_reader::ArrowReader; pub use self::arrow_reader::ParquetFileArrowReader; -pub use self::schema::{parquet_to_arrow_schema, parquet_to_arrow_schema_by_columns}; +pub use self::arrow_writer::ArrowWriter; +pub use self::schema::{ + arrow_to_parquet_schema, parquet_to_arrow_schema, parquet_to_arrow_schema_by_columns, + parquet_to_arrow_schema_by_root_columns, +}; + +/// Schema metadata key used to store serialized Arrow IPC schema +pub const ARROW_SCHEMA_META_KEY: &str = "ARROW:schema"; diff --git a/rust/parquet/src/arrow/record_reader.rs b/rust/parquet/src/arrow/record_reader.rs index ccfdaf8f0e5..519bd15fb0c 100644 --- a/rust/parquet/src/arrow/record_reader.rs +++ b/rust/parquet/src/arrow/record_reader.rs @@ -86,6 +86,7 @@ impl<'a, T> FatPtr<'a, T> { self.ptr } + #[allow(clippy::wrong_self_convention)] fn to_slice_mut(&mut self) -> &mut [T] { self.ptr } @@ -123,26 +124,6 @@ impl RecordReader { } } - pub(crate) fn cast(&mut self) -> &mut RecordReader { - trait CastRecordReader { - fn cast(&mut self) -> &mut RecordReader; - } - - impl CastRecordReader for RecordReader { - default fn cast(&mut self) -> &mut RecordReader { - panic!("Attempted to cast RecordReader to the wrong type") - } - } - - impl CastRecordReader for RecordReader { - fn cast(&mut self) -> &mut RecordReader { - self - } - } - - CastRecordReader::::cast(self) - } - /// Set the current page reader. pub fn set_page_reader(&mut self, page_reader: Box) -> Result<()> { self.column_reader = diff --git a/rust/parquet/src/arrow/schema.rs b/rust/parquet/src/arrow/schema.rs index aebb9e776cc..10270fff464 100644 --- a/rust/parquet/src/arrow/schema.rs +++ b/rust/parquet/src/arrow/schema.rs @@ -26,27 +26,91 @@ use std::collections::{HashMap, HashSet}; use std::rc::Rc; +use arrow::datatypes::{DataType, DateUnit, Field, Schema, TimeUnit}; +use arrow::ipc::writer; + use crate::basic::{LogicalType, Repetition, Type as PhysicalType}; use crate::errors::{ParquetError::ArrowError, Result}; -use crate::file::metadata::KeyValue; +use crate::file::{metadata::KeyValue, properties::WriterProperties}; use crate::schema::types::{ColumnDescriptor, SchemaDescriptor, Type, TypePtr}; -use arrow::datatypes::TimeUnit; -use arrow::datatypes::{DataType, DateUnit, Field, Schema}; - -/// Convert parquet schema to arrow schema including optional metadata. +/// Convert Parquet schema to Arrow schema including optional metadata. +/// Attempts to decode any existing Arrow shcema metadata, falling back +/// to converting the Parquet schema column-wise pub fn parquet_to_arrow_schema( parquet_schema: &SchemaDescriptor, - metadata: &Option>, + key_value_metadata: &Option>, ) -> Result { - parquet_to_arrow_schema_by_columns( - parquet_schema, - 0..parquet_schema.columns().len(), - metadata, - ) + let mut metadata = parse_key_value_metadata(key_value_metadata).unwrap_or_default(); + let arrow_schema_metadata = metadata + .remove(super::ARROW_SCHEMA_META_KEY) + .map(|encoded| get_arrow_schema_from_metadata(&encoded)); + + match arrow_schema_metadata { + Some(Some(schema)) => Ok(schema), + _ => parquet_to_arrow_schema_by_columns( + parquet_schema, + 0..parquet_schema.columns().len(), + key_value_metadata, + ), + } +} + +/// Convert parquet schema to arrow schema including optional metadata, +/// only preserving some root columns. +/// This is useful if we have columns `a.b`, `a.c.e` and `a.d`, +/// and want `a` with all its child fields +pub fn parquet_to_arrow_schema_by_root_columns( + parquet_schema: &SchemaDescriptor, + column_indices: T, + key_value_metadata: &Option>, +) -> Result +where + T: IntoIterator, +{ + // Reconstruct the index ranges of the parent columns + // An Arrow struct gets represented by 1+ columns based on how many child fields the + // struct has. This means that getting fields 1 and 2 might return the struct twice, + // if field 1 is the struct having say 3 fields, and field 2 is a primitive. + // + // The below gets the parent columns, and counts the number of child fields in each parent, + // such that we would end up with: + // - field 1 - columns: [0, 1, 2] + // - field 2 - columns: [3] + let mut parent_columns = vec![]; + let mut curr_name = ""; + let mut prev_name = ""; + let mut indices = vec![]; + (0..(parquet_schema.num_columns())).for_each(|i| { + let p_type = parquet_schema.get_column_root(i); + curr_name = p_type.get_basic_info().name(); + if prev_name == "" { + // first index + indices.push(i); + prev_name = curr_name; + } else if curr_name != prev_name { + prev_name = curr_name; + parent_columns.push((curr_name.to_string(), indices.clone())); + indices = vec![i]; + } else { + indices.push(i); + } + }); + // push the last column if indices has values + if !indices.is_empty() { + parent_columns.push((curr_name.to_string(), indices)); + } + + // gather the required leaf columns + let leaf_columns = column_indices + .into_iter() + .flat_map(|i| parent_columns[i].1.clone()); + + parquet_to_arrow_schema_by_columns(parquet_schema, leaf_columns, key_value_metadata) } -/// Convert parquet schema to arrow schema including optional metadata, only preserving some leaf columns. +/// Convert parquet schema to arrow schema including optional metadata, +/// only preserving some leaf columns. pub fn parquet_to_arrow_schema_by_columns( parquet_schema: &SchemaDescriptor, column_indices: T, @@ -55,32 +119,136 @@ pub fn parquet_to_arrow_schema_by_columns( where T: IntoIterator, { + let mut metadata = parse_key_value_metadata(key_value_metadata).unwrap_or_default(); + let arrow_schema_metadata = metadata + .remove(super::ARROW_SCHEMA_META_KEY) + .map(|encoded| get_arrow_schema_from_metadata(&encoded)) + .unwrap_or_default(); + + // add the Arrow metadata to the Parquet metadata + if let Some(arrow_schema) = &arrow_schema_metadata { + arrow_schema.metadata().iter().for_each(|(k, v)| { + metadata.insert(k.clone(), v.clone()); + }); + } + let mut base_nodes = Vec::new(); let mut base_nodes_set = HashSet::new(); let mut leaves = HashSet::new(); + enum FieldType<'a> { + Parquet(&'a Type), + Arrow(Field), + } + for c in column_indices { - let column = parquet_schema.column(c).self_type() as *const Type; - let root = parquet_schema.get_column_root(c); - let root_raw_ptr = root as *const Type; - - leaves.insert(column); - if !base_nodes_set.contains(&root_raw_ptr) { - base_nodes.push(root); - base_nodes_set.insert(root_raw_ptr); + let column = parquet_schema.column(c); + let name = column.name(); + + if let Some(field) = arrow_schema_metadata + .as_ref() + .and_then(|schema| schema.field_with_name(name).ok().cloned()) + { + base_nodes.push(FieldType::Arrow(field)); + } else { + let column = column.self_type() as *const Type; + let root = parquet_schema.get_column_root(c); + let root_raw_ptr = root as *const Type; + + leaves.insert(column); + if !base_nodes_set.contains(&root_raw_ptr) { + base_nodes.push(FieldType::Parquet(root)); + base_nodes_set.insert(root_raw_ptr); + } } } - let metadata = parse_key_value_metadata(key_value_metadata).unwrap_or_default(); - base_nodes .into_iter() - .map(|t| ParquetTypeConverter::new(t, &leaves).to_field()) + .map(|t| match t { + FieldType::Parquet(t) => ParquetTypeConverter::new(t, &leaves).to_field(), + FieldType::Arrow(f) => Ok(Some(f)), + }) .collect::>>>() .map(|result| result.into_iter().filter_map(|f| f).collect::>()) .map(|fields| Schema::new_with_metadata(fields, metadata)) } +/// Try to convert Arrow schema metadata into a schema +fn get_arrow_schema_from_metadata(encoded_meta: &str) -> Option { + let decoded = base64::decode(encoded_meta); + match decoded { + Ok(bytes) => { + let slice = if bytes[0..4] == [255u8; 4] { + &bytes[8..] + } else { + bytes.as_slice() + }; + let message = arrow::ipc::get_root_as_message(slice); + message + .header_as_schema() + .map(arrow::ipc::convert::fb_to_schema) + } + Err(err) => { + // The C++ implementation returns an error if the schema can't be parsed. + // To prevent this, we explicitly log this, then compute the schema without the metadata + eprintln!( + "Unable to decode the encoded schema stored in {}, {:?}", + super::ARROW_SCHEMA_META_KEY, + err + ); + None + } + } +} + +/// Encodes the Arrow schema into the IPC format, and base64 encodes it +fn encode_arrow_schema(schema: &Schema) -> String { + let options = writer::IpcWriteOptions::default(); + let mut serialized_schema = arrow::ipc::writer::schema_to_bytes(&schema, &options); + + // manually prepending the length to the schema as arrow uses the legacy IPC format + // TODO: change after addressing ARROW-9777 + let schema_len = serialized_schema.ipc_message.len(); + let mut len_prefix_schema = Vec::with_capacity(schema_len + 8); + len_prefix_schema.append(&mut vec![255u8, 255, 255, 255]); + len_prefix_schema.append((schema_len as u32).to_le_bytes().to_vec().as_mut()); + len_prefix_schema.append(&mut serialized_schema.ipc_message); + + base64::encode(&len_prefix_schema) +} + +/// Mutates writer metadata by storing the encoded Arrow schema. +/// If there is an existing Arrow schema metadata, it is replaced. +pub(crate) fn add_encoded_arrow_schema_to_metadata( + schema: &Schema, + props: &mut WriterProperties, +) { + let encoded = encode_arrow_schema(schema); + + let schema_kv = KeyValue { + key: super::ARROW_SCHEMA_META_KEY.to_string(), + value: Some(encoded), + }; + + let mut meta = props.key_value_metadata.clone().unwrap_or_default(); + // check if ARROW:schema exists, and overwrite it + let schema_meta = meta + .iter() + .enumerate() + .find(|(_, kv)| kv.key.as_str() == super::ARROW_SCHEMA_META_KEY); + match schema_meta { + Some((i, _)) => { + meta.remove(i); + meta.push(schema_kv); + } + None => { + meta.push(schema_kv); + } + } + props.key_value_metadata = Some(meta); +} + /// Convert arrow schema to parquet schema pub fn arrow_to_parquet_schema(schema: &Schema) -> Result { let fields: Result> = schema @@ -140,7 +308,10 @@ fn arrow_to_parquet_type(field: &Field) -> Result { }; // create type from field match field.data_type() { - DataType::Null => Err(ArrowError("Null arrays not supported".to_string())), + DataType::Null => Type::primitive_type_builder(name, PhysicalType::INT32) + .with_logical_type(LogicalType::NONE) + .with_repetition(repetition) + .build(), DataType::Boolean => Type::primitive_type_builder(name, PhysicalType::BOOLEAN) .with_repetition(repetition) .build(), @@ -215,42 +386,48 @@ fn arrow_to_parquet_type(field: &Field) -> Result { Type::primitive_type_builder(name, PhysicalType::FIXED_LEN_BYTE_ARRAY) .with_logical_type(LogicalType::INTERVAL) .with_repetition(repetition) - .with_length(3) + .with_length(12) + .build() + } + DataType::Binary | DataType::LargeBinary => { + Type::primitive_type_builder(name, PhysicalType::BYTE_ARRAY) + .with_repetition(repetition) .build() } - DataType::Binary => Type::primitive_type_builder(name, PhysicalType::BYTE_ARRAY) - .with_repetition(repetition) - .build(), DataType::FixedSizeBinary(length) => { Type::primitive_type_builder(name, PhysicalType::FIXED_LEN_BYTE_ARRAY) .with_repetition(repetition) .with_length(*length) .build() } - DataType::Utf8 => Type::primitive_type_builder(name, PhysicalType::BYTE_ARRAY) - .with_logical_type(LogicalType::UTF8) - .with_repetition(repetition) - .build(), - DataType::List(dtype) | DataType::FixedSizeList(dtype, _) => { - Type::group_type_builder(name) - .with_fields(&mut vec![Rc::new( - Type::group_type_builder("list") - .with_fields(&mut vec![Rc::new({ - let list_field = Field::new( - "element", - *dtype.clone(), - field.is_nullable(), - ); - arrow_to_parquet_type(&list_field)? - })]) - .with_repetition(Repetition::REPEATED) - .build()?, - )]) - .with_logical_type(LogicalType::LIST) - .with_repetition(Repetition::REQUIRED) + DataType::Utf8 | DataType::LargeUtf8 => { + Type::primitive_type_builder(name, PhysicalType::BYTE_ARRAY) + .with_logical_type(LogicalType::UTF8) + .with_repetition(repetition) .build() } + DataType::List(dtype) + | DataType::FixedSizeList(dtype, _) + | DataType::LargeList(dtype) => Type::group_type_builder(name) + .with_fields(&mut vec![Rc::new( + Type::group_type_builder("list") + .with_fields(&mut vec![Rc::new({ + let list_field = + Field::new("element", *dtype.clone(), field.is_nullable()); + arrow_to_parquet_type(&list_field)? + })]) + .with_repetition(Repetition::REPEATED) + .build()?, + )]) + .with_logical_type(LogicalType::LIST) + .with_repetition(Repetition::REQUIRED) + .build(), DataType::Struct(fields) => { + if fields.is_empty() { + return Err(ArrowError( + "Parquet does not support writing empty structs".to_string(), + )); + } // recursively convert children to types/nodes let fields: Result> = fields .iter() @@ -267,9 +444,6 @@ fn arrow_to_parquet_type(field: &Field) -> Result { let dict_field = Field::new(name, *value.clone(), field.is_nullable()); arrow_to_parquet_type(&dict_field) } - DataType::LargeUtf8 | DataType::LargeBinary | DataType::LargeList(_) => { - Err(ArrowError("Large arrays not supported".to_string())) - } } } /// This struct is used to group methods and data structures used to convert parquet @@ -555,12 +729,16 @@ impl ParquetTypeConverter<'_> { mod tests { use super::*; - use std::collections::HashMap; + use std::{collections::HashMap, convert::TryFrom, sync::Arc}; - use arrow::datatypes::{DataType, DateUnit, Field, TimeUnit}; + use arrow::datatypes::{DataType, DateUnit, Field, IntervalUnit, TimeUnit}; - use crate::file::metadata::KeyValue; - use crate::schema::{parser::parse_message_type, types::SchemaDescriptor}; + use crate::file::{metadata::KeyValue, reader::SerializedFileReader}; + use crate::{ + arrow::{ArrowReader, ArrowWriter, ParquetFileArrowReader}, + schema::{parser::parse_message_type, types::SchemaDescriptor}, + util::test_common::get_temp_file, + }; #[test] fn test_flat_primitives() { @@ -1194,6 +1372,17 @@ mod tests { }); } + #[test] + #[should_panic(expected = "Parquet does not support writing empty structs")] + fn test_empty_struct_field() { + let arrow_fields = vec![Field::new("struct", DataType::Struct(vec![]), false)]; + let arrow_schema = Schema::new(arrow_fields); + let converted_arrow_schema = arrow_to_parquet_schema(&arrow_schema); + + assert!(converted_arrow_schema.is_err()); + converted_arrow_schema.unwrap(); + } + #[test] fn test_metadata() { let message_type = " @@ -1216,4 +1405,184 @@ mod tests { assert_eq!(converted_arrow_schema.metadata(), &expected_metadata); } + + #[test] + fn test_arrow_schema_roundtrip() -> Result<()> { + // This tests the roundtrip of an Arrow schema + // Fields that are commented out fail roundtrip tests or are unsupported by the writer + let metadata: HashMap = + [("Key".to_string(), "Value".to_string())] + .iter() + .cloned() + .collect(); + + let schema = Schema::new_with_metadata( + vec![ + Field::new("c1", DataType::Utf8, false), + Field::new("c2", DataType::Binary, false), + Field::new("c3", DataType::FixedSizeBinary(3), false), + Field::new("c4", DataType::Boolean, false), + Field::new("c5", DataType::Date32(DateUnit::Day), false), + Field::new("c6", DataType::Date64(DateUnit::Millisecond), false), + Field::new("c7", DataType::Time32(TimeUnit::Second), false), + Field::new("c8", DataType::Time32(TimeUnit::Millisecond), false), + Field::new("c13", DataType::Time64(TimeUnit::Microsecond), false), + Field::new("c14", DataType::Time64(TimeUnit::Nanosecond), false), + Field::new("c15", DataType::Timestamp(TimeUnit::Second, None), false), + Field::new( + "c16", + DataType::Timestamp( + TimeUnit::Millisecond, + Some(Arc::new("UTC".to_string())), + ), + false, + ), + Field::new( + "c17", + DataType::Timestamp( + TimeUnit::Microsecond, + Some(Arc::new("Africa/Johannesburg".to_string())), + ), + false, + ), + Field::new( + "c18", + DataType::Timestamp(TimeUnit::Nanosecond, None), + false, + ), + Field::new("c19", DataType::Interval(IntervalUnit::DayTime), false), + Field::new("c20", DataType::Interval(IntervalUnit::YearMonth), false), + Field::new("c21", DataType::List(Box::new(DataType::Boolean)), false), + // Field::new( + // "c22", + // DataType::FixedSizeList(Box::new(DataType::Boolean), 5), + // false, + // ), + // Field::new( + // "c23", + // DataType::List(Box::new(DataType::LargeList(Box::new( + // DataType::Struct(vec![ + // Field::new("a", DataType::Int16, true), + // Field::new("b", DataType::Float64, false), + // ]), + // )))), + // true, + // ), + Field::new( + "c24", + DataType::Struct(vec![ + Field::new("a", DataType::Utf8, false), + Field::new("b", DataType::UInt16, false), + ]), + false, + ), + Field::new("c25", DataType::Interval(IntervalUnit::YearMonth), true), + Field::new("c26", DataType::Interval(IntervalUnit::DayTime), true), + // Field::new("c27", DataType::Duration(TimeUnit::Second), false), + // Field::new("c28", DataType::Duration(TimeUnit::Millisecond), false), + // Field::new("c29", DataType::Duration(TimeUnit::Microsecond), false), + // Field::new("c30", DataType::Duration(TimeUnit::Nanosecond), false), + Field::new_dict( + "c31", + DataType::Dictionary( + Box::new(DataType::Int32), + Box::new(DataType::Utf8), + ), + true, + 123, + true, + ), + Field::new("c32", DataType::LargeBinary, true), + Field::new("c33", DataType::LargeUtf8, true), + // Field::new( + // "c34", + // DataType::LargeList(Box::new(DataType::List(Box::new( + // DataType::Struct(vec![ + // Field::new("a", DataType::Int16, true), + // Field::new("b", DataType::Float64, true), + // ]), + // )))), + // true, + // ), + Field::new("c35", DataType::Null, true), + ], + metadata, + ); + + // write to an empty parquet file so that schema is serialized + let file = get_temp_file("test_arrow_schema_roundtrip.parquet", &[]); + let mut writer = ArrowWriter::try_new( + file.try_clone().unwrap(), + Arc::new(schema.clone()), + None, + )?; + writer.close()?; + + // read file back + let parquet_reader = SerializedFileReader::try_from(file)?; + let mut arrow_reader = ParquetFileArrowReader::new(Rc::new(parquet_reader)); + let read_schema = arrow_reader.get_schema()?; + assert_eq!(schema, read_schema); + + // read all fields by columns + let partial_read_schema = + arrow_reader.get_schema_by_columns(0..(schema.fields().len()), false)?; + assert_eq!(schema, partial_read_schema); + + Ok(()) + } + + #[test] + #[ignore = "Roundtrip of lists currently fails because we don't check their types correctly in the Arrow schema"] + fn test_arrow_schema_roundtrip_lists() -> Result<()> { + let metadata: HashMap = + [("Key".to_string(), "Value".to_string())] + .iter() + .cloned() + .collect(); + + let schema = Schema::new_with_metadata( + vec![ + Field::new("c21", DataType::List(Box::new(DataType::Boolean)), false), + Field::new( + "c22", + DataType::FixedSizeList(Box::new(DataType::Boolean), 5), + false, + ), + Field::new( + "c23", + DataType::List(Box::new(DataType::LargeList(Box::new( + DataType::Struct(vec![ + Field::new("a", DataType::Int16, true), + Field::new("b", DataType::Float64, false), + ]), + )))), + true, + ), + ], + metadata, + ); + + // write to an empty parquet file so that schema is serialized + let file = get_temp_file("test_arrow_schema_roundtrip_lists.parquet", &[]); + let mut writer = ArrowWriter::try_new( + file.try_clone().unwrap(), + Arc::new(schema.clone()), + None, + )?; + writer.close()?; + + // read file back + let parquet_reader = SerializedFileReader::try_from(file)?; + let mut arrow_reader = ParquetFileArrowReader::new(Rc::new(parquet_reader)); + let read_schema = arrow_reader.get_schema()?; + assert_eq!(schema, read_schema); + + // read all fields by columns + let partial_read_schema = + arrow_reader.get_schema_by_columns(0..(schema.fields().len()), false)?; + assert_eq!(schema, partial_read_schema); + + Ok(()) + } } diff --git a/rust/parquet/src/file/properties.rs b/rust/parquet/src/file/properties.rs index 188d6ec3c9e..b62ce7bbc38 100644 --- a/rust/parquet/src/file/properties.rs +++ b/rust/parquet/src/file/properties.rs @@ -89,8 +89,8 @@ pub type WriterPropertiesPtr = Rc; /// Writer properties. /// -/// It is created as an immutable data structure, use [`WriterPropertiesBuilder`] to -/// assemble the properties. +/// All properties except the key-value metadata are immutable, +/// use [`WriterPropertiesBuilder`] to assemble these properties. #[derive(Debug, Clone)] pub struct WriterProperties { data_pagesize_limit: usize, @@ -99,7 +99,7 @@ pub struct WriterProperties { max_row_group_size: usize, writer_version: WriterVersion, created_by: String, - key_value_metadata: Option>, + pub(crate) key_value_metadata: Option>, default_column_properties: ColumnProperties, column_properties: HashMap, } diff --git a/rust/parquet/src/schema/types.rs b/rust/parquet/src/schema/types.rs index 416073af035..57999050ab3 100644 --- a/rust/parquet/src/schema/types.rs +++ b/rust/parquet/src/schema/types.rs @@ -788,7 +788,7 @@ impl SchemaDescriptor { result.clone() } - fn column_root_of(&self, i: usize) -> &Rc { + fn column_root_of(&self, i: usize) -> &TypePtr { assert!( i < self.leaves.len(), "Index out of bound: {} not in [0, {})", @@ -810,6 +810,10 @@ impl SchemaDescriptor { self.schema.as_ref() } + pub fn root_schema_ptr(&self) -> TypePtr { + self.schema.clone() + } + /// Returns schema name. pub fn name(&self) -> &str { self.schema.name() diff --git a/rust/parquet/src/schema/visitor.rs b/rust/parquet/src/schema/visitor.rs index 6d712ce441f..a1866fb1471 100644 --- a/rust/parquet/src/schema/visitor.rs +++ b/rust/parquet/src/schema/visitor.rs @@ -50,7 +50,7 @@ pub trait TypeVisitor { { self.visit_list_with_item( list_type.clone(), - list_item, + list_item.clone(), context, ) } else { @@ -70,13 +70,13 @@ pub trait TypeVisitor { { self.visit_list_with_item( list_type.clone(), - fields.first().unwrap(), + fields.first().unwrap().clone(), context, ) } else { self.visit_list_with_item( list_type.clone(), - list_item, + list_item.clone(), context, ) } @@ -114,7 +114,7 @@ pub trait TypeVisitor { fn visit_list_with_item( &mut self, list_type: TypePtr, - item_type: &Type, + item_type: TypePtr, context: C, ) -> Result; } @@ -125,7 +125,7 @@ mod tests { use crate::basic::Type as PhysicalType; use crate::errors::Result; use crate::schema::parser::parse_message_type; - use crate::schema::types::{Type, TypePtr}; + use crate::schema::types::TypePtr; use std::rc::Rc; struct TestVisitorContext {} @@ -174,7 +174,7 @@ mod tests { fn visit_list_with_item( &mut self, list_type: TypePtr, - item_type: &Type, + item_type: TypePtr, _context: TestVisitorContext, ) -> Result { assert_eq!(