Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions rust/lance-index/benches/inverted.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@ use lance_core::cache::LanceCache;
use lance_core::ROW_ID;
use lance_datagen::{array, RowCount};
use lance_index::prefilter::NoFilter;
use lance_index::scalar::inverted::query::{FtsSearchParams, Operator};
use lance_index::scalar::inverted::lance_tokenizer::DocType;
use lance_index::scalar::inverted::query::{FtsSearchParams, Operator, Tokens};
use lance_index::scalar::inverted::{InvertedIndex, InvertedIndexBuilder};
use lance_index::scalar::lance_format::LanceIndexStore;
use lance_index::{
Expand Down Expand Up @@ -99,7 +100,10 @@ fn bench_inverted(c: &mut Criterion) {
black_box(
invert_index
.bm25_search(
vec![sample_words[word_idx].clone()].into(),
Arc::new(Tokens::new(
vec![sample_words[word_idx].clone()],
DocType::Text,
)),
params.clone().into(),
Operator::Or,
no_filter.clone(),
Expand Down
36 changes: 15 additions & 21 deletions rust/lance-index/src/scalar/inverted/index.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,7 @@ use std::{
cmp::{min, Reverse},
collections::BinaryHeap,
};
use std::{
collections::{HashMap, HashSet},
ops::Range,
};
use std::{collections::HashMap, ops::Range};

use crate::metrics::NoOpMetricsCollector;
use crate::prefilter::NoFilter;
Expand Down Expand Up @@ -232,7 +229,7 @@ impl InvertedIndex {
#[instrument(level = "debug", skip_all)]
pub async fn bm25_search(
&self,
tokens: Arc<Vec<String>>,
tokens: Arc<Tokens>,
params: Arc<FtsSearchParams>,
operator: Operator,
prefilter: Arc<dyn PreFilter>,
Expand Down Expand Up @@ -500,7 +497,7 @@ impl InvertedIndex {

let (doc_ids, _) = self
.bm25_search(
tokens.into(),
Arc::new(tokens),
params.into(),
Operator::And,
Arc::new(NoFilter),
Expand Down Expand Up @@ -679,7 +676,7 @@ impl InvertedPartition {
self.tokens.get(token)
}

pub fn expand_fuzzy(&self, tokens: &[String], params: &FtsSearchParams) -> Result<Vec<String>> {
pub fn expand_fuzzy(&self, tokens: &Tokens, params: &FtsSearchParams) -> Result<Tokens> {
let mut new_tokens = Vec::with_capacity(min(tokens.len(), params.max_expansions));
for token in tokens {
let fuzziness = match params.fuzziness {
Expand All @@ -692,8 +689,9 @@ impl InvertedPartition {
location: location!(),
})?;

let base_len = tokens.token_type().prefix_len(token) as u32;
if let TokenMap::Fst(ref map) = self.tokens.tokens {
match params.prefix_length {
match base_len + params.prefix_length {
0 => take_fst_keys(map.search(lev), &mut new_tokens, params.max_expansions),
prefix_length => {
let prefix = &token[..min(prefix_length as usize, token.len())];
Expand All @@ -712,7 +710,7 @@ impl InvertedPartition {
});
}
}
Ok(new_tokens)
Ok(Tokens::new(new_tokens, tokens.token_type().clone()))
}

// search the documents that contain the query
Expand All @@ -721,7 +719,7 @@ impl InvertedPartition {
#[instrument(level = "debug", skip_all)]
pub async fn bm25_search(
&self,
tokens: &[String],
tokens: &Tokens,
params: &FtsSearchParams,
operator: Operator,
mask: Arc<RowIdMask>,
Expand All @@ -731,7 +729,7 @@ impl InvertedPartition {
let is_phrase_query = params.phrase_slop.is_some();
let tokens = match is_fuzzy {
true => self.expand_fuzzy(tokens, params)?,
false => tokens.to_vec(),
false => tokens.clone(),
};
let mut token_ids = Vec::with_capacity(tokens.len());
for token in tokens {
Expand Down Expand Up @@ -2337,9 +2335,7 @@ fn do_flat_full_text_search<Offset: OffsetSizeTrait>(
let mut results = Vec::new();
let mut tokenizer =
tokenizer.unwrap_or_else(|| InvertedIndexParams::default().build().unwrap());
let query_tokens = collect_query_tokens(query, &mut tokenizer, None)
.into_iter()
.collect::<HashSet<_>>();
let query_tokens = collect_query_tokens(query, &mut tokenizer, None);

for batch in batches {
let row_id_array = batch[ROW_ID].as_primitive::<UInt64Type>();
Expand All @@ -2361,7 +2357,7 @@ fn do_flat_full_text_search<Offset: OffsetSizeTrait>(
pub fn flat_bm25_search(
batch: RecordBatch,
doc_col: &str,
query_tokens: &HashSet<String>,
query_tokens: &Tokens,
tokenizer: &mut Box<dyn LanceTokenizer>,
scorer: &mut MemBM25Scorer,
) -> std::result::Result<RecordBatch, DataFusionError> {
Expand Down Expand Up @@ -2389,7 +2385,7 @@ pub fn flat_bm25_search(
.or_insert(1);
}
let mut score = 0.0;
for token in query_tokens.iter() {
for token in query_tokens {
let freq = doc_token_count.get(token).copied().unwrap_or_default() as f32;

let idf = idf(scorer.num_docs_containing_token(token), scorer.num_docs());
Expand Down Expand Up @@ -2420,10 +2416,7 @@ pub fn flat_bm25_search_stream(
.build(),
)),
};
let tokens = collect_query_tokens(&query, &mut tokenizer, None)
.into_iter()
.sorted_unstable()
.collect::<HashSet<_>>();
let tokens = collect_query_tokens(&query, &mut tokenizer, None);

let mut bm25_scorer = match index {
Some(index) => {
Expand Down Expand Up @@ -2473,6 +2466,7 @@ pub fn is_phrase_query(query: &str) -> bool {

#[cfg(test)]
mod tests {
use crate::scalar::inverted::lance_tokenizer::DocType;
use lance_core::cache::LanceCache;
use lance_core::utils::tempfile::TempObjDir;
use lance_io::object_store::ObjectStore;
Expand Down Expand Up @@ -2671,7 +2665,7 @@ mod tests {
// Prewarm the inverted index (this loads posting lists into cache)
index.prewarm().await.unwrap();

let tokens = Arc::new(vec!["test".to_string()]);
let tokens = Arc::new(Tokens::new(vec!["test".to_string()], DocType::Text));
let params = Arc::new(FtsSearchParams::new().with_limit(Some(10)));
let prefilter = Arc::new(NoFilter);
let metrics = Arc::new(NoOpMetricsCollector);
Expand Down
71 changes: 66 additions & 5 deletions rust/lance-index/src/scalar/inverted/query.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
// SPDX-License-Identifier: Apache-2.0
// SPDX-FileCopyrightText: Copyright The Lance Authors

use crate::scalar::inverted::lance_tokenizer::DocType;
use crate::scalar::inverted::tokenizer::lance_tokenizer::LanceTokenizer;
use lance_core::{Error, Result};
use serde::ser::SerializeMap;
Expand Down Expand Up @@ -650,11 +651,70 @@ impl FtsQueryNode for BooleanQuery {
}
}

#[derive(Clone)]
pub struct Tokens {
tokens: Vec<String>,
tokens_set: HashSet<String>,
token_type: DocType,
}

impl Tokens {
pub fn new(tokens: Vec<String>, token_type: DocType) -> Self {
let mut tokens_vec = vec![];
let mut tokens_set = HashSet::new();
for token in tokens.into_iter() {
tokens_vec.push(token.clone());
tokens_set.insert(token);
}

Self {
tokens: tokens_vec,
tokens_set,
token_type,
}
}

pub fn len(&self) -> usize {
self.tokens.len()
}

pub fn is_empty(&self) -> bool {
self.tokens.is_empty()
}

pub fn token_type(&self) -> &DocType {
&self.token_type
}

pub fn contains(&self, token: &str) -> bool {
self.tokens_set.contains(token)
}
}

impl IntoIterator for Tokens {
type Item = String;
type IntoIter = std::vec::IntoIter<String>;

fn into_iter(self) -> Self::IntoIter {
self.tokens.into_iter()
}
}

impl<'a> IntoIterator for &'a Tokens {
type Item = &'a String;
type IntoIter = std::slice::Iter<'a, String>;

fn into_iter(self) -> Self::IntoIter {
self.tokens.iter()
}
}

pub fn collect_query_tokens(
text: &str,
tokenizer: &mut Box<dyn LanceTokenizer>,
inclusive: Option<&HashSet<String>>,
) -> Vec<String> {
) -> Tokens {
let token_type = tokenizer.doc_type();
let mut stream = tokenizer.token_stream_for_search(text);
let mut tokens = Vec::new();
while let Some(token) = stream.next() {
Expand All @@ -665,14 +725,15 @@ pub fn collect_query_tokens(
}
tokens.push(token.text.to_owned());
}
tokens
Tokens::new(tokens, token_type)
}

pub fn collect_doc_tokens(
text: &str,
tokenizer: &mut Box<dyn LanceTokenizer>,
inclusive: Option<&HashSet<String>>,
) -> Vec<String> {
inclusive: Option<&Tokens>,
) -> Tokens {
let token_type = tokenizer.doc_type();
let mut stream = tokenizer.token_stream_for_doc(text);
let mut tokens = Vec::new();
while let Some(token) = stream.next() {
Expand All @@ -683,7 +744,7 @@ pub fn collect_doc_tokens(
}
tokens.push(token.text.to_owned());
}
tokens
Tokens::new(tokens, token_type)
}

pub fn fill_fts_query_column(
Expand Down
3 changes: 2 additions & 1 deletion rust/lance-index/src/scalar/inverted/scorer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
// SPDX-FileCopyrightText: Copyright The Lance Authors

use super::InvertedPartition;
use crate::scalar::inverted::query::Tokens;
use std::collections::HashMap;

// the Scorer trait is used to calculate the score of a token in a document
Expand Down Expand Up @@ -43,7 +44,7 @@ impl MemBM25Scorer {
///
/// # Arguments
/// * `tokens` - The tokens of the new document.
pub fn update(&mut self, tokens: &Vec<String>) {
pub fn update(&mut self, tokens: &Tokens) {
self.total_tokens += tokens.len() as u64;
self.num_docs += 1;
for token in tokens {
Expand Down
30 changes: 30 additions & 0 deletions rust/lance-index/src/scalar/inverted/tokenizer/lance_tokenizer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ use snafu::location;
use tantivy::tokenizer::{BoxTokenStream, Token, TokenStream};

/// Document type for full text search.
#[derive(Debug, Clone)]
pub enum DocType {
Text,
Json,
Expand Down Expand Up @@ -49,6 +50,25 @@ impl TryFrom<&Field> for DocType {
}
}

impl DocType {
/// Get the length of the prefix before value.
/// - JSON Token: path,type,value
/// - Text Token: value
pub fn prefix_len(&self, token: &str) -> usize {
match self {
Self::Json => {
if let Some(pos) = token.find(',') {
if let Some(second_pos) = token[pos + 1..].find(',') {
return pos + second_pos + 2;
}
}
panic!("json token must be in format of <path>,<type>,<value>")
}
Self::Text => 0,
}
}
}

/// Lance full text search tokenizer.
///
/// `LanceTokenizer` defines 2 methods for tokenization, normally they are the same, but sometimes
Expand All @@ -63,6 +83,8 @@ pub trait LanceTokenizer: Send + Sync {
fn token_stream_for_doc<'a>(&'a mut self, text: &'a str) -> BoxTokenStream<'a>;
/// Clone the tokenizer.
fn box_clone(&self) -> Box<dyn LanceTokenizer>;
/// Get document type.
fn doc_type(&self) -> DocType;
}

impl Clone for Box<dyn LanceTokenizer> {
Expand Down Expand Up @@ -94,6 +116,10 @@ impl LanceTokenizer for TextTokenizer {
fn box_clone(&self) -> Box<dyn LanceTokenizer> {
Box::new(self.clone())
}

fn doc_type(&self) -> DocType {
DocType::Text
}
}

#[derive(Clone)]
Expand Down Expand Up @@ -129,6 +155,10 @@ impl LanceTokenizer for JsonTokenizer {
fn box_clone(&self) -> Box<dyn LanceTokenizer> {
Box::new(self.clone())
}

fn doc_type(&self) -> DocType {
DocType::Json
}
}

fn flatten_triplet(
Expand Down
Loading
Loading