diff --git a/sdk/textanalytics/azure-ai-textanalytics/azure/ai/textanalytics/aio/_base_client_async.py b/sdk/textanalytics/azure-ai-textanalytics/azure/ai/textanalytics/aio/_base_client_async.py index 2d97f49c3ee6..753802047992 100644 --- a/sdk/textanalytics/azure-ai-textanalytics/azure/ai/textanalytics/aio/_base_client_async.py +++ b/sdk/textanalytics/azure-ai-textanalytics/azure/ai/textanalytics/aio/_base_client_async.py @@ -6,7 +6,7 @@ from typing import Any from azure.core.credentials import AzureKeyCredential from azure.core.pipeline.policies import AzureKeyCredentialPolicy -from ._policies_async import AsyncTextAnalyticsResponseHookPolicy +from .._policies import TextAnalyticsResponseHookPolicy from .._user_agent import USER_AGENT from .._multiapi import load_generated_api @@ -34,11 +34,10 @@ def __init__(self, endpoint, credential, **kwargs): credential=credential, sdk_moniker=USER_AGENT, authentication_policy=_authentication_policy(credential), - custom_hook_policy=AsyncTextAnalyticsResponseHookPolicy(**kwargs), + custom_hook_policy=TextAnalyticsResponseHookPolicy(**kwargs), **kwargs ) - async def __aenter__(self) -> "AsyncTextAnalyticsClientBase": await self._client.__aenter__() return self diff --git a/sdk/textanalytics/azure-ai-textanalytics/azure/ai/textanalytics/aio/_policies_async.py b/sdk/textanalytics/azure-ai-textanalytics/azure/ai/textanalytics/aio/_policies_async.py deleted file mode 100644 index f035a52520f8..000000000000 --- a/sdk/textanalytics/azure-ai-textanalytics/azure/ai/textanalytics/aio/_policies_async.py +++ /dev/null @@ -1,36 +0,0 @@ -# coding=utf-8 -# ------------------------------------ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. -# ------------------------------------ - -import asyncio -from azure.core.pipeline.policies import ContentDecodePolicy -from azure.core.pipeline.policies import SansIOHTTPPolicy -from .._models import TextDocumentBatchStatistics - - -class AsyncTextAnalyticsResponseHookPolicy(SansIOHTTPPolicy): - - def __init__(self, **kwargs): - self._response_callback = kwargs.get('raw_response_hook') - super(AsyncTextAnalyticsResponseHookPolicy, self).__init__() - - async def on_request(self, request): - self._response_callback = request.context.options.pop("raw_response_hook", self._response_callback) - - async def on_response(self, request, response): - if self._response_callback: - data = ContentDecodePolicy.deserialize_from_http_generics(response.http_response) - statistics = data.get("statistics", None) - model_version = data.get("modelVersion", None) - - if statistics or model_version: - batch_statistics = TextDocumentBatchStatistics._from_generated(statistics) # pylint: disable=protected-access - response.statistics = batch_statistics - response.model_version = model_version - response.raw_response = data - if asyncio.iscoroutine(self._response_callback): - await self._response_callback(response) - else: - self._response_callback(response)