From 7881429407ba9a28be427b320db094797d5cff89 Mon Sep 17 00:00:00 2001 From: Sameeksha Vaity Date: Thu, 19 Dec 2019 01:39:57 -0800 Subject: [PATCH] refactor methods --- .../TextAnalyticsAsyncClient.java | 338 ++++++++---------- .../ai/textanalytics/TextAnalyticsClient.java | 23 +- .../batch/RecognizePiiBatchDocuments.java | 4 +- .../TextAnalyticsClientTestBase.java | 9 +- 4 files changed, 166 insertions(+), 208 deletions(-) diff --git a/sdk/textanalytics/azure-ai-textanalytics/src/main/java/com/azure/ai/textanalytics/TextAnalyticsAsyncClient.java b/sdk/textanalytics/azure-ai-textanalytics/src/main/java/com/azure/ai/textanalytics/TextAnalyticsAsyncClient.java index 3e89a57f5548..2cfd3abc03bb 100644 --- a/sdk/textanalytics/azure-ai-textanalytics/src/main/java/com/azure/ai/textanalytics/TextAnalyticsAsyncClient.java +++ b/sdk/textanalytics/azure-ai-textanalytics/src/main/java/com/azure/ai/textanalytics/TextAnalyticsAsyncClient.java @@ -10,24 +10,20 @@ import com.azure.ai.textanalytics.implementation.models.DocumentLanguage; import com.azure.ai.textanalytics.implementation.models.DocumentLinkedEntities; import com.azure.ai.textanalytics.implementation.models.DocumentSentiment; -import com.azure.ai.textanalytics.implementation.models.DocumentSentimentValue; import com.azure.ai.textanalytics.implementation.models.DocumentStatistics; import com.azure.ai.textanalytics.implementation.models.EntitiesResult; -import com.azure.ai.textanalytics.implementation.models.Entity; import com.azure.ai.textanalytics.implementation.models.EntityLinkingResult; import com.azure.ai.textanalytics.implementation.models.LanguageBatchInput; import com.azure.ai.textanalytics.implementation.models.LanguageInput; import com.azure.ai.textanalytics.implementation.models.LanguageResult; import com.azure.ai.textanalytics.implementation.models.LinkedEntity; -import com.azure.ai.textanalytics.implementation.models.Match; import com.azure.ai.textanalytics.implementation.models.MultiLanguageBatchInput; import com.azure.ai.textanalytics.implementation.models.MultiLanguageInput; import com.azure.ai.textanalytics.implementation.models.RequestStatistics; -import com.azure.ai.textanalytics.implementation.models.SentenceSentiment; -import com.azure.ai.textanalytics.implementation.models.SentenceSentimentValue; import com.azure.ai.textanalytics.implementation.models.SentimentConfidenceScorePerLabel; import com.azure.ai.textanalytics.implementation.models.SentimentResponse; import com.azure.ai.textanalytics.implementation.models.TextAnalyticsError; +import com.azure.ai.textanalytics.models.AnalyzeSentimentResult; import com.azure.ai.textanalytics.models.DetectLanguageInput; import com.azure.ai.textanalytics.models.DetectLanguageResult; import com.azure.ai.textanalytics.models.DetectedLanguage; @@ -38,6 +34,7 @@ import com.azure.ai.textanalytics.models.NamedEntity; import com.azure.ai.textanalytics.models.RecognizeEntitiesResult; import com.azure.ai.textanalytics.models.RecognizeLinkedEntitiesResult; +import com.azure.ai.textanalytics.models.RecognizePiiEntitiesResult; import com.azure.ai.textanalytics.models.TextAnalyticsClientOptions; import com.azure.ai.textanalytics.models.TextAnalyticsRequestOptions; import com.azure.ai.textanalytics.models.TextDocumentBatchStatistics; @@ -45,7 +42,6 @@ import com.azure.ai.textanalytics.models.TextDocumentStatistics; import com.azure.ai.textanalytics.models.TextSentiment; import com.azure.ai.textanalytics.models.TextSentimentClass; -import com.azure.ai.textanalytics.models.AnalyzeSentimentResult; import com.azure.core.annotation.ReturnType; import com.azure.core.annotation.ServiceClient; import com.azure.core.annotation.ServiceMethod; @@ -143,8 +139,8 @@ public Mono detectLanguage(String text) { * @param countryHint Accepts two letter country codes specified by ISO 3166-1 alpha-2. Defaults to "US" if not * specified. * - * @return A {@link Mono} containing a {@link Response} whose {@link Response#getValue() value} has the {@link - * DetectLanguageResult detected language} of the text. + * @return A {@link Mono} containing a {@link Response} whose {@link Response#getValue() value} has the + * {@link DetectLanguageResult detected language} of the text. * * @throws NullPointerException if {@code text} is {@code null}. */ @@ -262,7 +258,10 @@ Mono>> detectBatchLangua List textInputs, TextAnalyticsRequestOptions options, Context context) { final LanguageBatchInput languageBatchInput = new LanguageBatchInput() - .setDocuments(convertToLanguageInput(textInputs)); + .setDocuments(textInputs.stream().map(detectLanguageInput -> new LanguageInput() + .setId(detectLanguageInput.getId()).setText(detectLanguageInput.getText()) + .setCountryHint(detectLanguageInput.getCountryHint())).collect(Collectors.toList())); + return service.languagesWithRestResponseAsync( languageBatchInput, options == null ? null : options.getModelVersion(), options == null ? null : options.showStatistics(), context) @@ -430,7 +429,6 @@ Mono>> recognizeBatch } // PII Entity - /** * Returns a list of personal information entities ("SSN", "Bank Account", etc) in the text. For the list of * supported entity types, check https://aka.ms/tanerpii. See https://aka.ms/talangs for the list of enabled @@ -443,7 +441,7 @@ Mono>> recognizeBatch * @throws NullPointerException if {@code text} is {@code null}. */ @ServiceMethod(returns = ReturnType.SINGLE) - public Mono recognizePiiEntities(String text) { + public Mono recognizePiiEntities(String text) { try { return recognizePiiEntitiesWithResponse(text, defaultLanguage).flatMap(FluxUtil::toMono); } catch (RuntimeException ex) { @@ -466,7 +464,7 @@ public Mono recognizePiiEntities(String text) { * @throws NullPointerException if {@code text} is {@code null}. */ @ServiceMethod(returns = ReturnType.SINGLE) - public Mono> recognizePiiEntitiesWithResponse(String text, String language) { + public Mono> recognizePiiEntitiesWithResponse(String text, String language) { try { return withContext(context -> recognizePiiEntitiesWithResponse(text, language, context)); } catch (RuntimeException ex) { @@ -474,12 +472,12 @@ public Mono> recognizePiiEntitiesWithResponse( } } - Mono> recognizePiiEntitiesWithResponse(String text, String language, + Mono> recognizePiiEntitiesWithResponse(String text, String language, Context context) { return recognizeBatchPiiEntitiesWithResponse( Arrays.asList(new TextDocumentInput(Integer.toString(0), text, language)), null, context) .flatMap(response -> { - Iterator responseItem = response.getValue().iterator(); + Iterator responseItem = response.getValue().iterator(); return Mono.just(new SimpleResponse<>(response, responseItem.next())); }); } @@ -497,7 +495,7 @@ Mono> recognizePiiEntitiesWithResponse(String * @throws NullPointerException if {@code textInputs} is {@code null}. */ @ServiceMethod(returns = ReturnType.SINGLE) - public Mono> recognizePiiEntities(List textInputs) { + public Mono> recognizePiiEntities(List textInputs) { try { return recognizePiiEntitiesWithResponse(textInputs, defaultLanguage) .flatMap(FluxUtil::toMono); @@ -521,7 +519,7 @@ public Mono> recognizePiiEntit * @throws NullPointerException if {@code textInputs} is {@code null}. */ @ServiceMethod(returns = ReturnType.SINGLE) - public Mono>> recognizePiiEntitiesWithResponse( + public Mono>> recognizePiiEntitiesWithResponse( List textInputs, String language) { try { return withContext(context -> recognizePiiEntitiesWithResponse(textInputs, language, context)); @@ -530,7 +528,7 @@ public Mono>> recogni } } - Mono>> recognizePiiEntitiesWithResponse( + Mono>> recognizePiiEntitiesWithResponse( List textInputs, String language, Context context) { List documentInputs = mapByIndex(textInputs, (index, value) -> new TextDocumentInput(index, value, language)); @@ -554,7 +552,7 @@ Mono>> recognizePiiEn * @throws NullPointerException if {@code textInputs} is {@code null}. */ @ServiceMethod(returns = ReturnType.SINGLE) - public Mono> recognizeBatchPiiEntities( + public Mono> recognizeBatchPiiEntities( List textInputs) { try { return recognizeBatchPiiEntitiesWithResponse(textInputs, null).flatMap(FluxUtil::toMono); @@ -578,7 +576,7 @@ public Mono> recognizeBatchPii * @throws NullPointerException if {@code textInputs} is {@code null}. */ @ServiceMethod(returns = ReturnType.SINGLE) - public Mono>> recognizeBatchPiiEntitiesWithResponse( + public Mono>> recognizeBatchPiiEntitiesWithResponse( List textInputs, TextAnalyticsRequestOptions options) { try { return withContext(context -> recognizeBatchPiiEntitiesWithResponse(textInputs, options, context)); @@ -587,7 +585,7 @@ public Mono>> recogni } } - Mono>> recognizeBatchPiiEntitiesWithResponse( + Mono>> recognizeBatchPiiEntitiesWithResponse( List documents, TextAnalyticsRequestOptions options, Context context) { final MultiLanguageBatchInput batchInput = new MultiLanguageBatchInput() .setDocuments(convertToMultiLanguageInput(documents)); @@ -598,11 +596,10 @@ Mono>> recognizeBatch .doOnSubscribe(ignoredValue -> logger.info("A batch of PII entities input - {}", ignoredValue)) .doOnSuccess(response -> logger.info("A batch of PII entities output - {}", response.getValue())) .doOnError(error -> logger.warning("Failed to PII entities - {}", error)) - .map(response -> new SimpleResponse<>(response, toDocumentResultCollection(response.getValue()))); + .map(response -> new SimpleResponse<>(response, toPiiDocumentResultCollection(response.getValue()))); } // Linked Entity - /** * Returns a list of recognized entities with links to a well-known knowledge base for the provided text. See * https://aka.ms/talangs for supported languages in Text Analytics API. @@ -747,8 +744,8 @@ public Mono> recognizeBa */ @ServiceMethod(returns = ReturnType.SINGLE) public Mono>> - recognizeBatchLinkedEntitiesWithResponse( - List textInputs, TextAnalyticsRequestOptions options) { + recognizeBatchLinkedEntitiesWithResponse(List textInputs, + TextAnalyticsRequestOptions options) { try { return withContext(context -> recognizeBatchLinkedEntitiesWithResponse(textInputs, options, context)); } catch (RuntimeException ex) { @@ -770,28 +767,6 @@ Mono>> recogniz .map(response -> new SimpleResponse<>(response, toDocumentResultCollection(response.getValue()))); } - private DocumentResultCollection toDocumentResultCollection( - final EntityLinkingResult entityLinkingResult) { - return new DocumentResultCollection<>(getDocumentLinkedEntities(entityLinkingResult), - entityLinkingResult.getModelVersion(), entityLinkingResult.getStatistics() == null ? null - : mapBatchStatistics(entityLinkingResult.getStatistics())); - } - - private List getDocumentLinkedEntities(final EntityLinkingResult entitiesResult) { - List validDocumentList = new ArrayList<>(); - for (DocumentLinkedEntities documentLinkedEntities : entitiesResult.getDocuments()) { - validDocumentList.add(new RecognizeLinkedEntitiesResult(documentLinkedEntities.getId(), - documentLinkedEntities.getStatistics() == null ? null - : convertToTextDocumentStatistics(documentLinkedEntities.getStatistics()), - null, mapLinkedEntity(documentLinkedEntities.getEntities()))); - } - List errorDocumentList = new ArrayList<>(); - for (DocumentError documentError : entitiesResult.getErrors()) { - final com.azure.ai.textanalytics.models.TextAnalyticsError error = convertToError(documentError.getError()); - errorDocumentList.add(new RecognizeLinkedEntitiesResult(documentError.getId(), null, error, null)); - } - return Stream.concat(validDocumentList.stream(), errorDocumentList.stream()).collect(Collectors.toList()); - } // Key Phrases /** @@ -980,7 +955,6 @@ private List getKeyPhraseResults( } // Sentiment - /** * Returns a sentiment prediction, as well as sentiment scores for each sentiment class (Positive, Negative, and * Neutral) for the document and each sentence within i @@ -1157,15 +1131,6 @@ public String getDefaultLanguage() { return defaultLanguage; } - private List convertToLanguageInput(List textInputs) { - List languageInputList = new ArrayList<>(); - for (DetectLanguageInput detectLanguageInput : textInputs) { - languageInputList.add(new LanguageInput().setId(detectLanguageInput.getId()) - .setText(detectLanguageInput.getText()).setCountryHint(detectLanguageInput.getCountryHint())); - } - return languageInputList; - } - private List convertToMultiLanguageInput(List textInputs) { List multiLanguageInputs = new ArrayList<>(); for (TextDocumentInput textDocumentInput : textInputs) { @@ -1175,34 +1140,57 @@ private List convertToMultiLanguageInput(List toDocumentResultCollection( final SentimentResponse sentimentResponse) { - return new DocumentResultCollection<>(getDocumentTextSentiment(sentimentResponse), + List analyzeSentimentResults = new ArrayList<>(); + for (DocumentSentiment documentSentiment : sentimentResponse.getDocuments()) { + analyzeSentimentResults.add(convertToTextSentimentResult(documentSentiment)); + } + for (DocumentError documentError : sentimentResponse.getErrors()) { + final com.azure.ai.textanalytics.models.TextAnalyticsError error = convertToError(documentError.getError()); + analyzeSentimentResults.add(new AnalyzeSentimentResult(documentError.getId(), null, error, null, + null)); + } + return new DocumentResultCollection<>(analyzeSentimentResults, sentimentResponse.getModelVersion(), sentimentResponse.getStatistics() == null ? null : mapBatchStatistics(sentimentResponse.getStatistics())); } - private List getDocumentTextSentiment(final SentimentResponse sentimentResponse) { - Stream validDocumentList = sentimentResponse.getDocuments().stream() - .map(this::convertToTextSentimentResult); - Stream errorDocumentList = sentimentResponse.getErrors().stream() - .map(this::convertToErrorTextSentimentResult); - - return Stream.concat(validDocumentList, errorDocumentList).collect(Collectors.toList()); - } - private AnalyzeSentimentResult convertToTextSentimentResult(final DocumentSentiment documentSentiment) { // Document text sentiment - final TextSentimentClass documentSentimentClass = convertToTextSentimentClass(documentSentiment.getSentiment()); + final TextSentimentClass documentSentimentClass = TextSentimentClass.fromString(documentSentiment. + getSentiment().toString()); if (documentSentimentClass == null) { - return null; + throw logger.logExceptionAsWarning( + new RuntimeException(String.format("'%s' is not valid text sentiment.", + documentSentiment.getSentiment()))); } - final SentimentConfidenceScorePerLabel confidenceScorePerLabel = documentSentiment.getDocumentScores(); // Sentence text sentiment - final List sentenceSentimentTexts = - convertToSentenceSentiments(documentSentiment.getSentences()); + final List sentenceSentimentTexts = documentSentiment.getSentences().stream() + .map(sentenceSentiment -> { + TextSentimentClass sentimentClass = TextSentimentClass.fromString(sentenceSentiment + .getSentiment().toString()); + if (sentimentClass == null) { + throw logger.logExceptionAsWarning( + new RuntimeException(String.format("'%s' is not valid text sentiment.", + sentenceSentiment.getSentiment()))); + } + SentimentConfidenceScorePerLabel confidenceScorePerSentence = sentenceSentiment.getSentenceScores(); + + return new TextSentiment(sentimentClass, confidenceScorePerSentence.getNegative(), + confidenceScorePerSentence.getNeutral(), confidenceScorePerSentence.getPositive(), + sentenceSentiment.getLength(), sentenceSentiment.getOffset()); + + }).collect(Collectors.toList()); return new AnalyzeSentimentResult(documentSentiment.getId(), documentSentiment.getStatistics() == null ? null @@ -1213,56 +1201,6 @@ private AnalyzeSentimentResult convertToTextSentimentResult(final DocumentSentim sentenceSentimentTexts); } - private List convertToSentenceSentiments(final List sentenceSentiments) { - final List sentenceSentimentCollection = new ArrayList<>(); - sentenceSentiments.forEach(sentenceSentiment -> { - final TextSentimentClass sentimentClass = convertToTextSentimentClass(sentenceSentiment.getSentiment()); - - final SentimentConfidenceScorePerLabel confidenceScorePerLabel = sentenceSentiment.getSentenceScores(); - - sentenceSentimentCollection.add(new TextSentiment(sentimentClass, confidenceScorePerLabel.getNegative(), - confidenceScorePerLabel.getNeutral(), confidenceScorePerLabel.getPositive(), - sentenceSentiment.getLength(), sentenceSentiment.getOffset())); - }); - return sentenceSentimentCollection; - } - - private TextSentimentClass convertToTextSentimentClass(final DocumentSentimentValue sentimentValue) { - switch (sentimentValue) { - case POSITIVE: - return TextSentimentClass.POSITIVE; - case NEUTRAL: - return TextSentimentClass.NEUTRAL; - case NEGATIVE: - return TextSentimentClass.NEGATIVE; - case MIXED: - return TextSentimentClass.MIXED; - default: - throw logger.logExceptionAsWarning( - new RuntimeException(String.format("'%s' is not valid text sentiment.", sentimentValue))); - } - } - - private TextSentimentClass convertToTextSentimentClass(final SentenceSentimentValue sentimentValue) { - switch (sentimentValue) { - case POSITIVE: - return TextSentimentClass.POSITIVE; - case NEUTRAL: - return TextSentimentClass.NEUTRAL; - case NEGATIVE: - return TextSentimentClass.NEGATIVE; - default: - throw logger.logExceptionAsWarning( - new RuntimeException(String.format("'%s' is not valid text sentiment.", sentimentValue))); - } - } - - private AnalyzeSentimentResult convertToErrorTextSentimentResult(final DocumentError documentError) { - final com.azure.ai.textanalytics.models.TextAnalyticsError error = convertToError(documentError.getError()); - return new AnalyzeSentimentResult(documentError.getId(), null, error, null, - null); - } - /** * Helper method to convert the service response of {@link LanguageResult} to {@link DocumentResultCollection}. * @@ -1272,26 +1210,24 @@ private AnalyzeSentimentResult convertToErrorTextSentimentResult(final DocumentE */ private DocumentResultCollection toDocumentResultCollection( final LanguageResult languageResult) { - return new DocumentResultCollection<>(getDocumentLanguages(languageResult), languageResult.getModelVersion(), - languageResult.getStatistics() == null ? null : mapBatchStatistics(languageResult.getStatistics())); - } - /** - * Helper method to get a combined list of error documents and valid documents. - * - * @param languageResult the {@link LanguageResult} containing both the error and document list. - * - * @return the combined error and document list. - */ - private List getDocumentLanguages(final LanguageResult languageResult) { final List detectLanguageResults = new ArrayList<>(); for (DocumentLanguage documentLanguage : languageResult.getDocuments()) { + DetectedLanguage primaryLanguage = null; + if (documentLanguage.getDetectedLanguages().size() >= 1) { + com.azure.ai.textanalytics.implementation.models.DetectedLanguage detectedLanguageResult = + documentLanguage.getDetectedLanguages().get(0); + primaryLanguage = new DetectedLanguage(detectedLanguageResult.getName(), + detectedLanguageResult.getIso6391Name(), detectedLanguageResult.getScore()); + } detectLanguageResults.add(new DetectLanguageResult(documentLanguage.getId(), documentLanguage.getStatistics() == null ? null : convertToTextDocumentStatistics(documentLanguage.getStatistics()), null, - setPrimaryLanguage(documentLanguage.getDetectedLanguages()), - convertToDetectLanguages(documentLanguage.getDetectedLanguages()))); + primaryLanguage, + documentLanguage.getDetectedLanguages().stream().map(detectedLanguage -> + new DetectedLanguage(detectedLanguage.getName(), detectedLanguage.getIso6391Name(), + detectedLanguage.getScore())).collect(Collectors.toList()))); } for (DocumentError documentError : languageResult.getErrors()) { @@ -1300,65 +1236,92 @@ private List getDocumentLanguages(final LanguageResult lan new DetectLanguageResult(documentError.getId(), null, error, null, null)); } - return detectLanguageResults; + return new DocumentResultCollection<>(detectLanguageResults, languageResult.getModelVersion(), + languageResult.getStatistics() == null ? null : mapBatchStatistics(languageResult.getStatistics())); } - private List convertToDetectLanguages( - List detectedLanguages) { - List detectedLanguagesList = new ArrayList<>(); - for (com.azure.ai.textanalytics.implementation.models.DetectedLanguage detectedLanguage : detectedLanguages) { - detectedLanguagesList.add(new DetectedLanguage(detectedLanguage.getName(), - detectedLanguage.getIso6391Name(), detectedLanguage.getScore())); + /** + * Helper method to convert the service response of {@link EntitiesResult} to {@link DocumentResultCollection}. + * + * @param entitiesResult the {@link EntitiesResult} returned by the service. + * + * @return the {@link DocumentResultCollection} of {@link DetectLanguageResult} to be returned by the SDK. + */ + private DocumentResultCollection toDocumentResultCollection( + final EntitiesResult entitiesResult) { + List recognizeEntitiesResults = new ArrayList<>(); + for (DocumentEntities documentEntities : entitiesResult.getDocuments()) { + recognizeEntitiesResults.add(new RecognizeEntitiesResult(documentEntities.getId(), + documentEntities.getStatistics() == null ? null + : convertToTextDocumentStatistics(documentEntities.getStatistics()), + null, documentEntities.getEntities().stream().map(entity -> + new NamedEntity(entity.getText(), entity.getType(), entity.getSubtype(), entity.getOffset(), + entity.getLength(), entity.getScore())).collect(Collectors.toList()))); } - return detectedLanguagesList; - } - private DetectedLanguage setPrimaryLanguage( - List detectedLanguages) { - if (detectedLanguages.size() >= 1) { - com.azure.ai.textanalytics.implementation.models.DetectedLanguage detectedLanguageResult = - detectedLanguages.get(0); - return new DetectedLanguage(detectedLanguageResult.getName(), detectedLanguageResult.getIso6391Name(), - detectedLanguageResult.getScore()); + for (DocumentError documentError : entitiesResult.getErrors()) { + final com.azure.ai.textanalytics.models.TextAnalyticsError error = convertToError(documentError.getError()); + recognizeEntitiesResults.add(new RecognizeEntitiesResult(documentError.getId(), null, error, null)); } - return null; - } - private TextDocumentBatchStatistics mapBatchStatistics(RequestStatistics statistics) { - return new TextDocumentBatchStatistics(statistics.getDocumentsCount(), statistics.getErroneousDocumentsCount(), - statistics.getValidDocumentsCount(), statistics.getTransactionsCount()); - } - - private DocumentResultCollection toDocumentResultCollection( - final EntitiesResult entitiesResult) { - return new DocumentResultCollection<>(getDocumentNamedEntities(entitiesResult), + return new DocumentResultCollection<>(recognizeEntitiesResults, entitiesResult.getModelVersion(), entitiesResult.getStatistics() == null ? null : mapBatchStatistics(entitiesResult.getStatistics())); } - private List getDocumentNamedEntities(final EntitiesResult entitiesResult) { - List validDocumentList = new ArrayList<>(); + /** + * Helper method to convert the service response of {@link EntitiesResult} to {@link DocumentResultCollection}. + * + * @param entitiesResult the {@link EntitiesResult} returned by the service. + * + * @return the {@link DocumentResultCollection} of {@link RecognizePiiEntitiesResult} to be returned by the SDK. + */ + private DocumentResultCollection toPiiDocumentResultCollection( + final EntitiesResult entitiesResult) { + List recognizePiiEntitiesResults = new ArrayList<>(); for (DocumentEntities documentEntities : entitiesResult.getDocuments()) { - validDocumentList.add(new RecognizeEntitiesResult(documentEntities.getId(), + recognizePiiEntitiesResults.add(new RecognizePiiEntitiesResult(documentEntities.getId(), documentEntities.getStatistics() == null ? null : convertToTextDocumentStatistics(documentEntities.getStatistics()), - null, mapToNamedEntities(documentEntities.getEntities()))); + null, documentEntities.getEntities().stream().map(entity -> + new NamedEntity(entity.getText(), entity.getType(), entity.getSubtype(), entity.getOffset(), + entity.getLength(), entity.getScore())).collect(Collectors.toList()))); } - List errorDocumentList = new ArrayList<>(); + for (DocumentError documentError : entitiesResult.getErrors()) { final com.azure.ai.textanalytics.models.TextAnalyticsError error = convertToError(documentError.getError()); - errorDocumentList.add(new RecognizeEntitiesResult(documentError.getId(), null, error, null)); + recognizePiiEntitiesResults.add(new RecognizePiiEntitiesResult(documentError.getId(), null, error, null)); } - return Stream.concat(validDocumentList.stream(), errorDocumentList.stream()).collect(Collectors.toList()); + + return new DocumentResultCollection<>(recognizePiiEntitiesResults, + entitiesResult.getModelVersion(), entitiesResult.getStatistics() == null ? null + : mapBatchStatistics(entitiesResult.getStatistics())); } - private List mapToNamedEntities(List entities) { - List namedEntityList = new ArrayList<>(); - for (Entity entity : entities) { - namedEntityList.add(new NamedEntity(entity.getText(), entity.getType(), entity.getSubtype(), - entity.getOffset(), entity.getLength(), entity.getScore())); + /** + * Helper method to convert the service response of {@link EntityLinkingResult} to {@link DocumentResultCollection}. + * + * @param entityLinkingResult the {@link EntityLinkingResult} returned by the service. + * + * @return the {@link DocumentResultCollection} of {@link RecognizeLinkedEntitiesResult} to be returned by the SDK. + */ + private DocumentResultCollection toDocumentResultCollection( + final EntityLinkingResult entityLinkingResult) { + List linkedEntitiesResults = new ArrayList<>(); + for (DocumentLinkedEntities documentLinkedEntities : entityLinkingResult.getDocuments()) { + linkedEntitiesResults.add(new RecognizeLinkedEntitiesResult(documentLinkedEntities.getId(), + documentLinkedEntities.getStatistics() == null ? null + : convertToTextDocumentStatistics(documentLinkedEntities.getStatistics()), + null, mapLinkedEntity(documentLinkedEntities.getEntities()))); + } + for (DocumentError documentError : entityLinkingResult.getErrors()) { + final com.azure.ai.textanalytics.models.TextAnalyticsError error = convertToError(documentError.getError()); + linkedEntitiesResults.add(new RecognizeLinkedEntitiesResult(documentError.getId(), null, error, null)); } - return namedEntityList; + + return new DocumentResultCollection<>(linkedEntitiesResults, + entityLinkingResult.getModelVersion(), entityLinkingResult.getStatistics() == null ? null + : mapBatchStatistics(entityLinkingResult.getStatistics())); } private static List mapByIndex(List textInputs, BiFunction mappingFunction) { @@ -1371,28 +1334,26 @@ private TextDocumentStatistics convertToTextDocumentStatistics(DocumentStatistic return new TextDocumentStatistics(statistics.getCharactersCount(), statistics.getTransactionsCount()); } + private TextDocumentBatchStatistics mapBatchStatistics(RequestStatistics statistics) { + return new TextDocumentBatchStatistics(statistics.getDocumentsCount(), statistics.getErroneousDocumentsCount(), + statistics.getValidDocumentsCount(), statistics.getTransactionsCount()); + } + private List mapLinkedEntity(List linkedEntities) { List linkedEntitiesList = new ArrayList<>(); for (LinkedEntity linkedEntity : linkedEntities) { linkedEntitiesList.add(new com.azure.ai.textanalytics.models.LinkedEntity(linkedEntity.getName(), - mapLinkedEntityMatches(linkedEntity.getMatches()), linkedEntity.getLanguage(), linkedEntity.getId(), - linkedEntity.getUrl(), linkedEntity.getDataSource())); + linkedEntity.getMatches().stream().map(match -> + new LinkedEntityMatch(match.getText(), match.getScore(), match.getLength(), + match.getOffset())).collect(Collectors.toList()), linkedEntity.getLanguage(), + linkedEntity.getId(), linkedEntity.getUrl(), linkedEntity.getDataSource())); } return linkedEntitiesList; } - private List mapLinkedEntityMatches(List matches) { - List linkedEntityMatchesList = new ArrayList<>(); - for (Match match : matches) { - linkedEntityMatchesList.add(new LinkedEntityMatch(match.getText(), match.getScore(), match.getLength(), - match.getOffset())); - } - return linkedEntityMatchesList; - } - private com.azure.ai.textanalytics.models.TextAnalyticsError convertToError(TextAnalyticsError textAnalyticsError) { return new com.azure.ai.textanalytics.models.TextAnalyticsError( - convertToErrorCodeValue(textAnalyticsError.getCode()), textAnalyticsError.getMessage(), + ErrorCodeValue.fromString(textAnalyticsError.getCode().toString()), textAnalyticsError.getMessage(), textAnalyticsError.getTarget(), textAnalyticsError.getDetails() == null ? null : setErrors(textAnalyticsError.getDetails())); } @@ -1401,15 +1362,10 @@ private List setErrors(Lis List detailsList = new ArrayList<>(); for (TextAnalyticsError error : details) { detailsList.add(new com.azure.ai.textanalytics.models.TextAnalyticsError( - convertToErrorCodeValue(error.getCode()), + ErrorCodeValue.fromString(error.getCode().toString()), error.getMessage(), error.getTarget(), error.getDetails() == null ? null : setErrors(error.getDetails()))); } return detailsList; } - - private ErrorCodeValue convertToErrorCodeValue( - com.azure.ai.textanalytics.implementation.models.ErrorCodeValue errorCodeValue) { - return ErrorCodeValue.fromString(errorCodeValue.toString()); - } } diff --git a/sdk/textanalytics/azure-ai-textanalytics/src/main/java/com/azure/ai/textanalytics/TextAnalyticsClient.java b/sdk/textanalytics/azure-ai-textanalytics/src/main/java/com/azure/ai/textanalytics/TextAnalyticsClient.java index 0c60bd0e0b0f..9a8a53d777c2 100644 --- a/sdk/textanalytics/azure-ai-textanalytics/src/main/java/com/azure/ai/textanalytics/TextAnalyticsClient.java +++ b/sdk/textanalytics/azure-ai-textanalytics/src/main/java/com/azure/ai/textanalytics/TextAnalyticsClient.java @@ -9,6 +9,7 @@ import com.azure.ai.textanalytics.models.ExtractKeyPhraseResult; import com.azure.ai.textanalytics.models.RecognizeLinkedEntitiesResult; import com.azure.ai.textanalytics.models.RecognizeEntitiesResult; +import com.azure.ai.textanalytics.models.RecognizePiiEntitiesResult; import com.azure.ai.textanalytics.models.TextAnalyticsRequestOptions; import com.azure.ai.textanalytics.models.TextDocumentInput; import com.azure.ai.textanalytics.models.AnalyzeSentimentResult; @@ -259,12 +260,12 @@ public Response> recognizeBatc * See https://aka.ms/talangs for the list of enabled languages. * * @param text the text to recognize pii entities for. - * @return A {@link RecognizeEntitiesResult PII entity} of the text. + * @return A {@link RecognizePiiEntitiesResult PII entity} of the text. * * @throws NullPointerException if {@code text} is {@code null}. */ @ServiceMethod(returns = ReturnType.SINGLE) - public RecognizeEntitiesResult recognizePiiEntities(String text) { + public RecognizePiiEntitiesResult recognizePiiEntities(String text) { return recognizePiiEntitiesWithResponse(text, client.getDefaultLanguage(), Context.NONE).getValue(); } @@ -279,11 +280,11 @@ public RecognizeEntitiesResult recognizePiiEntities(String text) { * @param context Additional context that is passed through the Http pipeline during the service call. * * @return A {@link Response} whose {@link Response#getValue() value} has the - * {@link RecognizeEntitiesResult named entity} of the text. + * {@link RecognizePiiEntitiesResult named entity} of the text. * @throws NullPointerException if {@code text} is {@code null}. */ @ServiceMethod(returns = ReturnType.SINGLE) - public Response recognizePiiEntitiesWithResponse(String text, String language, + public Response recognizePiiEntitiesWithResponse(String text, String language, Context context) { return client.recognizePiiEntitiesWithResponse(text, language, context).block(); } @@ -295,11 +296,12 @@ public Response recognizePiiEntitiesWithResponse(String * * @param textInputs A list of text to recognize pii entities for. * - * @return A {@link DocumentResultCollection batch} of the {@link RecognizeEntitiesResult named entity} of the text. + * @return A {@link DocumentResultCollection batch} of the {@link RecognizePiiEntitiesResult named entity} + * of the text. * @throws NullPointerException if {@code textInputs} is {@code null}. */ @ServiceMethod(returns = ReturnType.SINGLE) - public DocumentResultCollection recognizePiiEntities(List textInputs) { + public DocumentResultCollection recognizePiiEntities(List textInputs) { return recognizePiiEntitiesWithResponse(textInputs, client.getDefaultLanguage(), Context.NONE).getValue(); } @@ -314,11 +316,11 @@ public DocumentResultCollection recognizePiiEntities(Li * @param context Additional context that is passed through the Http pipeline during the service call. * * @return A {@link Response} containing the {@link DocumentResultCollection batch} of the - * {@link RecognizeEntitiesResult named entity}. + * {@link RecognizePiiEntitiesResult named entity}. * @throws NullPointerException if {@code textInputs} is {@code null}. */ @ServiceMethod(returns = ReturnType.SINGLE) - public Response> recognizePiiEntitiesWithResponse( + public Response> recognizePiiEntitiesWithResponse( List textInputs, String language, Context context) { return client.recognizePiiEntitiesWithResponse(textInputs, language, context).block(); } @@ -334,7 +336,7 @@ public Response> recognizePiiE * @throws NullPointerException if {@code textInputs} is {@code null}. */ @ServiceMethod(returns = ReturnType.SINGLE) - public DocumentResultCollection recognizeBatchPiiEntities( + public DocumentResultCollection recognizeBatchPiiEntities( List textInputs) { return recognizeBatchPiiEntitiesWithResponse(textInputs, null, Context.NONE).getValue(); } @@ -354,13 +356,12 @@ public DocumentResultCollection recognizeBatchPiiEntiti * @throws NullPointerException if {@code textInputs} is {@code null}. */ @ServiceMethod(returns = ReturnType.SINGLE) - public Response> recognizeBatchPiiEntitiesWithResponse( + public Response> recognizeBatchPiiEntitiesWithResponse( List textInputs, TextAnalyticsRequestOptions options, Context context) { return client.recognizeBatchPiiEntitiesWithResponse(textInputs, options, context).block(); } // Linked Entities - /** * Returns a list of recognized entities with links to a well-known knowledge base for the provided text. * See https://aka.ms/talangs for supported languages in Text Analytics API. diff --git a/sdk/textanalytics/azure-ai-textanalytics/src/samples/java/com/azure/ai/textanalytics/batch/RecognizePiiBatchDocuments.java b/sdk/textanalytics/azure-ai-textanalytics/src/samples/java/com/azure/ai/textanalytics/batch/RecognizePiiBatchDocuments.java index ed70915b5dce..4986f308012e 100644 --- a/sdk/textanalytics/azure-ai-textanalytics/src/samples/java/com/azure/ai/textanalytics/batch/RecognizePiiBatchDocuments.java +++ b/sdk/textanalytics/azure-ai-textanalytics/src/samples/java/com/azure/ai/textanalytics/batch/RecognizePiiBatchDocuments.java @@ -6,7 +6,7 @@ import com.azure.ai.textanalytics.TextAnalyticsClient; import com.azure.ai.textanalytics.TextAnalyticsClientBuilder; import com.azure.ai.textanalytics.models.DocumentResultCollection; -import com.azure.ai.textanalytics.models.RecognizeEntitiesResult; +import com.azure.ai.textanalytics.models.RecognizePiiEntitiesResult; import com.azure.ai.textanalytics.models.TextAnalyticsRequestOptions; import com.azure.ai.textanalytics.models.TextDocumentBatchStatistics; import com.azure.ai.textanalytics.models.TextDocumentInput; @@ -38,7 +38,7 @@ public static void main(String[] args) { ); final TextAnalyticsRequestOptions requestOptions = new TextAnalyticsRequestOptions().setShowStatistics(true); - final DocumentResultCollection detectedBatchResult = client.recognizeBatchPiiEntitiesWithResponse(inputs, requestOptions, Context.NONE).getValue(); + final DocumentResultCollection detectedBatchResult = client.recognizeBatchPiiEntitiesWithResponse(inputs, requestOptions, Context.NONE).getValue(); System.out.printf("Model version: %s%n", detectedBatchResult.getModelVersion()); final TextDocumentBatchStatistics batchStatistics = detectedBatchResult.getStatistics(); diff --git a/sdk/textanalytics/azure-ai-textanalytics/src/test/java/com/azure/ai/textanalytics/TextAnalyticsClientTestBase.java b/sdk/textanalytics/azure-ai-textanalytics/src/test/java/com/azure/ai/textanalytics/TextAnalyticsClientTestBase.java index afc3bdc50395..6133516ef6fc 100644 --- a/sdk/textanalytics/azure-ai-textanalytics/src/test/java/com/azure/ai/textanalytics/TextAnalyticsClientTestBase.java +++ b/sdk/textanalytics/azure-ai-textanalytics/src/test/java/com/azure/ai/textanalytics/TextAnalyticsClientTestBase.java @@ -13,6 +13,7 @@ import com.azure.ai.textanalytics.models.NamedEntity; import com.azure.ai.textanalytics.models.RecognizeEntitiesResult; import com.azure.ai.textanalytics.models.RecognizeLinkedEntitiesResult; +import com.azure.ai.textanalytics.models.RecognizePiiEntitiesResult; import com.azure.ai.textanalytics.models.TextAnalyticsError; import com.azure.ai.textanalytics.models.TextAnalyticsRequestOptions; import com.azure.ai.textanalytics.models.TextDocumentBatchStatistics; @@ -915,7 +916,7 @@ static DocumentResultCollection getExpectedBatchNamedEn return new DocumentResultCollection<>(recognizeEntitiesResultList, MODEL_VERSION, textDocumentBatchStatistics); } - static DocumentResultCollection getExpectedBatchPiiEntities() { + static DocumentResultCollection getExpectedBatchPiiEntities() { NamedEntity namedEntity1 = new NamedEntity("859-98-0987", "U.S. Social Security Number (SSN)", "", 28, 11, 0.65); NamedEntity namedEntity2 = new NamedEntity("111000025", "ABA Routing Number", "", 18, 9, 0.75); @@ -925,11 +926,11 @@ static DocumentResultCollection getExpectedBatchPiiEnti TextDocumentStatistics textDocumentStatistics1 = new TextDocumentStatistics(67, 1); TextDocumentStatistics textDocumentStatistics2 = new TextDocumentStatistics(105, 1); - RecognizeEntitiesResult recognizeEntitiesResult1 = new RecognizeEntitiesResult("0", textDocumentStatistics1, null, namedEntityList1); - RecognizeEntitiesResult recognizeEntitiesResult2 = new RecognizeEntitiesResult("1", textDocumentStatistics2, null, namedEntityList2); + RecognizePiiEntitiesResult recognizeEntitiesResult1 = new RecognizePiiEntitiesResult("0", textDocumentStatistics1, null, namedEntityList1); + RecognizePiiEntitiesResult recognizeEntitiesResult2 = new RecognizePiiEntitiesResult("1", textDocumentStatistics2, null, namedEntityList2); TextDocumentBatchStatistics textDocumentBatchStatistics = new TextDocumentBatchStatistics(2, 0, 2, 2); - List recognizeEntitiesResultList = Arrays.asList(recognizeEntitiesResult1, recognizeEntitiesResult2); + List recognizeEntitiesResultList = Arrays.asList(recognizeEntitiesResult1, recognizeEntitiesResult2); return new DocumentResultCollection<>(recognizeEntitiesResultList, MODEL_VERSION, textDocumentBatchStatistics); }