diff --git a/parquet/src/arrow/arrow_writer/mod.rs b/parquet/src/arrow/arrow_writer/mod.rs index 4782efda9c4a..fbb0a7213de1 100644 --- a/parquet/src/arrow/arrow_writer/mod.rs +++ b/parquet/src/arrow/arrow_writer/mod.rs @@ -198,10 +198,12 @@ impl ArrowWriter { let max_row_group_size = props.max_row_group_size(); + let props_ptr = Arc::new(props); let file_writer = - SerializedFileWriter::new(writer, schema.root_schema_ptr(), Arc::new(props))?; + SerializedFileWriter::new(writer, schema.root_schema_ptr(), Arc::clone(&props_ptr))?; - let row_group_writer_factory = ArrowRowGroupWriterFactory::new(&file_writer); + let row_group_writer_factory = + ArrowRowGroupWriterFactory::new(&file_writer, schema, arrow_schema.clone(), props_ptr); Ok(Self { writer: file_writer, @@ -272,12 +274,10 @@ impl ArrowWriter { let in_progress = match &mut self.in_progress { Some(in_progress) => in_progress, - x => x.insert(self.row_group_writer_factory.create_row_group_writer( - self.writer.schema_descr(), - self.writer.properties(), - &self.arrow_schema, - self.writer.flushed_row_groups().len(), - )?), + x => x.insert( + self.row_group_writer_factory + .create_row_group_writer(self.writer.flushed_row_groups().len())?, + ), }; // If would exceed max_row_group_size, split batch @@ -755,7 +755,7 @@ impl ArrowColumnWriter { } /// Encodes [`RecordBatch`] to a parquet row group -struct ArrowRowGroupWriter { +pub struct ArrowRowGroupWriter { writers: Vec, schema: SchemaRef, buffered_rows: usize, @@ -787,54 +787,72 @@ impl ArrowRowGroupWriter { .map(|writer| writer.close()) .collect() } + + /// Get [`ArrowColumnWriter`]s for all columns in a row group + pub fn into_column_writers(self) -> Vec { + self.writers + } } -struct ArrowRowGroupWriterFactory { +/// Factory for creating [`ArrowRowGroupWriter`] instances. +/// This is used by [`ArrowWriter`] to create row group writers, but can be used +/// directly for lower level API. +pub struct ArrowRowGroupWriterFactory { + schema: SchemaDescriptor, + arrow_schema: SchemaRef, + props: WriterPropertiesPtr, #[cfg(feature = "encryption")] file_encryptor: Option>, } impl ArrowRowGroupWriterFactory { + /// Creates a new [`ArrowRowGroupWriterFactory`] using provided [`SerializedFileWriter`]. #[cfg(feature = "encryption")] - fn new(file_writer: &SerializedFileWriter) -> Self { + pub fn new( + file_writer: &SerializedFileWriter, + schema: SchemaDescriptor, + arrow_schema: SchemaRef, + props: WriterPropertiesPtr, + ) -> Self { Self { + schema, + arrow_schema, + props, file_encryptor: file_writer.file_encryptor(), } } #[cfg(not(feature = "encryption"))] - fn new(_file_writer: &SerializedFileWriter) -> Self { - Self {} + pub fn new( + _file_writer: &SerializedFileWriter, + schema: SchemaDescriptor, + arrow_schema: SchemaRef, + props: WriterPropertiesPtr, + ) -> Self { + Self { + schema, + arrow_schema, + props, + } } + /// Creates a new [`ArrowRowGroupWriter`] for the given parquet schema and writer properties. #[cfg(feature = "encryption")] - fn create_row_group_writer( - &self, - parquet: &SchemaDescriptor, - props: &WriterPropertiesPtr, - arrow: &SchemaRef, - row_group_index: usize, - ) -> Result { + pub fn create_row_group_writer(&self, row_group_index: usize) -> Result { let writers = get_column_writers_with_encryptor( - parquet, - props, - arrow, + &self.schema, + &self.props, + &self.arrow_schema, self.file_encryptor.clone(), row_group_index, )?; - Ok(ArrowRowGroupWriter::new(writers, arrow)) + Ok(ArrowRowGroupWriter::new(writers, &self.arrow_schema)) } #[cfg(not(feature = "encryption"))] - fn create_row_group_writer( - &self, - parquet: &SchemaDescriptor, - props: &WriterPropertiesPtr, - arrow: &SchemaRef, - _row_group_index: usize, - ) -> Result { - let writers = get_column_writers(parquet, props, arrow)?; - Ok(ArrowRowGroupWriter::new(writers, arrow)) + pub fn create_row_group_writer(&self, _row_group_index: usize) -> Result { + let writers = get_column_writers(&self.schema, &self.props, &self.arrow_schema)?; + Ok(ArrowRowGroupWriter::new(writers, &self.arrow_schema)) } } @@ -890,7 +908,14 @@ struct ArrowColumnWriterFactory { file_encryptor: Option>, } +impl Default for ArrowColumnWriterFactory { + fn default() -> Self { + Self::new() + } +} + impl ArrowColumnWriterFactory { + /// Create a new [`ArrowColumnWriterFactory`] pub fn new() -> Self { Self { #[cfg(feature = "encryption")] @@ -939,7 +964,7 @@ impl ArrowColumnWriterFactory { } /// Gets the [`ArrowColumnWriter`] for the given `data_type` - fn get_arrow_column_writer( + pub fn get_arrow_column_writer( &self, data_type: &ArrowDataType, props: &WriterPropertiesPtr, diff --git a/parquet/src/file/properties.rs b/parquet/src/file/properties.rs index 26177b69a577..3ab437f168e6 100644 --- a/parquet/src/file/properties.rs +++ b/parquet/src/file/properties.rs @@ -457,7 +457,7 @@ pub struct WriterPropertiesBuilder { impl WriterPropertiesBuilder { /// Returns default state of the builder. - fn with_defaults() -> Self { + pub fn with_defaults() -> Self { Self { data_page_size_limit: DEFAULT_PAGE_SIZE, data_page_row_count_limit: DEFAULT_DATA_PAGE_ROW_COUNT_LIMIT, diff --git a/parquet/tests/encryption/encryption.rs b/parquet/tests/encryption/encryption.rs index 7079e91d1209..63c92c2f1549 100644 --- a/parquet/tests/encryption/encryption.rs +++ b/parquet/tests/encryption/encryption.rs @@ -28,13 +28,14 @@ use parquet::arrow::arrow_reader::{ ArrowReaderMetadata, ArrowReaderOptions, ParquetRecordBatchReaderBuilder, RowSelection, RowSelector, }; -use parquet::arrow::ArrowWriter; +use parquet::arrow::arrow_writer::{compute_leaves, ArrowLeafColumn, ArrowRowGroupWriterFactory}; +use parquet::arrow::{ArrowSchemaConverter, ArrowWriter}; use parquet::data_type::{ByteArray, ByteArrayType}; use parquet::encryption::decrypt::FileDecryptionProperties; use parquet::encryption::encrypt::FileEncryptionProperties; use parquet::errors::ParquetError; use parquet::file::metadata::ParquetMetaData; -use parquet::file::properties::WriterProperties; +use parquet::file::properties::{WriterProperties, WriterPropertiesBuilder}; use parquet::file::writer::SerializedFileWriter; use parquet::schema::parser::parse_message_type; use std::fs::File; @@ -1062,14 +1063,10 @@ fn test_decrypt_page_index( Ok(()) } -fn read_and_roundtrip_to_encrypted_file( +fn read_encrypted_file( path: &str, decryption_properties: FileDecryptionProperties, - encryption_properties: FileEncryptionProperties, -) { - let temp_file = tempfile::tempfile().unwrap(); - - // read example data +) -> Result<(Vec, ArrowReaderMetadata), ParquetError> { let file = File::open(path).unwrap(); let options = ArrowReaderOptions::default() .with_file_decryption_properties(decryption_properties.clone()); @@ -1080,7 +1077,18 @@ fn read_and_roundtrip_to_encrypted_file( let batches = batch_reader .collect::, _>>() .unwrap(); + Ok((batches, metadata)) +} + +fn read_and_roundtrip_to_encrypted_file( + path: &str, + decryption_properties: FileDecryptionProperties, + encryption_properties: FileEncryptionProperties, +) { + // read example data + let (batches, metadata) = read_encrypted_file(path, decryption_properties.clone()).unwrap(); + let temp_file = tempfile::tempfile().unwrap(); // write example data let props = WriterProperties::builder() .with_file_encryption_properties(encryption_properties) @@ -1101,3 +1109,118 @@ fn read_and_roundtrip_to_encrypted_file( // check re-written example data verify_encryption_test_file_read(temp_file, decryption_properties); } + +#[tokio::test] +async fn test_multi_threaded_encrypted_writing() { + // Read example data and set up encryption/decryption properties + let testdata = arrow::util::test_util::parquet_test_data(); + let path = format!("{testdata}/encrypt_columns_and_footer.parquet.encrypted"); + + let file_encryption_properties = FileEncryptionProperties::builder(b"0123456789012345".into()) + .with_column_key("double_field", b"1234567890123450".into()) + .with_column_key("float_field", b"1234567890123451".into()) + .build() + .unwrap(); + let decryption_properties = FileDecryptionProperties::builder(b"0123456789012345".into()) + .with_column_key("double_field", b"1234567890123450".into()) + .with_column_key("float_field", b"1234567890123451".into()) + .build() + .unwrap(); + + let (record_batches, metadata) = + read_encrypted_file(&path, decryption_properties.clone()).unwrap(); + let to_write: Vec<_> = record_batches + .iter() + .flat_map(|rb| rb.columns().to_vec()) + .collect(); + let schema = metadata.schema().clone(); + + let props = Arc::new( + WriterPropertiesBuilder::with_defaults() + .with_file_encryption_properties(file_encryption_properties) + .build(), + ); + let parquet_schema = ArrowSchemaConverter::new() + .with_coerce_types(props.coerce_types()) + .convert(&schema) + .unwrap(); + let root_schema = parquet_schema.root_schema_ptr(); + + // Create a temporary file to write the encrypted data + let temp_file = tempfile::NamedTempFile::new().unwrap(); + let mut file_writer = + SerializedFileWriter::new(&temp_file, root_schema.clone(), props.clone()).unwrap(); + + let arrow_row_group_writer_factory = ArrowRowGroupWriterFactory::new( + &file_writer, + parquet_schema, + schema.clone(), + props.clone(), + ); + let arrow_row_group_writer = arrow_row_group_writer_factory + .create_row_group_writer(0) + .unwrap(); + + // Get column writers with encryptor from ArrowRowGroupWriter + let col_writers = arrow_row_group_writer.into_column_writers(); + let num_columns = col_writers.len(); + + // Create a channel for each column writer to send ArrowLeafColumn data to + let mut col_writer_tasks = Vec::with_capacity(num_columns); + let mut col_array_channels = Vec::with_capacity(num_columns); + for mut writer in col_writers.into_iter() { + let (send_array, mut receive_array) = tokio::sync::mpsc::channel::(100); + col_array_channels.push(send_array); + let handle = tokio::spawn(async move { + while let Some(col) = receive_array.recv().await { + let _ = writer.write(&col); + } + writer.close().unwrap() + }); + col_writer_tasks.push(handle); + } + + // Send the ArrowLeafColumn data to the respective column writer channels + let mut worker_iter = col_array_channels.iter_mut(); + for (array, field) in to_write.iter().zip(schema.fields()) { + for leaves in compute_leaves(field, array).unwrap() { + worker_iter.next().unwrap().send(leaves).await.unwrap(); + } + } + drop(col_array_channels); + + // Wait for all column writers to finish writing + let mut finalized_rg = Vec::with_capacity(num_columns); + for task in col_writer_tasks.into_iter() { + finalized_rg.push(task.await.unwrap()); + } + + // Wait for the workers to complete writing then append + // the resulting column chunks to the row group (and the file) + let mut row_group_writer = file_writer.next_row_group().unwrap(); + for chunk in finalized_rg { + chunk.append_to_row_group(&mut row_group_writer).unwrap(); + } + + // Close the row group which writes to the underlying file + row_group_writer.close().unwrap(); + + // Close the file writer which writes the footer + let metadata = file_writer.close().unwrap(); + assert_eq!(metadata.num_rows, 50); + + // Check that the file was written correctly + let (read_record_batches, read_metadata) = read_encrypted_file( + temp_file.path().to_str().unwrap(), + decryption_properties.clone(), + ) + .unwrap(); + verify_encryption_test_data(read_record_batches, read_metadata.metadata()); + + // Check that file was encrypted + let result = ArrowReaderMetadata::load(&temp_file.into_file(), ArrowReaderOptions::default()); + assert_eq!( + result.unwrap_err().to_string(), + "Parquet error: Parquet file has an encrypted footer but decryption properties were not provided" + ); +} diff --git a/parquet/tests/encryption/encryption_async.rs b/parquet/tests/encryption/encryption_async.rs index e0fbbcdfafe3..4aca4bacb057 100644 --- a/parquet/tests/encryption/encryption_async.rs +++ b/parquet/tests/encryption/encryption_async.rs @@ -20,6 +20,7 @@ use crate::encryption_util::{ verify_column_indexes, verify_encryption_test_data, TestKeyRetriever, }; +use arrow_array::RecordBatch; use futures::TryStreamExt; use parquet::arrow::arrow_reader::{ArrowReaderMetadata, ArrowReaderOptions}; use parquet::arrow::arrow_writer::ArrowWriterOptions; @@ -436,14 +437,13 @@ async fn test_decrypt_page_index( Ok(()) } -async fn verify_encryption_test_file_read_async( +async fn read_encrypted_file_async( file: &mut tokio::fs::File, decryption_properties: FileDecryptionProperties, -) -> Result<(), ParquetError> { +) -> Result<(Vec, ArrowReaderMetadata), ParquetError> { let options = ArrowReaderOptions::new().with_file_decryption_properties(decryption_properties); let arrow_metadata = ArrowReaderMetadata::load_async(file, options).await?; - let metadata = arrow_metadata.metadata(); let record_reader = ParquetRecordBatchStreamBuilder::new_with_metadata( file.try_clone().await?, @@ -451,8 +451,15 @@ async fn verify_encryption_test_file_read_async( ) .build()?; let record_batches = record_reader.try_collect::>().await?; + Ok((record_batches, arrow_metadata.clone())) +} - verify_encryption_test_data(record_batches, metadata); +async fn verify_encryption_test_file_read_async( + file: &mut tokio::fs::File, + decryption_properties: FileDecryptionProperties, +) -> Result<(), ParquetError> { + let (record_batches, metadata) = read_encrypted_file_async(file, decryption_properties).await?; + verify_encryption_test_data(record_batches, metadata.metadata()); Ok(()) } @@ -464,15 +471,8 @@ async fn read_and_roundtrip_to_encrypted_file_async( let temp_file = tempfile::tempfile().unwrap(); let mut file = File::open(&path).await.unwrap(); - let options = - ArrowReaderOptions::new().with_file_decryption_properties(decryption_properties.clone()); - let arrow_metadata = ArrowReaderMetadata::load_async(&mut file, options).await?; - let record_reader = ParquetRecordBatchStreamBuilder::new_with_metadata( - file.try_clone().await?, - arrow_metadata.clone(), - ) - .build()?; - let record_batches = record_reader.try_collect::>().await?; + let (record_batches, arrow_metadata) = + read_encrypted_file_async(&mut file, decryption_properties.clone()).await?; let props = WriterProperties::builder() .with_file_encryption_properties(encryption_properties)