From eeaad34a984e04605db3dbadfe0d29cc088ff13a Mon Sep 17 00:00:00 2001 From: Chris Tam Date: Sun, 19 Nov 2023 06:08:33 -0500 Subject: [PATCH] Better support bytes, IPs, and JSON (#152) --- src/document.rs | 59 ++++++++++++++++++++++++++-- src/schemabuilder.rs | 90 +++++++++++++++++++++++++++++++++++++++---- tests/tantivy_test.py | 83 ++++++++++++++++++++++++++++++++++----- 3 files changed, 211 insertions(+), 21 deletions(-) diff --git a/src/document.rs b/src/document.rs index 3adb13d1..f4e2cb7d 100644 --- a/src/document.rs +++ b/src/document.rs @@ -6,9 +6,10 @@ use pyo3::{ basic::CompareOp, prelude::*, types::{ - PyAny, PyBool, PyDateAccess, PyDateTime, PyDict, PyList, PyTimeAccess, - PyTuple, + PyAny, PyBool, PyDateAccess, PyDateTime, PyDict, PyInt, PyList, + PyTimeAccess, PyTuple, }, + Python, }; use chrono::{offset::TimeZone, NaiveDateTime, Utc}; @@ -23,7 +24,8 @@ use serde_json::Value as JsonValue; use std::{ collections::{BTreeMap, HashMap}, fmt, - net::Ipv6Addr, + net::{IpAddr, Ipv6Addr}, + str::FromStr, }; pub(crate) fn extract_value(any: &PyAny) -> PyResult { @@ -50,6 +52,11 @@ pub(crate) fn extract_value(any: &PyAny) -> PyResult { if let Ok(b) = any.extract::>() { return Ok(Value::Bytes(b)); } + if let Ok(dict) = any.downcast::() { + if let Ok(json) = pythonize::depythonize(dict) { + return Ok(Value::JsonObject(json)); + } + } Err(to_pyerr(format!("Value unsupported {any:?}"))) } @@ -105,7 +112,37 @@ pub(crate) fn extract_value_for_type( .map_err(to_pyerr_for_type("Facet", field_name, any))? .inner, ), - _ => return Err(to_pyerr(format!("Value unsupported {:?}", any))), + tv::schema::Type::Bytes => Value::Bytes( + any.extract::>() + .map_err(to_pyerr_for_type("Bytes", field_name, any))?, + ), + tv::schema::Type::Json => { + if let Ok(json_str) = any.extract::<&str>() { + return serde_json::from_str(json_str) + .map(Value::JsonObject) + .map_err(to_pyerr_for_type("Json", field_name, any)); + } + + Value::JsonObject( + any.downcast::() + .map(|dict| pythonize::depythonize(&dict)) + .map_err(to_pyerr_for_type("Json", field_name, any))? + .map_err(to_pyerr_for_type("Json", field_name, any))?, + ) + } + tv::schema::Type::IpAddr => { + let val = any + .extract::<&str>() + .map_err(to_pyerr_for_type("IpAddr", field_name, any))?; + + IpAddr::from_str(val) + .map(|addr| match addr { + IpAddr::V4(addr) => addr.to_ipv6_mapped(), + IpAddr::V6(addr) => addr, + }) + .map(Value::IpAddr) + .map_err(to_pyerr_for_type("IpAddr", field_name, any))? + } }; Ok(value) @@ -126,6 +163,20 @@ fn extract_value_single_or_list_for_type( ) -> PyResult> { // Check if a numeric fast field supports multivalues. if let Ok(values) = any.downcast::() { + // Process an array of integers as a single entry if it is a bytes field. + if field_type.value_type() == tv::schema::Type::Bytes + && values + .get_item(0) + .map(|v| v.is_instance_of::()) + .unwrap_or(false) + { + return Ok(vec![extract_value_for_type( + values, + field_type.value_type(), + field_name, + )?]); + } + values .iter() .map(|any| { diff --git a/src/schemabuilder.rs b/src/schemabuilder.rs index 4984b6df..abe89518 100644 --- a/src/schemabuilder.rs +++ b/src/schemabuilder.rs @@ -2,11 +2,11 @@ use pyo3::{exceptions, prelude::*}; -use tantivy::schema; - use crate::schema::Schema; use std::sync::{Arc, RwLock}; -use tantivy::schema::{DateOptions, INDEXED}; +use tantivy::schema::{ + self, BytesOptions, DateOptions, IpAddrOptions, INDEXED, +}; /// Tantivy has a very strict schema. /// You need to specify in advance whether a field is indexed or not, @@ -357,22 +357,96 @@ impl SchemaBuilder { /// Add a fast bytes field to the schema. /// - /// Bytes field are not searchable and are only used - /// as fast field, to associate any kind of payload - /// to a document. + /// Args: + /// name (str): The name of the field. + /// stored (bool, optional): If true sets the field as stored, the + /// content of the field can be later restored from a Searcher. + /// Defaults to False. + /// indexed (bool, optional): If true sets the field to be indexed. + /// fast (str, optional): Set the bytes options as a fast field. A fast + /// field is a column-oriented fashion storage for tantivy. It is + /// designed for the fast random access of some document fields + /// given a document id. + #[pyo3(signature = ( + name, + stored = false, + indexed = false, + fast = false + ))] + fn add_bytes_field( + &mut self, + name: &str, + stored: bool, + indexed: bool, + fast: bool, + ) -> PyResult { + let builder = &mut self.builder; + let mut opts = BytesOptions::default(); + if stored { + opts = opts.set_stored(); + } + if indexed { + opts = opts.set_indexed(); + } + if fast { + opts = opts.set_fast(); + } + + if let Some(builder) = builder.write().unwrap().as_mut() { + builder.add_bytes_field(name, opts); + } else { + return Err(exceptions::PyValueError::new_err( + "Schema builder object isn't valid anymore.", + )); + } + Ok(self.clone()) + } + + /// Add an IP address field to the schema. /// /// Args: /// name (str): The name of the field. - fn add_bytes_field(&mut self, name: &str) -> PyResult { + /// stored (bool, optional): If true sets the field as stored, the + /// content of the field can be later restored from a Searcher. + /// Defaults to False. + /// indexed (bool, optional): If true sets the field to be indexed. + /// fast (str, optional): Set the IP address options as a fast field. A + /// fast field is a column-oriented fashion storage for tantivy. It + /// is designed for the fast random access of some document fields + /// given a document id. + #[pyo3(signature = ( + name, + stored = false, + indexed = false, + fast = false + ))] + fn add_ip_addr_field( + &mut self, + name: &str, + stored: bool, + indexed: bool, + fast: bool, + ) -> PyResult { let builder = &mut self.builder; + let mut opts = IpAddrOptions::default(); + if stored { + opts = opts.set_stored(); + } + if indexed { + opts = opts.set_indexed(); + } + if fast { + opts = opts.set_fast(); + } if let Some(builder) = builder.write().unwrap().as_mut() { - builder.add_bytes_field(name, INDEXED); + builder.add_ip_addr_field(name, opts); } else { return Err(exceptions::PyValueError::new_err( "Schema builder object isn't valid anymore.", )); } + Ok(self.clone()) } diff --git a/tests/tantivy_test.py b/tests/tantivy_test.py index 5c8e1b56..5e409207 100644 --- a/tests/tantivy_test.py +++ b/tests/tantivy_test.py @@ -2,6 +2,7 @@ import copy import datetime +import json import tantivy import pickle import pytest @@ -365,7 +366,9 @@ def test_order_by_search(self): searched_doc = index.searcher().doc(doc_address) assert searched_doc["title"] == ["Test title"] - result = searcher.search(query, 10, order_by_field="order", order=tantivy.Order.Asc) + result = searcher.search( + query, 10, order_by_field="order", order=tantivy.Order.Asc + ) assert len(result.hits) == 3 @@ -443,7 +446,7 @@ def test_with_merges(self): assert searcher.num_segments < 8 - def test_doc_from_dict_schema_validation(self): + def test_doc_from_dict_numeric_validation(self): schema = ( SchemaBuilder() .add_unsigned_field("unsigned") @@ -504,6 +507,70 @@ def test_doc_from_dict_schema_validation(self): schema, ) + def test_doc_from_dict_bytes_validation(self): + schema = SchemaBuilder().add_bytes_field("bytes").build() + + good = Document.from_dict({"bytes": b"hello"}, schema) + good = Document.from_dict({"bytes": [[1, 2, 3], [4, 5, 6]]}, schema) + good = Document.from_dict({"bytes": [1, 2, 3]}, schema) + + with pytest.raises(ValueError): + bad = Document.from_dict({"bytes": [1, 2, 256]}, schema) + + with pytest.raises(ValueError): + bad = Document.from_dict({"bytes": "hello"}, schema) + + with pytest.raises(ValueError): + bad = Document.from_dict({"bytes": [1024, "there"]}, schema) + + def test_doc_from_dict_ip_addr_validation(self): + schema = SchemaBuilder().add_ip_addr_field("ip").build() + + good = Document.from_dict({"ip": "127.0.0.1"}, schema) + good = Document.from_dict({"ip": "::1"}, schema) + + with pytest.raises(ValueError): + bad = Document.from_dict({"ip": 12309812348}, schema) + + with pytest.raises(ValueError): + bad = Document.from_dict({"ip": "256.100.0.1"}, schema) + + with pytest.raises(ValueError): + bad = Document.from_dict( + {"ip": "1234:5678:9ABC:DEF0:1234:5678:9ABC:DEF0:1234"}, schema + ) + + with pytest.raises(ValueError): + bad = Document.from_dict( + {"ip": "1234:5678:9ABC:DEF0:1234:5678:9ABC:GHIJ"}, schema + ) + + def test_doc_from_dict_json_validation(self): + # Test implicit JSON + good = Document.from_dict({"dict": {"hello": "world"}}) + + schema = SchemaBuilder().add_json_field("json").build() + + good = Document.from_dict({"json": {}}, schema) + good = Document.from_dict({"json": {"hello": "world"}}, schema) + good = Document.from_dict( + {"nested": {"hello": ["world", "!"]}, "numbers": [1, 2, 3]}, schema + ) + + list_of_jsons = [ + {"hello": "world"}, + {"nested": {"hello": ["world", "!"]}, "numbers": [1, 2, 3]}, + ] + good = Document.from_dict({"json": list_of_jsons}, schema) + + good = Document.from_dict({"json": json.dumps(list_of_jsons[1])}, schema) + + with pytest.raises(ValueError): + bad = Document.from_dict({"json": 123}, schema) + + with pytest.raises(ValueError): + bad = Document.from_dict({"json": "hello"}, schema) + def test_search_result_eq(self, ram_index, spanish_index): eng_index = ram_index eng_query = eng_index.parse_query("sea whale", ["title", "body"]) @@ -650,10 +717,6 @@ def test_document_with_facet(self): doc = tantivy.Document(facet=facet) assert doc["facet"][0].to_path() == ["asia/oceania", "fiji"] - def test_document_error(self): - with pytest.raises(ValueError): - tantivy.Document(name={}) - def test_document_eq(self): doc1 = tantivy.Document(name="Bill", reference=[1, 2]) doc2 = tantivy.Document.from_dict({"name": "Bill", "reference": [1, 2]}) @@ -848,9 +911,11 @@ def test_document_snippet(self, dir_index): result = searcher.search(query) assert len(result.hits) == 1 - snippet_generator = SnippetGenerator.create(searcher, query, doc_schema, "title") + snippet_generator = SnippetGenerator.create( + searcher, query, doc_schema, "title" + ) - for (score, doc_address) in result.hits: + for score, doc_address in result.hits: doc = searcher.doc(doc_address) snippet = snippet_generator.snippet_from_doc(doc) highlights = snippet.highlighted() @@ -859,4 +924,4 @@ def test_document_snippet(self, dir_index): assert first.start == 20 assert first.end == 23 html_snippet = snippet.to_html() - assert html_snippet == 'The Old Man and the Sea' + assert html_snippet == "The Old Man and the Sea"