diff --git a/rust/arrow/src/compute/kernels/cast.rs b/rust/arrow/src/compute/kernels/cast.rs index d8cb480a80c..a1142b481d5 100644 --- a/rust/arrow/src/compute/kernels/cast.rs +++ b/rust/arrow/src/compute/kernels/cast.rs @@ -327,11 +327,27 @@ pub fn cast(array: &ArrayRef, to_type: &DataType) -> Result { // temporal casts (Int32, Date32(_)) => cast_array_data::(array, to_type.clone()), - (Int32, Time32(_)) => cast_array_data::(array, to_type.clone()), + (Int32, Time32(unit)) => match unit { + TimeUnit::Second => { + cast_array_data::(array, to_type.clone()) + } + TimeUnit::Millisecond => { + cast_array_data::(array, to_type.clone()) + } + _ => unreachable!(), + }, (Date32(_), Int32) => cast_array_data::(array, to_type.clone()), (Time32(_), Int32) => cast_array_data::(array, to_type.clone()), (Int64, Date64(_)) => cast_array_data::(array, to_type.clone()), - (Int64, Time64(_)) => cast_array_data::(array, to_type.clone()), + (Int64, Time64(unit)) => match unit { + TimeUnit::Microsecond => { + cast_array_data::(array, to_type.clone()) + } + TimeUnit::Nanosecond => { + cast_array_data::(array, to_type.clone()) + } + _ => unreachable!(), + }, (Date64(_), Int64) => cast_array_data::(array, to_type.clone()), (Time64(_), Int64) => cast_array_data::(array, to_type.clone()), (Date32(DateUnit::Day), Date64(DateUnit::Millisecond)) => { diff --git a/rust/parquet/Cargo.toml b/rust/parquet/Cargo.toml index 50d7c34d341..60e43c93ffa 100644 --- a/rust/parquet/Cargo.toml +++ b/rust/parquet/Cargo.toml @@ -40,6 +40,7 @@ zstd = { version = "0.5", optional = true } chrono = "0.4" num-bigint = "0.3" arrow = { path = "../arrow", version = "2.0.0-SNAPSHOT", optional = true } +base64 = { version = "*", optional = true } [dev-dependencies] rand = "0.7" @@ -52,4 +53,4 @@ arrow = { path = "../arrow", version = "2.0.0-SNAPSHOT" } serde_json = { version = "1.0", features = ["preserve_order"] } [features] -default = ["arrow", "snap", "brotli", "flate2", "lz4", "zstd"] +default = ["arrow", "snap", "brotli", "flate2", "lz4", "zstd", "base64"] diff --git a/rust/parquet/src/arrow/array_reader.rs b/rust/parquet/src/arrow/array_reader.rs index 14bf7d287a3..4fbc54d209d 100644 --- a/rust/parquet/src/arrow/array_reader.rs +++ b/rust/parquet/src/arrow/array_reader.rs @@ -35,9 +35,10 @@ use crate::arrow::converter::{ BinaryArrayConverter, BinaryConverter, BoolConverter, BooleanArrayConverter, Converter, Date32Converter, FixedLenBinaryConverter, FixedSizeArrayConverter, Float32Converter, Float64Converter, Int16Converter, Int32Converter, Int64Converter, - Int8Converter, Int96ArrayConverter, Int96Converter, TimestampMicrosecondConverter, - TimestampMillisecondConverter, UInt16Converter, UInt32Converter, UInt64Converter, - UInt8Converter, Utf8ArrayConverter, Utf8Converter, + Int8Converter, Int96ArrayConverter, Int96Converter, Time32MillisecondConverter, + Time32SecondConverter, Time64MicrosecondConverter, Time64NanosecondConverter, + TimestampMicrosecondConverter, TimestampMillisecondConverter, UInt16Converter, + UInt32Converter, UInt64Converter, UInt8Converter, Utf8ArrayConverter, Utf8Converter, }; use crate::arrow::record_reader::RecordReader; use crate::arrow::schema::parquet_to_arrow_field; @@ -196,11 +197,27 @@ impl ArrayReader for PrimitiveArrayReader { .convert(self.record_reader.cast::()), _ => Err(general_err!("No conversion from parquet type to arrow type for date with unit {:?}", unit)), } - (ArrowType::Time32(_), PhysicalType::INT32) => { - UInt32Converter::new().convert(self.record_reader.cast::()) + (ArrowType::Time32(unit), PhysicalType::INT32) => { + match unit { + TimeUnit::Second => { + Time32SecondConverter::new().convert(self.record_reader.cast::()) + } + TimeUnit::Millisecond => { + Time32MillisecondConverter::new().convert(self.record_reader.cast::()) + } + _ => Err(general_err!("Invalid or unsupported arrow array with datatype {:?}", self.get_data_type())) + } } - (ArrowType::Time64(_), PhysicalType::INT64) => { - UInt64Converter::new().convert(self.record_reader.cast::()) + (ArrowType::Time64(unit), PhysicalType::INT64) => { + match unit { + TimeUnit::Microsecond => { + Time64MicrosecondConverter::new().convert(self.record_reader.cast::()) + } + TimeUnit::Nanosecond => { + Time64NanosecondConverter::new().convert(self.record_reader.cast::()) + } + _ => Err(general_err!("Invalid or unsupported arrow array with datatype {:?}", self.get_data_type())) + } } (ArrowType::Interval(IntervalUnit::YearMonth), PhysicalType::INT32) => { UInt32Converter::new().convert(self.record_reader.cast::()) @@ -941,10 +958,12 @@ mod tests { use crate::util::test_common::{get_test_file, make_pages}; use arrow::array::{Array, ArrayRef, PrimitiveArray, StringArray, StructArray}; use arrow::datatypes::{ - DataType as ArrowType, Date32Type as ArrowDate32, Field, Int32Type as ArrowInt32, + ArrowPrimitiveType, DataType as ArrowType, Date32Type as ArrowDate32, Field, + Int32Type as ArrowInt32, Int64Type as ArrowInt64, + Time32MillisecondType as ArrowTime32MillisecondArray, + Time64MicrosecondType as ArrowTime64MicrosecondArray, TimestampMicrosecondType as ArrowTimestampMicrosecondType, TimestampMillisecondType as ArrowTimestampMillisecondType, - UInt32Type as ArrowUInt32, UInt64Type as ArrowUInt64, }; use rand::distributions::uniform::SampleUniform; use rand::{thread_rng, Rng}; @@ -1101,7 +1120,7 @@ mod tests { } macro_rules! test_primitive_array_reader_one_type { - ($arrow_parquet_type:ty, $physical_type:expr, $logical_type_str:expr, $result_arrow_type:ty, $result_primitive_type:ty) => {{ + ($arrow_parquet_type:ty, $physical_type:expr, $logical_type_str:expr, $result_arrow_type:ty, $result_arrow_cast_type:ty, $result_primitive_type:ty) => {{ let message_type = format!( " message test_schema {{ @@ -1112,7 +1131,7 @@ mod tests { ); let schema = parse_message_type(&message_type) .map(|t| Rc::new(SchemaDescriptor::new(Rc::new(t)))) - .unwrap(); + .expect("Unable to parse message type into a schema descriptor"); let column_desc = schema.column(0); @@ -1142,24 +1161,48 @@ mod tests { Box::new(page_iterator), column_desc.clone(), ) - .unwrap(); + .expect("Unable to get array reader"); - let array = array_reader.next_batch(50).unwrap(); + let array = array_reader + .next_batch(50) + .expect("Unable to get batch from reader"); + let result_data_type = <$result_arrow_type>::get_data_type(); let array = array .as_any() .downcast_ref::>() - .unwrap(); - - assert_eq!( - &PrimitiveArray::<$result_arrow_type>::from( - data[0..50] - .iter() - .map(|x| *x as $result_primitive_type) - .collect::>() - ), - array + .expect( + format!( + "Unable to downcast {:?} to {:?}", + array.data_type(), + result_data_type + ) + .as_str(), + ); + + // create expected array as primitive, and cast to result type + let expected = PrimitiveArray::<$result_arrow_cast_type>::from( + data[0..50] + .iter() + .map(|x| *x as $result_primitive_type) + .collect::>(), ); + let expected = Arc::new(expected) as ArrayRef; + let expected = arrow::compute::cast(&expected, &result_data_type) + .expect("Unable to cast expected array"); + assert_eq!(expected.data_type(), &result_data_type); + let expected = expected + .as_any() + .downcast_ref::>() + .expect( + format!( + "Unable to downcast expected {:?} to {:?}", + expected.data_type(), + result_data_type + ) + .as_str(), + ); + assert_eq!(expected, array); } }}; } @@ -1171,27 +1214,31 @@ mod tests { PhysicalType::INT32, "DATE", ArrowDate32, + ArrowInt32, i32 ); test_primitive_array_reader_one_type!( Int32Type, PhysicalType::INT32, "TIME_MILLIS", - ArrowUInt32, - u32 + ArrowTime32MillisecondArray, + ArrowInt32, + i32 ); test_primitive_array_reader_one_type!( Int64Type, PhysicalType::INT64, "TIME_MICROS", - ArrowUInt64, - u64 + ArrowTime64MicrosecondArray, + ArrowInt64, + i64 ); test_primitive_array_reader_one_type!( Int64Type, PhysicalType::INT64, "TIMESTAMP_MILLIS", ArrowTimestampMillisecondType, + ArrowInt64, i64 ); test_primitive_array_reader_one_type!( @@ -1199,6 +1246,7 @@ mod tests { PhysicalType::INT64, "TIMESTAMP_MICROS", ArrowTimestampMicrosecondType, + ArrowInt64, i64 ); } diff --git a/rust/parquet/src/arrow/arrow_writer.rs b/rust/parquet/src/arrow/arrow_writer.rs new file mode 100644 index 00000000000..314a800c325 --- /dev/null +++ b/rust/parquet/src/arrow/arrow_writer.rs @@ -0,0 +1,1098 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Contains writer which writes arrow data into parquet data. + +use std::rc::Rc; + +use arrow::array as arrow_array; +use arrow::datatypes::{DataType as ArrowDataType, SchemaRef}; +use arrow::record_batch::RecordBatch; +use arrow_array::{Array, PrimitiveArrayOps}; + +use super::schema::add_encoded_arrow_schema_to_metadata; +use crate::column::writer::ColumnWriter; +use crate::errors::{ParquetError, Result}; +use crate::file::properties::WriterProperties; +use crate::{ + data_type::*, + file::writer::{FileWriter, ParquetWriter, RowGroupWriter, SerializedFileWriter}, +}; + +/// Arrow writer +/// +/// Writes Arrow `RecordBatch`es to a Parquet writer +pub struct ArrowWriter { + /// Underlying Parquet writer + writer: SerializedFileWriter, + /// A copy of the Arrow schema. + /// + /// The schema is used to verify that each record batch written has the correct schema + arrow_schema: SchemaRef, +} + +impl ArrowWriter { + /// Try to create a new Arrow writer + /// + /// The writer will fail if: + /// * a `SerializedFileWriter` cannot be created from the ParquetWriter + /// * the Arrow schema contains unsupported datatypes such as Unions + pub fn try_new( + writer: W, + arrow_schema: SchemaRef, + props: Option, + ) -> Result { + let schema = crate::arrow::arrow_to_parquet_schema(&arrow_schema)?; + // add serialized arrow schema + let mut props = props.unwrap_or_else(|| WriterProperties::builder().build()); + add_encoded_arrow_schema_to_metadata(&arrow_schema, &mut props); + + let file_writer = SerializedFileWriter::new( + writer.try_clone()?, + schema.root_schema_ptr(), + Rc::new(props), + )?; + + Ok(Self { + writer: file_writer, + arrow_schema, + }) + } + + /// Write a RecordBatch to writer + /// + /// *NOTE:* The writer currently does not support all Arrow data types + pub fn write(&mut self, batch: &RecordBatch) -> Result<()> { + // validate batch schema against writer's supplied schema + if self.arrow_schema != batch.schema() { + return Err(ParquetError::ArrowError( + "Record batch schema does not match writer schema".to_string(), + )); + } + // compute the definition and repetition levels of the batch + let mut levels = vec![]; + batch.columns().iter().for_each(|array| { + let mut array_levels = + get_levels(array, 0, &vec![1i16; batch.num_rows()][..], None); + levels.append(&mut array_levels); + }); + // reverse levels so we can use Vec::pop(&mut self) + levels.reverse(); + + let mut row_group_writer = self.writer.next_row_group()?; + + // write leaves + for column in batch.columns() { + write_leaves(&mut row_group_writer, column, &mut levels)?; + } + + self.writer.close_row_group(row_group_writer) + } + + /// Close and finalise the underlying Parquet writer + pub fn close(&mut self) -> Result<()> { + self.writer.close() + } +} + +/// Convenience method to get the next ColumnWriter from the RowGroupWriter +#[inline] +#[allow(clippy::borrowed_box)] +fn get_col_writer( + row_group_writer: &mut Box, +) -> Result { + let col_writer = row_group_writer + .next_column()? + .expect("Unable to get column writer"); + Ok(col_writer) +} + +#[allow(clippy::borrowed_box)] +fn write_leaves( + mut row_group_writer: &mut Box, + array: &arrow_array::ArrayRef, + mut levels: &mut Vec, +) -> Result<()> { + match array.data_type() { + ArrowDataType::Int8 + | ArrowDataType::Int16 + | ArrowDataType::Int32 + | ArrowDataType::Int64 + | ArrowDataType::UInt8 + | ArrowDataType::UInt16 + | ArrowDataType::UInt32 + | ArrowDataType::UInt64 + | ArrowDataType::Float16 + | ArrowDataType::Float32 + | ArrowDataType::Float64 + | ArrowDataType::Timestamp(_, _) + | ArrowDataType::Date32(_) + | ArrowDataType::Date64(_) + | ArrowDataType::Time32(_) + | ArrowDataType::Time64(_) + | ArrowDataType::Duration(_) + | ArrowDataType::Interval(_) + | ArrowDataType::LargeBinary + | ArrowDataType::Binary + | ArrowDataType::Utf8 + | ArrowDataType::LargeUtf8 => { + let mut col_writer = get_col_writer(&mut row_group_writer)?; + write_leaf( + &mut col_writer, + array, + levels.pop().expect("Levels exhausted"), + )?; + row_group_writer.close_column(col_writer)?; + Ok(()) + } + ArrowDataType::List(_) | ArrowDataType::LargeList(_) => { + // write the child list + let data = array.data(); + let child_array = arrow_array::make_array(data.child_data()[0].clone()); + write_leaves(&mut row_group_writer, &child_array, &mut levels)?; + Ok(()) + } + ArrowDataType::Struct(_) => { + let struct_array: &arrow_array::StructArray = array + .as_any() + .downcast_ref::() + .expect("Unable to get struct array"); + for field in struct_array.columns() { + write_leaves(&mut row_group_writer, field, &mut levels)?; + } + Ok(()) + } + ArrowDataType::FixedSizeList(_, _) + | ArrowDataType::Null + | ArrowDataType::Boolean + | ArrowDataType::FixedSizeBinary(_) + | ArrowDataType::Union(_) + | ArrowDataType::Dictionary(_, _) => Err(ParquetError::NYI( + "Attempting to write an Arrow type that is not yet implemented".to_string(), + )), + } +} + +fn write_leaf( + writer: &mut ColumnWriter, + column: &arrow_array::ArrayRef, + levels: Levels, +) -> Result { + let written = match writer { + ColumnWriter::Int32ColumnWriter(ref mut typed) => { + let array = arrow::compute::cast(column, &ArrowDataType::Int32)?; + let array = array + .as_any() + .downcast_ref::() + .expect("Unable to get int32 array"); + typed.write_batch( + get_numeric_array_slice::(&array).as_slice(), + Some(levels.definition.as_slice()), + levels.repetition.as_deref(), + )? + } + ColumnWriter::BoolColumnWriter(ref mut _typed) => { + unreachable!("Currently unreachable because data type not supported") + } + ColumnWriter::Int64ColumnWriter(ref mut typed) => { + let array = arrow_array::Int64Array::from(column.data()); + typed.write_batch( + get_numeric_array_slice::(&array).as_slice(), + Some(levels.definition.as_slice()), + levels.repetition.as_deref(), + )? + } + ColumnWriter::Int96ColumnWriter(ref mut _typed) => { + unreachable!("Currently unreachable because data type not supported") + } + ColumnWriter::FloatColumnWriter(ref mut typed) => { + let array = arrow_array::Float32Array::from(column.data()); + typed.write_batch( + get_numeric_array_slice::(&array).as_slice(), + Some(levels.definition.as_slice()), + levels.repetition.as_deref(), + )? + } + ColumnWriter::DoubleColumnWriter(ref mut typed) => { + let array = arrow_array::Float64Array::from(column.data()); + typed.write_batch( + get_numeric_array_slice::(&array).as_slice(), + Some(levels.definition.as_slice()), + levels.repetition.as_deref(), + )? + } + ColumnWriter::ByteArrayColumnWriter(ref mut typed) => match column.data_type() { + ArrowDataType::Binary | ArrowDataType::Utf8 => { + let array = arrow_array::BinaryArray::from(column.data()); + typed.write_batch( + get_binary_array(&array).as_slice(), + Some(levels.definition.as_slice()), + levels.repetition.as_deref(), + )? + } + ArrowDataType::LargeBinary | ArrowDataType::LargeUtf8 => { + let array = arrow_array::LargeBinaryArray::from(column.data()); + typed.write_batch( + get_large_binary_array(&array).as_slice(), + Some(levels.definition.as_slice()), + levels.repetition.as_deref(), + )? + } + _ => unreachable!("Currently unreachable because data type not supported"), + }, + ColumnWriter::FixedLenByteArrayColumnWriter(ref mut _typed) => { + unreachable!("Currently unreachable because data type not supported") + } + }; + Ok(written as i64) +} + +/// A struct that represents definition and repetition levels. +/// Repetition levels are only populated if the parent or current leaf is repeated +#[derive(Debug)] +struct Levels { + definition: Vec, + repetition: Option>, +} + +/// Compute nested levels of the Arrow array, recursing into lists and structs +fn get_levels( + array: &arrow_array::ArrayRef, + level: i16, + parent_def_levels: &[i16], + parent_rep_levels: Option<&[i16]>, +) -> Vec { + match array.data_type() { + ArrowDataType::Null => unimplemented!(), + ArrowDataType::Boolean + | ArrowDataType::Int8 + | ArrowDataType::Int16 + | ArrowDataType::Int32 + | ArrowDataType::Int64 + | ArrowDataType::UInt8 + | ArrowDataType::UInt16 + | ArrowDataType::UInt32 + | ArrowDataType::UInt64 + | ArrowDataType::Float16 + | ArrowDataType::Float32 + | ArrowDataType::Float64 + | ArrowDataType::Utf8 + | ArrowDataType::LargeUtf8 + | ArrowDataType::Timestamp(_, _) + | ArrowDataType::Date32(_) + | ArrowDataType::Date64(_) + | ArrowDataType::Time32(_) + | ArrowDataType::Time64(_) + | ArrowDataType::Duration(_) + | ArrowDataType::Interval(_) + | ArrowDataType::Binary + | ArrowDataType::LargeBinary => vec![Levels { + definition: get_primitive_def_levels(array, parent_def_levels), + repetition: None, + }], + ArrowDataType::FixedSizeBinary(_) => unimplemented!(), + ArrowDataType::List(_) | ArrowDataType::LargeList(_) => { + let array_data = array.data(); + let child_data = array_data.child_data().get(0).unwrap(); + // get offsets, accounting for large offsets if present + let offsets: Vec = { + if let ArrowDataType::LargeList(_) = array.data_type() { + unsafe { array_data.buffers()[0].typed_data::() }.to_vec() + } else { + let offsets = unsafe { array_data.buffers()[0].typed_data::() }; + offsets.to_vec().into_iter().map(|v| v as i64).collect() + } + }; + let child_array = arrow_array::make_array(child_data.clone()); + + let mut list_def_levels = Vec::with_capacity(child_array.len()); + let mut list_rep_levels = Vec::with_capacity(child_array.len()); + let rep_levels: Vec = parent_rep_levels + .map(|l| l.to_vec()) + .unwrap_or_else(|| vec![0i16; parent_def_levels.len()]); + parent_def_levels + .iter() + .zip(rep_levels) + .zip(offsets.windows(2)) + .for_each(|((parent_def_level, parent_rep_level), window)| { + if *parent_def_level == 0 { + // parent is null, list element must also be null + list_def_levels.push(0); + list_rep_levels.push(0); + } else { + // parent is not null, check if list is empty or null + let start = window[0]; + let end = window[1]; + let len = end - start; + if len == 0 { + list_def_levels.push(*parent_def_level - 1); + list_rep_levels.push(parent_rep_level); + } else { + list_def_levels.push(*parent_def_level); + list_rep_levels.push(parent_rep_level); + for _ in 1..len { + list_def_levels.push(*parent_def_level); + list_rep_levels.push(parent_rep_level + 1); + } + } + } + }); + + // if datatype is a primitive, we can construct levels of the child array + match child_array.data_type() { + ArrowDataType::Null => unimplemented!(), + ArrowDataType::Boolean => unimplemented!(), + ArrowDataType::Int8 + | ArrowDataType::Int16 + | ArrowDataType::Int32 + | ArrowDataType::Int64 + | ArrowDataType::UInt8 + | ArrowDataType::UInt16 + | ArrowDataType::UInt32 + | ArrowDataType::UInt64 + | ArrowDataType::Float16 + | ArrowDataType::Float32 + | ArrowDataType::Float64 + | ArrowDataType::Timestamp(_, _) + | ArrowDataType::Date32(_) + | ArrowDataType::Date64(_) + | ArrowDataType::Time32(_) + | ArrowDataType::Time64(_) + | ArrowDataType::Duration(_) + | ArrowDataType::Interval(_) => { + let def_levels = + get_primitive_def_levels(&child_array, &list_def_levels[..]); + vec![Levels { + definition: def_levels, + repetition: Some(list_rep_levels), + }] + } + ArrowDataType::Binary + | ArrowDataType::Utf8 + | ArrowDataType::LargeUtf8 => unimplemented!(), + ArrowDataType::FixedSizeBinary(_) => unimplemented!(), + ArrowDataType::LargeBinary => unimplemented!(), + ArrowDataType::List(_) | ArrowDataType::LargeList(_) => { + // nested list + unimplemented!() + } + ArrowDataType::FixedSizeList(_, _) => unimplemented!(), + ArrowDataType::Struct(_) => get_levels( + array, + level + 1, // indicates a nesting level of 2 (list + struct) + &list_def_levels[..], + Some(&list_rep_levels[..]), + ), + ArrowDataType::Union(_) => unimplemented!(), + ArrowDataType::Dictionary(_, _) => unimplemented!(), + } + } + ArrowDataType::FixedSizeList(_, _) => unimplemented!(), + ArrowDataType::Struct(_) => { + let struct_array: &arrow_array::StructArray = array + .as_any() + .downcast_ref::() + .expect("Unable to get struct array"); + let mut struct_def_levels = Vec::with_capacity(struct_array.len()); + for i in 0..array.len() { + struct_def_levels.push(level + struct_array.is_valid(i) as i16); + } + // trying to create levels for struct's fields + let mut struct_levels = vec![]; + struct_array.columns().into_iter().for_each(|col| { + let mut levels = + get_levels(col, level + 1, &struct_def_levels[..], parent_rep_levels); + struct_levels.append(&mut levels); + }); + struct_levels + } + ArrowDataType::Union(_) => unimplemented!(), + ArrowDataType::Dictionary(_, _) => unimplemented!(), + } +} + +/// Get the definition levels of the numeric array, with level 0 being null and 1 being not null +/// In the case where the array in question is a child of either a list or struct, the levels +/// are incremented in accordance with the `level` parameter. +/// Parent levels are either 0 or 1, and are used to higher (correct terminology?) leaves as null +fn get_primitive_def_levels( + array: &arrow_array::ArrayRef, + parent_def_levels: &[i16], +) -> Vec { + let mut array_index = 0; + let max_def_level = parent_def_levels.iter().max().unwrap(); + let mut primitive_def_levels = vec![]; + parent_def_levels.iter().for_each(|def_level| { + if def_level < max_def_level { + primitive_def_levels.push(*def_level); + } else { + primitive_def_levels.push(def_level - array.is_null(array_index) as i16); + array_index += 1; + } + }); + primitive_def_levels +} + +macro_rules! def_get_binary_array_fn { + ($name:ident, $ty:ty) => { + fn $name(array: &$ty) -> Vec { + let mut values = Vec::with_capacity(array.len() - array.null_count()); + for i in 0..array.len() { + if array.is_valid(i) { + let bytes = ByteArray::from(array.value(i).to_vec()); + values.push(bytes); + } + } + values + } + }; +} + +def_get_binary_array_fn!(get_binary_array, arrow_array::BinaryArray); +def_get_binary_array_fn!(get_large_binary_array, arrow_array::LargeBinaryArray); + +/// Get the underlying numeric array slice, skipping any null values. +/// If there are no null values, it might be quicker to get the slice directly instead of +/// calling this function. +fn get_numeric_array_slice(array: &arrow_array::PrimitiveArray) -> Vec +where + T: DataType, + A: arrow::datatypes::ArrowNumericType, + T::T: From, +{ + let mut values = Vec::with_capacity(array.len() - array.null_count()); + for i in 0..array.len() { + if array.is_valid(i) { + values.push(array.value(i).into()) + } + } + values +} + +#[cfg(test)] +mod tests { + use super::*; + + use std::io::Seek; + use std::sync::Arc; + + use arrow::array::*; + use arrow::datatypes::ToByteSlice; + use arrow::datatypes::{DataType, Field, Schema}; + use arrow::record_batch::{RecordBatch, RecordBatchReader}; + + use crate::arrow::{ArrowReader, ParquetFileArrowReader}; + use crate::file::{metadata::KeyValue, reader::SerializedFileReader}; + use crate::util::test_common::get_temp_file; + + #[test] + fn arrow_writer() { + // define schema + let schema = Schema::new(vec![ + Field::new("a", DataType::Int32, false), + Field::new("b", DataType::Int32, true), + ]); + + // create some data + let a = Int32Array::from(vec![1, 2, 3, 4, 5]); + let b = Int32Array::from(vec![Some(1), None, None, Some(4), Some(5)]); + + // build a record batch + let batch = RecordBatch::try_new( + Arc::new(schema.clone()), + vec![Arc::new(a), Arc::new(b)], + ) + .unwrap(); + + let file = get_temp_file("test_arrow_writer.parquet", &[]); + let mut writer = ArrowWriter::try_new(file, Arc::new(schema), None).unwrap(); + writer.write(&batch).unwrap(); + writer.close().unwrap(); + } + + #[test] + fn arrow_writer_list() { + // define schema + let schema = Schema::new(vec![Field::new( + "a", + DataType::List(Box::new(DataType::Int32)), + false, + )]); + + // create some data + let a_values = Int32Array::from(vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10]); + + // Construct a buffer for value offsets, for the nested array: + // [[false], [true, false], null, [true, false, true], [false, true, false, true]] + let a_value_offsets = + arrow::buffer::Buffer::from(&[0, 1, 3, 3, 6, 10].to_byte_slice()); + + // Construct a list array from the above two + let a_list_data = ArrayData::builder(DataType::List(Box::new(DataType::Int32))) + .len(5) + .add_buffer(a_value_offsets) + .add_child_data(a_values.data()) + .build(); + let a = ListArray::from(a_list_data); + + // build a record batch + let batch = + RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(a)]).unwrap(); + + let file = get_temp_file("test_arrow_writer_list.parquet", &[]); + let mut writer = ArrowWriter::try_new(file, Arc::new(schema), None).unwrap(); + writer.write(&batch).unwrap(); + writer.close().unwrap(); + } + + #[test] + fn arrow_writer_binary() { + let string_field = Field::new("a", DataType::Utf8, false); + let binary_field = Field::new("b", DataType::Binary, false); + let schema = Schema::new(vec![string_field, binary_field]); + + let raw_string_values = vec!["foo", "bar", "baz", "quux"]; + let raw_binary_values = vec![ + b"foo".to_vec(), + b"bar".to_vec(), + b"baz".to_vec(), + b"quux".to_vec(), + ]; + let raw_binary_value_refs = raw_binary_values + .iter() + .map(|x| x.as_slice()) + .collect::>(); + + let string_values = StringArray::from(raw_string_values.clone()); + let binary_values = BinaryArray::from(raw_binary_value_refs); + let batch = RecordBatch::try_new( + Arc::new(schema.clone()), + vec![Arc::new(string_values), Arc::new(binary_values)], + ) + .unwrap(); + + let mut file = get_temp_file("test_arrow_writer_binary.parquet", &[]); + let mut writer = + ArrowWriter::try_new(file.try_clone().unwrap(), Arc::new(schema), None) + .unwrap(); + writer.write(&batch).unwrap(); + writer.close().unwrap(); + + file.seek(std::io::SeekFrom::Start(0)).unwrap(); + let file_reader = SerializedFileReader::new(file).unwrap(); + let mut arrow_reader = ParquetFileArrowReader::new(Rc::new(file_reader)); + let mut record_batch_reader = arrow_reader.get_record_reader(1024).unwrap(); + + let batch = record_batch_reader.next_batch().unwrap().unwrap(); + let string_col = batch + .column(0) + .as_any() + .downcast_ref::() + .unwrap(); + let binary_col = batch + .column(1) + .as_any() + .downcast_ref::() + .unwrap(); + + for i in 0..batch.num_rows() { + assert_eq!(string_col.value(i), raw_string_values[i]); + assert_eq!(binary_col.value(i), raw_binary_values[i].as_slice()); + } + } + + #[test] + fn arrow_writer_complex() { + // define schema + let struct_field_d = Field::new("d", DataType::Float64, true); + let struct_field_f = Field::new("f", DataType::Float32, true); + let struct_field_g = + Field::new("g", DataType::List(Box::new(DataType::Int16)), false); + let struct_field_e = Field::new( + "e", + DataType::Struct(vec![struct_field_f.clone(), struct_field_g.clone()]), + true, + ); + let schema = Schema::new(vec![ + Field::new("a", DataType::Int32, false), + Field::new("b", DataType::Int32, true), + Field::new( + "c", + DataType::Struct(vec![struct_field_d.clone(), struct_field_e.clone()]), + false, + ), + ]); + + // create some data + let a = Int32Array::from(vec![1, 2, 3, 4, 5]); + let b = Int32Array::from(vec![Some(1), None, None, Some(4), Some(5)]); + let d = Float64Array::from(vec![None, None, None, Some(1.0), None]); + let f = Float32Array::from(vec![Some(0.0), None, Some(333.3), None, Some(5.25)]); + + let g_value = Int16Array::from(vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10]); + + // Construct a buffer for value offsets, for the nested array: + // [[1], [2, 3], null, [4, 5, 6], [7, 8, 9, 10]] + let g_value_offsets = + arrow::buffer::Buffer::from(&[0, 1, 3, 3, 6, 10].to_byte_slice()); + + // Construct a list array from the above two + let g_list_data = ArrayData::builder(struct_field_g.data_type().clone()) + .len(5) + .add_buffer(g_value_offsets) + .add_child_data(g_value.data()) + .build(); + let g = ListArray::from(g_list_data); + + let e = StructArray::from(vec![ + (struct_field_f, Arc::new(f) as ArrayRef), + (struct_field_g, Arc::new(g) as ArrayRef), + ]); + + let c = StructArray::from(vec![ + (struct_field_d, Arc::new(d) as ArrayRef), + (struct_field_e, Arc::new(e) as ArrayRef), + ]); + + // build a record batch + let batch = RecordBatch::try_new( + Arc::new(schema.clone()), + vec![Arc::new(a), Arc::new(b), Arc::new(c)], + ) + .unwrap(); + + let props = WriterProperties::builder() + .set_key_value_metadata(Some(vec![KeyValue { + key: "test_key".to_string(), + value: Some("test_value".to_string()), + }])) + .build(); + + let file = get_temp_file("test_arrow_writer_complex.parquet", &[]); + let mut writer = + ArrowWriter::try_new(file, Arc::new(schema), Some(props)).unwrap(); + writer.write(&batch).unwrap(); + writer.close().unwrap(); + } + + const SMALL_SIZE: usize = 100; + + fn roundtrip(filename: &str, expected_batch: RecordBatch) { + let file = get_temp_file(filename, &[]); + + let mut writer = ArrowWriter::try_new( + file.try_clone().unwrap(), + expected_batch.schema(), + None, + ) + .unwrap(); + writer.write(&expected_batch).unwrap(); + writer.close().unwrap(); + + let reader = SerializedFileReader::new(file).unwrap(); + let mut arrow_reader = ParquetFileArrowReader::new(Rc::new(reader)); + let mut record_batch_reader = arrow_reader.get_record_reader(1024).unwrap(); + + let actual_batch = record_batch_reader.next_batch().unwrap().unwrap(); + + assert_eq!(expected_batch.schema(), actual_batch.schema()); + assert_eq!(expected_batch.num_columns(), actual_batch.num_columns()); + assert_eq!(expected_batch.num_rows(), actual_batch.num_rows()); + for i in 0..expected_batch.num_columns() { + let expected_data = expected_batch.column(i).data(); + let actual_data = actual_batch.column(i).data(); + + assert_eq!(expected_data.data_type(), actual_data.data_type()); + assert_eq!(expected_data.len(), actual_data.len()); + assert_eq!(expected_data.null_count(), actual_data.null_count()); + assert_eq!(expected_data.offset(), actual_data.offset()); + assert_eq!(expected_data.buffers(), actual_data.buffers()); + assert_eq!(expected_data.child_data(), actual_data.child_data()); + assert_eq!(expected_data.null_bitmap(), actual_data.null_bitmap()); + } + } + + fn one_column_roundtrip(filename: &str, values: ArrayRef, nullable: bool) { + let schema = Schema::new(vec![Field::new( + "col", + values.data_type().clone(), + nullable, + )]); + let expected_batch = + RecordBatch::try_new(Arc::new(schema), vec![values]).unwrap(); + + roundtrip(filename, expected_batch); + } + + fn values_required(iter: I, filename: &str) + where + A: From> + Array + 'static, + I: IntoIterator, + { + let raw_values: Vec<_> = iter.into_iter().collect(); + let values = Arc::new(A::from(raw_values)); + one_column_roundtrip(filename, values, false); + } + + fn values_optional(iter: I, filename: &str) + where + A: From>> + Array + 'static, + I: IntoIterator, + { + let optional_raw_values: Vec<_> = iter + .into_iter() + .enumerate() + .map(|(i, v)| if i % 2 == 0 { None } else { Some(v) }) + .collect(); + let optional_values = Arc::new(A::from(optional_raw_values)); + one_column_roundtrip(filename, optional_values, true); + } + + fn required_and_optional(iter: I, filename: &str) + where + A: From> + From>> + Array + 'static, + I: IntoIterator + Clone, + { + values_required::(iter.clone(), filename); + values_optional::(iter, filename); + } + + #[test] + #[should_panic(expected = "Null arrays not supported")] + fn null_single_column() { + let values = Arc::new(NullArray::new(SMALL_SIZE)); + one_column_roundtrip("null_single_column", values.clone(), true); + one_column_roundtrip("null_single_column", values, false); + } + + #[test] + #[should_panic( + expected = "Attempting to write an Arrow type that is not yet implemented" + )] + fn bool_single_column() { + required_and_optional::( + [true, false].iter().cycle().copied().take(SMALL_SIZE), + "bool_single_column", + ); + } + + #[test] + fn i8_single_column() { + required_and_optional::(0..SMALL_SIZE as i8, "i8_single_column"); + } + + #[test] + fn i16_single_column() { + required_and_optional::(0..SMALL_SIZE as i16, "i16_single_column"); + } + + #[test] + fn i32_single_column() { + required_and_optional::(0..SMALL_SIZE as i32, "i32_single_column"); + } + + #[test] + fn i64_single_column() { + required_and_optional::(0..SMALL_SIZE as i64, "i64_single_column"); + } + + #[test] + fn u8_single_column() { + required_and_optional::(0..SMALL_SIZE as u8, "u8_single_column"); + } + + #[test] + fn u16_single_column() { + required_and_optional::( + 0..SMALL_SIZE as u16, + "u16_single_column", + ); + } + + #[test] + fn u32_single_column() { + required_and_optional::( + 0..SMALL_SIZE as u32, + "u32_single_column", + ); + } + + #[test] + fn u64_single_column() { + required_and_optional::( + 0..SMALL_SIZE as u64, + "u64_single_column", + ); + } + + // How to create Float16 values that aren't supported in Rust? + + #[test] + fn f32_single_column() { + required_and_optional::( + (0..SMALL_SIZE).map(|i| i as f32), + "f32_single_column", + ); + } + + #[test] + fn f64_single_column() { + required_and_optional::( + (0..SMALL_SIZE).map(|i| i as f64), + "f64_single_column", + ); + } + + // The timestamp array types don't implement From> because they need the timezone + // argument, and they also doesn't support building from a Vec>, so call + // one_column_roundtrip manually instead of calling required_and_optional for these tests. + + #[test] + #[ignore] // Timestamp support isn't correct yet + fn timestamp_second_single_column() { + let raw_values: Vec<_> = (0..SMALL_SIZE as i64).collect(); + let values = Arc::new(TimestampSecondArray::from_vec(raw_values, None)); + + one_column_roundtrip("timestamp_second_single_column", values, false); + } + + #[test] + fn timestamp_millisecond_single_column() { + let raw_values: Vec<_> = (0..SMALL_SIZE as i64).collect(); + let values = Arc::new(TimestampMillisecondArray::from_vec(raw_values, None)); + + one_column_roundtrip("timestamp_millisecond_single_column", values, false); + } + + #[test] + fn timestamp_microsecond_single_column() { + let raw_values: Vec<_> = (0..SMALL_SIZE as i64).collect(); + let values = Arc::new(TimestampMicrosecondArray::from_vec(raw_values, None)); + + one_column_roundtrip("timestamp_microsecond_single_column", values, false); + } + + #[test] + #[ignore] // Timestamp support isn't correct yet + fn timestamp_nanosecond_single_column() { + let raw_values: Vec<_> = (0..SMALL_SIZE as i64).collect(); + let values = Arc::new(TimestampNanosecondArray::from_vec(raw_values, None)); + + one_column_roundtrip("timestamp_nanosecond_single_column", values, false); + } + + #[test] + fn date32_single_column() { + required_and_optional::( + 0..SMALL_SIZE as i32, + "date32_single_column", + ); + } + + #[test] + #[ignore] // Date support isn't correct yet + fn date64_single_column() { + required_and_optional::( + 0..SMALL_SIZE as i64, + "date64_single_column", + ); + } + + #[test] + #[ignore] // DateUnit resolution mismatch + fn time32_second_single_column() { + required_and_optional::( + 0..SMALL_SIZE as i32, + "time32_second_single_column", + ); + } + + #[test] + fn time32_millisecond_single_column() { + required_and_optional::( + 0..SMALL_SIZE as i32, + "time32_millisecond_single_column", + ); + } + + #[test] + fn time64_microsecond_single_column() { + required_and_optional::( + 0..SMALL_SIZE as i64, + "time64_microsecond_single_column", + ); + } + + #[test] + #[ignore] // DateUnit resolution mismatch + fn time64_nanosecond_single_column() { + required_and_optional::( + 0..SMALL_SIZE as i64, + "time64_nanosecond_single_column", + ); + } + + #[test] + #[should_panic(expected = "Converting Duration to parquet not supported")] + fn duration_second_single_column() { + required_and_optional::( + 0..SMALL_SIZE as i64, + "duration_second_single_column", + ); + } + + #[test] + #[should_panic(expected = "Converting Duration to parquet not supported")] + fn duration_millisecond_single_column() { + required_and_optional::( + 0..SMALL_SIZE as i64, + "duration_millisecond_single_column", + ); + } + + #[test] + #[should_panic(expected = "Converting Duration to parquet not supported")] + fn duration_microsecond_single_column() { + required_and_optional::( + 0..SMALL_SIZE as i64, + "duration_microsecond_single_column", + ); + } + + #[test] + #[should_panic(expected = "Converting Duration to parquet not supported")] + fn duration_nanosecond_single_column() { + required_and_optional::( + 0..SMALL_SIZE as i64, + "duration_nanosecond_single_column", + ); + } + + #[test] + #[should_panic(expected = "Currently unreachable because data type not supported")] + fn interval_year_month_single_column() { + required_and_optional::( + 0..SMALL_SIZE as i32, + "interval_year_month_single_column", + ); + } + + #[test] + #[should_panic(expected = "Currently unreachable because data type not supported")] + fn interval_day_time_single_column() { + required_and_optional::( + 0..SMALL_SIZE as i64, + "interval_day_time_single_column", + ); + } + + #[test] + #[ignore] // Binary support isn't correct yet - null_bitmap doesn't match + fn binary_single_column() { + let one_vec: Vec = (0..SMALL_SIZE as u8).collect(); + let many_vecs: Vec<_> = std::iter::repeat(one_vec).take(SMALL_SIZE).collect(); + let many_vecs_iter = many_vecs.iter().map(|v| v.as_slice()); + + // BinaryArrays can't be built from Vec>, so only call `values_required` + values_required::(many_vecs_iter, "binary_single_column"); + } + + #[test] + #[ignore] // Large Binary support isn't correct yet + fn large_binary_single_column() { + let one_vec: Vec = (0..SMALL_SIZE as u8).collect(); + let many_vecs: Vec<_> = std::iter::repeat(one_vec).take(SMALL_SIZE).collect(); + let many_vecs_iter = many_vecs.iter().map(|v| v.as_slice()); + + // LargeBinaryArrays can't be built from Vec>, so only call `values_required` + values_required::( + many_vecs_iter, + "large_binary_single_column", + ); + } + + #[test] + #[ignore] // String support isn't correct yet - null_bitmap doesn't match + fn string_single_column() { + let raw_values: Vec<_> = (0..SMALL_SIZE).map(|i| i.to_string()).collect(); + let raw_strs = raw_values.iter().map(|s| s.as_str()); + + required_and_optional::(raw_strs, "string_single_column"); + } + + #[test] + #[ignore] // Large String support isn't correct yet - null_bitmap and buffers don't match + fn large_string_single_column() { + let raw_values: Vec<_> = (0..SMALL_SIZE).map(|i| i.to_string()).collect(); + let raw_strs = raw_values.iter().map(|s| s.as_str()); + + required_and_optional::( + raw_strs, + "large_string_single_column", + ); + } + + #[test] + #[should_panic( + expected = "Reading parquet list array into arrow is not supported yet!" + )] + fn list_single_column() { + let a_values = Int32Array::from(vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10]); + let a_value_offsets = + arrow::buffer::Buffer::from(&[0, 1, 3, 3, 6, 10].to_byte_slice()); + let a_list_data = ArrayData::builder(DataType::List(Box::new(DataType::Int32))) + .len(5) + .add_buffer(a_value_offsets) + .add_child_data(a_values.data()) + .build(); + let a = ListArray::from(a_list_data); + + let values = Arc::new(a); + one_column_roundtrip("list_single_column", values, false); + } + + #[test] + #[should_panic( + expected = "Reading parquet list array into arrow is not supported yet!" + )] + fn large_list_single_column() { + let a_values = Int32Array::from(vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10]); + let a_value_offsets = + arrow::buffer::Buffer::from(&[0i64, 1, 3, 3, 6, 10].to_byte_slice()); + let a_list_data = + ArrayData::builder(DataType::LargeList(Box::new(DataType::Int32))) + .len(5) + .add_buffer(a_value_offsets) + .add_child_data(a_values.data()) + .build(); + let a = LargeListArray::from(a_list_data); + + let values = Arc::new(a); + one_column_roundtrip("large_list_single_column", values, false); + } + + #[test] + #[ignore] // Struct support isn't correct yet - null_bitmap doesn't match + fn struct_single_column() { + let a_values = Int32Array::from(vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10]); + let struct_field_a = Field::new("f", DataType::Int32, false); + let s = StructArray::from(vec![(struct_field_a, Arc::new(a_values) as ArrayRef)]); + + let values = Arc::new(s); + one_column_roundtrip("struct_single_column", values, false); + } +} diff --git a/rust/parquet/src/arrow/converter.rs b/rust/parquet/src/arrow/converter.rs index 9fbfa339168..c988aaeacfc 100644 --- a/rust/parquet/src/arrow/converter.rs +++ b/rust/parquet/src/arrow/converter.rs @@ -17,12 +17,19 @@ use crate::arrow::record_reader::RecordReader; use crate::data_type::{ByteArray, DataType, Int96}; -use arrow::array::{ - Array, ArrayRef, BinaryBuilder, BooleanArray, BooleanBufferBuilder, - BufferBuilderTrait, FixedSizeBinaryBuilder, StringBuilder, - TimestampNanosecondBuilder, +// TODO: clean up imports (best done when there are few moving parts) +use arrow::{ + array::{ + Array, ArrayRef, BinaryBuilder, BooleanArray, BooleanBufferBuilder, + BufferBuilderTrait, FixedSizeBinaryBuilder, StringBuilder, + TimestampNanosecondBuilder, + }, + datatypes::Time32MillisecondType, +}; +use arrow::{ + compute::cast, datatypes::Time32SecondType, datatypes::Time64MicrosecondType, + datatypes::Time64NanosecondType, }; -use arrow::compute::cast; use std::convert::From; use std::sync::Arc; @@ -226,6 +233,14 @@ pub type TimestampMillisecondConverter = CastConverter; pub type TimestampMicrosecondConverter = CastConverter; +pub type Time32SecondConverter = + CastConverter; +pub type Time32MillisecondConverter = + CastConverter; +pub type Time64MicrosecondConverter = + CastConverter; +pub type Time64NanosecondConverter = + CastConverter; pub type UInt64Converter = CastConverter; pub type Float32Converter = CastConverter; pub type Float64Converter = CastConverter; diff --git a/rust/parquet/src/arrow/mod.rs b/rust/parquet/src/arrow/mod.rs index 02f50fd3a90..2bdb07cfbbb 100644 --- a/rust/parquet/src/arrow/mod.rs +++ b/rust/parquet/src/arrow/mod.rs @@ -51,10 +51,17 @@ pub(in crate::arrow) mod array_reader; pub mod arrow_reader; +pub mod arrow_writer; pub(in crate::arrow) mod converter; pub(in crate::arrow) mod record_reader; pub mod schema; pub use self::arrow_reader::ArrowReader; pub use self::arrow_reader::ParquetFileArrowReader; -pub use self::schema::{parquet_to_arrow_schema, parquet_to_arrow_schema_by_columns}; +pub use self::arrow_writer::ArrowWriter; +pub use self::schema::{ + arrow_to_parquet_schema, parquet_to_arrow_schema, parquet_to_arrow_schema_by_columns, +}; + +/// Schema metadata key used to store serialized Arrow IPC schema +pub const ARROW_SCHEMA_META_KEY: &str = "ARROW:schema"; diff --git a/rust/parquet/src/arrow/schema.rs b/rust/parquet/src/arrow/schema.rs index aebb9e776cc..d5a0ff9ca08 100644 --- a/rust/parquet/src/arrow/schema.rs +++ b/rust/parquet/src/arrow/schema.rs @@ -26,24 +26,34 @@ use std::collections::{HashMap, HashSet}; use std::rc::Rc; +use arrow::datatypes::{DataType, DateUnit, Field, Schema, TimeUnit}; +use arrow::ipc::writer; + use crate::basic::{LogicalType, Repetition, Type as PhysicalType}; use crate::errors::{ParquetError::ArrowError, Result}; -use crate::file::metadata::KeyValue; +use crate::file::{metadata::KeyValue, properties::WriterProperties}; use crate::schema::types::{ColumnDescriptor, SchemaDescriptor, Type, TypePtr}; -use arrow::datatypes::TimeUnit; -use arrow::datatypes::{DataType, DateUnit, Field, Schema}; - -/// Convert parquet schema to arrow schema including optional metadata. +/// Convert Parquet schema to Arrow schema including optional metadata. +/// Attempts to decode any existing Arrow shcema metadata, falling back +/// to converting the Parquet schema column-wise pub fn parquet_to_arrow_schema( parquet_schema: &SchemaDescriptor, - metadata: &Option>, + key_value_metadata: &Option>, ) -> Result { - parquet_to_arrow_schema_by_columns( - parquet_schema, - 0..parquet_schema.columns().len(), - metadata, - ) + let mut metadata = parse_key_value_metadata(key_value_metadata).unwrap_or_default(); + let arrow_schema_metadata = metadata + .remove(super::ARROW_SCHEMA_META_KEY) + .map(|encoded| get_arrow_schema_from_metadata(&encoded)); + + match arrow_schema_metadata { + Some(Some(schema)) => Ok(schema), + _ => parquet_to_arrow_schema_by_columns( + parquet_schema, + 0..parquet_schema.columns().len(), + key_value_metadata, + ), + } } /// Convert parquet schema to arrow schema including optional metadata, only preserving some leaf columns. @@ -81,6 +91,81 @@ where .map(|fields| Schema::new_with_metadata(fields, metadata)) } +/// Try to convert Arrow schema metadata into a schema +fn get_arrow_schema_from_metadata(encoded_meta: &str) -> Option { + let decoded = base64::decode(encoded_meta); + match decoded { + Ok(bytes) => { + let slice = if bytes[0..4] == [255u8; 4] { + &bytes[8..] + } else { + bytes.as_slice() + }; + let message = arrow::ipc::get_root_as_message(slice); + message + .header_as_schema() + .map(arrow::ipc::convert::fb_to_schema) + } + Err(err) => { + // The C++ implementation returns an error if the schema can't be parsed. + // To prevent this, we explicitly log this, then compute the schema without the metadata + eprintln!( + "Unable to decode the encoded schema stored in {}, {:?}", + super::ARROW_SCHEMA_META_KEY, + err + ); + None + } + } +} + +/// Encodes the Arrow schema into the IPC format, and base64 encodes it +fn encode_arrow_schema(schema: &Schema) -> String { + let options = writer::IpcWriteOptions::default(); + let mut serialized_schema = arrow::ipc::writer::schema_to_bytes(&schema, &options); + + // manually prepending the length to the schema as arrow uses the legacy IPC format + // TODO: change after addressing ARROW-9777 + let schema_len = serialized_schema.ipc_message.len(); + let mut len_prefix_schema = Vec::with_capacity(schema_len + 8); + len_prefix_schema.append(&mut vec![255u8, 255, 255, 255]); + len_prefix_schema.append((schema_len as u32).to_le_bytes().to_vec().as_mut()); + len_prefix_schema.append(&mut serialized_schema.ipc_message); + + base64::encode(&len_prefix_schema) +} + +/// Mutates writer metadata by storing the encoded Arrow schema. +/// If there is an existing Arrow schema metadata, it is replaced. +pub(crate) fn add_encoded_arrow_schema_to_metadata( + schema: &Schema, + props: &mut WriterProperties, +) { + let encoded = encode_arrow_schema(schema); + + let schema_kv = KeyValue { + key: super::ARROW_SCHEMA_META_KEY.to_string(), + value: Some(encoded), + }; + + let mut meta = props.key_value_metadata.clone().unwrap_or_default(); + // check if ARROW:schema exists, and overwrite it + let schema_meta = meta + .iter() + .enumerate() + .find(|(_, kv)| kv.key.as_str() == super::ARROW_SCHEMA_META_KEY); + match schema_meta { + Some((i, _)) => { + meta.remove(i); + meta.push(schema_kv); + } + None => { + meta.push(schema_kv); + } + } + props.key_value_metadata = Some(meta); +} + /// Convert arrow schema to parquet schema pub fn arrow_to_parquet_schema(schema: &Schema) -> Result { let fields: Result> = schema @@ -215,42 +300,48 @@ fn arrow_to_parquet_type(field: &Field) -> Result { Type::primitive_type_builder(name, PhysicalType::FIXED_LEN_BYTE_ARRAY) .with_logical_type(LogicalType::INTERVAL) .with_repetition(repetition) - .with_length(3) + .with_length(12) + .build() + } + DataType::Binary | DataType::LargeBinary => { + Type::primitive_type_builder(name, PhysicalType::BYTE_ARRAY) + .with_repetition(repetition) .build() } - DataType::Binary => Type::primitive_type_builder(name, PhysicalType::BYTE_ARRAY) - .with_repetition(repetition) - .build(), DataType::FixedSizeBinary(length) => { Type::primitive_type_builder(name, PhysicalType::FIXED_LEN_BYTE_ARRAY) .with_repetition(repetition) .with_length(*length) .build() } - DataType::Utf8 => Type::primitive_type_builder(name, PhysicalType::BYTE_ARRAY) - .with_logical_type(LogicalType::UTF8) - .with_repetition(repetition) - .build(), - DataType::List(dtype) | DataType::FixedSizeList(dtype, _) => { - Type::group_type_builder(name) - .with_fields(&mut vec![Rc::new( - Type::group_type_builder("list") - .with_fields(&mut vec![Rc::new({ - let list_field = Field::new( - "element", - *dtype.clone(), - field.is_nullable(), - ); - arrow_to_parquet_type(&list_field)? - })]) - .with_repetition(Repetition::REPEATED) - .build()?, - )]) - .with_logical_type(LogicalType::LIST) - .with_repetition(Repetition::REQUIRED) + DataType::Utf8 | DataType::LargeUtf8 => { + Type::primitive_type_builder(name, PhysicalType::BYTE_ARRAY) + .with_logical_type(LogicalType::UTF8) + .with_repetition(repetition) .build() } + DataType::List(dtype) + | DataType::FixedSizeList(dtype, _) + | DataType::LargeList(dtype) => Type::group_type_builder(name) + .with_fields(&mut vec![Rc::new( + Type::group_type_builder("list") + .with_fields(&mut vec![Rc::new({ + let list_field = + Field::new("element", *dtype.clone(), field.is_nullable()); + arrow_to_parquet_type(&list_field)? + })]) + .with_repetition(Repetition::REPEATED) + .build()?, + )]) + .with_logical_type(LogicalType::LIST) + .with_repetition(Repetition::REQUIRED) + .build(), DataType::Struct(fields) => { + if fields.is_empty() { + return Err(ArrowError( + "Parquet does not support writing empty structs".to_string(), + )); + } // recursively convert children to types/nodes let fields: Result> = fields .iter() @@ -267,9 +358,6 @@ fn arrow_to_parquet_type(field: &Field) -> Result { let dict_field = Field::new(name, *value.clone(), field.is_nullable()); arrow_to_parquet_type(&dict_field) } - DataType::LargeUtf8 | DataType::LargeBinary | DataType::LargeList(_) => { - Err(ArrowError("Large arrays not supported".to_string())) - } } } /// This struct is used to group methods and data structures used to convert parquet @@ -555,12 +643,16 @@ impl ParquetTypeConverter<'_> { mod tests { use super::*; - use std::collections::HashMap; + use std::{collections::HashMap, convert::TryFrom, sync::Arc}; - use arrow::datatypes::{DataType, DateUnit, Field, TimeUnit}; + use arrow::datatypes::{DataType, DateUnit, Field, IntervalUnit, TimeUnit}; - use crate::file::metadata::KeyValue; - use crate::schema::{parser::parse_message_type, types::SchemaDescriptor}; + use crate::file::{metadata::KeyValue, reader::SerializedFileReader}; + use crate::{ + arrow::{ArrowReader, ArrowWriter, ParquetFileArrowReader}, + schema::{parser::parse_message_type, types::SchemaDescriptor}, + util::test_common::get_temp_file, + }; #[test] fn test_flat_primitives() { @@ -1194,6 +1286,17 @@ mod tests { }); } + #[test] + #[should_panic(expected = "Parquet does not support writing empty structs")] + fn test_empty_struct_field() { + let arrow_fields = vec![Field::new("struct", DataType::Struct(vec![]), false)]; + let arrow_schema = Schema::new(arrow_fields); + let converted_arrow_schema = arrow_to_parquet_schema(&arrow_schema); + + assert!(converted_arrow_schema.is_err()); + converted_arrow_schema.unwrap(); + } + #[test] fn test_metadata() { let message_type = " @@ -1216,4 +1319,123 @@ mod tests { assert_eq!(converted_arrow_schema.metadata(), &expected_metadata); } + + #[test] + fn test_arrow_schema_roundtrip() -> Result<()> { + // This tests the roundtrip of an Arrow schema + // Fields that are commented out fail roundtrip tests or are unsupported by the writer + let metadata: HashMap = + [("Key".to_string(), "Value".to_string())] + .iter() + .cloned() + .collect(); + + let schema = Schema::new_with_metadata( + vec![ + Field::new("c1", DataType::Utf8, false), + Field::new("c2", DataType::Binary, false), + Field::new("c3", DataType::FixedSizeBinary(3), false), + Field::new("c4", DataType::Boolean, false), + Field::new("c5", DataType::Date32(DateUnit::Day), false), + Field::new("c6", DataType::Date64(DateUnit::Millisecond), false), + Field::new("c7", DataType::Time32(TimeUnit::Second), false), + Field::new("c8", DataType::Time32(TimeUnit::Millisecond), false), + Field::new("c13", DataType::Time64(TimeUnit::Microsecond), false), + Field::new("c14", DataType::Time64(TimeUnit::Nanosecond), false), + Field::new("c15", DataType::Timestamp(TimeUnit::Second, None), false), + Field::new( + "c16", + DataType::Timestamp( + TimeUnit::Millisecond, + Some(Arc::new("UTC".to_string())), + ), + false, + ), + Field::new( + "c17", + DataType::Timestamp( + TimeUnit::Microsecond, + Some(Arc::new("Africa/Johannesburg".to_string())), + ), + false, + ), + Field::new( + "c18", + DataType::Timestamp(TimeUnit::Nanosecond, None), + false, + ), + Field::new("c19", DataType::Interval(IntervalUnit::DayTime), false), + Field::new("c20", DataType::Interval(IntervalUnit::YearMonth), false), + Field::new("c21", DataType::List(Box::new(DataType::Boolean)), false), + Field::new( + "c22", + DataType::FixedSizeList(Box::new(DataType::Boolean), 5), + false, + ), + Field::new( + "c23", + DataType::List(Box::new(DataType::List(Box::new(DataType::Struct( + vec![ + Field::new("a", DataType::Int16, true), + Field::new("b", DataType::Float64, false), + ], + ))))), + true, + ), + Field::new( + "c24", + DataType::Struct(vec![ + Field::new("a", DataType::Utf8, false), + Field::new("b", DataType::UInt16, false), + ]), + false, + ), + Field::new("c25", DataType::Interval(IntervalUnit::YearMonth), true), + Field::new("c26", DataType::Interval(IntervalUnit::DayTime), true), + // Field::new("c27", DataType::Duration(TimeUnit::Second), false), + // Field::new("c28", DataType::Duration(TimeUnit::Millisecond), false), + // Field::new("c29", DataType::Duration(TimeUnit::Microsecond), false), + // Field::new("c30", DataType::Duration(TimeUnit::Nanosecond), false), + // Field::new_dict( + // "c31", + // DataType::Dictionary( + // Box::new(DataType::Int32), + // Box::new(DataType::Utf8), + // ), + // true, + // 123, + // true, + // ), + Field::new("c32", DataType::LargeBinary, true), + Field::new("c33", DataType::LargeUtf8, true), + Field::new( + "c34", + DataType::LargeList(Box::new(DataType::LargeList(Box::new( + DataType::Struct(vec![ + Field::new("a", DataType::Int16, true), + Field::new("b", DataType::Float64, true), + ]), + )))), + true, + ), + ], + metadata, + ); + + // write to an empty parquet file so that schema is serialized + let file = get_temp_file("test_arrow_schema_roundtrip.parquet", &[]); + let mut writer = ArrowWriter::try_new( + file.try_clone().unwrap(), + Arc::new(schema.clone()), + None, + )?; + writer.close()?; + + // read file back + let parquet_reader = SerializedFileReader::try_from(file)?; + let mut arrow_reader = ParquetFileArrowReader::new(Rc::new(parquet_reader)); + let read_schema = arrow_reader.get_schema()?; + assert_eq!(schema, read_schema); + Ok(()) + } } diff --git a/rust/parquet/src/file/properties.rs b/rust/parquet/src/file/properties.rs index 188d6ec3c9e..b62ce7bbc38 100644 --- a/rust/parquet/src/file/properties.rs +++ b/rust/parquet/src/file/properties.rs @@ -89,8 +89,8 @@ pub type WriterPropertiesPtr = Rc; /// Writer properties. /// -/// It is created as an immutable data structure, use [`WriterPropertiesBuilder`] to -/// assemble the properties. +/// All properties except the key-value metadata are immutable, +/// use [`WriterPropertiesBuilder`] to assemble these properties. #[derive(Debug, Clone)] pub struct WriterProperties { data_pagesize_limit: usize, @@ -99,7 +99,7 @@ pub struct WriterProperties { max_row_group_size: usize, writer_version: WriterVersion, created_by: String, - key_value_metadata: Option>, + pub(crate) key_value_metadata: Option>, default_column_properties: ColumnProperties, column_properties: HashMap, } diff --git a/rust/parquet/src/schema/types.rs b/rust/parquet/src/schema/types.rs index 416073af035..57999050ab3 100644 --- a/rust/parquet/src/schema/types.rs +++ b/rust/parquet/src/schema/types.rs @@ -788,7 +788,7 @@ impl SchemaDescriptor { result.clone() } - fn column_root_of(&self, i: usize) -> &Rc { + fn column_root_of(&self, i: usize) -> &TypePtr { assert!( i < self.leaves.len(), "Index out of bound: {} not in [0, {})", @@ -810,6 +810,10 @@ impl SchemaDescriptor { self.schema.as_ref() } + pub fn root_schema_ptr(&self) -> TypePtr { + self.schema.clone() + } + /// Returns schema name. pub fn name(&self) -> &str { self.schema.name()