Skip to content

Commit f3ccd19

Browse files
committed
[ML] handles compressed model stream from native process
1 parent ef66191 commit f3ccd19

File tree

11 files changed

+823
-273
lines changed

11 files changed

+823
-273
lines changed

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/messages/Messages.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,7 @@ public final class Messages {
8888
" (a-z and 0-9), hyphens or underscores; must start and end with alphanumeric";
8989

9090
public static final String INFERENCE_TRAINED_MODEL_EXISTS = "Trained machine learning model [{0}] already exists";
91+
public static final String INFERENCE_TRAINED_MODEL_DOC_EXISTS = "Trained machine learning model chunked doc [{0}][{1}] already exists";
9192
public static final String INFERENCE_FAILED_TO_STORE_MODEL = "Failed to store trained machine learning model [{0}]";
9293
public static final String INFERENCE_NOT_FOUND = "Could not find trained model [{0}]";
9394
public static final String INFERENCE_NOT_FOUND_MULTIPLE = "Could not find trained models {0}";
Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,129 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License;
4+
* you may not use this file except in compliance with the Elastic License.
5+
*/
6+
package org.elasticsearch.xpack.ml.integration;
7+
8+
import org.elasticsearch.ElasticsearchException;
9+
import org.elasticsearch.Version;
10+
import org.elasticsearch.action.support.PlainActionFuture;
11+
import org.elasticsearch.common.collect.Tuple;
12+
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
13+
import org.elasticsearch.license.License;
14+
import org.elasticsearch.xpack.core.action.util.PageParams;
15+
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig;
16+
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsDest;
17+
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsSource;
18+
import org.elasticsearch.xpack.core.ml.dataframe.analyses.Regression;
19+
import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider;
20+
import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
21+
import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinition;
22+
import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinitionTests;
23+
import org.elasticsearch.xpack.core.ml.inference.TrainedModelInputTests;
24+
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TargetType;
25+
import org.elasticsearch.xpack.ml.MlSingleNodeTestCase;
26+
import org.elasticsearch.xpack.ml.dataframe.process.ChunkedTrainedModelPersister;
27+
import org.elasticsearch.xpack.ml.dataframe.process.results.TrainedModelDefinitionChunk;
28+
import org.elasticsearch.xpack.ml.extractor.DocValueField;
29+
import org.elasticsearch.xpack.ml.extractor.ExtractedField;
30+
import org.elasticsearch.xpack.ml.inference.modelsize.MlModelSizeNamedXContentProvider;
31+
import org.elasticsearch.xpack.ml.inference.modelsize.ModelSizeInfo;
32+
import org.elasticsearch.xpack.ml.inference.modelsize.ModelSizeInfoTests;
33+
import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider;
34+
import org.elasticsearch.xpack.ml.notifications.DataFrameAnalyticsAuditor;
35+
import org.junit.Before;
36+
37+
import java.io.IOException;
38+
import java.util.ArrayList;
39+
import java.util.Collections;
40+
import java.util.List;
41+
import java.util.Set;
42+
43+
import static org.hamcrest.Matchers.equalTo;
44+
45+
public class ChunkedTrainedMoodelPersisterIT extends MlSingleNodeTestCase {
46+
47+
private TrainedModelProvider trainedModelProvider;
48+
49+
@Before
50+
public void createComponents() throws Exception {
51+
trainedModelProvider = new TrainedModelProvider(client(), xContentRegistry());
52+
waitForMlTemplates();
53+
}
54+
55+
public void testStoreModelViaChunkedPersister() throws IOException {
56+
String modelId = "stored-chunked-model";
57+
DataFrameAnalyticsConfig analyticsConfig = new DataFrameAnalyticsConfig.Builder()
58+
.setId(modelId)
59+
.setSource(new DataFrameAnalyticsSource(new String[] {"my_source"}, null, null))
60+
.setDest(new DataFrameAnalyticsDest("my_dest", null))
61+
.setAnalysis(new Regression("foo"))
62+
.build();
63+
List<ExtractedField> extractedFieldList = Collections.singletonList(new DocValueField("foo", Collections.emptySet()));
64+
TrainedModelConfig.Builder configBuilder = buildTrainedModelConfigBuilder(modelId);
65+
String compressedDefinition = configBuilder.build().getCompressedDefinition();
66+
int totalSize = compressedDefinition.length();
67+
List<String> chunks = chunkStringWithSize(compressedDefinition, totalSize/3);
68+
69+
ChunkedTrainedModelPersister persister = new ChunkedTrainedModelPersister(trainedModelProvider,
70+
analyticsConfig,
71+
new DataFrameAnalyticsAuditor(client(), "test-node"),
72+
(ex) -> { throw new ElasticsearchException(ex); },
73+
extractedFieldList
74+
);
75+
76+
//Accuracy for size is not tested here
77+
ModelSizeInfo modelSizeInfo = ModelSizeInfoTests.createRandom();
78+
persister.createAndIndexInferenceModelMetadata(modelSizeInfo);
79+
for (String chunk : chunks) {
80+
persister.createAndIndexInferenceModelDoc(new TrainedModelDefinitionChunk(chunk, totalSize));
81+
}
82+
83+
PlainActionFuture<Tuple<Long, Set<String>>> getIdsFuture = new PlainActionFuture<>();
84+
trainedModelProvider.expandIds(modelId + "*", false, PageParams.defaultParams(), Collections.emptySet(), getIdsFuture);
85+
Tuple<Long, Set<String>> ids = getIdsFuture.actionGet();
86+
assertThat(ids.v1(), equalTo(1L));
87+
88+
PlainActionFuture<TrainedModelConfig> getTrainedModelFuture = new PlainActionFuture<>();
89+
trainedModelProvider.getTrainedModel(ids.v2().iterator().next(), true, getTrainedModelFuture);
90+
91+
TrainedModelConfig storedConfig = getTrainedModelFuture.actionGet();
92+
assertThat(storedConfig.getCompressedDefinition(), equalTo(compressedDefinition));
93+
assertThat(storedConfig.getEstimatedOperations(), equalTo((long)modelSizeInfo.numOperations()));
94+
assertThat(storedConfig.getEstimatedHeapMemory(), equalTo(modelSizeInfo.ramBytesUsed()));
95+
}
96+
97+
private static TrainedModelConfig.Builder buildTrainedModelConfigBuilder(String modelId) {
98+
TrainedModelDefinition.Builder definitionBuilder = TrainedModelDefinitionTests.createRandomBuilder();
99+
long bytesUsed = definitionBuilder.build().ramBytesUsed();
100+
long operations = definitionBuilder.build().getTrainedModel().estimatedNumOperations();
101+
return TrainedModelConfig.builder()
102+
.setCreatedBy("ml_test")
103+
.setParsedDefinition(TrainedModelDefinitionTests.createRandomBuilder(TargetType.REGRESSION))
104+
.setDescription("trained model config for test")
105+
.setModelId(modelId)
106+
.setVersion(Version.CURRENT)
107+
.setLicenseLevel(License.OperationMode.PLATINUM.description())
108+
.setEstimatedHeapMemory(bytesUsed)
109+
.setEstimatedOperations(operations)
110+
.setInput(TrainedModelInputTests.createRandomInput());
111+
}
112+
113+
private static List<String> chunkStringWithSize(String str, int chunkSize) {
114+
List<String> subStrings = new ArrayList<>((int)Math.ceil(str.length()/(double)chunkSize));
115+
for (int i = 0; i < str.length();i += chunkSize) {
116+
subStrings.add(str.substring(i, Math.min(i + chunkSize, str.length())));
117+
}
118+
return subStrings;
119+
}
120+
121+
@Override
122+
public NamedXContentRegistry xContentRegistry() {
123+
List<NamedXContentRegistry.Entry> namedXContent = new ArrayList<>();
124+
namedXContent.addAll(new MlInferenceNamedXContentProvider().getNamedXContentParsers());
125+
namedXContent.addAll(new MlModelSizeNamedXContentProvider().getNamedXContentParsers());
126+
return new NamedXContentRegistry(namedXContent);
127+
}
128+
129+
}

x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsResultProcessor.java

Lines changed: 17 additions & 138 deletions
Original file line numberDiff line numberDiff line change
@@ -8,53 +8,31 @@
88
import org.apache.logging.log4j.LogManager;
99
import org.apache.logging.log4j.Logger;
1010
import org.apache.logging.log4j.message.ParameterizedMessage;
11-
import org.elasticsearch.Version;
12-
import org.elasticsearch.action.ActionListener;
13-
import org.elasticsearch.action.LatchedActionListener;
1411
import org.elasticsearch.common.Nullable;
15-
import org.elasticsearch.common.xcontent.XContentHelper;
16-
import org.elasticsearch.common.xcontent.json.JsonXContent;
17-
import org.elasticsearch.license.License;
1812
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig;
19-
import org.elasticsearch.xpack.core.ml.dataframe.analyses.Classification;
20-
import org.elasticsearch.xpack.core.ml.dataframe.analyses.Regression;
2113
import org.elasticsearch.xpack.core.ml.dataframe.stats.classification.ClassificationStats;
2214
import org.elasticsearch.xpack.core.ml.dataframe.stats.common.MemoryUsage;
2315
import org.elasticsearch.xpack.core.ml.dataframe.stats.outlierdetection.OutlierDetectionStats;
2416
import org.elasticsearch.xpack.core.ml.dataframe.stats.regression.RegressionStats;
25-
import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
26-
import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinition;
27-
import org.elasticsearch.xpack.core.ml.inference.TrainedModelInput;
28-
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig;
29-
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig;
30-
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.PredictionFieldType;
31-
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfig;
32-
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TargetType;
3317
import org.elasticsearch.xpack.core.ml.job.messages.Messages;
3418
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
3519
import org.elasticsearch.xpack.core.ml.utils.PhaseProgress;
36-
import org.elasticsearch.xpack.core.security.user.XPackUser;
3720
import org.elasticsearch.xpack.ml.dataframe.process.results.AnalyticsResult;
3821
import org.elasticsearch.xpack.ml.dataframe.process.results.RowResults;
22+
import org.elasticsearch.xpack.ml.dataframe.process.results.TrainedModelDefinitionChunk;
3923
import org.elasticsearch.xpack.ml.dataframe.stats.StatsHolder;
4024
import org.elasticsearch.xpack.ml.dataframe.stats.StatsPersister;
4125
import org.elasticsearch.xpack.ml.extractor.ExtractedField;
42-
import org.elasticsearch.xpack.ml.extractor.MultiField;
26+
import org.elasticsearch.xpack.ml.inference.modelsize.ModelSizeInfo;
4327
import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider;
4428
import org.elasticsearch.xpack.ml.notifications.DataFrameAnalyticsAuditor;
4529

46-
import java.time.Instant;
4730
import java.util.Collections;
4831
import java.util.Iterator;
4932
import java.util.List;
50-
import java.util.Map;
5133
import java.util.Objects;
52-
import java.util.Optional;
5334
import java.util.concurrent.CountDownLatch;
54-
import java.util.concurrent.TimeUnit;
55-
import java.util.stream.Collectors;
5635

57-
import static java.util.stream.Collectors.toList;
5836

5937
public class AnalyticsResultProcessor {
6038

@@ -80,6 +58,7 @@ public class AnalyticsResultProcessor {
8058
private final StatsPersister statsPersister;
8159
private final List<ExtractedField> fieldNames;
8260
private final CountDownLatch completionLatch = new CountDownLatch(1);
61+
private final ChunkedTrainedModelPersister chunkedTrainedModelPersister;
8362
private volatile String failure;
8463
private volatile boolean isCancelled;
8564

@@ -93,6 +72,13 @@ public AnalyticsResultProcessor(DataFrameAnalyticsConfig analytics, DataFrameRow
9372
this.auditor = Objects.requireNonNull(auditor);
9473
this.statsPersister = Objects.requireNonNull(statsPersister);
9574
this.fieldNames = Collections.unmodifiableList(Objects.requireNonNull(fieldNames));
75+
this.chunkedTrainedModelPersister = new ChunkedTrainedModelPersister(
76+
trainedModelProvider,
77+
analytics,
78+
auditor,
79+
this::setAndReportFailure,
80+
fieldNames
81+
);
9682
}
9783

9884
@Nullable
@@ -171,9 +157,13 @@ private void processResult(AnalyticsResult result, DataFrameRowsJoiner resultsJo
171157
phaseProgress.getProgressPercent());
172158
statsHolder.getProgressTracker().updatePhase(phaseProgress);
173159
}
174-
TrainedModelDefinition.Builder inferenceModelBuilder = result.getInferenceModelBuilder();
175-
if (inferenceModelBuilder != null) {
176-
createAndIndexInferenceModel(inferenceModelBuilder);
160+
ModelSizeInfo modelSize = result.getModelSizeInfo();
161+
if (modelSize != null) {
162+
chunkedTrainedModelPersister.createAndIndexInferenceModelMetadata(modelSize);
163+
}
164+
TrainedModelDefinitionChunk trainedModelDefinitionChunk = result.getTrainedModelDefinitionChunk();
165+
if (trainedModelDefinitionChunk != null) {
166+
chunkedTrainedModelPersister.createAndIndexInferenceModelDoc(trainedModelDefinitionChunk);
177167
}
178168
MemoryUsage memoryUsage = result.getMemoryUsage();
179169
if (memoryUsage != null) {
@@ -197,117 +187,6 @@ private void processResult(AnalyticsResult result, DataFrameRowsJoiner resultsJo
197187
}
198188
}
199189

200-
private void createAndIndexInferenceModel(TrainedModelDefinition.Builder inferenceModel) {
201-
TrainedModelConfig trainedModelConfig = createTrainedModelConfig(inferenceModel);
202-
CountDownLatch latch = storeTrainedModel(trainedModelConfig);
203-
204-
try {
205-
if (latch.await(30, TimeUnit.SECONDS) == false) {
206-
LOGGER.error("[{}] Timed out (30s) waiting for inference model to be stored", analytics.getId());
207-
}
208-
} catch (InterruptedException e) {
209-
Thread.currentThread().interrupt();
210-
setAndReportFailure(ExceptionsHelper.serverError("interrupted waiting for inference model to be stored"));
211-
}
212-
}
213-
214-
private TrainedModelConfig createTrainedModelConfig(TrainedModelDefinition.Builder inferenceModel) {
215-
Instant createTime = Instant.now();
216-
String modelId = analytics.getId() + "-" + createTime.toEpochMilli();
217-
TrainedModelDefinition definition = inferenceModel.build();
218-
String dependentVariable = getDependentVariable();
219-
List<String> fieldNamesWithoutDependentVariable = fieldNames.stream()
220-
.map(ExtractedField::getName)
221-
.filter(f -> f.equals(dependentVariable) == false)
222-
.collect(toList());
223-
Map<String, String> defaultFieldMapping = fieldNames.stream()
224-
.filter(ef -> ef instanceof MultiField && (ef.getName().equals(dependentVariable) == false))
225-
.collect(Collectors.toMap(ExtractedField::getParentField, ExtractedField::getName));
226-
return TrainedModelConfig.builder()
227-
.setModelId(modelId)
228-
.setCreatedBy(XPackUser.NAME)
229-
.setVersion(Version.CURRENT)
230-
.setCreateTime(createTime)
231-
// NOTE: GET _cat/ml/trained_models relies on the creating analytics ID being in the tags
232-
.setTags(Collections.singletonList(analytics.getId()))
233-
.setDescription(analytics.getDescription())
234-
.setMetadata(Collections.singletonMap("analytics_config",
235-
XContentHelper.convertToMap(JsonXContent.jsonXContent, analytics.toString(), true)))
236-
.setEstimatedHeapMemory(definition.ramBytesUsed())
237-
.setEstimatedOperations(definition.getTrainedModel().estimatedNumOperations())
238-
.setParsedDefinition(inferenceModel)
239-
.setInput(new TrainedModelInput(fieldNamesWithoutDependentVariable))
240-
.setLicenseLevel(License.OperationMode.PLATINUM.description())
241-
.setDefaultFieldMap(defaultFieldMapping)
242-
.setInferenceConfig(buildInferenceConfig(definition.getTrainedModel().targetType()))
243-
.build();
244-
}
245-
246-
private InferenceConfig buildInferenceConfig(TargetType targetType) {
247-
switch (targetType) {
248-
case CLASSIFICATION:
249-
assert analytics.getAnalysis() instanceof Classification;
250-
Classification classification = ((Classification)analytics.getAnalysis());
251-
PredictionFieldType predictionFieldType = getPredictionFieldType(classification);
252-
return ClassificationConfig.builder()
253-
.setNumTopClasses(classification.getNumTopClasses())
254-
.setNumTopFeatureImportanceValues(classification.getBoostedTreeParams().getNumTopFeatureImportanceValues())
255-
.setPredictionFieldType(predictionFieldType)
256-
.build();
257-
case REGRESSION:
258-
assert analytics.getAnalysis() instanceof Regression;
259-
Regression regression = ((Regression)analytics.getAnalysis());
260-
return RegressionConfig.builder()
261-
.setNumTopFeatureImportanceValues(regression.getBoostedTreeParams().getNumTopFeatureImportanceValues())
262-
.build();
263-
default:
264-
throw ExceptionsHelper.serverError(
265-
"process created a model with an unsupported target type [{}]",
266-
null,
267-
targetType);
268-
}
269-
}
270-
271-
PredictionFieldType getPredictionFieldType(Classification classification) {
272-
String dependentVariable = classification.getDependentVariable();
273-
Optional<ExtractedField> extractedField = fieldNames.stream()
274-
.filter(f -> f.getName().equals(dependentVariable))
275-
.findAny();
276-
PredictionFieldType predictionFieldType = Classification.getPredictionFieldType(
277-
extractedField.isPresent() ? extractedField.get().getTypes() : null
278-
);
279-
return predictionFieldType == null ? PredictionFieldType.STRING : predictionFieldType;
280-
}
281-
282-
private String getDependentVariable() {
283-
if (analytics.getAnalysis() instanceof Classification) {
284-
return ((Classification)analytics.getAnalysis()).getDependentVariable();
285-
}
286-
if (analytics.getAnalysis() instanceof Regression) {
287-
return ((Regression)analytics.getAnalysis()).getDependentVariable();
288-
}
289-
return null;
290-
}
291-
292-
private CountDownLatch storeTrainedModel(TrainedModelConfig trainedModelConfig) {
293-
CountDownLatch latch = new CountDownLatch(1);
294-
ActionListener<Boolean> storeListener = ActionListener.wrap(
295-
aBoolean -> {
296-
if (aBoolean == false) {
297-
LOGGER.error("[{}] Storing trained model responded false", analytics.getId());
298-
setAndReportFailure(ExceptionsHelper.serverError("storing trained model responded false"));
299-
} else {
300-
LOGGER.info("[{}] Stored trained model with id [{}]", analytics.getId(), trainedModelConfig.getModelId());
301-
auditor.info(analytics.getId(), "Stored trained model with id [" + trainedModelConfig.getModelId() + "]");
302-
}
303-
},
304-
e -> setAndReportFailure(ExceptionsHelper.serverError("error storing trained model with id [{}]", e,
305-
trainedModelConfig.getModelId()))
306-
);
307-
trainedModelProvider.storeTrainedModel(trainedModelConfig, new LatchedActionListener<>(storeListener, latch));
308-
return latch;
309-
}
310-
311190
private void setAndReportFailure(Exception e) {
312191
LOGGER.error(new ParameterizedMessage("[{}] Error processing results; ", analytics.getId()), e);
313192
failure = "error processing results; " + e.getMessage();

0 commit comments

Comments
 (0)