diff --git a/arrow-ipc/src/reader.rs b/arrow-ipc/src/reader.rs index 2ab618177754..ca0d09e2282f 100644 --- a/arrow-ipc/src/reader.rs +++ b/arrow-ipc/src/reader.rs @@ -1484,10 +1484,11 @@ mod tests { use super::*; - use crate::root_as_message; + use crate::convert::fb_to_schema; + use crate::{root_as_footer, root_as_message}; use arrow_array::builder::{PrimitiveRunBuilder, UnionBuilder}; use arrow_array::types::*; - use arrow_buffer::NullBuffer; + use arrow_buffer::{NullBuffer, OffsetBuffer}; use arrow_data::ArrayDataBuilder; fn create_test_projection_schema() -> Schema { @@ -1724,27 +1725,73 @@ mod tests { }); } - fn roundtrip_ipc(rb: &RecordBatch) -> RecordBatch { + /// Write the record batch to an in-memory buffer in IPC File format + fn write_ipc(rb: &RecordBatch) -> Vec { let mut buf = Vec::new(); let mut writer = crate::writer::FileWriter::try_new(&mut buf, rb.schema_ref()).unwrap(); writer.write(rb).unwrap(); writer.finish().unwrap(); - drop(writer); + buf + } - let mut reader = FileReader::try_new(std::io::Cursor::new(buf), None).unwrap(); - reader.next().unwrap().unwrap() + /// Return the first record batch read from the IPC File buffer + fn read_ipc(buf: &[u8]) -> Result { + let mut reader = FileReader::try_new(std::io::Cursor::new(buf), None)?; + reader.next().unwrap() } - fn roundtrip_ipc_stream(rb: &RecordBatch) -> RecordBatch { + fn roundtrip_ipc(rb: &RecordBatch) -> RecordBatch { + let buf = write_ipc(rb); + read_ipc(&buf).unwrap() + } + + /// Return the first record batch read from the IPC File buffer + /// using the FileDecoder API + fn read_ipc_with_decoder(buf: Vec) -> Result { + let buffer = Buffer::from_vec(buf); + let trailer_start = buffer.len() - 10; + let footer_len = read_footer_length(buffer[trailer_start..].try_into().unwrap())?; + let footer = root_as_footer(&buffer[trailer_start - footer_len..trailer_start]) + .map_err(|e| ArrowError::InvalidArgumentError(format!("Invalid footer: {e}")))?; + + let schema = fb_to_schema(footer.schema().unwrap()); + + let mut decoder = FileDecoder::new(Arc::new(schema), footer.version()); + // Read dictionaries + for block in footer.dictionaries().iter().flatten() { + let block_len = block.bodyLength() as usize + block.metaDataLength() as usize; + let data = buffer.slice_with_length(block.offset() as _, block_len); + decoder.read_dictionary(block, &data)? + } + + // Read record batch + let batches = footer.recordBatches().unwrap(); + assert_eq!(batches.len(), 1); // Only wrote a single batch + + let block = batches.get(0); + let block_len = block.bodyLength() as usize + block.metaDataLength() as usize; + let data = buffer.slice_with_length(block.offset() as _, block_len); + Ok(decoder.read_record_batch(block, &data)?.unwrap()) + } + + /// Write the record batch to an in-memory buffer in IPC Stream format + fn write_stream(rb: &RecordBatch) -> Vec { let mut buf = Vec::new(); let mut writer = crate::writer::StreamWriter::try_new(&mut buf, rb.schema_ref()).unwrap(); writer.write(rb).unwrap(); writer.finish().unwrap(); - drop(writer); + buf + } + + /// Return the first record batch read from the IPC Stream buffer + fn read_stream(buf: &[u8]) -> Result { + let mut reader = StreamReader::try_new(std::io::Cursor::new(buf), None)?; + reader.next().unwrap() + } - let mut reader = - crate::reader::StreamReader::try_new(std::io::Cursor::new(buf), None).unwrap(); - reader.next().unwrap().unwrap() + fn roundtrip_ipc_stream(rb: &RecordBatch) -> RecordBatch { + let buf = write_stream(rb); + read_stream(&buf).unwrap() } #[test] @@ -2403,17 +2450,10 @@ mod tests { .build_unchecked(), ) }; - - let batch = RecordBatch::try_new(schema.clone(), vec![invalid_struct_arr]).unwrap(); - - let mut buf = Vec::new(); - let mut writer = crate::writer::FileWriter::try_new(&mut buf, schema.as_ref()).unwrap(); - writer.write(&batch).unwrap(); - writer.finish().unwrap(); - - let mut reader = FileReader::try_new(std::io::Cursor::new(buf), None).unwrap(); - let err = reader.next().unwrap().unwrap_err(); - assert!(matches!(err, ArrowError::InvalidArgumentError(_))); + expect_ipc_validation_error( + Arc::new(invalid_struct_arr), + "Invalid argument error: Incorrect array length for StructArray field \"b\", expected 4 got 3", + ); } #[test] @@ -2472,4 +2512,109 @@ mod tests { assert_eq!(decoded_batch.expect("Failed to read RecordBatch"), batch); }); } + + #[test] + fn test_validation_of_invalid_list_array() { + // ListArray with invalid offsets + let array = unsafe { + let values = Int32Array::from(vec![1, 2, 3]); + let bad_offsets = ScalarBuffer::::from(vec![0, 2, 4, 2]); // offsets can't go backwards + let offsets = OffsetBuffer::new_unchecked(bad_offsets); // INVALID array created + let field = Field::new_list_field(DataType::Int32, true); + let nulls = None; + ListArray::new(Arc::new(field), offsets, Arc::new(values), nulls) + }; + + expect_ipc_validation_error( + Arc::new(array), + "Invalid argument error: Offset invariant failure: offset at position 2 out of bounds: 4 > 2" + ); + } + + #[test] + fn test_validation_of_invalid_string_array() { + let valid: &[u8] = b" "; + let mut invalid = vec![]; + invalid.extend_from_slice(b"ThisStringIsCertainlyLongerThan12Bytes"); + invalid.extend_from_slice(INVALID_UTF8_FIRST_CHAR); + let binary_array = BinaryArray::from_iter(vec![None, Some(valid), None, Some(&invalid)]); + // data is not valid utf8 we can not construct a correct StringArray + // safely, so purposely create an invalid StringArray + let array = unsafe { + StringArray::new_unchecked( + binary_array.offsets().clone(), + binary_array.values().clone(), + binary_array.nulls().cloned(), + ) + }; + expect_ipc_validation_error( + Arc::new(array), + "Invalid argument error: Invalid UTF8 sequence at string index 3 (3..45): invalid utf-8 sequence of 1 bytes from index 38" + ); + } + + #[test] + fn test_validation_of_invalid_string_view_array() { + let valid: &[u8] = b" "; + let mut invalid = vec![]; + invalid.extend_from_slice(b"ThisStringIsCertainlyLongerThan12Bytes"); + invalid.extend_from_slice(INVALID_UTF8_FIRST_CHAR); + let binary_view_array = + BinaryViewArray::from_iter(vec![None, Some(valid), None, Some(&invalid)]); + // data is not valid utf8 we can not construct a correct StringArray + // safely, so purposely create an invalid StringArray + let array = unsafe { + StringViewArray::new_unchecked( + binary_view_array.views().clone(), + binary_view_array.data_buffers().to_vec(), + binary_view_array.nulls().cloned(), + ) + }; + expect_ipc_validation_error( + Arc::new(array), + "Invalid argument error: Encountered non-UTF-8 data at index 3: invalid utf-8 sequence of 1 bytes from index 38" + ); + } + + /// return an invalid dictionary array (key is larger than values) + /// ListArray with invalid offsets + #[test] + fn test_validation_of_invalid_dictionary_array() { + let array = unsafe { + let values = StringArray::from_iter_values(["a", "b", "c"]); + let keys = Int32Array::from(vec![1, 200]); // keys are not valid for values + DictionaryArray::new_unchecked(keys, Arc::new(values)) + }; + + expect_ipc_validation_error( + Arc::new(array), + "Invalid argument error: Value at position 1 out of bounds: 200 (should be in [0, 2])", + ); + } + + /// Invalid Utf-8 sequence in the first character + /// + const INVALID_UTF8_FIRST_CHAR: &[u8] = &[0xa0, 0xa1, 0x20, 0x20]; + + /// Expect an error when reading the record batch using IPC or IPC Streams + fn expect_ipc_validation_error(array: ArrayRef, expected_err: &str) { + let rb = RecordBatch::try_from_iter([("a", array)]).unwrap(); + + // IPC Stream format + let buf = write_stream(&rb); // write is ok + let err = read_stream(&buf).unwrap_err(); + assert_eq!(err.to_string(), expected_err); + + // IPC File format + let buf = write_ipc(&rb); // write is ok + let err = read_ipc(&buf).unwrap_err(); + assert_eq!(err.to_string(), expected_err); + + // TODO verify there is no error when validation is disabled + // see https://github.com/apache/arrow-rs/issues/3287 + + // IPC Format with FileDecoder + let err = read_ipc_with_decoder(buf).unwrap_err(); + assert_eq!(err.to_string(), expected_err); + } }