diff --git a/sdk/documenttranslation/azure-ai-documenttranslation/azure/ai/documenttranslation/_client.py b/sdk/documenttranslation/azure-ai-documenttranslation/azure/ai/documenttranslation/_client.py index 21f8467589d8..54179455add4 100644 --- a/sdk/documenttranslation/azure-ai-documenttranslation/azure/ai/documenttranslation/_client.py +++ b/sdk/documenttranslation/azure-ai-documenttranslation/azure/ai/documenttranslation/_client.py @@ -6,16 +6,25 @@ from typing import Union, Any, TYPE_CHECKING, List from azure.core.tracing.decorator import distributed_trace +from azure.core.polling import LROPoller +from azure.core.polling.base_polling import LROBasePolling from ._generated import BatchDocumentTranslationClient as _BatchDocumentTranslationClient +from ._generated.models import BatchStatusDetail as _BatchStatusDetail +from ._models import ( + JobStatusDetail, + DocumentStatusDetail, + BatchDocumentInput, + FileFormat +) from ._helpers import get_authentication_policy from ._user_agent import USER_AGENT +from ._polling import TranslationPolling if TYPE_CHECKING: from azure.core.paging import ItemPaged from azure.core.credentials import AzureKeyCredential, TokenCredential - from ._models import JobStatusDetail, DocumentStatusDetail, BatchDocumentInput, FileFormat -class DocumentTranslationClient(object): +class DocumentTranslationClient(object): # pylint: disable=r0205 """DocumentTranslationClient """ @@ -55,12 +64,24 @@ def create_translation_job(self, batch, **kwargs): :rtype: JobStatusDetail """ - return self._client.document_translation.begin_submit_batch_request( - inputs=batch, - polling=True, + # submit translation job + response_headers = self._client.document_translation._submit_batch_request_initial( # pylint: disable=protected-access + inputs=BatchDocumentInput._to_generated_list(batch), # pylint: disable=protected-access + cls=lambda pipeline_response, _, response_headers: response_headers, **kwargs ) + def get_job_id(response_headers): + operation_loc_header = response_headers['Operation-Location'] + return operation_loc_header.split('/')[-1] + + # get job id from response header + job_id = get_job_id(response_headers) + + # get job status + return self.get_job_status(job_id) + + @distributed_trace def get_job_status(self, job_id, **kwargs): # type: (str, **Any) -> JobStatusDetail @@ -71,7 +92,8 @@ def get_job_status(self, job_id, **kwargs): :rtype: ~azure.ai.documenttranslation.JobStatusDetail """ - return self._client.document_translation.get_operation_status(job_id, **kwargs) + job_status = self._client.document_translation.get_operation_status(job_id, **kwargs) + return JobStatusDetail._from_generated(job_status) # pylint: disable=protected-access @distributed_trace def cancel_job(self, job_id, **kwargs): @@ -95,7 +117,27 @@ def wait_until_done(self, job_id, **kwargs): :return: JobStatusDetail :rtype: JobStatusDetail """ - pass # pylint: disable=unnecessary-pass + + pipeline_response = self._client.document_translation.get_operation_status( + job_id, + cls=lambda pipeline_response, _, response_headers: pipeline_response + ) + + def callback(raw_response): + detail = self._client._deserialize(_BatchStatusDetail, raw_response) # pylint: disable=protected-access + return JobStatusDetail._from_generated(detail) # pylint: disable=protected-access + + poller = LROPoller( + client=self._client._client, # pylint: disable=protected-access + initial_response=pipeline_response, + deserialization_callback=callback, + polling_method=LROBasePolling( + timeout=30, + lro_algorithms=[TranslationPolling()], + **kwargs + ), + ) + return poller.result() @distributed_trace def list_submitted_jobs(self, **kwargs): @@ -106,7 +148,25 @@ def list_submitted_jobs(self, **kwargs): :keyword int skip: :rtype: ~azure.core.polling.ItemPaged[JobStatusDetail] """ - return self._client.document_translation.get_operations(**kwargs) + + skip = kwargs.pop('skip', None) + results_per_page = kwargs.pop('results_per_page', None) + + def _convert_from_generated_model(generated_model): # pylint: disable=protected-access + return JobStatusDetail._from_generated(generated_model) # pylint: disable=protected-access + + model_conversion_function = kwargs.pop( + "cls", + lambda job_statuses: [ + _convert_from_generated_model(job_status) for job_status in job_statuses + ]) + + return self._client.document_translation.get_operations( + top=results_per_page, + skip=skip, + cls=model_conversion_function, + **kwargs + ) @distributed_trace def list_documents_statuses(self, job_id, **kwargs): @@ -120,7 +180,26 @@ def list_documents_statuses(self, job_id, **kwargs): :rtype: ~azure.core.paging.ItemPaged[DocumentStatusDetail] """ - return self._client.document_translation.get_operation_documents_status(job_id, **kwargs) + skip = kwargs.pop('skip', None) + results_per_page = kwargs.pop('results_per_page', None) + + def _convert_from_generated_model(generated_model): + return DocumentStatusDetail._from_generated(generated_model) # pylint: disable=protected-access + + model_conversion_function = kwargs.pop( + "cls", + lambda doc_statuses: [ + _convert_from_generated_model(doc_status) for doc_status in doc_statuses + ]) + + return self._client.document_translation.get_operation_documents_status( + id=job_id, + top=results_per_page, + skip=skip, + cls=model_conversion_function, + **kwargs + ) + @distributed_trace def get_document_status(self, job_id, document_id, **kwargs): @@ -133,16 +212,12 @@ def get_document_status(self, job_id, document_id, **kwargs): :type document_id: str :rtype: ~azure.ai.documenttranslation.DocumentStatusDetail """ - return self._client.document_translation.get_document_status(job_id, document_id, **kwargs) - - @distributed_trace - def get_supported_storage_sources(self, **kwargs): - # type: (**Any) -> List[str] - """ - :rtype: List[str] - """ - return self._client.document_translation.get_document_storage_source(**kwargs) + document_status = self._client.document_translation.get_document_status( + job_id, + document_id, + **kwargs) + return DocumentStatusDetail._from_generated(document_status) # pylint: disable=protected-access @distributed_trace def get_supported_glossary_formats(self, **kwargs): @@ -152,7 +227,8 @@ def get_supported_glossary_formats(self, **kwargs): :rtype: List[FileFormat] """ - return self._client.document_translation.get_glossary_formats(**kwargs) + glossary_formats = self._client.document_translation.get_glossary_formats(**kwargs) + return FileFormat._from_generated_list(glossary_formats.value) # pylint: disable=protected-access @distributed_trace def get_supported_document_formats(self, **kwargs): @@ -162,4 +238,5 @@ def get_supported_document_formats(self, **kwargs): :rtype: List[FileFormat] """ - return self._client.document_translation.get_document_formats(**kwargs) + document_formats = self._client.document_translation.get_document_formats(**kwargs) + return FileFormat._from_generated_list(document_formats.value) # pylint: disable=protected-access diff --git a/sdk/documenttranslation/azure-ai-documenttranslation/azure/ai/documenttranslation/_models.py b/sdk/documenttranslation/azure-ai-documenttranslation/azure/ai/documenttranslation/_models.py index b8e69056fc82..50f6fcaefc93 100644 --- a/sdk/documenttranslation/azure-ai-documenttranslation/azure/ai/documenttranslation/_models.py +++ b/sdk/documenttranslation/azure-ai-documenttranslation/azure/ai/documenttranslation/_models.py @@ -4,10 +4,18 @@ # Licensed under the MIT License. # ------------------------------------ +# pylint: disable=unused-import from typing import Any, List +import six +from ._generated.models import ( + BatchRequest as _BatchRequest, + SourceInput as _SourceInput, + DocumentFilter as _DocumentFilter, + TargetInput as _TargetInput, + Glossary as _Glossary +) - -class TranslationGlossary(object): +class TranslationGlossary(object): # pylint: disable=useless-object-inheritance """Glossary / translation memory for the request. :param glossary_url: Required. Location of the glossary. @@ -32,8 +40,30 @@ def __init__( self.format_version = kwargs.get("format_version", None) self.storage_source = kwargs.get("storage_source", None) + def _to_generated(self): + return _Glossary( + glossary_url=self.glossary_url, + format=self.format, + version=self.format_version, + storage_source=self.storage_source + ) + + @staticmethod + def _to_generated_unknown_type(glossary): + if isinstance(glossary, TranslationGlossary): + return glossary._to_generated() # pylint: disable=protected-access + if isinstance(glossary, six.string_types): + return _Glossary( + glossary_url=glossary, + ) + return None + + @staticmethod + def _to_generated_list(glossaries): + return [TranslationGlossary._to_generated_unknown_type(glossary) for glossary in glossaries] + -class StorageTarget(object): +class StorageTarget(object): # pylint: disable=useless-object-inheritance """Destination for the finished translated documents. :param target_url: Required. Location of the folder / container with your documents. @@ -60,8 +90,23 @@ def __init__( self.glossaries = kwargs.get("glossaries", None) self.storage_source = kwargs.get("storage_source", None) + def _to_generated(self): + return _TargetInput( + target_url=self.target_url, + category=self.category_id, + language=self.language, + storage_source=self.storage_source, + glossaries=TranslationGlossary._to_generated_list(self.glossaries) # pylint: disable=protected-access + if self.glossaries else None + ) -class BatchDocumentInput(object): + @staticmethod + def _to_generated_list(targets): + return [target._to_generated() for target in targets] # pylint: disable=protected-access + + +class BatchDocumentInput(object): # pylint: disable=useless-object-inheritance + # pylint: disable=C0301 """Definition for the input batch translation request. :param source_url: Required. Location of the folder / container or single file with your @@ -97,8 +142,30 @@ def __init__( self.prefix = kwargs.get("prefix", None) self.suffix = kwargs.get("suffix", None) + def _to_generated(self): + return _BatchRequest( + source=_SourceInput( + source_url=self.source_url, + filter=_DocumentFilter( + prefix=self.prefix, + suffix=self.suffix + ), + language=self.source_language, + storage_source=self.storage_source + ), + targets=StorageTarget._to_generated_list(self.targets), # pylint: disable=protected-access + storage_type=self.storage_type + ) + + @staticmethod + def _to_generated_list(batch_document_inputs): + return [ + batch_document_input._to_generated() # pylint: disable=protected-access + for batch_document_input in batch_document_inputs + ] -class JobStatusDetail(object): # pylint: disable=too-many-instance-attributes + +class JobStatusDetail(object): # pylint: disable=useless-object-inheritance, too-many-instance-attributes """Job status response. :ivar id: Required. Id of the job. @@ -143,8 +210,26 @@ def __init__( self.documents_cancelled_count = kwargs.get('documents_cancelled_count', None) self.total_characters_charged = kwargs.get('total_characters_charged', None) + @classmethod + def _from_generated(cls, batch_status_details): + return cls( + id=batch_status_details.id, + created_on=batch_status_details.created_date_time_utc, + last_updated_on=batch_status_details.last_action_date_time_utc, + status=batch_status_details.status, + error=DocumentTranslationError._from_generated(batch_status_details.error) # pylint: disable=protected-access + if batch_status_details.error else None, + documents_total_count=batch_status_details.summary.total, + documents_failed_count=batch_status_details.summary.failed, + documents_succeeded_count=batch_status_details.summary.success, + documents_in_progress_count=batch_status_details.summary.in_progress, + documents_not_yet_started_count=batch_status_details.summary.not_yet_started, + documents_cancelled_count=batch_status_details.summary.cancelled, + total_characters_charged=batch_status_details.summary.total_character_charged + ) + -class DocumentStatusDetail(object): +class DocumentStatusDetail(object): # pylint: disable=useless-object-inheritance, R0903 """DocumentStatusDetail. :ivar url: Required. Location of the document or folder. @@ -186,7 +271,22 @@ def __init__( self.characters_charged = kwargs.get('characters_charged', None) -class DocumentTranslationError(object): + @classmethod + def _from_generated(cls, doc_status): + return cls( + url=doc_status.path, + created_on=doc_status.created_date_time_utc, + last_updated_on=doc_status.last_action_date_time_utc, + status=doc_status.status, + translate_to=doc_status.to, + error=DocumentTranslationError._from_generated(doc_status.error) if doc_status.error else None, # pylint: disable=protected-access + translation_progress=doc_status.progress, + id=doc_status.id, + characters_charged=doc_status.character_charged + ) + + +class DocumentTranslationError(object): # pylint: disable=useless-object-inheritance, R0903 """This contains an outer error with error code, message, details, target and an inner error with more descriptive details. @@ -210,8 +310,16 @@ def __init__( self.message = None self.target = None + @classmethod + def _from_generated(cls, error): + return cls( + code=error.code, + message=error.message, + target=error.target + ) -class FileFormat(object): + +class FileFormat(object): # pylint: disable=useless-object-inheritance, R0903 """FileFormat. :ivar format: Name of the format. @@ -233,3 +341,16 @@ def __init__( self.file_extensions = kwargs.get('file_extensions', None) self.content_types = kwargs.get('content_types', None) self.versions = kwargs.get('versions', None) + + @classmethod + def _from_generated(cls, file_format): + return cls( + format=file_format.format, + file_extensions=file_format.file_extensions, + content_types=file_format.content_types, + versions=file_format.versions + ) + + @staticmethod + def _from_generated_list(file_formats): + return [FileFormat._from_generated(file_formats) for file_formats in file_formats] diff --git a/sdk/documenttranslation/azure-ai-documenttranslation/azure/ai/documenttranslation/_polling.py b/sdk/documenttranslation/azure-ai-documenttranslation/azure/ai/documenttranslation/_polling.py new file mode 100644 index 000000000000..c07d92cfc192 --- /dev/null +++ b/sdk/documenttranslation/azure-ai-documenttranslation/azure/ai/documenttranslation/_polling.py @@ -0,0 +1,89 @@ +# coding=utf-8 +# ------------------------------------ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +# ------------------------------------ + +from azure.core.polling.base_polling import ( + LongRunningOperation, + _is_empty, + _as_json, + BadResponse, + OperationFailed +) + + +class TranslationPolling(LongRunningOperation): + """Implements a Location polling. + """ + + def __init__(self): + self._async_url = None + + def can_poll(self, pipeline_response): + # type: (PipelineResponseType) -> bool + """Answer if this polling method could be used. + """ + response = pipeline_response.http_response + if not _is_empty(response): + body = _as_json(response) + status = body.get("status") + if status: + return True + return False + + def get_polling_url(self): + # type: () -> str + """Return the polling URL. + """ + return self._async_url + + def set_initial_status(self, pipeline_response): + # type: (PipelineResponseType) -> str + """Process first response after initiating long running operation. + + :param azure.core.pipeline.PipelineResponse response: initial REST call response. + """ + self._async_url = pipeline_response.http_response.request.url + + response = pipeline_response.http_response + if response.status_code in {200, 201, 202, 204} and self._async_url: + return "InProgress" + raise OperationFailed("Operation failed or canceled") + + def get_status(self, pipeline_response): + # type: (PipelineResponseType) -> str + """Process the latest status update retrieved from a 'location' header. + + :param azure.core.pipeline.PipelineResponse response: latest REST call response. + :raises: BadResponse if response has no body and not status 202. + """ + response = pipeline_response.http_response + if not _is_empty(response): + body = _as_json(response) + status = body.get("status") + if status: + return self._map_nonstandard_statuses(status) + raise BadResponse("No status found in body") + raise BadResponse("The response from long running operation does not contain a body.") + + def get_final_get_url(self, pipeline_response): + # type: (PipelineResponseType) -> Optional[str] + """If a final GET is needed, returns the URL. + + :rtype: str + """ + return None + + # pylint: disable=R0201 + def _map_nonstandard_statuses(self, status): + # type: (str) -> str + """Map non-standard statuses. + + :param str status: lro process status. + """ + if status in ["ValidationFailed"]: + return "Failed" + if status in ["Cancelled", "Cancelling"]: + return "Canceled" + return status