From 726c4871d7635609faba3ea91d290351fc91c92e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Przemys=C5=82aw=20Witek?= Date: Fri, 27 Mar 2020 08:36:41 +0100 Subject: [PATCH 1/3] Do not fail Evaluate API when the actual and predicted fields' types differ (#54255) --- .../MulticlassConfusionMatrix.java | 4 +- .../evaluation/classification/Precision.java | 2 +- .../ClassificationEvaluationIT.java | 499 +++++++++++------- 3 files changed, 313 insertions(+), 192 deletions(-) diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/MulticlassConfusionMatrix.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/MulticlassConfusionMatrix.java index 1dc3614723dfa..13c08098776f5 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/MulticlassConfusionMatrix.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/MulticlassConfusionMatrix.java @@ -143,7 +143,7 @@ public final Tuple, List> a if (result.get() == null) { // These are steps 2, 3, 4 etc. KeyedFilter[] keyedFiltersPredicted = topActualClassNames.get().stream() - .map(className -> new KeyedFilter(className, QueryBuilders.termQuery(predictedField, className))) + .map(className -> new KeyedFilter(className, QueryBuilders.matchQuery(predictedField, className).lenient(true))) .toArray(KeyedFilter[]::new); // Knowing exactly how many buckets does each aggregation use, we can choose the size of the batch so that // too_many_buckets_exception exception is not thrown. @@ -154,7 +154,7 @@ public final Tuple, List> a topActualClassNames.get().stream() .skip(actualClasses.size()) .limit(actualClassesPerBatch) - .map(className -> new KeyedFilter(className, QueryBuilders.termQuery(actualField, className))) + .map(className -> new KeyedFilter(className, QueryBuilders.matchQuery(actualField, className).lenient(true))) .toArray(KeyedFilter[]::new); if (keyedFiltersActual.length > 0) { return Tuple.tuple( diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/Precision.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/Precision.java index 0ffdc22ab1c7d..b90bfd8cce6c6 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/Precision.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/Precision.java @@ -108,7 +108,7 @@ public final Tuple, List> a if (result.get() == null) { // This is step 2 KeyedFilter[] keyedFiltersPredicted = topActualClassNames.get().stream() - .map(className -> new KeyedFilter(className, QueryBuilders.termQuery(predictedField, className))) + .map(className -> new KeyedFilter(className, QueryBuilders.matchQuery(predictedField, className).lenient(true))) .toArray(KeyedFilter[]::new); Script script = PainlessScripts.buildIsEqualScript(actualField, predictedField); return Tuple.tuple( diff --git a/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/ClassificationEvaluationIT.java b/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/ClassificationEvaluationIT.java index 70f8e7ca8ae12..5e961f97cc611 100644 --- a/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/ClassificationEvaluationIT.java +++ b/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/ClassificationEvaluationIT.java @@ -29,23 +29,25 @@ import static java.util.stream.Collectors.toList; import static org.hamcrest.Matchers.contains; import static org.hamcrest.Matchers.containsString; +import static org.hamcrest.Matchers.empty; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.hasSize; import static org.hamcrest.Matchers.instanceOf; import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.notANumber; public class ClassificationEvaluationIT extends MlNativeDataFrameAnalyticsIntegTestCase { private static final String ANIMALS_DATA_INDEX = "test-evaluate-animals-index"; private static final String ANIMAL_NAME_KEYWORD_FIELD = "animal_name_keyword"; - private static final String ANIMAL_NAME_PREDICTION_FIELD = "animal_name_prediction"; + private static final String ANIMAL_NAME_PREDICTION_KEYWORD_FIELD = "animal_name_keyword_prediction"; private static final String NO_LEGS_KEYWORD_FIELD = "no_legs_keyword"; private static final String NO_LEGS_INTEGER_FIELD = "no_legs_integer"; - private static final String NO_LEGS_PREDICTION_FIELD = "no_legs_prediction"; + private static final String NO_LEGS_PREDICTION_INTEGER_FIELD = "no_legs_integer_prediction"; private static final String IS_PREDATOR_KEYWORD_FIELD = "predator_keyword"; private static final String IS_PREDATOR_BOOLEAN_FIELD = "predator_boolean"; - private static final String IS_PREDATOR_PREDICTION_FIELD = "predator_prediction"; + private static final String IS_PREDATOR_PREDICTION_BOOLEAN_FIELD = "predator_boolean_prediction"; @Before public void setup() { @@ -64,7 +66,8 @@ public void cleanup() { public void testEvaluate_DefaultMetrics() { EvaluateDataFrameAction.Response evaluateDataFrameResponse = - evaluateDataFrame(ANIMALS_DATA_INDEX, new Classification(ANIMAL_NAME_KEYWORD_FIELD, ANIMAL_NAME_PREDICTION_FIELD, null)); + evaluateDataFrame( + ANIMALS_DATA_INDEX, new Classification(ANIMAL_NAME_KEYWORD_FIELD, ANIMAL_NAME_PREDICTION_KEYWORD_FIELD, null)); assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Classification.NAME.getPreferredName())); assertThat( @@ -78,8 +81,8 @@ public void testEvaluate_AllMetrics() { ANIMALS_DATA_INDEX, new Classification( ANIMAL_NAME_KEYWORD_FIELD, - ANIMAL_NAME_PREDICTION_FIELD, - Arrays.asList(new Accuracy(), new MulticlassConfusionMatrix(), new Precision(), new Recall()))); + ANIMAL_NAME_PREDICTION_KEYWORD_FIELD, + List.of(new Accuracy(), new MulticlassConfusionMatrix(), new Precision(), new Recall()))); assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Classification.NAME.getPreferredName())); assertThat( @@ -91,163 +94,257 @@ public void testEvaluate_AllMetrics() { Recall.NAME.getPreferredName())); } - public void testEvaluate_Accuracy_KeywordField() { + public void testEvaluate_AllMetrics_KeywordField_CaseSensitivity() { + String indexName = "some-index"; + String actualField = "fieldA"; + String predictedField = "fieldB"; + client().admin().indices().prepareCreate(indexName) + .setMapping( + actualField, "type=keyword", + predictedField, "type=keyword") + .get(); + client().prepareIndex(indexName) + .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE) + .setSource( + actualField, "crocodile", + predictedField, "cRoCoDiLe") + .get(); + EvaluateDataFrameAction.Response evaluateDataFrameResponse = evaluateDataFrame( - ANIMALS_DATA_INDEX, - new Classification(ANIMAL_NAME_KEYWORD_FIELD, ANIMAL_NAME_PREDICTION_FIELD, Arrays.asList(new Accuracy()))); - - assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Classification.NAME.getPreferredName())); - assertThat(evaluateDataFrameResponse.getMetrics(), hasSize(1)); + indexName, + new Classification( + actualField, + predictedField, + List.of(new Accuracy(), new MulticlassConfusionMatrix(), new Precision(), new Recall()))); Accuracy.Result accuracyResult = (Accuracy.Result) evaluateDataFrameResponse.getMetrics().get(0); - assertThat(accuracyResult.getMetricName(), equalTo(Accuracy.NAME.getPreferredName())); + assertThat(accuracyResult.getClasses(), contains(new Accuracy.PerClassResult("crocodile", 0.0))); + assertThat(accuracyResult.getOverallAccuracy(), equalTo(0.0)); + + MulticlassConfusionMatrix.Result confusionMatrixResult = + (MulticlassConfusionMatrix.Result) evaluateDataFrameResponse.getMetrics().get(1); assertThat( - accuracyResult.getClasses(), - equalTo( - Arrays.asList( - new Accuracy.PerClassResult("ant", 47.0 / 75), - new Accuracy.PerClassResult("cat", 47.0 / 75), - new Accuracy.PerClassResult("dog", 47.0 / 75), - new Accuracy.PerClassResult("fox", 47.0 / 75), - new Accuracy.PerClassResult("mouse", 47.0 / 75)))); - assertThat(accuracyResult.getOverallAccuracy(), equalTo(5.0 / 75)); + confusionMatrixResult.getConfusionMatrix(), + equalTo(List.of( + new MulticlassConfusionMatrix.ActualClass( + "crocodile", 1, List.of(new MulticlassConfusionMatrix.PredictedClass("crocodile", 0L)), 1)))); + + Precision.Result precisionResult = (Precision.Result) evaluateDataFrameResponse.getMetrics().get(2); + assertThat(precisionResult.getClasses(), empty()); + assertThat(precisionResult.getAvgPrecision(), is(notANumber())); + + Recall.Result recallResult = (Recall.Result) evaluateDataFrameResponse.getMetrics().get(3); + assertThat(recallResult.getClasses(), contains(new Recall.PerClassResult("crocodile", 0.0))); + assertThat(recallResult.getAvgRecall(), equalTo(0.0)); } - private void evaluateAccuracy_IntegerField(String actualField) { + private Accuracy.Result evaluateAccuracy(String actualField, String predictedField) { EvaluateDataFrameAction.Response evaluateDataFrameResponse = - evaluateDataFrame(ANIMALS_DATA_INDEX, new Classification(actualField, NO_LEGS_PREDICTION_FIELD, Arrays.asList(new Accuracy()))); + evaluateDataFrame(ANIMALS_DATA_INDEX, new Classification(actualField, predictedField, List.of(new Accuracy()))); assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Classification.NAME.getPreferredName())); assertThat(evaluateDataFrameResponse.getMetrics(), hasSize(1)); Accuracy.Result accuracyResult = (Accuracy.Result) evaluateDataFrameResponse.getMetrics().get(0); assertThat(accuracyResult.getMetricName(), equalTo(Accuracy.NAME.getPreferredName())); - assertThat( - accuracyResult.getClasses(), - equalTo( - Arrays.asList( - new Accuracy.PerClassResult("1", 57.0 / 75), - new Accuracy.PerClassResult("2", 54.0 / 75), - new Accuracy.PerClassResult("3", 51.0 / 75), - new Accuracy.PerClassResult("4", 48.0 / 75), - new Accuracy.PerClassResult("5", 45.0 / 75)))); - assertThat(accuracyResult.getOverallAccuracy(), equalTo(15.0 / 75)); - } - - public void testEvaluate_Accuracy_IntegerField() { - evaluateAccuracy_IntegerField(NO_LEGS_INTEGER_FIELD); - } - - public void testEvaluate_Accuracy_IntegerField_MappingTypeMismatch() { - evaluateAccuracy_IntegerField(NO_LEGS_KEYWORD_FIELD); + return accuracyResult; } - private void evaluateAccuracy_BooleanField(String actualField) { - EvaluateDataFrameAction.Response evaluateDataFrameResponse = - evaluateDataFrame( - ANIMALS_DATA_INDEX, new Classification(actualField, IS_PREDATOR_PREDICTION_FIELD, Arrays.asList(new Accuracy()))); + public void testEvaluate_Accuracy_KeywordField() { + List expectedPerClassResults = + List.of( + new Accuracy.PerClassResult("ant", 47.0 / 75), + new Accuracy.PerClassResult("cat", 47.0 / 75), + new Accuracy.PerClassResult("dog", 47.0 / 75), + new Accuracy.PerClassResult("fox", 47.0 / 75), + new Accuracy.PerClassResult("mouse", 47.0 / 75)); + double expectedOverallAccuracy = 5.0 / 75; - assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Classification.NAME.getPreferredName())); - assertThat(evaluateDataFrameResponse.getMetrics(), hasSize(1)); + Accuracy.Result accuracyResult = evaluateAccuracy(ANIMAL_NAME_KEYWORD_FIELD, ANIMAL_NAME_PREDICTION_KEYWORD_FIELD); + assertThat(accuracyResult.getClasses(), equalTo(expectedPerClassResults)); + assertThat(accuracyResult.getOverallAccuracy(), equalTo(expectedOverallAccuracy)); - Accuracy.Result accuracyResult = (Accuracy.Result) evaluateDataFrameResponse.getMetrics().get(0); - assertThat(accuracyResult.getMetricName(), equalTo(Accuracy.NAME.getPreferredName())); - assertThat( - accuracyResult.getClasses(), - equalTo( - Arrays.asList( - new Accuracy.PerClassResult("false", 18.0 / 30), - new Accuracy.PerClassResult("true", 27.0 / 45)))); - assertThat(accuracyResult.getOverallAccuracy(), equalTo(45.0 / 75)); + // Actual and predicted fields are swapped but this does not alter the result (accuracy is symmetric) + accuracyResult = evaluateAccuracy(ANIMAL_NAME_PREDICTION_KEYWORD_FIELD, ANIMAL_NAME_KEYWORD_FIELD); + assertThat(accuracyResult.getClasses(), equalTo(expectedPerClassResults)); + assertThat(accuracyResult.getOverallAccuracy(), equalTo(expectedOverallAccuracy)); } - public void testEvaluate_Accuracy_BooleanField() { - evaluateAccuracy_BooleanField(IS_PREDATOR_BOOLEAN_FIELD); + public void testEvaluate_Accuracy_IntegerField() { + List expectedPerClassResults = + List.of( + new Accuracy.PerClassResult("1", 57.0 / 75), + new Accuracy.PerClassResult("2", 54.0 / 75), + new Accuracy.PerClassResult("3", 51.0 / 75), + new Accuracy.PerClassResult("4", 48.0 / 75), + new Accuracy.PerClassResult("5", 45.0 / 75)); + double expectedOverallAccuracy = 15.0 / 75; + + Accuracy.Result accuracyResult = evaluateAccuracy(NO_LEGS_INTEGER_FIELD, NO_LEGS_PREDICTION_INTEGER_FIELD); + assertThat(accuracyResult.getClasses(), equalTo(expectedPerClassResults)); + assertThat(accuracyResult.getOverallAccuracy(), equalTo(expectedOverallAccuracy)); + + // Actual and predicted fields are swapped but this does not alter the result (accuracy is symmetric) + accuracyResult = evaluateAccuracy(NO_LEGS_PREDICTION_INTEGER_FIELD, NO_LEGS_INTEGER_FIELD); + assertThat(accuracyResult.getClasses(), equalTo(expectedPerClassResults)); + assertThat(accuracyResult.getOverallAccuracy(), equalTo(expectedOverallAccuracy)); + + // Actual and predicted fields are of different types but the values are matched correctly + accuracyResult = evaluateAccuracy(NO_LEGS_KEYWORD_FIELD, NO_LEGS_PREDICTION_INTEGER_FIELD); + assertThat(accuracyResult.getClasses(), equalTo(expectedPerClassResults)); + assertThat(accuracyResult.getOverallAccuracy(), equalTo(expectedOverallAccuracy)); + + // Actual and predicted fields are of different types but the values are matched correctly + accuracyResult = evaluateAccuracy(NO_LEGS_PREDICTION_INTEGER_FIELD, NO_LEGS_KEYWORD_FIELD); + assertThat(accuracyResult.getClasses(), equalTo(expectedPerClassResults)); + assertThat(accuracyResult.getOverallAccuracy(), equalTo(expectedOverallAccuracy)); } - public void testEvaluate_Accuracy_BooleanField_MappingTypeMismatch() { - evaluateAccuracy_BooleanField(IS_PREDATOR_KEYWORD_FIELD); + public void testEvaluate_Accuracy_BooleanField() { + List expectedPerClassResults = + List.of( + new Accuracy.PerClassResult("false", 18.0 / 30), + new Accuracy.PerClassResult("true", 27.0 / 45)); + double expectedOverallAccuracy = 45.0 / 75; + + Accuracy.Result accuracyResult = evaluateAccuracy(IS_PREDATOR_BOOLEAN_FIELD, IS_PREDATOR_PREDICTION_BOOLEAN_FIELD); + assertThat(accuracyResult.getClasses(), equalTo(expectedPerClassResults)); + assertThat(accuracyResult.getOverallAccuracy(), equalTo(expectedOverallAccuracy)); + + // Actual and predicted fields are swapped but this does not alter the result (accuracy is symmetric) + accuracyResult = evaluateAccuracy(IS_PREDATOR_PREDICTION_BOOLEAN_FIELD, IS_PREDATOR_BOOLEAN_FIELD); + assertThat(accuracyResult.getClasses(), equalTo(expectedPerClassResults)); + assertThat(accuracyResult.getOverallAccuracy(), equalTo(expectedOverallAccuracy)); + + // Actual and predicted fields are of different types but the values are matched correctly + accuracyResult = evaluateAccuracy(IS_PREDATOR_KEYWORD_FIELD, IS_PREDATOR_PREDICTION_BOOLEAN_FIELD); + assertThat(accuracyResult.getClasses(), equalTo(expectedPerClassResults)); + assertThat(accuracyResult.getOverallAccuracy(), equalTo(expectedOverallAccuracy)); + + // Actual and predicted fields are of different types but the values are matched correctly + accuracyResult = evaluateAccuracy(IS_PREDATOR_PREDICTION_BOOLEAN_FIELD, IS_PREDATOR_KEYWORD_FIELD); + assertThat(accuracyResult.getClasses(), equalTo(expectedPerClassResults)); + assertThat(accuracyResult.getOverallAccuracy(), equalTo(expectedOverallAccuracy)); + } + + public void testEvaluate_Accuracy_FieldTypeMismatch() { + { + // When actual and predicted fields have different types, the sets of classes are disjoint + List expectedPerClassResults = + List.of( + new Accuracy.PerClassResult("1", 0.8), + new Accuracy.PerClassResult("2", 0.8), + new Accuracy.PerClassResult("3", 0.8), + new Accuracy.PerClassResult("4", 0.8), + new Accuracy.PerClassResult("5", 0.8)); + double expectedOverallAccuracy = 0.0; + + Accuracy.Result accuracyResult = evaluateAccuracy(NO_LEGS_INTEGER_FIELD, IS_PREDATOR_BOOLEAN_FIELD); + assertThat(accuracyResult.getClasses(), equalTo(expectedPerClassResults)); + assertThat(accuracyResult.getOverallAccuracy(), equalTo(expectedOverallAccuracy)); + } + { + // When actual and predicted fields have different types, the sets of classes are disjoint + List expectedPerClassResults = + List.of( + new Accuracy.PerClassResult("false", 0.6), + new Accuracy.PerClassResult("true", 0.4)); + double expectedOverallAccuracy = 0.0; + + Accuracy.Result accuracyResult = evaluateAccuracy(IS_PREDATOR_BOOLEAN_FIELD, NO_LEGS_INTEGER_FIELD); + assertThat(accuracyResult.getClasses(), equalTo(expectedPerClassResults)); + assertThat(accuracyResult.getOverallAccuracy(), equalTo(expectedOverallAccuracy)); + } } - public void testEvaluate_Precision_KeywordField() { + private Precision.Result evaluatePrecision(String actualField, String predictedField) { EvaluateDataFrameAction.Response evaluateDataFrameResponse = - evaluateDataFrame( - ANIMALS_DATA_INDEX, - new Classification(ANIMAL_NAME_KEYWORD_FIELD, ANIMAL_NAME_PREDICTION_FIELD, Arrays.asList(new Precision()))); + evaluateDataFrame(ANIMALS_DATA_INDEX, new Classification(actualField, predictedField, List.of(new Precision()))); assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Classification.NAME.getPreferredName())); assertThat(evaluateDataFrameResponse.getMetrics(), hasSize(1)); Precision.Result precisionResult = (Precision.Result) evaluateDataFrameResponse.getMetrics().get(0); assertThat(precisionResult.getMetricName(), equalTo(Precision.NAME.getPreferredName())); - assertThat( - precisionResult.getClasses(), - equalTo( - Arrays.asList( - new Precision.PerClassResult("ant", 1.0 / 15), - new Precision.PerClassResult("cat", 1.0 / 15), - new Precision.PerClassResult("dog", 1.0 / 15), - new Precision.PerClassResult("fox", 1.0 / 15), - new Precision.PerClassResult("mouse", 1.0 / 15)))); - assertThat(precisionResult.getAvgPrecision(), equalTo(5.0 / 75)); + return precisionResult; } - private void evaluatePrecision_IntegerField(String actualField) { - EvaluateDataFrameAction.Response evaluateDataFrameResponse = - evaluateDataFrame( - ANIMALS_DATA_INDEX, new Classification(actualField, NO_LEGS_PREDICTION_FIELD, Arrays.asList(new Precision()))); + public void testEvaluate_Precision_KeywordField() { + List expectedPerClassResults = + List.of( + new Precision.PerClassResult("ant", 1.0 / 15), + new Precision.PerClassResult("cat", 1.0 / 15), + new Precision.PerClassResult("dog", 1.0 / 15), + new Precision.PerClassResult("fox", 1.0 / 15), + new Precision.PerClassResult("mouse", 1.0 / 15)); + double expectedAvgPrecision = 5.0 / 75; - assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Classification.NAME.getPreferredName())); - assertThat(evaluateDataFrameResponse.getMetrics(), hasSize(1)); + Precision.Result precisionResult = evaluatePrecision(ANIMAL_NAME_KEYWORD_FIELD, ANIMAL_NAME_PREDICTION_KEYWORD_FIELD); + assertThat(precisionResult.getClasses(), equalTo(expectedPerClassResults)); + assertThat(precisionResult.getAvgPrecision(), equalTo(expectedAvgPrecision)); - Precision.Result precisionResult = (Precision.Result) evaluateDataFrameResponse.getMetrics().get(0); - assertThat(precisionResult.getMetricName(), equalTo(Precision.NAME.getPreferredName())); - assertThat( - precisionResult.getClasses(), - equalTo( - Arrays.asList( - new Precision.PerClassResult("1", 0.2), - new Precision.PerClassResult("2", 0.2), - new Precision.PerClassResult("3", 0.2), - new Precision.PerClassResult("4", 0.2), - new Precision.PerClassResult("5", 0.2)))); - assertThat(precisionResult.getAvgPrecision(), equalTo(0.2)); + evaluatePrecision(ANIMAL_NAME_PREDICTION_KEYWORD_FIELD, ANIMAL_NAME_KEYWORD_FIELD); } public void testEvaluate_Precision_IntegerField() { - evaluatePrecision_IntegerField(NO_LEGS_INTEGER_FIELD); - } + List expectedPerClassResults = + List.of( + new Precision.PerClassResult("1", 0.2), + new Precision.PerClassResult("2", 0.2), + new Precision.PerClassResult("3", 0.2), + new Precision.PerClassResult("4", 0.2), + new Precision.PerClassResult("5", 0.2)); + double expectedAvgPrecision = 0.2; - public void testEvaluate_Precision_IntegerField_MappingTypeMismatch() { - evaluatePrecision_IntegerField(NO_LEGS_KEYWORD_FIELD); - } + Precision.Result precisionResult = evaluatePrecision(NO_LEGS_INTEGER_FIELD, NO_LEGS_PREDICTION_INTEGER_FIELD); + assertThat(precisionResult.getClasses(), equalTo(expectedPerClassResults)); + assertThat(precisionResult.getAvgPrecision(), equalTo(expectedAvgPrecision)); - private void evaluatePrecision_BooleanField(String actualField) { - EvaluateDataFrameAction.Response evaluateDataFrameResponse = - evaluateDataFrame( - ANIMALS_DATA_INDEX, new Classification(actualField, IS_PREDATOR_PREDICTION_FIELD, Arrays.asList(new Precision()))); + // Actual and predicted fields are of different types but the values are matched correctly + precisionResult = evaluatePrecision(NO_LEGS_KEYWORD_FIELD, NO_LEGS_PREDICTION_INTEGER_FIELD); + assertThat(precisionResult.getClasses(), equalTo(expectedPerClassResults)); + assertThat(precisionResult.getAvgPrecision(), equalTo(expectedAvgPrecision)); - assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Classification.NAME.getPreferredName())); - assertThat(evaluateDataFrameResponse.getMetrics(), hasSize(1)); + evaluatePrecision(NO_LEGS_PREDICTION_INTEGER_FIELD, NO_LEGS_INTEGER_FIELD); - Precision.Result precisionResult = (Precision.Result) evaluateDataFrameResponse.getMetrics().get(0); - assertThat(precisionResult.getMetricName(), equalTo(Precision.NAME.getPreferredName())); - assertThat( - precisionResult.getClasses(), - equalTo( - Arrays.asList( - new Precision.PerClassResult("false", 0.5), - new Precision.PerClassResult("true", 9.0 / 13)))); - assertThat(precisionResult.getAvgPrecision(), equalTo(31.0 / 52)); + evaluatePrecision(NO_LEGS_PREDICTION_INTEGER_FIELD, NO_LEGS_KEYWORD_FIELD); } public void testEvaluate_Precision_BooleanField() { - evaluatePrecision_BooleanField(IS_PREDATOR_BOOLEAN_FIELD); + List expectedPerClassResults = + List.of( + new Precision.PerClassResult("false", 0.5), + new Precision.PerClassResult("true", 9.0 / 13)); + double expectedAvgPrecision = 31.0 / 52; + + Precision.Result precisionResult = evaluatePrecision(IS_PREDATOR_BOOLEAN_FIELD, IS_PREDATOR_PREDICTION_BOOLEAN_FIELD); + assertThat(precisionResult.getClasses(), equalTo(expectedPerClassResults)); + assertThat(precisionResult.getAvgPrecision(), equalTo(expectedAvgPrecision)); + + // Actual and predicted fields are of different types but the values are matched correctly + precisionResult = evaluatePrecision(IS_PREDATOR_KEYWORD_FIELD, IS_PREDATOR_PREDICTION_BOOLEAN_FIELD); + assertThat(precisionResult.getClasses(), equalTo(expectedPerClassResults)); + assertThat(precisionResult.getAvgPrecision(), equalTo(expectedAvgPrecision)); + + evaluatePrecision(IS_PREDATOR_PREDICTION_BOOLEAN_FIELD, IS_PREDATOR_BOOLEAN_FIELD); + + evaluatePrecision(IS_PREDATOR_PREDICTION_BOOLEAN_FIELD, IS_PREDATOR_KEYWORD_FIELD); } - public void testEvaluate_Precision_BooleanField_MappingTypeMismatch() { - evaluatePrecision_BooleanField(IS_PREDATOR_KEYWORD_FIELD); + public void testEvaluate_Precision_FieldTypeMismatch() { + { + Precision.Result precisionResult = evaluatePrecision(NO_LEGS_INTEGER_FIELD, IS_PREDATOR_BOOLEAN_FIELD); + // When actual and predicted fields have different types, the sets of classes are disjoint, hence empty results here + assertThat(precisionResult.getClasses(), empty()); + assertThat(precisionResult.getAvgPrecision(), is(notANumber())); + } + { + Precision.Result precisionResult = evaluatePrecision(IS_PREDATOR_BOOLEAN_FIELD, NO_LEGS_INTEGER_FIELD); + // When actual and predicted fields have different types, the sets of classes are disjoint, hence empty results here + assertThat(precisionResult.getClasses(), empty()); + assertThat(precisionResult.getAvgPrecision(), is(notANumber())); + } } public void testEvaluate_Precision_CardinalityTooHigh() { @@ -257,88 +354,112 @@ public void testEvaluate_Precision_CardinalityTooHigh() { ElasticsearchStatusException.class, () -> evaluateDataFrame( ANIMALS_DATA_INDEX, - new Classification(ANIMAL_NAME_KEYWORD_FIELD, ANIMAL_NAME_PREDICTION_FIELD, Arrays.asList(new Precision())))); + new Classification(ANIMAL_NAME_KEYWORD_FIELD, ANIMAL_NAME_PREDICTION_KEYWORD_FIELD, List.of(new Precision())))); assertThat(e.getMessage(), containsString("Cardinality of field [animal_name_keyword] is too high")); } - public void testEvaluate_Recall_KeywordField() { + private Recall.Result evaluateRecall(String actualField, String predictedField) { EvaluateDataFrameAction.Response evaluateDataFrameResponse = - evaluateDataFrame( - ANIMALS_DATA_INDEX, - new Classification(ANIMAL_NAME_KEYWORD_FIELD, ANIMAL_NAME_PREDICTION_FIELD, Arrays.asList(new Recall()))); + evaluateDataFrame(ANIMALS_DATA_INDEX, new Classification(actualField, predictedField, List.of(new Recall()))); assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Classification.NAME.getPreferredName())); assertThat(evaluateDataFrameResponse.getMetrics(), hasSize(1)); Recall.Result recallResult = (Recall.Result) evaluateDataFrameResponse.getMetrics().get(0); assertThat(recallResult.getMetricName(), equalTo(Recall.NAME.getPreferredName())); - assertThat( - recallResult.getClasses(), - equalTo( - Arrays.asList( - new Recall.PerClassResult("ant", 1.0 / 15), - new Recall.PerClassResult("cat", 1.0 / 15), - new Recall.PerClassResult("dog", 1.0 / 15), - new Recall.PerClassResult("fox", 1.0 / 15), - new Recall.PerClassResult("mouse", 1.0 / 15)))); - assertThat(recallResult.getAvgRecall(), equalTo(5.0 / 75)); + return recallResult; } - private void evaluateRecall_IntegerField(String actualField) { - EvaluateDataFrameAction.Response evaluateDataFrameResponse = - evaluateDataFrame( - ANIMALS_DATA_INDEX, new Classification(actualField, NO_LEGS_INTEGER_FIELD, Arrays.asList(new Recall()))); + public void testEvaluate_Recall_KeywordField() { + List expectedPerClassResults = + List.of( + new Recall.PerClassResult("ant", 1.0 / 15), + new Recall.PerClassResult("cat", 1.0 / 15), + new Recall.PerClassResult("dog", 1.0 / 15), + new Recall.PerClassResult("fox", 1.0 / 15), + new Recall.PerClassResult("mouse", 1.0 / 15)); + double expectedAvgRecall = 5.0 / 75; - assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Classification.NAME.getPreferredName())); - assertThat(evaluateDataFrameResponse.getMetrics(), hasSize(1)); + Recall.Result recallResult = evaluateRecall(ANIMAL_NAME_KEYWORD_FIELD, ANIMAL_NAME_PREDICTION_KEYWORD_FIELD); + assertThat(recallResult.getClasses(), equalTo(expectedPerClassResults)); + assertThat(recallResult.getAvgRecall(), equalTo(expectedAvgRecall)); - Recall.Result recallResult = (Recall.Result) evaluateDataFrameResponse.getMetrics().get(0); - assertThat(recallResult.getMetricName(), equalTo(Recall.NAME.getPreferredName())); - assertThat( - recallResult.getClasses(), - equalTo( - Arrays.asList( - new Recall.PerClassResult("1", 1.0), - new Recall.PerClassResult("2", 1.0), - new Recall.PerClassResult("3", 1.0), - new Recall.PerClassResult("4", 1.0), - new Recall.PerClassResult("5", 1.0)))); - assertThat(recallResult.getAvgRecall(), equalTo(1.0)); + evaluateRecall(ANIMAL_NAME_PREDICTION_KEYWORD_FIELD, ANIMAL_NAME_KEYWORD_FIELD); } public void testEvaluate_Recall_IntegerField() { - evaluateRecall_IntegerField(NO_LEGS_INTEGER_FIELD); - } + List expectedPerClassResults = + List.of( + new Recall.PerClassResult("1", 1.0 / 15), + new Recall.PerClassResult("2", 2.0 / 15), + new Recall.PerClassResult("3", 3.0 / 15), + new Recall.PerClassResult("4", 4.0 / 15), + new Recall.PerClassResult("5", 5.0 / 15)); + double expectedAvgRecall = 3.0 / 15; - public void testEvaluate_Recall_IntegerField_MappingTypeMismatch() { - evaluateRecall_IntegerField(NO_LEGS_KEYWORD_FIELD); - } + Recall.Result recallResult = evaluateRecall(NO_LEGS_INTEGER_FIELD, NO_LEGS_PREDICTION_INTEGER_FIELD); + assertThat(recallResult.getClasses(), equalTo(expectedPerClassResults)); + assertThat(recallResult.getAvgRecall(), equalTo(expectedAvgRecall)); - private void evaluateRecall_BooleanField(String actualField) { - EvaluateDataFrameAction.Response evaluateDataFrameResponse = - evaluateDataFrame( - ANIMALS_DATA_INDEX, new Classification(actualField, IS_PREDATOR_PREDICTION_FIELD, Arrays.asList(new Recall()))); + // Actual and predicted fields are of different types but the values are matched correctly + recallResult = evaluateRecall(NO_LEGS_KEYWORD_FIELD, NO_LEGS_PREDICTION_INTEGER_FIELD); + assertThat(recallResult.getClasses(), equalTo(expectedPerClassResults)); + assertThat(recallResult.getAvgRecall(), equalTo(expectedAvgRecall)); - assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Classification.NAME.getPreferredName())); - assertThat(evaluateDataFrameResponse.getMetrics(), hasSize(1)); + evaluateRecall(NO_LEGS_PREDICTION_INTEGER_FIELD, NO_LEGS_INTEGER_FIELD); - Recall.Result recallResult = (Recall.Result) evaluateDataFrameResponse.getMetrics().get(0); - assertThat(recallResult.getMetricName(), equalTo(Recall.NAME.getPreferredName())); - assertThat( - recallResult.getClasses(), - equalTo( - Arrays.asList( - new Recall.PerClassResult("true", 0.6), - new Recall.PerClassResult("false", 0.6)))); - assertThat(recallResult.getAvgRecall(), equalTo(0.6)); + evaluateRecall(NO_LEGS_PREDICTION_INTEGER_FIELD, NO_LEGS_KEYWORD_FIELD); } public void testEvaluate_Recall_BooleanField() { - evaluateRecall_BooleanField(IS_PREDATOR_BOOLEAN_FIELD); - } - - public void testEvaluate_Recall_BooleanField_MappingTypeMismatch() { - evaluateRecall_BooleanField(IS_PREDATOR_KEYWORD_FIELD); + List expectedPerClassResults = + List.of( + new Recall.PerClassResult("true", 0.6), + new Recall.PerClassResult("false", 0.6)); + double expectedAvgRecall = 0.6; + + Recall.Result recallResult = evaluateRecall(IS_PREDATOR_BOOLEAN_FIELD, IS_PREDATOR_PREDICTION_BOOLEAN_FIELD); + assertThat(recallResult.getClasses(), equalTo(expectedPerClassResults)); + assertThat(recallResult.getAvgRecall(), equalTo(expectedAvgRecall)); + + // Actual and predicted fields are of different types but the values are matched correctly + recallResult = evaluateRecall(IS_PREDATOR_KEYWORD_FIELD, IS_PREDATOR_PREDICTION_BOOLEAN_FIELD); + assertThat(recallResult.getClasses(), equalTo(expectedPerClassResults)); + assertThat(recallResult.getAvgRecall(), equalTo(expectedAvgRecall)); + + evaluateRecall(IS_PREDATOR_PREDICTION_BOOLEAN_FIELD, IS_PREDATOR_BOOLEAN_FIELD); + + evaluateRecall(IS_PREDATOR_PREDICTION_BOOLEAN_FIELD, IS_PREDATOR_KEYWORD_FIELD); + } + + public void testEvaluate_Recall_FieldTypeMismatch() { + { + // When actual and predicted fields have different types, the sets of classes are disjoint, hence 0.0 results here + List expectedPerClassResults = + List.of( + new Recall.PerClassResult("1", 0.0), + new Recall.PerClassResult("2", 0.0), + new Recall.PerClassResult("3", 0.0), + new Recall.PerClassResult("4", 0.0), + new Recall.PerClassResult("5", 0.0)); + double expectedAvgRecall = 0.0; + + Recall.Result recallResult = evaluateRecall(NO_LEGS_INTEGER_FIELD, IS_PREDATOR_BOOLEAN_FIELD); + assertThat(recallResult.getClasses(), equalTo(expectedPerClassResults)); + assertThat(recallResult.getAvgRecall(), equalTo(expectedAvgRecall)); + } + { + // When actual and predicted fields have different types, the sets of classes are disjoint, hence 0.0 results here + List expectedPerClassResults = + List.of( + new Recall.PerClassResult("true", 0.0), + new Recall.PerClassResult("false", 0.0)); + double expectedAvgRecall = 0.0; + + Recall.Result recallResult = evaluateRecall(IS_PREDATOR_BOOLEAN_FIELD, NO_LEGS_INTEGER_FIELD); + assertThat(recallResult.getClasses(), equalTo(expectedPerClassResults)); + assertThat(recallResult.getAvgRecall(), equalTo(expectedAvgRecall)); + } } public void testEvaluate_Recall_CardinalityTooHigh() { @@ -348,16 +469,16 @@ public void testEvaluate_Recall_CardinalityTooHigh() { ElasticsearchStatusException.class, () -> evaluateDataFrame( ANIMALS_DATA_INDEX, - new Classification(ANIMAL_NAME_KEYWORD_FIELD, ANIMAL_NAME_PREDICTION_FIELD, Arrays.asList(new Recall())))); + new Classification(ANIMAL_NAME_KEYWORD_FIELD, ANIMAL_NAME_PREDICTION_KEYWORD_FIELD, List.of(new Recall())))); assertThat(e.getMessage(), containsString("Cardinality of field [animal_name_keyword] is too high")); } - private void evaluateWithMulticlassConfusionMatrix() { + private void evaluateMulticlassConfusionMatrix() { EvaluateDataFrameAction.Response evaluateDataFrameResponse = evaluateDataFrame( ANIMALS_DATA_INDEX, new Classification( - ANIMAL_NAME_KEYWORD_FIELD, ANIMAL_NAME_PREDICTION_FIELD, Arrays.asList(new MulticlassConfusionMatrix()))); + ANIMAL_NAME_KEYWORD_FIELD, ANIMAL_NAME_PREDICTION_KEYWORD_FIELD, List.of(new MulticlassConfusionMatrix()))); assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Classification.NAME.getPreferredName())); assertThat(evaluateDataFrameResponse.getMetrics(), hasSize(1)); @@ -417,16 +538,16 @@ private void evaluateWithMulticlassConfusionMatrix() { } public void testEvaluate_ConfusionMatrixMetricWithDefaultSize() { - evaluateWithMulticlassConfusionMatrix(); + evaluateMulticlassConfusionMatrix(); client().admin().cluster().prepareUpdateSettings().setTransientSettings(Settings.builder().put("search.max_buckets", 20)).get(); - evaluateWithMulticlassConfusionMatrix(); + evaluateMulticlassConfusionMatrix(); client().admin().cluster().prepareUpdateSettings().setTransientSettings(Settings.builder().put("search.max_buckets", 7)).get(); - evaluateWithMulticlassConfusionMatrix(); + evaluateMulticlassConfusionMatrix(); client().admin().cluster().prepareUpdateSettings().setTransientSettings(Settings.builder().put("search.max_buckets", 6)).get(); - ElasticsearchException e = expectThrows(ElasticsearchException.class, this::evaluateWithMulticlassConfusionMatrix); + ElasticsearchException e = expectThrows(ElasticsearchException.class, this::evaluateMulticlassConfusionMatrix); assertThat(e.getCause(), is(instanceOf(TooManyBucketsException.class))); TooManyBucketsException tmbe = (TooManyBucketsException) e.getCause(); @@ -438,7 +559,7 @@ public void testEvaluate_ConfusionMatrixMetricWithUserProvidedSize() { evaluateDataFrame( ANIMALS_DATA_INDEX, new Classification( - ANIMAL_NAME_KEYWORD_FIELD, ANIMAL_NAME_PREDICTION_FIELD, Arrays.asList(new MulticlassConfusionMatrix(3, null)))); + ANIMAL_NAME_KEYWORD_FIELD, ANIMAL_NAME_PREDICTION_KEYWORD_FIELD, List.of(new MulticlassConfusionMatrix(3, null)))); assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Classification.NAME.getPreferredName())); assertThat(evaluateDataFrameResponse.getMetrics(), hasSize(1)); @@ -476,13 +597,13 @@ private static void createAnimalsIndex(String indexName) { client().admin().indices().prepareCreate(indexName) .addMapping("_doc", ANIMAL_NAME_KEYWORD_FIELD, "type=keyword", - ANIMAL_NAME_PREDICTION_FIELD, "type=keyword", + ANIMAL_NAME_PREDICTION_KEYWORD_FIELD, "type=keyword", NO_LEGS_KEYWORD_FIELD, "type=keyword", NO_LEGS_INTEGER_FIELD, "type=integer", - NO_LEGS_PREDICTION_FIELD, "type=integer", + NO_LEGS_PREDICTION_INTEGER_FIELD, "type=integer", IS_PREDATOR_KEYWORD_FIELD, "type=keyword", IS_PREDATOR_BOOLEAN_FIELD, "type=boolean", - IS_PREDATOR_PREDICTION_FIELD, "type=boolean") + IS_PREDATOR_PREDICTION_BOOLEAN_FIELD, "type=boolean") .get(); } @@ -497,13 +618,13 @@ private static void indexAnimalsData(String indexName) { new IndexRequest(indexName) .source( ANIMAL_NAME_KEYWORD_FIELD, animalNames.get(i), - ANIMAL_NAME_PREDICTION_FIELD, animalNames.get((i + j) % animalNames.size()), + ANIMAL_NAME_PREDICTION_KEYWORD_FIELD, animalNames.get((i + j) % animalNames.size()), NO_LEGS_KEYWORD_FIELD, String.valueOf(i + 1), NO_LEGS_INTEGER_FIELD, i + 1, - NO_LEGS_PREDICTION_FIELD, j + 1, + NO_LEGS_PREDICTION_INTEGER_FIELD, j + 1, IS_PREDATOR_KEYWORD_FIELD, String.valueOf(i % 2 == 0), IS_PREDATOR_BOOLEAN_FIELD, i % 2 == 0, - IS_PREDATOR_PREDICTION_FIELD, (i + j) % 2 == 0)); + IS_PREDATOR_PREDICTION_BOOLEAN_FIELD, (i + j) % 2 == 0)); } } } @@ -519,7 +640,7 @@ private static void indexDistinctAnimals(String indexName, int distinctAnimalCou for (int i = 0; i < distinctAnimalCount; i++) { bulkRequestBuilder.add( new IndexRequest(indexName) - .source(ANIMAL_NAME_KEYWORD_FIELD, "animal_" + i, ANIMAL_NAME_PREDICTION_FIELD, randomAlphaOfLength(5))); + .source(ANIMAL_NAME_KEYWORD_FIELD, "animal_" + i, ANIMAL_NAME_PREDICTION_KEYWORD_FIELD, randomAlphaOfLength(5))); } BulkResponse bulkResponse = bulkRequestBuilder.get(); if (bulkResponse.hasFailures()) { From 2887a31066d4d8db7bd8fc2c6488d02c9aa85bbf Mon Sep 17 00:00:00 2001 From: Przemyslaw Witek Date: Fri, 27 Mar 2020 08:47:03 +0100 Subject: [PATCH 2/3] Replace List.of with Arrays.asList --- .../ClassificationEvaluationIT.java | 50 ++++++++++--------- 1 file changed, 26 insertions(+), 24 deletions(-) diff --git a/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/ClassificationEvaluationIT.java b/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/ClassificationEvaluationIT.java index 5e961f97cc611..c0735ffe980d8 100644 --- a/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/ClassificationEvaluationIT.java +++ b/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/ClassificationEvaluationIT.java @@ -82,7 +82,7 @@ public void testEvaluate_AllMetrics() { new Classification( ANIMAL_NAME_KEYWORD_FIELD, ANIMAL_NAME_PREDICTION_KEYWORD_FIELD, - List.of(new Accuracy(), new MulticlassConfusionMatrix(), new Precision(), new Recall()))); + Arrays.asList(new Accuracy(), new MulticlassConfusionMatrix(), new Precision(), new Recall()))); assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Classification.NAME.getPreferredName())); assertThat( @@ -116,7 +116,7 @@ public void testEvaluate_AllMetrics_KeywordField_CaseSensitivity() { new Classification( actualField, predictedField, - List.of(new Accuracy(), new MulticlassConfusionMatrix(), new Precision(), new Recall()))); + Arrays.asList(new Accuracy(), new MulticlassConfusionMatrix(), new Precision(), new Recall()))); Accuracy.Result accuracyResult = (Accuracy.Result) evaluateDataFrameResponse.getMetrics().get(0); assertThat(accuracyResult.getClasses(), contains(new Accuracy.PerClassResult("crocodile", 0.0))); @@ -126,9 +126,9 @@ public void testEvaluate_AllMetrics_KeywordField_CaseSensitivity() { (MulticlassConfusionMatrix.Result) evaluateDataFrameResponse.getMetrics().get(1); assertThat( confusionMatrixResult.getConfusionMatrix(), - equalTo(List.of( + equalTo(Arrays.asList( new MulticlassConfusionMatrix.ActualClass( - "crocodile", 1, List.of(new MulticlassConfusionMatrix.PredictedClass("crocodile", 0L)), 1)))); + "crocodile", 1, Arrays.asList(new MulticlassConfusionMatrix.PredictedClass("crocodile", 0L)), 1)))); Precision.Result precisionResult = (Precision.Result) evaluateDataFrameResponse.getMetrics().get(2); assertThat(precisionResult.getClasses(), empty()); @@ -141,7 +141,7 @@ public void testEvaluate_AllMetrics_KeywordField_CaseSensitivity() { private Accuracy.Result evaluateAccuracy(String actualField, String predictedField) { EvaluateDataFrameAction.Response evaluateDataFrameResponse = - evaluateDataFrame(ANIMALS_DATA_INDEX, new Classification(actualField, predictedField, List.of(new Accuracy()))); + evaluateDataFrame(ANIMALS_DATA_INDEX, new Classification(actualField, predictedField, Arrays.asList(new Accuracy()))); assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Classification.NAME.getPreferredName())); assertThat(evaluateDataFrameResponse.getMetrics(), hasSize(1)); @@ -153,7 +153,7 @@ private Accuracy.Result evaluateAccuracy(String actualField, String predictedFie public void testEvaluate_Accuracy_KeywordField() { List expectedPerClassResults = - List.of( + Arrays.asList( new Accuracy.PerClassResult("ant", 47.0 / 75), new Accuracy.PerClassResult("cat", 47.0 / 75), new Accuracy.PerClassResult("dog", 47.0 / 75), @@ -173,7 +173,7 @@ public void testEvaluate_Accuracy_KeywordField() { public void testEvaluate_Accuracy_IntegerField() { List expectedPerClassResults = - List.of( + Arrays.asList( new Accuracy.PerClassResult("1", 57.0 / 75), new Accuracy.PerClassResult("2", 54.0 / 75), new Accuracy.PerClassResult("3", 51.0 / 75), @@ -203,7 +203,7 @@ public void testEvaluate_Accuracy_IntegerField() { public void testEvaluate_Accuracy_BooleanField() { List expectedPerClassResults = - List.of( + Arrays.asList( new Accuracy.PerClassResult("false", 18.0 / 30), new Accuracy.PerClassResult("true", 27.0 / 45)); double expectedOverallAccuracy = 45.0 / 75; @@ -232,7 +232,7 @@ public void testEvaluate_Accuracy_FieldTypeMismatch() { { // When actual and predicted fields have different types, the sets of classes are disjoint List expectedPerClassResults = - List.of( + Arrays.asList( new Accuracy.PerClassResult("1", 0.8), new Accuracy.PerClassResult("2", 0.8), new Accuracy.PerClassResult("3", 0.8), @@ -247,7 +247,7 @@ public void testEvaluate_Accuracy_FieldTypeMismatch() { { // When actual and predicted fields have different types, the sets of classes are disjoint List expectedPerClassResults = - List.of( + Arrays.asList( new Accuracy.PerClassResult("false", 0.6), new Accuracy.PerClassResult("true", 0.4)); double expectedOverallAccuracy = 0.0; @@ -260,7 +260,7 @@ public void testEvaluate_Accuracy_FieldTypeMismatch() { private Precision.Result evaluatePrecision(String actualField, String predictedField) { EvaluateDataFrameAction.Response evaluateDataFrameResponse = - evaluateDataFrame(ANIMALS_DATA_INDEX, new Classification(actualField, predictedField, List.of(new Precision()))); + evaluateDataFrame(ANIMALS_DATA_INDEX, new Classification(actualField, predictedField, Arrays.asList(new Precision()))); assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Classification.NAME.getPreferredName())); assertThat(evaluateDataFrameResponse.getMetrics(), hasSize(1)); @@ -272,7 +272,7 @@ private Precision.Result evaluatePrecision(String actualField, String predictedF public void testEvaluate_Precision_KeywordField() { List expectedPerClassResults = - List.of( + Arrays.asList( new Precision.PerClassResult("ant", 1.0 / 15), new Precision.PerClassResult("cat", 1.0 / 15), new Precision.PerClassResult("dog", 1.0 / 15), @@ -289,7 +289,7 @@ public void testEvaluate_Precision_KeywordField() { public void testEvaluate_Precision_IntegerField() { List expectedPerClassResults = - List.of( + Arrays.asList( new Precision.PerClassResult("1", 0.2), new Precision.PerClassResult("2", 0.2), new Precision.PerClassResult("3", 0.2), @@ -313,7 +313,7 @@ public void testEvaluate_Precision_IntegerField() { public void testEvaluate_Precision_BooleanField() { List expectedPerClassResults = - List.of( + Arrays.asList( new Precision.PerClassResult("false", 0.5), new Precision.PerClassResult("true", 9.0 / 13)); double expectedAvgPrecision = 31.0 / 52; @@ -354,13 +354,13 @@ public void testEvaluate_Precision_CardinalityTooHigh() { ElasticsearchStatusException.class, () -> evaluateDataFrame( ANIMALS_DATA_INDEX, - new Classification(ANIMAL_NAME_KEYWORD_FIELD, ANIMAL_NAME_PREDICTION_KEYWORD_FIELD, List.of(new Precision())))); + new Classification(ANIMAL_NAME_KEYWORD_FIELD, ANIMAL_NAME_PREDICTION_KEYWORD_FIELD, Arrays.asList(new Precision())))); assertThat(e.getMessage(), containsString("Cardinality of field [animal_name_keyword] is too high")); } private Recall.Result evaluateRecall(String actualField, String predictedField) { EvaluateDataFrameAction.Response evaluateDataFrameResponse = - evaluateDataFrame(ANIMALS_DATA_INDEX, new Classification(actualField, predictedField, List.of(new Recall()))); + evaluateDataFrame(ANIMALS_DATA_INDEX, new Classification(actualField, predictedField, Arrays.asList(new Recall()))); assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Classification.NAME.getPreferredName())); assertThat(evaluateDataFrameResponse.getMetrics(), hasSize(1)); @@ -372,7 +372,7 @@ private Recall.Result evaluateRecall(String actualField, String predictedField) public void testEvaluate_Recall_KeywordField() { List expectedPerClassResults = - List.of( + Arrays.asList( new Recall.PerClassResult("ant", 1.0 / 15), new Recall.PerClassResult("cat", 1.0 / 15), new Recall.PerClassResult("dog", 1.0 / 15), @@ -389,7 +389,7 @@ public void testEvaluate_Recall_KeywordField() { public void testEvaluate_Recall_IntegerField() { List expectedPerClassResults = - List.of( + Arrays.asList( new Recall.PerClassResult("1", 1.0 / 15), new Recall.PerClassResult("2", 2.0 / 15), new Recall.PerClassResult("3", 3.0 / 15), @@ -413,7 +413,7 @@ public void testEvaluate_Recall_IntegerField() { public void testEvaluate_Recall_BooleanField() { List expectedPerClassResults = - List.of( + Arrays.asList( new Recall.PerClassResult("true", 0.6), new Recall.PerClassResult("false", 0.6)); double expectedAvgRecall = 0.6; @@ -436,7 +436,7 @@ public void testEvaluate_Recall_FieldTypeMismatch() { { // When actual and predicted fields have different types, the sets of classes are disjoint, hence 0.0 results here List expectedPerClassResults = - List.of( + Arrays.asList( new Recall.PerClassResult("1", 0.0), new Recall.PerClassResult("2", 0.0), new Recall.PerClassResult("3", 0.0), @@ -451,7 +451,7 @@ public void testEvaluate_Recall_FieldTypeMismatch() { { // When actual and predicted fields have different types, the sets of classes are disjoint, hence 0.0 results here List expectedPerClassResults = - List.of( + Arrays.asList( new Recall.PerClassResult("true", 0.0), new Recall.PerClassResult("false", 0.0)); double expectedAvgRecall = 0.0; @@ -469,7 +469,7 @@ public void testEvaluate_Recall_CardinalityTooHigh() { ElasticsearchStatusException.class, () -> evaluateDataFrame( ANIMALS_DATA_INDEX, - new Classification(ANIMAL_NAME_KEYWORD_FIELD, ANIMAL_NAME_PREDICTION_KEYWORD_FIELD, List.of(new Recall())))); + new Classification(ANIMAL_NAME_KEYWORD_FIELD, ANIMAL_NAME_PREDICTION_KEYWORD_FIELD, Arrays.asList(new Recall())))); assertThat(e.getMessage(), containsString("Cardinality of field [animal_name_keyword] is too high")); } @@ -478,7 +478,7 @@ private void evaluateMulticlassConfusionMatrix() { evaluateDataFrame( ANIMALS_DATA_INDEX, new Classification( - ANIMAL_NAME_KEYWORD_FIELD, ANIMAL_NAME_PREDICTION_KEYWORD_FIELD, List.of(new MulticlassConfusionMatrix()))); + ANIMAL_NAME_KEYWORD_FIELD, ANIMAL_NAME_PREDICTION_KEYWORD_FIELD, Arrays.asList(new MulticlassConfusionMatrix()))); assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Classification.NAME.getPreferredName())); assertThat(evaluateDataFrameResponse.getMetrics(), hasSize(1)); @@ -559,7 +559,9 @@ public void testEvaluate_ConfusionMatrixMetricWithUserProvidedSize() { evaluateDataFrame( ANIMALS_DATA_INDEX, new Classification( - ANIMAL_NAME_KEYWORD_FIELD, ANIMAL_NAME_PREDICTION_KEYWORD_FIELD, List.of(new MulticlassConfusionMatrix(3, null)))); + ANIMAL_NAME_KEYWORD_FIELD, + ANIMAL_NAME_PREDICTION_KEYWORD_FIELD, + Arrays.asList(new MulticlassConfusionMatrix(3, null)))); assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Classification.NAME.getPreferredName())); assertThat(evaluateDataFrameResponse.getMetrics(), hasSize(1)); From bec1b29ec89b1669fb489ce5810da9c5f3658ca4 Mon Sep 17 00:00:00 2001 From: Przemyslaw Witek Date: Fri, 27 Mar 2020 08:57:49 +0100 Subject: [PATCH 3/3] Fix compile errors --- .../xpack/ml/integration/ClassificationEvaluationIT.java | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/ClassificationEvaluationIT.java b/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/ClassificationEvaluationIT.java index c0735ffe980d8..6e135c9995d1d 100644 --- a/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/ClassificationEvaluationIT.java +++ b/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/ClassificationEvaluationIT.java @@ -99,11 +99,11 @@ public void testEvaluate_AllMetrics_KeywordField_CaseSensitivity() { String actualField = "fieldA"; String predictedField = "fieldB"; client().admin().indices().prepareCreate(indexName) - .setMapping( + .addMapping("_doc", actualField, "type=keyword", predictedField, "type=keyword") .get(); - client().prepareIndex(indexName) + client().prepareIndex(indexName, "_doc") .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE) .setSource( actualField, "crocodile",