diff --git a/arrow-data/src/data.rs b/arrow-data/src/data.rs index 4314b550d680..f62a5de68e22 100644 --- a/arrow-data/src/data.rs +++ b/arrow-data/src/data.rs @@ -28,7 +28,6 @@ use std::mem; use std::ops::Range; use std::sync::Arc; -use crate::data::private::UnsafeFlag; use crate::{equal, validate_binary_view, validate_string_view}; #[inline] @@ -1783,33 +1782,55 @@ impl PartialEq for ArrayData { } } -mod private { - /// A boolean flag that cannot be mutated outside of unsafe code. - /// - /// Defaults to a value of false. +/// A boolean flag that cannot be mutated outside of unsafe code. +/// +/// Defaults to a value of false. +/// +/// This structure is used to enforce safety in the [`ArrayDataBuilder`] +/// +/// [`ArrayDataBuilder`]: super::ArrayDataBuilder +/// +/// # Example +/// ```rust +/// use arrow_data::UnsafeFlag; +/// assert!(!UnsafeFlag::default().get()); // default is false +/// let mut flag = UnsafeFlag::new(); +/// assert!(!flag.get()); // defaults to false +/// // can only set it to true in unsafe code +/// unsafe { flag.set(true) }; +/// assert!(flag.get()); // now true +/// ``` +#[derive(Debug, Copy, Clone)] +pub struct UnsafeFlag(bool); + +impl UnsafeFlag { + /// Creates a new `UnsafeFlag` with the value set to `false` /// - /// This structure is used to enforce safety in the [`ArrayDataBuilder`] + /// See examples on [`Self::new`] + #[inline] + pub const fn new() -> Self { + Self(false) + } + + /// Sets the value of the flag to the given value /// - /// [`ArrayDataBuilder`]: super::ArrayDataBuilder - #[derive(Debug)] - pub struct UnsafeFlag(bool); - - impl UnsafeFlag { - /// Creates a new `UnsafeFlag` with the value set to `false` - #[inline] - pub const fn new() -> Self { - Self(false) - } + /// Note this can only be done in `unsafe` code + #[inline] + pub unsafe fn set(&mut self, val: bool) { + self.0 = val; + } - #[inline] - pub unsafe fn set(&mut self, val: bool) { - self.0 = val; - } + /// Returns the value of the flag + #[inline] + pub fn get(&self) -> bool { + self.0 + } +} - #[inline] - pub fn get(&self) -> bool { - self.0 - } +// Manual impl to make it clear you can not construct unsafe with true +impl Default for UnsafeFlag { + fn default() -> Self { + Self::new() } } @@ -2040,6 +2061,16 @@ impl ArrayDataBuilder { self.skip_validation.set(skip_validation); self } + + /// Specifies skipping validation of the data based on an [`UnsafeFlag`] + /// + /// # Safety + /// While this function is safe, setting the flag to true can only be done + /// in `unsafe` code. See [`Self::skip_validation`] for more details + pub fn with_skip_validation(mut self, skip_validation: UnsafeFlag) -> Self { + self.skip_validation = skip_validation; + self + } } impl From for ArrayDataBuilder { diff --git a/arrow-ipc/src/reader.rs b/arrow-ipc/src/reader.rs index e79ab2321147..6faba96d46f8 100644 --- a/arrow-ipc/src/reader.rs +++ b/arrow-ipc/src/reader.rs @@ -36,7 +36,7 @@ use std::sync::Arc; use arrow_array::*; use arrow_buffer::{ArrowNativeType, BooleanBuffer, Buffer, MutableBuffer, ScalarBuffer}; -use arrow_data::ArrayData; +use arrow_data::{ArrayData, UnsafeFlag}; use arrow_schema::*; use crate::compression::CompressionCodec; @@ -65,6 +65,7 @@ fn read_buffer( (false, Some(decompressor)) => decompressor.decompress_to_buffer(&buf_data), } } + impl RecordBatchDecoder<'_> { /// Coordinates reading arrays based on data types. /// @@ -85,16 +86,15 @@ impl RecordBatchDecoder<'_> { ) -> Result { let data_type = field.data_type(); match data_type { - Utf8 | Binary | LargeBinary | LargeUtf8 => create_primitive_array( - self.next_node(field)?, - data_type, - &[ + Utf8 | Binary | LargeBinary | LargeUtf8 => { + let field_node = self.next_node(field)?; + let buffers = [ self.next_buffer()?, self.next_buffer()?, self.next_buffer()?, - ], - self.require_alignment, - ), + ]; + self.create_primitive_array(field_node, data_type, &buffers) + } BinaryView | Utf8View => { let count = variadic_counts .pop_front() @@ -105,42 +105,25 @@ impl RecordBatchDecoder<'_> { let buffers = (0..count) .map(|_| self.next_buffer()) .collect::, _>>()?; - create_primitive_array( - self.next_node(field)?, - data_type, - &buffers, - self.require_alignment, - ) + let field_node = self.next_node(field)?; + self.create_primitive_array(field_node, data_type, &buffers) + } + FixedSizeBinary(_) => { + let field_node = self.next_node(field)?; + let buffers = [self.next_buffer()?, self.next_buffer()?]; + self.create_primitive_array(field_node, data_type, &buffers) } - FixedSizeBinary(_) => create_primitive_array( - self.next_node(field)?, - data_type, - &[self.next_buffer()?, self.next_buffer()?], - self.require_alignment, - ), List(ref list_field) | LargeList(ref list_field) | Map(ref list_field, _) => { let list_node = self.next_node(field)?; let list_buffers = [self.next_buffer()?, self.next_buffer()?]; let values = self.create_array(list_field, variadic_counts)?; - create_list_array( - list_node, - data_type, - &list_buffers, - values, - self.require_alignment, - ) + self.create_list_array(list_node, data_type, &list_buffers, values) } FixedSizeList(ref list_field, _) => { let list_node = self.next_node(field)?; let list_buffers = [self.next_buffer()?]; let values = self.create_array(list_field, variadic_counts)?; - create_list_array( - list_node, - data_type, - &list_buffers, - values, - self.require_alignment, - ) + self.create_list_array(list_node, data_type, &list_buffers, values) } Struct(struct_fields) => { let struct_node = self.next_node(field)?; @@ -185,6 +168,7 @@ impl RecordBatchDecoder<'_> { .add_child_data(run_ends.into_data()) .add_child_data(values.into_data()) .align_buffers(!self.require_alignment) + .with_skip_validation(self.skip_validation) .build()?; Ok(make_array(array_data)) @@ -205,12 +189,11 @@ impl RecordBatchDecoder<'_> { )) })?; - create_dictionary_array( + self.create_dictionary_array( index_node, data_type, &index_buffers, value_array.clone(), - self.require_alignment, ) } Union(fields, mode) => { @@ -260,112 +243,119 @@ impl RecordBatchDecoder<'_> { .len(length as usize) .offset(0) .align_buffers(!self.require_alignment) + .with_skip_validation(self.skip_validation) .build()?; // no buffer increases Ok(Arc::new(NullArray::from(array_data))) } - _ => create_primitive_array( - self.next_node(field)?, - data_type, - &[self.next_buffer()?, self.next_buffer()?], - self.require_alignment, - ), + _ => { + let field_node = self.next_node(field)?; + let buffers = [self.next_buffer()?, self.next_buffer()?]; + self.create_primitive_array(field_node, data_type, &buffers) + } } } -} -/// Reads the correct number of buffers based on data type and null_count, and creates a -/// primitive array ref -fn create_primitive_array( - field_node: &FieldNode, - data_type: &DataType, - buffers: &[Buffer], - require_alignment: bool, -) -> Result { - let length = field_node.length() as usize; - let null_buffer = (field_node.null_count() > 0).then_some(buffers[0].clone()); - let builder = match data_type { - Utf8 | Binary | LargeBinary | LargeUtf8 => { - // read 3 buffers: null buffer (optional), offsets buffer and data buffer - ArrayData::builder(data_type.clone()) - .len(length) - .buffers(buffers[1..3].to_vec()) - .null_bit_buffer(null_buffer) - } - BinaryView | Utf8View => ArrayData::builder(data_type.clone()) - .len(length) - .buffers(buffers[1..].to_vec()) - .null_bit_buffer(null_buffer), - _ if data_type.is_primitive() || matches!(data_type, Boolean | FixedSizeBinary(_)) => { - // read 2 buffers: null buffer (optional) and data buffer - ArrayData::builder(data_type.clone()) + /// Reads the correct number of buffers based on data type and null_count, and creates a + /// primitive array ref + fn create_primitive_array( + &self, + field_node: &FieldNode, + data_type: &DataType, + buffers: &[Buffer], + ) -> Result { + let length = field_node.length() as usize; + let null_buffer = (field_node.null_count() > 0).then_some(buffers[0].clone()); + let builder = match data_type { + Utf8 | Binary | LargeBinary | LargeUtf8 => { + // read 3 buffers: null buffer (optional), offsets buffer and data buffer + ArrayData::builder(data_type.clone()) + .len(length) + .buffers(buffers[1..3].to_vec()) + .null_bit_buffer(null_buffer) + } + BinaryView | Utf8View => ArrayData::builder(data_type.clone()) .len(length) - .add_buffer(buffers[1].clone()) - .null_bit_buffer(null_buffer) - } - t => unreachable!("Data type {:?} either unsupported or not primitive", t), - }; + .buffers(buffers[1..].to_vec()) + .null_bit_buffer(null_buffer), + _ if data_type.is_primitive() || matches!(data_type, Boolean | FixedSizeBinary(_)) => { + // read 2 buffers: null buffer (optional) and data buffer + ArrayData::builder(data_type.clone()) + .len(length) + .add_buffer(buffers[1].clone()) + .null_bit_buffer(null_buffer) + } + t => unreachable!("Data type {:?} either unsupported or not primitive", t), + }; - let array_data = builder.align_buffers(!require_alignment).build()?; + let array_data = builder + .align_buffers(!self.require_alignment) + .with_skip_validation(self.skip_validation) + .build()?; - Ok(make_array(array_data)) -} + Ok(make_array(array_data)) + } -/// Reads the correct number of buffers based on list type and null_count, and creates a -/// list array ref -fn create_list_array( - field_node: &FieldNode, - data_type: &DataType, - buffers: &[Buffer], - child_array: ArrayRef, - require_alignment: bool, -) -> Result { - let null_buffer = (field_node.null_count() > 0).then_some(buffers[0].clone()); - let length = field_node.length() as usize; - let child_data = child_array.into_data(); - let builder = match data_type { - List(_) | LargeList(_) | Map(_, _) => ArrayData::builder(data_type.clone()) - .len(length) - .add_buffer(buffers[1].clone()) - .add_child_data(child_data) - .null_bit_buffer(null_buffer), - - FixedSizeList(_, _) => ArrayData::builder(data_type.clone()) - .len(length) - .add_child_data(child_data) - .null_bit_buffer(null_buffer), - - _ => unreachable!("Cannot create list or map array from {:?}", data_type), - }; + /// Reads the correct number of buffers based on list type and null_count, and creates a + /// list array ref + fn create_list_array( + &self, + field_node: &FieldNode, + data_type: &DataType, + buffers: &[Buffer], + child_array: ArrayRef, + ) -> Result { + let null_buffer = (field_node.null_count() > 0).then_some(buffers[0].clone()); + let length = field_node.length() as usize; + let child_data = child_array.into_data(); + let builder = match data_type { + List(_) | LargeList(_) | Map(_, _) => ArrayData::builder(data_type.clone()) + .len(length) + .add_buffer(buffers[1].clone()) + .add_child_data(child_data) + .null_bit_buffer(null_buffer), - let array_data = builder.align_buffers(!require_alignment).build()?; + FixedSizeList(_, _) => ArrayData::builder(data_type.clone()) + .len(length) + .add_child_data(child_data) + .null_bit_buffer(null_buffer), - Ok(make_array(array_data)) -} + _ => unreachable!("Cannot create list or map array from {:?}", data_type), + }; -/// Reads the correct number of buffers based on list type and null_count, and creates a -/// list array ref -fn create_dictionary_array( - field_node: &FieldNode, - data_type: &DataType, - buffers: &[Buffer], - value_array: ArrayRef, - require_alignment: bool, -) -> Result { - if let Dictionary(_, _) = *data_type { - let null_buffer = (field_node.null_count() > 0).then_some(buffers[0].clone()); - let array_data = ArrayData::builder(data_type.clone()) - .len(field_node.length() as usize) - .add_buffer(buffers[1].clone()) - .add_child_data(value_array.into_data()) - .null_bit_buffer(null_buffer) - .align_buffers(!require_alignment) + let array_data = builder + .align_buffers(!self.require_alignment) + .with_skip_validation(self.skip_validation) .build()?; Ok(make_array(array_data)) - } else { - unreachable!("Cannot create dictionary array from {:?}", data_type) + } + + /// Reads the correct number of buffers based on list type and null_count, and creates a + /// list array ref + fn create_dictionary_array( + &self, + field_node: &FieldNode, + data_type: &DataType, + buffers: &[Buffer], + value_array: ArrayRef, + ) -> Result { + if let Dictionary(_, _) = *data_type { + let null_buffer = (field_node.null_count() > 0).then_some(buffers[0].clone()); + let array_data = ArrayData::builder(data_type.clone()) + .len(field_node.length() as usize) + .add_buffer(buffers[1].clone()) + .add_child_data(value_array.into_data()) + .null_bit_buffer(null_buffer) + .align_buffers(!self.require_alignment) + .with_skip_validation(self.skip_validation) + .build()?; + + Ok(make_array(array_data)) + } else { + unreachable!("Cannot create dictionary array from {:?}", data_type) + } } } @@ -396,6 +386,18 @@ struct RecordBatchDecoder<'a> { /// Are buffers required to already be aligned? See /// [`RecordBatchDecoder::with_require_alignment`] for details require_alignment: bool, + /// Should validation be skipped when reading data? + /// + /// Defaults to false. + /// + /// If true [`ArrayData::validate`] is not called after reading + /// + /// # Safety + /// + /// This flag can only be set to true using `unsafe` APIs. However, once true + /// subsequent calls to `build()` may result in undefined behavior if the data + /// is not valid. + skip_validation: UnsafeFlag, } impl<'a> RecordBatchDecoder<'a> { @@ -430,6 +432,7 @@ impl<'a> RecordBatchDecoder<'a> { buffers: buffers.iter(), projection: None, require_alignment: false, + skip_validation: UnsafeFlag::new(), }) } @@ -452,6 +455,21 @@ impl<'a> RecordBatchDecoder<'a> { self } + /// Set skip_validation (default: false) + /// + /// Note this is a pub(crate) API and can not be used outside of this crate + /// + /// If true, validation is skipped. + /// + /// # Safety + /// + /// Relies on `UnsafeFlag` to enforce safety -- can only be enabled via + /// unsafe APIs. + pub(crate) fn with_skip_validation(mut self, skip_validation: UnsafeFlag) -> Self { + self.skip_validation = skip_validation; + self + } + /// Read the record batch, consuming the reader fn read_record_batch(mut self) -> Result { let mut variadic_counts: VecDeque = self @@ -621,7 +639,16 @@ pub fn read_dictionary( dictionaries_by_id: &mut HashMap, metadata: &MetadataVersion, ) -> Result<(), ArrowError> { - read_dictionary_impl(buf, batch, schema, dictionaries_by_id, metadata, false) + let skip_validation = UnsafeFlag::new(); // do not skip valididation + read_dictionary_impl( + buf, + batch, + schema, + dictionaries_by_id, + metadata, + false, + skip_validation, + ) } fn read_dictionary_impl( @@ -631,6 +658,7 @@ fn read_dictionary_impl( dictionaries_by_id: &mut HashMap, metadata: &MetadataVersion, require_alignment: bool, + skip_validation: UnsafeFlag, ) -> Result<(), ArrowError> { if batch.isDelta() { return Err(ArrowError::InvalidArgumentError( @@ -662,6 +690,7 @@ fn read_dictionary_impl( metadata, )? .with_require_alignment(require_alignment) + .with_skip_validation(skip_validation) .read_record_batch()?; Some(record_batch.column(0).clone()) @@ -792,6 +821,7 @@ pub struct FileDecoder { version: MetadataVersion, projection: Option>, require_alignment: bool, + skip_validation: UnsafeFlag, } impl FileDecoder { @@ -803,6 +833,7 @@ impl FileDecoder { dictionaries: Default::default(), projection: None, require_alignment: false, + skip_validation: UnsafeFlag::new(), } } @@ -822,13 +853,28 @@ impl FileDecoder { /// If `require_alignment` is false (the default), this decoder will automatically allocate a /// new aligned buffer and copy over the data if any array data in the input `buf` is not /// properly aligned. (Properly aligned array data will remain zero-copy.) - /// Under the hood it will use [`arrow_data::ArrayDataBuilder::build_aligned`] to construct + /// Under the hood it will use [`arrow_data::ArrayDataBuilder::align_buffers`] to construct /// [`arrow_data::ArrayData`]. pub fn with_require_alignment(mut self, require_alignment: bool) -> Self { self.require_alignment = require_alignment; self } + /// Specifies whether validation should be skipped when reading data (default to `false`) + /// + /// # Safety + /// + /// This flag must only be set to `true` when you trust and are sure the data you are + /// reading is a valid Arrow IPC file, otherwise undefined behavior may + /// result. + /// + /// For example, some programs may wish to trust reading IPC files written + /// by the same process that created the files. + pub unsafe fn with_skip_validation(mut self, skip_validation: bool) -> Self { + self.skip_validation.set(skip_validation); + self + } + fn read_message<'a>(&self, buf: &'a [u8]) -> Result, ArrowError> { let message = parse_message(buf)?; @@ -854,6 +900,7 @@ impl FileDecoder { &mut self.dictionaries, &message.version(), self.require_alignment, + self.skip_validation, ) } t => Err(ArrowError::ParseError(format!( @@ -1270,6 +1317,9 @@ pub struct StreamReader { /// Optional projection projection: Option<(Vec, Schema)>, + + /// Should the reader skip validation + skip_validation: UnsafeFlag, } impl fmt::Debug for StreamReader { @@ -1349,6 +1399,7 @@ impl StreamReader { finished: false, dictionaries_by_id, projection, + skip_validation: UnsafeFlag::new(), }) } @@ -1457,6 +1508,7 @@ impl StreamReader { &mut self.dictionaries_by_id, &message.version(), false, + self.skip_validation, )?; // read the next message until we encounter a RecordBatch @@ -1482,6 +1534,21 @@ impl StreamReader { pub fn get_mut(&mut self) -> &mut R { &mut self.reader } + + /// Specifies whether validation should be skipped when reading data (default to `false`) + /// + /// # Safety + /// + /// This flag must only be set to `true` when you trust and are sure the data you are + /// reading is a valid Arrow IPC file, otherwise undefined behavior may + /// result. + /// + /// For example, some programs may wish to trust reading IPC files written + /// by the same process that created the files. + pub unsafe fn with_skip_validation(mut self, skip_validation: bool) -> Self { + self.skip_validation.set(skip_validation); + self + } } impl Iterator for StreamReader { diff --git a/arrow-ipc/src/reader/stream.rs b/arrow-ipc/src/reader/stream.rs index 174e69c1f670..65528b50a198 100644 --- a/arrow-ipc/src/reader/stream.rs +++ b/arrow-ipc/src/reader/stream.rs @@ -21,6 +21,7 @@ use std::sync::Arc; use arrow_array::{ArrayRef, RecordBatch}; use arrow_buffer::{Buffer, MutableBuffer}; +use arrow_data::UnsafeFlag; use arrow_schema::{ArrowError, SchemaRef}; use crate::convert::MessageBuffer; @@ -42,13 +43,15 @@ pub struct StreamDecoder { buf: MutableBuffer, /// Whether or not array data in input buffers are required to be aligned require_alignment: bool, + /// Should we skip validation when reading arrays? + skip_validation: UnsafeFlag, } #[derive(Debug)] enum DecoderState { /// Decoding the message header Header { - /// Temporary buffer + /// Temporaray buffer buf: [u8; 4], /// Number of bytes read into buf read: u8, @@ -102,6 +105,21 @@ impl StreamDecoder { self } + /// Specifies whether validation should be skipped when reading data (default to `false`) + /// + /// # Safety + /// + /// This flag must only be set to `true` when you trust and are sure the data you are + /// reading is a valid Arrow IPC stream, otherwise undefined behavior may + /// result. + /// + /// For example, some programs may wish to trust reading IPC streams written + /// by the same process that created the files. + pub unsafe fn with_skip_validation(mut self, skip_validation: bool) -> Self { + self.skip_validation.set(skip_validation); + self + } + /// Try to read the next [`RecordBatch`] from the provided [`Buffer`] /// /// [`Buffer::advance`] will be called on `buffer` for any consumed bytes. @@ -219,6 +237,7 @@ impl StreamDecoder { &version, )? .with_require_alignment(self.require_alignment) + .with_skip_validation(self.skip_validation) .read_record_batch()?; self.state = DecoderState::default(); return Ok(Some(batch)); @@ -235,6 +254,7 @@ impl StreamDecoder { &mut self.dictionaries, &version, self.require_alignment, + self.skip_validation, )?; self.state = DecoderState::default(); }