Skip to content

Commit

Permalink
Collapse support added to AsyncSearch helper (#769)
Browse files Browse the repository at this point in the history
  • Loading branch information
Radoslaw Kuczynski committed Sep 26, 2024
1 parent 45d8172 commit 26ad804
Show file tree
Hide file tree
Showing 3 changed files with 87 additions and 3 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ Inspired from [Keep a Changelog](https://keepachangelog.com/en/1.0.0/)

## [Unreleased]
### Added
- Added `AsyncSearch#collapse` ([827](https://github.com/opensearch-project/opensearch-py/pull/827))
### Changed
### Deprecated
### Removed
Expand Down
39 changes: 36 additions & 3 deletions opensearchpy/_async/helpers/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
# GitHub history for details.

import copy
from typing import Any, Sequence
from typing import Any, Dict, Sequence, cast

from opensearchpy._async.helpers.actions import aiter, async_scan
from opensearchpy.connection.async_connections import get_connection
Expand Down Expand Up @@ -39,6 +39,7 @@ def __init__(self, **kwargs: Any) -> None:

self.aggs = AggsProxy(self)
self._sort: Sequence[Any] = []
self._collapse: Dict[str, Any] = {}
self._source: Any = None
self._highlight: Any = {}
self._highlight_opts: Any = {}
Expand Down Expand Up @@ -111,13 +112,13 @@ def from_dict(cls, d: Any) -> Any:
s.update_from_dict(d)
return s

def _clone(self) -> Any:
def _clone(self) -> "AsyncSearch":
"""
Return a clone of the current search request. Performs a shallow copy
of all the underlying objects. Used internally by most state modifying
APIs.
"""
s = super()._clone()
s = cast(AsyncSearch, super()._clone())

s._response_class = self._response_class
s._sort = self._sort[:]
Expand All @@ -126,6 +127,7 @@ def _clone(self) -> Any:
s._highlight_opts = self._highlight_opts.copy()
s._suggest = self._suggest.copy()
s._script_fields = self._script_fields.copy()
s._collapse = self._collapse.copy()
for x in ("query", "post_filter"):
getattr(s, x)._proxied = getattr(self, x)._proxied

Expand Down Expand Up @@ -281,6 +283,34 @@ def sort(self, *keys: Any) -> Any:
s._sort.append(k)
return s

def collapse(
self,
field: Any = None,
inner_hits: Any = None,
max_concurrent_group_searches: Any = None,
) -> "AsyncSearch":
"""
Add collapsing information to the search request.
If called without providing ``field``, it will remove all collapse
requirements, otherwise it will replace them with the provided
arguments.
The API returns a copy of the AsyncSearch object and can thus be chained.
"""
s = self._clone()
s._collapse = {}

if field is None:
return s

s._collapse["field"] = field
if inner_hits:
s._collapse["inner_hits"] = inner_hits
if max_concurrent_group_searches:
s._collapse["max_concurrent_group_searches"] = max_concurrent_group_searches
return s

def highlight_options(self, **kwargs: Any) -> Any:
"""
Update the global highlighting options used for this request. For
Expand Down Expand Up @@ -376,6 +406,9 @@ def to_dict(self, count: bool = False, **kwargs: Any) -> Any:
if self._sort:
d["sort"] = self._sort

if self._collapse:
d["collapse"] = self._collapse

d.update(recursive_to_dict(self._extra))

if self._source not in (None, {}):
Expand Down
50 changes: 50 additions & 0 deletions test_opensearchpy/test_async/test_helpers/test_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,40 @@ async def test_sort_by_score() -> None:
s.sort("-_score")


def test_collapse() -> None:
s = search.AsyncSearch()

inner_hits = {"name": "most_recent", "size": 5, "sort": [{"@timestamp": "desc"}]}
s = s.collapse(
field="user.id", inner_hits=inner_hits, max_concurrent_group_searches=4
)

assert {
"field": "user.id",
"inner_hits": {
"name": "most_recent",
"size": 5,
"sort": [{"@timestamp": "desc"}],
},
"max_concurrent_group_searches": 4,
} == s._collapse
assert {
"collapse": {
"field": "user.id",
"inner_hits": {
"name": "most_recent",
"size": 5,
"sort": [{"@timestamp": "desc"}],
},
"max_concurrent_group_searches": 4,
}
} == s.to_dict()

s = s.collapse()
assert {} == s._collapse
assert search.AsyncSearch().to_dict() == s.to_dict()


async def test_slice() -> None:
s = search.AsyncSearch()
assert {"from": 3, "size": 7} == s[3:10].to_dict()
Expand Down Expand Up @@ -546,3 +580,19 @@ async def test_rescore_query_to_dict() -> None:
},
},
}


def test_collapse_chaining() -> None:
s = search.AsyncSearch(index="index_name")
s = s.filter("term", color="red")
s = s.collapse(field="category")
s = s.filter("term", brand="something")

assert {
"query": {
"bool": {
"filter": [{"term": {"color": "red"}}, {"term": {"brand": "something"}}]
}
},
"collapse": {"field": "category"},
} == s.to_dict()

0 comments on commit 26ad804

Please sign in to comment.