diff --git a/sdk/documenttranslation/azure-ai-documenttranslation/azure/ai/documenttranslation/aio/_client_async.py b/sdk/documenttranslation/azure-ai-documenttranslation/azure/ai/documenttranslation/aio/_client_async.py index 6d011d8e957a..843c3de579ae 100644 --- a/sdk/documenttranslation/azure-ai-documenttranslation/azure/ai/documenttranslation/aio/_client_async.py +++ b/sdk/documenttranslation/azure-ai-documenttranslation/azure/ai/documenttranslation/aio/_client_async.py @@ -7,11 +7,22 @@ from typing import Union, Any, List, TYPE_CHECKING from azure.core.tracing.decorator_async import distributed_trace_async from azure.core.tracing.decorator import distributed_trace +from azure.core.polling import AsyncLROPoller +from azure.core.polling.async_base_polling import AsyncLROBasePolling from azure.core.async_paging import AsyncItemPaged from .._generated.aio import BatchDocumentTranslationClient as _BatchDocumentTranslationClient from .._user_agent import USER_AGENT -from .._models import JobStatusDetail, DocumentStatusDetail, BatchDocumentInput, FileFormat +from .._generated.models import ( + BatchStatusDetail as _BatchStatusDetail, +) +from .._models import ( + JobStatusDetail, + BatchDocumentInput, + FileFormat, + DocumentStatusDetail +) from .._helpers import get_authentication_policy +from .._polling import TranslationPolling if TYPE_CHECKING: from azure.core.credentials_async import AsyncTokenCredential from azure.core.credentials import AzureKeyCredential @@ -59,12 +70,27 @@ async def create_translation_job(self, batch, **kwargs): :rtype: JobStatusDetail """ - return await self._client.document_translation.begin_submit_batch_request( - inputs=batch, + # submit translation job + response_headers = await self._client.document_translation._submit_batch_request_initial( # pylint: disable=protected-access + # pylint: disable=protected-access + inputs=BatchDocumentInput._to_generated_list(batch), + cls=lambda pipeline_response, _, response_headers: response_headers, polling=True, **kwargs ) + def get_job_id(response_headers): + # extract job id. + operation_location_header = response_headers['Operation-Location'] + return operation_location_header.split('/')[-1] + + # get job id from response header + job_id = get_job_id(response_headers) + + # get job status + return await self.get_job_status(job_id) + + @distributed_trace_async async def get_job_status(self, job_id, **kwargs): # type: (str, **Any) -> JobStatusDetail @@ -75,7 +101,9 @@ async def get_job_status(self, job_id, **kwargs): :rtype: ~azure.ai.documenttranslation.JobStatusDetail """ - return await self._client.document_translation.get_operation_status(job_id, **kwargs) + job_status = await self._client.document_translation.get_operation_status(job_id, **kwargs) + # pylint: disable=protected-access + return JobStatusDetail._from_generated(job_status) @distributed_trace_async async def cancel_job(self, job_id, **kwargs): @@ -99,7 +127,26 @@ async def wait_until_done(self, job_id, **kwargs): :return: JobStatusDetail :rtype: JobStatusDetail """ - pass # pylint: disable=unnecessary-pass + pipeline_response = await self._client.document_translation.get_operation_status( + job_id, + cls=lambda pipeline_response, _, response_headers: pipeline_response + ) + + def callback(raw_response): + detail = self._client._deserialize(_BatchStatusDetail, raw_response) # pylint: disable=protected-access + return JobStatusDetail._from_generated(detail) # pylint: disable=protected-access + + poller = AsyncLROPoller( + client=self._client._client, # pylint: disable=protected-access + initial_response=pipeline_response, + deserialization_callback=callback, + polling_method=AsyncLROBasePolling( + timeout=30, + lro_algorithms=[TranslationPolling()], + **kwargs + ), + ) + return poller.result() @distributed_trace def list_submitted_jobs(self, **kwargs): @@ -110,7 +157,24 @@ def list_submitted_jobs(self, **kwargs): :keyword int skip: :rtype: ~azure.core.polling.AsyncItemPaged[JobStatusDetail] """ - return self._client.document_translation.get_operations(**kwargs) + skip = kwargs.pop('skip', None) + results_per_page = kwargs.pop('results_per_page', None) + + def _convert_from_generated_model(generated_model): + # pylint: disable=protected-access + return JobStatusDetail._from_generated(generated_model) + + model_conversion_function = kwargs.pop( + "cls", + lambda job_statuses: [_convert_from_generated_model(job_status) for job_status in job_statuses] + ) + + return self._client.document_translation.get_operations( + top=results_per_page, + skip=skip, + cls=model_conversion_function, + **kwargs + ) @distributed_trace def list_documents_statuses(self, job_id, **kwargs): @@ -123,8 +187,26 @@ def list_documents_statuses(self, job_id, **kwargs): :keyword int skip: :rtype: ~azure.core.paging.AsyncItemPaged[DocumentStatusDetail] """ + skip = kwargs.pop('skip', None) + results_per_page = kwargs.pop('results_per_page', None) + + def _convert_from_generated_model(generated_model): + # pylint: disable=protected-access + return DocumentStatusDetail._from_generated(generated_model) + + model_conversion_function = kwargs.pop( + "cls", + lambda doc_statuses: [_convert_from_generated_model(doc_status) for doc_status in doc_statuses] + ) + + return self._client.document_translation.get_operation_documents_status( + id=job_id, + top=results_per_page, + skip=skip, + cls=model_conversion_function, + **kwargs + ) - return self._client.document_translation.get_operation_documents_status(job_id, **kwargs) @distributed_trace_async async def get_document_status(self, job_id, document_id, **kwargs): @@ -137,16 +219,10 @@ async def get_document_status(self, job_id, document_id, **kwargs): :type document_id: str :rtype: ~azure.ai.documenttranslation.DocumentStatusDetail """ - return await self._client.document_translation.get_document_status(job_id, document_id, **kwargs) + document_status = await self._client.document_translation.get_document_status(job_id, document_id, **kwargs) + # pylint: disable=protected-access + return DocumentStatusDetail._from_generated(document_status) - @distributed_trace_async - async def get_supported_storage_sources(self, **kwargs): - # type: (**Any) -> List[str] - """ - - :rtype: list[str] - """ - return await self._client.document_translation.get_document_storage_source(**kwargs) @distributed_trace_async async def get_supported_glossary_formats(self, **kwargs): @@ -155,8 +231,9 @@ async def get_supported_glossary_formats(self, **kwargs): :rtype: list[FileFormat] """ - - return await self._client.document_translation.get_glossary_formats(**kwargs) + glossary_formats = await self._client.document_translation.get_glossary_formats(**kwargs) + # pylint: disable=protected-access + return FileFormat._from_generated_list(glossary_formats.value) @distributed_trace_async async def get_supported_document_formats(self, **kwargs): @@ -165,5 +242,6 @@ async def get_supported_document_formats(self, **kwargs): :rtype: list[FileFormat] """ - - return await self._client.document_translation.get_document_formats(**kwargs) + document_formats = await self._client.document_translation.get_document_formats(**kwargs) + # pylint: disable=protected-access + return FileFormat._from_generated_list(document_formats.value)