Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -1409,13 +1409,14 @@ def __repr__(self, **kwargs):
return "RecognizeEntitiesAction(model_version={}, string_index_type={}, disable_service_logs={})" \
.format(self.model_version, self.string_index_type, self.disable_service_logs)[:1024]

def to_generated(self):
def to_generated(self, task_id):
return _latest_preview_models.EntitiesTask(
parameters=_latest_preview_models.EntitiesTaskParameters(
model_version=self.model_version,
string_index_type=self.string_index_type,
logging_opt_out=self.disable_service_logs,
)
),
task_name=task_id
)


Expand Down Expand Up @@ -1480,14 +1481,15 @@ def __repr__(self, **kwargs):
self.disable_service_logs,
)[:1024]

def to_generated(self):
def to_generated(self, task_id):
return _latest_preview_models.SentimentAnalysisTask(
parameters=_latest_preview_models.SentimentAnalysisTaskParameters(
model_version=self.model_version,
opinion_mining=self.show_opinion_mining,
string_index_type=self.string_index_type,
logging_opt_out=self.disable_service_logs,
)
),
task_name=task_id
)


Expand Down Expand Up @@ -1546,14 +1548,15 @@ def __repr__(self, **kwargs):
self.disable_service_logs,
)[:1024]

def to_generated(self):
def to_generated(self, task_id):
return _latest_preview_models.PiiTask(
parameters=_latest_preview_models.PiiTaskParameters(
model_version=self.model_version,
domain=self.domain_filter,
string_index_type=self.string_index_type,
logging_opt_out=self.disable_service_logs
)
),
task_name=task_id
)


Expand Down Expand Up @@ -1593,12 +1596,13 @@ def __repr__(self, **kwargs):
return "ExtractKeyPhrasesAction(model_version={}, disable_service_logs={})" \
.format(self.model_version, self.disable_service_logs)[:1024]

def to_generated(self):
def to_generated(self, task_id):
return _latest_preview_models.KeyPhrasesTask(
parameters=_latest_preview_models.KeyPhrasesTaskParameters(
model_version=self.model_version,
logging_opt_out=self.disable_service_logs,
)
),
task_name=task_id
)


Expand Down Expand Up @@ -1649,11 +1653,12 @@ def __repr__(self, **kwargs):
self.model_version, self.string_index_type, self.disable_service_logs
)[:1024]

def to_generated(self):
def to_generated(self, task_id):
return _latest_preview_models.EntityLinkingTask(
parameters=_latest_preview_models.EntityLinkingTaskParameters(
model_version=self.model_version,
string_index_type=self.string_index_type,
logging_opt_out=self.disable_service_logs,
)
),
task_name=task_id
)
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@


import six

