Skip to content
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ public final class Messages {
" (a-z and 0-9), hyphens or underscores; must start and end with alphanumeric";

public static final String INFERENCE_TRAINED_MODEL_EXISTS = "Trained machine learning model [{0}] already exists";
public static final String INFERENCE_TRAINED_MODEL_DOC_EXISTS = "Trained machine learning model chunked doc [{0}][{1}] already exists";
public static final String INFERENCE_FAILED_TO_STORE_MODEL = "Failed to store trained machine learning model [{0}]";
public static final String INFERENCE_NOT_FOUND = "Could not find trained model [{0}]";
public static final String INFERENCE_NOT_FOUND_MULTIPLE = "Could not find trained models {0}";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
package org.elasticsearch.xpack.ml.integration;

import org.apache.logging.log4j.message.ParameterizedMessage;
import org.apache.lucene.util.LuceneTestCase;
import org.elasticsearch.ElasticsearchException;
import org.elasticsearch.ElasticsearchStatusException;
import org.elasticsearch.action.ActionModule;
Expand Down Expand Up @@ -64,7 +63,6 @@
import static org.hamcrest.Matchers.nullValue;
import static org.hamcrest.Matchers.startsWith;

@LuceneTestCase.AwaitsFix(bugUrl = "https://github.com/elastic/ml-cpp/pull/1349")
public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {

private static final String BOOLEAN_FIELD = "boolean-field";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
*/
package org.elasticsearch.xpack.ml.integration;

import org.apache.lucene.util.LuceneTestCase;
import org.elasticsearch.ElasticsearchException;
import org.elasticsearch.action.ActionModule;
import org.elasticsearch.action.DocWriteRequest;
Expand Down Expand Up @@ -43,7 +42,6 @@
import static org.hamcrest.Matchers.lessThan;
import static org.hamcrest.Matchers.not;

@LuceneTestCase.AwaitsFix(bugUrl = "https://github.com/elastic/ml-cpp/pull/1349")
public class RegressionIT extends MlNativeDataFrameAnalyticsIntegTestCase {

private static final String NUMERICAL_FEATURE_FIELD = "feature";
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.ml.integration;

import org.elasticsearch.ElasticsearchException;
import org.elasticsearch.Version;
import org.elasticsearch.action.support.PlainActionFuture;
import org.elasticsearch.common.collect.Tuple;
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
import org.elasticsearch.license.License;
import org.elasticsearch.xpack.core.action.util.PageParams;
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig;
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsDest;
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsSource;
import org.elasticsearch.xpack.core.ml.dataframe.analyses.Regression;
import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinition;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinitionTests;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelInputTests;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TargetType;
import org.elasticsearch.xpack.ml.MlSingleNodeTestCase;
import org.elasticsearch.xpack.ml.dataframe.process.ChunkedTrainedModelPersister;
import org.elasticsearch.xpack.ml.dataframe.process.results.TrainedModelDefinitionChunk;
import org.elasticsearch.xpack.ml.extractor.DocValueField;
import org.elasticsearch.xpack.ml.extractor.ExtractedField;
import org.elasticsearch.xpack.ml.extractor.ExtractedFields;
import org.elasticsearch.xpack.ml.inference.modelsize.MlModelSizeNamedXContentProvider;
import org.elasticsearch.xpack.ml.inference.modelsize.ModelSizeInfo;
import org.elasticsearch.xpack.ml.inference.modelsize.ModelSizeInfoTests;
import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider;
import org.elasticsearch.xpack.ml.notifications.DataFrameAnalyticsAuditor;
import org.junit.Before;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Set;

import static org.hamcrest.Matchers.equalTo;

public class ChunkedTrainedMoodelPersisterIT extends MlSingleNodeTestCase {

private TrainedModelProvider trainedModelProvider;

@Before
public void createComponents() throws Exception {
trainedModelProvider = new TrainedModelProvider(client(), xContentRegistry());
waitForMlTemplates();
}

public void testStoreModelViaChunkedPersister() throws IOException {
String modelId = "stored-chunked-model";
DataFrameAnalyticsConfig analyticsConfig = new DataFrameAnalyticsConfig.Builder()
.setId(modelId)
.setSource(new DataFrameAnalyticsSource(new String[] {"my_source"}, null, null))
.setDest(new DataFrameAnalyticsDest("my_dest", null))
.setAnalysis(new Regression("foo"))
.build();
List<ExtractedField> extractedFieldList = Collections.singletonList(new DocValueField("foo", Collections.emptySet()));
TrainedModelConfig.Builder configBuilder = buildTrainedModelConfigBuilder(modelId);
String compressedDefinition = configBuilder.build().getCompressedDefinition();
int totalSize = compressedDefinition.length();
List<String> chunks = chunkStringWithSize(compressedDefinition, totalSize/3);

ChunkedTrainedModelPersister persister = new ChunkedTrainedModelPersister(trainedModelProvider,
analyticsConfig,
new DataFrameAnalyticsAuditor(client(), "test-node"),
(ex) -> { throw new ElasticsearchException(ex); },
new ExtractedFields(extractedFieldList, Collections.emptyMap())
);

//Accuracy for size is not tested here
ModelSizeInfo modelSizeInfo = ModelSizeInfoTests.createRandom();
persister.createAndIndexInferenceModelMetadata(modelSizeInfo);
for (int i = 0; i < chunks.size(); i++) {
persister.createAndIndexInferenceModelDoc(new TrainedModelDefinitionChunk(chunks.get(i), i, i == (chunks.size() - 1)));
}

PlainActionFuture<Tuple<Long, Set<String>>> getIdsFuture = new PlainActionFuture<>();
trainedModelProvider.expandIds(modelId + "*", false, PageParams.defaultParams(), Collections.emptySet(), getIdsFuture);
Tuple<Long, Set<String>> ids = getIdsFuture.actionGet();
assertThat(ids.v1(), equalTo(1L));

PlainActionFuture<TrainedModelConfig> getTrainedModelFuture = new PlainActionFuture<>();
trainedModelProvider.getTrainedModel(ids.v2().iterator().next(), true, getTrainedModelFuture);

TrainedModelConfig storedConfig = getTrainedModelFuture.actionGet();
assertThat(storedConfig.getCompressedDefinition(), equalTo(compressedDefinition));
assertThat(storedConfig.getEstimatedOperations(), equalTo((long)modelSizeInfo.numOperations()));
assertThat(storedConfig.getEstimatedHeapMemory(), equalTo(modelSizeInfo.ramBytesUsed()));
}

private static TrainedModelConfig.Builder buildTrainedModelConfigBuilder(String modelId) {
TrainedModelDefinition.Builder definitionBuilder = TrainedModelDefinitionTests.createRandomBuilder();
long bytesUsed = definitionBuilder.build().ramBytesUsed();
long operations = definitionBuilder.build().getTrainedModel().estimatedNumOperations();
return TrainedModelConfig.builder()
.setCreatedBy("ml_test")
.setParsedDefinition(TrainedModelDefinitionTests.createRandomBuilder(TargetType.REGRESSION))
.setDescription("trained model config for test")
.setModelId(modelId)
.setVersion(Version.CURRENT)
.setLicenseLevel(License.OperationMode.PLATINUM.description())
.setEstimatedHeapMemory(bytesUsed)
.setEstimatedOperations(operations)
.setInput(TrainedModelInputTests.createRandomInput());
}

private static List<String> chunkStringWithSize(String str, int chunkSize) {
List<String> subStrings = new ArrayList<>((int)Math.ceil(str.length()/(double)chunkSize));
for (int i = 0; i < str.length();i += chunkSize) {
subStrings.add(str.substring(i, Math.min(i + chunkSize, str.length())));
}
return subStrings;
}

@Override
public NamedXContentRegistry xContentRegistry() {
List<NamedXContentRegistry.Entry> namedXContent = new ArrayList<>();
namedXContent.addAll(new MlInferenceNamedXContentProvider().getNamedXContentParsers());
namedXContent.addAll(new MlModelSizeNamedXContentProvider().getNamedXContentParsers());
return new NamedXContentRegistry(namedXContent);
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -8,48 +8,29 @@
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.apache.logging.log4j.message.ParameterizedMessage;
import org.elasticsearch.Version;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.LatchedActionListener;
import org.elasticsearch.common.Nullable;
import org.elasticsearch.common.xcontent.XContentHelper;
import org.elasticsearch.common.xcontent.json.JsonXContent;
import org.elasticsearch.license.License;
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig;
import org.elasticsearch.xpack.core.ml.dataframe.analyses.Classification;
import org.elasticsearch.xpack.core.ml.dataframe.analyses.Regression;
import org.elasticsearch.xpack.core.ml.dataframe.stats.classification.ClassificationStats;
import org.elasticsearch.xpack.core.ml.dataframe.stats.common.MemoryUsage;
import org.elasticsearch.xpack.core.ml.dataframe.stats.outlierdetection.OutlierDetectionStats;
import org.elasticsearch.xpack.core.ml.dataframe.stats.regression.RegressionStats;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinition;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelInput;
import org.elasticsearch.xpack.core.ml.job.messages.Messages;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import org.elasticsearch.xpack.core.ml.utils.PhaseProgress;
import org.elasticsearch.xpack.core.security.user.XPackUser;
import org.elasticsearch.xpack.ml.dataframe.process.results.AnalyticsResult;
import org.elasticsearch.xpack.ml.dataframe.process.results.RowResults;
import org.elasticsearch.xpack.ml.dataframe.process.results.TrainedModelDefinitionChunk;
import org.elasticsearch.xpack.ml.dataframe.stats.StatsHolder;
import org.elasticsearch.xpack.ml.dataframe.stats.StatsPersister;
import org.elasticsearch.xpack.ml.extractor.ExtractedField;
import org.elasticsearch.xpack.ml.inference.modelsize.ModelSizeInfo;
import org.elasticsearch.xpack.ml.extractor.ExtractedFields;
import org.elasticsearch.xpack.ml.extractor.MultiField;
import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider;
import org.elasticsearch.xpack.ml.notifications.DataFrameAnalyticsAuditor;

import java.time.Instant;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
import java.util.stream.Collectors;

import static java.util.stream.Collectors.toList;

public class AnalyticsResultProcessor {

Expand All @@ -70,11 +51,10 @@ public class AnalyticsResultProcessor {
private final DataFrameAnalyticsConfig analytics;
private final DataFrameRowsJoiner dataFrameRowsJoiner;
private final StatsHolder statsHolder;
private final TrainedModelProvider trainedModelProvider;
private final DataFrameAnalyticsAuditor auditor;
private final StatsPersister statsPersister;
private final ExtractedFields extractedFields;
private final CountDownLatch completionLatch = new CountDownLatch(1);
private final ChunkedTrainedModelPersister chunkedTrainedModelPersister;
private volatile String failure;
private volatile boolean isCancelled;

Expand All @@ -84,10 +64,15 @@ public AnalyticsResultProcessor(DataFrameAnalyticsConfig analytics, DataFrameRow
this.analytics = Objects.requireNonNull(analytics);
this.dataFrameRowsJoiner = Objects.requireNonNull(dataFrameRowsJoiner);
this.statsHolder = Objects.requireNonNull(statsHolder);
this.trainedModelProvider = Objects.requireNonNull(trainedModelProvider);
this.auditor = Objects.requireNonNull(auditor);
this.statsPersister = Objects.requireNonNull(statsPersister);
this.extractedFields = Objects.requireNonNull(extractedFields);
this.chunkedTrainedModelPersister = new ChunkedTrainedModelPersister(
trainedModelProvider,
analytics,
auditor,
this::setAndReportFailure,
extractedFields
);
}

@Nullable
Expand Down Expand Up @@ -166,9 +151,13 @@ private void processResult(AnalyticsResult result, DataFrameRowsJoiner resultsJo
phaseProgress.getProgressPercent());
statsHolder.getProgressTracker().updatePhase(phaseProgress);
}
TrainedModelDefinition.Builder inferenceModelBuilder = result.getInferenceModelBuilder();
if (inferenceModelBuilder != null) {
createAndIndexInferenceModel(inferenceModelBuilder);
ModelSizeInfo modelSize = result.getModelSizeInfo();
if (modelSize != null) {
chunkedTrainedModelPersister.createAndIndexInferenceModelMetadata(modelSize);
}
TrainedModelDefinitionChunk trainedModelDefinitionChunk = result.getTrainedModelDefinitionChunk();
if (trainedModelDefinitionChunk != null) {
chunkedTrainedModelPersister.createAndIndexInferenceModelDoc(trainedModelDefinitionChunk);
}
MemoryUsage memoryUsage = result.getMemoryUsage();
if (memoryUsage != null) {
Expand All @@ -191,82 +180,6 @@ private void processResult(AnalyticsResult result, DataFrameRowsJoiner resultsJo
}
}

private void createAndIndexInferenceModel(TrainedModelDefinition.Builder inferenceModel) {
TrainedModelConfig trainedModelConfig = createTrainedModelConfig(inferenceModel);
CountDownLatch latch = storeTrainedModel(trainedModelConfig);

try {
if (latch.await(30, TimeUnit.SECONDS) == false) {
LOGGER.error("[{}] Timed out (30s) waiting for inference model to be stored", analytics.getId());
}
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
setAndReportFailure(ExceptionsHelper.serverError("interrupted waiting for inference model to be stored"));
}
}

private TrainedModelConfig createTrainedModelConfig(TrainedModelDefinition.Builder inferenceModel) {
Instant createTime = Instant.now();
String modelId = analytics.getId() + "-" + createTime.toEpochMilli();
TrainedModelDefinition definition = inferenceModel.build();
String dependentVariable = getDependentVariable();
List<ExtractedField> fieldNames = extractedFields.getAllFields();
List<String> fieldNamesWithoutDependentVariable = fieldNames.stream()
.map(ExtractedField::getName)
.filter(f -> f.equals(dependentVariable) == false)
.collect(toList());
Map<String, String> defaultFieldMapping = fieldNames.stream()
.filter(ef -> ef instanceof MultiField && (ef.getName().equals(dependentVariable) == false))
.collect(Collectors.toMap(ExtractedField::getParentField, ExtractedField::getName));
return TrainedModelConfig.builder()
.setModelId(modelId)
.setCreatedBy(XPackUser.NAME)
.setVersion(Version.CURRENT)
.setCreateTime(createTime)
// NOTE: GET _cat/ml/trained_models relies on the creating analytics ID being in the tags
.setTags(Collections.singletonList(analytics.getId()))
.setDescription(analytics.getDescription())
.setMetadata(Collections.singletonMap("analytics_config",
XContentHelper.convertToMap(JsonXContent.jsonXContent, analytics.toString(), true)))
.setEstimatedHeapMemory(definition.ramBytesUsed())
.setEstimatedOperations(definition.getTrainedModel().estimatedNumOperations())
.setParsedDefinition(inferenceModel)
.setInput(new TrainedModelInput(fieldNamesWithoutDependentVariable))
.setLicenseLevel(License.OperationMode.PLATINUM.description())
.setDefaultFieldMap(defaultFieldMapping)
.setInferenceConfig(analytics.getAnalysis().inferenceConfig(new AnalysisFieldInfo(extractedFields)))
.build();
}

private String getDependentVariable() {
if (analytics.getAnalysis() instanceof Classification) {
return ((Classification)analytics.getAnalysis()).getDependentVariable();
}
if (analytics.getAnalysis() instanceof Regression) {
return ((Regression)analytics.getAnalysis()).getDependentVariable();
}
return null;
}

private CountDownLatch storeTrainedModel(TrainedModelConfig trainedModelConfig) {
CountDownLatch latch = new CountDownLatch(1);
ActionListener<Boolean> storeListener = ActionListener.wrap(
aBoolean -> {
if (aBoolean == false) {
LOGGER.error("[{}] Storing trained model responded false", analytics.getId());
setAndReportFailure(ExceptionsHelper.serverError("storing trained model responded false"));
} else {
LOGGER.info("[{}] Stored trained model with id [{}]", analytics.getId(), trainedModelConfig.getModelId());
auditor.info(analytics.getId(), "Stored trained model with id [" + trainedModelConfig.getModelId() + "]");
}
},
e -> setAndReportFailure(ExceptionsHelper.serverError("error storing trained model with id [{}]", e,
trainedModelConfig.getModelId()))
);
trainedModelProvider.storeTrainedModel(trainedModelConfig, new LatchedActionListener<>(storeListener, latch));
return latch;
}

private void setAndReportFailure(Exception e) {
LOGGER.error(new ParameterizedMessage("[{}] Error processing results; ", analytics.getId()), e);
failure = "error processing results; " + e.getMessage();
Expand Down
Loading