Skip to content

Commit bfcfcde

Browse files
authored
[7.x] Do not copy mapping from dependent variable to prediction field in regression analysis (#51227) (#51288)
1 parent 1009f92 commit bfcfcde

File tree

12 files changed

+179
-101
lines changed

12 files changed

+179
-101
lines changed

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/Classification.java

Lines changed: 27 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
1515
import org.elasticsearch.common.xcontent.XContentBuilder;
1616
import org.elasticsearch.common.xcontent.XContentParser;
17+
import org.elasticsearch.index.mapper.FieldAliasMapper;
1718
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
1819

1920
import java.io.IOException;
@@ -28,6 +29,7 @@
2829

2930
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg;
3031
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.optionalConstructorArg;
32+
import static org.elasticsearch.common.xcontent.support.XContentMapValues.extractValue;
3133

3234
public class Classification implements DataFrameAnalysis {
3335

@@ -248,12 +250,32 @@ public Map<String, Long> getFieldCardinalityLimits() {
248250
return Collections.singletonMap(dependentVariable, 2L);
249251
}
250252

253+
@SuppressWarnings("unchecked")
251254
@Override
252-
public Map<String, String> getExplicitlyMappedFields(String resultsFieldName) {
253-
return new HashMap<String, String>() {{
254-
put(resultsFieldName + "." + predictionFieldName, dependentVariable);
255-
put(resultsFieldName + ".top_classes.class_name", dependentVariable);
256-
}};
255+
public Map<String, Object> getExplicitlyMappedFields(Map<String, Object> mappingsProperties, String resultsFieldName) {
256+
Object dependentVariableMapping = extractMapping(dependentVariable, mappingsProperties);
257+
if ((dependentVariableMapping instanceof Map) == false) {
258+
return Collections.emptyMap();
259+
}
260+
Map<String, Object> dependentVariableMappingAsMap = (Map) dependentVariableMapping;
261+
// If the source field is an alias, fetch the concrete field that the alias points to.
262+
if (FieldAliasMapper.CONTENT_TYPE.equals(dependentVariableMappingAsMap.get("type"))) {
263+
String path = (String) dependentVariableMappingAsMap.get(FieldAliasMapper.Names.PATH);
264+
dependentVariableMapping = extractMapping(path, mappingsProperties);
265+
}
266+
// We may have updated the value of {@code dependentVariableMapping} in the "if" block above.
267+
// Hence, we need to check the "instanceof" condition again.
268+
if ((dependentVariableMapping instanceof Map) == false) {
269+
return Collections.emptyMap();
270+
}
271+
Map<String, Object> additionalProperties = new HashMap<>();
272+
additionalProperties.put(resultsFieldName + "." + predictionFieldName, dependentVariableMapping);
273+
additionalProperties.put(resultsFieldName + ".top_classes.class_name", dependentVariableMapping);
274+
return additionalProperties;
275+
}
276+
277+
private static Object extractMapping(String path, Map<String, Object> mappingsProperties) {
278+
return extractValue(String.join(".properties.", path.split("\\.")), mappingsProperties);
257279
}
258280

259281
@Override

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/DataFrameAnalysis.java

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -42,15 +42,13 @@ public interface DataFrameAnalysis extends ToXContentObject, NamedWriteable {
4242
Map<String, Long> getFieldCardinalityLimits();
4343

4444
/**
45-
* Returns fields for which the mappings should be copied from source index to destination index.
46-
* Each entry of the returned {@link Map} is of the form:
47-
* key - field path in the destination index
48-
* value - field path in the source index from which the mapping should be taken
45+
* Returns fields for which the mappings should be either predefined or copied from source index to destination index.
4946
*
47+
* @param mappingsProperties mappings.properties portion of the index mappings
5048
* @param resultsFieldName name of the results field under which all the results are stored
51-
* @return {@link Map} containing fields for which the mappings should be copied from source index to destination index
49+
* @return {@link Map} containing fields for which the mappings should be handled explicitly
5250
*/
53-
Map<String, String> getExplicitlyMappedFields(String resultsFieldName);
51+
Map<String, Object> getExplicitlyMappedFields(Map<String, Object> mappingsProperties, String resultsFieldName);
5452

5553
/**
5654
* @return {@code true} if this analysis supports data frame rows with missing values

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/OutlierDetection.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -230,7 +230,7 @@ public Map<String, Long> getFieldCardinalityLimits() {
230230
}
231231

232232
@Override
233-
public Map<String, String> getExplicitlyMappedFields(String resultsFieldName) {
233+
public Map<String, Object> getExplicitlyMappedFields(Map<String, Object> mappingsProperties, String resultsFieldName) {
234234
return Collections.emptyMap();
235235
}
236236

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/Regression.java

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -187,8 +187,10 @@ public Map<String, Long> getFieldCardinalityLimits() {
187187
}
188188

189189
@Override
190-
public Map<String, String> getExplicitlyMappedFields(String resultsFieldName) {
191-
return Collections.singletonMap(resultsFieldName + "." + predictionFieldName, dependentVariable);
190+
public Map<String, Object> getExplicitlyMappedFields(Map<String, Object> mappingsProperties, String resultsFieldName) {
191+
// Prediction field should be always mapped as "double" rather than "float" in order to increase precision in case of
192+
// high (over 10M) values of dependent variable.
193+
return Collections.singletonMap(resultsFieldName + "." + predictionFieldName, Collections.singletonMap("type", "double"));
192194
}
193195

194196
@Override

x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/ClassificationTests.java

Lines changed: 35 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
import java.util.Map;
2626
import java.util.Set;
2727

28+
import static org.hamcrest.Matchers.allOf;
2829
import static org.hamcrest.Matchers.anEmptyMap;
2930
import static org.hamcrest.Matchers.containsString;
3031
import static org.hamcrest.Matchers.empty;
@@ -171,8 +172,40 @@ public void testFieldCardinalityLimitsIsNonEmpty() {
171172
assertThat(createTestInstance().getFieldCardinalityLimits(), is(not(anEmptyMap())));
172173
}
173174

174-
public void testFieldMappingsToCopyIsNonEmpty() {
175-
assertThat(createTestInstance().getExplicitlyMappedFields(""), is(not(anEmptyMap())));
175+
public void testGetExplicitlyMappedFields() {
176+
assertThat(new Classification("foo").getExplicitlyMappedFields(null, "results"), is(anEmptyMap()));
177+
assertThat(new Classification("foo").getExplicitlyMappedFields(Collections.emptyMap(), "results"), is(anEmptyMap()));
178+
assertThat(
179+
new Classification("foo").getExplicitlyMappedFields(Collections.singletonMap("foo", "not_a_map"), "results"),
180+
is(anEmptyMap()));
181+
assertThat(
182+
new Classification("foo").getExplicitlyMappedFields(
183+
Collections.singletonMap("foo", Collections.singletonMap("bar", "baz")),
184+
"results"),
185+
allOf(
186+
hasEntry("results.foo_prediction", Collections.singletonMap("bar", "baz")),
187+
hasEntry("results.top_classes.class_name", Collections.singletonMap("bar", "baz"))));
188+
assertThat(
189+
new Classification("foo").getExplicitlyMappedFields(
190+
new HashMap<String, Object>() {{
191+
put("foo", new HashMap<String, String>() {{
192+
put("type", "alias");
193+
put("path", "bar");
194+
}});
195+
put("bar", Collections.singletonMap("type", "long"));
196+
}},
197+
"results"),
198+
allOf(
199+
hasEntry("results.foo_prediction", Collections.singletonMap("type", "long")),
200+
hasEntry("results.top_classes.class_name", Collections.singletonMap("type", "long"))));
201+
assertThat(
202+
new Classification("foo").getExplicitlyMappedFields(
203+
Collections.singletonMap("foo", new HashMap<String, String>() {{
204+
put("type", "alias");
205+
put("path", "missing");
206+
}}),
207+
"results"),
208+
is(anEmptyMap()));
176209
}
177210

178211
public void testToXContent_GivenVersionBeforeRandomizeSeedWasIntroduced() throws IOException {

x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/OutlierDetectionTests.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -92,8 +92,8 @@ public void testFieldCardinalityLimitsIsEmpty() {
9292
assertThat(createTestInstance().getFieldCardinalityLimits(), is(anEmptyMap()));
9393
}
9494

95-
public void testFieldMappingsToCopyIsEmpty() {
96-
assertThat(createTestInstance().getExplicitlyMappedFields(""), is(anEmptyMap()));
95+
public void testGetExplicitlyMappedFields() {
96+
assertThat(createTestInstance().getExplicitlyMappedFields(null, null), is(anEmptyMap()));
9797
}
9898

9999
public void testGetStateDocId() {

x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/RegressionTests.java

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ protected Regression createTestInstance() {
4343
return createRandom();
4444
}
4545

46-
public static Regression createRandom() {
46+
private static Regression createRandom() {
4747
String dependentVariableName = randomAlphaOfLength(10);
4848
BoostedTreeParams boostedTreeParams = BoostedTreeParamsTests.createRandom();
4949
String predictionFieldName = randomBoolean() ? null : randomAlphaOfLength(10);
@@ -110,8 +110,10 @@ public void testFieldCardinalityLimitsIsEmpty() {
110110
assertThat(createTestInstance().getFieldCardinalityLimits(), is(anEmptyMap()));
111111
}
112112

113-
public void testFieldMappingsToCopyIsNonEmpty() {
114-
assertThat(createTestInstance().getExplicitlyMappedFields(""), is(not(anEmptyMap())));
113+
public void testGetExplicitlyMappedFields() {
114+
assertThat(
115+
new Regression("foo").getExplicitlyMappedFields(null, "results"),
116+
hasEntry("results.foo_prediction", Collections.singletonMap("type", "double")));
115117
}
116118

117119
public void testGetStateDocId() {

x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/ClassificationIT.java

Lines changed: 7 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,6 @@
77

88
import com.google.common.collect.Ordering;
99
import org.elasticsearch.ElasticsearchStatusException;
10-
import org.elasticsearch.action.admin.indices.get.GetIndexAction;
11-
import org.elasticsearch.action.admin.indices.get.GetIndexRequest;
1210
import org.elasticsearch.action.admin.indices.refresh.RefreshRequest;
1311
import org.elasticsearch.action.bulk.BulkRequestBuilder;
1412
import org.elasticsearch.action.bulk.BulkResponse;
@@ -42,7 +40,6 @@
4240
import java.util.Set;
4341

4442
import static java.util.stream.Collectors.toList;
45-
import static org.elasticsearch.common.xcontent.support.XContentMapValues.extractValue;
4643
import static org.hamcrest.Matchers.allOf;
4744
import static org.hamcrest.Matchers.anyOf;
4845
import static org.hamcrest.Matchers.equalTo;
@@ -116,7 +113,7 @@ public void testSingleNumericFeatureAndMixedTrainingAndNonTrainingRows() throws
116113
assertThat(searchStoredProgress(jobId).getHits().getTotalHits().value, equalTo(1L));
117114
assertModelStatePersisted(stateDocId());
118115
assertInferenceModelPersisted(jobId);
119-
assertMlResultsFieldMappings(predictedClassField, "keyword");
116+
assertMlResultsFieldMappings(destIndex, predictedClassField, "keyword");
120117
assertThatAuditMessagesMatch(jobId,
121118
"Created analytics with analysis type [classification]",
122119
"Estimated memory usage for this analytics to be",
@@ -157,7 +154,7 @@ public void testWithOnlyTrainingRowsAndTrainingPercentIsHundred() throws Excepti
157154
assertThat(searchStoredProgress(jobId).getHits().getTotalHits().value, equalTo(1L));
158155
assertModelStatePersisted(stateDocId());
159156
assertInferenceModelPersisted(jobId);
160-
assertMlResultsFieldMappings(predictedClassField, "keyword");
157+
assertMlResultsFieldMappings(destIndex, predictedClassField, "keyword");
161158
assertThatAuditMessagesMatch(jobId,
162159
"Created analytics with analysis type [classification]",
163160
"Estimated memory usage for this analytics to be",
@@ -220,7 +217,7 @@ public <T> void testWithOnlyTrainingRowsAndTrainingPercentIsFifty(String jobId,
220217
assertThat(searchStoredProgress(jobId).getHits().getTotalHits().value, equalTo(1L));
221218
assertModelStatePersisted(stateDocId());
222219
assertInferenceModelPersisted(jobId);
223-
assertMlResultsFieldMappings(predictedClassField, expectedMappingTypeForPredictedField);
220+
assertMlResultsFieldMappings(destIndex, predictedClassField, expectedMappingTypeForPredictedField);
224221
assertThatAuditMessagesMatch(jobId,
225222
"Created analytics with analysis type [classification]",
226223
"Estimated memory usage for this analytics to be",
@@ -308,7 +305,7 @@ public void testStopAndRestart() throws Exception {
308305
assertThat(searchStoredProgress(jobId).getHits().getTotalHits().value, equalTo(1L));
309306
assertModelStatePersisted(stateDocId());
310307
assertInferenceModelPersisted(jobId);
311-
assertMlResultsFieldMappings(predictedClassField, "keyword");
308+
assertMlResultsFieldMappings(destIndex, predictedClassField, "keyword");
312309
assertEvaluation(KEYWORD_FIELD, KEYWORD_FIELD_VALUES, "ml." + predictedClassField);
313310
}
314311

@@ -365,7 +362,7 @@ public void testDependentVariableIsNested() throws Exception {
365362
assertThat(searchStoredProgress(jobId).getHits().getTotalHits().value, equalTo(1L));
366363
assertModelStatePersisted(stateDocId());
367364
assertInferenceModelPersisted(jobId);
368-
assertMlResultsFieldMappings(predictedClassField, "keyword");
365+
assertMlResultsFieldMappings(destIndex, predictedClassField, "keyword");
369366
assertEvaluation(NESTED_FIELD, KEYWORD_FIELD_VALUES, "ml." + predictedClassField);
370367
}
371368

@@ -384,7 +381,7 @@ public void testDependentVariableIsAliasToKeyword() throws Exception {
384381
assertThat(searchStoredProgress(jobId).getHits().getTotalHits().value, equalTo(1L));
385382
assertModelStatePersisted(stateDocId());
386383
assertInferenceModelPersisted(jobId);
387-
assertMlResultsFieldMappings(predictedClassField, "keyword");
384+
assertMlResultsFieldMappings(destIndex, predictedClassField, "keyword");
388385
assertEvaluation(ALIAS_TO_KEYWORD_FIELD, KEYWORD_FIELD_VALUES, "ml." + predictedClassField);
389386
}
390387

@@ -403,7 +400,7 @@ public void testDependentVariableIsAliasToNested() throws Exception {
403400
assertThat(searchStoredProgress(jobId).getHits().getTotalHits().value, equalTo(1L));
404401
assertModelStatePersisted(stateDocId());
405402
assertInferenceModelPersisted(jobId);
406-
assertMlResultsFieldMappings(predictedClassField, "keyword");
403+
assertMlResultsFieldMappings(destIndex, predictedClassField, "keyword");
407404
assertEvaluation(ALIAS_TO_NESTED_FIELD, KEYWORD_FIELD_VALUES, "ml." + predictedClassField);
408405
}
409406

@@ -564,15 +561,6 @@ private static Map<String, Object> getDestDoc(DataFrameAnalyticsConfig config, S
564561
return destDoc;
565562
}
566563

567-
/**
568-
* Wrapper around extractValue that:
569-
* - allows dots (".") in the path elements provided as arguments
570-
* - supports implicit casting to the appropriate type
571-
*/
572-
private static <T> T getFieldValue(Map<String, Object> doc, String... path) {
573-
return (T)extractValue(String.join(".", path), doc);
574-
}
575-
576564
private static <T> void assertTopClasses(Map<String, Object> resultsObject,
577565
int numTopClasses,
578566
String dependentVariable,
@@ -656,27 +644,6 @@ private <T> void assertEvaluation(String dependentVariable, List<T> dependentVar
656644
}
657645
}
658646

659-
private void assertMlResultsFieldMappings(String predictedClassField, String expectedType) {
660-
Map<String, Object> mappings =
661-
client()
662-
.execute(GetIndexAction.INSTANCE, new GetIndexRequest().indices(destIndex))
663-
.actionGet()
664-
.mappings()
665-
.get(destIndex)
666-
.get("_doc")
667-
.sourceAsMap();
668-
assertThat(
669-
mappings.toString(),
670-
getFieldValue(
671-
mappings,
672-
"properties", "ml", "properties", String.join(".properties.", predictedClassField.split("\\.")), "type"),
673-
equalTo(expectedType));
674-
assertThat(
675-
mappings.toString(),
676-
getFieldValue(mappings, "properties", "ml", "properties", "top_classes", "properties", "class_name", "type"),
677-
equalTo(expectedType));
678-
}
679-
680647
private String stateDocId() {
681648
return jobId + "_classification_state#1";
682649
}

x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/MlNativeDataFrameAnalyticsIntegTestCase.java

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
*/
66
package org.elasticsearch.xpack.ml.integration;
77

8+
import org.elasticsearch.action.admin.indices.get.GetIndexAction;
9+
import org.elasticsearch.action.admin.indices.get.GetIndexRequest;
810
import org.elasticsearch.action.admin.indices.refresh.RefreshAction;
911
import org.elasticsearch.action.admin.indices.refresh.RefreshRequest;
1012
import org.elasticsearch.action.admin.indices.refresh.RefreshResponse;
@@ -53,6 +55,7 @@
5355
import java.util.concurrent.TimeUnit;
5456
import java.util.stream.Collectors;
5557

58+
import static org.elasticsearch.common.xcontent.support.XContentMapValues.extractValue;
5659
import static org.hamcrest.Matchers.anyOf;
5760
import static org.hamcrest.Matchers.arrayWithSize;
5861
import static org.hamcrest.Matchers.equalTo;
@@ -281,4 +284,36 @@ protected static void assertModelStatePersisted(String stateDocId) {
281284
.get();
282285
assertThat(searchResponse.getHits().getHits().length, equalTo(1));
283286
}
287+
288+
protected static void assertMlResultsFieldMappings(String index, String predictedClassField, String expectedType) {
289+
Map<String, Object> mappings =
290+
client()
291+
.execute(GetIndexAction.INSTANCE, new GetIndexRequest().indices(index))
292+
.actionGet()
293+
.mappings()
294+
.get(index)
295+
.get("_doc")
296+
.sourceAsMap();
297+
assertThat(
298+
mappings.toString(),
299+
getFieldValue(
300+
mappings,
301+
"properties", "ml", "properties", String.join(".properties.", predictedClassField.split("\\.")), "type"),
302+
equalTo(expectedType));
303+
if (getFieldValue(mappings, "properties", "ml", "properties", "top_classes") != null) {
304+
assertThat(
305+
mappings.toString(),
306+
getFieldValue(mappings, "properties", "ml", "properties", "top_classes", "properties", "class_name", "type"),
307+
equalTo(expectedType));
308+
}
309+
}
310+
311+
/**
312+
* Wrapper around extractValue that:
313+
* - allows dots (".") in the path elements provided as arguments
314+
* - supports implicit casting to the appropriate type
315+
*/
316+
protected static <T> T getFieldValue(Map<String, Object> doc, String... path) {
317+
return (T)extractValue(String.join(".", path), doc);
318+
}
284319
}

0 commit comments

Comments
 (0)