from ._generated.models import EntitiesTask, PiiTask, EntityLinkingTask, SentimentAnalysisTask
from ._models import (
DetectLanguageInput,
TextDocumentInput,
Expand Down Expand Up @@ -80,6 +80,17 @@ def _determine_action_type(action):
return _AnalyzeActionsType.ANALYZE_SENTIMENT
return _AnalyzeActionsType.EXTRACT_KEY_PHRASES

def _determine_task_type(action):
if isinstance(action, EntitiesTask):
return _AnalyzeActionsType.RECOGNIZE_ENTITIES
if isinstance(action, PiiTask):
return _AnalyzeActionsType.RECOGNIZE_PII_ENTITIES
if isinstance(action, EntityLinkingTask):
return _AnalyzeActionsType.RECOGNIZE_LINKED_ENTITIES
if isinstance(action, SentimentAnalysisTask):
return _AnalyzeActionsType.ANALYZE_SENTIMENT
return _AnalyzeActionsType.EXTRACT_KEY_PHRASES

def _check_string_index_type_arg(string_index_type_arg, api_version, string_index_type_default="UnicodeCodePoint"):
string_index_type = None

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -200,31 +200,29 @@ def _get_property_name_from_task_type(task_type):
return "sentiment_analysis_tasks"
return "key_phrase_extraction_tasks"

def _get_good_result(current_task_type, index_of_task_result, doc_id_order, response_headers, returned_tasks_object):
def _get_good_result(task, doc_id_order, response_headers, returned_tasks_object):
current_task_type, task_name = task
deserialization_callback = _get_deserialization_callback_from_task_type(current_task_type)
property_name = _get_property_name_from_task_type(current_task_type)
response_task_to_deserialize = getattr(returned_tasks_object, property_name)[index_of_task_result]
response_task_to_deserialize = \
[task for task in getattr(returned_tasks_object, property_name) if task.task_name == task_name][0]
return deserialization_callback(
doc_id_order, response_task_to_deserialize.results, response_headers, lro=True
)

def get_iter_items(doc_id_order, task_order, response_headers, analyze_job_state):
iter_items = defaultdict(list) # map doc id to action results
task_type_to_index = defaultdict(int) # need to keep track of how many of each type of tasks we've seen
iter_items = defaultdict(list) # map doc id to action results
returned_tasks_object = analyze_job_state.tasks
for current_task_type in task_order:
index_of_task_result = task_type_to_index[current_task_type]

for task in task_order:
results = _get_good_result(
current_task_type,
index_of_task_result,
task,
doc_id_order,
response_headers,
returned_tasks_object,
)
for result in results:
iter_items[result.id].append(result)

task_type_to_index[current_task_type] += 1
return [
iter_items[doc_id]
for doc_id in doc_id_order
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from ._base_client import TextAnalyticsClientBase
from ._request_handlers import (
_validate_input,
_determine_action_type,
_determine_task_type,
_check_string_index_type_arg
)
from ._response_handlers import (
Expand Down Expand Up @@ -896,32 +896,26 @@ def begin_analyze_actions( # type: ignore
continuation_token = kwargs.pop("continuation_token", None)

doc_id_order = [doc.get("id") for doc in docs.documents]
task_order = [_determine_action_type(action) for action in actions]
generated_tasks = [action.to_generated(str(idx)) for idx, action in enumerate(actions)]
task_order = [(_determine_task_type(a), a.task_name) for a in generated_tasks]

try:
analyze_tasks = self._client.models(api_version='v3.1').JobManifestTasks(
entity_recognition_tasks=[
t.to_generated() for t in
[a for a in actions if _determine_action_type(a) == _AnalyzeActionsType.RECOGNIZE_ENTITIES]
a for a in generated_tasks if _determine_task_type(a) == _AnalyzeActionsType.RECOGNIZE_ENTITIES
],
entity_recognition_pii_tasks=[
t.to_generated() for t in
[a for a in actions if _determine_action_type(a) == _AnalyzeActionsType.RECOGNIZE_PII_ENTITIES]
a for a in generated_tasks if _determine_task_type(a) == _AnalyzeActionsType.RECOGNIZE_PII_ENTITIES
],
key_phrase_extraction_tasks=[
t.to_generated() for t in
[a for a in actions if _determine_action_type(a) == _AnalyzeActionsType.EXTRACT_KEY_PHRASES]
a for a in generated_tasks if _determine_task_type(a) == _AnalyzeActionsType.EXTRACT_KEY_PHRASES
],
entity_linking_tasks=[
t.to_generated() for t in
[
a for a in actions
if _determine_action_type(a) == _AnalyzeActionsType.RECOGNIZE_LINKED_ENTITIES
]
a for a in generated_tasks
if _determine_task_type(a) == _AnalyzeActionsType.RECOGNIZE_LINKED_ENTITIES
],
sentiment_analysis_tasks=[
t.to_generated() for t in
[a for a in actions if _determine_action_type(a) == _AnalyzeActionsType.ANALYZE_SENTIMENT]
a for a in generated_tasks if _determine_task_type(a) == _AnalyzeActionsType.ANALYZE_SENTIMENT
]
)
analyze_body = self._client.models(api_version='v3.1').AnalyzeBatchInput(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from azure.core.exceptions import HttpResponseError
from azure.core.credentials import AzureKeyCredential
from ._base_client_async import AsyncTextAnalyticsClientBase
from .._request_handlers import _validate_input, _determine_action_type, _check_string_index_type_arg
from .._request_handlers import _validate_input, _determine_task_type, _check_string_index_type_arg
from .._response_handlers import (
process_http_response_error,
entities_result,
Expand Down Expand Up @@ -838,7 +838,7 @@ async def begin_analyze_actions( # type: ignore
:keyword bool show_stats: If set to true, response will contain document level statistics.
:keyword int polling_interval: Waiting time between two polls for LRO operations
if no Retry-After header is present. Defaults to 30 seconds.
:return: An instance of an LROPoller. Call `result()` on the poller
:return: An instance of an AsyncAnalyzeActionsLROPoller. Call `result()` on the poller
object to return a pageable heterogeneous list of lists. This list of lists is first ordered
by the documents you input, then ordered by the actions you input. For example,
if you have documents input ["Hello", "world"], and actions
Expand All @@ -850,7 +850,7 @@ async def begin_analyze_actions( # type: ignore
Then, you will get the :class:`~azure.ai.textanalytics.RecognizeEntitiesResult` and
:class:`~azure.ai.textanalytics.AnalyzeSentimentResult` of "world".
:rtype:
~azure.core.polling.AsyncLROPoller[~azure.core.async_paging.AsyncItemPaged[
~azure.core.polling.AsyncAnalyzeActionsLROPoller[~azure.core.async_paging.AsyncItemPaged[
list[
RecognizeEntitiesResult or RecognizeLinkedEntitiesResult or RecognizePiiEntitiesResult or
ExtractKeyPhrasesResult or AnalyzeSentimentResult
Expand Down Expand Up @@ -879,32 +879,26 @@ async def begin_analyze_actions( # type: ignore
continuation_token = kwargs.pop("continuation_token", None)

doc_id_order = [doc.get("id") for doc in docs.documents]
task_order = [_determine_action_type(action) for action in actions]
generated_tasks = [action.to_generated(str(idx)) for idx, action in enumerate(actions)]
task_order = [(_determine_task_type(a), a.task_name) for a in generated_tasks]

try:
analyze_tasks = self._client.models(api_version='v3.1').JobManifestTasks(
entity_recognition_tasks=[
t.to_generated() for t in
[a for a in actions if _determine_action_type(a) == _AnalyzeActionsType.RECOGNIZE_ENTITIES]
a for a in generated_tasks if _determine_task_type(a) == _AnalyzeActionsType.RECOGNIZE_ENTITIES
],
entity_recognition_pii_tasks=[
t.to_generated() for t in
[a for a in actions if _determine_action_type(a) == _AnalyzeActionsType.RECOGNIZE_PII_ENTITIES]
a for a in generated_tasks if _determine_task_type(a) == _AnalyzeActionsType.RECOGNIZE_PII_ENTITIES
],
key_phrase_extraction_tasks=[
t.to_generated() for t in
[a for a in actions if _determine_action_type(a) == _AnalyzeActionsType.EXTRACT_KEY_PHRASES]
a for a in generated_tasks if _determine_task_type(a) == _AnalyzeActionsType.EXTRACT_KEY_PHRASES
],
entity_linking_tasks=[
t.to_generated() for t in
[
a for a in actions if \
_determine_action_type(a) == _AnalyzeActionsType.RECOGNIZE_LINKED_ENTITIES
]
a for a in generated_tasks
if _determine_task_type(a) == _AnalyzeActionsType.RECOGNIZE_LINKED_ENTITIES
],
sentiment_analysis_tasks=[
t.to_generated() for t in
[a for a in actions if _determine_action_type(a) == _AnalyzeActionsType.ANALYZE_SENTIMENT]
a for a in generated_tasks if _determine_task_type(a) == _AnalyzeActionsType.ANALYZE_SENTIMENT
]
)
analyze_body = self._client.models(api_version='v3.1').AnalyzeBatchInput(
Expand Down
81 changes: 81 additions & 0 deletions sdk/textanalytics/azure-ai-textanalytics/tests/test_analyze.py
Original file line number Diff line number Diff line change
Expand Up @@ -679,3 +679,84 @@ def callback(resp):
polling_interval=self._interval(),
raw_response_hook=callback,
).result()

@GlobalTextAnalyticsAccountPreparer()
@TextAnalyticsClientPreparer()
def test_partial_success_for_actions(self, client):
docs = [{"id": "1", "language": "tr", "text": "I did not like the hotel we stayed at."},
{"id": "2", "language": "en", "text": "I did not like the hotel we stayed at."}]

response = client.begin_analyze_actions(
docs,
actions=[
AnalyzeSentimentAction(),
RecognizePiiEntitiesAction(),
],
polling_interval=self._interval(),
).result()

action_results = list(response)
assert len(action_results) == len(docs)
action_order = [
_AnalyzeActionsType.ANALYZE_SENTIMENT,
_AnalyzeActionsType.RECOGNIZE_PII_ENTITIES,
]

assert len(action_results[0]) == len(action_order)
assert len(action_results[1]) == len(action_order)

# first doc
assert isinstance(action_results[0][0], AnalyzeSentimentResult)
assert action_results[0][0].id == "1"
assert action_results[0][1].is_error
assert action_results[0][1].id == "1"

# second doc
assert isinstance(action_results[1][0], AnalyzeSentimentResult)
assert action_results[1][0].id == "2"
assert isinstance(action_results[1][1], RecognizePiiEntitiesResult)
assert action_results[1][1].id == "2"

@pytest.mark.skip("Service bug - https://msazure.visualstudio.com/Cognitive%20Services/_workitems/edit/10145316")
@GlobalTextAnalyticsAccountPreparer()
@TextAnalyticsClientPreparer()
def test_multiple_of_same_action(self, client):
docs = [{"id": "1", "text": "My SSN is 859-98-0987."},
{"id": "2", "text": "Is 998.214.865-68 your Brazilian CPF number?"}]

response = client.begin_analyze_actions(
docs,
actions=[
AnalyzeSentimentAction(),
RecognizePiiEntitiesAction(),
RecognizePiiEntitiesAction(domain_filter="phi"),
],
polling_interval=self._interval(),
).result()

action_results = list(response)
assert len(action_results) == len(docs)
action_order = [
_AnalyzeActionsType.ANALYZE_SENTIMENT,
_AnalyzeActionsType.RECOGNIZE_PII_ENTITIES,
_AnalyzeActionsType.RECOGNIZE_PII_ENTITIES,
]

assert len(action_results[0]) == len(action_order)
assert len(action_results[1]) == len(action_order)

# first doc
assert isinstance(action_results[0][0], AnalyzeSentimentResult)
assert action_results[0][0].id == "1"
assert isinstance(action_results[0][1], RecognizePiiEntitiesResult)
assert action_results[0][1].id == "1"
assert isinstance(action_results[0][2], RecognizePiiEntitiesResult)
assert action_results[0][2].id == "1"

# second doc
assert isinstance(action_results[1][0], AnalyzeSentimentResult)
assert action_results[1][0].id == "2"
assert isinstance(action_results[1][1], RecognizePiiEntitiesResult)
assert action_results[1][1].id == "2"
assert isinstance(action_results[1][2], RecognizePiiEntitiesResult)
assert action_results[1][2].id == "2"
Loading