diff --git a/rust/lance-index/src/scalar/inverted/wand.rs b/rust/lance-index/src/scalar/inverted/wand.rs index 61786bdb23a..b9bbfad62bd 100644 --- a/rust/lance-index/src/scalar/inverted/wand.rs +++ b/rust/lance-index/src/scalar/inverted/wand.rs @@ -148,7 +148,7 @@ impl PostingIterator { num_doc: usize, ) -> Self { let approximate_upper_bound = match list.max_score() { - Some(max_score) => max_score, // the index doesn't include the full BM25 upper bound at indexing time, so we need to multiply it here + Some(max_score) => max_score, None => idf(list.len(), num_doc) * (K1 + 1.0), }; @@ -275,7 +275,7 @@ impl PostingIterator { #[inline] fn block_max_score(&self) -> f32 { match self.list { - PostingList::Compressed(ref list) => list.block_max_score(self.block_idx) * (K1 + 1.0), + PostingList::Compressed(ref list) => list.block_max_score(self.block_idx), PostingList::Plain(_) => self.approximate_upper_bound, } } @@ -993,4 +993,23 @@ mod tests { assert!(result.is_ok()); } + + #[test] + fn test_block_max_score_matches_stored_value() { + let doc_ids = vec![0_u32]; + let block_max_scores = vec![0.7_f32]; + let posting_list = generate_posting_list(doc_ids, 0.7, Some(block_max_scores), true); + let expected = match &posting_list { + PostingList::Compressed(list) => list.block_max_score(0), + PostingList::Plain(_) => unreachable!("expected compressed posting list"), + }; + + let posting = PostingIterator::new(String::from("test"), 0, 0, posting_list, 1); + + let actual = posting.block_max_score(); + assert!( + (actual - expected).abs() < 1e-6, + "block max score should match stored value" + ); + } }