diff --git a/rust/lance-arrow/src/scalar.rs b/rust/lance-arrow/src/scalar.rs index e475b0561a7..f32a831648b 100644 --- a/rust/lance-arrow/src/scalar.rs +++ b/rust/lance-arrow/src/scalar.rs @@ -141,20 +141,28 @@ pub fn decode_scalar_from_inline_value( data_type: &DataType, inline_value: &[u8], ) -> Result { - let byte_width = data_type.byte_width_opt().ok_or_else(|| { - ArrowError::InvalidArgumentError(format!( - "Inline constant is not supported for non-fixed-stride data type {:?}", - data_type - )) - })?; - - if inline_value.len() != byte_width { - return Err(ArrowError::InvalidArgumentError(format!( + // I expect our input to be safe here, but I added some debug_assert_eq statements just in case. + // If they are triggered, we may need to change them to return actual errors. + // + // Boolean values are bit-packed in Arrow and therefore are not "fixed-stride" in bytes. + // As a result, `byte_width_opt()` returns `None` for `DataType::Boolean`, even though a + // length-1 scalar can be represented inline using a single byte (matching `try_inline_value`). + if matches!(data_type, DataType::Boolean) { + debug_assert_eq!( + inline_value.len(), + 1, + "Invalid boolean inline scalar length (expected 1 byte, got {})", + inline_value.len() + ); + } else if let Some(byte_width) = data_type.byte_width_opt() { + debug_assert_eq!( + inline_value.len(), + byte_width, "Inline constant length mismatch for {:?}: expected {} bytes but got {}", data_type, byte_width, inline_value.len() - ))); + ); } let data = ArrayDataBuilder::new(data_type.clone()) @@ -187,7 +195,7 @@ pub fn try_inline_value(scalar: &ArrayRef) -> Option> { mod tests { use std::sync::Arc; - use arrow_array::{cast::AsArray, FixedSizeBinaryArray, Int32Array, StringArray}; + use arrow_array::{cast::AsArray, BooleanArray, FixedSizeBinaryArray, Int32Array, StringArray}; use super::*; @@ -231,6 +239,16 @@ mod tests { assert_eq!(decoded.as_fixed_size_binary().value(0), val.as_slice()); } + #[test] + fn test_inline_value_boolean_round_trip() { + let scalar: ArrayRef = Arc::new(BooleanArray::from_iter([Some(true)])); + let inline = try_inline_value(&scalar).unwrap(); + let decoded = decode_scalar_from_inline_value(&DataType::Boolean, &inline).unwrap(); + assert_eq!(decoded.len(), 1); + assert_eq!(decoded.null_count(), 0); + assert!(decoded.as_boolean().value(0)); + } + #[test] fn test_scalar_value_buffer_rejects_nested_type() { let field = Arc::new(arrow_schema::Field::new("item", DataType::Int32, false)); diff --git a/rust/lance/src/dataset/tests/dataset_io.rs b/rust/lance/src/dataset/tests/dataset_io.rs index cffbb97c706..5aade47d9e1 100644 --- a/rust/lance/src/dataset/tests/dataset_io.rs +++ b/rust/lance/src/dataset/tests/dataset_io.rs @@ -22,8 +22,8 @@ use arrow_array::RecordBatchReader; use arrow_array::{ cast::as_string_array, types::{Float32Type, Int32Type}, - ArrayRef, Int32Array, Int64Array, Int8Array, Int8DictionaryArray, RecordBatchIterator, - StringArray, + ArrayRef, BooleanArray, Int32Array, Int64Array, Int8Array, Int8DictionaryArray, + RecordBatchIterator, StringArray, }; use arrow_array::{Array, FixedSizeListArray, Int16Array, Int16DictionaryArray, StructArray}; use arrow_ord::sort::sort_to_indices; @@ -178,6 +178,58 @@ async fn test_create_and_fill_empty_dataset( assert_eq!(&expected_struct_arr, as_struct_array(sorted_arr.as_ref())); } +#[tokio::test] +async fn test_scan_constant_boolean_inline_value_v2_2() { + let test_uri = TempStrDir::default(); + let schema = Arc::new(ArrowSchema::new(vec![ArrowField::new( + "flag", + DataType::Boolean, + false, + )])); + + let rows = 1024usize; + let flags: ArrayRef = Arc::new(BooleanArray::from_iter(std::iter::repeat_n(true, rows))); + let batch = RecordBatch::try_new(schema.clone(), vec![flags]).unwrap(); + let reader = RecordBatchIterator::new(vec![Ok(batch)].into_iter(), schema.clone()); + + Dataset::write( + reader, + &test_uri, + Some(WriteParams { + data_storage_version: Some(LanceFileVersion::V2_2), + ..Default::default() + }), + ) + .await + .unwrap(); + + let ds = Dataset::open(&test_uri).await.unwrap(); + let batches = ds + .scan() + .project(&["flag"]) + .unwrap() + .try_into_stream() + .await + .unwrap() + .try_collect::>() + .await + .unwrap(); + + let total_rows: usize = batches.iter().map(|b| b.num_rows()).sum(); + assert_eq!(total_rows, rows); + for batch in batches { + let flags = batch + .column_by_name("flag") + .unwrap() + .as_any() + .downcast_ref::() + .unwrap(); + for i in 0..flags.len() { + assert!(flags.value(i)); + } + } +} + #[rstest] #[lance_test_macros::test(tokio::test)] async fn test_create_with_empty_iter(