diff --git a/rust/lance-index/src/scalar/inverted/wand.rs b/rust/lance-index/src/scalar/inverted/wand.rs index b9bbfad62bd..0d3e57fb743 100644 --- a/rust/lance-index/src/scalar/inverted/wand.rs +++ b/rust/lance-index/src/scalar/inverted/wand.rs @@ -22,7 +22,7 @@ use super::{ encoding::{decompress_positions, decompress_posting_block, decompress_posting_remainder}, query::FtsSearchParams, scorer::Scorer, - DocSet, PostingList, RawDocInfo, + CompressedPostingList, DocSet, PostingList, RawDocInfo, }; use super::{builder::BLOCK_SIZE, DocInfo}; use super::{ @@ -140,6 +140,28 @@ impl Ord for PostingIterator { } impl PostingIterator { + #[inline] + fn compressed_state_ptr(&self) -> *mut CompressedState { + debug_assert!(self.compressed.is_some()); + // this method is called very frequently, so we prefer to use `UnsafeCell` instead of + // `RefCell` to avoid the overhead of runtime borrow checking + self.compressed.as_ref().unwrap().get() + } + + #[inline] + fn ensure_compressed_block_ptr( + &self, + list: &CompressedPostingList, + block_idx: usize, + ) -> *mut CompressedState { + let compressed = unsafe { &mut *self.compressed_state_ptr() }; + if compressed.block_idx != block_idx || compressed.doc_ids.is_empty() { + let block = list.blocks.value(block_idx); + compressed.decompress(block, block_idx, list.blocks.len(), list.length); + } + compressed as *mut CompressedState + } + pub(crate) fn new( token: String, token_id: u32, @@ -194,19 +216,9 @@ impl PostingIterator { match self.list { PostingList::Compressed(ref list) => { - debug_assert!(self.compressed.is_some()); - // this method is called very frequently, so we prefer to use `UnsafeCell` instead of `RefCell` - // to avoid the overhead of runtime borrow checking - let compressed = unsafe { - let compressed = self.compressed.as_ref().unwrap(); - &mut *compressed.get() - }; let block_idx = self.index / BLOCK_SIZE; let block_offset = self.index % BLOCK_SIZE; - if compressed.block_idx != block_idx || compressed.doc_ids.is_empty() { - let block = list.blocks.value(block_idx); - compressed.decompress(block, block_idx, list.blocks.len(), list.length); - } + let compressed = unsafe { &mut *self.ensure_compressed_block_ptr(list, block_idx) }; // Read from the decompressed block let doc_id = compressed.doc_ids[block_offset]; @@ -232,7 +244,7 @@ impl PostingIterator { // move to the next doc id that is greater than or equal to least_id fn next(&mut self, least_id: u64) { match self.list { - PostingList::Compressed(ref mut list) => { + PostingList::Compressed(ref list) => { debug_assert!(least_id <= u32::MAX as u64); let least_id = least_id as u32; let mut block_idx = self.index / BLOCK_SIZE; @@ -242,9 +254,24 @@ impl PostingIterator { block_idx += 1; } self.index = self.index.max(block_idx * BLOCK_SIZE); - let length = self.list.len(); - while self.index < length && (self.doc().unwrap().doc_id() as u32) < least_id { - self.index += 1; + let length = list.length as usize; + while self.index < length { + let block_idx = self.index / BLOCK_SIZE; + let block_offset = self.index % BLOCK_SIZE; + let compressed = + unsafe { &mut *self.ensure_compressed_block_ptr(list, block_idx) }; + let in_block = &compressed.doc_ids[block_offset..]; + let offset_in_block = in_block.partition_point(|&doc_id| doc_id < least_id); + let new_offset = block_offset + offset_in_block; + if new_offset < compressed.doc_ids.len() { + self.index = block_idx * BLOCK_SIZE + new_offset; + break; + } + if block_idx + 1 >= list.blocks.len() { + self.index = length; + break; + } + self.index = (block_idx + 1) * BLOCK_SIZE; } self.block_idx = self.index / BLOCK_SIZE; } @@ -256,7 +283,7 @@ impl PostingIterator { fn shallow_next(&mut self, least_id: u64) { match self.list { - PostingList::Compressed(ref mut list) => { + PostingList::Compressed(ref list) => { debug_assert!(least_id <= u32::MAX as u64); let least_id = least_id as u32; while self.block_idx + 1 < list.blocks.len() @@ -952,6 +979,29 @@ mod tests { assert_eq!(result.len(), 0); // Should not panic } + #[test] + fn test_posting_iterator_next_compressed_partition_point() { + let mut docs = DocSet::default(); + let num_docs = (BLOCK_SIZE * 2 + 5) as u32; + for i in 0..num_docs { + docs.append(i as u64, 1); + } + + let doc_ids = (0..num_docs).collect::>(); + let posting = generate_posting_list(doc_ids, 1.0, None, true); + let mut iter = PostingIterator::new(String::from("term"), 0, 0, posting, docs.len()); + + iter.next(10); + assert_eq!(iter.doc().unwrap().doc_id(), 10); + + let target = BLOCK_SIZE as u64 + 3; + iter.next(target); + assert_eq!(iter.doc().unwrap().doc_id(), target); + + iter.next(num_docs as u64 + 10); + assert!(iter.doc().is_none()); + } + #[test] fn test_wand_skip_to_next_block() { let mut docs = DocSet::default();