diff --git a/rust/parquet/src/file/writer.rs b/rust/parquet/src/file/writer.rs index ede6ce47a2e..8b8ed6dcb09 100644 --- a/rust/parquet/src/file/writer.rs +++ b/rust/parquet/src/file/writer.rs @@ -18,27 +18,32 @@ //! Contains file writer API, and provides methods to write row groups and columns by //! using row group writers and column writers respectively. +use std::fs::File; use std::{ io::{Seek, SeekFrom, Write}, rc::Rc, }; +use arrow::array; +use arrow::datatypes::Schema; use byteorder::{ByteOrder, LittleEndian}; use parquet_format as parquet; use thrift::protocol::{TCompactOutputProtocol, TOutputProtocol}; -use crate::basic::PageType; +use crate::basic::{PageType, Repetition, Type}; use crate::column::{ page::{CompressedPage, Page, PageWriteSpec, PageWriter}, writer::{get_column_writer, ColumnWriter}, }; use crate::errors::{ParquetError, Result}; +use crate::file::properties::WriterProperties; use crate::file::{ metadata::*, properties::WriterPropertiesPtr, reader::TryClone, statistics::to_thrift as statistics_to_thrift, FOOTER_SIZE, PARQUET_MAGIC, }; use crate::schema::types::{self, SchemaDescPtr, SchemaDescriptor, TypePtr}; use crate::util::io::{FileSink, Position}; +use arrow::record_batch::RecordBatch; // ---------------------------------------------------------------------- // APIs for file & row group writers @@ -521,6 +526,75 @@ impl PageWriter for SerializedPageWriter { } } +struct ArrowWriter { + writer: SerializedFileWriter, + rows: i64, +} + +impl ArrowWriter { + pub fn new(file: File, _arrow_schema: &Schema) -> Self { + //TODO convert Arrow schema to Parquet schema + let schema = Rc::new( + types::Type::group_type_builder("schema") + .with_fields(&mut vec![ + Rc::new( + types::Type::primitive_type_builder("a", Type::INT32) + .with_repetition(Repetition::REQUIRED) + .build() + .unwrap(), + ), + Rc::new( + types::Type::primitive_type_builder("b", Type::INT32) + .with_repetition(Repetition::REQUIRED) + .build() + .unwrap(), + ), + ]) + .build() + .unwrap(), + ); + let props = Rc::new(WriterProperties::builder().build()); + let file_writer = + SerializedFileWriter::new(file.try_clone().unwrap(), schema, props).unwrap(); + + Self { + writer: file_writer, + rows: 0, + } + } + + pub fn write(&mut self, batch: &RecordBatch) { + let mut row_group_writer = self.writer.next_row_group().unwrap(); + for i in 0..batch.schema().fields().len() { + let col_writer = row_group_writer.next_column().unwrap(); + if let Some(mut writer) = col_writer { + match writer { + ColumnWriter::Int32ColumnWriter(ref mut typed) => { + let array = batch + .column(i) + .as_any() + .downcast_ref::() + .unwrap(); + self.rows += typed + .write_batch(array.value_slice(0, array.len()), None, None) + .unwrap() as i64; + } + //TODO add other types + _ => { + unimplemented!(); + } + } + row_group_writer.close_column(writer).unwrap(); + } + } + self.writer.close_row_group(row_group_writer).unwrap(); + } + + pub fn close(&mut self) { + self.writer.close().unwrap(); + } +} + #[cfg(test)] mod tests { use super::*; @@ -538,6 +612,36 @@ mod tests { use crate::record::RowAccessor; use crate::util::{memory::ByteBufferPtr, test_common::get_temp_file}; + use arrow::array::Int32Array; + use arrow::datatypes::{DataType, Field, Schema}; + use arrow::record_batch::RecordBatch; + use std::sync::Arc; + + #[test] + fn arrow_writer() { + // define schema + let schema = Schema::new(vec![ + Field::new("a", DataType::Int32, false), + Field::new("b", DataType::Int32, false), + ]); + + // create some data + let a = Int32Array::from(vec![1, 2, 3, 4, 5]); + let b = Int32Array::from(vec![1, 2, 3, 4, 5]); + + // build a record batch + let batch = RecordBatch::try_new( + Arc::new(schema.clone()), + vec![Arc::new(a), Arc::new(b)], + ) + .unwrap(); + + let file = File::create("test.parquet").unwrap(); + let mut writer = ArrowWriter::new(file, &schema); + writer.write(&batch); + writer.close(); + } + #[test] fn test_file_writer_error_after_close() { let file = get_temp_file("test_file_writer_error_after_close", &[]);