Skip to content

Commit

Permalink
[TA] Return action error as in ActionResult (#19013)
Browse files Browse the repository at this point in the history
Return Actions error in the ...ActionResult classes instead of throwing.
  • Loading branch information
mssfang authored Feb 5, 2021
1 parent b41a7ab commit e34b2cb
Show file tree
Hide file tree
Showing 11 changed files with 1,247 additions and 70 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,16 @@
import com.azure.ai.textanalytics.implementation.models.TasksStateTasksEntityRecognitionPiiTasksItem;
import com.azure.ai.textanalytics.implementation.models.TasksStateTasksEntityRecognitionTasksItem;
import com.azure.ai.textanalytics.implementation.models.TasksStateTasksKeyPhraseExtractionTasksItem;
import com.azure.ai.textanalytics.implementation.models.TextAnalyticsError;
import com.azure.ai.textanalytics.models.AnalyzeBatchActionsOperationDetail;
import com.azure.ai.textanalytics.models.AnalyzeBatchActionsOptions;
import com.azure.ai.textanalytics.models.AnalyzeBatchActionsResult;
import com.azure.ai.textanalytics.models.ExtractKeyPhrasesActionResult;
import com.azure.ai.textanalytics.models.RecognizeEntitiesActionResult;
import com.azure.ai.textanalytics.models.RecognizePiiEntitiesActionResult;
import com.azure.ai.textanalytics.models.TextAnalyticsActionResult;
import com.azure.ai.textanalytics.models.TextAnalyticsActions;
import com.azure.ai.textanalytics.models.TextAnalyticsErrorCode;
import com.azure.ai.textanalytics.models.TextDocumentBatchStatistics;
import com.azure.ai.textanalytics.models.TextDocumentInput;
import com.azure.core.http.rest.PagedFlux;
Expand All @@ -51,10 +54,13 @@
import com.azure.core.util.polling.PollingContext;
import reactor.core.publisher.Mono;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.function.Function;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import java.util.stream.Collectors;
import java.util.stream.StreamSupport;

Expand All @@ -71,6 +77,9 @@
import static com.azure.core.util.tracing.Tracer.AZ_TRACING_NAMESPACE_KEY;

class AnalyzeBatchActionsAsyncClient {
private static final String REGEX_ACTION_ERROR_TARGET =
"#/tasks/(keyPhraseExtractionTasks|entityRecognitionPiiTasks|entityRecognitionTasks)/(\\d+)";

private final ClientLogger logger = new ClientLogger(AnalyzeBatchActionsAsyncClient.class);
private final TextAnalyticsClientImpl service;

Expand Down Expand Up @@ -311,45 +320,71 @@ private AnalyzeBatchActionsResult toAnalyzeTasks(AnalyzeJobState analyzeJobState
tasksStateTasks.getEntityRecognitionTasks();
final List<TasksStateTasksKeyPhraseExtractionTasksItem> keyPhraseExtractionTasks =
tasksStateTasks.getKeyPhraseExtractionTasks();
IterableStream<RecognizeEntitiesActionResult> recognizeEntitiesActionResults = null;
IterableStream<RecognizePiiEntitiesActionResult> recognizePiiEntitiesActionResults = null;
IterableStream<ExtractKeyPhrasesActionResult> extractKeyPhrasesActionResults = null;

List<RecognizeEntitiesActionResult> recognizeEntitiesActionResults = new ArrayList<>();
List<RecognizePiiEntitiesActionResult> recognizePiiEntitiesActionResults = new ArrayList<>();
List<ExtractKeyPhrasesActionResult> extractKeyPhrasesActionResults = new ArrayList<>();
if (!CoreUtils.isNullOrEmpty(entityRecognitionTasksItems)) {
recognizeEntitiesActionResults = IterableStream.of(entityRecognitionTasksItems.stream()
.map(taskItem -> {
RecognizeEntitiesActionResult actionResult = new RecognizeEntitiesActionResult();
RecognizeEntitiesActionResultPropertiesHelper.setResult(actionResult,
toRecognizeEntitiesResultCollectionResponse(taskItem.getResults()));
TextAnalyticsActionResultPropertiesHelper.setCompletedAt(actionResult,
taskItem.getLastUpdateDateTime());
return actionResult;
})
.collect(Collectors.toList()));
for (int i = 0; i < entityRecognitionTasksItems.size(); i++) {
final TasksStateTasksEntityRecognitionTasksItem taskItem = entityRecognitionTasksItems.get(i);
final RecognizeEntitiesActionResult actionResult = new RecognizeEntitiesActionResult();
RecognizeEntitiesActionResultPropertiesHelper.setResult(actionResult,
toRecognizeEntitiesResultCollectionResponse(taskItem.getResults()));
TextAnalyticsActionResultPropertiesHelper.setCompletedAt(actionResult,
taskItem.getLastUpdateDateTime());
recognizeEntitiesActionResults.add(actionResult);
}
}
if (!CoreUtils.isNullOrEmpty(piiTasksItems)) {
recognizePiiEntitiesActionResults = IterableStream.of(piiTasksItems.stream()
.map(taskItem -> {
RecognizePiiEntitiesActionResult actionResult = new RecognizePiiEntitiesActionResult();
RecognizePiiEntitiesActionResultPropertiesHelper.setResult(actionResult,
toRecognizePiiEntitiesResultCollection(taskItem.getResults()));
TextAnalyticsActionResultPropertiesHelper.setCompletedAt(actionResult,
taskItem.getLastUpdateDateTime());
return actionResult;
})
.collect(Collectors.toList()));
for (int i = 0; i < piiTasksItems.size(); i++) {
final TasksStateTasksEntityRecognitionPiiTasksItem taskItem = piiTasksItems.get(i);
final RecognizePiiEntitiesActionResult actionResult = new RecognizePiiEntitiesActionResult();
RecognizePiiEntitiesActionResultPropertiesHelper.setResult(actionResult,
toRecognizePiiEntitiesResultCollection(taskItem.getResults()));
TextAnalyticsActionResultPropertiesHelper.setCompletedAt(actionResult,
taskItem.getLastUpdateDateTime());
recognizePiiEntitiesActionResults.add(actionResult);
}
}
if (!CoreUtils.isNullOrEmpty(keyPhraseExtractionTasks)) {
extractKeyPhrasesActionResults = IterableStream.of(keyPhraseExtractionTasks.stream()
.map(taskItem -> {
ExtractKeyPhrasesActionResult actionResult = new ExtractKeyPhrasesActionResult();
ExtractKeyPhrasesActionResultPropertiesHelper.setResult(actionResult,
toExtractKeyPhrasesResultCollection(taskItem.getResults()));
TextAnalyticsActionResultPropertiesHelper.setCompletedAt(actionResult,
taskItem.getLastUpdateDateTime());
return actionResult;
})
.collect(Collectors.toList()));
for (int i = 0; i < keyPhraseExtractionTasks.size(); i++) {
final TasksStateTasksKeyPhraseExtractionTasksItem taskItem = keyPhraseExtractionTasks.get(i);
final ExtractKeyPhrasesActionResult actionResult = new ExtractKeyPhrasesActionResult();
ExtractKeyPhrasesActionResultPropertiesHelper.setResult(actionResult,
toExtractKeyPhrasesResultCollection(taskItem.getResults()));
TextAnalyticsActionResultPropertiesHelper.setCompletedAt(actionResult,
taskItem.getLastUpdateDateTime());
extractKeyPhrasesActionResults.add(actionResult);
}
}

final List<TextAnalyticsError> errors = analyzeJobState.getErrors();
if (!CoreUtils.isNullOrEmpty(errors)) {
for (TextAnalyticsError error : errors) {
final String[] targetPair = parseActionErrorTarget(error.getTarget());
final String taskName = targetPair[0];
final Integer taskIndex = Integer.valueOf(targetPair[1]);
final TextAnalyticsActionResult actionResult;
if ("entityRecognitionTasks".equals(taskName)) {
actionResult = recognizeEntitiesActionResults.get(taskIndex);
} else if ("entityRecognitionPiiTasks".equals(taskName)) {
actionResult = recognizePiiEntitiesActionResults.get(taskIndex);
} else if ("keyPhraseExtractionTasks".equals(taskName)) {
actionResult = extractKeyPhrasesActionResults.get(taskIndex);
} else {
throw logger.logExceptionAsError(new RuntimeException(
"Invalid task name in target reference, " + taskName));
}

TextAnalyticsActionResultPropertiesHelper.setIsError(actionResult, true);
TextAnalyticsActionResultPropertiesHelper.setError(actionResult,
new com.azure.ai.textanalytics.models.TextAnalyticsError(
TextAnalyticsErrorCode.fromString(
error.getCode() == null ? null : error.getCode().toString()),
error.getMessage(), null));
}
}

final AnalyzeBatchActionsResult analyzeBatchActionsResult = new AnalyzeBatchActionsResult();

final RequestStatistics requestStatistics = analyzeJobState.getStatistics();
Expand All @@ -363,11 +398,11 @@ private AnalyzeBatchActionsResult toAnalyzeTasks(AnalyzeJobState analyzeJobState

AnalyzeBatchActionsResultPropertiesHelper.setStatistics(analyzeBatchActionsResult, batchStatistics);
AnalyzeBatchActionsResultPropertiesHelper.setRecognizeEntitiesActionResults(analyzeBatchActionsResult,
recognizeEntitiesActionResults);
IterableStream.of(recognizeEntitiesActionResults));
AnalyzeBatchActionsResultPropertiesHelper.setRecognizePiiEntitiesActionResults(analyzeBatchActionsResult,
recognizePiiEntitiesActionResults);
IterableStream.of(recognizePiiEntitiesActionResults));
AnalyzeBatchActionsResultPropertiesHelper.setExtractKeyPhrasesActionResults(analyzeBatchActionsResult,
extractKeyPhrasesActionResults);
IterableStream.of(extractKeyPhrasesActionResults));
return analyzeBatchActionsResult;
}

