Skip to content

Commit

Permalink
Better support bytes, IPs, and JSON (#152)
Browse files Browse the repository at this point in the history
  • Loading branch information
GodTamIt authored Nov 19, 2023
1 parent 4ac17da commit eeaad34
Show file tree
Hide file tree
Showing 3 changed files with 211 additions and 21 deletions.
59 changes: 55 additions & 4 deletions src/document.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,10 @@ use pyo3::{
basic::CompareOp,
prelude::*,
types::{
PyAny, PyBool, PyDateAccess, PyDateTime, PyDict, PyList, PyTimeAccess,
PyTuple,
PyAny, PyBool, PyDateAccess, PyDateTime, PyDict, PyInt, PyList,
PyTimeAccess, PyTuple,
},
Python,
};

use chrono::{offset::TimeZone, NaiveDateTime, Utc};
Expand All @@ -23,7 +24,8 @@ use serde_json::Value as JsonValue;
use std::{
collections::{BTreeMap, HashMap},
fmt,
net::Ipv6Addr,
net::{IpAddr, Ipv6Addr},
str::FromStr,
};

pub(crate) fn extract_value(any: &PyAny) -> PyResult<Value> {
Expand All @@ -50,6 +52,11 @@ pub(crate) fn extract_value(any: &PyAny) -> PyResult<Value> {
if let Ok(b) = any.extract::<Vec<u8>>() {
return Ok(Value::Bytes(b));
}
if let Ok(dict) = any.downcast::<PyDict>() {
if let Ok(json) = pythonize::depythonize(dict) {
return Ok(Value::JsonObject(json));
}
}
Err(to_pyerr(format!("Value unsupported {any:?}")))
}

Expand Down Expand Up @@ -105,7 +112,37 @@ pub(crate) fn extract_value_for_type(
.map_err(to_pyerr_for_type("Facet", field_name, any))?
.inner,
),
_ => return Err(to_pyerr(format!("Value unsupported {:?}", any))),
tv::schema::Type::Bytes => Value::Bytes(
any.extract::<Vec<u8>>()
.map_err(to_pyerr_for_type("Bytes", field_name, any))?,
),
tv::schema::Type::Json => {
if let Ok(json_str) = any.extract::<&str>() {
return serde_json::from_str(json_str)
.map(Value::JsonObject)
.map_err(to_pyerr_for_type("Json", field_name, any));
}

Value::JsonObject(
any.downcast::<PyDict>()
.map(|dict| pythonize::depythonize(&dict))
.map_err(to_pyerr_for_type("Json", field_name, any))?
.map_err(to_pyerr_for_type("Json", field_name, any))?,
)
}
tv::schema::Type::IpAddr => {
let val = any
.extract::<&str>()
.map_err(to_pyerr_for_type("IpAddr", field_name, any))?;

IpAddr::from_str(val)
.map(|addr| match addr {
IpAddr::V4(addr) => addr.to_ipv6_mapped(),
IpAddr::V6(addr) => addr,
})
.map(Value::IpAddr)
.map_err(to_pyerr_for_type("IpAddr", field_name, any))?
}
};

Ok(value)
Expand All @@ -126,6 +163,20 @@ fn extract_value_single_or_list_for_type(
) -> PyResult<Vec<Value>> {
// Check if a numeric fast field supports multivalues.
if let Ok(values) = any.downcast::<PyList>() {
// Process an array of integers as a single entry if it is a bytes field.
if field_type.value_type() == tv::schema::Type::Bytes
&& values
.get_item(0)
.map(|v| v.is_instance_of::<PyInt>())
.unwrap_or(false)
{
return Ok(vec![extract_value_for_type(
values,
field_type.value_type(),
field_name,
)?]);
}

values
.iter()
.map(|any| {
Expand Down
90 changes: 82 additions & 8 deletions src/schemabuilder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,11 @@

use pyo3::{exceptions, prelude::*};

use tantivy::schema;

use crate::schema::Schema;
use std::sync::{Arc, RwLock};
use tantivy::schema::{DateOptions, INDEXED};
use tantivy::schema::{
self, BytesOptions, DateOptions, IpAddrOptions, INDEXED,
};

/// Tantivy has a very strict schema.
/// You need to specify in advance whether a field is indexed or not,
Expand Down Expand Up @@ -357,22 +357,96 @@ impl SchemaBuilder {

/// Add a fast bytes field to the schema.
///
/// Bytes field are not searchable and are only used
/// as fast field, to associate any kind of payload
/// to a document.
/// Args:
/// name (str): The name of the field.
/// stored (bool, optional): If true sets the field as stored, the
/// content of the field can be later restored from a Searcher.
/// Defaults to False.
/// indexed (bool, optional): If true sets the field to be indexed.
/// fast (str, optional): Set the bytes options as a fast field. A fast
/// field is a column-oriented fashion storage for tantivy. It is
/// designed for the fast random access of some document fields
/// given a document id.
#[pyo3(signature = (
name,
stored = false,
indexed = false,
fast = false
))]
fn add_bytes_field(
&mut self,
name: &str,
stored: bool,
indexed: bool,
fast: bool,
) -> PyResult<Self> {
let builder = &mut self.builder;
let mut opts = BytesOptions::default();
if stored {
opts = opts.set_stored();
}
if indexed {
opts = opts.set_indexed();
}
if fast {
opts = opts.set_fast();
}

if let Some(builder) = builder.write().unwrap().as_mut() {
builder.add_bytes_field(name, opts);
} else {
return Err(exceptions::PyValueError::new_err(
"Schema builder object isn't valid anymore.",
));
}
Ok(self.clone())
}

/// Add an IP address field to the schema.
///
/// Args:
/// name (str): The name of the field.
fn add_bytes_field(&mut self, name: &str) -> PyResult<Self> {
/// stored (bool, optional): If true sets the field as stored, the
/// content of the field can be later restored from a Searcher.
/// Defaults to False.
/// indexed (bool, optional): If true sets the field to be indexed.
/// fast (str, optional): Set the IP address options as a fast field. A
/// fast field is a column-oriented fashion storage for tantivy. It
/// is designed for the fast random access of some document fields
/// given a document id.
#[pyo3(signature = (
name,
stored = false,
indexed = false,
fast = false
))]
fn add_ip_addr_field(
&mut self,
name: &str,
stored: bool,
indexed: bool,
fast: bool,
) -> PyResult<Self> {
let builder = &mut self.builder;
let mut opts = IpAddrOptions::default();
if stored {
opts = opts.set_stored();
}
if indexed {
opts = opts.set_indexed();
}
if fast {
opts = opts.set_fast();
}

if let Some(builder) = builder.write().unwrap().as_mut() {
builder.add_bytes_field(name, INDEXED);
builder.add_ip_addr_field(name, opts);
} else {
return Err(exceptions::PyValueError::new_err(
"Schema builder object isn't valid anymore.",
));
}

Ok(self.clone())
}

Expand Down
83 changes: 74 additions & 9 deletions tests/tantivy_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import copy
import datetime
import json
import tantivy
import pickle
import pytest
Expand Down Expand Up @@ -365,7 +366,9 @@ def test_order_by_search(self):
searched_doc = index.searcher().doc(doc_address)
assert searched_doc["title"] == ["Test title"]

result = searcher.search(query, 10, order_by_field="order", order=tantivy.Order.Asc)
result = searcher.search(
query, 10, order_by_field="order", order=tantivy.Order.Asc
)

assert len(result.hits) == 3

Expand Down Expand Up @@ -443,7 +446,7 @@ def test_with_merges(self):

assert searcher.num_segments < 8

def test_doc_from_dict_schema_validation(self):
def test_doc_from_dict_numeric_validation(self):
schema = (
SchemaBuilder()
.add_unsigned_field("unsigned")
Expand Down Expand Up @@ -504,6 +507,70 @@ def test_doc_from_dict_schema_validation(self):
schema,
)

def test_doc_from_dict_bytes_validation(self):
schema = SchemaBuilder().add_bytes_field("bytes").build()

good = Document.from_dict({"bytes": b"hello"}, schema)
good = Document.from_dict({"bytes": [[1, 2, 3], [4, 5, 6]]}, schema)
good = Document.from_dict({"bytes": [1, 2, 3]}, schema)

with pytest.raises(ValueError):
bad = Document.from_dict({"bytes": [1, 2, 256]}, schema)

with pytest.raises(ValueError):
bad = Document.from_dict({"bytes": "hello"}, schema)

with pytest.raises(ValueError):
bad = Document.from_dict({"bytes": [1024, "there"]}, schema)

def test_doc_from_dict_ip_addr_validation(self):
schema = SchemaBuilder().add_ip_addr_field("ip").build()

good = Document.from_dict({"ip": "127.0.0.1"}, schema)
good = Document.from_dict({"ip": "::1"}, schema)

with pytest.raises(ValueError):
bad = Document.from_dict({"ip": 12309812348}, schema)

with pytest.raises(ValueError):
bad = Document.from_dict({"ip": "256.100.0.1"}, schema)

with pytest.raises(ValueError):
bad = Document.from_dict(
{"ip": "1234:5678:9ABC:DEF0:1234:5678:9ABC:DEF0:1234"}, schema
)

with pytest.raises(ValueError):
bad = Document.from_dict(
{"ip": "1234:5678:9ABC:DEF0:1234:5678:9ABC:GHIJ"}, schema
)

def test_doc_from_dict_json_validation(self):
# Test implicit JSON
good = Document.from_dict({"dict": {"hello": "world"}})

schema = SchemaBuilder().add_json_field("json").build()

good = Document.from_dict({"json": {}}, schema)
good = Document.from_dict({"json": {"hello": "world"}}, schema)
good = Document.from_dict(
{"nested": {"hello": ["world", "!"]}, "numbers": [1, 2, 3]}, schema
)

list_of_jsons = [
{"hello": "world"},
{"nested": {"hello": ["world", "!"]}, "numbers": [1, 2, 3]},
]
good = Document.from_dict({"json": list_of_jsons}, schema)

good = Document.from_dict({"json": json.dumps(list_of_jsons[1])}, schema)

with pytest.raises(ValueError):
bad = Document.from_dict({"json": 123}, schema)

with pytest.raises(ValueError):
bad = Document.from_dict({"json": "hello"}, schema)

def test_search_result_eq(self, ram_index, spanish_index):
eng_index = ram_index
eng_query = eng_index.parse_query("sea whale", ["title", "body"])
Expand Down Expand Up @@ -650,10 +717,6 @@ def test_document_with_facet(self):
doc = tantivy.Document(facet=facet)
assert doc["facet"][0].to_path() == ["asia/oceania", "fiji"]

def test_document_error(self):
with pytest.raises(ValueError):
tantivy.Document(name={})

def test_document_eq(self):
doc1 = tantivy.Document(name="Bill", reference=[1, 2])
doc2 = tantivy.Document.from_dict({"name": "Bill", "reference": [1, 2]})
Expand Down Expand Up @@ -848,9 +911,11 @@ def test_document_snippet(self, dir_index):
result = searcher.search(query)
assert len(result.hits) == 1

snippet_generator = SnippetGenerator.create(searcher, query, doc_schema, "title")
snippet_generator = SnippetGenerator.create(
searcher, query, doc_schema, "title"
)

for (score, doc_address) in result.hits:
for score, doc_address in result.hits:
doc = searcher.doc(doc_address)
snippet = snippet_generator.snippet_from_doc(doc)
highlights = snippet.highlighted()
Expand All @@ -859,4 +924,4 @@ def test_document_snippet(self, dir_index):
assert first.start == 20
assert first.end == 23
html_snippet = snippet.to_html()
assert html_snippet == 'The Old Man and the <b>Sea</b>'
assert html_snippet == "The Old Man and the <b>Sea</b>"

0 comments on commit eeaad34

Please sign in to comment.