diff --git a/arrow-avro/src/reader/mod.rs b/arrow-avro/src/reader/mod.rs index e9bf7af61e1c..b4e1c7ad4ee5 100644 --- a/arrow-avro/src/reader/mod.rs +++ b/arrow-avro/src/reader/mod.rs @@ -89,7 +89,6 @@ //! } //! ``` //! - use crate::codec::{AvroField, AvroFieldBuilder}; use crate::schema::{ compare_schemas, generate_fingerprint, AvroSchema, Fingerprint, FingerprintAlgorithm, Schema, @@ -130,6 +129,15 @@ fn read_header(mut reader: R) -> Result { }) } +// NOTE: The Current ` is_incomplete_data ` below is temporary and will be improved prior to public release +fn is_incomplete_data(err: &ArrowError) -> bool { + matches!( + err, + ArrowError::ParseError(msg) + if msg.contains("Unexpected EOF") + ) +} + /// A low-level interface for decoding Avro-encoded bytes into Arrow `RecordBatch`. #[derive(Debug)] pub struct Decoder { @@ -139,10 +147,10 @@ pub struct Decoder { remaining_capacity: usize, cache: IndexMap, fingerprint_algorithm: FingerprintAlgorithm, - expect_prefix: bool, utf8_view: bool, strict_mode: bool, pending_schema: Option<(Fingerprint, RecordDecoder)>, + awaiting_body: bool, } impl Decoder { @@ -162,29 +170,33 @@ impl Decoder { /// /// Returns the number of bytes consumed. pub fn decode(&mut self, data: &[u8]) -> Result { - if self.expect_prefix - && data.len() >= SINGLE_OBJECT_MAGIC.len() - && !data.starts_with(&SINGLE_OBJECT_MAGIC) - { - return Err(ArrowError::ParseError( - "Expected single‑object encoding fingerprint prefix for first message \ - (writer_schema_store is set but active_fingerprint is None)" - .into(), - )); - } let mut total_consumed = 0usize; - // The loop stops when the batch is full, a schema change is staged, - // or handle_prefix indicates we need more bytes (Some(0)). while total_consumed < data.len() && self.remaining_capacity > 0 { - if let Some(n) = self.handle_prefix(&data[total_consumed..])? { - // We either consumed a prefix (n > 0) and need a schema switch, or we need - // more bytes to make a decision. Either way, this decoding attempt is finished. - total_consumed += n; + if self.awaiting_body { + match self.active_decoder.decode(&data[total_consumed..], 1) { + Ok(n) => { + self.remaining_capacity -= 1; + total_consumed += n; + self.awaiting_body = false; + continue; + } + Err(ref e) if is_incomplete_data(e) => break, + err => return err, + }; + } + match self.handle_prefix(&data[total_consumed..])? { + Some(0) => break, // insufficient bytes + Some(n) => { + total_consumed += n; + self.apply_pending_schema_if_batch_empty(); + self.awaiting_body = true; + } + None => { + return Err(ArrowError::ParseError( + "Missing magic bytes and fingerprint".to_string(), + )) + } } - // No prefix: decode one row and keep going. - let n = self.active_decoder.decode(&data[total_consumed..], 1)?; - self.remaining_capacity -= 1; - total_consumed += n; } Ok(total_consumed) } @@ -195,10 +207,6 @@ impl Decoder { // * Ok(Some(0)) – prefix detected, but the buffer is too short; caller should await more bytes. // * Ok(Some(n)) – consumed `n > 0` bytes of a complete prefix (magic and fingerprint). fn handle_prefix(&mut self, buf: &[u8]) -> Result, ArrowError> { - // If there is no schema store, prefixes are unrecognized. - if !self.expect_prefix { - return Ok(None); - } // Need at least the magic bytes to decide (2 bytes). let Some(magic_bytes) = buf.get(..SINGLE_OBJECT_MAGIC.len()) else { return Ok(Some(0)); // Get more bytes @@ -252,15 +260,7 @@ impl Decoder { Ok(Some(N)) } - /// Produce a `RecordBatch` if at least one row is fully decoded, returning - /// `Ok(None)` if no new rows are available. - pub fn flush(&mut self) -> Result, ArrowError> { - if self.remaining_capacity == self.batch_size { - return Ok(None); - } - let batch = self.active_decoder.flush()?; - self.remaining_capacity = self.batch_size; - // Apply any staged schema switch. + fn apply_pending_schema(&mut self) { if let Some((new_fingerprint, new_decoder)) = self.pending_schema.take() { if let Some(old_fingerprint) = self.active_fingerprint.replace(new_fingerprint) { let old_decoder = std::mem::replace(&mut self.active_decoder, new_decoder); @@ -270,9 +270,32 @@ impl Decoder { self.active_decoder = new_decoder; } } + } + + fn apply_pending_schema_if_batch_empty(&mut self) { + if self.batch_is_empty() { + self.apply_pending_schema(); + } + } + + fn flush_and_reset(&mut self) -> Result, ArrowError> { + if self.batch_is_empty() { + return Ok(None); + } + let batch = self.active_decoder.flush()?; + self.remaining_capacity = self.batch_size; Ok(Some(batch)) } + /// Produce a `RecordBatch` if at least one row is fully decoded, returning + /// `Ok(None)` if no new rows are available. + pub fn flush(&mut self) -> Result, ArrowError> { + // We must flush the active decoder before switching to the pending one. + let batch = self.flush_and_reset(); + self.apply_pending_schema(); + batch + } + /// Returns the number of rows that can be added to this decoder before it is full. pub fn capacity(&self) -> usize { self.remaining_capacity @@ -282,6 +305,31 @@ impl Decoder { pub fn batch_is_full(&self) -> bool { self.remaining_capacity == 0 } + + /// Returns true if the decoder has not decoded any batches yet. + pub fn batch_is_empty(&self) -> bool { + self.remaining_capacity == self.batch_size + } + + // Decode either the block count or remaining capacity from `data` (an OCF block payload). + // + // Returns the number of bytes consumed from `data` along with the number of records decoded. + fn decode_block(&mut self, data: &[u8], count: usize) -> Result<(usize, usize), ArrowError> { + // OCF decoding never interleaves records across blocks, so no chunking. + let to_decode = std::cmp::min(count, self.remaining_capacity); + if to_decode == 0 { + return Ok((0, 0)); + } + let consumed = self.active_decoder.decode(data, to_decode)?; + self.remaining_capacity -= to_decode; + Ok((consumed, to_decode)) + } + + // Produce a `RecordBatch` if at least one row is fully decoded, returning + // `Ok(None)` if no new rows are available. + fn flush_block(&mut self) -> Result, ArrowError> { + self.flush_and_reset() + } } /// A builder to create an [`Avro Reader`](Reader) that reads Avro data @@ -342,7 +390,6 @@ impl ReaderBuilder { active_decoder: RecordDecoder, active_fingerprint: Option, cache: IndexMap, - expect_prefix: bool, fingerprint_algorithm: FingerprintAlgorithm, ) -> Decoder { Decoder { @@ -351,11 +398,11 @@ impl ReaderBuilder { active_fingerprint, active_decoder, cache, - expect_prefix, utf8_view: self.utf8_view, fingerprint_algorithm, strict_mode: self.strict_mode, pending_schema: None, + awaiting_body: false, } } @@ -376,7 +423,6 @@ impl ReaderBuilder { record_decoder, None, IndexMap::new(), - false, FingerprintAlgorithm::Rabin, )); } @@ -423,7 +469,6 @@ impl ReaderBuilder { active_decoder, Some(start_fingerprint), cache, - true, store.fingerprint_algorithm(), )) } @@ -496,6 +541,7 @@ impl ReaderBuilder { decoder, block_decoder: BlockDecoder::default(), block_data: Vec::new(), + block_count: 0, block_cursor: 0, finished: false, }) @@ -521,6 +567,7 @@ pub struct Reader { decoder: Decoder, block_decoder: BlockDecoder, block_data: Vec, + block_count: usize, block_cursor: usize, finished: bool, } @@ -550,12 +597,12 @@ impl Reader { self.reader.consume(consumed); if let Some(block) = self.block_decoder.flush() { // Successfully decoded a block. - let block_data = if let Some(ref codec) = self.header.compression()? { + self.block_data = if let Some(ref codec) = self.header.compression()? { codec.decompress(&block.data)? } else { block.data }; - self.block_data = block_data; + self.block_count = block.count; self.block_cursor = 0; } else if consumed == 0 { // The block decoder made no progress on a non-empty buffer. @@ -564,11 +611,16 @@ impl Reader { )); } } - // Try to decode more rows from the current block. - let consumed = self.decoder.decode(&self.block_data[self.block_cursor..])?; - self.block_cursor += consumed; + // Decode as many rows as will fit in the current batch + if self.block_cursor < self.block_data.len() { + let (consumed, records_decoded) = self + .decoder + .decode_block(&self.block_data[self.block_cursor..], self.block_count)?; + self.block_cursor += consumed; + self.block_count -= records_decoded; + } } - self.decoder.flush() + self.decoder.flush_block() } } @@ -709,6 +761,35 @@ mod test { .expect("decoder") } + fn make_value_schema(pt: PrimitiveType) -> AvroSchema { + let json_schema = format!( + r#"{{"type":"record","name":"S","fields":[{{"name":"v","type":"{}"}}]}}"#, + pt.as_ref() + ); + AvroSchema::new(json_schema) + } + + fn encode_zigzag(value: i64) -> Vec { + let mut n = ((value << 1) ^ (value >> 63)) as u64; + let mut out = Vec::new(); + loop { + if (n & !0x7F) == 0 { + out.push(n as u8); + break; + } else { + out.push(((n & 0x7F) | 0x80) as u8); + n >>= 7; + } + } + out + } + + fn make_message(fp: Fingerprint, value: i64) -> Vec { + let mut msg = make_prefix(fp); + msg.extend_from_slice(&encode_zigzag(value)); + msg + } + #[test] fn test_schema_store_register_lookup() { let schema_int = make_record_schema(PrimitiveType::Int); @@ -735,35 +816,6 @@ mod test { ); } - #[test] - fn test_missing_initial_fingerprint_error() { - let (store, _fp_int, _fp_long, schema_int, _schema_long) = make_two_schema_store(); - let mut decoder = ReaderBuilder::new() - .with_batch_size(8) - .with_reader_schema(schema_int.clone()) - .with_writer_schema_store(store) - .build_decoder() - .unwrap(); - let buf = [0x02u8, 0x00u8]; - let err = decoder.decode(&buf).expect_err("decode should error"); - let msg = err.to_string(); - assert!( - msg.contains("Expected single‑object encoding fingerprint"), - "unexpected message: {msg}" - ); - } - - #[test] - fn test_handle_prefix_no_schema_store() { - let (store, fp_int, _fp_long, schema_int, _schema_long) = make_two_schema_store(); - let mut decoder = make_decoder(&store, fp_int, &schema_int); - decoder.expect_prefix = false; - let res = decoder - .handle_prefix(&SINGLE_OBJECT_MAGIC[..]) - .expect("handle_prefix"); - assert!(res.is_none(), "Expected None when expect_prefix is false"); - } - #[test] fn test_handle_prefix_incomplete_magic() { let (store, fp_int, _fp_long, schema_int, _schema_long) = make_two_schema_store(); @@ -815,6 +867,219 @@ mod test { assert_eq!(decoder.pending_schema.as_ref().unwrap().0, fp_long); } + #[test] + fn test_two_messages_same_schema() { + let writer_schema = make_value_schema(PrimitiveType::Int); + let reader_schema = writer_schema.clone(); + let mut store = SchemaStore::new(); + let fp = store.register(writer_schema).unwrap(); + let msg1 = make_message(fp, 42); + let msg2 = make_message(fp, 11); + let input = [msg1.clone(), msg2.clone()].concat(); + let mut decoder = ReaderBuilder::new() + .with_batch_size(8) + .with_reader_schema(reader_schema.clone()) + .with_writer_schema_store(store) + .with_active_fingerprint(fp) + .build_decoder() + .unwrap(); + let _ = decoder.decode(&input).unwrap(); + let batch = decoder.flush().unwrap().expect("batch"); + assert_eq!(batch.num_rows(), 2); + let col = batch + .column(0) + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(col.value(0), 42); + assert_eq!(col.value(1), 11); + } + + #[test] + fn test_two_messages_schema_switch() { + let w_int = make_value_schema(PrimitiveType::Int); + let w_long = make_value_schema(PrimitiveType::Long); + let r_long = w_long.clone(); + let mut store = SchemaStore::new(); + let fp_int = store.register(w_int).unwrap(); + let fp_long = store.register(w_long).unwrap(); + let msg_int = make_message(fp_int, 1); + let msg_long = make_message(fp_long, 123456789_i64); + let mut decoder = ReaderBuilder::new() + .with_batch_size(8) + .with_writer_schema_store(store) + .with_active_fingerprint(fp_int) + .build_decoder() + .unwrap(); + let _ = decoder.decode(&msg_int).unwrap(); + let batch1 = decoder.flush().unwrap().expect("batch1"); + assert_eq!(batch1.num_rows(), 1); + assert_eq!( + batch1 + .column(0) + .as_any() + .downcast_ref::() + .unwrap() + .value(0), + 1 + ); + let _ = decoder.decode(&msg_long).unwrap(); + let batch2 = decoder.flush().unwrap().expect("batch2"); + assert_eq!(batch2.num_rows(), 1); + assert_eq!( + batch2 + .column(0) + .as_any() + .downcast_ref::() + .unwrap() + .value(0), + 123456789_i64 + ); + } + + #[test] + fn test_split_message_across_chunks() { + let writer_schema = make_value_schema(PrimitiveType::Int); + let reader_schema = writer_schema.clone(); + let mut store = SchemaStore::new(); + let fp = store.register(writer_schema).unwrap(); + let msg1 = make_message(fp, 7); + let msg2 = make_message(fp, 8); + let msg3 = make_message(fp, 9); + let (pref2, body2) = msg2.split_at(10); + let (pref3, body3) = msg3.split_at(10); + let mut decoder = ReaderBuilder::new() + .with_batch_size(8) + .with_reader_schema(reader_schema) + .with_writer_schema_store(store) + .with_active_fingerprint(fp) + .build_decoder() + .unwrap(); + let _ = decoder.decode(&msg1).unwrap(); + let batch1 = decoder.flush().unwrap().expect("batch1"); + assert_eq!(batch1.num_rows(), 1); + assert_eq!( + batch1 + .column(0) + .as_any() + .downcast_ref::() + .unwrap() + .value(0), + 7 + ); + let _ = decoder.decode(pref2).unwrap(); + assert!(decoder.flush().unwrap().is_none()); + let mut chunk3 = Vec::from(body2); + chunk3.extend_from_slice(pref3); + let _ = decoder.decode(&chunk3).unwrap(); + let batch2 = decoder.flush().unwrap().expect("batch2"); + assert_eq!(batch2.num_rows(), 1); + assert_eq!( + batch2 + .column(0) + .as_any() + .downcast_ref::() + .unwrap() + .value(0), + 8 + ); + let _ = decoder.decode(body3).unwrap(); + let batch3 = decoder.flush().unwrap().expect("batch3"); + assert_eq!(batch3.num_rows(), 1); + assert_eq!( + batch3 + .column(0) + .as_any() + .downcast_ref::() + .unwrap() + .value(0), + 9 + ); + } + + #[test] + fn test_decode_stream_with_schema() { + struct TestCase<'a> { + name: &'a str, + schema: &'a str, + expected_error: Option<&'a str>, + } + let tests = vec![ + TestCase { + name: "success", + schema: r#"{"type":"record","name":"test","fields":[{"name":"f2","type":"string"}]}"#, + expected_error: None, + }, + TestCase { + name: "valid schema invalid data", + schema: r#"{"type":"record","name":"test","fields":[{"name":"f2","type":"long"}]}"#, + expected_error: Some("did not consume all bytes"), + }, + ]; + for test in tests { + let avro_schema = AvroSchema::new(test.schema.to_string()); + let mut store = SchemaStore::new(); + let fp = store.register(avro_schema.clone()).unwrap(); + let prefix = make_prefix(fp); + let record_val = "some_string"; + let mut body = prefix; + body.push((record_val.len() as u8) << 1); + body.extend_from_slice(record_val.as_bytes()); + let decoder_res = ReaderBuilder::new() + .with_batch_size(1) + .with_writer_schema_store(store) + .with_active_fingerprint(fp) + .build_decoder(); + let decoder = match decoder_res { + Ok(d) => d, + Err(e) => { + if let Some(expected) = test.expected_error { + assert!( + e.to_string().contains(expected), + "Test '{}' failed at build – expected '{expected}', got '{e}'", + test.name + ); + continue; + } else { + panic!("Test '{}' failed during build: {e}", test.name); + } + } + }; + let stream = Box::pin(stream::once(async { Bytes::from(body) })); + let decoded_stream = decode_stream(decoder, stream); + let batches_result: Result, ArrowError> = + block_on(decoded_stream.try_collect()); + match (batches_result, test.expected_error) { + (Ok(batches), None) => { + let batch = + arrow::compute::concat_batches(&batches[0].schema(), &batches).unwrap(); + let expected_field = Field::new("f2", DataType::Utf8, false); + let expected_schema = Arc::new(Schema::new(vec![expected_field])); + let expected_array = Arc::new(StringArray::from(vec![record_val])); + let expected_batch = + RecordBatch::try_new(expected_schema, vec![expected_array]).unwrap(); + assert_eq!(batch, expected_batch, "Test '{}'", test.name); + } + (Err(e), Some(expected)) => { + assert!( + e.to_string().contains(expected), + "Test '{}' – expected error containing '{expected}', got '{e}'", + test.name + ); + } + (Ok(_), Some(expected)) => { + panic!( + "Test '{}' expected failure ('{expected}') but succeeded", + test.name + ); + } + (Err(e), None) => { + panic!("Test '{}' unexpectedly failed with '{e}'", test.name); + } + } + } + } + #[test] fn test_utf8view_support() { let schema_json = r#"{ @@ -1128,89 +1393,6 @@ mod test { assert_eq!(batch, expected); } - #[test] - fn test_decode_stream_with_schema() { - struct TestCase<'a> { - name: &'a str, - schema: &'a str, - expected_error: Option<&'a str>, - } - let tests = vec![ - TestCase { - name: "success", - schema: r#"{"type":"record","name":"test","fields":[{"name":"f2","type":"string"}]}"#, - expected_error: None, - }, - TestCase { - name: "valid schema invalid data", - schema: r#"{"type":"record","name":"test","fields":[{"name":"f2","type":"long"}]}"#, - expected_error: Some("did not consume all bytes"), - }, - ]; - for test in tests { - let avro_schema = AvroSchema::new(test.schema.to_string()); - let mut store = SchemaStore::new(); - let fp = store.register(avro_schema.clone()).unwrap(); - let prefix = make_prefix(fp); - let record_val = "some_string"; - let mut body = prefix; - body.push((record_val.len() as u8) << 1); - body.extend_from_slice(record_val.as_bytes()); - let decoder_res = ReaderBuilder::new() - .with_batch_size(1) - .with_writer_schema_store(store) - .with_active_fingerprint(fp) - .build_decoder(); - let decoder = match decoder_res { - Ok(d) => d, - Err(e) => { - if let Some(expected) = test.expected_error { - assert!( - e.to_string().contains(expected), - "Test '{}' failed at build – expected '{expected}', got '{e}'", - test.name - ); - continue; - } else { - panic!("Test '{}' failed during build: {e}", test.name); - } - } - }; - let stream = Box::pin(stream::once(async { Bytes::from(body) })); - let decoded_stream = decode_stream(decoder, stream); - let batches_result: Result, ArrowError> = - block_on(decoded_stream.try_collect()); - match (batches_result, test.expected_error) { - (Ok(batches), None) => { - let batch = - arrow::compute::concat_batches(&batches[0].schema(), &batches).unwrap(); - let expected_field = Field::new("f2", DataType::Utf8, false); - let expected_schema = Arc::new(Schema::new(vec![expected_field])); - let expected_array = Arc::new(StringArray::from(vec![record_val])); - let expected_batch = - RecordBatch::try_new(expected_schema, vec![expected_array]).unwrap(); - assert_eq!(batch, expected_batch, "Test '{}'", test.name); - } - (Err(e), Some(expected)) => { - assert!( - e.to_string().contains(expected), - "Test '{}' – expected error containing '{expected}', got '{e}'", - test.name - ); - } - (Ok(_), Some(expected)) => { - panic!( - "Test '{}' expected failure ('{expected}') but succeeded", - test.name - ); - } - (Err(e), None) => { - panic!("Test '{}' unexpectedly failed with '{e}'", test.name); - } - } - } - } - #[test] fn test_decimal() { let files = [