Expand Down Expand Up @@ -425,4 +460,20 @@ private AnalyzeBatchActionsOptions getNotNullAnalyzeBatchActionsOptions(AnalyzeB
private String getNotNullModelVersion(String modelVersion) {
return modelVersion == null ? "latest" : modelVersion;
}

private String[] parseActionErrorTarget(String targetReference) {
if (CoreUtils.isNullOrEmpty(targetReference)) {
throw logger.logExceptionAsError(new RuntimeException(
"Expected an error with a target field referencing an action but did not get one"));
}
// action could be failed and the target reference is "#/tasks/keyPhraseExtractionTasks/0";
final Pattern pattern = Pattern.compile(REGEX_ACTION_ERROR_TARGET, Pattern.MULTILINE);
final Matcher matcher = pattern.matcher(targetReference);
String[] taskNameIdPair = new String[2];
while (matcher.find()) {
taskNameIdPair[0] = matcher.group(1);
taskNameIdPair[1] = matcher.group(2);
}
return taskNameIdPair;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ public static void main(String[] args) {
// Task operation statistics
while (syncPoller.poll().getStatus() == LongRunningOperationStatus.IN_PROGRESS) {
final AnalyzeBatchActionsOperationDetail operationResult = syncPoller.poll().getValue();
System.out.printf("Action display name: %s, Successfully completed tasks: %d, in-process tasks: %d, failed tasks: %d, total tasks: %d%n",
System.out.printf("Action display name: %s, Successfully completed actions: %d, in-process actions: %d, failed actions: %d, total actions: %d%n",
operationResult.getDisplayName(), operationResult.getActionsSucceeded(),
operationResult.getActionsInProgress(), operationResult.getActionsFailed(),
operationResult.getActionsInTotal());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ public static void main(String[] args) {
new AnalyzeBatchActionsOptions().setIncludeStatistics(false))
.flatMap(result -> {
AnalyzeBatchActionsOperationDetail operationResult = result.getValue();
System.out.printf("Action display name: %s, Successfully completed tasks: %d, in-process tasks: %d, failed tasks: %d, total tasks: %d%n",
System.out.printf("Action display name: %s, Successfully completed actions: %d, in-process actions: %d, failed actions: %d, total actions: %d%n",
operationResult.getDisplayName(), operationResult.getActionsSucceeded(),
operationResult.getActionsInProgress(), operationResult.getActionsFailed(), operationResult.getActionsInTotal());
return result.getFinalResult();
Expand Down
Loading

0 comments on commit e34b2cb

Please sign in to comment.