Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 35 additions & 5 deletions rust/lance-index/src/scalar/inverted/index.rs
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,21 @@ pub static FTS_SCHEMA: LazyLock<SchemaRef> =
static ROW_ID_SCHEMA: LazyLock<SchemaRef> =
LazyLock::new(|| Arc::new(Schema::new(vec![ROW_ID_FIELD.clone()])));

#[derive(Debug)]
struct PartitionCandidates {
tokens_by_position: Vec<String>,
candidates: Vec<DocCandidate>,
}

impl PartitionCandidates {
fn empty() -> Self {
Self {
tokens_by_position: Vec::new(),
candidates: Vec::new(),
}
}
}

#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash, Default)]
pub enum TokenSetFormat {
Arrow,
Expand Down Expand Up @@ -256,19 +271,28 @@ impl InvertedIndex {
.load_posting_lists(tokens.as_ref(), params.as_ref(), metrics.as_ref())
.await?;
if postings.is_empty() {
return Ok(Vec::new());
return Ok(PartitionCandidates::empty());
}
let mut tokens_by_position = vec![String::new(); postings.len()];
for posting in &postings {
let idx = posting.term_index() as usize;
tokens_by_position[idx] = posting.token().to_owned();
}
let params = params.clone();
let mask = mask.clone();
let metrics = metrics.clone();
spawn_cpu(move || {
part.bm25_search(
let candidates = part.bm25_search(
params.as_ref(),
operator,
mask,
postings,
metrics.as_ref(),
)
)?;
Ok(PartitionCandidates {
tokens_by_position,
candidates,
})
})
.await
}
Expand All @@ -277,14 +301,20 @@ impl InvertedIndex {
let mut parts = stream::iter(parts).buffer_unordered(get_num_compute_intensive_cpus());
let scorer = IndexBM25Scorer::new(self.partitions.iter().map(|part| part.as_ref()));
while let Some(res) = parts.try_next().await? {
if res.candidates.is_empty() {
continue;
}
let tokens_by_position = &res.tokens_by_position;
for DocCandidate {
row_id,
freqs,
doc_length,
} in res
} in res.candidates
{
let mut score = 0.0;
for (token, freq) in freqs.into_iter() {
for (term_index, freq) in freqs.into_iter() {
debug_assert!((term_index as usize) < tokens_by_position.len());
let token = &tokens_by_position[term_index as usize];
score += scorer.score(token.as_str(), freq, doc_length);
}
if candidates.len() < limit {
Expand Down
35 changes: 25 additions & 10 deletions rust/lance-index/src/scalar/inverted/wand.rs
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,16 @@ impl PostingIterator {
}
}

#[inline]
pub(crate) fn term_index(&self) -> u32 {
self.position
}

#[inline]
pub(crate) fn token(&self) -> &str {
&self.token
}

#[inline]
fn approximate_upper_bound(&self) -> f32 {
self.approximate_upper_bound
Expand Down Expand Up @@ -293,9 +303,11 @@ impl PostingIterator {
}
}

#[derive(Debug)]
pub struct DocCandidate {
pub row_id: u64,
pub freqs: Vec<(String, u32)>,
/// (term_index, freq)
pub freqs: Vec<(u32, u32)>,
pub doc_length: u32,
}

Expand Down Expand Up @@ -359,7 +371,7 @@ impl<'a, S: Scorer> Wand<'a, S> {
_ => {}
}

let mut candidates = BinaryHeap::new();
let mut candidates = BinaryHeap::with_capacity(std::cmp::min(limit, BLOCK_SIZE * 10));
let mut num_comparisons = 0;
while let Some((pivot, doc)) = self.next()? {
if let Some(cur_doc) = self.cur_doc {
Expand Down Expand Up @@ -394,10 +406,7 @@ impl<'a, S: Scorer> Wand<'a, S> {
DocInfo::Located(doc) => self.docs.num_tokens_by_row_id(doc.row_id),
};
let score = self.score(pivot, doc_length);
let freqs = self
.iter_token_freqs(pivot)
.map(|(token, freq)| (token.to_owned(), freq))
.collect();
let freqs = self.iter_term_freqs(pivot).collect();
if candidates.len() < limit {
candidates.push(Reverse((ScoredDoc::new(row_id, score), freqs, doc_length)));
if candidates.len() == limit {
Expand Down Expand Up @@ -522,10 +531,7 @@ impl<'a, S: Scorer> Wand<'a, S> {
};

let score = self.score(max_pivot, doc_length);
let freqs = self
.iter_token_freqs(max_pivot)
.map(|(token, freq)| (token.to_owned(), freq))
.collect();
let freqs = self.iter_term_freqs(max_pivot).collect();

if candidates.len() < limit {
candidates.push(Reverse((ScoredDoc::new(row_id, score), freqs, doc_length)));
Expand Down Expand Up @@ -568,6 +574,15 @@ impl<'a, S: Scorer> Wand<'a, S> {
})
}

// iterate over all the preceding terms and collect the term index and frequency
fn iter_term_freqs(&self, pivot: usize) -> impl Iterator<Item = (u32, u32)> + '_ {
self.postings[..=pivot].iter().filter_map(|posting| {
posting
.doc()
.map(|doc| (posting.term_index(), doc.frequency()))
})
}

// find the next doc candidate
fn next(&mut self) -> Result<Option<(usize, DocInfo)>> {
while let Some((pivot, max_pivot)) = self.find_pivot_term() {
Expand Down
Loading