diff --git a/sdk/search/azure-search-documents/azure/search/documents/_paging.py b/sdk/search/azure-search-documents/azure/search/documents/_paging.py index 42c40910a10d..afdd9457a02b 100644 --- a/sdk/search/azure-search-documents/azure/search/documents/_paging.py +++ b/sdk/search/azure-search-documents/azure/search/documents/_paging.py @@ -134,6 +134,7 @@ def _extract_data_cb(self, response): # pylint:disable=no-self-use @_ensure_response def get_facets(self): + self.continuation_token = None facets = self._response.facets if facets is not None and self._facets is None: self._facets = {k: [x.as_dict() for x in v] for k, v in facets.items()} @@ -141,12 +142,15 @@ def get_facets(self): @_ensure_response def get_coverage(self): + self.continuation_token = None return self._response.coverage @_ensure_response def get_count(self): + self.continuation_token = None return self._response.count @_ensure_response def get_answers(self): + self.continuation_token = None return self._response.answers diff --git a/sdk/search/azure-search-documents/azure/search/documents/aio/_paging.py b/sdk/search/azure-search-documents/azure/search/documents/aio/_paging.py index 1bd544a957a7..d2498f955d9f 100644 --- a/sdk/search/azure-search-documents/azure/search/documents/aio/_paging.py +++ b/sdk/search/azure-search-documents/azure/search/documents/aio/_paging.py @@ -118,6 +118,7 @@ async def _extract_data_cb(self, response): # pylint:disable=no-self-use @_ensure_response async def get_facets(self): + self.continuation_token = None facets = self._response.facets if facets is not None and self._facets is None: self._facets = {k: [x.as_dict() for x in v] for k, v in facets.items()} @@ -125,12 +126,15 @@ async def get_facets(self): @_ensure_response async def get_coverage(self): + self.continuation_token = None return self._response.coverage @_ensure_response async def get_count(self): + self.continuation_token = None return self._response.count @_ensure_response async def get_answers(self): + self.continuation_token = None return self._response.answers diff --git a/sdk/search/azure-search-documents/tests/async_tests/test_search_client_async.py b/sdk/search/azure-search-documents/tests/async_tests/test_search_client_async.py new file mode 100644 index 000000000000..9ba5807d6548 --- /dev/null +++ b/sdk/search/azure-search-documents/tests/async_tests/test_search_client_async.py @@ -0,0 +1,30 @@ +# ------------------------------------ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +# ------------------------------------ +try: + from unittest import mock +except ImportError: + import mock +from azure.core.credentials import AzureKeyCredential +from azure.search.documents._generated.models import SearchDocumentsResult, SearchResult +from azure.search.documents.aio import SearchClient +from azure.search.documents.aio._search_client_async import AsyncSearchPageIterator + +CREDENTIAL = AzureKeyCredential(key="test_api_key") + +class TestSearchClientAsync(object): + @mock.patch( + "azure.search.documents._generated.aio.operations._documents_operations.DocumentsOperations.search_post" + ) + async def test_get_count_reset_continuation_token(self, mock_search_post): + client = SearchClient("endpoint", "index name", CREDENTIAL) + result = await client.search(search_text="search text") + assert result._page_iterator_class is AsyncSearchPageIterator + search_result = SearchDocumentsResult() + search_result.results = [SearchResult(additional_properties={"key": "val"})] + mock_search_post.return_value = search_result + await result.__anext__() + result._first_page_iterator_instance.continuation_token = "fake token" + await result.get_count() + assert not result._first_page_iterator_instance.continuation_token \ No newline at end of file diff --git a/sdk/search/azure-search-documents/tests/test_search_client.py b/sdk/search/azure-search-documents/tests/test_search_client.py index d1cb060e6891..1d450b5e6910 100644 --- a/sdk/search/azure-search-documents/tests/test_search_client.py +++ b/sdk/search/azure-search-documents/tests/test_search_client.py @@ -181,6 +181,22 @@ def test_suggest_bad_argument(self): repr("bad_query") ) + @mock.patch( + "azure.search.documents._generated.operations._documents_operations.DocumentsOperations.search_post" + ) + def test_get_count_reset_continuation_token(self, mock_search_post): + client = SearchClient("endpoint", "index name", CREDENTIAL) + result = client.search(search_text="search text") + assert isinstance(result, ItemPaged) + assert result._page_iterator_class is SearchPageIterator + search_result = SearchDocumentsResult() + search_result.results = [SearchResult(additional_properties={"key": "val"})] + mock_search_post.return_value = search_result + result.__next__() + result._first_page_iterator_instance.continuation_token = "fake token" + result.get_count() + assert not result._first_page_iterator_instance.continuation_token + @mock.patch( "azure.search.documents._generated.operations._documents_operations.DocumentsOperations.autocomplete_post" )