diff --git a/python/python/lance/dataset.py b/python/python/lance/dataset.py index 783a0772b97..a119933fdd8 100644 --- a/python/python/lance/dataset.py +++ b/python/python/lance/dataset.py @@ -2369,6 +2369,7 @@ def create_scalar_index( field = lance_field.to_arrow() field_type = field.type + field_meta = field.metadata if hasattr(field_type, "storage_type"): field_type = field_type.storage_type @@ -2397,12 +2398,17 @@ def create_scalar_index( value_type = field_type if pa.types.is_list(field_type) or pa.types.is_large_list(field_type): value_type = field_type.value_type - if not pa.types.is_string(value_type) and not pa.types.is_large_string( - value_type + if ( + not pa.types.is_string(value_type) + and not pa.types.is_large_string(value_type) + and not ( + pa.types.is_large_binary(value_type) + and field_meta[b"ARROW:extension:name"] == b"lance.json" + ) ): raise TypeError( f"INVERTED index column {column} must be string, large string" - " or list of strings, but got {value_type}" + f" or list of strings, or json, but got {value_type}" ) if pa.types.is_duration(field_type): diff --git a/python/python/tests/test_scalar_index.py b/python/python/tests/test_scalar_index.py index d33ca4a9b6d..c2010dc5670 100644 --- a/python/python/tests/test_scalar_index.py +++ b/python/python/tests/test_scalar_index.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright The Lance Authors +import json import os import random import re @@ -3945,3 +3946,70 @@ def test_nested_field_bitmap_index(tmp_path): # Verify query still works after optimization results = ds.to_table(filter="attributes.color = 'red'", prefilter=True) assert results.num_rows == 51 # 34 + 17 from new data + + +def test_json_inverted_match_query(tmp_path): + # Prepare dataset with JSON data + json_data = [ + { + "Title": "HarryPotter Chapter One", + "Content": "Once upon a time, there was a boy named Harry.", + "Author": "J.K. Rowling", + "Price": 99, + "Language": ["english", "french"], + }, + { + "Title": "HarryPotter Chapter Two", + "Content": "Nearly ten years had passed since the Dursleys had woken up...", + "Author": "J.K. Rowling", + "Price": 128, + "Language": ["english", "chinese"], + }, + { + "Title": "The Hobbit", + "Content": "In a hole in the ground there lived a hobbit.", + "Author": "J.R.R. Tolkien", + "Price": 89, + "Language": ["english"], + }, + ] + + # Convert to JSON strings + json_strings = pa.array([json.dumps(doc) for doc in json_data], type=pa.json_()) + table = pa.table({"json_col": json_strings, "id": range(len(json_data))}) + dataset = lance.write_dataset(table, tmp_path) + + # Create inverted index with JSON tokenizer + dataset.create_scalar_index( + "json_col", + index_type="INVERTED", + base_tokenizer="simple", + max_token_length=10, + stem=True, + lower_case=True, + remove_stop_words=True, + ) + + # Test match query with token exceeding max_token_length + results = dataset.to_table( + full_text_query=MatchQuery("Title,str,harrypotter", "json_col") + ) + assert results.num_rows == 0 + + # Test stemming + results = dataset.to_table( + full_text_query=MatchQuery("Content,str,onc", "json_col") + ) + assert results.num_rows == 1 + + # Test language match + results = dataset.to_table( + full_text_query=MatchQuery("Language,str,english", "json_col") + ) + assert results.num_rows == 3 + + # Test author match + results = dataset.to_table( + full_text_query=MatchQuery("Author,str,tolkien", "json_col") + ) + assert results.num_rows == 1 diff --git a/rust/lance-index/src/scalar/inverted/builder.rs b/rust/lance-index/src/scalar/inverted/builder.rs index 8da433647d8..4d2e44b450a 100644 --- a/rust/lance-index/src/scalar/inverted/builder.rs +++ b/rust/lance-index/src/scalar/inverted/builder.rs @@ -7,6 +7,7 @@ use super::{ InvertedIndexParams, }; use crate::scalar::inverted::json::JsonTextStream; +use crate::scalar::inverted::lance_tokenizer::DocType; use crate::scalar::inverted::tokenizer::lance_tokenizer::LanceTokenizer; use crate::scalar::lance_format::LanceIndexStore; use crate::scalar::IndexStore; @@ -144,6 +145,15 @@ impl InvertedIndexBuilder { ) -> Result<()> { let schema = new_data.schema(); let doc_col = schema.field(0).name(); + + // infer lance_tokenizer based on document type + if self.params.lance_tokenizer.is_none() { + let schema = new_data.schema(); + let field = schema.column_with_name(doc_col).expect_ok()?.1; + let doc_type = DocType::try_from(field)?; + self.params.lance_tokenizer = Some(doc_type.as_ref().to_string()); + } + let new_data = document_input(new_data, doc_col)?; self.update_index(new_data).await?; diff --git a/rust/lance-index/src/scalar/inverted/tokenizer.rs b/rust/lance-index/src/scalar/inverted/tokenizer.rs index 698283b7fea..324ae9d7b6a 100644 --- a/rust/lance-index/src/scalar/inverted/tokenizer.rs +++ b/rust/lance-index/src/scalar/inverted/tokenizer.rs @@ -30,7 +30,7 @@ pub struct InvertedIndexParams { /// lance tokenizer takes care of different data types, such as text, json, etc. /// - 'text': parsing input documents into tokens /// - 'json': parsing input json string into tokens - /// - default: text + /// - none: auto type inference pub(crate) lance_tokenizer: Option, /// base tokenizer: /// - `simple`: splits tokens on whitespace and punctuation @@ -146,7 +146,7 @@ impl InvertedIndexParams { /// Default to `English`. pub fn new(base_tokenizer: String, language: tantivy::tokenizer::Language) -> Self { Self { - lance_tokenizer: Some("text".to_owned()), + lance_tokenizer: None, base_tokenizer, language, with_position: false, diff --git a/rust/lance-index/src/scalar/inverted/tokenizer/lance_tokenizer.rs b/rust/lance-index/src/scalar/inverted/tokenizer/lance_tokenizer.rs index 50171999a33..30107bb2546 100644 --- a/rust/lance-index/src/scalar/inverted/tokenizer/lance_tokenizer.rs +++ b/rust/lance-index/src/scalar/inverted/tokenizer/lance_tokenizer.rs @@ -1,10 +1,54 @@ // SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: Copyright The Lance Authors +use arrow_schema::{DataType, Field}; +use lance_arrow::json::JSON_EXT_NAME; +use lance_arrow::ARROW_EXT_NAME_KEY; use serde_json::Value; use snafu::location; use tantivy::tokenizer::{BoxTokenStream, Token, TokenStream}; +/// Document type for full text search. +pub enum DocType { + Text, + Json, +} + +impl AsRef for DocType { + fn as_ref(&self) -> &str { + match self { + Self::Text => "text", + Self::Json => "json", + } + } +} + +impl TryFrom<&Field> for DocType { + type Error = lance_core::Error; + + fn try_from(field: &Field) -> Result { + match field.data_type() { + DataType::Utf8 | DataType::LargeUtf8 => Ok(Self::Text), + DataType::List(field) | DataType::LargeList(field) + if matches!(field.data_type(), DataType::Utf8 | DataType::LargeUtf8) => + { + Ok(Self::Text) + } + DataType::LargeBinary => match field.metadata().get(ARROW_EXT_NAME_KEY) { + Some(name) if name.as_str() == JSON_EXT_NAME => Ok(Self::Json), + _ => Err(lance_core::Error::InvalidInput { + source: format!("field {} is not json", field.name()).into(), + location: location!(), + }), + }, + _ => Err(lance_core::Error::InvalidInput { + source: format!("field {} is not json", field.name()).into(), + location: location!(), + }), + } + } +} + /// Lance full text search tokenizer. /// /// `LanceTokenizer` defines 2 methods for tokenization, normally they are the same, but sometimes diff --git a/rust/lance/src/dataset.rs b/rust/lance/src/dataset.rs index 6cc56932eb6..facbb283756 100644 --- a/rust/lance/src/dataset.rs +++ b/rust/lance/src/dataset.rs @@ -9497,4 +9497,38 @@ mod tests { // Should have both data writes (10 rows total) assert_eq!(final_dataset.count_rows(None).await.unwrap(), 10); } + + #[tokio::test] + async fn test_auto_infer_lance_tokenizer() { + let (mut dataset, json_col) = prepare_json_dataset().await; + + // Create inverted index for json col. Expect auto-infer 'json' for lance tokenizer. + dataset + .create_index( + &[&json_col], + IndexType::Inverted, + None, + &InvertedIndexParams::default(), + true, + ) + .await + .unwrap(); + + // Match query succeed only when lance tokenizer is 'json' + let query = FullTextSearchQuery { + query: FtsQuery::Match( + MatchQuery::new("Content,str,once".to_string()).with_column(Some(json_col.clone())), + ), + limit: None, + wand_factor: None, + }; + let batch = dataset + .scan() + .full_text_search(query) + .unwrap() + .try_into_batch() + .await + .unwrap(); + assert_eq!(1, batch.num_rows()); + } }