-
Notifications
You must be signed in to change notification settings - Fork 62
Add support for full text queries and hybrid search queries #303
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from 20 commits
Commits
Show all changes
27 commits
Select commit
Hold shift + click to select a range
9500853
wip: adding Text and Hybrid queries
justin-cechmanek 3fa93ff
tokenizer helper function
rbs333 3b8e2b6
adds TextQuery class
justin-cechmanek 7e0f24d
adds nltk requirement
justin-cechmanek 49b2aba
makes stopwords user defined in TextQuery
justin-cechmanek f7a4b9e
adds hybrid aggregation query and tests. modifies search index to acc…
justin-cechmanek 6e007f7
Validate passed-in Redis clients (#296)
abrookins 298d055
Add batch_search to sync Index (#305)
abrookins 94eea52
Support client-side schema validation using Pydantic (#304)
tylerhutcherson 123ee22
Run API tests once (#306)
abrookins 9025bfe
Add option to normalize vector distances on query (#298)
rbs333 ae69ae9
adds TextQuery class
justin-cechmanek e403934
makes stopwords user defined in TextQuery
justin-cechmanek 9348583
adds hybrid aggregation query and tests. modifies search index to acc…
justin-cechmanek 10f4474
cleans text and hybrid tests
justin-cechmanek 018fe9f
merge conflicts
justin-cechmanek 3518121
updates lock file
justin-cechmanek 091148c
mypy cannot find defined methods
justin-cechmanek 9069dd5
updates nltk requirement
justin-cechmanek c5ad696
I swear I have changed this 4 times now
justin-cechmanek ea5d087
wip: debugging aggregations and filters
justin-cechmanek 1672ea3
fixes query string parsing. adds more tests
justin-cechmanek f32067a
test now checks default dialect is 2
justin-cechmanek 9b1dc18
makes methods private
justin-cechmanek ff44041
abstracts AggregationQuery to follow BaseQuery calls in search index
justin-cechmanek c0be24f
updates docstrings
justin-cechmanek aae3949
fixes docstring
justin-cechmanek File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,221 @@ | ||
| from typing import Any, Dict, List, Optional, Set, Tuple, Union | ||
|
|
||
| import nltk | ||
| from nltk.corpus import stopwords as nltk_stopwords | ||
| from redis.commands.search.aggregation import AggregateRequest, Desc | ||
|
|
||
| from redisvl.query.filter import FilterExpression | ||
| from redisvl.redis.utils import array_to_buffer | ||
| from redisvl.utils.token_escaper import TokenEscaper | ||
|
|
||
|
|
||
| # base class | ||
justin-cechmanek marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| class AggregationQuery(AggregateRequest): | ||
| """ | ||
| Base class for aggregation queries used to create aggregation queries for Redis. | ||
| """ | ||
|
|
||
| def __init__(self, query_string): | ||
| super().__init__(query_string) | ||
|
|
||
|
|
||
| class HybridAggregationQuery(AggregationQuery): | ||
| """ | ||
| HybridAggregationQuery combines text and vector search in Redis. | ||
| It allows you to perform a hybrid search using both text and vector similarity. | ||
| It scores documents based on a weighted combination of text and vector similarity. | ||
| """ | ||
|
|
||
| DISTANCE_ID: str = "vector_distance" | ||
| VECTOR_PARAM: str = "vector" | ||
|
|
||
| def __init__( | ||
| self, | ||
| text: str, | ||
| text_field: str, | ||
| vector: Union[bytes, List[float]], | ||
| vector_field: str, | ||
| text_scorer: str = "BM25STD", | ||
| filter_expression: Optional[Union[str, FilterExpression]] = None, | ||
| alpha: float = 0.7, | ||
| dtype: str = "float32", | ||
| num_results: int = 10, | ||
| return_fields: Optional[List[str]] = None, | ||
| stopwords: Optional[Union[str, Set[str]]] = "english", | ||
tylerhutcherson marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| dialect: int = 4, | ||
justin-cechmanek marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| ): | ||
| """ | ||
| Instantiages a HybridAggregationQuery object. | ||
| Args: | ||
| text (str): The text to search for. | ||
| text_field (str): The text field name to search in. | ||
| vector (Union[bytes, List[float]]): The vector to perform vector similarity search. | ||
| vector_field (str): The vector field name to search in. | ||
| text_scorer (str, optional): The text scorer to use. Options are {TFIDF, TFIDF.DOCNORM, | ||
| BM25, DISMAX, DOCSCORE, BM25STD}. Defaults to "BM25STD". | ||
| filter_expression (Optional[FilterExpression], optional): The filter expression to use. | ||
| Defaults to None. | ||
| alpha (float, optional): The weight of the vector similarity. Documents will be scored | ||
| as: hybrid_score = (alpha) * vector_score + (1-alpha) * text_score. | ||
| Defaults to 0.7. | ||
| dtype (str, optional): The data type of the vector. Defaults to "float32". | ||
| num_results (int, optional): The number of results to return. Defaults to 10. | ||
| return_fields (Optional[List[str]], optional): The fields to return. Defaults to None. | ||
| stopwords (Optional[Union[str, Set[str]]], optional): The stopwords to remove from the | ||
| provided text prior to searchuse. If a string such as "english" "german" is | ||
| provided then a default set of stopwords for that language will be used. if a list, | ||
| set, or tuple of strings is provided then those will be used as stopwords. | ||
| Defaults to "english". if set to "None" then no stopwords will be removed. | ||
| dialect (int, optional): The Redis dialect version. Defaults to 4. | ||
| Raises: | ||
| ValueError: If the text string is empty, or if the text string becomes empty after | ||
| stopwords are removed. | ||
| TypeError: If the stopwords are not a set, list, or tuple of strings. | ||
| .. code-block:: python | ||
| from redisvl.query.aggregate import HybridAggregationQuery | ||
| from redisvl.index import SearchIndex | ||
| index = SearchIndex("my_index") | ||
| query = HybridAggregationQuery( | ||
| text="example text", | ||
| text_field="text_field", | ||
| vector=[0.1, 0.2, 0.3], | ||
| vector_field="vector_field", | ||
| text_scorer="BM25STD", | ||
| filter_expression=None, | ||
| alpha=0.7, | ||
| dtype="float32", | ||
| num_results=10, | ||
| return_fields=["field1", "field2"], | ||
| stopwords="english", | ||
| dialect=4, | ||
| ) | ||
| results = index.aggregate_query(query) | ||
| """ | ||
|
|
||
| if not text.strip(): | ||
| raise ValueError("text string cannot be empty") | ||
|
|
||
| self._text = text | ||
| self._text_field = text_field | ||
| self._vector = vector | ||
| self._vector_field = vector_field | ||
| self._filter_expression = filter_expression | ||
| self._alpha = alpha | ||
| self._dtype = dtype | ||
| self._num_results = num_results | ||
| self.set_stopwords(stopwords) | ||
|
|
||
| query_string = self._build_query_string() | ||
| super().__init__(query_string) | ||
|
|
||
| self.scorer(text_scorer) # type: ignore[attr-defined] | ||
| self.add_scores() # type: ignore[attr-defined] | ||
| self.apply( | ||
| vector_similarity=f"(2 - @{self.DISTANCE_ID})/2", text_score="@__score" | ||
| ) | ||
| self.apply(hybrid_score=f"{1-alpha}*@text_score + {alpha}*@vector_similarity") | ||
| self.sort_by(Desc("@hybrid_score"), max=num_results) | ||
| self.dialect(dialect) # type: ignore[attr-defined] | ||
tylerhutcherson marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| if return_fields: | ||
| self.load(*return_fields) | ||
|
|
||
| @property | ||
| def params(self) -> Dict[str, Any]: | ||
| """Return the parameters for the aggregation. | ||
| Returns: | ||
| Dict[str, Any]: The parameters for the aggregation. | ||
| """ | ||
| if isinstance(self._vector, bytes): | ||
| vector = self._vector | ||
| else: | ||
| vector = array_to_buffer(self._vector, dtype=self._dtype) | ||
|
|
||
| params = {self.VECTOR_PARAM: vector} | ||
|
|
||
| return params | ||
|
|
||
| @property | ||
| def stopwords(self) -> Set[str]: | ||
| """Return the stopwords used in the query. | ||
| Returns: | ||
| Set[str]: The stopwords used in the query. | ||
| """ | ||
| return self._stopwords.copy() if self._stopwords else set() | ||
|
|
||
| def set_stopwords(self, stopwords: Optional[Union[str, Set[str]]] = "english"): | ||
| """Set the stopwords to use in the query. | ||
| Args: | ||
| stopwords (Optional[Union[str, Set[str]]]): The stopwords to use. If a string | ||
| such as "english" "german" is provided then a default set of stopwords for that | ||
| language will be used. if a list, set, or tuple of strings is provided then those | ||
| will be used as stopwords. Defaults to "english". if set to "None" then no stopwords | ||
| will be removed. | ||
| Raises: | ||
| TypeError: If the stopwords are not a set, list, or tuple of strings. | ||
| """ | ||
| if not stopwords: | ||
| self._stopwords = set() | ||
| elif isinstance(stopwords, str): | ||
| try: | ||
| nltk.download("stopwords") | ||
| self._stopwords = set(nltk_stopwords.words(stopwords)) | ||
| except Exception as e: | ||
| raise ValueError(f"Error trying to load {stopwords} from nltk. {e}") | ||
| elif isinstance(stopwords, (Set, List, Tuple)) and all( # type: ignore | ||
| isinstance(word, str) for word in stopwords | ||
| ): | ||
| self._stopwords = set(stopwords) | ||
| else: | ||
| raise TypeError("stopwords must be a set, list, or tuple of strings") | ||
|
|
||
| def tokenize_and_escape_query(self, user_query: str) -> str: | ||
justin-cechmanek marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| """Convert a raw user query to a redis full text query joined by ORs | ||
| Args: | ||
| user_query (str): The user query to tokenize and escape. | ||
| Returns: | ||
| str: The tokenized and escaped query string. | ||
| Raises: | ||
| ValueError: If the text string becomes empty after stopwords are removed. | ||
| """ | ||
|
|
||
| escaper = TokenEscaper() | ||
|
|
||
| tokens = [ | ||
| escaper.escape( | ||
| token.strip().strip(",").replace("“", "").replace("”", "").lower() | ||
| ) | ||
| for token in user_query.split() | ||
| ] | ||
| tokenized = " | ".join( | ||
| [token for token in tokens if token and token not in self._stopwords] | ||
| ) | ||
|
|
||
| if not tokenized: | ||
| raise ValueError("text string cannot be empty after removing stopwords") | ||
| return tokenized | ||
|
|
||
| def _build_query_string(self) -> str: | ||
| """Build the full query string for text search with optional filtering.""" | ||
| if isinstance(self._filter_expression, FilterExpression): | ||
| filter_expression = str(self._filter_expression) | ||
| else: | ||
| filter_expression = "" | ||
|
|
||
| # base KNN query | ||
| knn_query = f"KNN {self._num_results} @{self._vector_field} ${self.VECTOR_PARAM} AS {self.DISTANCE_ID}" | ||
tylerhutcherson marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| text = f"(~@{self._text_field}:({self.tokenize_and_escape_query(self._text)}))" | ||
|
|
||
| if filter_expression and filter_expression != "*": | ||
| text += f"({filter_expression})" | ||
|
|
||
| return f"{text}=>[{knn_query}]" | ||
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.