diff --git a/rust/lance-index/src/scalar/inverted/index.rs b/rust/lance-index/src/scalar/inverted/index.rs index 1c7dac32397..f9794079811 100644 --- a/rust/lance-index/src/scalar/inverted/index.rs +++ b/rust/lance-index/src/scalar/inverted/index.rs @@ -111,6 +111,21 @@ pub static FTS_SCHEMA: LazyLock = static ROW_ID_SCHEMA: LazyLock = LazyLock::new(|| Arc::new(Schema::new(vec![ROW_ID_FIELD.clone()]))); +#[derive(Debug)] +struct PartitionCandidates { + tokens_by_position: Vec, + candidates: Vec, +} + +impl PartitionCandidates { + fn empty() -> Self { + Self { + tokens_by_position: Vec::new(), + candidates: Vec::new(), + } + } +} + #[derive(Copy, Clone, Debug, Eq, PartialEq, Hash, Default)] pub enum TokenSetFormat { Arrow, @@ -256,19 +271,28 @@ impl InvertedIndex { .load_posting_lists(tokens.as_ref(), params.as_ref(), metrics.as_ref()) .await?; if postings.is_empty() { - return Ok(Vec::new()); + return Ok(PartitionCandidates::empty()); + } + let mut tokens_by_position = vec![String::new(); postings.len()]; + for posting in &postings { + let idx = posting.term_index() as usize; + tokens_by_position[idx] = posting.token().to_owned(); } let params = params.clone(); let mask = mask.clone(); let metrics = metrics.clone(); spawn_cpu(move || { - part.bm25_search( + let candidates = part.bm25_search( params.as_ref(), operator, mask, postings, metrics.as_ref(), - ) + )?; + Ok(PartitionCandidates { + tokens_by_position, + candidates, + }) }) .await } @@ -277,14 +301,20 @@ impl InvertedIndex { let mut parts = stream::iter(parts).buffer_unordered(get_num_compute_intensive_cpus()); let scorer = IndexBM25Scorer::new(self.partitions.iter().map(|part| part.as_ref())); while let Some(res) = parts.try_next().await? { + if res.candidates.is_empty() { + continue; + } + let tokens_by_position = &res.tokens_by_position; for DocCandidate { row_id, freqs, doc_length, - } in res + } in res.candidates { let mut score = 0.0; - for (token, freq) in freqs.into_iter() { + for (term_index, freq) in freqs.into_iter() { + debug_assert!((term_index as usize) < tokens_by_position.len()); + let token = &tokens_by_position[term_index as usize]; score += scorer.score(token.as_str(), freq, doc_length); } if candidates.len() < limit { diff --git a/rust/lance-index/src/scalar/inverted/wand.rs b/rust/lance-index/src/scalar/inverted/wand.rs index 786e01337d7..61786bdb23a 100644 --- a/rust/lance-index/src/scalar/inverted/wand.rs +++ b/rust/lance-index/src/scalar/inverted/wand.rs @@ -166,6 +166,16 @@ impl PostingIterator { } } + #[inline] + pub(crate) fn term_index(&self) -> u32 { + self.position + } + + #[inline] + pub(crate) fn token(&self) -> &str { + &self.token + } + #[inline] fn approximate_upper_bound(&self) -> f32 { self.approximate_upper_bound @@ -293,9 +303,11 @@ impl PostingIterator { } } +#[derive(Debug)] pub struct DocCandidate { pub row_id: u64, - pub freqs: Vec<(String, u32)>, + /// (term_index, freq) + pub freqs: Vec<(u32, u32)>, pub doc_length: u32, } @@ -359,7 +371,7 @@ impl<'a, S: Scorer> Wand<'a, S> { _ => {} } - let mut candidates = BinaryHeap::new(); + let mut candidates = BinaryHeap::with_capacity(std::cmp::min(limit, BLOCK_SIZE * 10)); let mut num_comparisons = 0; while let Some((pivot, doc)) = self.next()? { if let Some(cur_doc) = self.cur_doc { @@ -394,10 +406,7 @@ impl<'a, S: Scorer> Wand<'a, S> { DocInfo::Located(doc) => self.docs.num_tokens_by_row_id(doc.row_id), }; let score = self.score(pivot, doc_length); - let freqs = self - .iter_token_freqs(pivot) - .map(|(token, freq)| (token.to_owned(), freq)) - .collect(); + let freqs = self.iter_term_freqs(pivot).collect(); if candidates.len() < limit { candidates.push(Reverse((ScoredDoc::new(row_id, score), freqs, doc_length))); if candidates.len() == limit { @@ -522,10 +531,7 @@ impl<'a, S: Scorer> Wand<'a, S> { }; let score = self.score(max_pivot, doc_length); - let freqs = self - .iter_token_freqs(max_pivot) - .map(|(token, freq)| (token.to_owned(), freq)) - .collect(); + let freqs = self.iter_term_freqs(max_pivot).collect(); if candidates.len() < limit { candidates.push(Reverse((ScoredDoc::new(row_id, score), freqs, doc_length))); @@ -568,6 +574,15 @@ impl<'a, S: Scorer> Wand<'a, S> { }) } + // iterate over all the preceding terms and collect the term index and frequency + fn iter_term_freqs(&self, pivot: usize) -> impl Iterator + '_ { + self.postings[..=pivot].iter().filter_map(|posting| { + posting + .doc() + .map(|doc| (posting.term_index(), doc.frequency())) + }) + } + // find the next doc candidate fn next(&mut self) -> Result> { while let Some((pivot, max_pivot)) = self.find_pivot_term() {