Skip to content
Merged
Show file tree
Hide file tree
Changes from 20 commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
9500853
wip: adding Text and Hybrid queries
justin-cechmanek Mar 24, 2025
3fa93ff
tokenizer helper function
rbs333 Mar 24, 2025
3b8e2b6
adds TextQuery class
justin-cechmanek Mar 26, 2025
7e0f24d
adds nltk requirement
justin-cechmanek Mar 27, 2025
49b2aba
makes stopwords user defined in TextQuery
justin-cechmanek Mar 27, 2025
f7a4b9e
adds hybrid aggregation query and tests. modifies search index to acc…
justin-cechmanek Apr 1, 2025
6e007f7
Validate passed-in Redis clients (#296)
abrookins Mar 21, 2025
298d055
Add batch_search to sync Index (#305)
abrookins Mar 29, 2025
94eea52
Support client-side schema validation using Pydantic (#304)
tylerhutcherson Mar 31, 2025
123ee22
Run API tests once (#306)
abrookins Mar 31, 2025
9025bfe
Add option to normalize vector distances on query (#298)
rbs333 Mar 31, 2025
ae69ae9
adds TextQuery class
justin-cechmanek Mar 26, 2025
e403934
makes stopwords user defined in TextQuery
justin-cechmanek Mar 27, 2025
9348583
adds hybrid aggregation query and tests. modifies search index to acc…
justin-cechmanek Apr 1, 2025
10f4474
cleans text and hybrid tests
justin-cechmanek Apr 2, 2025
018fe9f
merge conflicts
justin-cechmanek Apr 2, 2025
3518121
updates lock file
justin-cechmanek Apr 2, 2025
091148c
mypy cannot find defined methods
justin-cechmanek Apr 2, 2025
9069dd5
updates nltk requirement
justin-cechmanek Apr 2, 2025
c5ad696
I swear I have changed this 4 times now
justin-cechmanek Apr 2, 2025
ea5d087
wip: debugging aggregations and filters
justin-cechmanek Apr 2, 2025
1672ea3
fixes query string parsing. adds more tests
justin-cechmanek Apr 3, 2025
f32067a
test now checks default dialect is 2
justin-cechmanek Apr 3, 2025
9b1dc18
makes methods private
justin-cechmanek Apr 3, 2025
ff44041
abstracts AggregationQuery to follow BaseQuery calls in search index
justin-cechmanek Apr 3, 2025
c0be24f
updates docstrings
justin-cechmanek Apr 3, 2025
aae3949
fixes docstring
justin-cechmanek Apr 4, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
555 changes: 242 additions & 313 deletions poetry.lock

Large diffs are not rendered by default.

3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,8 @@ tenacity = ">=8.2.2"
tabulate = "^0.9.0"
ml-dtypes = "^0.4.0"
python-ulid = "^3.0.0"
nltk = { version = "^3.8.1", optional = true }
jsonpath-ng = "^1.5.0"

openai = { version = "^1.13.0", optional = true }
sentence-transformers = { version = "^3.4.0", optional = true }
scipy = [
Expand All @@ -58,6 +58,7 @@ mistralai = ["mistralai"]
voyageai = ["voyageai"]
ranx = ["ranx"]
bedrock = ["boto3"]
nltk = ["nltk"]

[tool.poetry.group.dev.dependencies]
black = "^25.1.0"
Expand Down
76 changes: 75 additions & 1 deletion redisvl/index/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
Union,
)

from redisvl.redis.utils import convert_bytes, make_dict
from redisvl.utils.utils import deprecated_argument, deprecated_function, sync_wrapper

if TYPE_CHECKING:
Expand All @@ -39,7 +40,14 @@
SchemaValidationError,
)
from redisvl.index.storage import BaseStorage, HashStorage, JsonStorage
from redisvl.query import BaseQuery, BaseVectorQuery, CountQuery, FilterQuery
from redisvl.query import (
AggregationQuery,
BaseQuery,
BaseVectorQuery,
CountQuery,
FilterQuery,
HybridAggregationQuery,
)
from redisvl.query.filter import FilterExpression
from redisvl.redis.connection import (
RedisConnectionFactory,
Expand Down Expand Up @@ -138,6 +146,34 @@ def _process(doc: "Document") -> Dict[str, Any]:
return [_process(doc) for doc in results.docs]


def process_aggregate_results(
results: "AggregateResult", query: AggregationQuery, storage_type: StorageType
) -> List[Dict[str, Any]]:
"""Convert an aggregate reslt object into a list of document dictionaries.

This function processes results from Redis, handling different storage
types and query types. For JSON storage with empty return fields, it
unpacks the JSON object while retaining the document ID. The 'payload'
field is also removed from all resulting documents for consistency.

Args:
results (AggregarteResult): The aggregart results from Redis.
query (AggregationQuery): The aggregation query object used for the aggregation.
storage_type (StorageType): The storage type of the search
index (json or hash).

Returns:
List[Dict[str, Any]]: A list of processed document dictionaries.
"""

def _process(row):
result = make_dict(convert_bytes(row))
result.pop("__score", None)
return result

return [_process(r) for r in results.rows]


class BaseSearchIndex:
"""Base search engine class"""

Expand Down Expand Up @@ -650,6 +686,44 @@ def fetch(self, id: str) -> Optional[Dict[str, Any]]:
return convert_bytes(obj[0])
return None

def aggregate_query(
self, aggregation_query: AggregationQuery
) -> List[Dict[str, Any]]:
"""Execute an aggretation query and processes the results.

This method takes an AggregationHyridQuery object directly, runs the search, and
handles post-processing of the search.

Args:
aggregation_query (AggregationQuery): The aggregation query to run.

Returns:
List[Result]: A list of search results.

.. code-block:: python

from redisvl.query import HybridAggregationQuery

aggregation = HybridAggregationQuery(
text="the text to search for",
text_field="description",
vector=[0.16, -0.34, 0.98, 0.23],
vector_field="embedding",
num_results=3
)

results = index.aggregate_query(aggregation_query)

"""
results = self.aggregate(
aggregation_query, query_params=aggregation_query.params # type: ignore[attr-defined]
)
return process_aggregate_results(
results,
query=aggregation_query,
storage_type=self.schema.index.storage_type,
)

def aggregate(self, *args, **kwargs) -> "AggregateResult":
"""Perform an aggregation operation against the index.

Expand Down
5 changes: 5 additions & 0 deletions redisvl/query/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
from redisvl.query.aggregate import AggregationQuery, HybridAggregationQuery
from redisvl.query.query import (
BaseQuery,
BaseVectorQuery,
CountQuery,
FilterQuery,
RangeQuery,
TextQuery,
VectorQuery,
VectorRangeQuery,
)
Expand All @@ -16,4 +18,7 @@
"RangeQuery",
"VectorRangeQuery",
"CountQuery",
"TextQuery",
"AggregationQuery",
"HybridAggregationQuery",
]
221 changes: 221 additions & 0 deletions redisvl/query/aggregate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,221 @@
from typing import Any, Dict, List, Optional, Set, Tuple, Union

import nltk
from nltk.corpus import stopwords as nltk_stopwords
from redis.commands.search.aggregation import AggregateRequest, Desc

from redisvl.query.filter import FilterExpression
from redisvl.redis.utils import array_to_buffer
from redisvl.utils.token_escaper import TokenEscaper


# base class
class AggregationQuery(AggregateRequest):
"""
Base class for aggregation queries used to create aggregation queries for Redis.
"""

def __init__(self, query_string):
super().__init__(query_string)


class HybridAggregationQuery(AggregationQuery):
"""
HybridAggregationQuery combines text and vector search in Redis.
It allows you to perform a hybrid search using both text and vector similarity.
It scores documents based on a weighted combination of text and vector similarity.
"""

DISTANCE_ID: str = "vector_distance"
VECTOR_PARAM: str = "vector"

def __init__(
self,
text: str,
text_field: str,
vector: Union[bytes, List[float]],
vector_field: str,
text_scorer: str = "BM25STD",
filter_expression: Optional[Union[str, FilterExpression]] = None,
alpha: float = 0.7,
dtype: str = "float32",
num_results: int = 10,
return_fields: Optional[List[str]] = None,
stopwords: Optional[Union[str, Set[str]]] = "english",
dialect: int = 4,
):
"""
Instantiages a HybridAggregationQuery object.
Args:
text (str): The text to search for.
text_field (str): The text field name to search in.
vector (Union[bytes, List[float]]): The vector to perform vector similarity search.
vector_field (str): The vector field name to search in.
text_scorer (str, optional): The text scorer to use. Options are {TFIDF, TFIDF.DOCNORM,
BM25, DISMAX, DOCSCORE, BM25STD}. Defaults to "BM25STD".
filter_expression (Optional[FilterExpression], optional): The filter expression to use.
Defaults to None.
alpha (float, optional): The weight of the vector similarity. Documents will be scored
as: hybrid_score = (alpha) * vector_score + (1-alpha) * text_score.
Defaults to 0.7.
dtype (str, optional): The data type of the vector. Defaults to "float32".
num_results (int, optional): The number of results to return. Defaults to 10.
return_fields (Optional[List[str]], optional): The fields to return. Defaults to None.
stopwords (Optional[Union[str, Set[str]]], optional): The stopwords to remove from the
provided text prior to searchuse. If a string such as "english" "german" is
provided then a default set of stopwords for that language will be used. if a list,
set, or tuple of strings is provided then those will be used as stopwords.
Defaults to "english". if set to "None" then no stopwords will be removed.
dialect (int, optional): The Redis dialect version. Defaults to 4.
Raises:
ValueError: If the text string is empty, or if the text string becomes empty after
stopwords are removed.
TypeError: If the stopwords are not a set, list, or tuple of strings.
.. code-block:: python
from redisvl.query.aggregate import HybridAggregationQuery
from redisvl.index import SearchIndex
index = SearchIndex("my_index")
query = HybridAggregationQuery(
text="example text",
text_field="text_field",
vector=[0.1, 0.2, 0.3],
vector_field="vector_field",
text_scorer="BM25STD",
filter_expression=None,
alpha=0.7,
dtype="float32",
num_results=10,
return_fields=["field1", "field2"],
stopwords="english",
dialect=4,
)
results = index.aggregate_query(query)
"""

if not text.strip():
raise ValueError("text string cannot be empty")

self._text = text
self._text_field = text_field
self._vector = vector
self._vector_field = vector_field
self._filter_expression = filter_expression
self._alpha = alpha
self._dtype = dtype
self._num_results = num_results
self.set_stopwords(stopwords)

query_string = self._build_query_string()
super().__init__(query_string)

self.scorer(text_scorer) # type: ignore[attr-defined]
self.add_scores() # type: ignore[attr-defined]
self.apply(
vector_similarity=f"(2 - @{self.DISTANCE_ID})/2", text_score="@__score"
)
self.apply(hybrid_score=f"{1-alpha}*@text_score + {alpha}*@vector_similarity")
self.sort_by(Desc("@hybrid_score"), max=num_results)
self.dialect(dialect) # type: ignore[attr-defined]

if return_fields:
self.load(*return_fields)

@property
def params(self) -> Dict[str, Any]:
"""Return the parameters for the aggregation.
Returns:
Dict[str, Any]: The parameters for the aggregation.
"""
if isinstance(self._vector, bytes):
vector = self._vector
else:
vector = array_to_buffer(self._vector, dtype=self._dtype)

params = {self.VECTOR_PARAM: vector}

return params

@property
def stopwords(self) -> Set[str]:
"""Return the stopwords used in the query.
Returns:
Set[str]: The stopwords used in the query.
"""
return self._stopwords.copy() if self._stopwords else set()

def set_stopwords(self, stopwords: Optional[Union[str, Set[str]]] = "english"):
"""Set the stopwords to use in the query.
Args:
stopwords (Optional[Union[str, Set[str]]]): The stopwords to use. If a string
such as "english" "german" is provided then a default set of stopwords for that
language will be used. if a list, set, or tuple of strings is provided then those
will be used as stopwords. Defaults to "english". if set to "None" then no stopwords
will be removed.
Raises:
TypeError: If the stopwords are not a set, list, or tuple of strings.
"""
if not stopwords:
self._stopwords = set()
elif isinstance(stopwords, str):
try:
nltk.download("stopwords")
self._stopwords = set(nltk_stopwords.words(stopwords))
except Exception as e:
raise ValueError(f"Error trying to load {stopwords} from nltk. {e}")
elif isinstance(stopwords, (Set, List, Tuple)) and all( # type: ignore
isinstance(word, str) for word in stopwords
):
self._stopwords = set(stopwords)
else:
raise TypeError("stopwords must be a set, list, or tuple of strings")

def tokenize_and_escape_query(self, user_query: str) -> str:
"""Convert a raw user query to a redis full text query joined by ORs
Args:
user_query (str): The user query to tokenize and escape.
Returns:
str: The tokenized and escaped query string.
Raises:
ValueError: If the text string becomes empty after stopwords are removed.
"""

escaper = TokenEscaper()

tokens = [
escaper.escape(
token.strip().strip(",").replace("“", "").replace("”", "").lower()
)
for token in user_query.split()
]
tokenized = " | ".join(
[token for token in tokens if token and token not in self._stopwords]
)

if not tokenized:
raise ValueError("text string cannot be empty after removing stopwords")
return tokenized

def _build_query_string(self) -> str:
"""Build the full query string for text search with optional filtering."""
if isinstance(self._filter_expression, FilterExpression):
filter_expression = str(self._filter_expression)
else:
filter_expression = ""

# base KNN query
knn_query = f"KNN {self._num_results} @{self._vector_field} ${self.VECTOR_PARAM} AS {self.DISTANCE_ID}"

text = f"(~@{self._text_field}:({self.tokenize_and_escape_query(self._text)}))"

if filter_expression and filter_expression != "*":
text += f"({filter_expression})"

return f"{text}=>[{knn_query}]"
Loading
Loading