diff --git a/rust/arrow-flight/src/utils.rs b/rust/arrow-flight/src/utils.rs index dbb5781aec9..fcc58470e21 100644 --- a/rust/arrow-flight/src/utils.rs +++ b/rust/arrow-flight/src/utils.rs @@ -129,21 +129,29 @@ impl TryFrom<&SchemaResult> for Schema { pub fn flight_data_to_arrow_batch( data: &FlightData, schema: SchemaRef, -) -> Result> { +) -> Option> { // check that the data_header is a record batch message let message = arrow::ipc::get_root_as_message(&data.data_header[..]); let dictionaries_by_field = Vec::new(); - let batch_header = message.header_as_record_batch().ok_or_else(|| { - ArrowError::ParseError( - "Unable to convert flight data header to a record batch".to_string(), + + message + .header_as_record_batch() + .ok_or_else(|| { + ArrowError::ParseError( + "Unable to convert flight data header to a record batch".to_string(), + ) + }) + .map_or_else( + |err| Some(Err(err)), + |batch| { + Some(reader::read_record_batch( + &data.data_body, + batch, + schema, + &dictionaries_by_field, + )) + }, ) - })?; - reader::read_record_batch( - &data.data_body, - batch_header, - schema, - &dictionaries_by_field, - ) } // TODO: add more explicit conversion that expoess flight descriptor and metadata options diff --git a/rust/arrow/src/csv/reader.rs b/rust/arrow/src/csv/reader.rs index 0f36648c506..c53104c4efc 100644 --- a/rust/arrow/src/csv/reader.rs +++ b/rust/arrow/src/csv/reader.rs @@ -313,32 +313,7 @@ impl Reader { } } - /// Read the next batch of rows - #[allow(clippy::should_implement_trait)] - pub fn next(&mut self) -> Result> { - // read a batch of rows into memory - let mut rows: Vec = Vec::with_capacity(self.batch_size); - for i in 0..self.batch_size { - match self.record_iter.next() { - Some(Ok(r)) => { - rows.push(r); - } - Some(Err(e)) => { - return Err(ArrowError::ParseError(format!( - "Error parsing line {}: {:?}", - self.line_number + i, - e - ))); - } - None => break, - } - } - - // return early if no data was loaded - if rows.is_empty() { - return Ok(None); - } - + fn parse(&self, rows: &[StringRecord]) -> Result { let projection: Vec = match self.projection { Some(ref v) => v.clone(), None => self @@ -350,7 +325,6 @@ impl Reader { .collect(), }; - let rows = &rows[..]; let arrays: Result> = projection .iter() .map(|i| { @@ -398,8 +372,6 @@ impl Reader { }) .collect(); - self.line_number += rows.len(); - let schema_fields = self.schema.fields(); let projected_fields: Vec = projection @@ -409,7 +381,7 @@ impl Reader { let projected_schema = Arc::new(Schema::new(projected_fields)); - arrays.and_then(|arr| RecordBatch::try_new(projected_schema, arr).map(Some)) + arrays.and_then(|arr| RecordBatch::try_new(projected_schema, arr)) } fn build_primitive_array( @@ -448,6 +420,42 @@ impl Reader { } } +impl Iterator for Reader { + type Item = Result; + + fn next(&mut self) -> Option { + // read a batch of rows into memory + let mut rows: Vec = Vec::with_capacity(self.batch_size); + for i in 0..self.batch_size { + match self.record_iter.next() { + Some(Ok(r)) => { + rows.push(r); + } + Some(Err(e)) => { + return Some(Err(ArrowError::ParseError(format!( + "Error parsing line {}: {:?}", + self.line_number + i, + e + )))); + } + None => break, + } + } + + // return early if no data was loaded + if rows.is_empty() { + return None; + } + + // parse the batches into a RecordBatch + let result = self.parse(&rows); + + self.line_number += rows.len(); + + Some(result) + } +} + /// CSV file reader builder #[derive(Debug)] pub struct ReaderBuilder { @@ -832,11 +840,14 @@ mod tests { let mut csv = builder.build(file).unwrap(); match csv.next() { - Err(e) => assert_eq!( - "ParseError(\"Error while parsing value 4.x4 for column 1 at line 4\")", - format!("{:?}", e) - ), - Ok(_) => panic!("should have failed"), + Some(e) => match e { + Err(e) => assert_eq!( + "ParseError(\"Error while parsing value 4.x4 for column 1 at line 4\")", + format!("{:?}", e) + ), + Ok(_) => panic!("should have failed"), + } + None => panic!("should have failed"), } } diff --git a/rust/arrow/src/ipc/reader.rs b/rust/arrow/src/ipc/reader.rs index af0b4e66a39..53c422d481c 100644 --- a/rust/arrow/src/ipc/reader.rs +++ b/rust/arrow/src/ipc/reader.rs @@ -414,7 +414,7 @@ pub fn read_record_batch( batch: ipc::RecordBatch, schema: SchemaRef, dictionaries: &[Option], -) -> Result> { +) -> Result { let buffers = batch.buffers().ok_or_else(|| { ArrowError::IoError("Unable to get buffers from IPC RecordBatch".to_string()) })?; @@ -442,7 +442,7 @@ pub fn read_record_batch( arrays.push(triple.0); } - RecordBatch::try_new(schema, arrays).map(Some) + RecordBatch::try_new(schema, arrays) } // Linear search for the first dictionary field with a dictionary id. @@ -592,8 +592,7 @@ impl FileReader { batch.data().unwrap(), Arc::new(schema), &dictionaries_by_field, - )? - .unwrap(); + )?; Some(record_batch.column(0).clone()) } _ => None, @@ -662,81 +661,88 @@ impl FileReader { Ok(()) } } -} -impl RecordBatchReader for FileReader { - fn schema(&self) -> SchemaRef { - self.schema.clone() - } + fn maybe_next(&mut self) -> Result> { + let block = self.blocks[self.current_block]; + self.current_block += 1; - fn next_batch(&mut self) -> Result> { - // get current block - if self.current_block < self.total_blocks { - let block = self.blocks[self.current_block]; - self.current_block += 1; - - // read length - self.reader.seek(SeekFrom::Start(block.offset() as u64))?; - let mut meta_buf = [0; 4]; + // read length + self.reader.seek(SeekFrom::Start(block.offset() as u64))?; + let mut meta_buf = [0; 4]; + self.reader.read_exact(&mut meta_buf)?; + if meta_buf == CONTINUATION_MARKER { + // continuation marker encountered, read message next self.reader.read_exact(&mut meta_buf)?; - if meta_buf == CONTINUATION_MARKER { - // continuation marker encountered, read message next - self.reader.read_exact(&mut meta_buf)?; - } - let meta_len = i32::from_le_bytes(meta_buf); + } + let meta_len = i32::from_le_bytes(meta_buf); - let mut block_data = vec![0; meta_len as usize]; - self.reader.read_exact(&mut block_data)?; + let mut block_data = vec![0; meta_len as usize]; + self.reader.read_exact(&mut block_data)?; - let message = ipc::get_root_as_message(&block_data[..]); - - // some old test data's footer metadata is not set, so we account for that - if self.metadata_version != ipc::MetadataVersion::V1 - && message.version() != self.metadata_version - { - return Err(ArrowError::IoError( - "Could not read IPC message as metadata versions mismatch" - .to_string(), - )); - } + let message = ipc::get_root_as_message(&block_data[..]); - match message.header_type() { - ipc::MessageHeader::Schema => Err(ArrowError::IoError( - "Not expecting a schema when messages are read".to_string(), - )), - ipc::MessageHeader::RecordBatch => { - let batch = message.header_as_record_batch().ok_or_else(|| { - ArrowError::IoError( - "Unable to read IPC message as record batch".to_string(), - ) - })?; - // read the block that makes up the record batch into a buffer - let mut buf = vec![0; block.bodyLength() as usize]; - self.reader.seek(SeekFrom::Start( - block.offset() as u64 + block.metaDataLength() as u64, - ))?; - self.reader.read_exact(&mut buf)?; + // some old test data's footer metadata is not set, so we account for that + if self.metadata_version != ipc::MetadataVersion::V1 + && message.version() != self.metadata_version + { + return Err(ArrowError::IoError( + "Could not read IPC message as metadata versions mismatch".to_string(), + )); + } - read_record_batch( - &buf, - batch, - self.schema(), - &self.dictionaries_by_field, + match message.header_type() { + ipc::MessageHeader::Schema => Err(ArrowError::IoError( + "Not expecting a schema when messages are read".to_string(), + )), + ipc::MessageHeader::RecordBatch => { + let batch = message.header_as_record_batch().ok_or_else(|| { + ArrowError::IoError( + "Unable to read IPC message as record batch".to_string(), ) - } - ipc::MessageHeader::NONE => { - Ok(None) - } - t => Err(ArrowError::IoError(format!( - "Reading types other than record batches not yet supported, unable to read {:?}", t - ))), + })?; + // read the block that makes up the record batch into a buffer + let mut buf = vec![0; block.bodyLength() as usize]; + self.reader.seek(SeekFrom::Start( + block.offset() as u64 + block.metaDataLength() as u64, + ))?; + self.reader.read_exact(&mut buf)?; + + read_record_batch( + &buf, + batch, + self.schema(), + &self.dictionaries_by_field, + ).map(Some) } + ipc::MessageHeader::NONE => { + Ok(None) + } + t => Err(ArrowError::IoError(format!( + "Reading types other than record batches not yet supported, unable to read {:?}", t + ))), + } + } +} + +impl Iterator for FileReader { + type Item = Result; + + fn next(&mut self) -> Option { + // get current block + if self.current_block < self.total_blocks { + self.maybe_next().transpose() } else { - Ok(None) + None } } } +impl RecordBatchReader for FileReader { + fn schema(&self) -> SchemaRef { + self.schema.clone() + } +} + /// Arrow Stream reader pub struct StreamReader { /// Buffered stream reader @@ -805,14 +811,8 @@ impl StreamReader { pub fn is_finished(&self) -> bool { self.finished } -} -impl RecordBatchReader for StreamReader { - fn schema(&self) -> SchemaRef { - self.schema.clone() - } - - fn next_batch(&mut self) -> Result> { + fn maybe_next(&mut self) -> Result> { if self.finished { return Ok(None); } @@ -869,7 +869,7 @@ impl RecordBatchReader for StreamReader { let mut buf = vec![0; message.bodyLength() as usize]; self.reader.read_exact(&mut buf)?; - read_record_batch(&buf, batch, self.schema(), &self.dictionaries_by_field) + read_record_batch(&buf, batch, self.schema(), &self.dictionaries_by_field).map(Some) } ipc::MessageHeader::NONE => { Ok(None) @@ -881,6 +881,20 @@ impl RecordBatchReader for StreamReader { } } +impl Iterator for StreamReader { + type Item = Result; + + fn next(&mut self) -> Option { + self.maybe_next().transpose() + } +} + +impl RecordBatchReader for StreamReader { + fn schema(&self) -> SchemaRef { + self.schema.clone() + } +} + #[cfg(test)] mod tests { use super::*; @@ -945,7 +959,7 @@ mod tests { let arrow_json = read_gzip_json(path); assert!(arrow_json.equals_reader(&mut reader)); // the next batch must be empty - assert!(reader.next_batch().unwrap().is_none()); + assert!(reader.next().is_none()); // the stream must indicate that it's finished assert!(reader.is_finished()); }); @@ -975,8 +989,10 @@ mod tests { // read stream back let file = File::open("target/debug/testdata/float.stream").unwrap(); - let mut reader = StreamReader::try_new(file).unwrap(); - while let Some(batch) = reader.next_batch().unwrap() { + let reader = StreamReader::try_new(file).unwrap(); + + reader.for_each(|batch| { + let batch = batch.unwrap(); assert!( batch .column(0) @@ -995,7 +1011,7 @@ mod tests { .value(0) != 0.0 ); - } + }) } /// Read gzipped JSON file diff --git a/rust/arrow/src/ipc/writer.rs b/rust/arrow/src/ipc/writer.rs index effbc7168e5..7678a5cd200 100644 --- a/rust/arrow/src/ipc/writer.rs +++ b/rust/arrow/src/ipc/writer.rs @@ -593,7 +593,6 @@ mod tests { use crate::array::*; use crate::datatypes::Field; use crate::ipc::reader::*; - use crate::record_batch::RecordBatchReader; use crate::util::integration_util::*; use std::env; use std::fs::File; @@ -633,7 +632,7 @@ mod tests { File::open(format!("target/debug/testdata/{}.arrow_file", "arrow")) .unwrap(); let mut reader = FileReader::try_new(file).unwrap(); - while let Ok(Some(read_batch)) = reader.next_batch() { + while let Some(Ok(read_batch)) = reader.next() { read_batch .columns() .iter() @@ -679,9 +678,10 @@ mod tests { { let file = File::open("target/debug/testdata/nulls.arrow_file").unwrap(); - let mut reader = FileReader::try_new(file).unwrap(); - while let Ok(Some(read_batch)) = reader.next_batch() { - read_batch + let reader = FileReader::try_new(file).unwrap(); + reader.for_each(|maybe_batch| { + maybe_batch + .unwrap() .columns() .iter() .zip(batch.columns()) @@ -690,7 +690,7 @@ mod tests { assert_eq!(a.len(), b.len()); assert_eq!(a.null_count(), b.null_count()); }); - } + }); } } @@ -721,7 +721,7 @@ mod tests { File::create(format!("target/debug/testdata/{}.arrow_file", path)) .unwrap(); let mut writer = FileWriter::try_new(file, &reader.schema()).unwrap(); - while let Ok(Some(batch)) = reader.next_batch() { + while let Some(Ok(batch)) = reader.next() { writer.write(&batch).unwrap(); } writer.finish().unwrap(); @@ -756,16 +756,16 @@ mod tests { )) .unwrap(); - let mut reader = StreamReader::try_new(file).unwrap(); + let reader = StreamReader::try_new(file).unwrap(); // read and rewrite the stream to a temp location { let file = File::create(format!("target/debug/testdata/{}.stream", path)) .unwrap(); let mut writer = StreamWriter::try_new(file, &reader.schema()).unwrap(); - while let Ok(Some(batch)) = reader.next_batch() { - writer.write(&batch).unwrap(); - } + reader.for_each(|batch| { + writer.write(&batch.unwrap()).unwrap(); + }); writer.finish().unwrap(); } diff --git a/rust/arrow/src/record_batch.rs b/rust/arrow/src/record_batch.rs index 76b154ddb25..9bfe1873198 100644 --- a/rust/arrow/src/record_batch.rs +++ b/rust/arrow/src/record_batch.rs @@ -217,7 +217,7 @@ impl Into for RecordBatch { } /// Trait for types that can read `RecordBatch`'s. -pub trait RecordBatchReader { +pub trait RecordBatchReader: Iterator> { /// Returns the schema of this `RecordBatchReader`. /// /// Implementation of this trait should guarantee that all `RecordBatch`'s returned by this @@ -225,7 +225,13 @@ pub trait RecordBatchReader { fn schema(&self) -> SchemaRef; /// Reads the next `RecordBatch`. - fn next_batch(&mut self) -> Result>; + #[deprecated( + since = "2.0.0", + note = "This method is deprecated in favour of `next` from the trait Iterator." + )] + fn next_batch(&mut self) -> Result> { + self.next().transpose() + } } #[cfg(test)] diff --git a/rust/arrow/src/util/integration_util.rs b/rust/arrow/src/util/integration_util.rs index db3fc55ef88..c1bba13aee6 100644 --- a/rust/arrow/src/util/integration_util.rs +++ b/rust/arrow/src/util/integration_util.rs @@ -78,9 +78,9 @@ impl ArrowJson { return false; } self.batches.iter().all(|col| { - let batch = reader.next_batch(); + let batch = reader.next(); match batch { - Ok(Some(batch)) => col.equals_batch(&batch), + Some(Ok(batch)) => col.equals_batch(&batch), _ => false, } }) diff --git a/rust/datafusion/examples/flight_client.rs b/rust/datafusion/examples/flight_client.rs index 3bc2a04a499..64103445a93 100644 --- a/rust/datafusion/examples/flight_client.rs +++ b/rust/datafusion/examples/flight_client.rs @@ -63,7 +63,7 @@ async fn main() -> Result<(), Box> { while let Some(flight_data) = stream.message().await? { // the unwrap is infallible and thus safe let record_batch = - flight_data_to_arrow_batch(&flight_data, schema.clone())?.unwrap(); + flight_data_to_arrow_batch(&flight_data, schema.clone()).unwrap()?; results.push(record_batch); } diff --git a/rust/datafusion/src/datasource/memory.rs b/rust/datafusion/src/datasource/memory.rs index 4f0f8468280..fcff7d61ac5 100644 --- a/rust/datafusion/src/datasource/memory.rs +++ b/rust/datafusion/src/datasource/memory.rs @@ -22,6 +22,7 @@ use std::sync::Arc; use arrow::datatypes::{Field, Schema, SchemaRef}; +use arrow::error::Result as ArrowResult; use arrow::record_batch::RecordBatch; use crate::datasource::TableProvider; @@ -64,10 +65,7 @@ impl MemTable { for partition in 0..exec.output_partitioning().partition_count() { let it = exec.execute(partition).await?; let mut it = it.lock().unwrap(); - let mut partition_batches = vec![]; - while let Ok(Some(batch)) = it.next_batch() { - partition_batches.push(batch); - } + let partition_batches = it.into_iter().collect::>>()?; data.push(partition_batches); } @@ -148,7 +146,7 @@ mod tests { // scan with projection let exec = provider.scan(&Some(vec![2, 1]), 1024)?; let it = exec.execute(0).await?; - let batch2 = it.lock().expect("mutex lock").next_batch()?.unwrap(); + let batch2 = it.lock().expect("mutex lock").next().unwrap()?; assert_eq!(2, batch2.schema().fields().len()); assert_eq!("c", batch2.schema().field(0).name()); assert_eq!("b", batch2.schema().field(1).name()); @@ -178,7 +176,7 @@ mod tests { let exec = provider.scan(&None, 1024)?; let it = exec.execute(0).await?; - let batch1 = it.lock().expect("mutex lock").next_batch()?.unwrap(); + let batch1 = it.lock().expect("mutex lock").next().unwrap()?; assert_eq!(3, batch1.schema().fields().len()); assert_eq!(3, batch1.num_columns()); diff --git a/rust/datafusion/src/datasource/parquet.rs b/rust/datafusion/src/datasource/parquet.rs index cdefaad6459..9f11372b14a 100644 --- a/rust/datafusion/src/datasource/parquet.rs +++ b/rust/datafusion/src/datasource/parquet.rs @@ -85,12 +85,14 @@ mod tests { let it = exec.execute(0).await?; let mut it = it.lock().unwrap(); - let mut count = 0; - while let Some(batch) = it.next_batch()? { - assert_eq!(11, batch.num_columns()); - assert_eq!(2, batch.num_rows()); - count += 1; - } + let count = it + .into_iter() + .map(|batch| { + let batch = batch.unwrap(); + assert_eq!(11, batch.num_columns()); + assert_eq!(2, batch.num_rows()); + }) + .count(); // we should have seen 4 batches of 2 rows assert_eq!(4, count); @@ -304,8 +306,8 @@ mod tests { let exec = table.scan(projection, 1024)?; let it = exec.execute(0).await?; let mut it = it.lock().expect("failed to lock mutex"); - Ok(it - .next_batch()? - .expect("should have received at least one batch")) + it.next() + .expect("should have received at least one batch") + .map_err(|e| e.into()) } } diff --git a/rust/datafusion/src/execution/context.rs b/rust/datafusion/src/execution/context.rs index 7bc2d123b04..3df5e80f4b1 100644 --- a/rust/datafusion/src/execution/context.rs +++ b/rust/datafusion/src/execution/context.rs @@ -25,6 +25,7 @@ use std::sync::Arc; use arrow::csv; use arrow::datatypes::*; +use arrow::error::Result as ArrowResult; use arrow::record_batch::RecordBatch; use crate::datasource::csv::CsvFile; @@ -363,15 +364,13 @@ impl ExecutionContext { let mut writer = csv::Writer::new(file); let reader = plan.execute(i).await.unwrap(); let mut reader = reader.lock().unwrap(); - loop { - match reader.next_batch() { - Ok(Some(batch)) => writer.write(&batch).unwrap(), - Ok(None) => break, - Err(e) => return Err(ExecutionError::from(e)), - } - } - } + reader + .into_iter() + .map(|batch| writer.write(&batch?)) + .collect::>() + .map_err(|e| ExecutionError::from(e))? + } Ok(()) } diff --git a/rust/datafusion/src/physical_plan/common.rs b/rust/datafusion/src/physical_plan/common.rs index e13e64c2f56..3008d9b1bc7 100644 --- a/rust/datafusion/src/physical_plan/common.rs +++ b/rust/datafusion/src/physical_plan/common.rs @@ -54,39 +54,34 @@ impl RecordBatchIterator { } } -impl RecordBatchReader for RecordBatchIterator { - fn schema(&self) -> SchemaRef { - self.schema.clone() - } +impl Iterator for RecordBatchIterator { + type Item = ArrowResult; - fn next_batch(&mut self) -> ArrowResult> { + fn next(&mut self) -> Option { if self.index < self.batches.len() { self.index += 1; - Ok(Some(self.batches[self.index - 1].as_ref().clone())) + Some(Ok(self.batches[self.index - 1].as_ref().clone())) } else { - Ok(None) + None } } } +impl RecordBatchReader for RecordBatchIterator { + fn schema(&self) -> SchemaRef { + self.schema.clone() + } +} + /// Create a vector of record batches from an iterator pub fn collect( it: Arc>, ) -> Result> { - let mut reader = it.lock().unwrap(); - let mut results: Vec = vec![]; - loop { - match reader.next_batch() { - Ok(Some(batch)) => { - results.push(batch); - } - Ok(None) => { - // end of result set - return Ok(results); - } - Err(e) => return Err(ExecutionError::from(e)), - } - } + it.lock() + .unwrap() + .into_iter() + .collect::>>() + .map_err(|e| ExecutionError::from(e)) } /// Recursively build a list of files in a directory with a given extension diff --git a/rust/datafusion/src/physical_plan/csv.rs b/rust/datafusion/src/physical_plan/csv.rs index 436f5c4b64e..acbf92d87b6 100644 --- a/rust/datafusion/src/physical_plan/csv.rs +++ b/rust/datafusion/src/physical_plan/csv.rs @@ -262,16 +262,19 @@ impl CsvIterator { } } +impl Iterator for CsvIterator { + type Item = ArrowResult; + + fn next(&mut self) -> Option { + self.reader.next() + } +} + impl RecordBatchReader for CsvIterator { /// Get the schema fn schema(&self) -> SchemaRef { self.reader.schema() } - - /// Get the next RecordBatch - fn next_batch(&mut self) -> ArrowResult> { - Ok(self.reader.next()?) - } } #[cfg(test)] @@ -296,7 +299,7 @@ mod tests { assert_eq!(3, csv.schema().fields().len()); let it = csv.execute(0).await?; let mut it = it.lock().unwrap(); - let batch = it.next_batch()?.unwrap(); + let batch = it.next().unwrap()?; assert_eq!(3, batch.num_columns()); let batch_schema = batch.schema(); assert_eq!(3, batch_schema.fields().len()); @@ -319,7 +322,7 @@ mod tests { assert_eq!(13, csv.schema().fields().len()); let it = csv.execute(0).await?; let mut it = it.lock().unwrap(); - let batch = it.next_batch()?.unwrap(); + let batch = it.next().unwrap()?; assert_eq!(13, batch.num_columns()); let batch_schema = batch.schema(); assert_eq!(13, batch_schema.fields().len()); diff --git a/rust/datafusion/src/physical_plan/filter.rs b/rust/datafusion/src/physical_plan/filter.rs index ffc9fa5ff6f..26c04cb8d13 100644 --- a/rust/datafusion/src/physical_plan/filter.rs +++ b/rust/datafusion/src/physical_plan/filter.rs @@ -120,48 +120,58 @@ struct FilterExecIter { input: Arc>, } -impl RecordBatchReader for FilterExecIter { - /// Get the schema - fn schema(&self) -> SchemaRef { - self.schema.clone() - } +impl Iterator for FilterExecIter { + type Item = ArrowResult; /// Get the next batch - fn next_batch(&mut self) -> ArrowResult> { + fn next(&mut self) -> Option> { let mut input = self.input.lock().unwrap(); - match input.next_batch()? { - Some(batch) => { + match input.next() { + Some(Ok(batch)) => { // evaluate the filter predicate to get a boolean array indicating which rows // to include in the output - let result = self - .predicate - .evaluate(&batch) - .map_err(ExecutionError::into_arrow_external_error)?; - - if let Some(f) = result.as_any().downcast_ref::() { - // filter each array - let mut filtered_arrays = Vec::with_capacity(batch.num_columns()); - for i in 0..batch.num_columns() { - let array = batch.column(i); - let filtered_array = filter(array.as_ref(), f)?; - filtered_arrays.push(filtered_array); - } - Ok(Some(RecordBatch::try_new( - batch.schema().clone(), - filtered_arrays, - )?)) - } else { - Err(ExecutionError::InternalError( - "Filter predicate evaluated to non-boolean value".to_string(), - ) - .into_arrow_external_error()) - } + Some( + self.predicate + .evaluate(&batch) + .map_err(ExecutionError::into_arrow_external_error) + .and_then(|array| { + array + .as_any() + .downcast_ref::() + .ok_or( + ExecutionError::InternalError( + "Filter predicate evaluated to non-boolean value" + .to_string(), + ) + .into_arrow_external_error(), + ) + // apply predicate to each column + .and_then(|predicate| { + batch + .columns() + .iter() + .map(|column| filter(column.as_ref(), predicate)) + .collect::>>() + }) + }) + // build RecordBatch + .and_then(|columns| { + RecordBatch::try_new(batch.schema().clone(), columns) + }), + ) } - None => Ok(None), + other => other, } } } +impl RecordBatchReader for FilterExecIter { + /// Get the schema + fn schema(&self) -> SchemaRef { + self.schema.clone() + } +} + #[cfg(test)] mod tests { diff --git a/rust/datafusion/src/physical_plan/hash_aggregate.rs b/rust/datafusion/src/physical_plan/hash_aggregate.rs index 8bb679b3f32..5420135812b 100644 --- a/rust/datafusion/src/physical_plan/hash_aggregate.rs +++ b/rust/datafusion/src/physical_plan/hash_aggregate.rs @@ -240,26 +240,124 @@ impl GroupedHashAggregateIterator { finished: false, } } + + fn aggregate_batch( + &self, + batch: &RecordBatch, + accumulators: &mut FnvHashMap< + Vec, + (AccumulatorSet, Box>), + >, + aggregate_expressions: &Vec>>, + ) -> Result<()> { + // evaluate the grouping expressions + let group_values = evaluate(&self.group_expr, batch)?; + + // evaluate the aggregation expressions. + // We could evaluate them after the `take`, but since we need to evaluate all + // of them anyways, it is more performant to do it while they are together. + let aggr_input_values = evaluate_many(aggregate_expressions, &batch)?; + + // create vector large enough to hold the grouping key + // this is an optimization to avoid allocating `key` on every row. + // it will be overwritten on every iteration of the loop below + let mut key = Vec::with_capacity(group_values.len()); + for _ in 0..group_values.len() { + key.push(GroupByScalar::UInt32(0)); + } + + // 1.1 construct the key from the group values + // 1.2 construct the mapping key if it does not exist + // 1.3 add the row' index to `indices` + for row in 0..batch.num_rows() { + // 1.1 + create_key(&group_values, row, &mut key) + .map_err(ExecutionError::into_arrow_external_error)?; + + match accumulators.get_mut(&key) { + // 1.2 + None => { + let accumulator_set = create_accumulators(&self.aggr_expr) + .map_err(ExecutionError::into_arrow_external_error)?; + + accumulators.insert( + key.clone(), + (accumulator_set, Box::new(vec![row as u32])), + ); + } + // 1.3 + Some((_, v)) => v.push(row as u32), + } + } + + // 2.1 for each key + // 2.2 for each aggregation + // 2.3 `take` from each of its arrays the keys' values + // 2.4 update / merge the accumulator with the values + // 2.5 clear indices + accumulators + .iter_mut() + // 2.1 + .map(|(_, (accumulator_set, indices))| { + // 2.2 + accumulator_set + .iter() + .zip(&aggr_input_values) + .into_iter() + .map(|(accumulator, aggr_array)| { + ( + accumulator, + aggr_array + .iter() + .map(|array| { + // 2.3 + compute::take( + array, + &UInt32Array::from(*indices.clone()), + None, // None: no index check + ) + .unwrap() + }) + .collect::>(), + ) + }) + // 2.4 + .map(|(accumulator, values)| match self.mode { + AggregateMode::Partial => { + accumulator.borrow_mut().update_batch(&values) + } + AggregateMode::Final => { + // note: the aggregation here is over states, not values, thus the merge + accumulator.borrow_mut().merge_batch(&values) + } + }) + .collect::>() + // 2.5 + .and(Ok(indices.clear())) + }) + .collect::>() + } } type AccumulatorSet = Vec>>; -impl RecordBatchReader for GroupedHashAggregateIterator { - fn schema(&self) -> SchemaRef { - self.schema.clone() - } +impl Iterator for GroupedHashAggregateIterator { + type Item = ArrowResult; - fn next_batch(&mut self) -> ArrowResult> { + fn next(&mut self) -> Option { if self.finished { - return Ok(None); + return None; } // return single batch self.finished = true; // the expressions to evaluate the batch, one vec of expressions per aggregation - let aggregate_expressions = aggregate_expressions(&self.aggr_expr, &self.mode) - .map_err(ExecutionError::into_arrow_external_error)?; + let aggregate_expressions = + match aggregate_expressions(&self.aggr_expr, &self.mode) { + Ok(e) => e, + Err(e) => return Some(Err(ExecutionError::into_arrow_external_error(e))), + }; // mapping key -> (set of accumulators, indices of the key in the batch) // * the indexes are updated at each row @@ -271,110 +369,36 @@ impl RecordBatchReader for GroupedHashAggregateIterator { > = FnvHashMap::default(); // iterate over all input batches and update the accumulators - let mut input = self.input.lock().unwrap(); - - // iterate over input and perform aggregation - while let Some(batch) = &input.next_batch()? { - // evaluate the grouping expressions - let group_values = evaluate(&self.group_expr, batch) - .map_err(ExecutionError::into_arrow_external_error)?; - - // evaluate the aggregation expressions. - // We could evaluate them after the `take`, but since we need to evaluate all - // of them anyways, it is more performant to do it while they are together. - let aggr_input_values = evaluate_many(&aggregate_expressions, &batch) - .map_err(ExecutionError::into_arrow_external_error)?; - - // create vector large enough to hold the grouping key - // this is an optimization to avoid allocating `key` on every row. - // it will be overwritten on every iteration of the loop below - let mut key = Vec::with_capacity(group_values.len()); - for _ in 0..group_values.len() { - key.push(GroupByScalar::UInt32(0)); - } - - // 1.1 construct the key from the group values - // 1.2 construct the mapping key if it does not exist - // 1.3 add the row' index to `indices` - for row in 0..batch.num_rows() { - // 1.1 - create_key(&group_values, row, &mut key) - .map_err(ExecutionError::into_arrow_external_error)?; - - match accumulators.get_mut(&key) { - // 1.2 - None => { - let accumulator_set = create_accumulators(&self.aggr_expr) - .map_err(ExecutionError::into_arrow_external_error)?; - - accumulators.insert( - key.clone(), - (accumulator_set, Box::new(vec![row as u32])), - ); - } - // 1.3 - Some((_, v)) => v.push(row as u32), - } - } - - // 2.1 for each key - // 2.2 for each aggregation - // 2.3 `take` from each of its arrays the keys' values - // 2.4 update / merge the accumulator with the values - // 2.5 clear indices - accumulators - .iter_mut() - // 2.1 - .map(|(_, (accumulator_set, indices))| { - // 2.2 - accumulator_set - .iter() - .zip(&aggr_input_values) - .into_iter() - .map(|(accumulator, aggr_array)| { - ( - accumulator, - aggr_array - .iter() - .map(|array| { - // 2.3 - compute::take( - array, - &UInt32Array::from(*indices.clone()), - None, // None: no index check - ) - .unwrap() - }) - .collect::>(), - ) - }) - // 2.4 - .map(|(accumulator, values)| match self.mode { - AggregateMode::Partial => { - accumulator.borrow_mut().update_batch(&values) - } - AggregateMode::Final => { - // note: the aggregation here is over states, not values, thus the merge - accumulator.borrow_mut().merge_batch(&values) - } - }) - .collect::>() - // 2.5 - .and(Ok(indices.clear())) - }) - .collect::>() - .map_err(ExecutionError::into_arrow_external_error)?; + match self + .input + .lock() + .unwrap() + .into_iter() + .map(|batch| { + self.aggregate_batch(&batch?, &mut accumulators, &aggregate_expressions) + .map_err(ExecutionError::into_arrow_external_error) + }) + .collect::>() + { + Err(e) => return Some(Err(e)), + Ok(_) => {} } - let batch = create_batch_from_map( - &self.mode, - &accumulators, - self.group_expr.len(), - &self.schema, + Some( + create_batch_from_map( + &self.mode, + &accumulators, + self.group_expr.len(), + &self.schema, + ) + .map_err(ExecutionError::into_arrow_external_error), ) - .map_err(ExecutionError::into_arrow_external_error)?; + } +} - Ok(Some(batch)) +impl RecordBatchReader for GroupedHashAggregateIterator { + fn schema(&self) -> SchemaRef { + self.schema.clone() } } @@ -456,62 +480,87 @@ impl HashAggregateIterator { finished: false, } } -} -impl RecordBatchReader for HashAggregateIterator { - fn schema(&self) -> SchemaRef { - self.schema.clone() + fn aggregate_batch( + &self, + batch: &RecordBatch, + accumulators: &AccumulatorSet, + expressions: &Vec>>, + ) -> Result<()> { + // 1.1 iterate accumulators and respective expressions together + // 1.2 evaluate expressions + // 1.3 update / merge accumulators with the expressions' values + + // 1.1 + accumulators + .iter() + .zip(expressions) + .map(|(accum, expr)| { + // 1.2 + let values = &expr + .iter() + .map(|e| e.evaluate(batch)) + .collect::>>()?; + + // 1.3 + match self.mode { + AggregateMode::Partial => accum.borrow_mut().update_batch(values), + AggregateMode::Final => accum.borrow_mut().merge_batch(values), + } + }) + .collect::>() } +} + +impl Iterator for HashAggregateIterator { + type Item = ArrowResult; - fn next_batch(&mut self) -> ArrowResult> { + fn next(&mut self) -> Option { if self.finished { - return Ok(None); + return None; } // return single batch self.finished = true; - let accumulators = create_accumulators(&self.aggr_expr) - .map_err(ExecutionError::into_arrow_external_error)?; - - let expressions = aggregate_expressions(&self.aggr_expr, &self.mode) - .map_err(ExecutionError::into_arrow_external_error)?; - - let mut input = self.input.lock().unwrap(); - - // 1 for each batch: - // 1.1 iterate accumulators and respective expressions together - // 1.2 evaluate expressions - // 1.3 update / merge accumulators with the expressions' values - // 2 convert values to a record batch - while let Some(batch) = input.next_batch()? { - // 1.1 - accumulators - .iter() - .zip(&expressions) - .map(|(accum, expr)| { - // 1.2 - let values = &expr - .iter() - .map(|e| e.evaluate(&batch)) - .collect::>>()?; - - // 1.3 - match self.mode { - AggregateMode::Partial => accum.borrow_mut().update_batch(values), - AggregateMode::Final => accum.borrow_mut().merge_batch(values), - } - }) - .collect::>() - .map_err(ExecutionError::into_arrow_external_error)?; + let accumulators = match create_accumulators(&self.aggr_expr) { + Ok(e) => e, + Err(e) => return Some(Err(ExecutionError::into_arrow_external_error(e))), + }; + + let expressions = match aggregate_expressions(&self.aggr_expr, &self.mode) { + Ok(e) => e, + Err(e) => return Some(Err(ExecutionError::into_arrow_external_error(e))), + }; + + // 1 for each batch, update / merge accumulators with the expressions' values + match self + .input + .lock() + .unwrap() + .into_iter() + .map(|batch| { + self.aggregate_batch(&batch?, &accumulators, &expressions) + .map_err(ExecutionError::into_arrow_external_error) + }) + .collect::>() + { + Err(e) => return Some(Err(e)), + Ok(_) => {} } - // 2 - let columns = finalize_aggregation(&accumulators, &self.mode) - .map_err(ExecutionError::into_arrow_external_error)?; + // 2 convert values to a record batch + Some( + finalize_aggregation(&accumulators, &self.mode) + .map_err(ExecutionError::into_arrow_external_error) + .and_then(|columns| RecordBatch::try_new(self.schema.clone(), columns)), + ) + } +} - let batch = RecordBatch::try_new(self.schema.clone(), columns)?; - Ok(Some(batch)) +impl RecordBatchReader for HashAggregateIterator { + fn schema(&self) -> SchemaRef { + self.schema.clone() } } diff --git a/rust/datafusion/src/physical_plan/limit.rs b/rust/datafusion/src/physical_plan/limit.rs index a42140e9cae..fee774f133b 100644 --- a/rust/datafusion/src/physical_plan/limit.rs +++ b/rust/datafusion/src/physical_plan/limit.rs @@ -202,8 +202,8 @@ fn collect_with_limit( let mut reader = reader.lock().unwrap(); let mut results: Vec = vec![]; loop { - match reader.next_batch() { - Ok(Some(batch)) => { + match reader.next() { + Some(Ok(batch)) => { let capacity = limit - count; if batch.num_rows() <= capacity { count += batch.num_rows(); @@ -217,11 +217,10 @@ fn collect_with_limit( return Ok(results); } } - Ok(None) => { - // end of result set + None => { return Ok(results); } - Err(e) => return Err(ExecutionError::from(e)), + Some(Err(e)) => return Err(ExecutionError::from(e)), } } } diff --git a/rust/datafusion/src/physical_plan/memory.rs b/rust/datafusion/src/physical_plan/memory.rs index fd86006885b..b64976f4eac 100644 --- a/rust/datafusion/src/physical_plan/memory.rs +++ b/rust/datafusion/src/physical_plan/memory.rs @@ -126,27 +126,30 @@ impl MemoryIterator { } } -impl RecordBatchReader for MemoryIterator { - /// Get the schema - fn schema(&self) -> SchemaRef { - self.schema.clone() - } +impl Iterator for MemoryIterator { + type Item = ArrowResult; - /// Get the next RecordBatch - fn next_batch(&mut self) -> ArrowResult> { + fn next(&mut self) -> Option { if self.index < self.data.len() { self.index += 1; let batch = &self.data[self.index - 1]; // apply projection match &self.projection { - Some(columns) => Ok(Some(RecordBatch::try_new( + Some(columns) => Some(RecordBatch::try_new( self.schema.clone(), columns.iter().map(|i| batch.column(*i).clone()).collect(), - )?)), - None => Ok(Some(batch.clone())), + )), + None => Some(Ok(batch.clone())), } } else { - Ok(None) + None } } } + +impl RecordBatchReader for MemoryIterator { + /// Get the schema + fn schema(&self) -> SchemaRef { + self.schema.clone() + } +} diff --git a/rust/datafusion/src/physical_plan/parquet.rs b/rust/datafusion/src/physical_plan/parquet.rs index ddffdf8d606..83525f9364c 100644 --- a/rust/datafusion/src/physical_plan/parquet.rs +++ b/rust/datafusion/src/physical_plan/parquet.rs @@ -131,8 +131,8 @@ impl ExecutionPlan for ParquetExec { // because the parquet implementation is not thread-safe, it is necessary to execute // on a thread and communicate with channels let (response_tx, response_rx): ( - Sender>>, - Receiver>>, + Sender>>, + Receiver>>, ) = bounded(2); let filename = self.filenames[partition].clone(); @@ -155,8 +155,8 @@ impl ExecutionPlan for ParquetExec { } fn send_result( - response_tx: &Sender>>, - result: ArrowResult>, + response_tx: &Sender>>, + result: Option>, ) -> Result<()> { response_tx .send(result) @@ -168,7 +168,7 @@ fn read_file( filename: &str, projection: Vec, batch_size: usize, - response_tx: Sender>>, + response_tx: Sender>>, ) -> Result<()> { let file = File::open(&filename)?; let file_reader = Rc::new(SerializedFileReader::new(file)?); @@ -176,20 +176,20 @@ fn read_file( let mut batch_reader = arrow_reader.get_record_reader_by_columns(projection.clone(), batch_size)?; loop { - match batch_reader.next_batch() { - Ok(Some(batch)) => send_result(&response_tx, Ok(Some(batch)))?, - Ok(None) => { + match batch_reader.next() { + Some(Ok(batch)) => send_result(&response_tx, Some(Ok(batch)))?, + None => { // finished reading file - send_result(&response_tx, Ok(None))?; + send_result(&response_tx, None)?; break; } - Err(e) => { + Some(Err(e)) => { let err_msg = format!("Error reading batch from {}: {}", filename, e.to_string()); // send error to operator send_result( &response_tx, - Err(ArrowError::ParquetError(err_msg.clone())), + Some(Err(ArrowError::ParquetError(err_msg.clone()))), )?; // terminate thread with error return Err(ExecutionError::ExecutionError(err_msg)); @@ -201,23 +201,27 @@ fn read_file( struct ParquetIterator { schema: SchemaRef, - response_rx: Receiver>>, + response_rx: Receiver>>, } -impl RecordBatchReader for ParquetIterator { - fn schema(&self) -> SchemaRef { - self.schema.clone() - } +impl Iterator for ParquetIterator { + type Item = ArrowResult; - fn next_batch(&mut self) -> ArrowResult> { + fn next(&mut self) -> Option { match self.response_rx.recv() { Ok(batch) => batch, // RecvError means receiver has exited and closed the channel - Err(RecvError) => Ok(None), + Err(RecvError) => None, } } } +impl RecordBatchReader for ParquetIterator { + fn schema(&self) -> SchemaRef { + self.schema.clone() + } +} + #[cfg(test)] mod tests { use super::*; @@ -233,7 +237,7 @@ mod tests { let results = parquet_exec.execute(0).await?; let mut results = results.lock().unwrap(); - let batch = results.next_batch()?.unwrap(); + let batch = results.next().unwrap()?; assert_eq!(8, batch.num_rows()); assert_eq!(3, batch.num_columns()); @@ -243,13 +247,13 @@ mod tests { schema.fields().iter().map(|f| f.name().as_str()).collect(); assert_eq!(vec!["id", "bool_col", "tinyint_col"], field_names); - let batch = results.next_batch()?; + let batch = results.next(); assert!(batch.is_none()); - let batch = results.next_batch()?; + let batch = results.next(); assert!(batch.is_none()); - let batch = results.next_batch()?; + let batch = results.next(); assert!(batch.is_none()); Ok(()) diff --git a/rust/datafusion/src/physical_plan/projection.rs b/rust/datafusion/src/physical_plan/projection.rs index 43f60510b67..9c74dbd91ad 100644 --- a/rust/datafusion/src/physical_plan/projection.rs +++ b/rust/datafusion/src/physical_plan/projection.rs @@ -126,27 +126,32 @@ struct ProjectionIterator { input: Arc>, } +impl Iterator for ProjectionIterator { + type Item = ArrowResult; + + fn next(&mut self) -> Option { + let mut input = self.input.lock().unwrap(); + match input.next() { + Some(Ok(batch)) => Some( + self.expr + .iter() + .map(|expr| expr.evaluate(&batch)) + .collect::>>() + .map_or_else( + |e| Err(ExecutionError::into_arrow_external_error(e)), + |arrays| RecordBatch::try_new(self.schema.clone(), arrays), + ), + ), + other => other, + } + } +} + impl RecordBatchReader for ProjectionIterator { /// Get the schema fn schema(&self) -> SchemaRef { self.schema.clone() } - - /// Get the next batch - fn next_batch(&mut self) -> ArrowResult> { - let mut input = self.input.lock().unwrap(); - match input.next_batch()? { - Some(batch) => { - let arrays: Result> = - self.expr.iter().map(|expr| expr.evaluate(&batch)).collect(); - Ok(Some(RecordBatch::try_new( - self.schema.clone(), - arrays.map_err(ExecutionError::into_arrow_external_error)?, - )?)) - } - None => Ok(None), - } - } } #[cfg(test)] @@ -177,10 +182,15 @@ mod tests { partition_count += 1; let iterator = projection.execute(partition).await?; let mut iterator = iterator.lock().unwrap(); - while let Some(batch) = iterator.next_batch()? { - assert_eq!(1, batch.num_columns()); - row_count += batch.num_rows(); - } + + row_count += iterator + .into_iter() + .map(|batch| { + let batch = batch.unwrap(); + assert_eq!(1, batch.num_columns()); + batch.num_rows() + }) + .sum::(); } assert_eq!(partitions, partition_count); assert_eq!(100, row_count); diff --git a/rust/datafusion/tests/user_defined_plan.rs b/rust/datafusion/tests/user_defined_plan.rs index aa55238c166..03bc173b2b1 100644 --- a/rust/datafusion/tests/user_defined_plan.rs +++ b/rust/datafusion/tests/user_defined_plan.rs @@ -59,7 +59,6 @@ //! use arrow::{ - array::StringBuilder, array::{Int64Array, PrimitiveArrayOps, StringArray}, datatypes::SchemaRef, error::ArrowError, @@ -450,73 +449,69 @@ impl TopKReader { self.top_values.remove(&smallest_revenue); } } -} -impl RecordBatchReader for TopKReader { - fn schema(&self) -> SchemaRef { - self.input.lock().expect("locked input reader").schema() + // how we process a whole batch + fn accumulate_batch(&mut self, input_batch: &RecordBatch) -> Result<()> { + let num_rows = input_batch.num_rows(); + // Assuming the input columns are + // column[0]: customer_id / UTF8 + // column[1]: revenue: Int64 + let customer_id = input_batch + .column(0) + .as_any() + .downcast_ref::() + .expect("Column 0 is not customer_id"); + + let revenue = input_batch + .column(1) + .as_any() + .downcast_ref::() + .expect("Column 1 is not revenue"); + + for row in 0..num_rows { + self.add_row(customer_id.value(row), revenue.value(row)); + } + Ok(()) } +} + +impl Iterator for TopKReader { + type Item = std::result::Result; /// Reads the next `RecordBatch`. - fn next_batch(&mut self) -> std::result::Result, ArrowError> { + fn next(&mut self) -> Option { if self.done { - return Ok(None); + return None; } - // use a loop so that we release the mutex once we have read each input_batch - loop { - let input_batch = self - .input - .lock() - .expect("locked input mutex") - .next_batch()?; - - match input_batch { - Some(input_batch) => { - println!("Got an input batch"); - let num_rows = input_batch.num_rows(); - - // Assuming the input columns are - // column[0]: customer_id / UTF8 - // column[1]: revenue: Int64 - let customer_id = input_batch - .column(0) - .as_any() - .downcast_ref::() - .expect("Column 0 is not customer_id"); - - let revenue = input_batch - .column(1) - .as_any() - .downcast_ref::() - .expect("Column 1 is not revenue"); - - for row in 0..num_rows { - self.add_row(customer_id.value(row), revenue.value(row)); - } - } - None => break, - } - } - - let mut revenue_builder = Int64Array::builder(self.top_values.len()); - let mut customer_id_builder = StringBuilder::new(self.top_values.len()); + self.input + .clone() + .lock() + .unwrap() + .into_iter() + .map(|batch| self.accumulate_batch(&batch?)) + .collect::>() + .unwrap(); // make output by walking over the map backwards (so values are descending) - for (revenue, customer_id) in self.top_values.iter().rev() { - revenue_builder.append_value(*revenue)?; - customer_id_builder.append_value(customer_id)?; - } + let (revenue, customer): (Vec, Vec<&String>) = + self.top_values.iter().rev().unzip(); + + let customer: Vec<&str> = customer.iter().map(|&s| &**s).collect(); - let record_batch = RecordBatch::try_new( + self.done = true; + Some(RecordBatch::try_new( self.schema().clone(), vec![ - Arc::new(customer_id_builder.finish()), - Arc::new(revenue_builder.finish()), + Arc::new(StringArray::from(customer)), + Arc::new(Int64Array::from(revenue)), ], - )?; + )) + } +} - self.done = true; - Ok(Some(record_batch)) +impl RecordBatchReader for TopKReader { + fn schema(&self) -> SchemaRef { + self.input.lock().expect("locked input reader").schema() } } diff --git a/rust/integration-testing/src/bin/arrow-file-to-stream.rs b/rust/integration-testing/src/bin/arrow-file-to-stream.rs index d538b4f49aa..ded1972e40c 100644 --- a/rust/integration-testing/src/bin/arrow-file-to-stream.rs +++ b/rust/integration-testing/src/bin/arrow-file-to-stream.rs @@ -22,7 +22,6 @@ use std::io::{self, BufReader}; use arrow::error::Result; use arrow::ipc::reader::FileReader; use arrow::ipc::writer::StreamWriter; -use arrow::record_batch::RecordBatchReader; fn main() -> Result<()> { let args: Vec = env::args().collect(); @@ -33,14 +32,17 @@ fn main() -> Result<()> { let f = File::open(filename)?; let reader = BufReader::new(f); - let mut reader = FileReader::try_new(reader)?; + let reader = FileReader::try_new(reader)?; let schema = reader.schema(); let mut writer = StreamWriter::try_new(io::stdout(), &schema)?; - while let Some(batch) = reader.next_batch()? { - writer.write(&batch)?; - } + reader + .map(|batch| { + let batch = batch?; + writer.write(&batch) + }) + .collect::>()?; writer.finish()?; eprintln!("Completed without error"); diff --git a/rust/integration-testing/src/bin/arrow-json-integration-test.rs b/rust/integration-testing/src/bin/arrow-json-integration-test.rs index 196550443f3..5556c4cebc8 100644 --- a/rust/integration-testing/src/bin/arrow-json-integration-test.rs +++ b/rust/integration-testing/src/bin/arrow-json-integration-test.rs @@ -25,7 +25,7 @@ use arrow::datatypes::{DataType, DateUnit, IntervalUnit, Schema}; use arrow::error::{ArrowError, Result}; use arrow::ipc::reader::FileReader; use arrow::ipc::writer::FileWriter; -use arrow::record_batch::{RecordBatch, RecordBatchReader}; +use arrow::record_batch::RecordBatch; use hex::decode; use std::env; @@ -423,7 +423,7 @@ fn arrow_to_json(arrow_name: &str, json_name: &str, verbose: bool) -> Result<()> } let arrow_file = File::open(arrow_name)?; - let mut reader = FileReader::try_new(arrow_file)?; + let reader = FileReader::try_new(arrow_file)?; let mut fields = vec![]; for f in reader.schema().fields() { @@ -431,10 +431,9 @@ fn arrow_to_json(arrow_name: &str, json_name: &str, verbose: bool) -> Result<()> } let schema = ArrowJsonSchema { fields }; - let mut batches = vec![]; - while let Ok(Some(batch)) = reader.next_batch() { - batches.push(ArrowJsonBatch::from_batch(&batch)); - } + let batches = reader + .map(|batch| Ok(ArrowJsonBatch::from_batch(&batch?))) + .collect::>>()?; let arrow_json = ArrowJson { schema, @@ -483,7 +482,7 @@ fn validate(arrow_name: &str, json_name: &str, verbose: bool) -> Result<()> { } for json_batch in &json_batches { - if let Some(arrow_batch) = arrow_reader.next_batch()? { + if let Some(Ok(arrow_batch)) = arrow_reader.next() { // compare batches assert!(arrow_batch.num_columns() == json_batch.num_columns()); assert!(arrow_batch.num_rows() == json_batch.num_rows()); @@ -500,7 +499,7 @@ fn validate(arrow_name: &str, json_name: &str, verbose: bool) -> Result<()> { } } - if arrow_reader.next_batch()?.is_some() { + if arrow_reader.next().is_some() { return Err(ArrowError::ComputeError( "no more json batches left".to_owned(), )); diff --git a/rust/integration-testing/src/bin/arrow-stream-to-file.rs b/rust/integration-testing/src/bin/arrow-stream-to-file.rs index 2468970e6b7..87101c3ec89 100644 --- a/rust/integration-testing/src/bin/arrow-stream-to-file.rs +++ b/rust/integration-testing/src/bin/arrow-stream-to-file.rs @@ -21,20 +21,19 @@ use std::io; use arrow::error::Result; use arrow::ipc::reader::StreamReader; use arrow::ipc::writer::FileWriter; -use arrow::record_batch::RecordBatchReader; fn main() -> Result<()> { let args: Vec = env::args().collect(); eprintln!("{:?}", args); - let mut arrow_stream_reader = StreamReader::try_new(io::stdin())?; + let arrow_stream_reader = StreamReader::try_new(io::stdin())?; let schema = arrow_stream_reader.schema(); let mut writer = FileWriter::try_new(io::stdout(), &schema)?; - while let Some(batch) = arrow_stream_reader.next_batch()? { - writer.write(&batch)?; - } + arrow_stream_reader + .map(|batch| writer.write(&batch?)) + .collect::>()?; writer.finish()?; eprintln!("Completed without error"); diff --git a/rust/parquet/src/arrow/arrow_reader.rs b/rust/parquet/src/arrow/arrow_reader.rs index 106e6cc80f8..b654de1ad0a 100644 --- a/rust/parquet/src/arrow/arrow_reader.rs +++ b/rust/parquet/src/arrow/arrow_reader.rs @@ -22,10 +22,10 @@ use crate::arrow::schema::parquet_to_arrow_schema; use crate::arrow::schema::parquet_to_arrow_schema_by_columns; use crate::errors::{ParquetError, Result}; use crate::file::reader::FileReader; -use arrow::array::StructArray; use arrow::datatypes::{DataType as ArrowType, Schema, SchemaRef}; use arrow::error::Result as ArrowResult; use arrow::record_batch::{RecordBatch, RecordBatchReader}; +use arrow::{array::StructArray, error::ArrowError}; use std::rc::Rc; use std::sync::Arc; @@ -143,38 +143,43 @@ pub struct ParquetRecordBatchReader { schema: SchemaRef, } +impl Iterator for ParquetRecordBatchReader { + type Item = ArrowResult; + + fn next(&mut self) -> Option { + match self.array_reader.next_batch(self.batch_size) { + Err(error) => Some(Err(error.into())), + Ok(array) => { + let struct_array = + array.as_any().downcast_ref::().ok_or_else(|| { + ArrowError::ParquetError( + "Struct array reader should return struct array".to_string(), + ) + }); + match struct_array { + Err(err) => Some(Err(err)), + Ok(e) => { + match RecordBatch::try_new(self.schema.clone(), e.columns_ref()) { + Err(err) => Some(Err(err)), + Ok(record_batch) => { + if record_batch.num_rows() > 0 { + Some(Ok(record_batch)) + } else { + None + } + } + } + } + } + } + } + } +} + impl RecordBatchReader for ParquetRecordBatchReader { fn schema(&self) -> SchemaRef { self.schema.clone() } - - fn next_batch(&mut self) -> ArrowResult> { - self.array_reader - .next_batch(self.batch_size) - .map_err(|err| err.into()) - .and_then(|array| { - array - .as_any() - .downcast_ref::() - .ok_or_else(|| { - general_err!("Struct array reader should return struct array") - .into() - }) - .and_then(|struct_array| { - RecordBatch::try_new( - self.schema.clone(), - struct_array.columns_ref(), - ) - }) - }) - .map(|record_batch| { - if record_batch.num_rows() > 0 { - Some(record_batch) - } else { - None - } - }) - } } impl ParquetRecordBatchReader { @@ -472,7 +477,7 @@ mod tests { for i in 0..opts.num_iterations { let start = i * opts.record_batch_size; - let batch = record_reader.next_batch().unwrap(); + let batch = record_reader.next(); if start < expected_data.len() { let end = min(start + opts.record_batch_size, expected_data.len()); assert!(batch.is_some()); @@ -483,6 +488,7 @@ mod tests { assert_eq!( &converter.convert(data).unwrap(), batch + .unwrap() .unwrap() .column(0) .as_any() @@ -554,9 +560,8 @@ mod tests { ) { for i in 0..20 { let array: Option = record_batch_reader - .next_batch() - .expect("Failed to read record batch!") - .map(|r| r.into()); + .next() + .map(|r| r.expect("Failed to read record batch!").into()); let (start, end) = (i * 60 as usize, (i + 1) * 60 as usize); diff --git a/rust/parquet/src/arrow/mod.rs b/rust/parquet/src/arrow/mod.rs index 02f50fd3a90..ef1544d65bb 100644 --- a/rust/parquet/src/arrow/mod.rs +++ b/rust/parquet/src/arrow/mod.rs @@ -39,8 +39,8 @@ //! //! let mut record_batch_reader = arrow_reader.get_record_reader(2048).unwrap(); //! -//! loop { -//! let record_batch = record_batch_reader.next_batch().unwrap().unwrap(); +//! for maybe_record_batch in record_batch_reader { +//! let record_batch = maybe_record_batch.unwrap(); //! if record_batch.num_rows() > 0 { //! println!("Read {} records.", record_batch.num_rows()); //! } else {