diff --git a/sdk/textanalytics/azure-ai-textanalytics/CHANGELOG.md b/sdk/textanalytics/azure-ai-textanalytics/CHANGELOG.md index 953e4956eeaa..fa4e485f5fce 100644 --- a/sdk/textanalytics/azure-ai-textanalytics/CHANGELOG.md +++ b/sdk/textanalytics/azure-ai-textanalytics/CHANGELOG.md @@ -2,6 +2,12 @@ ## 5.1.0b8 (Unreleased) +**Breaking Changes** + +- Changed the response structure of `being_analyze_actions`. Now, we return a list of results, where each result is a list of the action results for the document, in the order the documents and actions were passed +- Removed `AnalyzeActionsType` +- Removed `AnalyzeActionsResult` +- Removed `AnalyzeActionsError` ## 5.1.0b7 (2021-05-18) diff --git a/sdk/textanalytics/azure-ai-textanalytics/README.md b/sdk/textanalytics/azure-ai-textanalytics/README.md index b1e69be921de..61596afe22dd 100644 --- a/sdk/textanalytics/azure-ai-textanalytics/README.md +++ b/sdk/textanalytics/azure-ai-textanalytics/README.md @@ -506,10 +506,7 @@ from azure.core.credentials import AzureKeyCredential from azure.ai.textanalytics import ( TextAnalyticsClient, RecognizeEntitiesAction, - RecognizePiiEntitiesAction, - ExtractKeyPhrasesAction, - RecognizeLinkedEntitiesAction, - AnalyzeSentimentAction + AnalyzeSentimentAction, ) credential = AzureKeyCredential("") @@ -524,81 +521,39 @@ poller = text_analytics_client.begin_analyze_actions( display_name="Sample Text Analysis", actions=[ RecognizeEntitiesAction(), - RecognizePiiEntitiesAction(), - ExtractKeyPhrasesAction(), - RecognizeLinkedEntitiesAction(), AnalyzeSentimentAction() ] ) # returns multiple actions results in the same order as the inputted actions -result = poller.result() - -first_action_result = next(result) -print("Results of Entities Recognition action:") -docs = [doc for doc in first_action_result.document_results if not doc.is_error] - -for idx, doc in enumerate(docs): - print("\nDocument text: {}".format(documents[idx])) - for entity in doc.entities: - print("Entity: {}".format(entity.text)) - print("...Category: {}".format(entity.category)) - print("...Confidence Score: {}".format(entity.confidence_score)) - print("...Offset: {}".format(entity.offset)) - print("------------------------------------------") - -second_action_result = next(result) -print("Results of PII Entities Recognition action:") -docs = [doc for doc in second_action_result.document_results if not doc.is_error] - -for idx, doc in enumerate(docs): - print("Document text: {}".format(documents[idx])) - for entity in doc.entities: - print("Entity: {}".format(entity.text)) - print("Category: {}".format(entity.category)) - print("Confidence Score: {}\n".format(entity.confidence_score)) - print("------------------------------------------") - -third_action_result = next(result) -print("Results of Key Phrase Extraction action:") -docs = [doc for doc in third_action_result.document_results if not doc.is_error] - -for idx, doc in enumerate(docs): - print("Document text: {}\n".format(documents[idx])) - print("Key Phrases: {}\n".format(doc.key_phrases)) - print("------------------------------------------") - -fourth_action_result = next(result) -print("Results of Linked Entities Recognition action:") -docs = [doc for doc in fourth_action_result.document_results if not doc.is_error] - -for idx, doc in enumerate(docs): - print("Document text: {}\n".format(documents[idx])) - for linked_entity in doc.entities: - print("Entity name: {}".format(linked_entity.name)) - print("...Data source: {}".format(linked_entity.data_source)) - print("...Data source language: {}".format(linked_entity.language)) - print("...Data source entity ID: {}".format(linked_entity.data_source_entity_id)) - print("...Data source URL: {}".format(linked_entity.url)) - print("...Document matches:") - for match in linked_entity.matches: - print("......Match text: {}".format(match.text)) - print(".........Confidence Score: {}".format(match.confidence_score)) - print(".........Offset: {}".format(match.offset)) - print(".........Length: {}".format(match.length)) - print("------------------------------------------") - -fifth_action_result = next(result) -print("Results of Sentiment Analysis action:") -docs = [doc for doc in fifth_action_result.document_results if not doc.is_error] - -for doc in docs: - print("Overall sentiment: {}".format(doc.sentiment)) - print("Scores: positive={}; neutral={}; negative={} \n".format( - doc.confidence_scores.positive, - doc.confidence_scores.neutral, - doc.confidence_scores.negative, - )) +document_results = poller.result() +for doc, action_results in zip(documents, document_results): + recognize_entities_result, analyze_sentiment_result = action_results + print("\nDocument text: {}".format(doc)) + print("...Results of Recognize Entities Action:") + if recognize_entities_result.is_error: + print("......Is an error with code '{}' and message '{}'".format( + recognize_entities_result.code, recognize_entities_result.message + )) + else: + for entity in recognize_entities_result.entities: + print("......Entity: {}".format(entity.text)) + print(".........Category: {}".format(entity.category)) + print(".........Confidence Score: {}".format(entity.confidence_score)) + print(".........Offset: {}".format(entity.offset)) + + print("...Results of Analyze Sentiment action:") + if analyze_sentiment_result.is_error: + print("......Is an error with code '{}' and message '{}'".format( + analyze_sentiment_result.code, analyze_sentiment_result.message + )) + else: + print("......Overall sentiment: {}".format(analyze_sentiment_result.sentiment)) + print("......Scores: positive={}; neutral={}; negative={} \n".format( + analyze_sentiment_result.confidence_scores.positive, + analyze_sentiment_result.confidence_scores.neutral, + analyze_sentiment_result.confidence_scores.negative, + )) print("------------------------------------------") ``` diff --git a/sdk/textanalytics/azure-ai-textanalytics/azure/ai/textanalytics/__init__.py b/sdk/textanalytics/azure-ai-textanalytics/azure/ai/textanalytics/__init__.py index 366fae23c386..45cdd5d2d375 100644 --- a/sdk/textanalytics/azure-ai-textanalytics/azure/ai/textanalytics/__init__.py +++ b/sdk/textanalytics/azure-ai-textanalytics/azure/ai/textanalytics/__init__.py @@ -39,9 +39,7 @@ RecognizeLinkedEntitiesAction, RecognizePiiEntitiesAction, ExtractKeyPhrasesAction, - AnalyzeActionsResult, - AnalyzeActionsType, - AnalyzeActionsError, + _AnalyzeActionsType, HealthcareEntityRelationRoleType, HealthcareRelation, HealthcareRelationRole, @@ -91,9 +89,7 @@ 'RecognizeLinkedEntitiesAction', 'RecognizePiiEntitiesAction', 'ExtractKeyPhrasesAction', - 'AnalyzeActionsResult', - 'AnalyzeActionsType', - "AnalyzeActionsError", + '_AnalyzeActionsType', "PiiEntityCategoryType", "HealthcareEntityRelationType", "HealthcareEntityRelationRoleType", diff --git a/sdk/textanalytics/azure-ai-textanalytics/azure/ai/textanalytics/_models.py b/sdk/textanalytics/azure-ai-textanalytics/azure/ai/textanalytics/_models.py index e72b41de3c7c..8566ea326c1e 100644 --- a/sdk/textanalytics/azure-ai-textanalytics/azure/ai/textanalytics/_models.py +++ b/sdk/textanalytics/azure-ai-textanalytics/azure/ai/textanalytics/_models.py @@ -1356,7 +1356,7 @@ def __repr__(self): .format(self.positive, self.neutral, self.negative)[:1024] -class AnalyzeActionsType(str, Enum): +class _AnalyzeActionsType(str, Enum): """The type of action that was applied to the documents """ RECOGNIZE_ENTITIES = "recognize_entities" #: Entities Recognition action. @@ -1365,62 +1365,6 @@ class AnalyzeActionsType(str, Enum): RECOGNIZE_LINKED_ENTITIES = "recognize_linked_entities" #: Linked Entities Recognition action. ANALYZE_SENTIMENT = "analyze_sentiment" #: Sentiment Analysis action. - -class AnalyzeActionsResult(DictMixin): - """AnalyzeActionsResult contains the results of a recognize entities action - on a list of documents. Returned by `begin_analyze_actions` - - :ivar document_results: A list of objects containing results for all Entity Recognition actions - included in the analysis. - :vartype document_results: list[~azure.ai.textanalytics.RecognizeEntitiesResult] - :ivar bool is_error: Boolean check for error item when iterating over list of - actions. Always False for an instance of a AnalyzeActionsResult. - :ivar action_type: The type of action this class is a result of. - :vartype action_type: str or ~azure.ai.textanalytics.AnalyzeActionsType - :ivar ~datetime.datetime completed_on: Date and time (UTC) when the result completed - on the service. - """ - def __init__(self, **kwargs): - self.document_results = kwargs.get("document_results") - self.is_error = False - self.action_type = kwargs.get("action_type") - self.completed_on = kwargs.get("completed_on") - - def __repr__(self): - return "AnalyzeActionsResult(document_results={}, is_error={}, action_type={}, completed_on={})".format( - repr(self.document_results), - self.is_error, - self.action_type, - self.completed_on, - )[:1024] - - -class AnalyzeActionsError(DictMixin): - """AnalyzeActionsError is an error object which represents an an - error response for an action. - - :ivar error: The action result error. - :vartype error: ~azure.ai.textanalytics.TextAnalyticsError - :ivar bool is_error: Boolean check for error item when iterating over list of - results. Always True for an instance of a DocumentError. - """ - - def __init__(self, **kwargs): - self.error = kwargs.get("error") - self.is_error = True - - def __repr__(self): - return "AnalyzeActionsError(error={}, is_error={}".format( - repr(self.error), self.is_error - ) - - @classmethod - def _from_generated(cls, error): - return cls( - error=TextAnalyticsError(code=error.code, message=error.message, target=error.target) - ) - - class RecognizeEntitiesAction(DictMixin): """RecognizeEntitiesAction encapsulates the parameters for starting a long-running Entities Recognition operation. diff --git a/sdk/textanalytics/azure-ai-textanalytics/azure/ai/textanalytics/_request_handlers.py b/sdk/textanalytics/azure-ai-textanalytics/azure/ai/textanalytics/_request_handlers.py index baf47085ba0b..b4b4d4466e6e 100644 --- a/sdk/textanalytics/azure-ai-textanalytics/azure/ai/textanalytics/_request_handlers.py +++ b/sdk/textanalytics/azure-ai-textanalytics/azure/ai/textanalytics/_request_handlers.py @@ -14,7 +14,7 @@ RecognizePiiEntitiesAction, RecognizeLinkedEntitiesAction, AnalyzeSentimentAction, - AnalyzeActionsType, + _AnalyzeActionsType, ) def _validate_input(documents, hint, whole_input_hint): @@ -71,14 +71,14 @@ def _validate_input(documents, hint, whole_input_hint): def _determine_action_type(action): if isinstance(action, RecognizeEntitiesAction): - return AnalyzeActionsType.RECOGNIZE_ENTITIES + return _AnalyzeActionsType.RECOGNIZE_ENTITIES if isinstance(action, RecognizePiiEntitiesAction): - return AnalyzeActionsType.RECOGNIZE_PII_ENTITIES + return _AnalyzeActionsType.RECOGNIZE_PII_ENTITIES if isinstance(action, RecognizeLinkedEntitiesAction): - return AnalyzeActionsType.RECOGNIZE_LINKED_ENTITIES + return _AnalyzeActionsType.RECOGNIZE_LINKED_ENTITIES if isinstance(action, AnalyzeSentimentAction): - return AnalyzeActionsType.ANALYZE_SENTIMENT - return AnalyzeActionsType.EXTRACT_KEY_PHRASES + return _AnalyzeActionsType.ANALYZE_SENTIMENT + return _AnalyzeActionsType.EXTRACT_KEY_PHRASES def _check_string_index_type_arg(string_index_type_arg, api_version, string_index_type_default="UnicodeCodePoint"): string_index_type = None diff --git a/sdk/textanalytics/azure-ai-textanalytics/azure/ai/textanalytics/_response_handlers.py b/sdk/textanalytics/azure-ai-textanalytics/azure/ai/textanalytics/_response_handlers.py index 9af6d2a069ca..0e4e3dfe1f17 100644 --- a/sdk/textanalytics/azure-ai-textanalytics/azure/ai/textanalytics/_response_handlers.py +++ b/sdk/textanalytics/azure-ai-textanalytics/azure/ai/textanalytics/_response_handlers.py @@ -32,10 +32,7 @@ RecognizePiiEntitiesResult, PiiEntity, AnalyzeHealthcareEntitiesResult, - AnalyzeActionsResult, - AnalyzeActionsType, - AnalyzeActionsError, - _get_indices, + _AnalyzeActionsType, ) class CSODataV4Format(ODataV4Format): @@ -86,10 +83,10 @@ def order_lro_results(doc_id_order, combined): def prepare_result(func): def choose_wrapper(*args, **kwargs): - def wrapper(response, obj, response_headers): # pylint: disable=unused-argument + def wrapper(response, obj, response_headers, ordering_function): # pylint: disable=unused-argument if obj.errors: combined = obj.documents + obj.errors - results = order_results(response, combined) + results = ordering_function(response, combined) else: results = obj.documents @@ -101,27 +98,11 @@ def wrapper(response, obj, response_headers): # pylint: disable=unused-argument results[idx] = func(item, results) return results - def lro_wrapper(doc_id_order, obj, response_headers): # pylint: disable=unused-argument - if obj.errors: - combined = obj.documents + obj.errors - - results = order_lro_results(doc_id_order, combined) - else: - results = obj.documents - - for idx, item in enumerate(results): - if hasattr(item, "error"): - results[idx] = DocumentError(id=item.id, error=TextAnalyticsError._from_generated(item.error)) # pylint: disable=protected-access - else: - results[idx] = func(item, results) - return results - lro = kwargs.get("lro", False) if lro: - return lro_wrapper(*args) - - return wrapper(*args) + return wrapper(*args, ordering_function=order_lro_results) + return wrapper(*args, ordering_function=order_results) return choose_wrapper @@ -198,96 +179,57 @@ def healthcare_extract_page_data(doc_id_order, obj, response_headers, health_job healthcare_result(doc_id_order, health_job_state.results, response_headers, lro=True)) def _get_deserialization_callback_from_task_type(task_type): - if task_type == AnalyzeActionsType.RECOGNIZE_ENTITIES: + if task_type == _AnalyzeActionsType.RECOGNIZE_ENTITIES: return entities_result - if task_type == AnalyzeActionsType.RECOGNIZE_PII_ENTITIES: + if task_type == _AnalyzeActionsType.RECOGNIZE_PII_ENTITIES: return pii_entities_result - if task_type == AnalyzeActionsType.RECOGNIZE_LINKED_ENTITIES: + if task_type == _AnalyzeActionsType.RECOGNIZE_LINKED_ENTITIES: return linked_entities_result - if task_type == AnalyzeActionsType.ANALYZE_SENTIMENT: + if task_type == _AnalyzeActionsType.ANALYZE_SENTIMENT: return sentiment_result return key_phrases_result def _get_property_name_from_task_type(task_type): - if task_type == AnalyzeActionsType.RECOGNIZE_ENTITIES: + if task_type == _AnalyzeActionsType.RECOGNIZE_ENTITIES: return "entity_recognition_tasks" - if task_type == AnalyzeActionsType.RECOGNIZE_PII_ENTITIES: + if task_type == _AnalyzeActionsType.RECOGNIZE_PII_ENTITIES: return "entity_recognition_pii_tasks" - if task_type == AnalyzeActionsType.RECOGNIZE_LINKED_ENTITIES: + if task_type == _AnalyzeActionsType.RECOGNIZE_LINKED_ENTITIES: return "entity_linking_tasks" - if task_type == AnalyzeActionsType.ANALYZE_SENTIMENT: + if task_type == _AnalyzeActionsType.ANALYZE_SENTIMENT: return "sentiment_analysis_tasks" return "key_phrase_extraction_tasks" -def _num_tasks_in_current_page(returned_tasks_object): - return ( - len(returned_tasks_object.entity_recognition_tasks or []) + - len(returned_tasks_object.entity_recognition_pii_tasks or []) + - len(returned_tasks_object.key_phrase_extraction_tasks or []) + - len(returned_tasks_object.entity_linking_tasks or []) + - len(returned_tasks_object.sentiment_analysis_tasks or []) - ) - -def _get_task_type_from_error(error): - if "pii" in error.target.lower(): - return AnalyzeActionsType.RECOGNIZE_PII_ENTITIES - if "entityrecognition" in error.target.lower(): - return AnalyzeActionsType.RECOGNIZE_ENTITIES - if "entitylinking" in error.target.lower(): - return AnalyzeActionsType.RECOGNIZE_LINKED_ENTITIES - if "sentiment" in error.target.lower(): - return AnalyzeActionsType.ANALYZE_SENTIMENT - return AnalyzeActionsType.EXTRACT_KEY_PHRASES - -def _get_mapped_errors(analyze_job_state): - """ - """ - mapped_errors = defaultdict(list) - if not analyze_job_state.errors: - return mapped_errors - for error in analyze_job_state.errors: - mapped_errors[_get_task_type_from_error(error)].append((_get_error_index(error), error)) - return mapped_errors - -def _get_error_index(error): - return _get_indices(error.target)[-1] - def _get_good_result(current_task_type, index_of_task_result, doc_id_order, response_headers, returned_tasks_object): deserialization_callback = _get_deserialization_callback_from_task_type(current_task_type) property_name = _get_property_name_from_task_type(current_task_type) response_task_to_deserialize = getattr(returned_tasks_object, property_name)[index_of_task_result] - document_results = deserialization_callback( + return deserialization_callback( doc_id_order, response_task_to_deserialize.results, response_headers, lro=True ) - return AnalyzeActionsResult( - document_results=document_results, - action_type=current_task_type, - completed_on=response_task_to_deserialize.last_update_date_time, - ) def get_iter_items(doc_id_order, task_order, response_headers, analyze_job_state): - iter_items = [] + iter_items = defaultdict(list) # map doc id to action results task_type_to_index = defaultdict(int) # need to keep track of how many of each type of tasks we've seen returned_tasks_object = analyze_job_state.tasks - mapped_errors = _get_mapped_errors(analyze_job_state) for current_task_type in task_order: index_of_task_result = task_type_to_index[current_task_type] + results = _get_good_result( + current_task_type, + index_of_task_result, + doc_id_order, + response_headers, + returned_tasks_object, + ) + for result in results: + iter_items[result.id].append(result) - try: - # try to deserailize as error. If fails, we know it's good - # kind of a weird way to order things, but we can fail when deserializing - # the curr response as an error, not when deserializing as a good response. - - current_task_type_errors = mapped_errors[current_task_type] - error = next(err for err in current_task_type_errors if err[0] == index_of_task_result) - result = AnalyzeActionsError._from_generated(error[1]) # pylint: disable=protected-access - except StopIteration: - result = _get_good_result( - current_task_type, index_of_task_result, doc_id_order, response_headers, returned_tasks_object - ) - iter_items.append(result) task_type_to_index[current_task_type] += 1 - return iter_items + return [ + iter_items[doc_id] + for doc_id in doc_id_order + if doc_id in iter_items + ] def analyze_extract_page_data(doc_id_order, task_order, response_headers, analyze_job_state): # return next link, list of diff --git a/sdk/textanalytics/azure-ai-textanalytics/azure/ai/textanalytics/_text_analytics_client.py b/sdk/textanalytics/azure-ai-textanalytics/azure/ai/textanalytics/_text_analytics_client.py index e41ddda9b05a..12656a8002e6 100644 --- a/sdk/textanalytics/azure-ai-textanalytics/azure/ai/textanalytics/_text_analytics_client.py +++ b/sdk/textanalytics/azure-ai-textanalytics/azure/ai/textanalytics/_text_analytics_client.py @@ -33,7 +33,7 @@ analyze_paged_result, ) -from ._models import AnalyzeActionsType +from ._models import _AnalyzeActionsType from ._lro import ( TextAnalyticsOperationResourcePolling, @@ -59,7 +59,6 @@ ExtractKeyPhrasesAction, AnalyzeSentimentAction, AnalyzeHealthcareEntitiesResult, - AnalyzeActionsResult, ) from ._lro import AnalyzeHealthcareEntitiesLROPoller, AnalyzeActionsLROPoller @@ -825,9 +824,13 @@ def begin_analyze_actions( # type: ignore documents, # type: Union[List[str], List[TextDocumentInput], List[Dict[str, str]]] actions, # type: List[Union[RecognizeEntitiesAction, RecognizeLinkedEntitiesAction, RecognizePiiEntitiesAction, ExtractKeyPhrasesAction, AnalyzeSentimentAction]] # pylint: disable=line-too-long **kwargs # type: Any - ): # type: (...) -> AnalyzeActionsLROPoller[ItemPaged[AnalyzeActionsResult]] + ): # type: (...) -> AnalyzeActionsLROPoller[ItemPaged[List[Union[RecognizeEntitiesResult, RecognizeLinkedEntitiesResult, RecognizePiiEntitiesResult, ExtractKeyPhrasesResult, AnalyzeSentimentResult]]]] # pylint: disable=line-too-long """Start a long-running operation to perform a variety of text analysis actions over a batch of documents. + We recommend you use this function if you're looking to analyze larger documents, and / or + combine multiple Text Analytics actions into one call. Otherwise, we recommend you use + the action specific endpoints, for example :func:`analyze_sentiment`: + :param documents: The set of documents to process as part of this batch. If you wish to specify the ID and language on a per-item basis you must use as input a list[:class:`~azure.ai.textanalytics.TextDocumentInput`] or a list of @@ -853,11 +856,22 @@ def begin_analyze_actions( # type: ignore :keyword int polling_interval: Waiting time between two polls for LRO operations if no Retry-After header is present. Defaults to 30 seconds. :return: An instance of an AnalyzeActionsLROPoller. Call `result()` on the poller - object to return a pageable heterogeneous list of the action results in the order - the actions were sent in this method. + object to return a pageable heterogeneous list of lists. This list of lists is first ordered + by the documents you input, then ordered by the actions you input. For example, + if you have documents input ["Hello", "world"], and actions + :class:`~azure.ai.textanalytics.RecognizeEntitiesAction` and + :class:`~azure.ai.textanalytics.AnalyzeSentimentAction`, when iterating over the list of lists, + you will first iterate over the action results for the "Hello" document, getting the + :class:`~azure.ai.textanalytics.RecognizeEntitiesResult` of "Hello", + then the :class:`~azure.ai.textanalytics.AnalyzeSentimentResult` of "Hello". + Then, you will get the :class:`~azure.ai.textanalytics.RecognizeEntitiesResult` and + :class:`~azure.ai.textanalytics.AnalyzeSentimentResult` of "world". :rtype: ~azure.ai.textanalytics.AnalyzeActionsLROPoller[~azure.core.paging.ItemPaged[ - ~azure.ai.textanalytics.AnalyzeActionsResult]] + list[ + RecognizeEntitiesResult or RecognizeLinkedEntitiesResult or RecognizePiiEntitiesResult or + ExtractKeyPhrasesResult or AnalyzeSentimentResult + ]]] :raises ~azure.core.exceptions.HttpResponseError or TypeError or ValueError or NotImplementedError: .. admonition:: Example: @@ -888,26 +902,26 @@ def begin_analyze_actions( # type: ignore analyze_tasks = self._client.models(api_version='v3.1-preview.5').JobManifestTasks( entity_recognition_tasks=[ t.to_generated() for t in - [a for a in actions if _determine_action_type(a) == AnalyzeActionsType.RECOGNIZE_ENTITIES] + [a for a in actions if _determine_action_type(a) == _AnalyzeActionsType.RECOGNIZE_ENTITIES] ], entity_recognition_pii_tasks=[ t.to_generated() for t in - [a for a in actions if _determine_action_type(a) == AnalyzeActionsType.RECOGNIZE_PII_ENTITIES] + [a for a in actions if _determine_action_type(a) == _AnalyzeActionsType.RECOGNIZE_PII_ENTITIES] ], key_phrase_extraction_tasks=[ t.to_generated() for t in - [a for a in actions if _determine_action_type(a) == AnalyzeActionsType.EXTRACT_KEY_PHRASES] + [a for a in actions if _determine_action_type(a) == _AnalyzeActionsType.EXTRACT_KEY_PHRASES] ], entity_linking_tasks=[ t.to_generated() for t in [ a for a in actions - if _determine_action_type(a) == AnalyzeActionsType.RECOGNIZE_LINKED_ENTITIES + if _determine_action_type(a) == _AnalyzeActionsType.RECOGNIZE_LINKED_ENTITIES ] ], sentiment_analysis_tasks=[ t.to_generated() for t in - [a for a in actions if _determine_action_type(a) == AnalyzeActionsType.ANALYZE_SENTIMENT] + [a for a in actions if _determine_action_type(a) == _AnalyzeActionsType.ANALYZE_SENTIMENT] ] ) analyze_body = self._client.models(api_version='v3.1-preview.5').AnalyzeBatchInput( diff --git a/sdk/textanalytics/azure-ai-textanalytics/azure/ai/textanalytics/aio/_text_analytics_client_async.py b/sdk/textanalytics/azure-ai-textanalytics/azure/ai/textanalytics/aio/_text_analytics_client_async.py index d212d26a33b7..8f57a97a2de5 100644 --- a/sdk/textanalytics/azure-ai-textanalytics/azure/ai/textanalytics/aio/_text_analytics_client_async.py +++ b/sdk/textanalytics/azure-ai-textanalytics/azure/ai/textanalytics/aio/_text_analytics_client_async.py @@ -41,8 +41,7 @@ RecognizeEntitiesAction, RecognizePiiEntitiesAction, ExtractKeyPhrasesAction, - AnalyzeActionsResult, - AnalyzeActionsType, + _AnalyzeActionsType, RecognizeLinkedEntitiesAction, AnalyzeSentimentAction, AnalyzeHealthcareEntitiesResult, @@ -808,9 +807,13 @@ async def begin_analyze_actions( # type: ignore documents: Union[List[str], List[TextDocumentInput], List[Dict[str, str]]], actions: List[Union[RecognizeEntitiesAction, RecognizeLinkedEntitiesAction, RecognizePiiEntitiesAction, ExtractKeyPhrasesAction, AnalyzeSentimentAction]], # pylint: disable=line-too-long **kwargs: Any - ) -> AsyncAnalyzeActionsLROPoller[AsyncItemPaged[AnalyzeActionsResult]]: + ) -> AsyncAnalyzeActionsLROPoller[AsyncItemPaged[List[Union[RecognizeEntitiesResult, RecognizeLinkedEntitiesResult, RecognizePiiEntitiesResult, ExtractKeyPhrasesResult, AnalyzeSentimentResult]]]]: # pylint: disable=line-too-long """Start a long-running operation to perform a variety of text analysis actions over a batch of documents. + We recommend you use this function if you're looking to analyze larger documents, and / or + combine multiple Text Analytics actions into one call. Otherwise, we recommend you use + the action specific endpoints, for example :func:`analyze_sentiment`: + :param documents: The set of documents to process as part of this batch. If you wish to specify the ID and language on a per-item basis you must use as input a list[:class:`~azure.ai.textanalytics.TextDocumentInput`] or a list of @@ -835,12 +838,23 @@ async def begin_analyze_actions( # type: ignore :keyword bool show_stats: If set to true, response will contain document level statistics. :keyword int polling_interval: Waiting time between two polls for LRO operations if no Retry-After header is present. Defaults to 30 seconds. - :return: An instance of an AsyncAnalyzeActionsLROPoller. Call `result()` on the poller - object to return a pageable heterogeneous list of the action results in the order - the actions were sent in this method. + :return: An instance of an LROPoller. Call `result()` on the poller + object to return a pageable heterogeneous list of lists. This list of lists is first ordered + by the documents you input, then ordered by the actions you input. For example, + if you have documents input ["Hello", "world"], and actions + :class:`~azure.ai.textanalytics.RecognizeEntitiesAction` and + :class:`~azure.ai.textanalytics.AnalyzeSentimentAction`, when iterating over the list of lists, + you will first iterate over the action results for the "Hello" document, getting the + :class:`~azure.ai.textanalytics.RecognizeEntitiesResult` of "Hello", + then the :class:`~azure.ai.textanalytics.AnalyzeSentimentResult` of "Hello". + Then, you will get the :class:`~azure.ai.textanalytics.RecognizeEntitiesResult` and + :class:`~azure.ai.textanalytics.AnalyzeSentimentResult` of "world". :rtype: - ~azure.ai.textanalytics.aio.AsyncAnalyzeActionsLROPoller[~azure.core.async_paging.AsyncItemPaged[ - ~azure.ai.textanalytics.AnalyzeActionsResult]] + ~azure.core.polling.AsyncLROPoller[~azure.core.async_paging.AsyncItemPaged[ + list[ + RecognizeEntitiesResult or RecognizeLinkedEntitiesResult or RecognizePiiEntitiesResult or + ExtractKeyPhrasesResult or AnalyzeSentimentResult + ]]] :raises ~azure.core.exceptions.HttpResponseError or TypeError or ValueError or NotImplementedError: .. admonition:: Example: @@ -871,26 +885,26 @@ async def begin_analyze_actions( # type: ignore analyze_tasks = self._client.models(api_version='v3.1-preview.5').JobManifestTasks( entity_recognition_tasks=[ t.to_generated() for t in - [a for a in actions if _determine_action_type(a) == AnalyzeActionsType.RECOGNIZE_ENTITIES] + [a for a in actions if _determine_action_type(a) == _AnalyzeActionsType.RECOGNIZE_ENTITIES] ], entity_recognition_pii_tasks=[ t.to_generated() for t in - [a for a in actions if _determine_action_type(a) == AnalyzeActionsType.RECOGNIZE_PII_ENTITIES] + [a for a in actions if _determine_action_type(a) == _AnalyzeActionsType.RECOGNIZE_PII_ENTITIES] ], key_phrase_extraction_tasks=[ t.to_generated() for t in - [a for a in actions if _determine_action_type(a) == AnalyzeActionsType.EXTRACT_KEY_PHRASES] + [a for a in actions if _determine_action_type(a) == _AnalyzeActionsType.EXTRACT_KEY_PHRASES] ], entity_linking_tasks=[ t.to_generated() for t in [ a for a in actions if \ - _determine_action_type(a) == AnalyzeActionsType.RECOGNIZE_LINKED_ENTITIES + _determine_action_type(a) == _AnalyzeActionsType.RECOGNIZE_LINKED_ENTITIES ] ], sentiment_analysis_tasks=[ t.to_generated() for t in - [a for a in actions if _determine_action_type(a) == AnalyzeActionsType.ANALYZE_SENTIMENT] + [a for a in actions if _determine_action_type(a) == _AnalyzeActionsType.ANALYZE_SENTIMENT] ] ) analyze_body = self._client.models(api_version='v3.1-preview.5').AnalyzeBatchInput( diff --git a/sdk/textanalytics/azure-ai-textanalytics/samples/async_samples/sample_analyze_actions_async.py b/sdk/textanalytics/azure-ai-textanalytics/samples/async_samples/sample_analyze_actions_async.py index d8894f025c51..9f46de7ec3a9 100644 --- a/sdk/textanalytics/azure-ai-textanalytics/samples/async_samples/sample_analyze_actions_async.py +++ b/sdk/textanalytics/azure-ai-textanalytics/samples/async_samples/sample_analyze_actions_async.py @@ -39,7 +39,6 @@ async def analyze_async(self): RecognizePiiEntitiesAction, ExtractKeyPhrasesAction, AnalyzeSentimentAction, - AnalyzeActionsType ) endpoint = os.environ["AZURE_TEXT_ANALYTICS_ENDPOINT"] @@ -51,12 +50,15 @@ async def analyze_async(self): ) documents = [ - "We went to Contoso Steakhouse located at midtown NYC last week for a dinner party, and we adore the spot! \ - They provide marvelous food and they have a great menu. The chief cook happens to be the owner (I think his name is John Doe) \ - and he is super nice, coming out of the kitchen and greeted us all. We enjoyed very much dining in the place! \ - The Sirloin steak I ordered was tender and juicy, and the place was impeccably clean. You can even pre-order from their \ - online menu at www.contososteakhouse.com, call 312-555-0176 or send email to order@contososteakhouse.com! \ - The only complaint I have is the food didn't come fast enough. Overall I highly recommend it!" + 'We went to Contoso Steakhouse located at midtown NYC last week for a dinner party, and we adore the spot!'\ + 'They provide marvelous food and they have a great menu. The chief cook happens to be the owner (I think his name is John Doe)'\ + 'and he is super nice, coming out of the kitchen and greeted us all.'\ + , + + 'We enjoyed very much dining in the place!'\ + 'The Sirloin steak I ordered was tender and juicy, and the place was impeccably clean. You can even pre-order from their'\ + 'online menu at www.contososteakhouse.com, call 312-555-0176 or send email to order@contososteakhouse.com!'\ + 'The only complaint I have is the food didn\'t come fast enough. Overall I highly recommend it!'\ ] async with text_analytics_client: @@ -72,71 +74,86 @@ async def analyze_async(self): ] ) - result = await poller.result() - - async for action_result in result: - if action_result.is_error: - raise ValueError( - "Action has failed with message: {}".format( - action_result.error.message - ) - ) - if action_result.action_type == AnalyzeActionsType.RECOGNIZE_ENTITIES: - print("Results of Entities Recognition action:") - for idx, doc in enumerate(action_result.document_results): - print("\nDocument text: {}".format(documents[idx])) - for entity in doc.entities: - print("Entity: {}".format(entity.text)) - print("...Category: {}".format(entity.category)) - print("...Confidence Score: {}".format(entity.confidence_score)) - print("...Offset: {}".format(entity.offset)) - print("------------------------------------------") - - if action_result.action_type == AnalyzeActionsType.RECOGNIZE_PII_ENTITIES: - print("Results of PII Entities Recognition action:") - for idx, doc in enumerate(action_result.document_results): - print("Document text: {}".format(documents[idx])) - for entity in doc.entities: - print("Entity: {}".format(entity.text)) - print("Category: {}".format(entity.category)) - print("Confidence Score: {}\n".format(entity.confidence_score)) - print("------------------------------------------") - - if action_result.action_type == AnalyzeActionsType.EXTRACT_KEY_PHRASES: - print("Results of Key Phrase Extraction action:") - for idx, doc in enumerate(action_result.document_results): - print("Document text: {}\n".format(documents[idx])) - print("Key Phrases: {}\n".format(doc.key_phrases)) - print("------------------------------------------") - - if action_result.action_type == AnalyzeActionsType.RECOGNIZE_LINKED_ENTITIES: - print("Results of Linked Entities Recognition action:") - for idx, doc in enumerate(action_result.document_results): - print("Document text: {}\n".format(documents[idx])) - for linked_entity in doc.entities: - print("Entity name: {}".format(linked_entity.name)) - print("...Data source: {}".format(linked_entity.data_source)) - print("...Data source language: {}".format(linked_entity.language)) - print("...Data source entity ID: {}".format(linked_entity.data_source_entity_id)) - print("...Data source URL: {}".format(linked_entity.url)) - print("...Document matches:") - for match in linked_entity.matches: - print("......Match text: {}".format(match.text)) - print(".........Confidence Score: {}".format(match.confidence_score)) - print(".........Offset: {}".format(match.offset)) - print(".........Length: {}".format(match.length)) - print("------------------------------------------") - - if action_result.action_type == AnalyzeActionsType.ANALYZE_SENTIMENT: - print("Results of Sentiment Analysis action:") - for doc in action_result.document_results: - print("Overall sentiment: {}".format(doc.sentiment)) - print("Scores: positive={}; neutral={}; negative={} \n".format( - doc.confidence_scores.positive, - doc.confidence_scores.neutral, - doc.confidence_scores.negative, - )) - print("------------------------------------------") + pages = await poller.result() + + # To enumerate / zip for async, unless you install a third party library, + # you have to read in all of the elements into memory first. + # If you're not looking to enumerate / zip, we recommend you just asynchronously + # loop over it immediately, without going through this step of reading them into memory + document_results = [] + async for page in pages: + document_results.append(page) + + for doc, action_results in zip(documents, document_results): + print("\nDocument text: {}".format(doc)) + recognize_entities_result = action_results[0] + print("...Results of Recognize Entities Action:") + if recognize_entities_result.is_error: + print("...Is an error with code '{}' and message '{}'".format( + recognize_entities_result.code, recognize_entities_result.message + )) + else: + for entity in recognize_entities_result.entities: + print("......Entity: {}".format(entity.text)) + print(".........Category: {}".format(entity.category)) + print(".........Confidence Score: {}".format(entity.confidence_score)) + print(".........Offset: {}".format(entity.offset)) + + recognize_pii_entities_result = action_results[1] + print("...Results of Recognize PII Entities action:") + if recognize_pii_entities_result.is_error: + print("...Is an error with code '{}' and message '{}'".format( + recognize_pii_entities_result.code, recognize_pii_entities_result.message + )) + else: + for entity in recognize_pii_entities_result.entities: + print("......Entity: {}".format(entity.text)) + print(".........Category: {}".format(entity.category)) + print(".........Confidence Score: {}".format(entity.confidence_score)) + + extract_key_phrases_result = action_results[2] + print("...Results of Extract Key Phrases action:") + if extract_key_phrases_result.is_error: + print("...Is an error with code '{}' and message '{}'".format( + extract_key_phrases_result.code, extract_key_phrases_result.message + )) + else: + print("......Key Phrases: {}".format(extract_key_phrases_result.key_phrases)) + + recognize_linked_entities_result = action_results[3] + print("...Results of Recognize Linked Entities action:") + if recognize_linked_entities_result.is_error: + print("...Is an error with code '{}' and message '{}'".format( + recognize_linked_entities_result.code, recognize_linked_entities_result.message + )) + else: + for linked_entity in recognize_linked_entities_result.entities: + print("......Entity name: {}".format(linked_entity.name)) + print(".........Data source: {}".format(linked_entity.data_source)) + print(".........Data source language: {}".format(linked_entity.language)) + print(".........Data source entity ID: {}".format(linked_entity.data_source_entity_id)) + print(".........Data source URL: {}".format(linked_entity.url)) + print(".........Document matches:") + for match in linked_entity.matches: + print("............Match text: {}".format(match.text)) + print("............Confidence Score: {}".format(match.confidence_score)) + print("............Offset: {}".format(match.offset)) + print("............Length: {}".format(match.length)) + + analyze_sentiment_result = action_results[4] + print("...Results of Analyze Sentiment action:") + if analyze_sentiment_result.is_error: + print("...Is an error with code '{}' and message '{}'".format( + analyze_sentiment_result.code, analyze_sentiment_result.message + )) + else: + print("......Overall sentiment: {}".format(analyze_sentiment_result.sentiment)) + print("......Scores: positive={}; neutral={}; negative={} \n".format( + analyze_sentiment_result.confidence_scores.positive, + analyze_sentiment_result.confidence_scores.neutral, + analyze_sentiment_result.confidence_scores.negative, + )) + print("------------------------------------------") # [END analyze_async] diff --git a/sdk/textanalytics/azure-ai-textanalytics/samples/sample_analyze_actions.py b/sdk/textanalytics/azure-ai-textanalytics/samples/sample_analyze_actions.py index c613fd6d45a5..16f431efb750 100644 --- a/sdk/textanalytics/azure-ai-textanalytics/samples/sample_analyze_actions.py +++ b/sdk/textanalytics/azure-ai-textanalytics/samples/sample_analyze_actions.py @@ -50,12 +50,15 @@ def analyze(self): ) documents = [ - "We went to Contoso Steakhouse located at midtown NYC last week for a dinner party, and we adore the spot! \ - They provide marvelous food and they have a great menu. The chief cook happens to be the owner (I think his name is John Doe) \ - and he is super nice, coming out of the kitchen and greeted us all. We enjoyed very much dining in the place! \ - The Sirloin steak I ordered was tender and juicy, and the place was impeccably clean. You can even pre-order from their \ - online menu at www.contososteakhouse.com, call 312-555-0176 or send email to order@contososteakhouse.com! \ - The only complaint I have is the food didn't come fast enough. Overall I highly recommend it!" + 'We went to Contoso Steakhouse located at midtown NYC last week for a dinner party, and we adore the spot!'\ + 'They provide marvelous food and they have a great menu. The chief cook happens to be the owner (I think his name is John Doe)'\ + 'and he is super nice, coming out of the kitchen and greeted us all.'\ + , + + 'We enjoyed very much dining in the place!'\ + 'The Sirloin steak I ordered was tender and juicy, and the place was impeccably clean. You can even pre-order from their'\ + 'online menu at www.contososteakhouse.com, call 312-555-0176 or send email to order@contososteakhouse.com!'\ + 'The only complaint I have is the food didn\'t come fast enough. Overall I highly recommend it!'\ ] poller = text_analytics_client.begin_analyze_actions( @@ -70,78 +73,76 @@ def analyze(self): ], ) - result = poller.result() - action_results = [action_result for action_result in list(result) if not action_result.is_error] - - first_action_result = action_results[0] - print("Results of Entities Recognition action:") - docs = [doc for doc in first_action_result.document_results if not doc.is_error] - - for idx, doc in enumerate(docs): - print("\nDocument text: {}".format(documents[idx])) - for entity in doc.entities: - print("Entity: {}".format(entity.text)) - print("...Category: {}".format(entity.category)) - print("...Confidence Score: {}".format(entity.confidence_score)) - print("...Offset: {}".format(entity.offset)) - print("...Length: {}".format(entity.length)) - print("------------------------------------------") - - second_action_result = action_results[1] - print("Results of PII Entities Recognition action:") - docs = [doc for doc in second_action_result.document_results if not doc.is_error] - - for idx, doc in enumerate(docs): - print("Document text: {}".format(documents[idx])) - print("Document text with redactions: {}".format(doc.redacted_text)) - for entity in doc.entities: - print("Entity: {}".format(entity.text)) - print("...Category: {}".format(entity.category)) - print("...Confidence Score: {}\n".format(entity.confidence_score)) - print("...Offset: {}".format(entity.offset)) - print("...Length: {}".format(entity.length)) - print("------------------------------------------") - - third_action_result = action_results[2] - print("Results of Key Phrase Extraction action:") - docs = [doc for doc in third_action_result.document_results if not doc.is_error] - - for idx, doc in enumerate(docs): - print("Document text: {}\n".format(documents[idx])) - print("Key Phrases: {}\n".format(doc.key_phrases)) - print("------------------------------------------") - - fourth_action_result = action_results[3] - print("Results of Linked Entities Recognition action:") - docs = [doc for doc in fourth_action_result.document_results if not doc.is_error] - - for idx, doc in enumerate(docs): - print("Document text: {}\n".format(documents[idx])) - for linked_entity in doc.entities: - print("Entity name: {}".format(linked_entity.name)) - print("...Data source: {}".format(linked_entity.data_source)) - print("...Data source language: {}".format(linked_entity.language)) - print("...Data source entity ID: {}".format(linked_entity.data_source_entity_id)) - print("...Data source URL: {}".format(linked_entity.url)) - print("...Document matches:") - for match in linked_entity.matches: - print("......Match text: {}".format(match.text)) - print(".........Confidence Score: {}".format(match.confidence_score)) - print(".........Offset: {}".format(match.offset)) - print(".........Length: {}".format(match.length)) - print("------------------------------------------") - - fifth_action_result = action_results[4] - print("Results of Sentiment Analysis action:") - docs = [doc for doc in fifth_action_result.document_results if not doc.is_error] - - for doc in docs: - print("Overall sentiment: {}".format(doc.sentiment)) - print("Scores: positive={}; neutral={}; negative={} \n".format( - doc.confidence_scores.positive, - doc.confidence_scores.neutral, - doc.confidence_scores.negative, - )) + document_results = poller.result() + for doc, action_results in zip(documents, document_results): + print("\nDocument text: {}".format(doc)) + recognize_entities_result = action_results[0] + print("...Results of Recognize Entities Action:") + if recognize_entities_result.is_error: + print("...Is an error with code '{}' and message '{}'".format( + recognize_entities_result.code, recognize_entities_result.message + )) + else: + for entity in recognize_entities_result.entities: + print("......Entity: {}".format(entity.text)) + print(".........Category: {}".format(entity.category)) + print(".........Confidence Score: {}".format(entity.confidence_score)) + print(".........Offset: {}".format(entity.offset)) + + recognize_pii_entities_result = action_results[1] + print("...Results of Recognize PII Entities action:") + if recognize_pii_entities_result.is_error: + print("...Is an error with code '{}' and message '{}'".format( + recognize_pii_entities_result.code, recognize_pii_entities_result.message + )) + else: + for entity in recognize_pii_entities_result.entities: + print("......Entity: {}".format(entity.text)) + print(".........Category: {}".format(entity.category)) + print(".........Confidence Score: {}".format(entity.confidence_score)) + + extract_key_phrases_result = action_results[2] + print("...Results of Extract Key Phrases action:") + if extract_key_phrases_result.is_error: + print("...Is an error with code '{}' and message '{}'".format( + extract_key_phrases_result.code, extract_key_phrases_result.message + )) + else: + print("......Key Phrases: {}".format(extract_key_phrases_result.key_phrases)) + + recognize_linked_entities_result = action_results[3] + print("...Results of Recognize Linked Entities action:") + if recognize_linked_entities_result.is_error: + print("...Is an error with code '{}' and message '{}'".format( + recognize_linked_entities_result.code, recognize_linked_entities_result.message + )) + else: + for linked_entity in recognize_linked_entities_result.entities: + print("......Entity name: {}".format(linked_entity.name)) + print(".........Data source: {}".format(linked_entity.data_source)) + print(".........Data source language: {}".format(linked_entity.language)) + print(".........Data source entity ID: {}".format(linked_entity.data_source_entity_id)) + print(".........Data source URL: {}".format(linked_entity.url)) + print(".........Document matches:") + for match in linked_entity.matches: + print("............Match text: {}".format(match.text)) + print("............Confidence Score: {}".format(match.confidence_score)) + print("............Offset: {}".format(match.offset)) + print("............Length: {}".format(match.length)) + + analyze_sentiment_result = action_results[4] + print("...Results of Analyze Sentiment action:") + if analyze_sentiment_result.is_error: + print("...Is an error with code '{}' and message '{}'".format( + analyze_sentiment_result.code, analyze_sentiment_result.message + )) + else: + print("......Overall sentiment: {}".format(analyze_sentiment_result.sentiment)) + print("......Scores: positive={}; neutral={}; negative={} \n".format( + analyze_sentiment_result.confidence_scores.positive, + analyze_sentiment_result.confidence_scores.neutral, + analyze_sentiment_result.confidence_scores.negative, + )) print("------------------------------------------") # [END analyze] diff --git a/sdk/textanalytics/azure-ai-textanalytics/tests/test_analyze.py b/sdk/textanalytics/azure-ai-textanalytics/tests/test_analyze.py index 96e7419257fc..b36d16dbf390 100644 --- a/sdk/textanalytics/azure-ai-textanalytics/tests/test_analyze.py +++ b/sdk/textanalytics/azure-ai-textanalytics/tests/test_analyze.py @@ -4,6 +4,7 @@ # Licensed under the MIT License. # ------------------------------------ +from collections import defaultdict import os import pytest import platform @@ -26,7 +27,12 @@ TextDocumentInput, VERSION, TextAnalyticsApiVersion, - AnalyzeActionsType, + _AnalyzeActionsType, + ExtractKeyPhrasesResult, + AnalyzeSentimentResult, + RecognizeLinkedEntitiesResult, + RecognizeEntitiesResult, + RecognizePiiEntitiesResult, ) # pre-apply the client_cls positional argument so it needn't be explicitly passed below @@ -57,19 +63,17 @@ def test_all_successful_passing_dict_key_phrase_task(self, client): polling_interval=self._interval(), ).result() - action_results = list(response) + document_results = list(response) - assert len(action_results) == 1 - action_result = action_results[0] - - assert action_result.action_type == AnalyzeActionsType.EXTRACT_KEY_PHRASES - assert len(action_result.document_results) == len(docs) - - for doc in action_result.document_results: - self.assertIn("Paul Allen", doc.key_phrases) - self.assertIn("Bill Gates", doc.key_phrases) - self.assertIn("Microsoft", doc.key_phrases) - self.assertIsNotNone(doc.id) + assert len(document_results) == 2 + for document_result in document_results: + assert len(document_result) == 1 + for document_result in document_result: + assert isinstance(document_result, ExtractKeyPhrasesResult) + assert "Paul Allen" in document_result.key_phrases + assert "Bill Gates" in document_result.key_phrases + assert "Microsoft" in document_result.key_phrases + assert document_result.id is not None @GlobalTextAnalyticsAccountPreparer() @TextAnalyticsClientPreparer() @@ -85,33 +89,31 @@ def test_all_successful_passing_dict_sentiment_task(self, client): polling_interval=self._interval(), ).result() - action_results = list(response) - - assert len(action_results) == 1 - action_result = action_results[0] - - assert action_result.action_type == AnalyzeActionsType.ANALYZE_SENTIMENT - assert len(action_result.document_results) == len(docs) - - self.assertEqual(action_result.document_results[0].sentiment, "neutral") - self.assertEqual(action_result.document_results[1].sentiment, "negative") - self.assertEqual(action_result.document_results[2].sentiment, "positive") - - for doc in action_result.document_results: - self.assertIsNotNone(doc.id) - self.assertIsNotNone(doc.statistics) - self.validateConfidenceScores(doc.confidence_scores) - self.assertIsNotNone(doc.sentences) - - self.assertEqual(len(action_result.document_results[0].sentences), 1) - self.assertEqual(action_result.document_results[0].sentences[0].text, "Microsoft was founded by Bill Gates and Paul Allen.") - self.assertEqual(len(action_result.document_results[1].sentences), 2) - self.assertEqual(action_result.document_results[1].sentences[0].text, "I did not like the hotel we stayed at.") - self.assertEqual(action_result.document_results[1].sentences[1].text, "It was too expensive.") - self.assertEqual(len(action_result.document_results[2].sentences), 2) - self.assertEqual(action_result.document_results[2].sentences[0].text, "The restaurant had really good food.") - self.assertEqual(action_result.document_results[2].sentences[1].text, "I recommend you try it.") - + pages = list(response) + + assert len(pages) == len(docs) + for idx, document_results in enumerate(pages): + assert len(document_results) == 1 + document_result = document_results[0] + assert isinstance(document_result, AnalyzeSentimentResult) + assert document_result.id is not None + assert document_result.statistics is not None + self.validateConfidenceScores(document_result.confidence_scores) + assert document_result.sentences is not None + if idx == 0: + assert document_result.sentiment == "neutral" + assert len(document_result.sentences) == 1 + assert document_result.sentences[0].text == "Microsoft was founded by Bill Gates and Paul Allen." + elif idx == 1: + assert document_result.sentiment == "negative" + assert len(document_result.sentences) == 2 + assert document_result.sentences[0].text == "I did not like the hotel we stayed at." + assert document_result.sentences[1].text == "It was too expensive." + else: + assert document_result.sentiment == "positive" + assert len(document_result.sentences) == 2 + assert document_result.sentences[0].text == "The restaurant had really good food." + assert document_result.sentences[1].text == "I recommend you try it." @GlobalTextAnalyticsAccountPreparer() @TextAnalyticsClientPreparer() @@ -128,16 +130,14 @@ def test_sentiment_analysis_task_with_opinion_mining(self, client): polling_interval=self._interval(), ).result() - action_results = list(response) - - assert len(action_results) == 1 - action_result = action_results[0] + pages = list(response) - assert action_result.action_type == AnalyzeActionsType.ANALYZE_SENTIMENT - assert len(action_result.document_results) == len(documents) - - for idx, doc in enumerate(action_result.document_results): - for sentence in doc.sentences: + assert len(pages) == len(documents) + for idx, document_results in enumerate(pages): + assert len(document_results) == 1 + document_result = document_results[0] + assert isinstance(document_result, AnalyzeSentimentResult) + for sentence in document_result.sentences: if idx == 0: for mined_opinion in sentence.mined_opinions: target = mined_opinion.target @@ -206,19 +206,20 @@ def test_all_successful_passing_text_document_input_entities_task(self, client): polling_interval=self._interval(), ).result() - action_results = list(response) - - assert len(action_results) == 1 - action_result = action_results[0] - - assert action_result.action_type == AnalyzeActionsType.RECOGNIZE_ENTITIES - assert len(action_result.document_results) == len(docs) - - for doc in action_result.document_results: - self.assertEqual(len(doc.entities), 4) - self.assertIsNotNone(doc.id) - for entity in doc.entities: - self.assertIsNotNone(entity.text) + pages = list(response) + assert len(pages) == len(docs) + + for document_results in pages: + assert len(document_results) == 1 + document_result = document_results[0] + assert isinstance(document_result, RecognizeEntitiesResult) + assert len(document_result.entities) == 4 + assert document_result.id is not None + for entity in document_result.entities: + assert entity.text is not None + assert entity.category is not None + assert entity.offset is not None + assert entity.confidence_score is not None self.assertIsNotNone(entity.category) self.assertIsNotNone(entity.offset) self.assertIsNotNone(entity.confidence_score) @@ -239,29 +240,23 @@ def test_all_successful_passing_string_pii_entities_task(self, client): polling_interval=self._interval(), ).result() - action_results = list(response) - - assert len(action_results) == 1 - action_result = action_results[0] - - assert action_result.action_type == AnalyzeActionsType.RECOGNIZE_PII_ENTITIES - assert len(action_result.document_results) == len(docs) - - self.assertEqual(action_result.document_results[0].entities[0].text, "859-98-0987") - self.assertEqual(action_result.document_results[0].entities[0].category, "USSocialSecurityNumber") - self.assertEqual(action_result.document_results[1].entities[0].text, "111000025") - # self.assertEqual(results[1].entities[0].category, "ABA Routing Number") # Service is currently returning PhoneNumber here - - # commenting out brazil cpf, currently service is not returning it - # self.assertEqual(action_result.document_results[2].entities[0].text, "998.214.865-68") - # self.assertEqual(action_result.document_results[2].entities[0].category, "Brazil CPF Number") - for doc in action_result.document_results: - self.assertIsNotNone(doc.id) - for entity in doc.entities: - self.assertIsNotNone(entity.text) - self.assertIsNotNone(entity.category) - self.assertIsNotNone(entity.offset) - self.assertIsNotNone(entity.confidence_score) + pages = list(response) + assert len(pages) == len(docs) + + for idx, document_results in enumerate(pages): + assert len(document_results) == 1 + document_result = document_results[0] + assert isinstance(document_result, RecognizePiiEntitiesResult) + if idx == 0: + assert document_result.entities[0].text == "859-98-0987" + assert document_result.entities[0].category == "USSocialSecurityNumber" + elif idx == 1: + assert document_result.entities[0].text == "111000025" + for entity in document_result.entities: + assert entity.text is not None + assert entity.category is not None + assert entity.offset is not None + assert entity.confidence_score is not None @GlobalTextAnalyticsAccountPreparer() @TextAnalyticsClientPreparer() @@ -331,23 +326,22 @@ def test_out_of_order_ids_multiple_tasks(self, client): polling_interval=self._interval(), ).result() - action_results = list(response) - assert len(action_results) == 5 - - assert action_results[0].action_type == AnalyzeActionsType.RECOGNIZE_ENTITIES - assert action_results[1].action_type == AnalyzeActionsType.EXTRACT_KEY_PHRASES - assert action_results[2].action_type == AnalyzeActionsType.RECOGNIZE_PII_ENTITIES - assert action_results[3].action_type == AnalyzeActionsType.RECOGNIZE_LINKED_ENTITIES - assert action_results[4].action_type == AnalyzeActionsType.ANALYZE_SENTIMENT - - action_results = [r for r in action_results if not r.is_error] - assert all([action_result for action_result in action_results if len(action_result.document_results) == len(docs)]) - - in_order = ["56", "0", "19", "1"] + results = list(response) + assert len(results) == len(docs) - for action_result in action_results: - for idx, resp in enumerate(action_result.document_results): - self.assertEqual(resp.id, in_order[idx]) + document_order = ["56", "0", "19", "1"] + action_order = [ + _AnalyzeActionsType.RECOGNIZE_ENTITIES, + _AnalyzeActionsType.EXTRACT_KEY_PHRASES, + _AnalyzeActionsType.RECOGNIZE_PII_ENTITIES, + _AnalyzeActionsType.RECOGNIZE_LINKED_ENTITIES, + _AnalyzeActionsType.ANALYZE_SENTIMENT, + ] + for doc_idx, document_results in enumerate(results): + assert len(document_results) == 5 + for action_idx, document_result in enumerate(document_results): + self.assertEqual(document_result.id, document_order[doc_idx]) + self.assertEqual(self.document_result_to_action_type(document_result), action_order[action_idx]) @GlobalTextAnalyticsAccountPreparer() @TextAnalyticsClientPreparer() @@ -393,20 +387,21 @@ def callback(resp): response = poller.result() - action_results = list(response) - assert len(action_results) == 5 - assert action_results[0].action_type == AnalyzeActionsType.RECOGNIZE_ENTITIES - assert action_results[1].action_type == AnalyzeActionsType.EXTRACT_KEY_PHRASES - assert action_results[2].action_type == AnalyzeActionsType.RECOGNIZE_PII_ENTITIES - assert action_results[3].action_type == AnalyzeActionsType.RECOGNIZE_LINKED_ENTITIES - assert action_results[4].action_type == AnalyzeActionsType.ANALYZE_SENTIMENT - - assert all([action_result for action_result in action_results if len(action_result.document_results) == len(docs)]) - - for action_result in action_results: - assert not hasattr(action_result, "statistics") - for doc in action_result.document_results: - assert doc.statistics + pages = list(response) + assert len(pages) == len(docs) + action_order = [ + _AnalyzeActionsType.RECOGNIZE_ENTITIES, + _AnalyzeActionsType.EXTRACT_KEY_PHRASES, + _AnalyzeActionsType.RECOGNIZE_PII_ENTITIES, + _AnalyzeActionsType.RECOGNIZE_LINKED_ENTITIES, + _AnalyzeActionsType.ANALYZE_SENTIMENT, + ] + for document_results in pages: + assert len(document_results) == len(action_order) + for document_result in document_results: + assert document_result.statistics + assert document_result.statistics.character_count + assert document_result.statistics.transaction_count @GlobalTextAnalyticsAccountPreparer() @TextAnalyticsClientPreparer() @@ -422,7 +417,7 @@ def test_poller_metadata(self, client): polling_interval=self._interval(), ) - response = poller.result() + poller.result() assert isinstance(poller.created_on, datetime.datetime) poller._polling_method.display_name @@ -463,8 +458,8 @@ def test_poller_metadata(self, client): # raw_response_hook=callback # ).result()) - # for action_result in response: - # for doc in action_result.document_results: + # for document_result in response: + # for doc in document_result.document_results: # self.assertFalse(doc.is_error) # @GlobalTextAnalyticsAccountPreparer() @@ -498,8 +493,8 @@ def test_poller_metadata(self, client): # polling_interval=self._interval(), # ).result()) - # for action_result in response: - # for doc in action_result.document_results: + # for document_result in response: + # for doc in document_result.document_results: # assert not doc.is_error @GlobalTextAnalyticsAccountPreparer() @@ -518,8 +513,8 @@ def test_invalid_language_hint_method(self, client): polling_interval=self._interval(), ).result()) - for action_result in response: - for doc in action_result.document_results: + for document_results in response: + for doc in document_results: assert doc.is_error @GlobalTextAnalyticsAccountPreparer() @@ -528,7 +523,7 @@ def test_bad_model_version_error_multiple_tasks(self, client): docs = [{"id": "1", "language": "english", "text": "I did not like the hotel we stayed at."}] with pytest.raises(HttpResponseError): - response = client.begin_analyze_actions( + client.begin_analyze_actions( docs, actions=[ RecognizeEntitiesAction(model_version="latest"), @@ -546,7 +541,7 @@ def test_bad_model_version_error_all_tasks(self, client): # TODO: verify behavi docs = [{"id": "1", "language": "english", "text": "I did not like the hotel we stayed at."}] with self.assertRaises(HttpResponseError): - response = client.begin_analyze_actions( + client.begin_analyze_actions( docs, actions=[ RecognizeEntitiesAction(model_version="bad"), @@ -617,41 +612,27 @@ def test_multiple_pages_of_results_returned_successfully(self, client): polling_interval=self._interval(), ).result() - recognize_entities_results = [] - extract_key_phrases_results = [] - recognize_pii_entities_results = [] - recognize_linked_entities_results = [] - analyze_sentiment_results = [] - - action_results = list(result) - - # do 2 pages of 5 task results - for idx, action_result in enumerate(action_results): - if idx % 5 == 0: - assert action_result.action_type == AnalyzeActionsType.RECOGNIZE_ENTITIES - recognize_entities_results.append(action_result) - elif idx % 5 == 1: - assert action_result.action_type == AnalyzeActionsType.EXTRACT_KEY_PHRASES - extract_key_phrases_results.append(action_result) - elif idx % 5 == 2: - assert action_result.action_type == AnalyzeActionsType.RECOGNIZE_PII_ENTITIES - recognize_pii_entities_results.append(action_result) - elif idx % 5 == 3: - assert action_result.action_type == AnalyzeActionsType.RECOGNIZE_LINKED_ENTITIES - recognize_linked_entities_results.append(action_result) - else: - assert action_result.action_type == AnalyzeActionsType.ANALYZE_SENTIMENT - analyze_sentiment_results.append(action_result) - if idx < 5: # first page of task results - assert len(action_result.document_results) == 20 - else: - assert len(action_result.document_results) == 5 - - assert all([action_result for action_result in recognize_entities_results if len(action_result.document_results) == len(docs)]) - assert all([action_result for action_result in extract_key_phrases_results if len(action_result.document_results) == len(docs)]) - assert all([action_result for action_result in recognize_pii_entities_results if len(action_result.document_results) == len(docs)]) - assert all([action_result for action_result in recognize_linked_entities_results if len(action_result.document_results) == len(docs)]) - assert all([action_result for action_result in analyze_sentiment_results if len(action_result.document_results) == len(docs)]) + pages = list(result) + assert len(pages) == len(docs) + action_order = [ + _AnalyzeActionsType.RECOGNIZE_ENTITIES, + _AnalyzeActionsType.EXTRACT_KEY_PHRASES, + _AnalyzeActionsType.RECOGNIZE_PII_ENTITIES, + _AnalyzeActionsType.RECOGNIZE_LINKED_ENTITIES, + _AnalyzeActionsType.ANALYZE_SENTIMENT, + ] + action_type_to_document_results = defaultdict(list) + + for doc_idx, page in enumerate(pages): + for action_idx, document_result in enumerate(page): + self.assertEqual(document_result.id, str(doc_idx)) + action_type = self.document_result_to_action_type(document_result) + self.assertEqual(action_type, action_order[action_idx]) + action_type_to_document_results[action_type].append(document_result) + + assert len(action_type_to_document_results) == len(action_order) + for document_results in action_type_to_document_results.values(): + assert len(document_results) == len(docs) @GlobalTextAnalyticsAccountPreparer() @TextAnalyticsClientPreparer() diff --git a/sdk/textanalytics/azure-ai-textanalytics/tests/test_analyze_async.py b/sdk/textanalytics/azure-ai-textanalytics/tests/test_analyze_async.py index 6230eb51c9ce..dea8e8e93ca2 100644 --- a/sdk/textanalytics/azure-ai-textanalytics/tests/test_analyze_async.py +++ b/sdk/textanalytics/azure-ai-textanalytics/tests/test_analyze_async.py @@ -3,6 +3,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # ------------------------------------ +from collections import defaultdict import datetime import os import pytest @@ -22,14 +23,17 @@ from azure.ai.textanalytics.aio import TextAnalyticsClient from azure.ai.textanalytics import ( TextDocumentInput, - VERSION, - TextAnalyticsApiVersion, RecognizeEntitiesAction, RecognizeLinkedEntitiesAction, RecognizePiiEntitiesAction, ExtractKeyPhrasesAction, AnalyzeSentimentAction, - AnalyzeActionsType + _AnalyzeActionsType, + RecognizePiiEntitiesResult, + RecognizeEntitiesResult, + RecognizeLinkedEntitiesResult, + AnalyzeSentimentResult, + ExtractKeyPhrasesResult, ) # pre-apply the client_cls positional argument so it needn't be explicitly passed below @@ -73,20 +77,19 @@ async def test_all_successful_passing_dict_key_phrase_task(self, client): polling_interval=self._interval() )).result() - action_results = [] + document_results = [] async for p in response: - action_results.append(p) - assert len(action_results) == 1 - action_result = action_results[0] - - assert action_result.action_type == AnalyzeActionsType.EXTRACT_KEY_PHRASES - assert len(action_result.document_results) == len(docs) - - for doc in action_result.document_results: - self.assertIn("Paul Allen", doc.key_phrases) - self.assertIn("Bill Gates", doc.key_phrases) - self.assertIn("Microsoft", doc.key_phrases) - self.assertIsNotNone(doc.id) + document_results.append(p) + assert len(document_results) == 2 + + for document_result in document_results: + assert len(document_result) == 1 + for document_result in document_result: + assert isinstance(document_result, ExtractKeyPhrasesResult) + assert "Paul Allen" in document_result.key_phrases + assert "Bill Gates" in document_result.key_phrases + assert "Microsoft" in document_result.key_phrases + assert document_result.id is not None @GlobalTextAnalyticsAccountPreparer() @TextAnalyticsClientPreparer() @@ -103,34 +106,34 @@ async def test_all_successful_passing_dict_sentiment_task(self, client): polling_interval=self._interval(), )).result() - action_results = [] + pages = [] async for p in response: - action_results.append(p) - - assert len(action_results) == 1 - action_result = action_results[0] - - assert action_result.action_type == AnalyzeActionsType.ANALYZE_SENTIMENT - assert len(action_result.document_results) == len(docs) - - self.assertEqual(action_result.document_results[0].sentiment, "neutral") - self.assertEqual(action_result.document_results[1].sentiment, "negative") - self.assertEqual(action_result.document_results[2].sentiment, "positive") - - for doc in action_result.document_results: - self.assertIsNotNone(doc.id) - self.assertIsNotNone(doc.statistics) - self.validateConfidenceScores(doc.confidence_scores) - self.assertIsNotNone(doc.sentences) - - self.assertEqual(len(action_result.document_results[0].sentences), 1) - self.assertEqual(action_result.document_results[0].sentences[0].text, "Microsoft was founded by Bill Gates and Paul Allen.") - self.assertEqual(len(action_result.document_results[1].sentences), 2) - self.assertEqual(action_result.document_results[1].sentences[0].text, "I did not like the hotel we stayed at.") - self.assertEqual(action_result.document_results[1].sentences[1].text, "It was too expensive.") - self.assertEqual(len(action_result.document_results[2].sentences), 2) - self.assertEqual(action_result.document_results[2].sentences[0].text, "The restaurant had really good food.") - self.assertEqual(action_result.document_results[2].sentences[1].text, "I recommend you try it.") + pages.append(p) + + assert len(pages) == len(docs) + + for idx, document_results in enumerate(pages): + assert len(document_results) == 1 + document_result = document_results[0] + assert isinstance(document_result, AnalyzeSentimentResult) + assert document_result.id is not None + assert document_result.statistics is not None + self.validateConfidenceScores(document_result.confidence_scores) + assert document_result.sentences is not None + if idx == 0: + assert document_result.sentiment == "neutral" + assert len(document_result.sentences) == 1 + assert document_result.sentences[0].text == "Microsoft was founded by Bill Gates and Paul Allen." + elif idx == 1: + assert document_result.sentiment == "negative" + assert len(document_result.sentences) == 2 + assert document_result.sentences[0].text == "I did not like the hotel we stayed at." + assert document_result.sentences[1].text == "It was too expensive." + else: + assert document_result.sentiment == "positive" + assert len(document_result.sentences) == 2 + assert document_result.sentences[0].text == "The restaurant had really good food." + assert document_result.sentences[1].text == "I recommend you try it." @GlobalTextAnalyticsAccountPreparer() @TextAnalyticsClientPreparer() @@ -148,18 +151,17 @@ async def test_sentiment_analysis_task_with_opinion_mining(self, client): polling_interval=self._interval(), )).result() - action_results = [] + pages = [] async for p in response: - action_results.append(p) - - assert len(action_results) == 1 - action_result = action_results[0] + pages.append(p) - assert action_result.action_type == AnalyzeActionsType.ANALYZE_SENTIMENT - assert len(action_result.document_results) == len(documents) + assert len(pages) == len(documents) - for idx, doc in enumerate(action_result.document_results): - for sentence in doc.sentences: + for idx, document_results in enumerate(pages): + assert len(document_results) == 1 + document_result = document_results[0] + assert isinstance(document_result, AnalyzeSentimentResult) + for sentence in document_result.sentences: if idx == 0: for mined_opinion in sentence.mined_opinions: target = mined_opinion.target @@ -211,7 +213,6 @@ async def test_sentiment_analysis_task_with_opinion_mining(self, client): self.assertEqual('food', food_target.text) self.assertEqual('negative', food_target.sentiment) self.assertEqual(0.0, food_target.confidence_scores.neutral) - @GlobalTextAnalyticsAccountPreparer() @TextAnalyticsClientPreparer() async def test_all_successful_passing_text_document_input_entities_task(self, client): @@ -230,20 +231,22 @@ async def test_all_successful_passing_text_document_input_entities_task(self, cl ) response = await poller.result() - action_results = [] + pages = [] async for p in response: - action_results.append(p) - assert len(action_results) == 1 - action_result = action_results[0] - - assert action_result.action_type == AnalyzeActionsType.RECOGNIZE_ENTITIES - assert len(action_result.document_results) == len(docs) - - for doc in action_result.document_results: - self.assertEqual(len(doc.entities), 4) - self.assertIsNotNone(doc.id) - for entity in doc.entities: - self.assertIsNotNone(entity.text) + pages.append(p) + assert len(pages) == len(docs) + + for document_results in pages: + assert len(document_results) == 1 + document_result = document_results[0] + assert isinstance(document_result, RecognizeEntitiesResult) + assert len(document_result.entities) == 4 + assert document_result.id is not None + for entity in document_result.entities: + assert entity.text is not None + assert entity.category is not None + assert entity.offset is not None + assert entity.confidence_score is not None self.assertIsNotNone(entity.category) self.assertIsNotNone(entity.offset) self.assertIsNotNone(entity.confidence_score) @@ -265,30 +268,25 @@ async def test_all_successful_passing_string_pii_entities_task(self, client): polling_interval=self._interval() )).result() - action_results = [] + pages = [] async for p in response: - action_results.append(p) - assert len(action_results) == 1 - action_result = action_results[0] - - assert action_result.action_type == AnalyzeActionsType.RECOGNIZE_PII_ENTITIES - assert len(action_result.document_results) == len(docs) - - self.assertEqual(action_result.document_results[0].entities[0].text, "859-98-0987") - self.assertEqual(action_result.document_results[0].entities[0].category, "USSocialSecurityNumber") - self.assertEqual(action_result.document_results[1].entities[0].text, "111000025") - # self.assertEqual(results[1].entities[0].category, "ABA Routing Number") # Service is currently returning PhoneNumber here - - # commenting out brazil cpf, currently service is not returning it - # self.assertEqual(action_result.document_results[2].entities[0].text, "998.214.865-68") - # self.assertEqual(action_result.document_results[2].entities[0].category, "Brazil CPF Number") - for doc in action_result.document_results: - self.assertIsNotNone(doc.id) - for entity in doc.entities: - self.assertIsNotNone(entity.text) - self.assertIsNotNone(entity.category) - self.assertIsNotNone(entity.offset) - self.assertIsNotNone(entity.confidence_score) + pages.append(p) + assert len(pages) == len(docs) + + for idx, document_results in enumerate(pages): + assert len(document_results) == 1 + document_result = document_results[0] + assert isinstance(document_result, RecognizePiiEntitiesResult) + if idx == 0: + assert document_result.entities[0].text == "859-98-0987" + assert document_result.entities[0].category == "USSocialSecurityNumber" + elif idx == 1: + assert document_result.entities[0].text == "111000025" + for entity in document_result.entities: + assert entity.text is not None + assert entity.category is not None + assert entity.offset is not None + assert entity.confidence_score is not None @GlobalTextAnalyticsAccountPreparer() @TextAnalyticsClientPreparer() @@ -297,7 +295,7 @@ async def test_bad_request_on_empty_document(self, client): with self.assertRaises(HttpResponseError): async with client: - response = await (await client.begin_analyze_actions( + await (await client.begin_analyze_actions( docs, actions=[ExtractKeyPhrasesAction()], polling_interval=self._interval() @@ -310,7 +308,7 @@ async def test_bad_request_on_empty_document(self, client): async def test_empty_credential_class(self, client): with self.assertRaises(ClientAuthenticationError): async with client: - response = await (await client.begin_analyze_actions( + await (await client.begin_analyze_actions( ["This is written in English."], actions=[ RecognizeEntitiesAction(), @@ -329,7 +327,7 @@ async def test_empty_credential_class(self, client): async def test_bad_credentials(self, client): with self.assertRaises(ClientAuthenticationError): async with client: - response = await (await client.begin_analyze_actions( + await (await client.begin_analyze_actions( ["This is written in English."], actions=[ RecognizeEntitiesAction(), @@ -362,30 +360,40 @@ async def test_out_of_order_ids_multiple_tasks(self, client): polling_interval=self._interval() )).result() - action_results = [] + results = [] async for p in response: - action_results.append(p) - assert len(action_results) == 5 - - assert action_results[0].action_type == AnalyzeActionsType.RECOGNIZE_ENTITIES - assert action_results[1].action_type == AnalyzeActionsType.EXTRACT_KEY_PHRASES - assert action_results[2].action_type == AnalyzeActionsType.RECOGNIZE_PII_ENTITIES - assert action_results[3].action_type == AnalyzeActionsType.RECOGNIZE_LINKED_ENTITIES - assert action_results[4].action_type == AnalyzeActionsType.ANALYZE_SENTIMENT - - action_results = [r for r in action_results if not r.is_error] - - assert all([action_result for action_result in action_results if len(action_result.document_results) == len(docs)]) + results.append(p) + assert len(results) == len(docs) + + document_order = ["56", "0", "19", "1"] + action_order = [ + _AnalyzeActionsType.RECOGNIZE_ENTITIES, + _AnalyzeActionsType.EXTRACT_KEY_PHRASES, + _AnalyzeActionsType.RECOGNIZE_PII_ENTITIES, + _AnalyzeActionsType.RECOGNIZE_LINKED_ENTITIES, + _AnalyzeActionsType.ANALYZE_SENTIMENT, + ] + for doc_idx, document_results in enumerate(results): + assert len(document_results) == 5 + for action_idx, document_result in enumerate(document_results): + self.assertEqual(document_result.id, document_order[doc_idx]) + self.assertEqual(self.document_result_to_action_type(document_result), action_order[action_idx]) - in_order = ["56", "0", "19", "1"] - - for action_result in action_results: - for idx, resp in enumerate(action_result.document_results): - self.assertEqual(resp.id, in_order[idx]) @GlobalTextAnalyticsAccountPreparer() @TextAnalyticsClientPreparer() async def test_show_stats_and_model_version_multiple_tasks(self, client): + + def callback(resp): + if not resp.raw_response: + # this is the initial post call + request_body = json.loads(resp.http_request.body) + assert len(request_body["tasks"]) == 5 + for task in request_body["tasks"].values(): + assert len(task) == 1 + assert task[0]['parameters']['model-version'] == 'latest' + assert not task[0]['parameters']['loggingOptOut'] + docs = [{"id": "56", "text": ":)"}, {"id": "0", "text": ":("}, {"id": "19", "text": ":P"}, @@ -425,22 +433,24 @@ def callback(resp): raw_response_hook=callback, )).result() - action_results = [] + pages = [] async for p in response: - action_results.append(p) - assert len(action_results) == 5 - assert action_results[0].action_type == AnalyzeActionsType.RECOGNIZE_ENTITIES - assert action_results[1].action_type == AnalyzeActionsType.EXTRACT_KEY_PHRASES - assert action_results[2].action_type == AnalyzeActionsType.RECOGNIZE_PII_ENTITIES - assert action_results[3].action_type == AnalyzeActionsType.RECOGNIZE_LINKED_ENTITIES - assert action_results[4].action_type == AnalyzeActionsType.ANALYZE_SENTIMENT - - assert all([action_result for action_result in action_results if len(action_result.document_results) == len(docs)]) - - for action_result in action_results: - assert not hasattr(action_result, "statistics") - for doc in action_result.document_results: - assert doc.statistics + pages.append(p) + assert len(pages) == len(docs) + + action_order = [ + _AnalyzeActionsType.RECOGNIZE_ENTITIES, + _AnalyzeActionsType.EXTRACT_KEY_PHRASES, + _AnalyzeActionsType.RECOGNIZE_PII_ENTITIES, + _AnalyzeActionsType.RECOGNIZE_LINKED_ENTITIES, + _AnalyzeActionsType.ANALYZE_SENTIMENT, + ] + for document_results in pages: + assert len(document_results) == len(action_order) + for document_result in document_results: + assert document_result.statistics + assert document_result.statistics.character_count + assert document_result.statistics.transaction_count @GlobalTextAnalyticsAccountPreparer() @TextAnalyticsClientPreparer() @@ -667,38 +677,26 @@ async def test_multiple_pages_of_results_returned_successfully(self, client): async for p in result: pages.append(p) - recognize_entities_results = [] - extract_key_phrases_results = [] - recognize_pii_entities_results = [] - recognize_linked_entities_results = [] - analyze_sentiment_results = [] - - for idx, action_result in enumerate(pages): - if idx % 5 == 0: - assert action_result.action_type == AnalyzeActionsType.RECOGNIZE_ENTITIES - recognize_entities_results.append(action_result) - elif idx % 5 == 1: - assert action_result.action_type == AnalyzeActionsType.EXTRACT_KEY_PHRASES - extract_key_phrases_results.append(action_result) - elif idx % 5 == 2: - assert action_result.action_type == AnalyzeActionsType.RECOGNIZE_PII_ENTITIES - recognize_pii_entities_results.append(action_result) - elif idx % 5 == 3: - assert action_result.action_type == AnalyzeActionsType.RECOGNIZE_LINKED_ENTITIES - recognize_linked_entities_results.append(action_result) - else: - assert action_result.action_type == AnalyzeActionsType.ANALYZE_SENTIMENT - analyze_sentiment_results.append(action_result) - if idx < 5: # first page of task results - assert len(action_result.document_results) == 20 - else: - assert len(action_result.document_results) == 5 - - assert all([action_result for action_result in recognize_entities_results if len(action_result.document_results) == len(docs)]) - assert all([action_result for action_result in extract_key_phrases_results if len(action_result.document_results) == len(docs)]) - assert all([action_result for action_result in recognize_pii_entities_results if len(action_result.document_results) == len(docs)]) - assert all([action_result for action_result in recognize_linked_entities_results if len(action_result.document_results) == len(docs)]) - assert all([action_result for action_result in analyze_sentiment_results if len(action_result.document_results) == len(docs)]) + assert len(pages) == len(docs) + action_order = [ + _AnalyzeActionsType.RECOGNIZE_ENTITIES, + _AnalyzeActionsType.EXTRACT_KEY_PHRASES, + _AnalyzeActionsType.RECOGNIZE_PII_ENTITIES, + _AnalyzeActionsType.RECOGNIZE_LINKED_ENTITIES, + _AnalyzeActionsType.ANALYZE_SENTIMENT, + ] + action_type_to_document_results = defaultdict(list) + + for doc_idx, page in enumerate(pages): + for action_idx, document_result in enumerate(page): + self.assertEqual(document_result.id, str(doc_idx)) + action_type = self.document_result_to_action_type(document_result) + self.assertEqual(action_type, action_order[action_idx]) + action_type_to_document_results[action_type].append(document_result) + + assert len(action_type_to_document_results) == len(action_order) + for document_results in action_type_to_document_results.values(): + assert len(document_results) == len(docs) @GlobalTextAnalyticsAccountPreparer() @TextAnalyticsClientPreparer() diff --git a/sdk/textanalytics/azure-ai-textanalytics/tests/test_repr.py b/sdk/textanalytics/azure-ai-textanalytics/tests/test_repr.py index b522be697b19..e754d291b0bb 100644 --- a/sdk/textanalytics/azure-ai-textanalytics/tests/test_repr.py +++ b/sdk/textanalytics/azure-ai-textanalytics/tests/test_repr.py @@ -441,54 +441,6 @@ def test_inner_error_takes_precedence(self): assert error.code == "UnsupportedLanguageCode" assert error.message == "Supplied language not supported. Pass in one of: de,en,es,fr,it,ja,ko,nl,pt-PT,zh-Hans,zh-Hant" - def test_analyze_actions_result_recognize_entities(self, recognize_entities_result): - model = _models.AnalyzeActionsResult( - document_results=[recognize_entities_result[0]], - is_error=False, - action_type=_models.AnalyzeActionsType.RECOGNIZE_ENTITIES, - completed_on=datetime.datetime(1, 1, 1) - ) - - model_repr = ( - "AnalyzeActionsResult(document_results=[{}], is_error={}, action_type={}, completed_on={})".format( - recognize_entities_result[1], False, "recognize_entities", datetime.datetime(1, 1, 1) - ) - ) - - assert repr(model) == model_repr - - def test_analyze_actions_result_recognize_pii_entities(self, recognize_pii_entities_result): - model = _models.AnalyzeActionsResult( - document_results=[recognize_pii_entities_result[0]], - is_error=False, - action_type=_models.AnalyzeActionsType.RECOGNIZE_PII_ENTITIES, - completed_on=datetime.datetime(1, 1, 1) - ) - - model_repr = ( - "AnalyzeActionsResult(document_results=[{}], is_error={}, action_type={}, completed_on={})".format( - recognize_pii_entities_result[1], False, "recognize_pii_entities", datetime.datetime(1, 1, 1) - ) - ) - - assert repr(model) == model_repr - - def test_analyze_actions_result_extract_key_phrases(self, extract_key_phrases_result): - model = _models.AnalyzeActionsResult( - document_results=[extract_key_phrases_result[0]], - is_error=False, - action_type=_models.AnalyzeActionsType.EXTRACT_KEY_PHRASES, - completed_on=datetime.datetime(1, 1, 1) - ) - - model_repr = ( - "AnalyzeActionsResult(document_results=[{}], is_error={}, action_type={}, completed_on={})".format( - extract_key_phrases_result[1], False, "extract_key_phrases", datetime.datetime(1, 1, 1) - ) - ) - - assert repr(model) == model_repr - def test_analyze_healthcare_entities_result_item( self, healthcare_entity, healthcare_relation, text_analytics_warning, text_document_statistics ): diff --git a/sdk/textanalytics/azure-ai-textanalytics/tests/testcase.py b/sdk/textanalytics/azure-ai-textanalytics/tests/testcase.py index af101458e0c8..7b2b5dc6d288 100644 --- a/sdk/textanalytics/azure-ai-textanalytics/tests/testcase.py +++ b/sdk/textanalytics/azure-ai-textanalytics/tests/testcase.py @@ -14,6 +14,14 @@ FakeResource, ResourceGroupPreparer, ) +from azure.ai.textanalytics import ( + RecognizeEntitiesResult, + RecognizeLinkedEntitiesResult, + RecognizePiiEntitiesResult, + AnalyzeSentimentResult, + ExtractKeyPhrasesResult, + _AnalyzeActionsType +) from devtools_testutils.cognitiveservices_testcase import CognitiveServicesAccountPreparer from azure_devtools.scenario_tests import ReplayableTest @@ -87,6 +95,19 @@ def assert_healthcare_entities_equal(self, entity_a, entity_b): assert entity_a.length == entity_b.length assert entity_a.offset == entity_b.offset + def document_result_to_action_type(self, document_result): + if isinstance(document_result, RecognizePiiEntitiesResult): + return _AnalyzeActionsType.RECOGNIZE_PII_ENTITIES + if isinstance(document_result, RecognizeEntitiesResult): + return _AnalyzeActionsType.RECOGNIZE_ENTITIES + if isinstance(document_result, RecognizeLinkedEntitiesResult): + return _AnalyzeActionsType.RECOGNIZE_LINKED_ENTITIES + if isinstance(document_result, AnalyzeSentimentResult): + return _AnalyzeActionsType.ANALYZE_SENTIMENT + if isinstance(document_result, ExtractKeyPhrasesResult): + return _AnalyzeActionsType.EXTRACT_KEY_PHRASES + raise ValueError("Your action result doesn't match any of the action types") + class GlobalResourceGroupPreparer(AzureMgmtPreparer): def __init__(self):