Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
9500853
wip: adding Text and Hybrid queries
justin-cechmanek Mar 24, 2025
3fa93ff
tokenizer helper function
rbs333 Mar 24, 2025
3b8e2b6
adds TextQuery class
justin-cechmanek Mar 26, 2025
7e0f24d
adds nltk requirement
justin-cechmanek Mar 27, 2025
49b2aba
makes stopwords user defined in TextQuery
justin-cechmanek Mar 27, 2025
f7a4b9e
adds hybrid aggregation query and tests. modifies search index to acc…
justin-cechmanek Apr 1, 2025
6e007f7
Validate passed-in Redis clients (#296)
abrookins Mar 21, 2025
298d055
Add batch_search to sync Index (#305)
abrookins Mar 29, 2025
94eea52
Support client-side schema validation using Pydantic (#304)
tylerhutcherson Mar 31, 2025
123ee22
Run API tests once (#306)
abrookins Mar 31, 2025
9025bfe
Add option to normalize vector distances on query (#298)
rbs333 Mar 31, 2025
ae69ae9
adds TextQuery class
justin-cechmanek Mar 26, 2025
e403934
makes stopwords user defined in TextQuery
justin-cechmanek Mar 27, 2025
9348583
adds hybrid aggregation query and tests. modifies search index to acc…
justin-cechmanek Apr 1, 2025
10f4474
cleans text and hybrid tests
justin-cechmanek Apr 2, 2025
018fe9f
merge conflicts
justin-cechmanek Apr 2, 2025
3518121
updates lock file
justin-cechmanek Apr 2, 2025
091148c
mypy cannot find defined methods
justin-cechmanek Apr 2, 2025
9069dd5
updates nltk requirement
justin-cechmanek Apr 2, 2025
c5ad696
I swear I have changed this 4 times now
justin-cechmanek Apr 2, 2025
ea5d087
wip: debugging aggregations and filters
justin-cechmanek Apr 2, 2025
1672ea3
fixes query string parsing. adds more tests
justin-cechmanek Apr 3, 2025
f32067a
test now checks default dialect is 2
justin-cechmanek Apr 3, 2025
9b1dc18
makes methods private
justin-cechmanek Apr 3, 2025
ff44041
abstracts AggregationQuery to follow BaseQuery calls in search index
justin-cechmanek Apr 3, 2025
c0be24f
updates docstrings
justin-cechmanek Apr 3, 2025
aae3949
fixes docstring
justin-cechmanek Apr 4, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ tenacity = ">=8.2.2"
tabulate = "^0.9.0"
ml-dtypes = "^0.4.0"
python-ulid = "^3.0.0"
nltk = { version = "^3.8.1", optional = true }
openai = { version = "^1.13.0", optional = true }
sentence-transformers = { version = "^3.4.0", optional = true }
scipy = [
Expand Down
4 changes: 4 additions & 0 deletions redisvl/query/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@
BaseQuery,
CountQuery,
FilterQuery,
HybridQuery,
RangeQuery,
TextQuery,
VectorQuery,
VectorRangeQuery,
)
Expand All @@ -14,4 +16,6 @@
"RangeQuery",
"VectorRangeQuery",
"CountQuery",
"TextQuery",
"HybridQuery",
]
128 changes: 126 additions & 2 deletions redisvl/query/query.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from enum import Enum
from typing import Any, Dict, List, Optional, Union

from redis.commands.search.aggregation import AggregateRequest, Desc
from redis.commands.search.query import Query as RedisQuery

from redisvl.query.filter import FilterExpression
Expand Down Expand Up @@ -91,7 +92,8 @@ def __init__(
num_results (Optional[int], optional): The number of results to return. Defaults to 10.
dialect (int, optional): The query dialect. Defaults to 2.
sort_by (Optional[str], optional): The field to order the results by. Defaults to None.
in_order (bool, optional): Requires the terms in the field to have the same order as the terms in the query filter. Defaults to False.
in_order (bool, optional): Requires the terms in the field to have the same order as the
terms in the query filter. Defaults to False.
params (Optional[Dict[str, Any]], optional): The parameters for the query. Defaults to None.

Raises:
Expand Down Expand Up @@ -136,7 +138,8 @@ def __init__(
"""A query for a simple count operation provided some filter expression.

Args:
filter_expression (Optional[Union[str, FilterExpression]]): The filter expression to query with. Defaults to None.
filter_expression (Optional[Union[str, FilterExpression]]): The filter expression to
query with. Defaults to None.
params (Optional[Dict[str, Any]], optional): The parameters for the query. Defaults to None.

Raises:
Expand Down Expand Up @@ -214,6 +217,7 @@ def __init__(
"float32".
num_results (int, optional): The top k results to return from the
vector search. Defaults to 10.

return_score (bool, optional): Whether to return the vector
distance. Defaults to True.
dialect (int, optional): The RediSearch query dialect.
Expand Down Expand Up @@ -647,3 +651,123 @@ def params(self) -> Dict[str, Any]:
class RangeQuery(VectorRangeQuery):
# keep for backwards compatibility
pass


class TextQuery(FilterQuery):
def __init__(
self,
text: str,
text_field: str,
text_scorer: str = "BM25",
filter_expression: Optional[Union[str, FilterExpression]] = None,
return_fields: Optional[List[str]] = None,
num_results: int = 10,
return_score: bool = True,
dialect: int = 2,
sort_by: Optional[str] = None,
in_order: bool = False,
params: Optional[Dict[str, Any]] = None,
):
"""A query for running a full text and vector search, along with an optional
filter expression.

Args:
text (str): The text string to perform the text search with.
text_field (str): The name of the document field to perform text search on.
text_scorer (str, optional): The text scoring algorithm to use.
Defaults to BM25. Options are {TFIDF, BM25, DOCNORM, DISMAX, DOCSCORE}.
See https://redis.io/docs/latest/develop/interact/search-and-query/advanced-concepts/scoring/
filter_expression (Union[str, FilterExpression], optional): A filter to apply
along with the text search. Defaults to None.
return_fields (List[str]): The declared fields to return with search
results.
num_results (int, optional): The top k results to return from the
search. Defaults to 10.
return_score (bool, optional): Whether to return the text score.
Defaults to True.
dialect (int, optional): The RediSearch query dialect.
Defaults to 2.
sort_by (Optional[str]): The field to order the results by. Defaults
to None. Results will be ordered by text score.
in_order (bool): Requires the terms in the field to have
the same order as the terms in the query filter, regardless of
the offsets between them. Defaults to False.
params (Optional[Dict[str, Any]], optional): The parameters for the query.
Defaults to None.
"""
import nltk
from nltk.corpus import stopwords

nltk.download("stopwords")
self._stopwords = set(stopwords.words("english"))

self._text = text
self._text_field = text_field
self._text_scorer = text_scorer

self.set_filter(filter_expression)
self._num_results = num_results

query_string = self._build_query_string()

super().__init__(
query_string,
return_fields=return_fields,
num_results=num_results,
dialect=dialect,
sort_by=sort_by,
in_order=in_order,
params=params,
)

# Handle query modifiers
self.scorer(self._text_scorer)
self.paging(0, self._num_results).dialect(dialect)

if return_score:
self.with_scores()

def tokenize_and_escape_query(self, user_query: str) -> str:
"""Convert a raw user query to a redis full text query joined by ORs"""
from redisvl.utils.token_escaper import TokenEscaper

escaper = TokenEscaper()

tokens = [
escaper.escape(
token.strip().strip(",").replace("“", "").replace("”", "").lower()
)
for token in user_query.split()
]
return " | ".join(
[token for token in tokens if token and token not in self._stopwords]
)

def _build_query_string(self) -> str:
"""Build the full query string for text search with optional filtering."""
filter_expression = self._filter_expression
if isinstance(filter_expression, FilterExpression):
filter_expression = str(filter_expression)
else:
filter_expression = ""

text = f"(~@{self._text_field}:({self.tokenize_and_escape_query(self._text)}))"
if filter_expression and filter_expression != "*":
text += f"({filter_expression})"
return text


class HybridQuery(AggregateRequest):
def __init__(
self, text_query: TextQuery, vector_query: VectorQuery, alpha: float = 0.7
):
"""An aggregate query for running a hybrid full text and vector search.

Args:
text_query (TextQuery): The text query to run text search with.
vector_query (VectorQuery): The vector query to run vector search with.
alpha (float, optional): The amount to weight the vector similarity
score relative to the text similarity score. Defaults to 0.7

"""
pass
72 changes: 71 additions & 1 deletion tests/unit/test_query_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,14 @@
from redis.commands.search.result import Result

from redisvl.index.index import process_results
from redisvl.query import CountQuery, FilterQuery, RangeQuery, VectorQuery
from redisvl.query import (
CountQuery,
FilterQuery,
HybridQuery,
RangeQuery,
TextQuery,
VectorQuery,
)
from redisvl.query.filter import Tag
from redisvl.query.query import VectorRangeQuery

Expand Down Expand Up @@ -188,6 +195,69 @@ def test_range_query():
assert range_query._in_order


def test_text_query():
text_string = "the toon squad play basketball against a gang of aliens"
text_field_name = "description"
return_fields = ["title", "genre", "rating"]
text_query = TextQuery(
text=text_string,
text_field=text_field_name,
return_fields=return_fields,
return_score=False,
)

# Check properties
assert text_query._return_fields == return_fields
assert text_query._num_results == 10
assert (
text_query.filter
== f"(~@{text_field_name}:({text_query.tokenize_and_escape_query(text_string)}))"
)
assert isinstance(text_query, Query)
assert isinstance(text_query.query, Query)
assert isinstance(text_query.params, dict)
assert text_query._text_scorer == "BM25"
assert text_query.params == {}
assert text_query._dialect == 2
assert text_query._in_order == False

# Test paging functionality
text_query.paging(5, 7)
assert text_query._offset == 5
assert text_query._num == 7
assert text_query._num_results == 10

# Test sort_by functionality
filter_expression = Tag("genre") == "comedy"
scorer = "TFIDF"
text_query = TextQuery(
text_string,
text_field_name,
scorer,
filter_expression,
return_fields,
num_results=10,
sort_by="rating",
)
assert text_query._sortby is not None

# Test in_order functionality
text_query = TextQuery(
text_string,
text_field_name,
scorer,
filter_expression,
return_fields,
num_results=10,
in_order=True,
)
assert text_query._in_order


def test_hybrid_query():
pass


@pytest.mark.parametrize(
"query",
[
Expand Down
Loading