Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 29 additions & 11 deletions rust/lance-arrow/src/scalar.rs
Original file line number Diff line number Diff line change
Expand Up @@ -141,20 +141,28 @@ pub fn decode_scalar_from_inline_value(
data_type: &DataType,
inline_value: &[u8],
) -> Result<ArrayRef> {
let byte_width = data_type.byte_width_opt().ok_or_else(|| {
ArrowError::InvalidArgumentError(format!(
"Inline constant is not supported for non-fixed-stride data type {:?}",
data_type
))
})?;

if inline_value.len() != byte_width {
return Err(ArrowError::InvalidArgumentError(format!(
// I expect our input to be safe here, but I added some debug_assert_eq statements just in case.
// If they are triggered, we may need to change them to return actual errors.
//
// Boolean values are bit-packed in Arrow and therefore are not "fixed-stride" in bytes.
// As a result, `byte_width_opt()` returns `None` for `DataType::Boolean`, even though a
// length-1 scalar can be represented inline using a single byte (matching `try_inline_value`).
if matches!(data_type, DataType::Boolean) {
debug_assert_eq!(
inline_value.len(),
1,
"Invalid boolean inline scalar length (expected 1 byte, got {})",
inline_value.len()
);
} else if let Some(byte_width) = data_type.byte_width_opt() {
debug_assert_eq!(
inline_value.len(),
byte_width,
"Inline constant length mismatch for {:?}: expected {} bytes but got {}",
data_type,
byte_width,
inline_value.len()
)));
);
}

let data = ArrayDataBuilder::new(data_type.clone())
Expand Down Expand Up @@ -187,7 +195,7 @@ pub fn try_inline_value(scalar: &ArrayRef) -> Option<Vec<u8>> {
mod tests {
use std::sync::Arc;

use arrow_array::{cast::AsArray, FixedSizeBinaryArray, Int32Array, StringArray};
use arrow_array::{cast::AsArray, BooleanArray, FixedSizeBinaryArray, Int32Array, StringArray};

use super::*;

Expand Down Expand Up @@ -231,6 +239,16 @@ mod tests {
assert_eq!(decoded.as_fixed_size_binary().value(0), val.as_slice());
}

#[test]
fn test_inline_value_boolean_round_trip() {
let scalar: ArrayRef = Arc::new(BooleanArray::from_iter([Some(true)]));
let inline = try_inline_value(&scalar).unwrap();
let decoded = decode_scalar_from_inline_value(&DataType::Boolean, &inline).unwrap();
assert_eq!(decoded.len(), 1);
assert_eq!(decoded.null_count(), 0);
assert!(decoded.as_boolean().value(0));
}

#[test]
fn test_scalar_value_buffer_rejects_nested_type() {
let field = Arc::new(arrow_schema::Field::new("item", DataType::Int32, false));
Expand Down
56 changes: 54 additions & 2 deletions rust/lance/src/dataset/tests/dataset_io.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@ use arrow_array::RecordBatchReader;
use arrow_array::{
cast::as_string_array,
types::{Float32Type, Int32Type},
ArrayRef, Int32Array, Int64Array, Int8Array, Int8DictionaryArray, RecordBatchIterator,
StringArray,
ArrayRef, BooleanArray, Int32Array, Int64Array, Int8Array, Int8DictionaryArray,
RecordBatchIterator, StringArray,
};
use arrow_array::{Array, FixedSizeListArray, Int16Array, Int16DictionaryArray, StructArray};
use arrow_ord::sort::sort_to_indices;
Expand Down Expand Up @@ -178,6 +178,58 @@ async fn test_create_and_fill_empty_dataset(
assert_eq!(&expected_struct_arr, as_struct_array(sorted_arr.as_ref()));
}

#[tokio::test]
async fn test_scan_constant_boolean_inline_value_v2_2() {
let test_uri = TempStrDir::default();
let schema = Arc::new(ArrowSchema::new(vec![ArrowField::new(
"flag",
DataType::Boolean,
false,
)]));

let rows = 1024usize;
let flags: ArrayRef = Arc::new(BooleanArray::from_iter(std::iter::repeat_n(true, rows)));
let batch = RecordBatch::try_new(schema.clone(), vec![flags]).unwrap();
let reader = RecordBatchIterator::new(vec![Ok(batch)].into_iter(), schema.clone());

Dataset::write(
reader,
&test_uri,
Some(WriteParams {
data_storage_version: Some(LanceFileVersion::V2_2),
..Default::default()
}),
)
.await
.unwrap();

let ds = Dataset::open(&test_uri).await.unwrap();
let batches = ds
.scan()
.project(&["flag"])
.unwrap()
.try_into_stream()
.await
.unwrap()
.try_collect::<Vec<_>>()
.await
.unwrap();

let total_rows: usize = batches.iter().map(|b| b.num_rows()).sum();
assert_eq!(total_rows, rows);
for batch in batches {
let flags = batch
.column_by_name("flag")
.unwrap()
.as_any()
.downcast_ref::<BooleanArray>()
.unwrap();
for i in 0..flags.len() {
assert!(flags.value(i));
}
}
}

#[rstest]
#[lance_test_macros::test(tokio::test)]
async fn test_create_with_empty_iter(
Expand Down
Loading