diff --git a/docs/src/guide/migration.md b/docs/src/guide/migration.md index 92c06129307..9b7471ed07c 100644 --- a/docs/src/guide/migration.md +++ b/docs/src/guide/migration.md @@ -6,6 +6,13 @@ stable and breaking changes should generally be communicated (via warnings) for give users a chance to migrate. This page documents the breaking changes between releases and gives advice on how to migrate. +## 1.0.0 + +* The `SearchResult` returned by scalar indices must now output information about null values. + Instead of containing a `RowIdTreeMap`, it now contains a `NullableRowIdSet`. Expressions that + resolve to null values must be included in search results in the null set. This ensures that + `NOT` can be applied to index search results correctly. + ## 0.39 * The `lance` crate no longer re-exports utilities from `lance-arrow` such as `RecordBatchExt` or `SchemaExt`. In the diff --git a/python/python/tests/test_scalar_index.py b/python/python/tests/test_scalar_index.py index 83aadfe8558..90b2fa02ec4 100644 --- a/python/python/tests/test_scalar_index.py +++ b/python/python/tests/test_scalar_index.py @@ -1856,13 +1856,14 @@ def test_json_index(): ) -def test_null_handling(tmp_path: Path): +def test_null_handling(): tbl = pa.table( { "x": [1, 2, None, 3], + "y": ["a", "b", "c", None], } ) - dataset = lance.write_dataset(tbl, tmp_path / "dataset") + dataset = lance.write_dataset(tbl, "memory://test") def check(): assert dataset.to_table(filter="x IS NULL").num_rows == 1 @@ -1871,11 +1872,19 @@ def check(): assert dataset.to_table(filter="x < 5").num_rows == 3 assert dataset.to_table(filter="x IN (1, 2)").num_rows == 2 assert dataset.to_table(filter="x IN (1, 2, NULL)").num_rows == 2 + assert dataset.to_table(filter="x > 0 OR (y != 'a')").num_rows == 4 + assert dataset.to_table(filter="x > 0 AND (y != 'a')").num_rows == 1 + assert dataset.to_table(filter="y != 'a'").num_rows == 2 + # NOT should exclude nulls (issue #4756) + assert dataset.to_table(filter="NOT (x < 2)").num_rows == 2 + assert dataset.to_table(filter="NOT (x IN (1, 2))").num_rows == 1 + # Double NOT + assert dataset.to_table(filter="NOT (NOT (x < 2))").num_rows == 1 check() dataset.create_scalar_index("x", index_type="BITMAP") check() - dataset.create_scalar_index("x", index_type="BTREE") + dataset.create_scalar_index("y", index_type="BTREE") check() diff --git a/rust/lance-core/src/utils/mask.rs b/rust/lance-core/src/utils/mask.rs index 595825d26e8..c0d5347026e 100644 --- a/rust/lance-core/src/utils/mask.rs +++ b/rust/lance-core/src/utils/mask.rs @@ -3,8 +3,7 @@ use std::collections::HashSet; use std::io::Write; -use std::iter; -use std::ops::{Range, RangeBounds}; +use std::ops::{Range, RangeBounds, RangeInclusive}; use std::{collections::BTreeMap, io::Read}; use arrow_array::{Array, BinaryArray, GenericBinaryArray}; @@ -17,20 +16,22 @@ use crate::Result; use super::address::RowAddress; -/// A row id mask to select or deselect particular row ids -/// -/// If both the allow_list and the block_list are Some then the only selected -/// row ids are those that are in the allow_list but not in the block_list -/// (the block_list takes precedence) -/// -/// If both the allow_list and the block_list are None (the default) then -/// all row ids are selected -#[derive(Clone, Debug, Default, DeepSizeOf)] -pub struct RowIdMask { - /// If Some then only these row ids are selected - pub allow_list: Option, - /// If Some then these row ids are not selected. - pub block_list: Option, +mod nullable; + +pub use nullable::{NullableRowAddrSet, NullableRowIdMask}; + +/// A mask that selects or deselects rows based on an allow-list or block-list. +#[derive(Clone, Debug, DeepSizeOf, PartialEq)] +pub enum RowIdMask { + AllowList(RowAddrTreeMap), + BlockList(RowAddrTreeMap), +} + +impl Default for RowIdMask { + fn default() -> Self { + // Empty block list means all rows are allowed + Self::BlockList(RowAddrTreeMap::new()) + } } impl RowIdMask { @@ -41,124 +42,68 @@ impl RowIdMask { // Create a mask that doesn't allow anything pub fn allow_nothing() -> Self { - Self { - allow_list: Some(RowAddrTreeMap::new()), - block_list: None, - } + Self::AllowList(RowAddrTreeMap::new()) } // Create a mask from an allow list pub fn from_allowed(allow_list: RowAddrTreeMap) -> Self { - Self { - allow_list: Some(allow_list), - block_list: None, - } + Self::AllowList(allow_list) } // Create a mask from a block list pub fn from_block(block_list: RowAddrTreeMap) -> Self { - Self { - allow_list: None, - block_list: Some(block_list), + Self::BlockList(block_list) + } + + pub fn block_list(&self) -> Option<&RowAddrTreeMap> { + match self { + Self::BlockList(block_list) => Some(block_list), + _ => None, } } - // If there is both a block list and an allow list then collapse into just an allow list - pub fn normalize(self) -> Self { - if let Self { - allow_list: Some(mut allow_list), - block_list: Some(block_list), - } = self - { - allow_list -= &block_list; - Self { - allow_list: Some(allow_list), - block_list: None, - } - } else { - self + pub fn allow_list(&self) -> Option<&RowAddrTreeMap> { + match self { + Self::AllowList(allow_list) => Some(allow_list), + _ => None, } } /// True if the row_id is selected by the mask, false otherwise pub fn selected(&self, row_id: u64) -> bool { - match (&self.allow_list, &self.block_list) { - (None, None) => true, - (Some(allow_list), None) => allow_list.contains(row_id), - (None, Some(block_list)) => !block_list.contains(row_id), - (Some(allow_list), Some(block_list)) => { - allow_list.contains(row_id) && !block_list.contains(row_id) - } + match self { + Self::AllowList(allow_list) => allow_list.contains(row_id), + Self::BlockList(block_list) => !block_list.contains(row_id), } } /// Return the indices of the input row ids that were valid pub fn selected_indices<'a>(&self, row_ids: impl Iterator + 'a) -> Vec { - let enumerated_ids = row_ids.enumerate(); - match (&self.block_list, &self.allow_list) { - (Some(block_list), Some(allow_list)) => { - // Only take rows that are both in the allow list and not in the block list - enumerated_ids - .filter(|(_, row_id)| { - !block_list.contains(**row_id) && allow_list.contains(**row_id) - }) - .map(|(idx, _)| idx as u64) - .collect() - } - (Some(block_list), None) => { - // Take rows that are not in the block list - enumerated_ids - .filter(|(_, row_id)| !block_list.contains(**row_id)) - .map(|(idx, _)| idx as u64) - .collect() - } - (None, Some(allow_list)) => { - // Take rows that are in the allow list - enumerated_ids - .filter(|(_, row_id)| allow_list.contains(**row_id)) - .map(|(idx, _)| idx as u64) - .collect() - } - (None, None) => { - // We should not encounter this case because callers should - // check is_empty first. - panic!("selected_indices called but prefilter has nothing to filter with") - } - } + row_ids + .enumerate() + .filter_map(|(idx, row_id)| { + if self.selected(*row_id) { + Some(idx as u64) + } else { + None + } + }) + .collect() } /// Also block the given ids pub fn also_block(self, block_list: RowAddrTreeMap) -> Self { - if block_list.is_empty() { - return self; - } - if let Some(existing) = self.block_list { - Self { - block_list: Some(existing | block_list), - allow_list: self.allow_list, - } - } else { - Self { - block_list: Some(block_list), - allow_list: self.allow_list, - } + match self { + Self::AllowList(allow_list) => Self::AllowList(allow_list - block_list), + Self::BlockList(existing) => Self::BlockList(existing | block_list), } } /// Also allow the given ids pub fn also_allow(self, allow_list: RowAddrTreeMap) -> Self { - if let Some(existing) = self.allow_list { - Self { - block_list: self.block_list, - allow_list: Some(existing | allow_list), - } - } else { - Self { - block_list: self.block_list, - // allow_list = None means "all rows allowed" and so allowing - // more rows is meaningless - allow_list: None, - } + match self { + Self::AllowList(existing) => Self::AllowList(existing | allow_list), + Self::BlockList(block_list) => Self::BlockList(block_list - allow_list), } } @@ -175,13 +120,17 @@ impl RowIdMask { /// We serialize this as a variable length binary array with two items. The first item /// is the block list and the second item is the allow list. pub fn into_arrow(&self) -> Result { - let block_list_length = self - .block_list + // NOTE: This serialization format must be stable as it is used in IPC. + let (block_list, allow_list) = match self { + Self::AllowList(allow_list) => (None, Some(allow_list)), + Self::BlockList(block_list) => (Some(block_list), None), + }; + + let block_list_length = block_list .as_ref() .map(|bl| bl.serialized_size()) .unwrap_or(0); - let allow_list_length = self - .allow_list + let allow_list_length = allow_list .as_ref() .map(|al| al.serialized_size()) .unwrap_or(0); @@ -189,11 +138,11 @@ impl RowIdMask { let offsets = OffsetBuffer::from_lengths(lengths); let mut value_bytes = vec![0; block_list_length + allow_list_length]; let mut validity = vec![false, false]; - if let Some(block_list) = &self.block_list { + if let Some(block_list) = &block_list { validity[0] = true; block_list.serialize_into(&mut value_bytes[0..])?; } - if let Some(allow_list) = &self.allow_list { + if let Some(allow_list) = &allow_list { validity[1] = true; allow_list.serialize_into(&mut value_bytes[block_list_length..])?; } @@ -202,7 +151,7 @@ impl RowIdMask { Ok(BinaryArray::try_new(offsets, values, Some(nulls))?) } - /// Deserialize a row id mask from Arrow + /// Deserialize a row address mask from Arrow pub fn from_arrow(array: &GenericBinaryArray) -> Result { let block_list = if array.is_null(0) { None @@ -217,65 +166,40 @@ impl RowIdMask { Some(RowAddrTreeMap::deserialize_from(array.value(1))) } .transpose()?; - Ok(Self { - block_list, - allow_list, - }) + + let res = match (block_list, allow_list) { + (Some(bl), None) => Self::BlockList(bl), + (None, Some(al)) => Self::AllowList(al), + (Some(block), Some(allow)) => Self::AllowList(allow).also_block(block), + (None, None) => Self::all_rows(), + }; + Ok(res) } - /// Return the maximum number of row ids that could be selected by this mask + /// Return the maximum number of row addresses that could be selected by this mask /// - /// Will be None if there is no allow list + /// Will be None if this is a BlockList (unbounded) pub fn max_len(&self) -> Option { - if let Some(allow_list) = &self.allow_list { - // If there is a block list we could theoretically intersect the two - // but it's not clear if that is worth the effort. Feel free to add later. - allow_list.len() - } else { - None + match self { + Self::AllowList(selection) => selection.len(), + Self::BlockList(_) => None, } } - /// Iterate over the row ids that are selected by the mask + /// Iterate over the row addresses that are selected by the mask /// - /// This is only possible if there is an allow list and neither the - /// allow list nor the block list contain any "full fragment" blocks. - /// - /// TODO: We could probably still iterate efficiently even if the block - /// list contains "full fragment" blocks but that would require some - /// extra logic. + /// This is only possible if this is an AllowList and the maps don't contain + /// any "full fragment" blocks. pub fn iter_ids(&self) -> Option + '_>> { - if let Some(mut allow_iter) = self.allow_list.as_ref().and_then(|list| list.row_addrs()) { - if let Some(block_list) = &self.block_list { - if let Some(block_iter) = block_list.row_addrs() { - let mut block_iter = block_iter.peekable(); - Some(Box::new(iter::from_fn(move || { - for allow_id in allow_iter.by_ref() { - while let Some(block_id) = block_iter.peek() { - if *block_id >= allow_id { - break; - } - block_iter.next(); - } - if let Some(block_id) = block_iter.peek() { - if *block_id == allow_id { - continue; - } - } - return Some(allow_id); - } - None - }))) + match self { + Self::AllowList(allow_list) => { + if let Some(allow_iter) = allow_list.row_addrs() { + Some(Box::new(allow_iter)) } else { - // There is a block list but we can't iterate over it, give up None } - } else { - // There is no block list, use the allow list - Some(Box::new(allow_iter)) } - } else { - None + Self::BlockList(_) => None, // Can't iterate over block list } } } @@ -284,9 +208,9 @@ impl std::ops::Not for RowIdMask { type Output = Self; fn not(self) -> Self::Output { - Self { - block_list: self.allow_list, - allow_list: self.block_list, + match self { + Self::AllowList(allow_list) => Self::BlockList(allow_list), + Self::BlockList(block_list) => Self::AllowList(block_list), } } } @@ -295,21 +219,11 @@ impl std::ops::BitAnd for RowIdMask { type Output = Self; fn bitand(self, rhs: Self) -> Self::Output { - let block_list = match (self.block_list, rhs.block_list) { - (None, None) => None, - (Some(lhs), None) => Some(lhs), - (None, Some(rhs)) => Some(rhs), - (Some(lhs), Some(rhs)) => Some(lhs | rhs), - }; - let allow_list = match (self.allow_list, rhs.allow_list) { - (None, None) => None, - (Some(lhs), None) => Some(lhs), - (None, Some(rhs)) => Some(rhs), - (Some(lhs), Some(rhs)) => Some(lhs & rhs), - }; - Self { - block_list, - allow_list, + match (self, rhs) { + (Self::AllowList(a), Self::AllowList(b)) => Self::AllowList(a & b), + (Self::AllowList(allow), Self::BlockList(block)) + | (Self::BlockList(block), Self::AllowList(allow)) => Self::AllowList(allow - block), + (Self::BlockList(a), Self::BlockList(b)) => Self::BlockList(a | b), } } } @@ -318,44 +232,11 @@ impl std::ops::BitOr for RowIdMask { type Output = Self; fn bitor(self, rhs: Self) -> Self::Output { - let this = self.normalize(); - let rhs = rhs.normalize(); - let block_list = if let Some(mut self_block_list) = this.block_list { - match (&rhs.allow_list, rhs.block_list) { - // If RHS is allow all, then our block list disappears - (None, None) => None, - // If RHS is allow list, remove allowed from our block list - (Some(allow_list), None) => { - self_block_list -= allow_list; - Some(self_block_list) - } - // If RHS is block list, intersect - (None, Some(block_list)) => Some(self_block_list & block_list), - // We normalized to avoid this path - (Some(_), Some(_)) => unreachable!(), - } - } else if let Some(mut rhs_block_list) = rhs.block_list { - if let Some(allow_list) = &this.allow_list { - rhs_block_list -= allow_list; - Some(rhs_block_list) - } else { - Some(rhs_block_list) - } - } else { - None - }; - - let allow_list = match (this.allow_list, rhs.allow_list) { - (None, None) => None, - // Remember that an allow list of None means "all rows" and - // so "all rows" | "some rows" is always "all rows" - (Some(_), None) => None, - (None, Some(_)) => None, - (Some(lhs), Some(rhs)) => Some(lhs | rhs), - }; - Self { - block_list, - allow_list, + match (self, rhs) { + (Self::AllowList(a), Self::AllowList(b)) => Self::AllowList(a | b), + (Self::AllowList(allow), Self::BlockList(block)) + | (Self::BlockList(block), Self::AllowList(allow)) => Self::BlockList(block - allow), + (Self::BlockList(a), Self::BlockList(b)) => Self::BlockList(a & b), } } } @@ -682,14 +563,16 @@ impl RowAddrTreeMap { /// Apply a mask to the row ids /// - /// If there is an allow list then this will intersect the set with the allow list - /// If there is a block list then this will subtract the block list from the set + /// For AllowList: only keep rows that are in the selection and not null + /// For BlockList: remove rows that are blocked (not null) and remove nulls pub fn mask(&mut self, mask: &RowIdMask) { - if let Some(allow_list) = &mask.allow_list { - *self &= allow_list; - } - if let Some(block_list) = &mask.block_list { - *self -= block_list; + match mask { + RowIdMask::AllowList(allow_list) => { + *self &= allow_list; + } + RowIdMask::BlockList(block_list) => { + *self -= block_list; + } } } @@ -723,8 +606,23 @@ impl std::ops::BitOr for RowAddrTreeMap { } } +impl std::ops::BitOr<&Self> for RowAddrTreeMap { + type Output = Self; + + fn bitor(mut self, rhs: &Self) -> Self::Output { + self |= rhs; + self + } +} + impl std::ops::BitOrAssign for RowAddrTreeMap { fn bitor_assign(&mut self, rhs: Self) { + *self |= &rhs; + } +} + +impl std::ops::BitOrAssign<&Self> for RowAddrTreeMap { + fn bitor_assign(&mut self, rhs: &Self) { for (fragment, rhs_set) in &rhs.inner { let lhs_set = self.inner.get_mut(fragment); if let Some(lhs_set) = lhs_set { @@ -757,6 +655,21 @@ impl std::ops::BitAnd for RowAddrTreeMap { } } +impl std::ops::BitAnd<&Self> for RowAddrTreeMap { + type Output = Self; + + fn bitand(mut self, rhs: &Self) -> Self::Output { + self &= rhs; + self + } +} + +impl std::ops::BitAndAssign for RowAddrTreeMap { + fn bitand_assign(&mut self, rhs: Self) { + *self &= &rhs; + } +} + impl std::ops::BitAndAssign<&Self> for RowAddrTreeMap { fn bitand_assign(&mut self, rhs: &Self) { // Remove fragment that aren't on the RHS @@ -795,6 +708,15 @@ impl std::ops::Sub for RowAddrTreeMap { } } +impl std::ops::Sub<&Self> for RowAddrTreeMap { + type Output = Self; + + fn sub(mut self, rhs: &Self) -> Self { + self -= rhs; + self + } +} + impl std::ops::SubAssign<&Self> for RowAddrTreeMap { fn sub_assign(&mut self, rhs: &Self) { for (fragment, rhs_set) in &rhs.inner { @@ -868,6 +790,14 @@ impl From> for RowAddrTreeMap { } } +impl From> for RowAddrTreeMap { + fn from(range: RangeInclusive) -> Self { + let mut map = Self::default(); + map.insert_range(range); + map + } +} + impl From for RowAddrTreeMap { fn from(roaring: RoaringTreemap) -> Self { let mut inner = BTreeMap::new(); @@ -937,53 +867,189 @@ mod tests { use super::*; use proptest::prop_assert_eq; + fn rows(ids: &[u64]) -> RowAddrTreeMap { + RowAddrTreeMap::from_iter(ids) + } + + fn assert_mask_selects(mask: &RowIdMask, selected: &[u64], not_selected: &[u64]) { + for &id in selected { + assert!(mask.selected(id), "Expected row {} to be selected", id); + } + for &id in not_selected { + assert!(!mask.selected(id), "Expected row {} to NOT be selected", id); + } + } + + fn selected_in_range(mask: &RowIdMask, range: std::ops::Range) -> Vec { + range.filter(|val| mask.selected(*val)).collect() + } + + #[test] + fn test_row_id_mask_construction() { + let full_mask = RowIdMask::all_rows(); + assert_eq!(full_mask.max_len(), None); + assert_mask_selects(&full_mask, &[0, 1, 4 << 32 | 3], &[]); + assert_eq!(full_mask.allow_list(), None); + assert_eq!(full_mask.block_list(), Some(&RowAddrTreeMap::default())); + assert!(full_mask.iter_ids().is_none()); + + let empty_mask = RowIdMask::allow_nothing(); + assert_eq!(empty_mask.max_len(), Some(0)); + assert_mask_selects(&empty_mask, &[], &[0, 1, 4 << 32 | 3]); + assert_eq!(empty_mask.allow_list(), Some(&RowAddrTreeMap::default())); + assert_eq!(empty_mask.block_list(), None); + let iter = empty_mask.iter_ids(); + assert!(iter.is_some()); + assert_eq!(iter.unwrap().count(), 0); + + let allow_list = RowIdMask::from_allowed(rows(&[10, 20, 30])); + assert_eq!(allow_list.max_len(), Some(3)); + assert_mask_selects(&allow_list, &[10, 20, 30], &[0, 15, 25, 40]); + assert_eq!(allow_list.allow_list(), Some(&rows(&[10, 20, 30]))); + assert_eq!(allow_list.block_list(), None); + let iter = allow_list.iter_ids(); + assert!(iter.is_some()); + let ids: Vec = iter.unwrap().map(|addr| addr.into()).collect(); + assert_eq!(ids, vec![10, 20, 30]); + + let mut full_frag = RowAddrTreeMap::default(); + full_frag.insert_fragment(2); + let allow_list = RowIdMask::from_allowed(full_frag); + assert_eq!(allow_list.max_len(), None); + assert_mask_selects(&allow_list, &[(2 << 32) + 5], &[(3 << 32) + 5]); + assert!(allow_list.iter_ids().is_none()); + } + + #[test] + fn test_selected_indices() { + // Allow list + let mask = RowIdMask::from_allowed(rows(&[10, 20, 40])); + assert!(mask.selected_indices(std::iter::empty()).is_empty()); + assert_eq!(mask.selected_indices([25, 20, 14, 10].iter()), &[1, 3]); + + // Block list + let mask = RowIdMask::from_block(rows(&[10, 20, 40])); + assert!(mask.selected_indices(std::iter::empty()).is_empty()); + assert_eq!(mask.selected_indices([25, 20, 14, 10].iter()), &[0, 2]); + } + + #[test] + fn test_also_allow() { + // Allow list + let mask = RowIdMask::from_allowed(rows(&[10, 20])); + let new_mask = mask.also_allow(rows(&[20, 30, 40])); + assert_eq!(new_mask, RowIdMask::from_allowed(rows(&[10, 20, 30, 40]))); + + // Block list + let mask = RowIdMask::from_block(rows(&[10, 20, 30])); + let new_mask = mask.also_allow(rows(&[20, 40])); + assert_eq!(new_mask, RowIdMask::from_block(rows(&[10, 30]))); + } + + #[test] + fn test_also_block() { + // Allow list + let mask = RowIdMask::from_allowed(rows(&[10, 20, 30])); + let new_mask = mask.also_block(rows(&[20, 40])); + assert_eq!(new_mask, RowIdMask::from_allowed(rows(&[10, 30]))); + + // Block list + let mask = RowIdMask::from_block(rows(&[10, 20])); + let new_mask = mask.also_block(rows(&[20, 30, 40])); + assert_eq!(new_mask, RowIdMask::from_block(rows(&[10, 20, 30, 40]))); + } + + #[test] + fn test_iter_ids() { + // Allow list + let mask = RowIdMask::from_allowed(rows(&[10, 20, 30])); + let expected: Vec<_> = [10, 20, 30].into_iter().map(RowAddress::from).collect(); + assert_eq!(mask.iter_ids().unwrap().collect::>(), expected); + + // Allow list with full fragment + let mut inner = RowAddrTreeMap::default(); + inner.insert_fragment(10); + let mask = RowIdMask::from_allowed(inner); + assert!(mask.iter_ids().is_none()); + + // Block list + let mask = RowIdMask::from_block(rows(&[10, 20, 30])); + assert!(mask.iter_ids().is_none()); + } + + #[test] + fn test_row_id_mask_not() { + let allow_list = RowIdMask::from_allowed(rows(&[1, 2, 3])); + let block_list = !allow_list.clone(); + assert_eq!(block_list, RowIdMask::from_block(rows(&[1, 2, 3]))); + // Can roundtrip by negating again + assert_eq!(!block_list, allow_list); + } + #[test] fn test_ops() { let mask = RowIdMask::default(); - assert!(mask.selected(1)); - assert!(mask.selected(5)); - let block_list = mask.also_block(RowAddrTreeMap::from_iter(&[0, 5, 15])); - assert!(block_list.selected(1)); - assert!(!block_list.selected(5)); - let allow_list = RowIdMask::from_allowed(RowAddrTreeMap::from_iter(&[0, 2, 5])); - assert!(!allow_list.selected(1)); - assert!(allow_list.selected(5)); + assert_mask_selects(&mask, &[1, 5], &[]); + + let block_list = mask.also_block(rows(&[0, 5, 15])); + assert_mask_selects(&block_list, &[1], &[5]); + + let allow_list = RowIdMask::from_allowed(rows(&[0, 2, 5])); + assert_mask_selects(&allow_list, &[5], &[1]); + let combined = block_list & allow_list; - assert!(combined.selected(2)); - assert!(!combined.selected(0)); - assert!(!combined.selected(5)); - let other = RowIdMask::from_allowed(RowAddrTreeMap::from_iter(&[3])); + assert_mask_selects(&combined, &[2], &[0, 5]); + + let other = RowIdMask::from_allowed(rows(&[3])); let combined = combined | other; - assert!(combined.selected(2)); - assert!(combined.selected(3)); - assert!(!combined.selected(0)); - assert!(!combined.selected(5)); + assert_mask_selects(&combined, &[2, 3], &[0, 5]); - let block_list = RowIdMask::from_block(RowAddrTreeMap::from_iter(&[0])); - let allow_list = RowIdMask::from_allowed(RowAddrTreeMap::from_iter(&[3])); + let block_list = RowIdMask::from_block(rows(&[0])); + let allow_list = RowIdMask::from_allowed(rows(&[3])); let combined = block_list | allow_list; - assert!(combined.selected(1)); + assert_mask_selects(&combined, &[1], &[]); + } + + #[test] + fn test_logical_and() { + let allow1 = RowIdMask::from_allowed(rows(&[0, 1])); + let block1 = RowIdMask::from_block(rows(&[1, 2])); + let allow2 = RowIdMask::from_allowed(rows(&[1, 2, 3, 4])); + let block2 = RowIdMask::from_block(rows(&[3, 4])); + + fn check(lhs: &RowIdMask, rhs: &RowIdMask, expected: &[u64]) { + for mask in [lhs.clone() & rhs.clone(), rhs.clone() & lhs.clone()] { + assert_eq!(selected_in_range(&mask, 0..10), expected); + } + } + + // Allow & Allow + check(&allow1, &allow1, &[0, 1]); + check(&allow1, &allow2, &[1]); + + // Block & Block + check(&block1, &block1, &[0, 3, 4, 5, 6, 7, 8, 9]); + check(&block1, &block2, &[0, 5, 6, 7, 8, 9]); + + // Allow & Block + check(&allow1, &block1, &[0]); + check(&allow1, &block2, &[0, 1]); + check(&allow2, &block1, &[3, 4]); + check(&allow2, &block2, &[1, 2]); } #[test] fn test_logical_or() { - let allow1 = RowIdMask::from_allowed(RowAddrTreeMap::from_iter(&[5, 6, 7, 8, 9])); - let block1 = RowIdMask::from_block(RowAddrTreeMap::from_iter(&[5, 6])); - let mixed1 = allow1 - .clone() - .also_block(block1.block_list.as_ref().unwrap().clone()); - let allow2 = RowIdMask::from_allowed(RowAddrTreeMap::from_iter(&[2, 3, 4, 5, 6, 7, 8])); - let block2 = RowIdMask::from_block(RowAddrTreeMap::from_iter(&[4, 5])); - let mixed2 = allow2 - .clone() - .also_block(block2.block_list.as_ref().unwrap().clone()); + let allow1 = RowIdMask::from_allowed(rows(&[5, 6, 7, 8, 9])); + let block1 = RowIdMask::from_block(rows(&[5, 6])); + let mixed1 = allow1.clone().also_block(rows(&[5, 6])); + let allow2 = RowIdMask::from_allowed(rows(&[2, 3, 4, 5, 6, 7, 8])); + let block2 = RowIdMask::from_block(rows(&[4, 5])); + let mixed2 = allow2.clone().also_block(rows(&[4, 5])); fn check(lhs: &RowIdMask, rhs: &RowIdMask, expected: &[u64]) { for mask in [lhs.clone() | rhs.clone(), rhs.clone() | lhs.clone()] { - let values = (0..10) - .filter(|val| mask.selected(*val)) - .collect::>(); - assert_eq!(&values, expected); + assert_eq!(selected_in_range(&mask, 0..10), expected); } } @@ -1011,6 +1077,113 @@ mod tests { check(&block2, &mixed2, &[0, 1, 2, 3, 6, 7, 8, 9]); } + #[test] + fn test_deserialize_legacy_format() { + // Test that we can deserialize the old format where both allow_list + // and block_list could be present in the serialized form. + // + // The old format (before this PR) used a struct with both allow_list and block_list + // fields. The new format uses an enum. The deserialization code should handle + // the case where both lists are present by converting to AllowList(allow - block). + + // Create the RowIdTreeMaps and serialize them directly + let allow = rows(&[1, 2, 3, 4, 5, 10, 15]); + let block = rows(&[2, 4, 15]); + + // Serialize using the stable RowIdTreeMap serialization format + let block_bytes = { + let mut buf = Vec::with_capacity(block.serialized_size()); + block.serialize_into(&mut buf).unwrap(); + buf + }; + let allow_bytes = { + let mut buf = Vec::with_capacity(allow.serialized_size()); + allow.serialize_into(&mut buf).unwrap(); + buf + }; + + // Construct a binary array with both values present (simulating old format) + let old_format_array = + BinaryArray::from_opt_vec(vec![Some(&block_bytes), Some(&allow_bytes)]); + + // Deserialize - should handle this by creating AllowList(allow - block) + let deserialized = RowIdMask::from_arrow(&old_format_array).unwrap(); + + // The expected result: AllowList([1, 2, 3, 4, 5, 10, 15] - [2, 4, 15]) = [1, 3, 5, 10] + assert_mask_selects(&deserialized, &[1, 3, 5, 10], &[2, 4, 15]); + assert!( + deserialized.allow_list().is_some(), + "Should deserialize to AllowList variant" + ); + } + + #[test] + fn test_roundtrip_arrow() { + let row_addrs = rows(&[1, 2, 3, 100, 2000]); + + // Allow list + let original = RowIdMask::from_allowed(row_addrs.clone()); + let array = original.into_arrow().unwrap(); + assert_eq!(RowIdMask::from_arrow(&array).unwrap(), original); + + // Block list + let original = RowIdMask::from_block(row_addrs); + let array = original.into_arrow().unwrap(); + assert_eq!(RowIdMask::from_arrow(&array).unwrap(), original); + } + + #[test] + fn test_deserialize_legacy_empty_lists() { + // Case 1: Both None (should become all_rows) + let array = BinaryArray::from_opt_vec(vec![None, None]); + let mask = RowIdMask::from_arrow(&array).unwrap(); + assert_mask_selects(&mask, &[0, 100, u64::MAX], &[]); + + // Case 2: Only block list (no allow list) + let block = rows(&[5, 10]); + let block_bytes = { + let mut buf = Vec::with_capacity(block.serialized_size()); + block.serialize_into(&mut buf).unwrap(); + buf + }; + let array = BinaryArray::from_opt_vec(vec![Some(&block_bytes[..]), None]); + let mask = RowIdMask::from_arrow(&array).unwrap(); + assert_mask_selects(&mask, &[0, 15], &[5, 10]); + + // Case 3: Only allow list (no block list) + let allow = rows(&[5, 10]); + let allow_bytes = { + let mut buf = Vec::with_capacity(allow.serialized_size()); + allow.serialize_into(&mut buf).unwrap(); + buf + }; + let array = BinaryArray::from_opt_vec(vec![None, Some(&allow_bytes[..])]); + let mask = RowIdMask::from_arrow(&array).unwrap(); + assert_mask_selects(&mask, &[5, 10], &[0, 15]); + } + + #[test] + fn test_map_insert() { + let mut map = RowAddrTreeMap::default(); + + assert!(!map.contains(20)); + assert!(map.insert(20)); + assert!(map.contains(20)); + assert!(!map.insert(20)); // Inserting again should be no-op + + let bitmap = map.get_fragment_bitmap(0); + assert!(bitmap.is_some()); + let bitmap = bitmap.unwrap(); + assert_eq!(bitmap.len(), 1); + + assert!(map.get_fragment_bitmap(1).is_none()); + + map.insert_fragment(0); + assert!(map.contains(0)); + assert!(!map.insert(0)); // Inserting into full fragment should be no-op + assert!(map.get_fragment_bitmap(0).is_none()); + } + #[test] fn test_map_insert_range() { let ranges = &[ @@ -1067,6 +1240,111 @@ mod tests { // a lot of memory. } + #[test] + fn test_map_mask() { + let mask = rows(&[0, 1, 2]); + let mask2 = rows(&[0, 2, 3]); + + let allow_list = RowIdMask::AllowList(mask2.clone()); + let mut actual = mask.clone(); + actual.mask(&allow_list); + assert_eq!(actual, rows(&[0, 2])); + + let block_list = RowIdMask::BlockList(mask2); + let mut actual = mask; + actual.mask(&block_list); + assert_eq!(actual, rows(&[1])); + } + + #[test] + #[should_panic(expected = "Size of full fragment is unknown")] + fn test_map_insert_full_fragment_row() { + let mut mask = RowAddrTreeMap::default(); + mask.insert_fragment(0); + + unsafe { + let _ = mask.into_addr_iter().collect::>(); + } + } + + #[test] + fn test_map_into_addr_iter() { + let mut mask = RowAddrTreeMap::default(); + mask.insert(0); + mask.insert(1); + mask.insert(1 << 32 | 5); + mask.insert(2 << 32 | 10); + + let expected = vec![0u64, 1, 1 << 32 | 5, 2 << 32 | 10]; + let actual: Vec = unsafe { mask.into_addr_iter().collect() }; + assert_eq!(actual, expected); + } + + #[test] + fn test_map_from() { + let map = RowAddrTreeMap::from(10..12); + assert!(map.contains(10)); + assert!(map.contains(11)); + assert!(!map.contains(12)); + assert!(!map.contains(3)); + + let map = RowAddrTreeMap::from(10..=12); + assert!(map.contains(10)); + assert!(map.contains(11)); + assert!(map.contains(12)); + assert!(!map.contains(3)); + } + + #[test] + fn test_map_from_roaring() { + let bitmap = RoaringTreemap::from_iter(&[0, 1, 1 << 32]); + let map = RowAddrTreeMap::from(bitmap); + assert!(map.contains(0) && map.contains(1) && map.contains(1 << 32)); + assert!(!map.contains(2)); + } + + #[test] + fn test_map_extend() { + let mut map = RowAddrTreeMap::default(); + map.insert(0); + map.insert_fragment(1); + + let other_rows = [0, 2, 1 << 32 | 10, 3 << 32 | 5]; + map.extend(other_rows.iter().copied()); + + assert!(map.contains(0)); + assert!(map.contains(2)); + assert!(map.contains(1 << 32 | 5)); + assert!(map.contains(1 << 32 | 10)); + assert!(map.contains(3 << 32 | 5)); + assert!(!map.contains(3)); + } + + #[test] + fn test_map_extend_other_maps() { + let mut map = RowAddrTreeMap::default(); + map.insert(0); + map.insert_fragment(1); + map.insert(4 << 32); + + let mut other_map = rows(&[0, 2, 1 << 32 | 10, 3 << 32 | 5]); + other_map.insert_fragment(4); + map.extend(std::iter::once(other_map)); + + for id in [ + 0, + 2, + 1 << 32 | 5, + 1 << 32 | 10, + 3 << 32 | 5, + 4 << 32, + 4 << 32 | 7, + ] { + assert!(map.contains(id), "Expected {} to be contained", id); + } + assert!(!map.contains(3)); + } + proptest::proptest! { #[test] fn test_map_serialization_roundtrip( @@ -1229,50 +1507,245 @@ mod tests { } #[test] - fn test_iter_ids() { - let mut mask = RowIdMask::default(); - assert!(mask.iter_ids().is_none()); + fn test_row_addr_selection_deep_size_of() { + use deepsize::DeepSizeOf; + + // Test Full variant - should have minimal size (just the enum discriminant) + let full = RowAddrSelection::Full; + let full_size = full.deep_size_of(); + // Full variant has no heap allocations beyond the enum itself + assert!(full_size < 100); // Small sanity check + + // Test Partial variant - should include bitmap size + let mut bitmap = RoaringBitmap::new(); + bitmap.insert_range(0..100); + let partial = RowAddrSelection::Partial(bitmap.clone()); + let partial_size = partial.deep_size_of(); + // Partial variant should be larger due to bitmap + assert!(partial_size >= bitmap.serialized_size()); + } - // Test with just an allow list - let mut allow_list = RowAddrTreeMap::default(); - allow_list.extend([1, 5, 10].iter().copied()); - mask.allow_list = Some(allow_list); - - let ids: Vec<_> = mask.iter_ids().unwrap().collect(); - assert_eq!( - ids, - vec![ - RowAddress::new_from_parts(0, 1), - RowAddress::new_from_parts(0, 5), - RowAddress::new_from_parts(0, 10) - ] - ); + #[test] + fn test_row_addr_selection_union_all_with_full() { + let full = RowAddrSelection::Full; + let partial = RowAddrSelection::Partial(RoaringBitmap::from_iter(&[1, 2, 3])); + + assert!(matches!( + RowAddrSelection::union_all(&[&full, &partial]), + RowAddrSelection::Full + )); + + let partial2 = RowAddrSelection::Partial(RoaringBitmap::from_iter(&[4, 5, 6])); + let RowAddrSelection::Partial(bitmap) = RowAddrSelection::union_all(&[&partial, &partial2]) + else { + panic!("Expected Partial"); + }; + assert!(bitmap.contains(1) && bitmap.contains(4)); + } - // Test with both allow list and block list - let mut block_list = RowAddrTreeMap::default(); - block_list.extend([5].iter().copied()); - mask.block_list = Some(block_list); - - let ids: Vec<_> = mask.iter_ids().unwrap().collect(); - assert_eq!( - ids, - vec![ - RowAddress::new_from_parts(0, 1), - RowAddress::new_from_parts(0, 10) - ] - ); + #[test] + fn test_insert_range_unbounded_start() { + let mut map = RowAddrTreeMap::default(); + + // Test exclusive start bound + let count = map.insert_range((std::ops::Bound::Excluded(5), std::ops::Bound::Included(10))); + assert_eq!(count, 5); // 6, 7, 8, 9, 10 + assert!(!map.contains(5)); + assert!(map.contains(6)); + assert!(map.contains(10)); + + // Test unbounded end + let mut map2 = RowAddrTreeMap::default(); + let count = map2.insert_range(0..5); + assert_eq!(count, 5); + assert!(map2.contains(0)); + assert!(map2.contains(4)); + assert!(!map2.contains(5)); + } - // Test with full fragment in block list - let mut block_list = RowAddrTreeMap::default(); - block_list.insert_fragment(0); - mask.block_list = Some(block_list); - assert!(mask.iter_ids().is_none()); + #[test] + fn test_remove_from_full_fragment() { + let mut map = RowAddrTreeMap::default(); + map.insert_fragment(0); - // Test with full fragment in allow list - mask.block_list = None; - let mut allow_list = RowAddrTreeMap::default(); - allow_list.insert_fragment(0); - mask.allow_list = Some(allow_list); - assert!(mask.iter_ids().is_none()); + // Verify it's a full fragment - get_fragment_bitmap returns None for Full + for id in [0, 100, u32::MAX as u64] { + assert!(map.contains(id)); + } + assert!(map.get_fragment_bitmap(0).is_none()); + + // Remove a value from the full fragment + assert!(map.remove(50)); + + // Now it should be partial (a full RoaringBitmap minus one value) + assert!(map.contains(0) && !map.contains(50) && map.contains(100)); + assert!(map.get_fragment_bitmap(0).is_some()); + } + + #[test] + fn test_retain_fragments() { + let mut map = RowAddrTreeMap::default(); + map.insert(0); // fragment 0 + map.insert(1 << 32 | 5); // fragment 1 + map.insert(2 << 32 | 10); // fragment 2 + map.insert_fragment(3); // fragment 3 + + map.retain_fragments([0, 2]); + + assert!(map.contains(0) && map.contains(2 << 32 | 10)); + assert!(!map.contains(1 << 32 | 5) && !map.contains(3 << 32)); + } + + #[test] + fn test_bitor_assign_full_fragment() { + // Test BitOrAssign when LHS has Full and RHS has Partial + let mut map1 = RowAddrTreeMap::default(); + map1.insert_fragment(0); + let mut map2 = RowAddrTreeMap::default(); + map2.insert(5); + + map1 |= &map2; + // Full | Partial = Full + assert!(map1.contains(0) && map1.contains(5) && map1.contains(100)); + + // Test BitOrAssign when LHS has Partial and RHS has Full + let mut map3 = RowAddrTreeMap::default(); + map3.insert(5); + let mut map4 = RowAddrTreeMap::default(); + map4.insert_fragment(0); + + map3 |= &map4; + // Partial | Full = Full + assert!(map3.contains(0) && map3.contains(5) && map3.contains(100)); + } + + #[test] + fn test_bitand_assign_full_fragments() { + // Test BitAndAssign when both have Full for same fragment + let mut map1 = RowAddrTreeMap::default(); + map1.insert_fragment(0); + let mut map2 = RowAddrTreeMap::default(); + map2.insert_fragment(0); + + map1 &= &map2; + // Full & Full = Full + assert!(map1.contains(0) && map1.contains(100)); + + // Test BitAndAssign when LHS Full, RHS Partial + let mut map3 = RowAddrTreeMap::default(); + map3.insert_fragment(0); + let mut map4 = RowAddrTreeMap::default(); + map4.insert(5); + map4.insert(10); + + map3 &= &map4; + // Full & Partial([5,10]) = Partial([5,10]) + assert!(map3.contains(5) && map3.contains(10)); + assert!(!map3.contains(0) && !map3.contains(100)); + + // Test that empty intersection results in removal + let mut map5 = RowAddrTreeMap::default(); + map5.insert(5); + let mut map6 = RowAddrTreeMap::default(); + map6.insert(10); + + map5 &= &map6; + assert!(map5.is_empty()); + } + + #[test] + fn test_sub_assign_with_full_fragments() { + // Test SubAssign when LHS is Full and RHS is Partial + let mut map1 = RowAddrTreeMap::default(); + map1.insert_fragment(0); + let mut map2 = RowAddrTreeMap::default(); + map2.insert(5); + map2.insert(10); + + map1 -= &map2; + // Full - Partial([5,10]) = Full minus those values + assert!(map1.contains(0) && map1.contains(100)); + assert!(!map1.contains(5) && !map1.contains(10)); + + // Test SubAssign when both are Full for same fragment + let mut map3 = RowAddrTreeMap::default(); + map3.insert_fragment(0); + let mut map4 = RowAddrTreeMap::default(); + map4.insert_fragment(0); + + map3 -= &map4; + // Full - Full = empty + assert!(map3.is_empty()); + + // Test SubAssign when LHS is Partial and RHS is Full + let mut map5 = RowAddrTreeMap::default(); + map5.insert(5); + map5.insert(10); + let mut map6 = RowAddrTreeMap::default(); + map6.insert_fragment(0); + + map5 -= &map6; + // Partial - Full = empty + assert!(map5.is_empty()); + } + + #[test] + fn test_from_iterator_with_full_fragment() { + // Test that inserting into a full fragment is a no-op + let mut map = RowAddrTreeMap::default(); + map.insert_fragment(0); + + // Extend with values that would go into fragment 0 + map.extend([5u64, 10, 100].iter()); + + // Should still be full fragment + for id in [0, 5, 10, 100, u32::MAX as u64] { + assert!(map.contains(id)); + } + } + + #[test] + fn test_insert_range_excluded_end() { + // Test excluded end bound (line 391-393) + let mut map = RowAddrTreeMap::default(); + // Using RangeFrom with small range won't hit the unbounded case + // Instead test Bound::Excluded for end + let count = map.insert_range((std::ops::Bound::Included(5), std::ops::Bound::Excluded(10))); + assert_eq!(count, 5); // 5, 6, 7, 8, 9 + assert!(map.contains(5)); + assert!(map.contains(9)); + assert!(!map.contains(10)); + } + + #[test] + fn test_bitand_assign_owned() { + // Test BitAndAssign (owned, not reference) + let mut map1 = RowAddrTreeMap::default(); + map1.insert(5); + map1.insert(10); + + // Using owned rhs (not reference) + map1 &= rows(&[5, 15]); + + assert!(map1.contains(5)); + assert!(!map1.contains(10) && !map1.contains(15)); + } + + #[test] + fn test_from_iter_with_full_fragment() { + // When we collect into RowAddrTreeMap, it should handle duplicates + let map: RowAddrTreeMap = vec![5u64, 10, 100].into_iter().collect(); + assert!(map.contains(5) && map.contains(10)); + + // Test that extending a map with full fragment ignores new values + let mut map = RowAddrTreeMap::default(); + map.insert_fragment(0); + for val in [5, 10, 100] { + map.insert(val); // This should be no-op since fragment is full + } + // Still full fragment + for id in [0, 5, u32::MAX as u64] { + assert!(map.contains(id)); + } } } diff --git a/rust/lance-core/src/utils/mask/nullable.rs b/rust/lance-core/src/utils/mask/nullable.rs new file mode 100644 index 00000000000..5e5657d8d0d --- /dev/null +++ b/rust/lance-core/src/utils/mask/nullable.rs @@ -0,0 +1,621 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright The Lance Authors + +use deepsize::DeepSizeOf; + +use super::{RowAddrTreeMap, RowIdMask}; + +/// A set of row ids, with optional set of nulls. +/// +/// This is often a result of a filter, where `selected` represents the rows that +/// passed the filter, and `nulls` represents the rows where the filter evaluated +/// to null. For example, in SQL `NULL > 5` evaluates to null. This is distinct +/// from being deselected to support proper three-valued logic for NOT. +/// (`NOT FALSE` is TRUE, `NOT TRUE` is FALSE, but `NOT NULL` is NULL. +/// `NULL | TRUE = TRUE`, `NULL & FALSE = FALSE`, but `NULL | FALSE = NULL` +/// and `NULL & TRUE = NULL`). +#[derive(Clone, Debug, Default, DeepSizeOf)] +pub struct NullableRowAddrSet { + selected: RowAddrTreeMap, + // Rows that are NULL. These rows are considered NULL even if they are also in `selected`. + nulls: RowAddrTreeMap, +} + +impl NullableRowAddrSet { + /// Create a new RowSelection from selected rows and null rows. + /// + /// `nulls` may have overlap with `selected`. Rows in `nulls` are considered NULL, + /// even if they are also in `selected`. + pub fn new(selected: RowAddrTreeMap, nulls: RowAddrTreeMap) -> Self { + Self { selected, nulls } + } + + pub fn with_nulls(mut self, nulls: RowAddrTreeMap) -> Self { + self.nulls = nulls; + self + } + + /// Create an empty selection. Alias for [Default::default] + pub fn empty() -> Self { + Default::default() + } + + /// Get the number of TRUE rows (selected but not null). + /// + /// Returns None if the number of TRUE rows cannot be determined. This happens + /// if the underlying RowAddrTreeMap has full fragments selected. + pub fn len(&self) -> Option { + self.true_rows().len() + } + + pub fn is_empty(&self) -> bool { + self.selected.is_empty() + } + + /// Check if a row_id is selected (TRUE) + pub fn selected(&self, row_id: u64) -> bool { + self.selected.contains(row_id) && !self.nulls.contains(row_id) + } + + /// Get the null rows + pub fn null_rows(&self) -> &RowAddrTreeMap { + &self.nulls + } + + /// Get the TRUE rows (selected but not null) + pub fn true_rows(&self) -> RowAddrTreeMap { + self.selected.clone() - self.nulls.clone() + } + + pub fn union_all(selections: &[Self]) -> Self { + let true_rows = selections + .iter() + .map(|s| s.true_rows()) + .collect::>(); + let true_rows_refs = true_rows.iter().collect::>(); + let selected = RowAddrTreeMap::union_all(&true_rows_refs); + let nulls = RowAddrTreeMap::union_all( + &selections + .iter() + .map(|s| &s.nulls) + .collect::>(), + ); + // TRUE | NULL = TRUE, so remove any TRUE rows from nulls + let nulls = nulls - &selected; + Self { selected, nulls } + } +} + +impl PartialEq for NullableRowAddrSet { + fn eq(&self, other: &Self) -> bool { + self.true_rows() == other.true_rows() && self.nulls == other.nulls + } +} + +impl std::ops::BitAndAssign<&Self> for NullableRowAddrSet { + fn bitand_assign(&mut self, rhs: &Self) { + self.nulls = if self.nulls.is_empty() && rhs.nulls.is_empty() { + RowAddrTreeMap::new() // Fast path + } else { + (self.nulls.clone() & &rhs.nulls) // null and null -> null + | (self.nulls.clone() & &rhs.selected) // null and true -> null + | (rhs.nulls.clone() & &self.selected) // true and null -> null + }; + + self.selected &= &rhs.selected; + } +} + +impl std::ops::BitOrAssign<&Self> for NullableRowAddrSet { + fn bitor_assign(&mut self, rhs: &Self) { + self.nulls = if self.nulls.is_empty() && rhs.nulls.is_empty() { + RowAddrTreeMap::new() // Fast path + } else { + // null or null -> null (excluding rows that are true in either) + let true_rows = + (self.selected.clone() - &self.nulls) | (rhs.selected.clone() - &rhs.nulls); + (self.nulls.clone() | &rhs.nulls) - true_rows + }; + + self.selected |= &rhs.selected; + } +} + +/// A version of [`RowIdMask`] that supports nulls. +/// +/// This mask handles three-valued logic for SQL expressions, where a filter can +/// evaluate to TRUE, FALSE, or NULL. The `selected` set includes rows that are +/// TRUE or NULL. The `nulls` set includes rows that are NULL. +#[derive(Clone, Debug)] +pub enum NullableRowIdMask { + AllowList(NullableRowAddrSet), + BlockList(NullableRowAddrSet), +} + +impl NullableRowIdMask { + pub fn selected(&self, row_id: u64) -> bool { + match self { + Self::AllowList(NullableRowAddrSet { selected, nulls }) => { + selected.contains(row_id) && !nulls.contains(row_id) + } + Self::BlockList(NullableRowAddrSet { selected, nulls }) => { + !selected.contains(row_id) && !nulls.contains(row_id) + } + } + } + + pub fn drop_nulls(self) -> RowIdMask { + match self { + Self::AllowList(NullableRowAddrSet { selected, nulls }) => { + RowIdMask::AllowList(selected - nulls) + } + Self::BlockList(NullableRowAddrSet { selected, nulls }) => { + RowIdMask::BlockList(selected | nulls) + } + } + } +} + +impl std::ops::Not for NullableRowIdMask { + type Output = Self; + + fn not(self) -> Self::Output { + match self { + Self::AllowList(set) => Self::BlockList(set), + Self::BlockList(set) => Self::AllowList(set), + } + } +} + +impl std::ops::BitAnd for NullableRowIdMask { + type Output = Self; + + fn bitand(self, rhs: Self) -> Self::Output { + // Null handling: + // * null and true -> null + // * null and null -> null + // * null and false -> false + match (self, rhs) { + (Self::AllowList(a), Self::AllowList(b)) => { + let nulls = if a.nulls.is_empty() && b.nulls.is_empty() { + RowAddrTreeMap::new() // Fast path + } else { + (a.nulls.clone() & &b.nulls) // null and null -> null + | (a.nulls & &b.selected) // null and true -> null + | (b.nulls & &a.selected) // true and null -> null + }; + let selected = a.selected & b.selected; + Self::AllowList(NullableRowAddrSet { selected, nulls }) + } + (Self::AllowList(allow), Self::BlockList(block)) + | (Self::BlockList(block), Self::AllowList(allow)) => { + let nulls = if allow.nulls.is_empty() && block.nulls.is_empty() { + RowAddrTreeMap::new() // Fast path + } else { + (allow.nulls.clone() & &block.nulls) // null and null -> null + | (allow.nulls - &block.selected) // null and true -> null + | (block.nulls & &allow.selected) // true and null -> null + }; + let selected = allow.selected - block.selected; + Self::AllowList(NullableRowAddrSet { selected, nulls }) + } + (Self::BlockList(a), Self::BlockList(b)) => { + let nulls = if a.nulls.is_empty() && b.nulls.is_empty() { + RowAddrTreeMap::new() // Fast path + } else { + (a.nulls.clone() & &b.nulls) // null and null -> null + | (a.nulls - &b.selected) // null and true -> null + | (b.nulls - &a.selected) // true and null -> null + }; + let selected = a.selected | b.selected; + Self::BlockList(NullableRowAddrSet { selected, nulls }) + } + } + } +} + +impl std::ops::BitOr for NullableRowIdMask { + type Output = Self; + + fn bitor(self, rhs: Self) -> Self::Output { + // Null handling: + // * null or true -> true + // * null or null -> null + // * null or false -> null + match (self, rhs) { + (Self::AllowList(a), Self::AllowList(b)) => { + let nulls = if a.nulls.is_empty() && b.nulls.is_empty() { + RowAddrTreeMap::new() // Fast path + } else { + // null or null -> null (excluding rows that are true in either) + let true_rows = + (a.selected.clone() - &a.nulls) | (b.selected.clone() - &b.nulls); + (a.nulls | b.nulls) - true_rows + }; + let selected = (a.selected | b.selected) | &nulls; + Self::AllowList(NullableRowAddrSet { selected, nulls }) + } + (Self::AllowList(allow), Self::BlockList(block)) + | (Self::BlockList(block), Self::AllowList(allow)) => { + let nulls = if allow.nulls.is_empty() && block.nulls.is_empty() { + RowAddrTreeMap::new() // Fast path + } else { + // null or null -> null (excluding rows that are true in either) + let allow_true = allow.selected.clone() - &allow.nulls; + ((allow.nulls | block.nulls) & block.selected.clone()) - allow_true + }; + let selected = (block.selected - allow.selected) | &nulls; + Self::BlockList(NullableRowAddrSet { selected, nulls }) + } + (Self::BlockList(a), Self::BlockList(b)) => { + let nulls = if a.nulls.is_empty() && b.nulls.is_empty() { + RowAddrTreeMap::new() // Fast path + } else { + // null or null -> null (excluding rows that are true in either) + let false_rows = + (a.selected.clone() - &a.nulls) & (b.selected.clone() - &b.nulls); + (a.nulls | &b.nulls) - false_rows + }; + let selected = (a.selected & b.selected) | &nulls; + Self::BlockList(NullableRowAddrSet { selected, nulls }) + } + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + fn rows(ids: &[u64]) -> RowAddrTreeMap { + RowAddrTreeMap::from_iter(ids) + } + + fn nullable_set(selected: &[u64], nulls: &[u64]) -> NullableRowAddrSet { + NullableRowAddrSet::new(rows(selected), rows(nulls)) + } + + fn allow(selected: &[u64], nulls: &[u64]) -> NullableRowIdMask { + NullableRowIdMask::AllowList(nullable_set(selected, nulls)) + } + + fn block(selected: &[u64], nulls: &[u64]) -> NullableRowIdMask { + NullableRowIdMask::BlockList(nullable_set(selected, nulls)) + } + + fn assert_mask_selects(mask: &NullableRowIdMask, selected: &[u64], not_selected: &[u64]) { + for &id in selected { + assert!(mask.selected(id), "Expected row {} to be selected", id); + } + for &id in not_selected { + assert!(!mask.selected(id), "Expected row {} to NOT be selected", id); + } + } + + #[test] + fn test_not_with_nulls() { + // Test case from issue #4756: x != 5 on data [0, 5, null] + // x = 5 should return: AllowList with selected=[1,2], nulls=[2] + // NOT(x = 5) should return: BlockList with selected=[1,2], nulls=[2] + // selected() should return TRUE for row 0, FALSE for rows 1 and 2 + let mask = allow(&[1, 2], &[2]); + let not_mask = !mask; + + // Row 0: selected (x=0, which is != 5) + // Row 1: NOT selected (x=5, which is == 5) + // Row 2: NOT selected (x=null, comparison result is null) + assert_mask_selects(¬_mask, &[0], &[1, 2]); + } + + #[test] + fn test_and_with_nulls() { + // Test Kleene AND logic: true AND null = null, false AND null = false + + // Case 1: TRUE mask AND mask with nulls + let true_mask = allow(&[0, 1, 2, 3, 4], &[]); + let null_mask = allow(&[0, 1, 2, 3, 4], &[1, 3]); + let result = true_mask & null_mask.clone(); + + // TRUE AND TRUE = TRUE; TRUE AND NULL = NULL (filtered out) + assert_mask_selects(&result, &[0, 2, 4], &[1, 3]); + + // Case 2: FALSE mask AND mask with nulls + let false_mask = block(&[0, 1, 2, 3, 4], &[]); + let result = false_mask & null_mask; + + // FALSE AND anything = FALSE + assert_mask_selects(&result, &[], &[0, 1, 2, 3, 4]); + + // Case 3: Both masks have nulls - union of null sets + let mask1 = allow(&[0, 1, 2], &[1]); + let mask2 = allow(&[0, 2, 3], &[2]); + let result = mask1 & mask2; + + // Only row 0 is TRUE in both; rows 1,2 are null in at least one; row 3 not in first + assert_mask_selects(&result, &[0], &[1, 2, 3]); + } + + #[test] + fn test_or_with_nulls() { + // Test Kleene OR logic: true OR null = true, false OR null = null + + // Case 1: FALSE mask OR mask with nulls + let false_mask = block(&[0, 1, 2], &[]); + let null_mask = allow(&[0, 1, 2], &[1, 2]); + let result = false_mask | null_mask.clone(); + + // FALSE OR TRUE = TRUE; FALSE OR NULL = NULL (filtered out) + assert_mask_selects(&result, &[0], &[1, 2]); + + // Case 2: TRUE mask OR mask with nulls + let true_mask = allow(&[0, 1, 2], &[]); + let result = true_mask | null_mask; + + // TRUE OR anything = TRUE + assert_mask_selects(&result, &[0, 1, 2], &[]); + + // Case 3: Both have nulls + let mask1 = block(&[0, 1, 2, 3], &[1, 2]); + let mask2 = block(&[0, 1, 2, 3], &[2, 3]); + let result = mask1 | mask2; + + // Row 0: FALSE in both; Rows 1,2,3: NULL in at least one + assert_mask_selects(&result, &[], &[0, 1, 2, 3]); + } + + #[test] + fn test_row_selection_bit_or() { + // [T, N, T, N, F, F, F] + let left = nullable_set(&[1, 2, 3, 4], &[2, 4]); + // [F, F, T, N, T, N, N] + let right = nullable_set(&[3, 4, 5, 6], &[4, 6, 7]); + // [T, N, T, N, T, N, N] + let expected_true = rows(&[1, 3, 5]); + let expected_nulls = rows(&[2, 4, 6, 7]); + + let mut result = left.clone(); + result |= &right; + assert_eq!(&result.true_rows(), &expected_true); + assert_eq!(result.null_rows(), &expected_nulls); + + // Commutative property holds + let mut result = right.clone(); + result |= &left; + assert_eq!(&result.true_rows(), &expected_true); + assert_eq!(result.null_rows(), &expected_nulls); + } + + #[test] + fn test_row_selection_bit_and() { + // [T, N, T, N, F, F, F] + let left = nullable_set(&[1, 2, 3, 4], &[2, 4]); + // [F, F, T, N, T, N, N] + let right = nullable_set(&[3, 4, 5, 6], &[4, 6, 7]); + // [F, F, T, N, F, F, F] + let expected_true = rows(&[3]); + let expected_nulls = rows(&[4]); + + let mut result = left.clone(); + result &= &right; + assert_eq!(&result.true_rows(), &expected_true); + assert_eq!(result.null_rows(), &expected_nulls); + + // Commutative property holds + let mut result = right.clone(); + result &= &left; + assert_eq!(&result.true_rows(), &expected_true); + assert_eq!(result.null_rows(), &expected_nulls); + } + + #[test] + fn test_union_all() { + // Union all is basically a series of ORs. + // [T, T, T, N, N, N, F, F, F] + let set1 = nullable_set(&[1, 2, 3, 4], &[4, 5, 6]); + // [T, N, F, T, N, F, T, N, F] + let set2 = nullable_set(&[1, 4, 7, 8], &[2, 5, 8]); + let set3 = NullableRowAddrSet::empty(); + + let result = NullableRowAddrSet::union_all(&[set1, set2, set3]); + + // [T, T, T, T, N, N, T, N, F] + assert_eq!(&result.true_rows(), &rows(&[1, 2, 3, 4, 7])); + assert_eq!(result.null_rows(), &rows(&[5, 6, 8])); + } + + #[test] + fn test_nullable_row_addr_set_with_nulls() { + let set = NullableRowAddrSet::new(rows(&[1, 2, 3]), RowAddrTreeMap::new()); + let set_with_nulls = set.with_nulls(rows(&[2])); + + assert!(set_with_nulls.selected(1) && set_with_nulls.selected(3)); + assert!(!set_with_nulls.selected(2)); // null + } + + #[test] + fn test_nullable_row_addr_set_len_and_is_empty() { + let set = nullable_set(&[1, 2, 3, 4, 5], &[2, 4]); + + // len() returns count of TRUE rows (selected - nulls) + assert_eq!(set.len(), Some(3)); // 1, 3, 5 + assert!(!set.is_empty()); + + let empty_set = NullableRowAddrSet::empty(); + assert!(empty_set.is_empty()); + assert_eq!(empty_set.len(), Some(0)); + } + + #[test] + fn test_nullable_row_addr_set_selected() { + let set = nullable_set(&[1, 2, 3], &[2]); + + // selected() returns true only for TRUE rows (in selected and not in nulls) + assert!(set.selected(1) && set.selected(3)); + assert!(!set.selected(2)); // null + assert!(!set.selected(4)); // not in selected + } + + #[test] + fn test_nullable_row_addr_set_partial_eq() { + let set1 = nullable_set(&[1, 2, 3], &[2]); + let set2 = nullable_set(&[1, 2, 3], &[2]); + // set3 has same true_rows but different nulls + let set3 = nullable_set(&[1, 3], &[3]); + + assert_eq!(set1, set2); + assert_ne!(set1, set3); // different nulls + } + + #[test] + fn test_nullable_row_addr_set_bitand_fast_path() { + // Test fast path when both have no nulls + let set1 = nullable_set(&[1, 2, 3], &[]); + let set2 = nullable_set(&[2, 3, 4], &[]); + + let mut result = set1; + result &= &set2; + + // Intersection: [2, 3] + assert!(result.selected(2) && result.selected(3)); + assert!(!result.selected(1) && !result.selected(4)); + assert!(result.null_rows().is_empty()); + } + + #[test] + fn test_nullable_row_addr_set_bitor_fast_path() { + // Test fast path when both have no nulls + let set1 = nullable_set(&[1, 2], &[]); + let set2 = nullable_set(&[3, 4], &[]); + + let mut result = set1; + result |= &set2; + + // Union: [1, 2, 3, 4] + for id in [1, 2, 3, 4] { + assert!(result.selected(id)); + } + assert!(result.null_rows().is_empty()); + } + + #[test] + fn test_nullable_row_id_mask_drop_nulls() { + // Test drop_nulls for AllowList + let allow_mask = allow(&[1, 2, 3, 4], &[2, 4]); + let dropped = allow_mask.drop_nulls(); + // Should be AllowList([1, 3]) after removing nulls + assert!(dropped.selected(1) && dropped.selected(3)); + assert!(!dropped.selected(2) && !dropped.selected(4)); + + // Test drop_nulls for BlockList + let block_mask = block(&[1, 2], &[3]); + let dropped = block_mask.drop_nulls(); + // BlockList: blocked = [1, 2] | [3] = [1, 2, 3] + assert!(!dropped.selected(1) && !dropped.selected(2) && !dropped.selected(3)); + assert!(dropped.selected(4) && dropped.selected(5)); + } + + #[test] + fn test_nullable_row_id_mask_not_blocklist() { + let block_mask = block(&[1, 2], &[2]); + let not_mask = !block_mask; + + // NOT(BlockList) = AllowList + assert!(matches!(not_mask, NullableRowIdMask::AllowList(_))); + } + + #[test] + fn test_nullable_row_id_mask_bitand_allow_allow_fast_path() { + // Test AllowList & AllowList with no nulls (fast path) + let mask1 = allow(&[1, 2, 3], &[]); + let mask2 = allow(&[2, 3, 4], &[]); + + let result = mask1 & mask2; + assert_mask_selects(&result, &[2, 3], &[1, 4]); + } + + #[test] + fn test_nullable_row_id_mask_bitand_allow_block() { + let allow_mask = allow(&[1, 2, 3, 4, 5], &[2]); + let block_mask = block(&[3, 4], &[4]); + + let result = allow_mask & block_mask; + // allow: T=[1,3,4,5], N=[2] + // block: F=[3,4], N=[4] + // T & T = T; N & T = N (filtered); T & F = F; T & N = N (filtered) + assert_mask_selects(&result, &[1, 5], &[2, 3, 4]); + } + + #[test] + fn test_nullable_row_id_mask_bitand_allow_block_fast_path() { + // Test AllowList & BlockList fast path (no nulls) + let allow_mask = allow(&[1, 2, 3], &[]); + let block_mask = block(&[2], &[]); + + let result = allow_mask & block_mask; + assert_mask_selects(&result, &[1, 3], &[2]); + } + + #[test] + fn test_nullable_row_id_mask_bitand_block_block() { + let block1 = block(&[1, 2], &[2]); + let block2 = block(&[2, 3], &[3]); + + let result = block1 & block2; + // block1: F=[1], N=[2]; block2: F=[2], N=[3] + // F & T = F; N & F = F; T & N = N (filtered); T & T = T + assert_mask_selects(&result, &[4], &[1, 2, 3]); + } + + #[test] + fn test_nullable_row_id_mask_bitand_block_block_fast_path() { + // Test BlockList & BlockList fast path (no nulls) + let block1 = block(&[1], &[]); + let block2 = block(&[2], &[]); + + let result = block1 & block2; + assert_mask_selects(&result, &[3], &[1, 2]); + } + + #[test] + fn test_nullable_row_id_mask_bitor_allow_allow_fast_path() { + // Test AllowList | AllowList with no nulls (fast path) + let mask1 = allow(&[1, 2], &[]); + let mask2 = allow(&[3, 4], &[]); + + let result = mask1 | mask2; + assert_mask_selects(&result, &[1, 2, 3, 4], &[5]); + } + + #[test] + fn test_nullable_row_id_mask_bitor_allow_block() { + let allow_mask = allow(&[1, 2, 3], &[2]); + let block_mask = block(&[1, 4], &[4]); + + let result = allow_mask | block_mask; + // allow: T=[1,3], N=[2]; block: F=[1], N=[4], T=everything else + // T|F=T, T|T=T, N|T=T + assert_mask_selects(&result, &[1, 2, 3], &[]); + } + + #[test] + fn test_nullable_row_id_mask_bitor_allow_block_fast_path() { + // Test AllowList | BlockList fast path (no nulls) + let allow_mask = allow(&[1], &[]); + let block_mask = block(&[2], &[]); + + let result = allow_mask | block_mask; + // AllowList([1]) | BlockList([2]) = BlockList([2] - [1]) = BlockList([2]) + assert_mask_selects(&result, &[1, 3], &[2]); + } + + #[test] + fn test_nullable_row_id_mask_bitor_block_block_fast_path() { + // Test BlockList | BlockList with no nulls (fast path) + let block1 = block(&[1, 2], &[]); + let block2 = block(&[2, 3], &[]); + + let result = block1 | block2; + // OR of BlockLists: BlockList([1,2] & [2,3]) = BlockList([2]) + assert_mask_selects(&result, &[1, 3, 4], &[2]); + } +} diff --git a/rust/lance-index/src/scalar.rs b/rust/lance-index/src/scalar.rs index 484cbb5cb2a..98ce994890c 100644 --- a/rust/lance-index/src/scalar.rs +++ b/rust/lance-index/src/scalar.rs @@ -19,7 +19,7 @@ use datafusion_expr::expr::ScalarFunction; use datafusion_expr::Expr; use deepsize::DeepSizeOf; use inverted::query::{fill_fts_query_column, FtsQuery, FtsQueryNode, FtsSearchParams, MatchQuery}; -use lance_core::utils::mask::RowAddrTreeMap; +use lance_core::utils::mask::{NullableRowAddrSet, RowAddrTreeMap}; use lance_core::{Error, Result}; use serde::Serialize; use snafu::location; @@ -685,20 +685,40 @@ impl AnyQuery for TokenQuery { #[derive(Debug, PartialEq)] pub enum SearchResult { /// The exact row ids that satisfy the query - Exact(RowAddrTreeMap), + Exact(NullableRowAddrSet), /// Any row id satisfying the query will be in this set but not every /// row id in this set will satisfy the query, a further recheck step /// is needed - AtMost(RowAddrTreeMap), + AtMost(NullableRowAddrSet), /// All of the given row ids satisfy the query but there may be more /// /// No scalar index actually returns this today but it can arise from /// boolean operations (e.g. NOT(AtMost(x)) == AtLeast(NOT(x))) - AtLeast(RowAddrTreeMap), + AtLeast(NullableRowAddrSet), } impl SearchResult { - pub fn row_addrs(&self) -> &RowAddrTreeMap { + pub fn exact(row_ids: impl Into) -> Self { + Self::Exact(NullableRowAddrSet::new(row_ids.into(), Default::default())) + } + + pub fn at_most(row_ids: impl Into) -> Self { + Self::AtMost(NullableRowAddrSet::new(row_ids.into(), Default::default())) + } + + pub fn at_least(row_ids: impl Into) -> Self { + Self::AtLeast(NullableRowAddrSet::new(row_ids.into(), Default::default())) + } + + pub fn with_nulls(self, nulls: impl Into) -> Self { + match self { + Self::Exact(row_ids) => Self::Exact(row_ids.with_nulls(nulls.into())), + Self::AtMost(row_ids) => Self::AtMost(row_ids.with_nulls(nulls.into())), + Self::AtLeast(row_ids) => Self::AtLeast(row_ids.with_nulls(nulls.into())), + } + } + + pub fn row_addrs(&self) -> &NullableRowAddrSet { match self { Self::Exact(row_addrs) => row_addrs, Self::AtMost(row_addrs) => row_addrs, diff --git a/rust/lance-index/src/scalar/bitmap.rs b/rust/lance-index/src/scalar/bitmap.rs index 8d9eb8fe60e..3766f3381d5 100644 --- a/rust/lance-index/src/scalar/bitmap.rs +++ b/rust/lance-index/src/scalar/bitmap.rs @@ -21,7 +21,10 @@ use futures::{stream, StreamExt, TryStreamExt}; use lance_core::{ cache::{CacheKey, LanceCache, WeakLanceCache}, error::LanceOptionExt, - utils::{mask::RowAddrTreeMap, tokio::get_num_compute_intensive_cpus}, + utils::{ + mask::{NullableRowAddrSet, RowAddrTreeMap}, + tokio::get_num_compute_intensive_cpus, + }, Error, Result, ROW_ID, }; use roaring::RoaringBitmap; @@ -404,15 +407,21 @@ impl ScalarIndex for BitmapIndex { ) -> Result { let query = query.as_any().downcast_ref::().unwrap(); - let row_ids = match query { + let (row_ids, null_row_ids) = match query { SargableQuery::Equals(val) => { metrics.record_comparisons(1); if val.is_null() { - (*self.null_map).clone() + // Querying FOR nulls - they are the TRUE result, not NULL result + ((*self.null_map).clone(), None) } else { let key = OrderableScalarValue(val.clone()); let bitmap = self.load_bitmap(&key, Some(metrics)).await?; - (*bitmap).clone() + let null_rows = if !self.null_map.is_empty() { + Some((*self.null_map).clone()) + } else { + None + }; + ((*bitmap).clone(), null_rows) } } SargableQuery::Range(start, end) => { @@ -436,7 +445,7 @@ impl ScalarIndex for BitmapIndex { metrics.record_comparisons(keys.len()); - if keys.is_empty() { + let result = if keys.is_empty() { RowAddrTreeMap::default() } else { let bitmaps: Vec<_> = stream::iter( @@ -449,7 +458,14 @@ impl ScalarIndex for BitmapIndex { let bitmap_refs: Vec<_> = bitmaps.iter().map(|b| b.as_ref()).collect(); RowAddrTreeMap::union_all(&bitmap_refs) - } + }; + + let null_rows = if !self.null_map.is_empty() { + Some((*self.null_map).clone()) + } else { + None + }; + (result, null_rows) } SargableQuery::IsIn(values) => { metrics.record_comparisons(values.len()); @@ -473,35 +489,41 @@ impl ScalarIndex for BitmapIndex { }) .collect(); - if keys.is_empty() && (!has_null || self.null_map.is_empty()) { + // Load bitmaps in parallel + let mut bitmaps: Vec<_> = stream::iter( + keys.into_iter() + .map(|key| async move { self.load_bitmap(&key, None).await }), + ) + .buffer_unordered(get_num_compute_intensive_cpus()) + .try_collect() + .await?; + + // Add null bitmap if needed + if has_null && !self.null_map.is_empty() { + bitmaps.push(self.null_map.clone()); + } + + let result = if bitmaps.is_empty() { RowAddrTreeMap::default() } else { - // Load bitmaps in parallel - let mut bitmaps: Vec<_> = stream::iter( - keys.into_iter() - .map(|key| async move { self.load_bitmap(&key, None).await }), - ) - .buffer_unordered(get_num_compute_intensive_cpus()) - .try_collect() - .await?; - - // Add null bitmap if needed - if has_null && !self.null_map.is_empty() { - bitmaps.push(self.null_map.clone()); - } + // Convert Arc to &RowAddrTreeMap for union_all + let bitmap_refs: Vec<_> = bitmaps.iter().map(|b| b.as_ref()).collect(); + RowAddrTreeMap::union_all(&bitmap_refs) + }; - if bitmaps.is_empty() { - RowAddrTreeMap::default() - } else { - // Convert Arc to &RowAddrTreeMap for union_all - let bitmap_refs: Vec<_> = bitmaps.iter().map(|b| b.as_ref()).collect(); - RowAddrTreeMap::union_all(&bitmap_refs) - } - } + // If the query explicitly includes null, then nulls are TRUE (not NULL) + // Otherwise, nulls remain NULL (unknown) + let null_rows = if !has_null && !self.null_map.is_empty() { + Some((*self.null_map).clone()) + } else { + None + }; + (result, null_rows) } SargableQuery::IsNull() => { metrics.record_comparisons(1); - (*self.null_map).clone() + // Querying FOR nulls - they are the TRUE result, not NULL result + ((*self.null_map).clone(), None) } SargableQuery::FullTextSearch(_) => { return Err(Error::NotSupported { @@ -511,7 +533,8 @@ impl ScalarIndex for BitmapIndex { } }; - Ok(SearchResult::Exact(row_ids)) + let selection = NullableRowAddrSet::new(row_ids, null_row_ids.unwrap_or_default()); + Ok(SearchResult::Exact(selection)) } fn can_remap(&self) -> bool { @@ -817,7 +840,7 @@ pub mod tests { use super::*; use crate::metrics::NoOpMetricsCollector; use crate::scalar::lance_format::LanceIndexStore; - use arrow_array::{RecordBatch, StringArray, UInt64Array}; + use arrow_array::{record_batch, RecordBatch, StringArray, UInt64Array}; use arrow_schema::{DataType, Field, Schema}; use datafusion::physical_plan::stream::RecordBatchStreamAdapter; use futures::stream; @@ -883,7 +906,12 @@ pub mod tests { // Verify results let expected_red_rows = vec![0u64, 3, 6, 10, 11]; if let SearchResult::Exact(row_ids) = result { - let mut actual: Vec = row_ids.row_addrs().unwrap().map(|id| id.into()).collect(); + let mut actual: Vec = row_ids + .true_rows() + .row_addrs() + .unwrap() + .map(|id| id.into()) + .collect(); actual.sort(); assert_eq!(actual, expected_red_rows); } else { @@ -893,7 +921,12 @@ pub mod tests { // Test 2: Search for "red" again - should hit cache let result = index.search(&query, &NoOpMetricsCollector).await.unwrap(); if let SearchResult::Exact(row_ids) = result { - let mut actual: Vec = row_ids.row_addrs().unwrap().map(|id| id.into()).collect(); + let mut actual: Vec = row_ids + .true_rows() + .row_addrs() + .unwrap() + .map(|id| id.into()) + .collect(); actual.sort(); assert_eq!(actual, expected_red_rows); } @@ -907,7 +940,12 @@ pub mod tests { let expected_range_rows = vec![1u64, 2, 5, 7, 8, 12, 13]; if let SearchResult::Exact(row_ids) = result { - let mut actual: Vec = row_ids.row_addrs().unwrap().map(|id| id.into()).collect(); + let mut actual: Vec = row_ids + .true_rows() + .row_addrs() + .unwrap() + .map(|id| id.into()) + .collect(); actual.sort(); assert_eq!(actual, expected_range_rows); } @@ -921,7 +959,12 @@ pub mod tests { let expected_in_rows = vec![0u64, 3, 4, 6, 9, 10, 11, 14]; if let SearchResult::Exact(row_ids) = result { - let mut actual: Vec = row_ids.row_addrs().unwrap().map(|id| id.into()).collect(); + let mut actual: Vec = row_ids + .true_rows() + .row_addrs() + .unwrap() + .map(|id| id.into()) + .collect(); actual.sort(); assert_eq!(actual, expected_in_rows); } @@ -1292,7 +1335,12 @@ pub mod tests { .await .unwrap(); if let crate::scalar::SearchResult::Exact(row_ids) = result { - let mut actual: Vec = row_ids.row_addrs().unwrap().map(u64::from).collect(); + let mut actual: Vec = row_ids + .true_rows() + .row_addrs() + .unwrap() + .map(u64::from) + .collect(); actual.sort(); let expected: Vec = vec![ RowAddress::new_from_parts(3, 2).into(), @@ -1308,7 +1356,12 @@ pub mod tests { .await .unwrap(); if let crate::scalar::SearchResult::Exact(row_ids) = result { - let mut actual: Vec = row_ids.row_addrs().unwrap().map(u64::from).collect(); + let mut actual: Vec = row_ids + .true_rows() + .row_addrs() + .unwrap() + .map(u64::from) + .collect(); actual.sort(); let expected: Vec = vec![ RowAddress::new_from_parts(3, 4).into(), @@ -1324,7 +1377,12 @@ pub mod tests { .await .unwrap(); if let crate::scalar::SearchResult::Exact(row_ids) = result { - let mut actual: Vec = row_ids.row_addrs().unwrap().map(u64::from).collect(); + let mut actual: Vec = row_ids + .true_rows() + .row_addrs() + .unwrap() + .map(u64::from) + .collect(); actual.sort(); assert_eq!( actual, expected_null_addrs, @@ -1332,4 +1390,114 @@ pub mod tests { ); } } + + #[tokio::test] + async fn test_bitmap_null_handling_in_queries() { + // Test that bitmap index correctly returns null_list for queries + let tmpdir = TempObjDir::default(); + let store = Arc::new(LanceIndexStore::new( + Arc::new(ObjectStore::local()), + tmpdir.clone(), + Arc::new(LanceCache::no_cache()), + )); + + // Create test data: [0, 5, null] + let batch = record_batch!( + ("value", Int64, [Some(0), Some(5), None]), + ("_rowid", UInt64, [0, 1, 2]) + ) + .unwrap(); + let schema = batch.schema(); + let stream = stream::once(async move { Ok(batch) }); + let stream = Box::pin(RecordBatchStreamAdapter::new(schema, stream)); + + // Train and write the bitmap index + BitmapIndexPlugin::train_bitmap_index(stream, store.as_ref()) + .await + .unwrap(); + + let cache = LanceCache::with_capacity(1024 * 1024); + let index = BitmapIndex::load(store.clone(), None, &cache) + .await + .unwrap(); + + // Test 1: Search for value 5 - should return allow=[1], null=[2] + let query = SargableQuery::Equals(ScalarValue::Int64(Some(5))); + let result = index.search(&query, &NoOpMetricsCollector).await.unwrap(); + + match result { + SearchResult::Exact(row_ids) => { + let actual_rows: Vec = row_ids + .true_rows() + .row_addrs() + .unwrap() + .map(u64::from) + .collect(); + assert_eq!(actual_rows, vec![1], "Should find row 1 where value == 5"); + + let null_row_ids = row_ids.null_rows(); + // Check that null_row_ids contains row 2 + assert!(!null_row_ids.is_empty(), "null_row_ids should be Some"); + let null_rows: Vec = + null_row_ids.row_addrs().unwrap().map(u64::from).collect(); + assert_eq!(null_rows, vec![2], "Should report row 2 as null"); + } + _ => panic!("Expected Exact search result"), + } + + // Test 2: Search for null values - should return allow=[2], null=None + let query = SargableQuery::IsNull(); + let result = index.search(&query, &NoOpMetricsCollector).await.unwrap(); + + match result { + SearchResult::Exact(row_addrs) => { + let actual_rows: Vec = row_addrs + .true_rows() + .row_addrs() + .unwrap() + .map(u64::from) + .collect(); + assert_eq!( + actual_rows, + vec![2], + "IsNull should find row 2 where value is null" + ); + + let null_row_ids = row_addrs.null_rows(); + // When querying FOR nulls, null_row_ids should be None (nulls are the TRUE result) + assert!( + null_row_ids.is_empty(), + "null_row_ids should be None for IsNull query" + ); + } + _ => panic!("Expected Exact search result"), + } + + // Test 3: Range query - should return matching rows and null_list + let query = SargableQuery::Range( + std::ops::Bound::Included(ScalarValue::Int64(Some(0))), + std::ops::Bound::Included(ScalarValue::Int64(Some(3))), + ); + let result = index.search(&query, &NoOpMetricsCollector).await.unwrap(); + + match result { + SearchResult::Exact(row_addrs) => { + let actual_rows: Vec = row_addrs + .true_rows() + .row_addrs() + .unwrap() + .map(u64::from) + .collect(); + assert_eq!(actual_rows, vec![0], "Should find row 0 where value == 0"); + + // Should report row 2 as null + let null_row_ids = row_addrs.null_rows(); + assert!(!null_row_ids.is_empty(), "null_row_ids should be Some"); + let null_rows: Vec = + null_row_ids.row_addrs().unwrap().map(u64::from).collect(); + assert_eq!(null_rows, vec![2], "Should report row 2 as null"); + } + _ => panic!("Expected Exact search result"), + } + } } diff --git a/rust/lance-index/src/scalar/bloomfilter.rs b/rust/lance-index/src/scalar/bloomfilter.rs index e1ca463143e..73851ca7aeb 100644 --- a/rust/lance-index/src/scalar/bloomfilter.rs +++ b/rust/lance-index/src/scalar/bloomfilter.rs @@ -1201,7 +1201,7 @@ mod tests { use std::sync::Arc; use crate::scalar::bloomfilter::BloomFilterIndexPlugin; - use arrow_array::{RecordBatch, UInt64Array}; + use arrow_array::{record_batch, RecordBatch, UInt64Array}; use arrow_schema::{DataType, Field, Schema}; use datafusion::execution::SendableRecordBatchStream; use datafusion::physical_plan::stream::RecordBatchStreamAdapter; @@ -1285,7 +1285,7 @@ mod tests { // Equals query: null (should match nothing, as there are no nulls in empty index) let query = BloomFilterQuery::Equals(ScalarValue::Int32(None)); let result = index.search(&query, &NoOpMetricsCollector).await.unwrap(); - assert_eq!(result, SearchResult::AtMost(RowAddrTreeMap::new())); + assert_eq!(result, SearchResult::at_most(RowAddrTreeMap::new())); } #[tokio::test] @@ -1342,14 +1342,14 @@ mod tests { // Should match the block since value 50 is in the range [0, 100) let mut expected = RowAddrTreeMap::new(); expected.insert_range(0..100); - assert_eq!(result, SearchResult::AtMost(expected)); + assert_eq!(result, SearchResult::at_most(expected)); // Test search for a value that shouldn't exist let query = BloomFilterQuery::Equals(ScalarValue::Int32(Some(500))); // Value not in [0, 100) let result = index.search(&query, &NoOpMetricsCollector).await.unwrap(); // Should return empty result since bloom filter correctly filters out this value - assert_eq!(result, SearchResult::AtMost(RowAddrTreeMap::new())); + assert_eq!(result, SearchResult::at_most(RowAddrTreeMap::new())); // Test calculate_included_frags assert_eq!( @@ -1436,7 +1436,7 @@ mod tests { // Value 150 is only in fragment 1 (values 100-199), not in fragment 0 (values 0-99) let mut expected = RowAddrTreeMap::new(); expected.insert_range((1u64 << 32) + 50..((1u64 << 32) + 100)); // Only the block containing 150 - assert_eq!(result, SearchResult::AtMost(expected)); + assert_eq!(result, SearchResult::at_most(expected)); // Test calculate_included_frags assert_eq!( @@ -1502,7 +1502,7 @@ mod tests { // Should match all blocks since they all contain NaN values let mut expected = RowAddrTreeMap::new(); expected.insert_range(0..500); // All rows since NaN is in every block - assert_eq!(result, SearchResult::AtMost(expected)); + assert_eq!(result, SearchResult::at_most(expected)); // Test search for a specific finite value that exists in the data let query = BloomFilterQuery::Equals(ScalarValue::Float32(Some(5.0))); @@ -1511,7 +1511,7 @@ mod tests { // Should match only the first block since 5.0 only exists in rows 0-99 let mut expected = RowAddrTreeMap::new(); expected.insert_range(0..100); - assert_eq!(result, SearchResult::AtMost(expected)); + assert_eq!(result, SearchResult::at_most(expected)); // Test search for a value that doesn't exist but is within expected range let query = BloomFilterQuery::Equals(ScalarValue::Float32(Some(250.0))); @@ -1520,14 +1520,14 @@ mod tests { // Should match the third block since 250.0 would be in that range if it existed let mut expected = RowAddrTreeMap::new(); expected.insert_range(200..300); - assert_eq!(result, SearchResult::AtMost(expected)); + assert_eq!(result, SearchResult::at_most(expected)); // Test search for a value way outside the range let query = BloomFilterQuery::Equals(ScalarValue::Float32(Some(10000.0))); let result = index.search(&query, &NoOpMetricsCollector).await.unwrap(); // Should return empty since bloom filter correctly filters out this value - assert_eq!(result, SearchResult::AtMost(RowAddrTreeMap::new())); + assert_eq!(result, SearchResult::at_most(RowAddrTreeMap::new())); // Test IsIn query with NaN and finite values let query = BloomFilterQuery::IsIn(vec![ @@ -1540,7 +1540,7 @@ mod tests { // Should match all blocks since they all contain NaN values let mut expected = RowAddrTreeMap::new(); expected.insert_range(0..500); - assert_eq!(result, SearchResult::AtMost(expected)); + assert_eq!(result, SearchResult::at_most(expected)); } #[tokio::test] @@ -1601,14 +1601,14 @@ mod tests { // Should match zone 2 let mut expected = RowAddrTreeMap::new(); expected.insert_range(2000..3000); - assert_eq!(result, SearchResult::AtMost(expected)); + assert_eq!(result, SearchResult::at_most(expected)); // Test search for a value way outside the range let query = BloomFilterQuery::Equals(ScalarValue::Int64(Some(50000))); let result = index.search(&query, &NoOpMetricsCollector).await.unwrap(); // Should return empty since bloom filter correctly filters out this value - assert_eq!(result, SearchResult::AtMost(RowAddrTreeMap::new())); + assert_eq!(result, SearchResult::at_most(RowAddrTreeMap::new())); // Test IsIn query with values from different zones let query = BloomFilterQuery::IsIn(vec![ @@ -1624,7 +1624,7 @@ mod tests { expected.insert_range(0..1000); // Zone 0 expected.insert_range(2000..3000); // Zone 2 expected.insert_range(7000..8000); // Zone 7 - assert_eq!(result, SearchResult::AtMost(expected)); + assert_eq!(result, SearchResult::at_most(expected)); // Test calculate_included_frags assert_eq!( @@ -1680,7 +1680,7 @@ mod tests { // Should match the first zone let mut expected = RowAddrTreeMap::new(); expected.insert_range(0..100); - assert_eq!(result, SearchResult::AtMost(expected)); + assert_eq!(result, SearchResult::at_most(expected)); // Test search for a value in the second zone let query = BloomFilterQuery::Equals(ScalarValue::Utf8(Some("value_150".to_string()))); @@ -1689,7 +1689,7 @@ mod tests { // Should match the second zone let mut expected = RowAddrTreeMap::new(); expected.insert_range(100..200); - assert_eq!(result, SearchResult::AtMost(expected)); + assert_eq!(result, SearchResult::at_most(expected)); // Test search for a value that doesn't exist let query = @@ -1697,7 +1697,7 @@ mod tests { let result = index.search(&query, &NoOpMetricsCollector).await.unwrap(); // Should return empty since bloom filter correctly filters out this value - assert_eq!(result, SearchResult::AtMost(RowAddrTreeMap::new())); + assert_eq!(result, SearchResult::at_most(RowAddrTreeMap::new())); // Test IsIn query with string values let query = BloomFilterQuery::IsIn(vec![ @@ -1710,7 +1710,7 @@ mod tests { // Should match both zones let mut expected = RowAddrTreeMap::new(); expected.insert_range(0..200); - assert_eq!(result, SearchResult::AtMost(expected)); + assert_eq!(result, SearchResult::at_most(expected)); } #[tokio::test] @@ -1762,7 +1762,7 @@ mod tests { // Should match the first zone let mut expected = RowAddrTreeMap::new(); expected.insert_range(0..50); - assert_eq!(result, SearchResult::AtMost(expected)); + assert_eq!(result, SearchResult::at_most(expected)); // Test search for a value in the second zone let query = BloomFilterQuery::Equals(ScalarValue::Binary(Some(vec![75, 76, 77]))); @@ -1771,14 +1771,14 @@ mod tests { // Should match the second zone let mut expected = RowAddrTreeMap::new(); expected.insert_range(50..100); - assert_eq!(result, SearchResult::AtMost(expected)); + assert_eq!(result, SearchResult::at_most(expected)); // Test search for a value that doesn't exist let query = BloomFilterQuery::Equals(ScalarValue::Binary(Some(vec![255, 254, 253]))); let result = index.search(&query, &NoOpMetricsCollector).await.unwrap(); // Should return empty since bloom filter correctly filters out this value - assert_eq!(result, SearchResult::AtMost(RowAddrTreeMap::new())); + assert_eq!(result, SearchResult::at_most(RowAddrTreeMap::new())); } #[tokio::test] @@ -1831,7 +1831,7 @@ mod tests { // Should match the first zone let mut expected = RowAddrTreeMap::new(); expected.insert_range(0..50); - assert_eq!(result, SearchResult::AtMost(expected)); + assert_eq!(result, SearchResult::at_most(expected)); // Test search for a value that doesn't exist let query = BloomFilterQuery::Equals(ScalarValue::LargeUtf8(Some( @@ -1840,7 +1840,7 @@ mod tests { let result = index.search(&query, &NoOpMetricsCollector).await.unwrap(); // Should return empty since bloom filter correctly filters out this value - assert_eq!(result, SearchResult::AtMost(RowAddrTreeMap::new())); + assert_eq!(result, SearchResult::at_most(RowAddrTreeMap::new())); } #[tokio::test] @@ -1887,19 +1887,19 @@ mod tests { let result = index.search(&query, &NoOpMetricsCollector).await.unwrap(); let mut expected = RowAddrTreeMap::new(); expected.insert_range(0..50); - assert_eq!(result, SearchResult::AtMost(expected)); + assert_eq!(result, SearchResult::at_most(expected)); // Test search for Date32 value in second zone let query = BloomFilterQuery::Equals(ScalarValue::Date32(Some(75))); let result = index.search(&query, &NoOpMetricsCollector).await.unwrap(); let mut expected = RowAddrTreeMap::new(); expected.insert_range(50..100); - assert_eq!(result, SearchResult::AtMost(expected)); + assert_eq!(result, SearchResult::at_most(expected)); // Test search for Date32 value that doesn't exist let query = BloomFilterQuery::Equals(ScalarValue::Date32(Some(500))); let result = index.search(&query, &NoOpMetricsCollector).await.unwrap(); - assert_eq!(result, SearchResult::AtMost(RowAddrTreeMap::new())); + assert_eq!(result, SearchResult::at_most(RowAddrTreeMap::new())); } #[tokio::test] @@ -1951,7 +1951,7 @@ mod tests { let result = index.search(&query, &NoOpMetricsCollector).await.unwrap(); let mut expected = RowAddrTreeMap::new(); expected.insert_range(0..50); - assert_eq!(result, SearchResult::AtMost(expected)); + assert_eq!(result, SearchResult::at_most(expected)); // Test search for Timestamp value in second zone let second_timestamp = timestamp_values[75]; @@ -1962,13 +1962,13 @@ mod tests { let result = index.search(&query, &NoOpMetricsCollector).await.unwrap(); let mut expected = RowAddrTreeMap::new(); expected.insert_range(50..100); - assert_eq!(result, SearchResult::AtMost(expected)); + assert_eq!(result, SearchResult::at_most(expected)); // Test search for Timestamp value that doesn't exist let query = BloomFilterQuery::Equals(ScalarValue::TimestampNanosecond(Some(999_999_999i64), None)); let result = index.search(&query, &NoOpMetricsCollector).await.unwrap(); - assert_eq!(result, SearchResult::AtMost(RowAddrTreeMap::new())); + assert_eq!(result, SearchResult::at_most(RowAddrTreeMap::new())); // Test IsIn query with multiple timestamp values let query = BloomFilterQuery::IsIn(vec![ @@ -1979,7 +1979,7 @@ mod tests { let result = index.search(&query, &NoOpMetricsCollector).await.unwrap(); let mut expected = RowAddrTreeMap::new(); expected.insert_range(0..100); // Should match both zones - assert_eq!(result, SearchResult::AtMost(expected)); + assert_eq!(result, SearchResult::at_most(expected)); } #[tokio::test] @@ -2030,12 +2030,12 @@ mod tests { let result = index.search(&query, &NoOpMetricsCollector).await.unwrap(); let mut expected = RowAddrTreeMap::new(); expected.insert_range(0..25); - assert_eq!(result, SearchResult::AtMost(expected)); + assert_eq!(result, SearchResult::at_most(expected)); // Test search for Time64 value that doesn't exist let query = BloomFilterQuery::Equals(ScalarValue::Time64Microsecond(Some(999_999_999i64))); let result = index.search(&query, &NoOpMetricsCollector).await.unwrap(); - assert_eq!(result, SearchResult::AtMost(RowAddrTreeMap::new())); + assert_eq!(result, SearchResult::at_most(RowAddrTreeMap::new())); } #[tokio::test] @@ -2081,12 +2081,12 @@ mod tests { let result = index.search(&query, &NoOpMetricsCollector).await.unwrap(); let mut expected = RowAddrTreeMap::new(); expected.insert_range(500..750); // Should match the zone containing 500 - assert_eq!(result, SearchResult::AtMost(expected)); + assert_eq!(result, SearchResult::at_most(expected)); // Test IsNull query let query = BloomFilterQuery::IsNull(); let result = index.search(&query, &NoOpMetricsCollector).await.unwrap(); - assert_eq!(result, SearchResult::AtMost(RowAddrTreeMap::new())); // No nulls in the data + assert_eq!(result, SearchResult::at_most(RowAddrTreeMap::new())); // No nulls in the data // Test IsIn query let query = BloomFilterQuery::IsIn(vec![ @@ -2097,6 +2097,86 @@ mod tests { let mut expected = RowAddrTreeMap::new(); expected.insert_range(0..250); // Zone containing 100 expected.insert_range(500..750); // Zone containing 600 - assert_eq!(result, SearchResult::AtMost(expected)); + assert_eq!(result, SearchResult::at_most(expected)); + } + + #[tokio::test] + async fn test_bloomfilter_null_handling_in_queries() { + // Test that bloomfilter index correctly returns null_list for queries + let tmpdir = TempObjDir::default(); + let store = Arc::new(LanceIndexStore::new( + Arc::new(ObjectStore::local()), + tmpdir.clone(), + Arc::new(LanceCache::no_cache()), + )); + + // Create test data: [0, 5, null] + let batch = record_batch!( + (VALUE_COLUMN_NAME, Int64, [Some(0), Some(5), None]), + (ROW_ADDR, UInt64, [0, 1, 2]) + ) + .unwrap(); + let schema = batch.schema(); + let stream = stream::once(async move { Ok(batch) }); + let stream = Box::pin(RecordBatchStreamAdapter::new(schema, stream)); + + // Train and write the bloomfilter index + BloomFilterIndexPlugin::train_bloomfilter_index(stream, store.as_ref(), None) + .await + .unwrap(); + + let cache = LanceCache::with_capacity(1024 * 1024); + let index = BloomFilterIndex::load(store.clone(), None, &cache) + .await + .unwrap(); + + // Test 1: Search for value 5 - bloomfilter should return at_most with all rows + // Like ZoneMap, BloomFilter returns AtMost (superset) and includes nulls + let query = BloomFilterQuery::Equals(ScalarValue::Int64(Some(5))); + let result = index.search(&query, &NoOpMetricsCollector).await.unwrap(); + + match result { + SearchResult::AtMost(row_addrs) => { + // Bloomfilter returns all rows in the zone including nulls + let all_rows: Vec = row_addrs + .true_rows() + .row_addrs() + .unwrap() + .map(u64::from) + .collect(); + assert_eq!( + all_rows, + vec![0, 1, 2], + "Should return all rows (including nulls) since BloomFilter is inexact" + ); + + // For AtMost results, nulls are included in the superset + } + _ => panic!("Expected AtMost search result from bloomfilter"), + } + + // Test 2: IsIn query - should also return all rows + let query = BloomFilterQuery::IsIn(vec![ + ScalarValue::Int64(Some(0)), + ScalarValue::Int64(Some(10)), + ]); + let result = index.search(&query, &NoOpMetricsCollector).await.unwrap(); + + match result { + SearchResult::AtMost(row_addrs) => { + let all_rows: Vec = row_addrs + .true_rows() + .row_addrs() + .unwrap() + .map(u64::from) + .collect(); + assert_eq!( + all_rows, + vec![0, 1, 2], + "Should return all rows in zone as possible matches" + ); + } + _ => panic!("Expected AtMost search result from bloomfilter"), + } } } diff --git a/rust/lance-index/src/scalar/btree.rs b/rust/lance-index/src/scalar/btree.rs index 291e32169ad..6a04dc88d40 100644 --- a/rust/lance-index/src/scalar/btree.rs +++ b/rust/lance-index/src/scalar/btree.rs @@ -44,7 +44,7 @@ use lance_core::{ cache::{CacheKey, LanceCache, WeakLanceCache}, error::LanceOptionExt, utils::{ - mask::RowAddrTreeMap, + mask::NullableRowAddrSet, tokio::get_num_compute_intensive_cpus, tracing::{IO_TYPE_LOAD_SCALAR_PART, TRACE_IO_EVENTS}, }, @@ -832,7 +832,7 @@ impl BTreeIndex { page_number: u32, index_reader: LazyIndexReader, metrics: &dyn MetricsCollector, - ) -> Result { + ) -> Result { let subindex = self.lookup_page(page_number, index_reader, metrics).await?; // TODO: If this is an IN query we can perhaps simplify the subindex query by restricting it to the // values that might be in the page. E.g. if we are searching for X IN [5, 3, 7] and five is in pages @@ -1172,13 +1172,19 @@ impl ScalarIndex for BTreeIndex { }) .collect::>(); debug!("Searching {} btree pages", page_tasks.len()); - let row_ids = stream::iter(page_tasks) + + // Collect both matching row IDs and null row IDs from all pages + let results: Vec = stream::iter(page_tasks) // I/O and compute mixed here but important case is index in cache so // use compute intensive thread count .buffered(get_num_compute_intensive_cpus()) - .try_collect::() + .try_collect() .await?; - Ok(SearchResult::Exact(row_ids)) + + // Merge matching row IDs + let selection = NullableRowAddrSet::union_all(&results); + + Ok(SearchResult::Exact(selection)) } fn can_remap(&self) -> bool { @@ -2003,7 +2009,7 @@ mod tests { use std::{collections::HashMap, sync::Arc}; use arrow::datatypes::{Float32Type, Float64Type, Int32Type, UInt64Type}; - use arrow_array::FixedSizeListArray; + use arrow_array::{record_batch, FixedSizeListArray}; use arrow_schema::DataType; use datafusion::{ execution::{SendableRecordBatchStream, TaskContext}, @@ -2012,12 +2018,14 @@ mod tests { use datafusion_common::{DataFusionError, ScalarValue}; use datafusion_physical_expr::{expressions::col, PhysicalSortExpr}; use deepsize::DeepSizeOf; + use futures::stream; use futures::TryStreamExt; use lance_core::utils::tempfile::TempObjDir; use lance_core::{cache::LanceCache, utils::mask::RowAddrTreeMap}; use lance_datafusion::{chunker::break_stream, datagen::DatafusionDatagenExt}; use lance_datagen::{array, gen_batch, ArrayGeneratorExt, BatchCount, RowCount}; use lance_io::object_store::ObjectStore; + use object_store::path::Path; use crate::metrics::LocalMetricsCollector; use crate::{ @@ -2165,7 +2173,7 @@ mod tests { let result = index.search(&query, &NoOpMetricsCollector).await.unwrap(); assert_eq!( result, - SearchResult::Exact(RowAddrTreeMap::from_iter(((idx as u64)..1000).step_by(7))) + SearchResult::exact(RowAddrTreeMap::from_iter(((idx as u64)..1000).step_by(7))) ); } } @@ -2872,4 +2880,117 @@ mod tests { // This test mainly verifies that the function doesn't panic and handles edge cases super::cleanup_partition_files(&test_store, &lookup_files, &page_files).await; } + + #[tokio::test] + async fn test_btree_null_handling_in_queries() { + let store = Arc::new(LanceIndexStore::new( + Arc::new(ObjectStore::memory()), + Path::default(), + Arc::new(LanceCache::no_cache()), + )); + + // Create test data: [null, 0, 5] at row IDs [0, 1, 2] + // BTree expects sorted data with nulls first (or filtered out) + let batch = record_batch!( + ("value", Int32, [None, Some(0), Some(5)]), + ("_rowid", UInt64, [0, 1, 2]) + ) + .unwrap(); + let stream = stream::once(futures::future::ok(batch.clone())); + let stream = Box::pin(RecordBatchStreamAdapter::new(batch.schema(), stream)); + + // Train the btree index with FlatIndexMetadata as sub-index + let sub_index_trainer = super::FlatIndexMetadata::new(DataType::Int32); + super::train_btree_index(stream, &sub_index_trainer, store.as_ref(), 256, None) + .await + .unwrap(); + + let cache = LanceCache::with_capacity(1024 * 1024); + let index = super::BTreeIndex::load(store.clone(), None, &cache) + .await + .unwrap(); + + // Test 1: Search for value 5 - should return allow=[2], null=[0] + let query = SargableQuery::Equals(ScalarValue::Int32(Some(5))); + let result = index.search(&query, &NoOpMetricsCollector).await.unwrap(); + + match result { + SearchResult::Exact(row_ids) => { + let actual_rows: Vec = row_ids + .true_rows() + .row_addrs() + .unwrap() + .map(u64::from) + .collect(); + assert_eq!(actual_rows, vec![2], "Should find row 2 where value == 5"); + + // Check that null_row_ids contains row 0 + let null_row_ids = row_ids.null_rows(); + assert!(!null_row_ids.is_empty(), "null_row_ids should be non-empty"); + let null_rows: Vec = + null_row_ids.row_addrs().unwrap().map(u64::from).collect(); + assert_eq!(null_rows, vec![0], "Should report row 0 as null"); + } + _ => panic!("Expected Exact search result"), + } + + // Test 2: Range query [0, 3] - should return allow=[1], null=[0] + let query = SargableQuery::Range( + std::ops::Bound::Included(ScalarValue::Int32(Some(0))), + std::ops::Bound::Included(ScalarValue::Int32(Some(3))), + ); + let result = index.search(&query, &NoOpMetricsCollector).await.unwrap(); + + match result { + SearchResult::Exact(row_ids) => { + let actual_rows: Vec = row_ids + .true_rows() + .row_addrs() + .unwrap() + .map(u64::from) + .collect(); + assert_eq!(actual_rows, vec![1], "Should find row 1 where value == 0"); + + // Should report row 0 as null + let null_row_ids = row_ids.null_rows(); + assert!(!null_row_ids.is_empty(), "null_row_ids should be non-empty"); + let null_rows: Vec = + null_row_ids.row_addrs().unwrap().map(u64::from).collect(); + assert_eq!(null_rows, vec![0], "Should report row 0 as null"); + } + _ => panic!("Expected Exact search result"), + } + + // Test 3: IsIn query [0, 5] - should return allow=[1, 2], null=[0] + let query = SargableQuery::IsIn(vec![ + ScalarValue::Int32(Some(0)), + ScalarValue::Int32(Some(5)), + ]); + let result = index.search(&query, &NoOpMetricsCollector).await.unwrap(); + + match result { + SearchResult::Exact(row_ids) => { + let mut actual_rows: Vec = row_ids + .true_rows() + .row_addrs() + .unwrap() + .map(u64::from) + .collect(); + actual_rows.sort(); + assert_eq!( + actual_rows, + vec![1, 2], + "Should find rows 1 and 2 where value in [0, 5]" + ); + + // Should report row 0 as null + let null_row_ids = row_ids.null_rows(); + assert!(!null_row_ids.is_empty(), "null_row_ids should be non-empty"); + let null_rows: Vec = + null_row_ids.row_addrs().unwrap().map(u64::from).collect(); + assert_eq!(null_rows, vec![0], "Should report row 0 as null"); + } + _ => panic!("Expected Exact search result"), + } + } } diff --git a/rust/lance-index/src/scalar/expression.rs b/rust/lance-index/src/scalar/expression.rs index 2e867bc9de8..0162163acb9 100644 --- a/rust/lance-index/src/scalar/expression.rs +++ b/rust/lance-index/src/scalar/expression.rs @@ -16,13 +16,16 @@ use datafusion_expr::{ expr::{InList, ScalarFunction}, Between, BinaryExpr, Expr, Operator, ReturnFieldArgs, ScalarUDF, }; +use tokio::try_join; use super::{ AnyQuery, BloomFilterQuery, LabelListQuery, MetricsCollector, SargableQuery, ScalarIndex, SearchResult, TextQuery, TokenQuery, }; -use futures::join; -use lance_core::{utils::mask::RowIdMask, Error, Result}; +use lance_core::{ + utils::mask::{NullableRowIdMask, RowIdMask}, + Error, Result, +}; use lance_datafusion::{expr::safe_coerce_scalar, planner::Planner}; use roaring::RoaringBitmap; use snafu::location; @@ -902,6 +905,81 @@ pub static INDEX_EXPR_RESULT_SCHEMA: LazyLock = LazyLock::new(|| { ])) }); +#[derive(Debug)] +enum NullableIndexExprResult { + Exact(NullableRowIdMask), + AtMost(NullableRowIdMask), + AtLeast(NullableRowIdMask), +} + +impl From for NullableIndexExprResult { + fn from(result: SearchResult) -> Self { + match result { + SearchResult::Exact(mask) => Self::Exact(NullableRowIdMask::AllowList(mask)), + SearchResult::AtMost(mask) => Self::AtMost(NullableRowIdMask::AllowList(mask)), + SearchResult::AtLeast(mask) => Self::AtLeast(NullableRowIdMask::AllowList(mask)), + } + } +} + +impl std::ops::BitAnd for NullableIndexExprResult { + type Output = Self; + + fn bitand(self, rhs: Self) -> Self { + match (self, rhs) { + (Self::Exact(lhs), Self::Exact(rhs)) => Self::Exact(lhs & rhs), + (Self::Exact(lhs), Self::AtMost(rhs)) | (Self::AtMost(lhs), Self::Exact(rhs)) => { + Self::AtMost(lhs & rhs) + } + (Self::Exact(exact), Self::AtLeast(_)) | (Self::AtLeast(_), Self::Exact(exact)) => { + // We could do better here, elements in both lhs and rhs are known + // to be true and don't require a recheck. We only need to recheck + // elements in lhs that are not in rhs + Self::AtMost(exact) + } + (Self::AtMost(lhs), Self::AtMost(rhs)) => Self::AtMost(lhs & rhs), + (Self::AtLeast(lhs), Self::AtLeast(rhs)) => Self::AtLeast(lhs & rhs), + (Self::AtMost(most), Self::AtLeast(_)) | (Self::AtLeast(_), Self::AtMost(most)) => { + Self::AtMost(most) + } + } + } +} + +impl std::ops::BitOr for NullableIndexExprResult { + type Output = Self; + + fn bitor(self, rhs: Self) -> Self { + match (self, rhs) { + (Self::Exact(lhs), Self::Exact(rhs)) => Self::Exact(lhs | rhs), + (Self::Exact(lhs), Self::AtMost(rhs)) | (Self::AtMost(rhs), Self::Exact(lhs)) => { + // We could do better here, elements in lhs are known to be true + // and don't require a recheck. We only need to recheck elements + // in rhs that are not in lhs + Self::AtMost(lhs | rhs) + } + (Self::Exact(lhs), Self::AtLeast(rhs)) | (Self::AtLeast(rhs), Self::Exact(lhs)) => { + Self::AtLeast(lhs | rhs) + } + (Self::AtMost(lhs), Self::AtMost(rhs)) => Self::AtMost(lhs | rhs), + (Self::AtLeast(lhs), Self::AtLeast(rhs)) => Self::AtLeast(lhs | rhs), + (Self::AtMost(_), Self::AtLeast(least)) | (Self::AtLeast(least), Self::AtMost(_)) => { + Self::AtLeast(least) + } + } + } +} + +impl NullableIndexExprResult { + pub fn drop_nulls(self) -> IndexExprResult { + match self { + Self::Exact(mask) => IndexExprResult::Exact(mask.drop_nulls()), + Self::AtMost(mask) => IndexExprResult::AtMost(mask.drop_nulls()), + Self::AtLeast(mask) => IndexExprResult::AtLeast(mask.drop_nulls()), + } + } +} + #[derive(Debug)] pub enum IndexExprResult { // The answer is exactly the rows in the allow list minus the rows in the block list @@ -981,117 +1059,59 @@ impl ScalarIndexExpr { /// TODO: We could potentially try and be smarter about reusing loaded indices for /// any situations where the session cache has been disabled. #[async_recursion] - #[instrument(level = "debug", skip_all)] - pub async fn evaluate( + async fn evaluate_impl( &self, index_loader: &dyn ScalarIndexLoader, metrics: &dyn MetricsCollector, - ) -> Result { + ) -> Result { match self { Self::Not(inner) => { - let result = inner.evaluate(index_loader, metrics).await?; - match result { - IndexExprResult::Exact(mask) => Ok(IndexExprResult::Exact(!mask)), - IndexExprResult::AtMost(mask) => Ok(IndexExprResult::AtLeast(!mask)), - IndexExprResult::AtLeast(mask) => Ok(IndexExprResult::AtMost(!mask)), - } - } - Self::And(lhs, rhs) => { - let lhs_result = lhs.evaluate(index_loader, metrics); - let rhs_result = rhs.evaluate(index_loader, metrics); - let (lhs_result, rhs_result) = join!(lhs_result, rhs_result); - match (lhs_result?, rhs_result?) { - (IndexExprResult::Exact(lhs), IndexExprResult::Exact(rhs)) => { - Ok(IndexExprResult::Exact(lhs & rhs)) - } - (IndexExprResult::Exact(lhs), IndexExprResult::AtMost(rhs)) - | (IndexExprResult::AtMost(lhs), IndexExprResult::Exact(rhs)) => { - Ok(IndexExprResult::AtMost(lhs & rhs)) - } - (IndexExprResult::Exact(lhs), IndexExprResult::AtLeast(_)) => { - // We could do better here, elements in both lhs and rhs are known - // to be true and don't require a recheck. We only need to recheck - // elements in lhs that are not in rhs - Ok(IndexExprResult::AtMost(lhs)) - } - (IndexExprResult::AtLeast(_), IndexExprResult::Exact(rhs)) => { - // We could do better here (see above) - Ok(IndexExprResult::AtMost(rhs)) - } - (IndexExprResult::AtMost(lhs), IndexExprResult::AtMost(rhs)) => { - Ok(IndexExprResult::AtMost(lhs & rhs)) - } - (IndexExprResult::AtLeast(lhs), IndexExprResult::AtLeast(rhs)) => { - Ok(IndexExprResult::AtLeast(lhs & rhs)) + let result = inner.evaluate_impl(index_loader, metrics).await?; + // Flip certainty: NOT(AtMost) → AtLeast, NOT(AtLeast) → AtMost + Ok(match result { + NullableIndexExprResult::Exact(mask) => NullableIndexExprResult::Exact(!mask), + NullableIndexExprResult::AtMost(mask) => { + NullableIndexExprResult::AtLeast(!mask) } - (IndexExprResult::AtLeast(_), IndexExprResult::AtMost(rhs)) => { - Ok(IndexExprResult::AtMost(rhs)) + NullableIndexExprResult::AtLeast(mask) => { + NullableIndexExprResult::AtMost(!mask) } - (IndexExprResult::AtMost(lhs), IndexExprResult::AtLeast(_)) => { - Ok(IndexExprResult::AtMost(lhs)) - } - } + }) + } + Self::And(lhs, rhs) => { + let lhs_result = lhs.evaluate_impl(index_loader, metrics); + let rhs_result = rhs.evaluate_impl(index_loader, metrics); + let (lhs_result, rhs_result) = try_join!(lhs_result, rhs_result)?; + Ok(lhs_result & rhs_result) } Self::Or(lhs, rhs) => { - let lhs_result = lhs.evaluate(index_loader, metrics); - let rhs_result = rhs.evaluate(index_loader, metrics); - let (lhs_result, rhs_result) = join!(lhs_result, rhs_result); - match (lhs_result?, rhs_result?) { - (IndexExprResult::Exact(lhs), IndexExprResult::Exact(rhs)) => { - Ok(IndexExprResult::Exact(lhs | rhs)) - } - (IndexExprResult::Exact(lhs), IndexExprResult::AtMost(rhs)) - | (IndexExprResult::AtMost(lhs), IndexExprResult::Exact(rhs)) => { - // We could do better here. Elements in the exact side don't need - // re-check. We only need to recheck elements exclusively in the - // at-most side - Ok(IndexExprResult::AtMost(lhs | rhs)) - } - (IndexExprResult::Exact(lhs), IndexExprResult::AtLeast(rhs)) => { - Ok(IndexExprResult::AtLeast(lhs | rhs)) - } - (IndexExprResult::AtLeast(lhs), IndexExprResult::Exact(rhs)) => { - Ok(IndexExprResult::AtLeast(lhs | rhs)) - } - (IndexExprResult::AtMost(lhs), IndexExprResult::AtMost(rhs)) => { - Ok(IndexExprResult::AtMost(lhs | rhs)) - } - (IndexExprResult::AtLeast(lhs), IndexExprResult::AtLeast(rhs)) => { - Ok(IndexExprResult::AtLeast(lhs | rhs)) - } - (IndexExprResult::AtLeast(lhs), IndexExprResult::AtMost(_)) => { - Ok(IndexExprResult::AtLeast(lhs)) - } - (IndexExprResult::AtMost(_), IndexExprResult::AtLeast(rhs)) => { - Ok(IndexExprResult::AtLeast(rhs)) - } - } + let lhs_result = lhs.evaluate_impl(index_loader, metrics); + let rhs_result = rhs.evaluate_impl(index_loader, metrics); + let (lhs_result, rhs_result) = try_join!(lhs_result, rhs_result)?; + Ok(lhs_result | rhs_result) } Self::Query(search) => { let index = index_loader .load_index(&search.column, &search.index_name, metrics) .await?; let search_result = index.search(search.query.as_ref(), metrics).await?; - match search_result { - SearchResult::Exact(matching_row_ids) => { - Ok(IndexExprResult::Exact(RowIdMask { - block_list: None, - allow_list: Some(matching_row_ids), - })) - } - SearchResult::AtMost(row_ids) => Ok(IndexExprResult::AtMost(RowIdMask { - block_list: None, - allow_list: Some(row_ids), - })), - SearchResult::AtLeast(row_ids) => Ok(IndexExprResult::AtLeast(RowIdMask { - block_list: None, - allow_list: Some(row_ids), - })), - } + Ok(search_result.into()) } } } + #[instrument(level = "debug", skip_all)] + pub async fn evaluate( + &self, + index_loader: &dyn ScalarIndexLoader, + metrics: &dyn MetricsCollector, + ) -> Result { + Ok(self + .evaluate_impl(index_loader, metrics) + .await? + .drop_nulls()) + } + pub fn to_expr(&self) -> Expr { match self { Self::Not(inner) => Expr::Not(inner.to_expr().into()), @@ -2175,4 +2195,125 @@ mod tests { check_no_index(&index_info, "aisle BETWEEN 5 AND NULL"); check_no_index(&index_info, "aisle BETWEEN NULL AND 10"); } + + #[tokio::test] + async fn test_not_flips_certainty() { + use lance_core::utils::mask::{NullableRowAddrSet, RowAddrTreeMap}; + + // Test that NOT flips certainty for inexact index results + // This tests the implementation in evaluate_impl for Self::Not + + // Helper function that mimics the NOT logic we just fixed + fn apply_not(result: NullableIndexExprResult) -> NullableIndexExprResult { + match result { + NullableIndexExprResult::Exact(mask) => NullableIndexExprResult::Exact(!mask), + NullableIndexExprResult::AtMost(mask) => NullableIndexExprResult::AtLeast(!mask), + NullableIndexExprResult::AtLeast(mask) => NullableIndexExprResult::AtMost(!mask), + } + } + + // AtMost: superset of matches (e.g., bloom filter says "might be in [1,2]") + let at_most = NullableIndexExprResult::AtMost(NullableRowIdMask::AllowList( + NullableRowAddrSet::new(RowAddrTreeMap::from_iter(&[1, 2]), RowAddrTreeMap::new()), + )); + // NOT(AtMost) should be AtLeast (definitely NOT in [1,2], might be elsewhere) + assert!(matches!( + apply_not(at_most), + NullableIndexExprResult::AtLeast(_) + )); + + // AtLeast: subset of matches (e.g., definitely in [1,2], might be more) + let at_least = NullableIndexExprResult::AtLeast(NullableRowIdMask::AllowList( + NullableRowAddrSet::new(RowAddrTreeMap::from_iter(&[1, 2]), RowAddrTreeMap::new()), + )); + // NOT(AtLeast) should be AtMost (might NOT be in [1,2], definitely elsewhere) + assert!(matches!( + apply_not(at_least), + NullableIndexExprResult::AtMost(_) + )); + + // Exact should stay Exact + let exact = NullableIndexExprResult::Exact(NullableRowIdMask::AllowList( + NullableRowAddrSet::new(RowAddrTreeMap::from_iter(&[1, 2]), RowAddrTreeMap::new()), + )); + assert!(matches!( + apply_not(exact), + NullableIndexExprResult::Exact(_) + )); + } + + #[tokio::test] + async fn test_and_or_preserve_certainty() { + use lance_core::utils::mask::{NullableRowAddrSet, RowAddrTreeMap}; + + // Test that AND/OR correctly propagate certainty + let make_at_most = || { + NullableIndexExprResult::AtMost(NullableRowIdMask::AllowList(NullableRowAddrSet::new( + RowAddrTreeMap::from_iter(&[1, 2, 3]), + RowAddrTreeMap::new(), + ))) + }; + + let make_at_least = || { + NullableIndexExprResult::AtLeast(NullableRowIdMask::AllowList(NullableRowAddrSet::new( + RowAddrTreeMap::from_iter(&[2, 3, 4]), + RowAddrTreeMap::new(), + ))) + }; + + let make_exact = || { + NullableIndexExprResult::Exact(NullableRowIdMask::AllowList(NullableRowAddrSet::new( + RowAddrTreeMap::from_iter(&[1, 2]), + RowAddrTreeMap::new(), + ))) + }; + + // AtMost & AtMost → AtMost + assert!(matches!( + make_at_most() & make_at_most(), + NullableIndexExprResult::AtMost(_) + )); + + // AtLeast & AtLeast → AtLeast + assert!(matches!( + make_at_least() & make_at_least(), + NullableIndexExprResult::AtLeast(_) + )); + + // AtMost & AtLeast → AtMost (superset remains superset) + assert!(matches!( + make_at_most() & make_at_least(), + NullableIndexExprResult::AtMost(_) + )); + + // AtMost | AtMost → AtMost + assert!(matches!( + make_at_most() | make_at_most(), + NullableIndexExprResult::AtMost(_) + )); + + // AtLeast | AtLeast → AtLeast + assert!(matches!( + make_at_least() | make_at_least(), + NullableIndexExprResult::AtLeast(_) + )); + + // AtMost | AtLeast → AtLeast (subset coverage guaranteed) + assert!(matches!( + make_at_most() | make_at_least(), + NullableIndexExprResult::AtLeast(_) + )); + + // Exact & AtMost → AtMost + assert!(matches!( + make_exact() & make_at_most(), + NullableIndexExprResult::AtMost(_) + )); + + // Exact | AtLeast → AtLeast + assert!(matches!( + make_exact() | make_at_least(), + NullableIndexExprResult::AtLeast(_) + )); + } } diff --git a/rust/lance-index/src/scalar/flat.rs b/rust/lance-index/src/scalar/flat.rs index a350dcb470c..ff6c0bcd11c 100644 --- a/rust/lance-index/src/scalar/flat.rs +++ b/rust/lance-index/src/scalar/flat.rs @@ -15,7 +15,7 @@ use datafusion_physical_expr::expressions::{in_list, lit, Column}; use deepsize::DeepSizeOf; use lance_core::error::LanceOptionExt; use lance_core::utils::address::RowAddress; -use lance_core::utils::mask::RowAddrTreeMap; +use lance_core::utils::mask::{NullableRowAddrSet, RowAddrTreeMap}; use lance_core::{Error, Result, ROW_ID}; use roaring::RoaringBitmap; use snafu::location; @@ -299,14 +299,37 @@ impl ScalarIndex for FlatIndex { let valid_values = arrow::compute::is_not_null(self.values())?; predicate = arrow::compute::and(&valid_values, &predicate)?; } + + // Track null row IDs for Kleene logic + // When querying FOR nulls (IS NULL or Equals(null)), don't track them as "null results" + // because they are the TRUE result of the query + let null_row_ids = if self.has_nulls + && !matches!(query, SargableQuery::IsNull()) + && !matches!(query, SargableQuery::Equals(val) if val.is_null()) + { + let null_mask = arrow::compute::is_null(self.values())?; + let null_ids = arrow_select::filter::filter(self.ids(), &null_mask)?; + let null_ids = null_ids + .as_any() + .downcast_ref::() + .expect("Result of arrow_select::filter::filter did not match input type"); + if null_ids.is_empty() { + None + } else { + Some(RowAddrTreeMap::from_iter(null_ids.values())) + } + } else { + None + }; + let matching_ids = arrow_select::filter::filter(self.ids(), &predicate)?; let matching_ids = matching_ids .as_any() .downcast_ref::() .expect("Result of arrow_select::filter::filter did not match input type"); - Ok(SearchResult::Exact(RowAddrTreeMap::from_iter( - matching_ids.values(), - ))) + let selected = RowAddrTreeMap::from_iter(matching_ids.values()); + let selection = NullableRowAddrSet::new(selected, null_row_ids.unwrap_or_default()); + Ok(SearchResult::Exact(selection)) } fn can_remap(&self) -> bool { @@ -372,7 +395,8 @@ mod tests { let SearchResult::Exact(actual_row_ids) = actual else { panic! {"Expected exact search result"} }; - let expected = RowAddrTreeMap::from_iter(expected); + let expected = + NullableRowAddrSet::new(RowAddrTreeMap::from_iter(expected), Default::default()); assert_eq!(actual_row_ids, expected); } diff --git a/rust/lance-index/src/scalar/inverted/index.rs b/rust/lance-index/src/scalar/inverted/index.rs index 4f30040fc31..08b9192a88e 100644 --- a/rust/lance-index/src/scalar/inverted/index.rs +++ b/rust/lance-index/src/scalar/inverted/index.rs @@ -547,7 +547,7 @@ impl ScalarIndex for InvertedIndex { .downcast_ref::() .unwrap(); let row_ids = row_ids.iter().flatten().collect_vec(); - Ok(SearchResult::AtMost(RowAddrTreeMap::from_iter(row_ids))) + Ok(SearchResult::at_most(RowAddrTreeMap::from_iter(row_ids))) } } } diff --git a/rust/lance-index/src/scalar/label_list.rs b/rust/lance-index/src/scalar/label_list.rs index 91d3a9063fe..b9850b3c01c 100644 --- a/rust/lance-index/src/scalar/label_list.rs +++ b/rust/lance-index/src/scalar/label_list.rs @@ -13,7 +13,8 @@ use datafusion_common::ScalarValue; use deepsize::DeepSizeOf; use futures::{stream::BoxStream, StreamExt, TryStream, TryStreamExt}; use lance_core::cache::LanceCache; -use lance_core::{utils::mask::RowAddrTreeMap, Error, Result}; +use lance_core::utils::mask::NullableRowAddrSet; +use lance_core::{Error, Result}; use roaring::RoaringBitmap; use snafu::location; use tracing::instrument; @@ -41,7 +42,7 @@ trait LabelListSubIndex: ScalarIndex + DeepSizeOf { &self, query: &dyn AnyQuery, metrics: &dyn MetricsCollector, - ) -> Result { + ) -> Result { let result = self.search(query, metrics).await?; match result { SearchResult::Exact(row_ids) => Ok(row_ids), @@ -118,7 +119,7 @@ impl LabelListIndex { &'a self, values: &'a Vec, metrics: &'a dyn MetricsCollector, - ) -> BoxStream<'a, Result> { + ) -> BoxStream<'a, Result> { futures::stream::iter(values) .then(move |value| { let value_query = SargableQuery::Equals(value.clone()); @@ -129,24 +130,24 @@ impl LabelListIndex { async fn set_union<'a>( &'a self, - mut sets: impl TryStream + 'a + Unpin, + mut sets: impl TryStream + 'a + Unpin, single_set: bool, - ) -> Result { + ) -> Result { let mut union_bitmap = sets.try_next().await?.unwrap(); if single_set { return Ok(union_bitmap); } while let Some(next) = sets.try_next().await? { - union_bitmap |= next; + union_bitmap |= &next; } Ok(union_bitmap) } async fn set_intersection<'a>( &'a self, - mut sets: impl TryStream + 'a + Unpin, + mut sets: impl TryStream + 'a + Unpin, single_set: bool, - ) -> Result { + ) -> Result { let mut intersect_bitmap = sets.try_next().await?.unwrap(); if single_set { return Ok(intersect_bitmap); diff --git a/rust/lance-index/src/scalar/lance_format.rs b/rust/lance-index/src/scalar/lance_format.rs index ac63b89c7e6..463953ee801 100644 --- a/rust/lance-index/src/scalar/lance_format.rs +++ b/rust/lance-index/src/scalar/lance_format.rs @@ -313,7 +313,7 @@ pub mod tests { bitmap::BitmapIndex, btree::{train_btree_index, DEFAULT_BTREE_BATCH_SIZE}, flat::FlatIndexMetadata, - LabelListQuery, SargableQuery, ScalarIndex, + LabelListQuery, SargableQuery, ScalarIndex, SearchResult, }; use super::*; @@ -321,7 +321,7 @@ pub mod tests { use arrow_array::{ cast::AsArray, types::{Int32Type, UInt64Type}, - RecordBatchIterator, RecordBatchReader, StringArray, UInt64Array, + ListArray, RecordBatchIterator, RecordBatchReader, StringArray, UInt64Array, }; use arrow_schema::Schema as ArrowSchema; use arrow_schema::{DataType, Field, TimeUnit}; @@ -402,9 +402,9 @@ pub mod tests { .unwrap(); assert!(result.is_exact()); - let row_addrs = result.row_addrs(); - assert_eq!(Some(1), row_addrs.len()); - assert!(row_addrs.contains(10000)); + let row_ids = result.row_addrs().true_rows(); + assert_eq!(Some(1), row_ids.len()); + assert!(row_ids.contains(10000)); let result = index .search( @@ -418,7 +418,7 @@ pub mod tests { .unwrap(); assert!(result.is_exact()); - let row_addrs = result.row_addrs(); + let row_addrs = result.row_addrs().true_rows(); assert_eq!(Some(0), row_addrs.len()); @@ -434,7 +434,7 @@ pub mod tests { .unwrap(); assert!(result.is_exact()); - let row_addrs = result.row_addrs(); + let row_addrs = result.row_addrs().true_rows(); assert_eq!(Some(100), row_addrs.len()); } @@ -494,7 +494,7 @@ pub mod tests { .unwrap(); assert!(result.is_exact()); - let row_addrs = result.row_addrs(); + let row_addrs = result.row_addrs().true_rows(); assert_eq!(Some(1), row_addrs.len()); assert!(row_addrs.contains(10000)); @@ -508,7 +508,7 @@ pub mod tests { .unwrap(); assert!(result.is_exact()); - let row_addrs = result.row_addrs(); + let row_addrs = result.row_addrs().true_rows(); assert_eq!(Some(1), row_addrs.len()); assert!(row_addrs.contains(500_000)); @@ -518,7 +518,7 @@ pub mod tests { let results = index.search(&query, &NoOpMetricsCollector).await.unwrap(); assert!(results.is_exact()); let expected_arr = RowAddrTreeMap::from_iter(expected); - assert_eq!(results.row_addrs(), &expected_arr); + assert_eq!(&results.row_addrs().true_rows(), &expected_arr); } #[tokio::test] @@ -823,7 +823,7 @@ pub mod tests { .unwrap(); assert!(result.is_exact()); - let row_addrs = result.row_addrs(); + let row_addrs = result.row_addrs().true_rows(); // The random data may have had duplicates so there might be more than 1 result // but even for boolean we shouldn't match the entire thing @@ -886,7 +886,7 @@ pub mod tests { .unwrap(); assert!(result.is_exact()); - let row_addrs = result.row_addrs(); + let row_addrs = result.row_addrs().true_rows(); assert!(row_addrs.is_empty()); @@ -895,7 +895,7 @@ pub mod tests { .await .unwrap(); assert!(result.is_exact()); - let row_addrs = result.row_addrs(); + let row_addrs = result.row_addrs().true_rows(); assert_eq!(row_addrs.len(), Some(4096)); } @@ -962,7 +962,7 @@ pub mod tests { .unwrap(); assert!(result.is_exact()); - let row_addrs = result.row_addrs(); + let row_addrs = result.row_addrs().true_rows(); assert_eq!(Some(1), row_addrs.len()); assert!(row_addrs.contains(2)); @@ -975,7 +975,7 @@ pub mod tests { .unwrap(); assert!(result.is_exact()); - let row_addrs = result.row_addrs(); + let row_addrs = result.row_addrs().true_rows(); assert_eq!(Some(3), row_addrs.len()); assert!(row_addrs.contains(1)); assert!(row_addrs.contains(3)); @@ -1004,7 +1004,7 @@ pub mod tests { .unwrap(); assert!(result.is_exact()); - let row_addrs = result.row_addrs(); + let row_addrs = result.row_addrs().true_rows(); assert_eq!(Some(1), row_addrs.len()); assert!(row_addrs.contains(10000)); @@ -1020,7 +1020,7 @@ pub mod tests { .unwrap(); assert!(result.is_exact()); - let row_addrs = result.row_addrs(); + let row_addrs = result.row_addrs().true_rows(); assert!(row_addrs.is_empty()); let result = index @@ -1035,7 +1035,7 @@ pub mod tests { .unwrap(); assert!(result.is_exact()); - let row_addrs = result.row_addrs(); + let row_addrs = result.row_addrs().true_rows(); assert_eq!(Some(100), row_addrs.len()); } @@ -1043,7 +1043,7 @@ pub mod tests { let results = index.search(&query, &NoOpMetricsCollector).await.unwrap(); assert!(results.is_exact()); let expected_arr = RowAddrTreeMap::from_iter(expected); - assert_eq!(results.row_addrs(), &expected_arr); + assert_eq!(&results.row_addrs().true_rows(), &expected_arr); } #[tokio::test] @@ -1307,7 +1307,7 @@ pub mod tests { .unwrap(); assert!(result.is_exact()); - let row_addrs = result.row_addrs(); + let row_addrs = result.row_addrs().true_rows(); assert_eq!(Some(1), row_addrs.len()); assert!(row_addrs.contains(5000)); } @@ -1357,7 +1357,7 @@ pub mod tests { .await .unwrap() .row_addrs() - .contains(65)); + .selected(65)); // Deleted assert!(remapped_index .search( @@ -1377,7 +1377,7 @@ pub mod tests { .await .unwrap() .row_addrs() - .contains(3)); + .selected(3)); } async fn train_tag( @@ -1442,7 +1442,7 @@ pub mod tests { .unwrap(); let result = index.search(&query, &NoOpMetricsCollector).await.unwrap(); assert!(result.is_exact()); - let row_addrs = result.row_addrs(); + let row_addrs = result.row_addrs().true_rows(); let row_addrs_set = row_addrs .row_addrs() @@ -1506,4 +1506,85 @@ pub mod tests { ) .await; } + + #[tokio::test] + async fn test_label_list_null_handling() { + let tempdir = TempDir::default(); + let index_store = test_store(&tempdir); + + // Create test data with null items within lists: + // Row 0: [1, 2] - no nulls + // Row 1: [3, null] - has a null item + // Row 2: [4] - no nulls + let list_array = ListArray::from_iter_primitive::(vec![ + Some(vec![Some(1), Some(2)]), + Some(vec![Some(3), None]), + Some(vec![Some(4)]), + ]); + let row_ids = UInt64Array::from_iter_values(0..3); + // Create schema with nullable list items to match the ListArray + let schema = Arc::new(Schema::new(vec![ + Field::new( + VALUE_COLUMN_NAME, + DataType::List(Arc::new(Field::new("item", DataType::UInt8, true))), + true, + ), + Field::new(ROW_ID, DataType::UInt64, false), + ])); + let batch = RecordBatch::try_new( + schema.clone(), + vec![Arc::new(list_array), Arc::new(row_ids)], + ) + .unwrap(); + + let batch_reader = RecordBatchIterator::new(vec![Ok(batch)], schema); + train_tag(&index_store, batch_reader).await; + + let index = LabelListIndexPlugin + .load_index( + index_store, + &default_details::(), + None, + &LanceCache::no_cache(), + ) + .await + .unwrap(); + + // Test: Search for lists containing value 1 + // Row 0: [1, 2] - contains 1 → TRUE + // Row 1: [3, null] - has null item, unknown if it matches → NULL + // Row 2: [4] - doesn't contain 1 → FALSE + let query = LabelListQuery::HasAnyLabel(vec![ScalarValue::UInt8(Some(1))]); + let result = index.search(&query, &NoOpMetricsCollector).await.unwrap(); + + match result { + SearchResult::Exact(row_ids) => { + let actual_rows: Vec = row_ids + .true_rows() + .row_addrs() + .unwrap() + .map(u64::from) + .collect(); + assert_eq!( + actual_rows, + vec![0], + "Should find row 0 where list contains 1" + ); + + let null_row_ids = row_ids.null_rows(); + assert!( + !null_row_ids.is_empty(), + "null_row_ids should not be empty - row 1 has null item" + ); + let null_rows: Vec = + null_row_ids.row_addrs().unwrap().map(u64::from).collect(); + assert_eq!( + null_rows, + vec![1], + "Should report row 1 as null because it contains a null item" + ); + } + _ => panic!("Expected Exact search result"), + } + } } diff --git a/rust/lance-index/src/scalar/ngram.rs b/rust/lance-index/src/scalar/ngram.rs index aec9cc29dcc..052d712b619 100644 --- a/rust/lance-index/src/scalar/ngram.rs +++ b/rust/lance-index/src/scalar/ngram.rs @@ -451,7 +451,7 @@ impl ScalarIndex for NGramIndex { TextQuery::StringContains(substr) => { if substr.len() < NGRAM_N { // We know nothing on short searches, need to recheck all - return Ok(SearchResult::AtLeast(RowAddrTreeMap::new())); + return Ok(SearchResult::at_least(RowAddrTreeMap::new())); } let mut row_offsets = Vec::with_capacity(substr.len() * 3); @@ -466,7 +466,7 @@ impl ScalarIndex for NGramIndex { }); // At least one token was missing, so we know there are zero results if missing { - return Ok(SearchResult::Exact(RowAddrTreeMap::new())); + return Ok(SearchResult::exact(RowAddrTreeMap::new())); } let posting_lists = futures::stream::iter( row_offsets @@ -479,7 +479,7 @@ impl ScalarIndex for NGramIndex { metrics.record_comparisons(posting_lists.len()); let list_refs = posting_lists.iter().map(|list| list.as_ref()); let row_ids = NGramPostingList::intersect(list_refs); - Ok(SearchResult::AtMost(RowAddrTreeMap::from(row_ids))) + Ok(SearchResult::at_most(RowAddrTreeMap::from(row_ids))) } } } @@ -1487,7 +1487,7 @@ mod tests { .await .unwrap(); - let expected = SearchResult::AtMost(RowAddrTreeMap::from_iter([0, 2, 3])); + let expected = SearchResult::at_most(RowAddrTreeMap::from_iter([0, 2, 3])); assert_eq!(expected, res); @@ -1499,7 +1499,7 @@ mod tests { ) .await .unwrap(); - let expected = SearchResult::AtMost(RowAddrTreeMap::from_iter([8])); + let expected = SearchResult::at_most(RowAddrTreeMap::from_iter([8])); assert_eq!(expected, res); // No matches @@ -1510,7 +1510,7 @@ mod tests { ) .await .unwrap(); - let expected = SearchResult::Exact(RowAddrTreeMap::new()); + let expected = SearchResult::exact(RowAddrTreeMap::new()); assert_eq!(expected, res); // False positive @@ -1521,7 +1521,7 @@ mod tests { ) .await .unwrap(); - let expected = SearchResult::AtMost(RowAddrTreeMap::from_iter([8])); + let expected = SearchResult::at_most(RowAddrTreeMap::from_iter([8])); assert_eq!(expected, res); // Too short, don't know anything @@ -1532,7 +1532,7 @@ mod tests { ) .await .unwrap(); - let expected = SearchResult::AtLeast(RowAddrTreeMap::new()); + let expected = SearchResult::at_least(RowAddrTreeMap::new()); assert_eq!(expected, res); // One short string but we still get at least one trigram, this is ok @@ -1543,7 +1543,7 @@ mod tests { ) .await .unwrap(); - let expected = SearchResult::AtMost(RowAddrTreeMap::from_iter([8])); + let expected = SearchResult::at_most(RowAddrTreeMap::from_iter([8])); assert_eq!(expected, res); } @@ -1582,7 +1582,7 @@ mod tests { ) .await .unwrap(); - let expected = SearchResult::AtMost(RowAddrTreeMap::from_iter([0, 4])); + let expected = SearchResult::at_most(RowAddrTreeMap::from_iter([0, 4])); assert_eq!(expected, res); let null_posting_list = get_null_posting_list(&index).await; diff --git a/rust/lance-index/src/scalar/zoned.rs b/rust/lance-index/src/scalar/zoned.rs index 02ef1098ee0..bb2be962d16 100644 --- a/rust/lance-index/src/scalar/zoned.rs +++ b/rust/lance-index/src/scalar/zoned.rs @@ -270,7 +270,7 @@ where } } - Ok(crate::scalar::SearchResult::AtMost(row_addr_tree_map)) + Ok(crate::scalar::SearchResult::at_most(row_addr_tree_map)) } /// Helper that retrains zones from `stream` and appends them to the existing @@ -745,14 +745,14 @@ mod tests { }; // Fragment 0, offsets 0 and 1 - assert!(map.contains(0)); - assert!(map.contains(1)); + assert!(map.selected(0)); + assert!(map.selected(1)); // Fragment 1 should be skipped entirely - assert!(!map.contains((1_u64 << 32) + 5)); - assert!(!map.contains((1_u64 << 32) + 7)); + assert!(!map.selected((1_u64 << 32) + 5)); + assert!(!map.selected((1_u64 << 32) + 7)); // Fragment 2 includes only the single offset 10 - assert!(map.contains((2_u64 << 32) + 10)); - assert!(!map.contains((2_u64 << 32) + 11)); + assert!(map.selected((2_u64 << 32) + 10)); + assert!(!map.selected((2_u64 << 32) + 11)); } #[test] diff --git a/rust/lance-index/src/scalar/zonemap.rs b/rust/lance-index/src/scalar/zonemap.rs index f41d97ee57d..b631ba89d48 100644 --- a/rust/lance-index/src/scalar/zonemap.rs +++ b/rust/lance-index/src/scalar/zonemap.rs @@ -941,12 +941,13 @@ mod tests { use crate::scalar::zoned::ZoneBound; use crate::scalar::zonemap::{ZoneMapIndexPlugin, ZoneMapStatistics}; use arrow::datatypes::Float32Type; - use arrow_array::{Array, RecordBatch, UInt64Array}; + use arrow_array::{record_batch, Array, RecordBatch, UInt64Array}; use arrow_schema::{DataType, Field, Schema}; use datafusion::execution::SendableRecordBatchStream; use datafusion::physical_plan::stream::RecordBatchStreamAdapter; use datafusion_common::ScalarValue; use futures::{stream, StreamExt, TryStreamExt}; + use lance_core::utils::mask::NullableRowAddrSet; use lance_core::utils::tempfile::TempObjDir; use lance_core::{cache::LanceCache, utils::mask::RowAddrTreeMap, ROW_ADDR}; use lance_datafusion::datagen::DatafusionDatagenExt; @@ -1029,7 +1030,7 @@ mod tests { // Equals query: null (should match nothing, as there are no nulls) let query = SargableQuery::Equals(ScalarValue::Int32(None)); let result = index.search(&query, &NoOpMetricsCollector).await.unwrap(); - assert_eq!(result, SearchResult::AtMost(RowAddrTreeMap::new())); + assert_eq!(result, SearchResult::at_most(RowAddrTreeMap::new())); } #[tokio::test] @@ -1085,7 +1086,7 @@ mod tests { let end = start + 5000; expected.insert_range(start..end); } - assert_eq!(result, SearchResult::AtMost(expected)); + assert_eq!(result, SearchResult::at_most(expected)); // Test update - add new data with Float32 values (matching the original data type) let new_data = @@ -1143,7 +1144,7 @@ mod tests { let end = start + 5000; expected.insert_range(start..end); } - assert_eq!(result, SearchResult::AtMost(expected)); + assert_eq!(result, SearchResult::at_most(expected)); // Test search for a value that should be in the new zone let query = SargableQuery::Equals(ScalarValue::Float32(Some(2.5))); // Value 2500/1000 = 2.5 @@ -1157,7 +1158,90 @@ mod tests { let start = 10u64 << 32; let end = start + 5000; expected.insert_range(start..end); - assert_eq!(result, SearchResult::AtMost(expected)); + assert_eq!(result, SearchResult::at_most(expected)); + } + + #[tokio::test] + async fn test_zonemap_null_handling_in_queries() { + // Test that zonemap index correctly returns null_list for queries + let tmpdir = TempObjDir::default(); + let store = Arc::new(LanceIndexStore::new( + Arc::new(ObjectStore::local()), + tmpdir.clone(), + Arc::new(LanceCache::no_cache()), + )); + + // Create test data: [0, 5, null] + let batch = record_batch!( + (VALUE_COLUMN_NAME, Int64, [Some(0), Some(5), None]), + (ROW_ADDR, UInt64, [0, 1, 2]) + ) + .unwrap(); + let schema = batch.schema(); + let stream = stream::once(async move { Ok(batch) }); + let stream = Box::pin(RecordBatchStreamAdapter::new(schema, stream)); + + // Train and write the zonemap index + ZoneMapIndexPlugin::train_zonemap_index(stream, store.as_ref(), None) + .await + .unwrap(); + + let cache = LanceCache::with_capacity(1024 * 1024); + let index = ZoneMapIndex::load(store.clone(), None, &cache) + .await + .unwrap(); + + // Test 1: Search for value 5 - zonemap should return at_most with all rows + // Since ZoneMap returns AtMost (superset), it's correct to include nulls in the result + let query = SargableQuery::Equals(ScalarValue::Int64(Some(5))); + let result = index.search(&query, &NoOpMetricsCollector).await.unwrap(); + + match result { + SearchResult::AtMost(row_ids) => { + // Zonemap can't determine exact matches, so it returns all rows in the zone + // This includes nulls because ZoneMap can't prove they don't match + let all_rows: Vec = row_ids + .true_rows() + .row_addrs() + .unwrap() + .map(u64::from) + .collect(); + assert_eq!( + all_rows, + vec![0, 1, 2], + "Should return all rows (including nulls) since ZoneMap is inexact" + ); + + // For AtMost results, nulls are included in the superset + // Downstream processing will handle null filtering + } + _ => panic!("Expected AtMost search result from zonemap"), + } + + // Test 2: Range query - should also return all rows as AtMost + let query = SargableQuery::Range( + std::ops::Bound::Included(ScalarValue::Int64(Some(0))), + std::ops::Bound::Included(ScalarValue::Int64(Some(3))), + ); + let result = index.search(&query, &NoOpMetricsCollector).await.unwrap(); + + match result { + SearchResult::AtMost(row_ids) => { + // Again, ZoneMap returns superset including nulls + let all_rows: Vec = row_ids + .true_rows() + .row_addrs() + .unwrap() + .map(u64::from) + .collect(); + assert_eq!( + all_rows, + vec![0, 1, 2], + "Should return all rows in zone as possible matches" + ); + } + _ => panic!("Expected AtMost search result from zonemap"), + } } #[tokio::test] @@ -1237,7 +1321,7 @@ mod tests { // Should match all zones since they all contain NaN values let mut expected = RowAddrTreeMap::new(); expected.insert_range(0..500); // All rows since NaN is in every zone - assert_eq!(result, SearchResult::AtMost(expected)); + assert_eq!(result, SearchResult::at_most(expected)); // Test search for a specific finite value that exists in the data let query = SargableQuery::Equals(ScalarValue::Float32(Some(5.0))); @@ -1246,7 +1330,7 @@ mod tests { // Should match only the first zone since 5.0 only exists in rows 0-99 let mut expected = RowAddrTreeMap::new(); expected.insert_range(0..100); - assert_eq!(result, SearchResult::AtMost(expected)); + assert_eq!(result, SearchResult::at_most(expected)); // Test search for a value that doesn't exist let query = SargableQuery::Equals(ScalarValue::Float32(Some(1000.0))); @@ -1256,7 +1340,7 @@ mod tests { // as potential matches for any finite target (false positive, but acceptable for zone maps) let mut expected = RowAddrTreeMap::new(); expected.insert_range(0..500); - assert_eq!(result, SearchResult::AtMost(expected)); + assert_eq!(result, SearchResult::at_most(expected)); // Test range query that should include finite values let query = SargableQuery::Range( @@ -1268,7 +1352,7 @@ mod tests { // Should match the first three zones since they contain values in the range [0, 250] let mut expected = RowAddrTreeMap::new(); expected.insert_range(0..300); - assert_eq!(result, SearchResult::AtMost(expected)); + assert_eq!(result, SearchResult::at_most(expected)); // Test IsIn query with NaN and finite values let query = SargableQuery::IsIn(vec![ @@ -1281,7 +1365,7 @@ mod tests { // Should match all zones since they all contain NaN values let mut expected = RowAddrTreeMap::new(); expected.insert_range(0..500); - assert_eq!(result, SearchResult::AtMost(expected)); + assert_eq!(result, SearchResult::at_most(expected)); // Test range query that excludes all values let query = SargableQuery::Range( @@ -1294,12 +1378,12 @@ mod tests { // as potential matches for any range query (false positive, but acceptable for zone maps) let mut expected = RowAddrTreeMap::new(); expected.insert_range(0..500); - assert_eq!(result, SearchResult::AtMost(expected)); + assert_eq!(result, SearchResult::at_most(expected)); // Test IsNull query (should match nothing since there are no null values) let query = SargableQuery::IsNull(); let result = index.search(&query, &NoOpMetricsCollector).await.unwrap(); - assert_eq!(result, SearchResult::AtMost(RowAddrTreeMap::new())); + assert_eq!(result, SearchResult::AtMost(NullableRowAddrSet::empty())); // Test range queries with NaN bounds // Range with NaN as start bound (included) @@ -1311,7 +1395,7 @@ mod tests { // Should match all zones since they all contain NaN values let mut expected = RowAddrTreeMap::new(); expected.insert_range(0..500); - assert_eq!(result, SearchResult::AtMost(expected)); + assert_eq!(result, SearchResult::at_most(expected)); // Range with NaN as end bound (included) let query = SargableQuery::Range( @@ -1322,7 +1406,7 @@ mod tests { // Should match all zones since they all contain NaN values let mut expected = RowAddrTreeMap::new(); expected.insert_range(0..500); - assert_eq!(result, SearchResult::AtMost(expected)); + assert_eq!(result, SearchResult::at_most(expected)); // Range with NaN as end bound (excluded) let query = SargableQuery::Range( @@ -1333,7 +1417,7 @@ mod tests { // Should match all zones since everything is less than NaN let mut expected = RowAddrTreeMap::new(); expected.insert_range(0..500); - assert_eq!(result, SearchResult::AtMost(expected)); + assert_eq!(result, SearchResult::at_most(expected)); // Range with NaN as start bound (excluded) let query = SargableQuery::Range( @@ -1342,7 +1426,7 @@ mod tests { ); let result = index.search(&query, &NoOpMetricsCollector).await.unwrap(); // Should match nothing since nothing is greater than NaN - assert_eq!(result, SearchResult::AtMost(RowAddrTreeMap::new())); + assert_eq!(result, SearchResult::AtMost(NullableRowAddrSet::empty())); // Test IsIn query with mixed float types (Float16, Float32, Float64) let query = SargableQuery::IsIn(vec![ @@ -1355,7 +1439,7 @@ mod tests { // Should match all zones since they all contain NaN values let mut expected = RowAddrTreeMap::new(); expected.insert_range(0..500); - assert_eq!(result, SearchResult::AtMost(expected)); + assert_eq!(result, SearchResult::at_most(expected)); } #[tokio::test] @@ -1481,10 +1565,7 @@ mod tests { Bound::Unbounded, ); let result = index.search(&query, &NoOpMetricsCollector).await.unwrap(); - assert_eq!( - result, - SearchResult::AtMost(RowAddrTreeMap::from_iter(0..=100)) - ); + assert_eq!(result, SearchResult::at_most(0..=100)); // 2. Range query: [0, 50] let query = SargableQuery::Range( @@ -1492,10 +1573,7 @@ mod tests { Bound::Included(ScalarValue::Int32(Some(50))), ); let result = index.search(&query, &NoOpMetricsCollector).await.unwrap(); - assert_eq!( - result, - SearchResult::AtMost(RowAddrTreeMap::from_iter(0..=99)) - ); + assert_eq!(result, SearchResult::at_most(0..=99)); // 3. Range query: [101, 200] (should only match the second zone, which is row 100) let query = SargableQuery::Range( @@ -1504,7 +1582,7 @@ mod tests { ); let result = index.search(&query, &NoOpMetricsCollector).await.unwrap(); // Only row 100 is in the second zone, but its value is 100, so this should be empty - assert_eq!(result, SearchResult::AtMost(RowAddrTreeMap::new())); + assert_eq!(result, SearchResult::at_most(RowAddrTreeMap::new())); // 4. Range query: [100, 100] (should match only the last row) let query = SargableQuery::Range( @@ -1512,37 +1590,27 @@ mod tests { Bound::Included(ScalarValue::Int32(Some(100))), ); let result = index.search(&query, &NoOpMetricsCollector).await.unwrap(); - assert_eq!( - result, - SearchResult::AtMost(RowAddrTreeMap::from_iter(100..=100)) - ); + assert_eq!(result, SearchResult::at_most(100..=100)); // 5. Equals query: 0 (should match first row) let query = SargableQuery::Equals(ScalarValue::Int32(Some(0))); let result = index.search(&query, &NoOpMetricsCollector).await.unwrap(); - assert_eq!( - result, - SearchResult::AtMost(RowAddrTreeMap::from_iter(0..100)) - ); + assert_eq!(result, SearchResult::at_most(0..=99)); // 6. Equals query: 100 (should match only last row) let query = SargableQuery::Equals(ScalarValue::Int32(Some(100))); let result = index.search(&query, &NoOpMetricsCollector).await.unwrap(); - assert_eq!( - result, - SearchResult::AtMost(RowAddrTreeMap::from_iter(100..=100)) - ); + assert_eq!(result, SearchResult::at_most(100..=100)); // 7. Equals query: 101 (should match nothing) let query = SargableQuery::Equals(ScalarValue::Int32(Some(101))); let result = index.search(&query, &NoOpMetricsCollector).await.unwrap(); - assert_eq!(result, SearchResult::AtMost(RowAddrTreeMap::new())); + assert_eq!(result, SearchResult::at_most(RowAddrTreeMap::new())); // 8. IsNull query (no nulls in data, should match nothing) let query = SargableQuery::IsNull(); let result = index.search(&query, &NoOpMetricsCollector).await.unwrap(); - assert_eq!(result, SearchResult::AtMost(RowAddrTreeMap::new())); - + assert_eq!(result, SearchResult::at_most(RowAddrTreeMap::new())); // 9. IsIn query: [0, 100, 101, 50] let query = SargableQuery::IsIn(vec![ ScalarValue::Int32(Some(0)), @@ -1552,10 +1620,7 @@ mod tests { ]); let result = index.search(&query, &NoOpMetricsCollector).await.unwrap(); // 0 and 50 are in the first zone, 100 in the second, 101 is not present - assert_eq!( - result, - SearchResult::AtMost(RowAddrTreeMap::from_iter(0..=100)) - ); + assert_eq!(result, SearchResult::at_most(0..=100)); // 10. IsIn query: [101, 102] (should match nothing) let query = SargableQuery::IsIn(vec![ @@ -1563,17 +1628,17 @@ mod tests { ScalarValue::Int32(Some(102)), ]); let result = index.search(&query, &NoOpMetricsCollector).await.unwrap(); - assert_eq!(result, SearchResult::AtMost(RowAddrTreeMap::new())); + assert_eq!(result, SearchResult::at_most(RowAddrTreeMap::new())); // 11. IsIn query: [null] (should match nothing, as there are no nulls) let query = SargableQuery::IsIn(vec![ScalarValue::Int32(None)]); let result = index.search(&query, &NoOpMetricsCollector).await.unwrap(); - assert_eq!(result, SearchResult::AtMost(RowAddrTreeMap::new())); + assert_eq!(result, SearchResult::at_most(RowAddrTreeMap::new())); // 12. Equals query: null (should match nothing, as there are no nulls) let query = SargableQuery::Equals(ScalarValue::Int32(None)); let result = index.search(&query, &NoOpMetricsCollector).await.unwrap(); - assert_eq!(result, SearchResult::AtMost(RowAddrTreeMap::new())); + assert_eq!(result, SearchResult::at_most(RowAddrTreeMap::new())); } #[tokio::test] @@ -1675,7 +1740,7 @@ mod tests { // Should match row 1000 in fragment 0: row address = (0 << 32) + 1000 = 1000 let mut expected = RowAddrTreeMap::new(); expected.insert_range(0..=8191); - assert_eq!(result, SearchResult::AtMost(expected)); + assert_eq!(result, SearchResult::at_most(expected)); // Search for a value in the second zone let query = SargableQuery::Equals(ScalarValue::Int64(Some(9000))); @@ -1683,12 +1748,12 @@ mod tests { // Should match row 9000 in fragment 0: row address = (0 << 32) + 9000 = 9000 let mut expected = RowAddrTreeMap::new(); expected.insert_range(8192..=16383); - assert_eq!(result, SearchResult::AtMost(expected)); + assert_eq!(result, SearchResult::at_most(expected)); // Search for a value not present in any zone let query = SargableQuery::Equals(ScalarValue::Int64(Some(20000))); let result = index.search(&query, &NoOpMetricsCollector).await.unwrap(); - assert_eq!(result, SearchResult::AtMost(RowAddrTreeMap::new())); + assert_eq!(result, SearchResult::at_most(RowAddrTreeMap::new())); // Search for a range that spans multiple zones let query = SargableQuery::Range( @@ -1699,7 +1764,7 @@ mod tests { // Should match all rows from 8000 to 16400 (inclusive) let mut expected = RowAddrTreeMap::new(); expected.insert_range(8192..=16425); - assert_eq!(result, SearchResult::AtMost(expected)); + assert_eq!(result, SearchResult::at_most(expected)); } #[tokio::test] @@ -1921,7 +1986,7 @@ mod tests { expected.insert_range(5000..8192); // zone 2 expected.insert_range((1u64 << 32)..((1u64 << 32) + 5000)); - assert_eq!(result, SearchResult::AtMost(expected)); + assert_eq!(result, SearchResult::at_most(expected)); // Test exact match query from zone 2 let query = SargableQuery::Equals(ScalarValue::Int64(Some(8192))); @@ -1929,7 +1994,7 @@ mod tests { // Should include zone 2 since it contains value 8192 let mut expected = RowAddrTreeMap::new(); expected.insert_range((1u64 << 32)..((1u64 << 32) + 5000)); - assert_eq!(result, SearchResult::AtMost(expected)); + assert_eq!(result, SearchResult::at_most(expected)); // Test exact match query from zone 4 let query = SargableQuery::Equals(ScalarValue::Int64(Some(16385))); @@ -1937,19 +2002,19 @@ mod tests { // Should include zone 4 since it contains value 16385 let mut expected = RowAddrTreeMap::new(); expected.insert_range(2u64 << 32..((2u64 << 32) + 42)); - assert_eq!(result, SearchResult::AtMost(expected)); + assert_eq!(result, SearchResult::at_most(expected)); // Test query that matches nothing let query = SargableQuery::Equals(ScalarValue::Int64(Some(99999))); let result = index.search(&query, &NoOpMetricsCollector).await.unwrap(); - assert_eq!(result, SearchResult::AtMost(RowAddrTreeMap::new())); + assert_eq!(result, SearchResult::at_most(RowAddrTreeMap::new())); // Test is_in query let query = SargableQuery::IsIn(vec![ScalarValue::Int64(Some(16385))]); let result = index.search(&query, &NoOpMetricsCollector).await.unwrap(); let mut expected = RowAddrTreeMap::new(); expected.insert_range(2u64 << 32..((2u64 << 32) + 42)); - assert_eq!(result, SearchResult::AtMost(expected)); + assert_eq!(result, SearchResult::at_most(expected)); // Test equals query with null let query = SargableQuery::Equals(ScalarValue::Int64(None)); @@ -1957,7 +2022,7 @@ mod tests { let mut expected = RowAddrTreeMap::new(); expected.insert_range(0..=16425); // expected = {:?}", expected - assert_eq!(result, SearchResult::AtMost(RowAddrTreeMap::new())); + assert_eq!(result, SearchResult::at_most(RowAddrTreeMap::new())); } // Each fragment is its own batch diff --git a/rust/lance/benches/scalar_index.rs b/rust/lance/benches/scalar_index.rs index 16787aa8776..f24816710bd 100644 --- a/rust/lance/benches/scalar_index.rs +++ b/rust/lance/benches/scalar_index.rs @@ -118,7 +118,7 @@ async fn warm_indexed_equality_search(index: &dyn ScalarIndex) { let SearchResult::Exact(row_ids) = result else { panic!("Expected exact results") }; - assert_eq!(row_ids.len(), Some(1)); + assert_eq!(row_ids.true_rows().len(), Some(1)); } async fn baseline_inequality_search(fixture: &BenchmarkFixture) { @@ -155,7 +155,7 @@ async fn warm_indexed_inequality_search(index: &dyn ScalarIndex) { }; // 100Mi - 50M = 54,857,600 - assert_eq!(row_ids.len(), Some(54857600)); + assert_eq!(row_ids.true_rows().len(), Some(54857600)); } async fn warm_indexed_isin_search(index: &dyn ScalarIndex) { @@ -176,7 +176,7 @@ async fn warm_indexed_isin_search(index: &dyn ScalarIndex) { }; // Only 3 because 150M is not in dataset - assert_eq!(row_ids.len(), Some(3)); + assert_eq!(row_ids.true_rows().len(), Some(3)); } fn bench_baseline(c: &mut Criterion) { diff --git a/rust/lance/src/index/prefilter.rs b/rust/lance/src/index/prefilter.rs index 9ccf9c485c2..cb202b9c8cb 100644 --- a/rust/lance/src/index/prefilter.rs +++ b/rust/lance/src/index/prefilter.rs @@ -350,7 +350,7 @@ mod test { ); assert!(mask.is_some()); let mask = mask.unwrap().await.unwrap(); - assert_eq!(mask.block_list.as_ref().and_then(|x| x.len()), Some(1)); // There was just one row deleted. + assert_eq!(mask.block_list().and_then(|x| x.len()), Some(1)); // There was just one row deleted. // If there are deletions and missing fragments, we should get a mask let mask = DatasetPreFilter::create_deletion_mask( @@ -361,7 +361,7 @@ mod test { let mask = mask.unwrap().await.unwrap(); let mut expected = RowAddrTreeMap::from_iter(vec![(2 << 32) + 2]); expected.insert_fragment(1); - assert_eq!(&mask.block_list, &Some(expected)); + assert_eq!(mask.block_list(), Some(&expected)); // If we don't pass the missing fragment id, we should get a smaller mask. let mask = DatasetPreFilter::create_deletion_mask( @@ -370,7 +370,7 @@ mod test { ); assert!(mask.is_some()); let mask = mask.unwrap().await.unwrap(); - assert_eq!(mask.block_list.as_ref().and_then(|x| x.len()), Some(1)); + assert_eq!(mask.block_list().and_then(|x| x.len()), Some(1)); // If there are only missing fragments, we should still get a mask let mask = DatasetPreFilter::create_deletion_mask( @@ -382,7 +382,7 @@ mod test { let mut expected = RowAddrTreeMap::new(); expected.insert_fragment(1); expected.insert_fragment(2); - assert_eq!(&mask.block_list, &Some(expected)); + assert_eq!(mask.block_list(), Some(&expected)); } #[tokio::test] @@ -405,7 +405,7 @@ mod test { assert!(mask.is_some()); let mask = mask.unwrap().await.unwrap(); let expected = RowAddrTreeMap::from_iter(0..8); - assert_eq!(mask.allow_list, Some(expected)); // There was just one row deleted. + assert_eq!(mask.allow_list(), Some(&expected)); // There was just one row deleted. // If there are deletions and missing fragments, we should get an allow list let mask = DatasetPreFilter::create_deletion_mask( @@ -414,7 +414,7 @@ mod test { ); assert!(mask.is_some()); let mask = mask.unwrap().await.unwrap(); - assert_eq!(mask.allow_list.as_ref().and_then(|x| x.len()), Some(5)); // There were five rows left over; + assert_eq!(mask.allow_list().and_then(|x| x.len()), Some(5)); // There were five rows left over; // If there are only missing fragments, we should get an allow list let mask = DatasetPreFilter::create_deletion_mask( @@ -423,6 +423,6 @@ mod test { ); assert!(mask.is_some()); let mask = mask.unwrap().await.unwrap(); - assert_eq!(mask.allow_list.as_ref().and_then(|x| x.len()), Some(3)); // There were three rows left over; + assert_eq!(mask.allow_list().and_then(|x| x.len()), Some(3)); // There were three rows left over; } } diff --git a/rust/lance/src/io/exec/filtered_read.rs b/rust/lance/src/io/exec/filtered_read.rs index f97ebcdadbf..4db065ec997 100644 --- a/rust/lance/src/io/exec/filtered_read.rs +++ b/rust/lance/src/io/exec/filtered_read.rs @@ -81,7 +81,7 @@ impl EvaluatedIndex { if batch.num_rows() != 2 { return Err(Error::InvalidInput { source: format!( - "Expected a batch with exactly one row but there are {} rows", + "Expected a batch with exactly 2 rows but there are {} rows", batch.num_rows() ) .into(), diff --git a/rust/lance/src/io/exec/scalar_index.rs b/rust/lance/src/io/exec/scalar_index.rs index 6e1718bdc9f..5657625ca77 100644 --- a/rust/lance/src/io/exec/scalar_index.rs +++ b/rust/lance/src/io/exec/scalar_index.rs @@ -318,29 +318,16 @@ impl MapIndexExec { row_id_mask = row_id_mask & deletion_mask.as_ref().clone(); } - if let Some(mut allow_list) = row_id_mask.allow_list { - // Flatten the allow list - if let Some(block_list) = row_id_mask.block_list { - allow_list -= &block_list; - } - - let allow_list = - allow_list - .row_addrs() - .ok_or(datafusion::error::DataFusionError::External( - "IndexedLookupExec: row addresses didn't have an iterable allow list" - .into(), - ))?; - let allow_list: UInt64Array = allow_list.map(u64::from).collect(); - Ok(RecordBatch::try_new( - INDEX_LOOKUP_SCHEMA.clone(), - vec![Arc::new(allow_list)], - )?) - } else { - Err(datafusion::error::DataFusionError::Internal( - "IndexedLookupExec: row addresses didn't have an allow list".to_string(), - )) - } + let row_id_iter = row_id_mask + .iter_ids() + .ok_or(datafusion::error::DataFusionError::Internal( + "IndexedLookupExec: Cannot iterate over row addresses (BlockList or contains full fragments)".to_string(), + ))?; + let allow_list: UInt64Array = row_id_iter.map(u64::from).collect(); + Ok(RecordBatch::try_new( + INDEX_LOOKUP_SCHEMA.clone(), + vec![Arc::new(allow_list)], + )?) } async fn do_execute( @@ -589,8 +576,8 @@ async fn row_ids_for_mask( dataset: &Dataset, fragments: &[Fragment], ) -> Result> { - match (mask.allow_list, mask.block_list) { - (None, None) => { + match mask { + RowIdMask::BlockList(block_list) if block_list.is_empty() => { // Matches all row ids in the given fragments. if dataset.manifest.uses_stable_row_ids() { let sequences = load_row_id_sequences(dataset, fragments) @@ -608,7 +595,7 @@ async fn row_ids_for_mask( Ok(FragIdIter::new(fragments).collect::>()) } } - (Some(mut allow_list), None) => { + RowIdMask::AllowList(mut allow_list) => { retain_fragments(&mut allow_list, fragments, dataset).await?; if let Some(allow_list_iter) = allow_list.row_addrs() { @@ -621,7 +608,7 @@ async fn row_ids_for_mask( .collect()) } } - (None, Some(block_list)) => { + RowIdMask::BlockList(block_list) => { if dataset.manifest.uses_stable_row_ids() { let sequences = load_row_id_sequences(dataset, fragments) .map_ok(|(_frag_id, sequence)| sequence) @@ -645,29 +632,6 @@ async fn row_ids_for_mask( .collect()) } } - (Some(mut allow_list), Some(block_list)) => { - // We need to filter out irrelevant fragments as well. - retain_fragments(&mut allow_list, fragments, dataset).await?; - - if let Some(allow_list_iter) = allow_list.row_addrs() { - Ok(allow_list_iter - .filter_map(|addr| { - let row_id = u64::from(addr); - if !block_list.contains(row_id) { - Some(row_id) - } else { - None - } - }) - .collect::>()) - } else { - // We shouldn't hit this branch if the row ids are stable. - debug_assert!(!dataset.manifest.uses_stable_row_ids()); - Ok(FragIdIter::new(fragments) - .filter(|row_id| !block_list.contains(*row_id) && allow_list.contains(*row_id)) - .collect()) - } - } } }