From f3ccd195ccb02bd2ae617f9974d0a667f04d048f Mon Sep 17 00:00:00 2001 From: Benjamin Trent <4357155+benwtrent@users.noreply.github.com> Date: Thu, 11 Jun 2020 12:59:20 -0400 Subject: [PATCH 1/9] [ML] handles compressed model stream from native process --- .../xpack/core/ml/job/messages/Messages.java | 1 + .../ChunkedTrainedMoodelPersisterIT.java | 129 +++++++++ .../process/AnalyticsResultProcessor.java | 155 ++--------- .../process/ChunkedTrainedModelPersister.java | 253 ++++++++++++++++++ .../process/results/AnalyticsResult.java | 53 ++-- .../results/TrainedModelDefinitionChunk.java | 76 ++++++ .../TrainedModelDefinitionDoc.java | 4 + .../persistence/TrainedModelProvider.java | 86 +++++- .../AnalyticsResultProcessorTests.java | 100 +------ .../ChunkedTrainedModelPersisterTests.java | 223 +++++++++++++++ .../process/results/AnalyticsResultTests.java | 16 +- 11 files changed, 823 insertions(+), 273 deletions(-) create mode 100644 x-pack/plugin/ml/src/internalClusterTest/java/org/elasticsearch/xpack/ml/integration/ChunkedTrainedMoodelPersisterIT.java create mode 100644 x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/ChunkedTrainedModelPersister.java create mode 100644 x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/results/TrainedModelDefinitionChunk.java create mode 100644 x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/ChunkedTrainedModelPersisterTests.java diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/messages/Messages.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/messages/Messages.java index 5287ce035318d..95c0ef2d085e1 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/messages/Messages.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/messages/Messages.java @@ -88,6 +88,7 @@ public final class Messages { " (a-z and 0-9), hyphens or underscores; must start and end with alphanumeric"; public static final String INFERENCE_TRAINED_MODEL_EXISTS = "Trained machine learning model [{0}] already exists"; + public static final String INFERENCE_TRAINED_MODEL_DOC_EXISTS = "Trained machine learning model chunked doc [{0}][{1}] already exists"; public static final String INFERENCE_FAILED_TO_STORE_MODEL = "Failed to store trained machine learning model [{0}]"; public static final String INFERENCE_NOT_FOUND = "Could not find trained model [{0}]"; public static final String INFERENCE_NOT_FOUND_MULTIPLE = "Could not find trained models {0}"; diff --git a/x-pack/plugin/ml/src/internalClusterTest/java/org/elasticsearch/xpack/ml/integration/ChunkedTrainedMoodelPersisterIT.java b/x-pack/plugin/ml/src/internalClusterTest/java/org/elasticsearch/xpack/ml/integration/ChunkedTrainedMoodelPersisterIT.java new file mode 100644 index 0000000000000..28ff17c6c47d5 --- /dev/null +++ b/x-pack/plugin/ml/src/internalClusterTest/java/org/elasticsearch/xpack/ml/integration/ChunkedTrainedMoodelPersisterIT.java @@ -0,0 +1,129 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.ml.integration; + +import org.elasticsearch.ElasticsearchException; +import org.elasticsearch.Version; +import org.elasticsearch.action.support.PlainActionFuture; +import org.elasticsearch.common.collect.Tuple; +import org.elasticsearch.common.xcontent.NamedXContentRegistry; +import org.elasticsearch.license.License; +import org.elasticsearch.xpack.core.action.util.PageParams; +import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig; +import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsDest; +import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsSource; +import org.elasticsearch.xpack.core.ml.dataframe.analyses.Regression; +import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider; +import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig; +import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinition; +import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinitionTests; +import org.elasticsearch.xpack.core.ml.inference.TrainedModelInputTests; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TargetType; +import org.elasticsearch.xpack.ml.MlSingleNodeTestCase; +import org.elasticsearch.xpack.ml.dataframe.process.ChunkedTrainedModelPersister; +import org.elasticsearch.xpack.ml.dataframe.process.results.TrainedModelDefinitionChunk; +import org.elasticsearch.xpack.ml.extractor.DocValueField; +import org.elasticsearch.xpack.ml.extractor.ExtractedField; +import org.elasticsearch.xpack.ml.inference.modelsize.MlModelSizeNamedXContentProvider; +import org.elasticsearch.xpack.ml.inference.modelsize.ModelSizeInfo; +import org.elasticsearch.xpack.ml.inference.modelsize.ModelSizeInfoTests; +import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider; +import org.elasticsearch.xpack.ml.notifications.DataFrameAnalyticsAuditor; +import org.junit.Before; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.Set; + +import static org.hamcrest.Matchers.equalTo; + +public class ChunkedTrainedMoodelPersisterIT extends MlSingleNodeTestCase { + + private TrainedModelProvider trainedModelProvider; + + @Before + public void createComponents() throws Exception { + trainedModelProvider = new TrainedModelProvider(client(), xContentRegistry()); + waitForMlTemplates(); + } + + public void testStoreModelViaChunkedPersister() throws IOException { + String modelId = "stored-chunked-model"; + DataFrameAnalyticsConfig analyticsConfig = new DataFrameAnalyticsConfig.Builder() + .setId(modelId) + .setSource(new DataFrameAnalyticsSource(new String[] {"my_source"}, null, null)) + .setDest(new DataFrameAnalyticsDest("my_dest", null)) + .setAnalysis(new Regression("foo")) + .build(); + List extractedFieldList = Collections.singletonList(new DocValueField("foo", Collections.emptySet())); + TrainedModelConfig.Builder configBuilder = buildTrainedModelConfigBuilder(modelId); + String compressedDefinition = configBuilder.build().getCompressedDefinition(); + int totalSize = compressedDefinition.length(); + List chunks = chunkStringWithSize(compressedDefinition, totalSize/3); + + ChunkedTrainedModelPersister persister = new ChunkedTrainedModelPersister(trainedModelProvider, + analyticsConfig, + new DataFrameAnalyticsAuditor(client(), "test-node"), + (ex) -> { throw new ElasticsearchException(ex); }, + extractedFieldList + ); + + //Accuracy for size is not tested here + ModelSizeInfo modelSizeInfo = ModelSizeInfoTests.createRandom(); + persister.createAndIndexInferenceModelMetadata(modelSizeInfo); + for (String chunk : chunks) { + persister.createAndIndexInferenceModelDoc(new TrainedModelDefinitionChunk(chunk, totalSize)); + } + + PlainActionFuture>> getIdsFuture = new PlainActionFuture<>(); + trainedModelProvider.expandIds(modelId + "*", false, PageParams.defaultParams(), Collections.emptySet(), getIdsFuture); + Tuple> ids = getIdsFuture.actionGet(); + assertThat(ids.v1(), equalTo(1L)); + + PlainActionFuture getTrainedModelFuture = new PlainActionFuture<>(); + trainedModelProvider.getTrainedModel(ids.v2().iterator().next(), true, getTrainedModelFuture); + + TrainedModelConfig storedConfig = getTrainedModelFuture.actionGet(); + assertThat(storedConfig.getCompressedDefinition(), equalTo(compressedDefinition)); + assertThat(storedConfig.getEstimatedOperations(), equalTo((long)modelSizeInfo.numOperations())); + assertThat(storedConfig.getEstimatedHeapMemory(), equalTo(modelSizeInfo.ramBytesUsed())); + } + + private static TrainedModelConfig.Builder buildTrainedModelConfigBuilder(String modelId) { + TrainedModelDefinition.Builder definitionBuilder = TrainedModelDefinitionTests.createRandomBuilder(); + long bytesUsed = definitionBuilder.build().ramBytesUsed(); + long operations = definitionBuilder.build().getTrainedModel().estimatedNumOperations(); + return TrainedModelConfig.builder() + .setCreatedBy("ml_test") + .setParsedDefinition(TrainedModelDefinitionTests.createRandomBuilder(TargetType.REGRESSION)) + .setDescription("trained model config for test") + .setModelId(modelId) + .setVersion(Version.CURRENT) + .setLicenseLevel(License.OperationMode.PLATINUM.description()) + .setEstimatedHeapMemory(bytesUsed) + .setEstimatedOperations(operations) + .setInput(TrainedModelInputTests.createRandomInput()); + } + + private static List chunkStringWithSize(String str, int chunkSize) { + List subStrings = new ArrayList<>((int)Math.ceil(str.length()/(double)chunkSize)); + for (int i = 0; i < str.length();i += chunkSize) { + subStrings.add(str.substring(i, Math.min(i + chunkSize, str.length()))); + } + return subStrings; + } + + @Override + public NamedXContentRegistry xContentRegistry() { + List namedXContent = new ArrayList<>(); + namedXContent.addAll(new MlInferenceNamedXContentProvider().getNamedXContentParsers()); + namedXContent.addAll(new MlModelSizeNamedXContentProvider().getNamedXContentParsers()); + return new NamedXContentRegistry(namedXContent); + } + +} diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsResultProcessor.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsResultProcessor.java index cd9ad2baf1f0d..b77b028bd8365 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsResultProcessor.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsResultProcessor.java @@ -8,53 +8,31 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.apache.logging.log4j.message.ParameterizedMessage; -import org.elasticsearch.Version; -import org.elasticsearch.action.ActionListener; -import org.elasticsearch.action.LatchedActionListener; import org.elasticsearch.common.Nullable; -import org.elasticsearch.common.xcontent.XContentHelper; -import org.elasticsearch.common.xcontent.json.JsonXContent; -import org.elasticsearch.license.License; import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig; -import org.elasticsearch.xpack.core.ml.dataframe.analyses.Classification; -import org.elasticsearch.xpack.core.ml.dataframe.analyses.Regression; import org.elasticsearch.xpack.core.ml.dataframe.stats.classification.ClassificationStats; import org.elasticsearch.xpack.core.ml.dataframe.stats.common.MemoryUsage; import org.elasticsearch.xpack.core.ml.dataframe.stats.outlierdetection.OutlierDetectionStats; import org.elasticsearch.xpack.core.ml.dataframe.stats.regression.RegressionStats; -import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig; -import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinition; -import org.elasticsearch.xpack.core.ml.inference.TrainedModelInput; -import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig; -import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig; -import org.elasticsearch.xpack.core.ml.inference.trainedmodel.PredictionFieldType; -import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfig; -import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TargetType; import org.elasticsearch.xpack.core.ml.job.messages.Messages; import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; import org.elasticsearch.xpack.core.ml.utils.PhaseProgress; -import org.elasticsearch.xpack.core.security.user.XPackUser; import org.elasticsearch.xpack.ml.dataframe.process.results.AnalyticsResult; import org.elasticsearch.xpack.ml.dataframe.process.results.RowResults; +import org.elasticsearch.xpack.ml.dataframe.process.results.TrainedModelDefinitionChunk; import org.elasticsearch.xpack.ml.dataframe.stats.StatsHolder; import org.elasticsearch.xpack.ml.dataframe.stats.StatsPersister; import org.elasticsearch.xpack.ml.extractor.ExtractedField; -import org.elasticsearch.xpack.ml.extractor.MultiField; +import org.elasticsearch.xpack.ml.inference.modelsize.ModelSizeInfo; import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider; import org.elasticsearch.xpack.ml.notifications.DataFrameAnalyticsAuditor; -import java.time.Instant; import java.util.Collections; import java.util.Iterator; import java.util.List; -import java.util.Map; import java.util.Objects; -import java.util.Optional; import java.util.concurrent.CountDownLatch; -import java.util.concurrent.TimeUnit; -import java.util.stream.Collectors; -import static java.util.stream.Collectors.toList; public class AnalyticsResultProcessor { @@ -80,6 +58,7 @@ public class AnalyticsResultProcessor { private final StatsPersister statsPersister; private final List fieldNames; private final CountDownLatch completionLatch = new CountDownLatch(1); + private final ChunkedTrainedModelPersister chunkedTrainedModelPersister; private volatile String failure; private volatile boolean isCancelled; @@ -93,6 +72,13 @@ public AnalyticsResultProcessor(DataFrameAnalyticsConfig analytics, DataFrameRow this.auditor = Objects.requireNonNull(auditor); this.statsPersister = Objects.requireNonNull(statsPersister); this.fieldNames = Collections.unmodifiableList(Objects.requireNonNull(fieldNames)); + this.chunkedTrainedModelPersister = new ChunkedTrainedModelPersister( + trainedModelProvider, + analytics, + auditor, + this::setAndReportFailure, + fieldNames + ); } @Nullable @@ -171,9 +157,13 @@ private void processResult(AnalyticsResult result, DataFrameRowsJoiner resultsJo phaseProgress.getProgressPercent()); statsHolder.getProgressTracker().updatePhase(phaseProgress); } - TrainedModelDefinition.Builder inferenceModelBuilder = result.getInferenceModelBuilder(); - if (inferenceModelBuilder != null) { - createAndIndexInferenceModel(inferenceModelBuilder); + ModelSizeInfo modelSize = result.getModelSizeInfo(); + if (modelSize != null) { + chunkedTrainedModelPersister.createAndIndexInferenceModelMetadata(modelSize); + } + TrainedModelDefinitionChunk trainedModelDefinitionChunk = result.getTrainedModelDefinitionChunk(); + if (trainedModelDefinitionChunk != null) { + chunkedTrainedModelPersister.createAndIndexInferenceModelDoc(trainedModelDefinitionChunk); } MemoryUsage memoryUsage = result.getMemoryUsage(); if (memoryUsage != null) { @@ -197,117 +187,6 @@ private void processResult(AnalyticsResult result, DataFrameRowsJoiner resultsJo } } - private void createAndIndexInferenceModel(TrainedModelDefinition.Builder inferenceModel) { - TrainedModelConfig trainedModelConfig = createTrainedModelConfig(inferenceModel); - CountDownLatch latch = storeTrainedModel(trainedModelConfig); - - try { - if (latch.await(30, TimeUnit.SECONDS) == false) { - LOGGER.error("[{}] Timed out (30s) waiting for inference model to be stored", analytics.getId()); - } - } catch (InterruptedException e) { - Thread.currentThread().interrupt(); - setAndReportFailure(ExceptionsHelper.serverError("interrupted waiting for inference model to be stored")); - } - } - - private TrainedModelConfig createTrainedModelConfig(TrainedModelDefinition.Builder inferenceModel) { - Instant createTime = Instant.now(); - String modelId = analytics.getId() + "-" + createTime.toEpochMilli(); - TrainedModelDefinition definition = inferenceModel.build(); - String dependentVariable = getDependentVariable(); - List fieldNamesWithoutDependentVariable = fieldNames.stream() - .map(ExtractedField::getName) - .filter(f -> f.equals(dependentVariable) == false) - .collect(toList()); - Map defaultFieldMapping = fieldNames.stream() - .filter(ef -> ef instanceof MultiField && (ef.getName().equals(dependentVariable) == false)) - .collect(Collectors.toMap(ExtractedField::getParentField, ExtractedField::getName)); - return TrainedModelConfig.builder() - .setModelId(modelId) - .setCreatedBy(XPackUser.NAME) - .setVersion(Version.CURRENT) - .setCreateTime(createTime) - // NOTE: GET _cat/ml/trained_models relies on the creating analytics ID being in the tags - .setTags(Collections.singletonList(analytics.getId())) - .setDescription(analytics.getDescription()) - .setMetadata(Collections.singletonMap("analytics_config", - XContentHelper.convertToMap(JsonXContent.jsonXContent, analytics.toString(), true))) - .setEstimatedHeapMemory(definition.ramBytesUsed()) - .setEstimatedOperations(definition.getTrainedModel().estimatedNumOperations()) - .setParsedDefinition(inferenceModel) - .setInput(new TrainedModelInput(fieldNamesWithoutDependentVariable)) - .setLicenseLevel(License.OperationMode.PLATINUM.description()) - .setDefaultFieldMap(defaultFieldMapping) - .setInferenceConfig(buildInferenceConfig(definition.getTrainedModel().targetType())) - .build(); - } - - private InferenceConfig buildInferenceConfig(TargetType targetType) { - switch (targetType) { - case CLASSIFICATION: - assert analytics.getAnalysis() instanceof Classification; - Classification classification = ((Classification)analytics.getAnalysis()); - PredictionFieldType predictionFieldType = getPredictionFieldType(classification); - return ClassificationConfig.builder() - .setNumTopClasses(classification.getNumTopClasses()) - .setNumTopFeatureImportanceValues(classification.getBoostedTreeParams().getNumTopFeatureImportanceValues()) - .setPredictionFieldType(predictionFieldType) - .build(); - case REGRESSION: - assert analytics.getAnalysis() instanceof Regression; - Regression regression = ((Regression)analytics.getAnalysis()); - return RegressionConfig.builder() - .setNumTopFeatureImportanceValues(regression.getBoostedTreeParams().getNumTopFeatureImportanceValues()) - .build(); - default: - throw ExceptionsHelper.serverError( - "process created a model with an unsupported target type [{}]", - null, - targetType); - } - } - - PredictionFieldType getPredictionFieldType(Classification classification) { - String dependentVariable = classification.getDependentVariable(); - Optional extractedField = fieldNames.stream() - .filter(f -> f.getName().equals(dependentVariable)) - .findAny(); - PredictionFieldType predictionFieldType = Classification.getPredictionFieldType( - extractedField.isPresent() ? extractedField.get().getTypes() : null - ); - return predictionFieldType == null ? PredictionFieldType.STRING : predictionFieldType; - } - - private String getDependentVariable() { - if (analytics.getAnalysis() instanceof Classification) { - return ((Classification)analytics.getAnalysis()).getDependentVariable(); - } - if (analytics.getAnalysis() instanceof Regression) { - return ((Regression)analytics.getAnalysis()).getDependentVariable(); - } - return null; - } - - private CountDownLatch storeTrainedModel(TrainedModelConfig trainedModelConfig) { - CountDownLatch latch = new CountDownLatch(1); - ActionListener storeListener = ActionListener.wrap( - aBoolean -> { - if (aBoolean == false) { - LOGGER.error("[{}] Storing trained model responded false", analytics.getId()); - setAndReportFailure(ExceptionsHelper.serverError("storing trained model responded false")); - } else { - LOGGER.info("[{}] Stored trained model with id [{}]", analytics.getId(), trainedModelConfig.getModelId()); - auditor.info(analytics.getId(), "Stored trained model with id [" + trainedModelConfig.getModelId() + "]"); - } - }, - e -> setAndReportFailure(ExceptionsHelper.serverError("error storing trained model with id [{}]", e, - trainedModelConfig.getModelId())) - ); - trainedModelProvider.storeTrainedModel(trainedModelConfig, new LatchedActionListener<>(storeListener, latch)); - return latch; - } - private void setAndReportFailure(Exception e) { LOGGER.error(new ParameterizedMessage("[{}] Error processing results; ", analytics.getId()), e); failure = "error processing results; " + e.getMessage(); diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/ChunkedTrainedModelPersister.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/ChunkedTrainedModelPersister.java new file mode 100644 index 0000000000000..6f1df92e822f1 --- /dev/null +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/ChunkedTrainedModelPersister.java @@ -0,0 +1,253 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ + +package org.elasticsearch.xpack.ml.dataframe.process; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.apache.logging.log4j.message.ParameterizedMessage; +import org.elasticsearch.Version; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.LatchedActionListener; +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.xcontent.XContentHelper; +import org.elasticsearch.common.xcontent.json.JsonXContent; +import org.elasticsearch.license.License; +import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig; +import org.elasticsearch.xpack.core.ml.dataframe.analyses.Classification; +import org.elasticsearch.xpack.core.ml.dataframe.analyses.Regression; +import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig; +import org.elasticsearch.xpack.core.ml.inference.TrainedModelInput; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.PredictionFieldType; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfig; +import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; +import org.elasticsearch.xpack.core.security.user.XPackUser; +import org.elasticsearch.xpack.ml.dataframe.process.results.TrainedModelDefinitionChunk; +import org.elasticsearch.xpack.ml.extractor.ExtractedField; +import org.elasticsearch.xpack.ml.extractor.MultiField; +import org.elasticsearch.xpack.ml.inference.modelsize.ModelSizeInfo; +import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelDefinitionDoc; +import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider; +import org.elasticsearch.xpack.ml.notifications.DataFrameAnalyticsAuditor; + +import java.time.Instant; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.atomic.AtomicLong; +import java.util.concurrent.atomic.AtomicReference; +import java.util.function.Consumer; +import java.util.stream.Collectors; + +import static java.util.stream.Collectors.toList; + +public class ChunkedTrainedModelPersister { + + private static final Logger LOGGER = LogManager.getLogger(ChunkedTrainedModelPersister.class); + private final TrainedModelProvider provider; + private final AtomicReference currentModelId; + private final AtomicInteger currentChunkedDoc; + private final AtomicLong persistedChunkLengths; + private final DataFrameAnalyticsConfig analytics; + private final DataFrameAnalyticsAuditor auditor; + private final Consumer failureHandler; + private final List fieldNames; + private volatile boolean readyToStoreNewModel = true; + + public ChunkedTrainedModelPersister(TrainedModelProvider provider, + DataFrameAnalyticsConfig analytics, + DataFrameAnalyticsAuditor auditor, + Consumer failureHandler, + List fieldNames) { + this.provider = provider; + this.currentModelId = new AtomicReference<>(""); + this.currentChunkedDoc = new AtomicInteger(0); + this.persistedChunkLengths = new AtomicLong(0L); + this.analytics = analytics; + this.auditor = auditor; + this.failureHandler = failureHandler; + this.fieldNames = fieldNames; + } + + public void createAndIndexInferenceModelDoc(TrainedModelDefinitionChunk trainedModelDefinitionChunk) { + if (Strings.isNullOrEmpty(this.currentModelId.get())) { + failureHandler.accept(ExceptionsHelper.serverError( + "chunked inference model definition is attempting to be stored before trained model configuration" + )); + return; + } + TrainedModelDefinitionDoc trainedModelDefinitionDoc = trainedModelDefinitionChunk.createTrainedModelDoc( + this.currentModelId.get(), + this.currentChunkedDoc.getAndIncrement()); + + CountDownLatch latch = new CountDownLatch(1); + ActionListener storeListener = ActionListener.wrap( + r -> { + LOGGER.debug(() -> new ParameterizedMessage( + "[{}] stored trained model definition chunk [{}] [{}]", + analytics.getId(), + trainedModelDefinitionDoc.getModelId(), + trainedModelDefinitionDoc.getDocNum())); + + long persistedChunkLengths = this.persistedChunkLengths.addAndGet(trainedModelDefinitionDoc.getDefinitionLength()); + if (persistedChunkLengths >= trainedModelDefinitionDoc.getTotalDefinitionLength()) { + readyToStoreNewModel = true; + LOGGER.info( + "[{}] finished stored trained model definition chunks with id [{}]", + analytics.getId(), + this.currentModelId.get()); + auditor.info(analytics.getId(), "Stored trained model with id [" + this.currentModelId.get() + "]"); + CountDownLatch refreshLatch = new CountDownLatch(1); + provider.refreshInferenceIndex( + new LatchedActionListener<>(ActionListener.wrap( + refreshResponse -> LOGGER.debug(() -> new ParameterizedMessage( + "[{}] refreshed inference index after model store", + analytics.getId() + )), + e -> LOGGER.warn("[{}] failed to refresh inference index after model store", analytics.getId())), + refreshLatch)); + try { + if (refreshLatch.await(30, TimeUnit.SECONDS) == false) { + LOGGER.error("[{}] Timed out (30s) waiting for index refresh", analytics.getId()); + } + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + } + } + }, + e -> failureHandler.accept(ExceptionsHelper.serverError("error storing trained model definition chunk [{}] with id [{}]", e, + trainedModelDefinitionDoc.getModelId(), trainedModelDefinitionDoc.getDocNum())) + ); + provider.storeTrainedModelDefinitionDoc(trainedModelDefinitionDoc, new LatchedActionListener<>(storeListener, latch)); + try { + if (latch.await(30, TimeUnit.SECONDS) == false) { + LOGGER.error("[{}] Timed out (30s) waiting for chunked inference definition to be stored", analytics.getId()); + } + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + failureHandler.accept(ExceptionsHelper.serverError("interrupted waiting for chunked inference definition to be stored")); + } + } + + public void createAndIndexInferenceModelMetadata(ModelSizeInfo inferenceModelSize) { + if (readyToStoreNewModel == false) { + failureHandler.accept(ExceptionsHelper.serverError( + "new inference model is attempting to be stored before completion previous model storage" + )); + return; + } + TrainedModelConfig trainedModelConfig = createTrainedModelConfig(inferenceModelSize); + CountDownLatch latch = storeTrainedModelMetadata(trainedModelConfig); + try { + readyToStoreNewModel = false; + if (latch.await(30, TimeUnit.SECONDS) == false) { + LOGGER.error("[{}] Timed out (30s) waiting for inference model metadata to be stored", analytics.getId()); + } + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + failureHandler.accept(ExceptionsHelper.serverError("interrupted waiting for inference model metadata to be stored")); + } + } + + private CountDownLatch storeTrainedModelMetadata(TrainedModelConfig trainedModelConfig) { + CountDownLatch latch = new CountDownLatch(1); + ActionListener storeListener = ActionListener.wrap( + aBoolean -> { + if (aBoolean == false) { + LOGGER.error("[{}] Storing trained model metadata responded false", analytics.getId()); + failureHandler.accept(ExceptionsHelper.serverError("storing trained model responded false")); + } else { + LOGGER.info("[{}] Stored trained model metadata with id [{}]", analytics.getId(), trainedModelConfig.getModelId()); + } + }, + e -> failureHandler.accept(ExceptionsHelper.serverError("error storing trained model metadata with id [{}]", e, + trainedModelConfig.getModelId())) + ); + provider.storeTrainedModelMetadata(trainedModelConfig, new LatchedActionListener<>(storeListener, latch)); + return latch; + } + + private TrainedModelConfig createTrainedModelConfig(ModelSizeInfo modelSize) { + Instant createTime = Instant.now(); + String modelId = analytics.getId() + "-" + createTime.toEpochMilli(); + currentModelId.set(modelId); + currentChunkedDoc.set(0); + persistedChunkLengths.set(0); + String dependentVariable = getDependentVariable(); + List fieldNamesWithoutDependentVariable = fieldNames.stream() + .map(ExtractedField::getName) + .filter(f -> f.equals(dependentVariable) == false) + .collect(toList()); + Map defaultFieldMapping = fieldNames.stream() + .filter(ef -> ef instanceof MultiField && (ef.getName().equals(dependentVariable) == false)) + .collect(Collectors.toMap(ExtractedField::getParentField, ExtractedField::getName)); + return TrainedModelConfig.builder() + .setModelId(modelId) + .setCreatedBy(XPackUser.NAME) + .setVersion(Version.CURRENT) + .setCreateTime(createTime) + // NOTE: GET _cat/ml/trained_models relies on the creating analytics ID being in the tags + .setTags(Collections.singletonList(analytics.getId())) + .setDescription(analytics.getDescription()) + .setMetadata(Collections.singletonMap("analytics_config", + XContentHelper.convertToMap(JsonXContent.jsonXContent, analytics.toString(), true))) + .setEstimatedHeapMemory(modelSize.ramBytesUsed()) + .setEstimatedOperations(modelSize.numOperations()) + .setInput(new TrainedModelInput(fieldNamesWithoutDependentVariable)) + .setLicenseLevel(License.OperationMode.PLATINUM.description()) + .setDefaultFieldMap(defaultFieldMapping) + .setInferenceConfig(buildInferenceConfigByAnalyticsType()) + .build(); + } + + private String getDependentVariable() { + if (analytics.getAnalysis() instanceof Classification) { + return ((Classification)analytics.getAnalysis()).getDependentVariable(); + } + if (analytics.getAnalysis() instanceof Regression) { + return ((Regression)analytics.getAnalysis()).getDependentVariable(); + } + return null; + } + + InferenceConfig buildInferenceConfigByAnalyticsType() { + if (analytics.getAnalysis() instanceof Classification) { + Classification classification = ((Classification)analytics.getAnalysis()); + PredictionFieldType predictionFieldType = getPredictionFieldType(classification); + return ClassificationConfig.builder() + .setNumTopClasses(classification.getNumTopClasses()) + .setNumTopFeatureImportanceValues(classification.getBoostedTreeParams().getNumTopFeatureImportanceValues()) + .setPredictionFieldType(predictionFieldType) + .build(); + } else if (analytics.getAnalysis() instanceof Regression) { + Regression regression = ((Regression)analytics.getAnalysis()); + return RegressionConfig.builder() + .setNumTopFeatureImportanceValues(regression.getBoostedTreeParams().getNumTopFeatureImportanceValues()) + .build(); + } + throw ExceptionsHelper.serverError( + "analytics type [{}] does not support model creation", + null, + analytics.getAnalysis().getWriteableName()); + } + + PredictionFieldType getPredictionFieldType(Classification classification) { + String dependentVariable = classification.getDependentVariable(); + Optional extractedField = fieldNames.stream() + .filter(f -> f.getName().equals(dependentVariable)) + .findAny(); + PredictionFieldType predictionFieldType = Classification.getPredictionFieldType( + extractedField.isPresent() ? extractedField.get().getTypes() : null + ); + return predictionFieldType == null ? PredictionFieldType.STRING : predictionFieldType; + } +} diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/results/AnalyticsResult.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/results/AnalyticsResult.java index 0a05f13c11869..c7e6f9a1c4377 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/results/AnalyticsResult.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/results/AnalyticsResult.java @@ -8,31 +8,27 @@ import org.elasticsearch.common.Nullable; import org.elasticsearch.common.ParseField; import org.elasticsearch.common.xcontent.ConstructingObjectParser; -import org.elasticsearch.common.xcontent.ToXContent; import org.elasticsearch.common.xcontent.ToXContentObject; import org.elasticsearch.common.xcontent.XContentBuilder; import org.elasticsearch.xpack.core.ml.dataframe.stats.classification.ClassificationStats; import org.elasticsearch.xpack.core.ml.dataframe.stats.common.MemoryUsage; import org.elasticsearch.xpack.core.ml.dataframe.stats.outlierdetection.OutlierDetectionStats; import org.elasticsearch.xpack.core.ml.dataframe.stats.regression.RegressionStats; -import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinition; import org.elasticsearch.xpack.core.ml.utils.PhaseProgress; import org.elasticsearch.xpack.ml.inference.modelsize.ModelSizeInfo; import java.io.IOException; -import java.util.Collections; import java.util.Objects; import static org.elasticsearch.common.xcontent.ConstructingObjectParser.optionalConstructorArg; -import static org.elasticsearch.xpack.core.ml.utils.ToXContentParams.FOR_INTERNAL_STORAGE; public class AnalyticsResult implements ToXContentObject { public static final ParseField TYPE = new ParseField("analytics_result"); private static final ParseField PHASE_PROGRESS = new ParseField("phase_progress"); - private static final ParseField INFERENCE_MODEL = new ParseField("inference_model"); private static final ParseField MODEL_SIZE_INFO = new ParseField("model_size_info"); + private static final ParseField COMPRESSED_INFERENCE_MODEL = new ParseField("compressed_inference_model"); private static final ParseField ANALYTICS_MEMORY_USAGE = new ParseField("analytics_memory_usage"); private static final ParseField OUTLIER_DETECTION_STATS = new ParseField("outlier_detection_stats"); private static final ParseField CLASSIFICATION_STATS = new ParseField("classification_stats"); @@ -42,53 +38,50 @@ public class AnalyticsResult implements ToXContentObject { a -> new AnalyticsResult( (RowResults) a[0], (PhaseProgress) a[1], - (TrainedModelDefinition.Builder) a[2], - (MemoryUsage) a[3], - (OutlierDetectionStats) a[4], - (ClassificationStats) a[5], - (RegressionStats) a[6], - (ModelSizeInfo) a[7] + (MemoryUsage) a[2], + (OutlierDetectionStats) a[3], + (ClassificationStats) a[4], + (RegressionStats) a[5], + (ModelSizeInfo) a[6], + (TrainedModelDefinitionChunk) a[7] )); static { PARSER.declareObject(optionalConstructorArg(), RowResults.PARSER, RowResults.TYPE); PARSER.declareObject(optionalConstructorArg(), PhaseProgress.PARSER, PHASE_PROGRESS); - // TODO change back to STRICT_PARSER once native side is aligned - PARSER.declareObject(optionalConstructorArg(), TrainedModelDefinition.LENIENT_PARSER, INFERENCE_MODEL); PARSER.declareObject(optionalConstructorArg(), MemoryUsage.STRICT_PARSER, ANALYTICS_MEMORY_USAGE); PARSER.declareObject(optionalConstructorArg(), OutlierDetectionStats.STRICT_PARSER, OUTLIER_DETECTION_STATS); PARSER.declareObject(optionalConstructorArg(), ClassificationStats.STRICT_PARSER, CLASSIFICATION_STATS); PARSER.declareObject(optionalConstructorArg(), RegressionStats.STRICT_PARSER, REGRESSION_STATS); PARSER.declareObject(optionalConstructorArg(), ModelSizeInfo.PARSER, MODEL_SIZE_INFO); + PARSER.declareObject(optionalConstructorArg(), TrainedModelDefinitionChunk.PARSER, COMPRESSED_INFERENCE_MODEL); } private final RowResults rowResults; private final PhaseProgress phaseProgress; - private final TrainedModelDefinition.Builder inferenceModelBuilder; - private final TrainedModelDefinition inferenceModel; private final MemoryUsage memoryUsage; private final OutlierDetectionStats outlierDetectionStats; private final ClassificationStats classificationStats; private final RegressionStats regressionStats; private final ModelSizeInfo modelSizeInfo; + private final TrainedModelDefinitionChunk trainedModelDefinitionChunk; public AnalyticsResult(@Nullable RowResults rowResults, @Nullable PhaseProgress phaseProgress, - @Nullable TrainedModelDefinition.Builder inferenceModelBuilder, @Nullable MemoryUsage memoryUsage, @Nullable OutlierDetectionStats outlierDetectionStats, @Nullable ClassificationStats classificationStats, @Nullable RegressionStats regressionStats, - @Nullable ModelSizeInfo modelSizeInfo) { + @Nullable ModelSizeInfo modelSizeInfo, + @Nullable TrainedModelDefinitionChunk trainedModelDefinitionChunk) { this.rowResults = rowResults; this.phaseProgress = phaseProgress; - this.inferenceModelBuilder = inferenceModelBuilder; - this.inferenceModel = inferenceModelBuilder == null ? null : inferenceModelBuilder.build(); this.memoryUsage = memoryUsage; this.outlierDetectionStats = outlierDetectionStats; this.classificationStats = classificationStats; this.regressionStats = regressionStats; this.modelSizeInfo = modelSizeInfo; + this.trainedModelDefinitionChunk = trainedModelDefinitionChunk; } public RowResults getRowResults() { @@ -99,10 +92,6 @@ public PhaseProgress getPhaseProgress() { return phaseProgress; } - public TrainedModelDefinition.Builder getInferenceModelBuilder() { - return inferenceModelBuilder; - } - public MemoryUsage getMemoryUsage() { return memoryUsage; } @@ -123,6 +112,10 @@ public ModelSizeInfo getModelSizeInfo() { return modelSizeInfo; } + public TrainedModelDefinitionChunk getTrainedModelDefinitionChunk() { + return trainedModelDefinitionChunk; + } + @Override public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { builder.startObject(); @@ -132,11 +125,6 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws if (phaseProgress != null) { builder.field(PHASE_PROGRESS.getPreferredName(), phaseProgress); } - if (inferenceModel != null) { - builder.field(INFERENCE_MODEL.getPreferredName(), - inferenceModel, - new ToXContent.MapParams(Collections.singletonMap(FOR_INTERNAL_STORAGE, "true"))); - } if (memoryUsage != null) { builder.field(ANALYTICS_MEMORY_USAGE.getPreferredName(), memoryUsage, params); } @@ -152,6 +140,9 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws if (modelSizeInfo != null) { builder.field(MODEL_SIZE_INFO.getPreferredName(), modelSizeInfo); } + if (trainedModelDefinitionChunk != null) { + builder.field(COMPRESSED_INFERENCE_MODEL.getPreferredName(), trainedModelDefinitionChunk); + } builder.endObject(); return builder; } @@ -168,17 +159,17 @@ public boolean equals(Object other) { AnalyticsResult that = (AnalyticsResult) other; return Objects.equals(rowResults, that.rowResults) && Objects.equals(phaseProgress, that.phaseProgress) - && Objects.equals(inferenceModel, that.inferenceModel) && Objects.equals(memoryUsage, that.memoryUsage) && Objects.equals(outlierDetectionStats, that.outlierDetectionStats) && Objects.equals(classificationStats, that.classificationStats) && Objects.equals(modelSizeInfo, that.modelSizeInfo) + && Objects.equals(trainedModelDefinitionChunk, that.trainedModelDefinitionChunk) && Objects.equals(regressionStats, that.regressionStats); } @Override public int hashCode() { - return Objects.hash(rowResults, phaseProgress, inferenceModel, memoryUsage, outlierDetectionStats, classificationStats, - regressionStats); + return Objects.hash(rowResults, phaseProgress, memoryUsage, outlierDetectionStats, classificationStats, + regressionStats, modelSizeInfo, trainedModelDefinitionChunk); } } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/results/TrainedModelDefinitionChunk.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/results/TrainedModelDefinitionChunk.java new file mode 100644 index 0000000000000..9b28b15971e67 --- /dev/null +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/results/TrainedModelDefinitionChunk.java @@ -0,0 +1,76 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ + +package org.elasticsearch.xpack.ml.dataframe.process.results; + +import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.xcontent.ConstructingObjectParser; +import org.elasticsearch.common.xcontent.ToXContentObject; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig; +import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelDefinitionDoc; + +import java.io.IOException; +import java.util.Objects; + +import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg; + +public class TrainedModelDefinitionChunk implements ToXContentObject { + + private static final ParseField DEFINITION = new ParseField("definition"); + private static final ParseField TOTAL_DEFINITION_LENGTH = new ParseField("total_definition_length"); + + public static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>( + "chunked_trained_model_definition", + a -> new TrainedModelDefinitionChunk((String) a[0], (Long) a[1])); + + static { + PARSER.declareString(constructorArg(), DEFINITION); + PARSER.declareLong(constructorArg(), TOTAL_DEFINITION_LENGTH); + } + + private final String definition; + private final long totalDefinitionLength; + + public TrainedModelDefinitionChunk(String definition, long totalDefinitionLength) { + this.definition = definition; + this.totalDefinitionLength = totalDefinitionLength; + } + + public TrainedModelDefinitionDoc createTrainedModelDoc(String modelId, int docNum) { + return new TrainedModelDefinitionDoc.Builder() + .setCompressionVersion(TrainedModelConfig.CURRENT_DEFINITION_COMPRESSION_VERSION) + .setTotalDefinitionLength(totalDefinitionLength) + .setModelId(modelId) + .setDefinitionLength(definition.length()) + .setDocNum(docNum) + .setCompressedString(definition) + .build(); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.field(DEFINITION.getPreferredName(), definition); + builder.field(TOTAL_DEFINITION_LENGTH.getPreferredName(), totalDefinitionLength); + builder.endObject(); + return builder; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + TrainedModelDefinitionChunk that = (TrainedModelDefinitionChunk) o; + return totalDefinitionLength == that.totalDefinitionLength && + Objects.equals(definition, that.definition); + } + + @Override + public int hashCode() { + return Objects.hash(definition, totalDefinitionLength); + } +} diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelDefinitionDoc.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelDefinitionDoc.java index b2b53ec445c2d..a378828756516 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelDefinitionDoc.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelDefinitionDoc.java @@ -114,6 +114,10 @@ public int getCompressionVersion() { return compressionVersion; } + public String getDocId() { + return docId(modelId, docNum); + } + @Override public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { builder.startObject(); diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelProvider.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelProvider.java index 0280eb3a86e3d..f185b92583076 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelProvider.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelProvider.java @@ -14,10 +14,14 @@ import org.elasticsearch.ResourceNotFoundException; import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.DocWriteRequest; +import org.elasticsearch.action.admin.indices.refresh.RefreshAction; +import org.elasticsearch.action.admin.indices.refresh.RefreshRequest; +import org.elasticsearch.action.admin.indices.refresh.RefreshResponse; import org.elasticsearch.action.bulk.BulkAction; import org.elasticsearch.action.bulk.BulkItemResponse; import org.elasticsearch.action.bulk.BulkRequestBuilder; import org.elasticsearch.action.bulk.BulkResponse; +import org.elasticsearch.action.index.IndexAction; import org.elasticsearch.action.index.IndexRequest; import org.elasticsearch.action.search.MultiSearchAction; import org.elasticsearch.action.search.MultiSearchRequest; @@ -142,6 +146,74 @@ public void storeTrainedModel(TrainedModelConfig trainedModelConfig, storeTrainedModelAndDefinition(trainedModelConfig, listener); } + public void storeTrainedModelMetadata(TrainedModelConfig trainedModelConfig, + ActionListener listener) { + if (MODELS_STORED_AS_RESOURCE.contains(trainedModelConfig.getModelId())) { + listener.onFailure(new ResourceAlreadyExistsException( + Messages.getMessage(Messages.INFERENCE_TRAINED_MODEL_EXISTS, trainedModelConfig.getModelId()))); + return; + } + assert trainedModelConfig.getModelDefinition() == null; + + executeAsyncWithOrigin(client, + ML_ORIGIN, + IndexAction.INSTANCE, + createRequest(trainedModelConfig.getModelId(), InferenceIndexConstants.LATEST_INDEX_NAME, trainedModelConfig), + ActionListener.wrap( + indexResponse -> listener.onResponse(true), + e -> { + if (ExceptionsHelper.unwrapCause(e) instanceof VersionConflictEngineException) { + listener.onFailure(new ResourceAlreadyExistsException( + Messages.getMessage(Messages.INFERENCE_TRAINED_MODEL_EXISTS, trainedModelConfig.getModelId()))); + } else { + listener.onFailure( + new ElasticsearchStatusException(Messages.INFERENCE_FAILED_TO_STORE_MODEL, + RestStatus.INTERNAL_SERVER_ERROR, + e, + trainedModelConfig.getModelId())); + } + } + )); + } + + public void storeTrainedModelDefinitionDoc(TrainedModelDefinitionDoc trainedModelDefinitionDoc, ActionListener listener) { + if (MODELS_STORED_AS_RESOURCE.contains(trainedModelDefinitionDoc.getModelId())) { + listener.onFailure(new ResourceAlreadyExistsException( + Messages.getMessage(Messages.INFERENCE_TRAINED_MODEL_EXISTS, trainedModelDefinitionDoc.getModelId()))); + return; + } + + executeAsyncWithOrigin(client, + ML_ORIGIN, + IndexAction.INSTANCE, + createRequest(trainedModelDefinitionDoc.getDocId(), InferenceIndexConstants.LATEST_INDEX_NAME, trainedModelDefinitionDoc), + ActionListener.wrap( + indexResponse -> listener.onResponse(null), + e -> { + if (ExceptionsHelper.unwrapCause(e) instanceof VersionConflictEngineException) { + listener.onFailure(new ResourceAlreadyExistsException( + Messages.getMessage(Messages.INFERENCE_TRAINED_MODEL_DOC_EXISTS, + trainedModelDefinitionDoc.getModelId(), + trainedModelDefinitionDoc.getDocNum()))); + } else { + listener.onFailure( + new ElasticsearchStatusException(Messages.INFERENCE_FAILED_TO_STORE_MODEL, + RestStatus.INTERNAL_SERVER_ERROR, + e, + trainedModelDefinitionDoc.getModelId())); + } + } + )); + } + + public void refreshInferenceIndex(ActionListener listener) { + executeAsyncWithOrigin(client, + ML_ORIGIN, + RefreshAction.INSTANCE, + new RefreshRequest(InferenceIndexConstants.INDEX_PATTERN), + listener); + } + private void storeTrainedModelAndDefinition(TrainedModelConfig trainedModelConfig, ActionListener listener) { @@ -831,14 +903,18 @@ private TrainedModelDefinitionDoc parseModelDefinitionDocLenientlyFromSource(Byt } } + private IndexRequest createRequest(String docId, String index, ToXContentObject body) { + return createRequest(new IndexRequest(index), docId, body); + } + private IndexRequest createRequest(String docId, ToXContentObject body) { + return createRequest(new IndexRequest(), docId, body); + } + + private IndexRequest createRequest(IndexRequest request, String docId, ToXContentObject body) { try (XContentBuilder builder = XContentFactory.jsonBuilder()) { XContentBuilder source = body.toXContent(builder, FOR_INTERNAL_STORAGE_PARAMS); - - return new IndexRequest() - .opType(DocWriteRequest.OpType.CREATE) - .id(docId) - .source(source); + return request.opType(DocWriteRequest.OpType.CREATE).id(docId).source(source); } catch (IOException ex) { // This should never happen. If we were able to deserialize the object (from Native or REST) and then fail to serialize it again // that is not the users fault. We did something wrong and should throw. diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsResultProcessorTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsResultProcessorTests.java index 4450c10ead41d..d625533dee7d7 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsResultProcessorTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsResultProcessorTests.java @@ -5,34 +5,24 @@ */ package org.elasticsearch.xpack.ml.dataframe.process; -import org.elasticsearch.Version; import org.elasticsearch.action.ActionListener; import org.elasticsearch.common.unit.ByteSizeValue; -import org.elasticsearch.common.xcontent.XContentHelper; -import org.elasticsearch.common.xcontent.json.JsonXContent; -import org.elasticsearch.license.License; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig; import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsDest; import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsSource; -import org.elasticsearch.xpack.core.ml.dataframe.analyses.Classification; import org.elasticsearch.xpack.core.ml.dataframe.analyses.DataFrameAnalysis; import org.elasticsearch.xpack.core.ml.dataframe.analyses.Regression; import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig; -import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinition; -import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinitionTests; -import org.elasticsearch.xpack.core.ml.inference.trainedmodel.PredictionFieldType; -import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TargetType; -import org.elasticsearch.xpack.core.security.user.XPackUser; import org.elasticsearch.xpack.ml.dataframe.process.results.AnalyticsResult; import org.elasticsearch.xpack.ml.dataframe.process.results.RowResults; import org.elasticsearch.xpack.ml.dataframe.stats.ProgressTracker; import org.elasticsearch.xpack.ml.dataframe.stats.StatsHolder; import org.elasticsearch.xpack.ml.dataframe.stats.StatsPersister; -import org.elasticsearch.xpack.ml.extractor.DocValueField; import org.elasticsearch.xpack.ml.extractor.ExtractedField; import org.elasticsearch.xpack.ml.extractor.ExtractedFields; -import org.elasticsearch.xpack.ml.extractor.MultiField; +import org.elasticsearch.xpack.ml.inference.modelsize.ModelSizeInfo; +import org.elasticsearch.xpack.ml.inference.modelsize.ModelSizeInfoTests; import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider; import org.elasticsearch.xpack.ml.notifications.DataFrameAnalyticsAuditor; import org.junit.Before; @@ -40,17 +30,12 @@ import org.mockito.InOrder; import org.mockito.Mockito; -import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; import java.util.List; -import java.util.Map; -import java.util.Set; -import static org.hamcrest.Matchers.contains; import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.equalTo; -import static org.hamcrest.Matchers.hasKey; import static org.hamcrest.Matchers.startsWith; import static org.mockito.Matchers.any; import static org.mockito.Matchers.eq; @@ -159,74 +144,6 @@ public void testProcess_GivenDataFrameRowsJoinerFails() { assertThat(statsHolder.getProgressTracker().getWritingResultsProgressPercent(), equalTo(0)); } - @SuppressWarnings("unchecked") - public void testProcess_GivenInferenceModelIsStoredSuccessfully() { - givenDataFrameRows(0); - - doAnswer(invocationOnMock -> { - ActionListener storeListener = (ActionListener) invocationOnMock.getArguments()[1]; - storeListener.onResponse(true); - return null; - }).when(trainedModelProvider).storeTrainedModel(any(TrainedModelConfig.class), any(ActionListener.class)); - - List extractedFieldList = new ArrayList<>(3); - extractedFieldList.add(new DocValueField("foo", Collections.emptySet())); - extractedFieldList.add(new MultiField("bar", new DocValueField("bar.keyword", Collections.emptySet()))); - extractedFieldList.add(new DocValueField("baz", Collections.emptySet())); - TargetType targetType = analyticsConfig.getAnalysis() instanceof Regression ? TargetType.REGRESSION : TargetType.CLASSIFICATION; - TrainedModelDefinition.Builder inferenceModel = TrainedModelDefinitionTests.createRandomBuilder(targetType); - givenProcessResults(Arrays.asList(new AnalyticsResult(null, null, inferenceModel, null, null, null, null, null))); - AnalyticsResultProcessor resultProcessor = createResultProcessor(extractedFieldList); - - resultProcessor.process(process); - resultProcessor.awaitForCompletion(); - - ArgumentCaptor storedModelCaptor = ArgumentCaptor.forClass(TrainedModelConfig.class); - verify(trainedModelProvider).storeTrainedModel(storedModelCaptor.capture(), any(ActionListener.class)); - - TrainedModelConfig storedModel = storedModelCaptor.getValue(); - assertThat(storedModel.getLicenseLevel(), equalTo(License.OperationMode.PLATINUM)); - assertThat(storedModel.getModelId(), containsString(JOB_ID)); - assertThat(storedModel.getVersion(), equalTo(Version.CURRENT)); - assertThat(storedModel.getCreatedBy(), equalTo(XPackUser.NAME)); - assertThat(storedModel.getTags(), contains(JOB_ID)); - assertThat(storedModel.getDescription(), equalTo(JOB_DESCRIPTION)); - assertThat(storedModel.getModelDefinition(), equalTo(inferenceModel.build())); - assertThat(storedModel.getDefaultFieldMap(), equalTo(Collections.singletonMap("bar", "bar.keyword"))); - assertThat(storedModel.getInput().getFieldNames(), equalTo(Arrays.asList("bar.keyword", "baz"))); - assertThat(storedModel.getEstimatedHeapMemory(), equalTo(inferenceModel.build().ramBytesUsed())); - assertThat(storedModel.getEstimatedOperations(), equalTo(inferenceModel.build().getTrainedModel().estimatedNumOperations())); - if (targetType.equals(TargetType.CLASSIFICATION)) { - assertThat(storedModel.getInferenceConfig().getName(), equalTo("classification")); - } else { - assertThat(storedModel.getInferenceConfig().getName(), equalTo("regression")); - } - Map metadata = storedModel.getMetadata(); - assertThat(metadata.size(), equalTo(1)); - assertThat(metadata, hasKey("analytics_config")); - Map analyticsConfigAsMap = XContentHelper.convertToMap(JsonXContent.jsonXContent, analyticsConfig.toString(), - true); - assertThat(analyticsConfigAsMap, equalTo(metadata.get("analytics_config"))); - - ArgumentCaptor auditCaptor = ArgumentCaptor.forClass(String.class); - verify(auditor).info(eq(JOB_ID), auditCaptor.capture()); - assertThat(auditCaptor.getValue(), containsString("Stored trained model with id [" + JOB_ID)); - Mockito.verifyNoMoreInteractions(auditor); - } - - public void testGetPredictionFieldType() { - List extractedFieldList = Arrays.asList( - new DocValueField("foo", Collections.emptySet()), - new DocValueField("bar", Set.of("keyword")), - new DocValueField("baz", Set.of("long")), - new DocValueField("bingo", Set.of("boolean"))); - AnalyticsResultProcessor resultProcessor = createResultProcessor(extractedFieldList); - assertThat(resultProcessor.getPredictionFieldType(new Classification("foo")), equalTo(PredictionFieldType.STRING)); - assertThat(resultProcessor.getPredictionFieldType(new Classification("bar")), equalTo(PredictionFieldType.STRING)); - assertThat(resultProcessor.getPredictionFieldType(new Classification("baz")), equalTo(PredictionFieldType.NUMBER)); - assertThat(resultProcessor.getPredictionFieldType(new Classification("bingo")), equalTo(PredictionFieldType.BOOLEAN)); - } - @SuppressWarnings("unchecked") public void testProcess_GivenInferenceModelFailedToStore() { givenDataFrameRows(0); @@ -235,11 +152,10 @@ public void testProcess_GivenInferenceModelFailedToStore() { ActionListener storeListener = (ActionListener) invocationOnMock.getArguments()[1]; storeListener.onFailure(new RuntimeException("some failure")); return null; - }).when(trainedModelProvider).storeTrainedModel(any(TrainedModelConfig.class), any(ActionListener.class)); + }).when(trainedModelProvider).storeTrainedModelMetadata(any(TrainedModelConfig.class), any(ActionListener.class)); - TargetType targetType = analyticsConfig.getAnalysis() instanceof Regression ? TargetType.REGRESSION : TargetType.CLASSIFICATION; - TrainedModelDefinition.Builder inferenceModel = TrainedModelDefinitionTests.createRandomBuilder(targetType); - givenProcessResults(Arrays.asList(new AnalyticsResult(null, null, inferenceModel, null, null, null, null, null))); + ModelSizeInfo modelSizeInfo = ModelSizeInfoTests.createRandom(); + givenProcessResults(Arrays.asList(new AnalyticsResult(null, null, null, null, null, null, modelSizeInfo, null))); AnalyticsResultProcessor resultProcessor = createResultProcessor(); resultProcessor.process(process); @@ -248,10 +164,12 @@ public void testProcess_GivenInferenceModelFailedToStore() { // This test verifies the processor knows how to handle a failure on storing the model and completes normally ArgumentCaptor auditCaptor = ArgumentCaptor.forClass(String.class); verify(auditor).error(eq(JOB_ID), auditCaptor.capture()); - assertThat(auditCaptor.getValue(), containsString("Error processing results; error storing trained model with id [" + JOB_ID)); + assertThat(auditCaptor.getValue(), + containsString("Error processing results; error storing trained model metadata with id [" + JOB_ID)); Mockito.verifyNoMoreInteractions(auditor); - assertThat(resultProcessor.getFailure(), startsWith("error processing results; error storing trained model with id [" + JOB_ID)); + assertThat(resultProcessor.getFailure(), + startsWith("error processing results; error storing trained model metadata with id [" + JOB_ID)); assertThat(statsHolder.getProgressTracker().getWritingResultsProgressPercent(), equalTo(0)); } diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/ChunkedTrainedModelPersisterTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/ChunkedTrainedModelPersisterTests.java new file mode 100644 index 0000000000000..277673a4a834b --- /dev/null +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/ChunkedTrainedModelPersisterTests.java @@ -0,0 +1,223 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ + +package org.elasticsearch.xpack.ml.dataframe.process; + +import org.elasticsearch.Version; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.common.xcontent.XContentHelper; +import org.elasticsearch.common.xcontent.json.JsonXContent; +import org.elasticsearch.license.License; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig; +import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsDest; +import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsSource; +import org.elasticsearch.xpack.core.ml.dataframe.analyses.BoostedTreeParams; +import org.elasticsearch.xpack.core.ml.dataframe.analyses.Classification; +import org.elasticsearch.xpack.core.ml.dataframe.analyses.Regression; +import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.PredictionFieldType; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfig; +import org.elasticsearch.xpack.core.security.user.XPackUser; +import org.elasticsearch.xpack.ml.dataframe.process.results.TrainedModelDefinitionChunk; +import org.elasticsearch.xpack.ml.extractor.DocValueField; +import org.elasticsearch.xpack.ml.extractor.ExtractedField; +import org.elasticsearch.xpack.ml.inference.modelsize.ModelSizeInfo; +import org.elasticsearch.xpack.ml.inference.modelsize.ModelSizeInfoTests; +import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelDefinitionDoc; +import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider; +import org.elasticsearch.xpack.ml.notifications.DataFrameAnalyticsAuditor; +import org.junit.Before; +import org.mockito.ArgumentCaptor; +import org.mockito.Mockito; + +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.Set; + +import static org.hamcrest.Matchers.contains; +import static org.hamcrest.Matchers.containsString; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.hasKey; +import static org.hamcrest.Matchers.instanceOf; +import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.nullValue; +import static org.mockito.Matchers.any; +import static org.mockito.Matchers.eq; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; + +public class ChunkedTrainedModelPersisterTests extends ESTestCase { + + private static final String JOB_ID = "analytics-result-processor-tests"; + private static final String JOB_DESCRIPTION = "This describes the job of these tests"; + + private TrainedModelProvider trainedModelProvider; + private DataFrameAnalyticsAuditor auditor; + + @Before + public void setUpMocks() { + trainedModelProvider = mock(TrainedModelProvider.class); + auditor = mock(DataFrameAnalyticsAuditor.class); + } + + @SuppressWarnings("unchecked") + public void testPersistAllDocs() { + DataFrameAnalyticsConfig analyticsConfig = new DataFrameAnalyticsConfig.Builder() + .setId(JOB_ID) + .setDescription(JOB_DESCRIPTION) + .setSource(new DataFrameAnalyticsSource(new String[] {"my_source"}, null, null)) + .setDest(new DataFrameAnalyticsDest("my_dest", null)) + .setAnalysis(randomBoolean() ? new Regression("foo") : new Classification("foo")) + .build(); + List extractedFieldList = Collections.singletonList(new DocValueField("foo", Collections.emptySet())); + + doAnswer(invocationOnMock -> { + ActionListener storeListener = (ActionListener) invocationOnMock.getArguments()[1]; + storeListener.onResponse(true); + return null; + }).when(trainedModelProvider).storeTrainedModelMetadata(any(TrainedModelConfig.class), any(ActionListener.class)); + + doAnswer(invocationOnMock -> { + ActionListener storeListener = (ActionListener) invocationOnMock.getArguments()[1]; + storeListener.onResponse(null); + return null; + }).when(trainedModelProvider).storeTrainedModelDefinitionDoc(any(TrainedModelDefinitionDoc.class), any(ActionListener.class)); + + ChunkedTrainedModelPersister resultProcessor = createChunkedTrainedModelPersister(extractedFieldList, analyticsConfig); + ModelSizeInfo modelSizeInfo = ModelSizeInfoTests.createRandom(); + TrainedModelDefinitionChunk chunk1 = new TrainedModelDefinitionChunk(randomAlphaOfLength(10), 20L); + TrainedModelDefinitionChunk chunk2 = new TrainedModelDefinitionChunk(randomAlphaOfLength(10), 20L); + + resultProcessor.createAndIndexInferenceModelMetadata(modelSizeInfo); + resultProcessor.createAndIndexInferenceModelDoc(chunk1); + resultProcessor.createAndIndexInferenceModelDoc(chunk2); + + ArgumentCaptor storedModelCaptor = ArgumentCaptor.forClass(TrainedModelConfig.class); + verify(trainedModelProvider).storeTrainedModelMetadata(storedModelCaptor.capture(), any(ActionListener.class)); + + ArgumentCaptor storedDocCapture = ArgumentCaptor.forClass(TrainedModelDefinitionDoc.class); + verify(trainedModelProvider, times(2)) + .storeTrainedModelDefinitionDoc(storedDocCapture.capture(), any(ActionListener.class)); + + TrainedModelConfig storedModel = storedModelCaptor.getValue(); + assertThat(storedModel.getLicenseLevel(), equalTo(License.OperationMode.PLATINUM)); + assertThat(storedModel.getModelId(), containsString(JOB_ID)); + assertThat(storedModel.getVersion(), equalTo(Version.CURRENT)); + assertThat(storedModel.getCreatedBy(), equalTo(XPackUser.NAME)); + assertThat(storedModel.getTags(), contains(JOB_ID)); + assertThat(storedModel.getDescription(), equalTo(JOB_DESCRIPTION)); + assertThat(storedModel.getModelDefinition(), is(nullValue())); + assertThat(storedModel.getEstimatedHeapMemory(), equalTo(modelSizeInfo.ramBytesUsed())); + assertThat(storedModel.getEstimatedOperations(), equalTo((long)modelSizeInfo.numOperations())); + if (analyticsConfig.getAnalysis() instanceof Classification) { + assertThat(storedModel.getInferenceConfig().getName(), equalTo("classification")); + } else { + assertThat(storedModel.getInferenceConfig().getName(), equalTo("regression")); + } + Map metadata = storedModel.getMetadata(); + assertThat(metadata.size(), equalTo(1)); + assertThat(metadata, hasKey("analytics_config")); + Map analyticsConfigAsMap = XContentHelper.convertToMap(JsonXContent.jsonXContent, analyticsConfig.toString(), + true); + assertThat(analyticsConfigAsMap, equalTo(metadata.get("analytics_config"))); + + TrainedModelDefinitionDoc storedDoc1 = storedDocCapture.getAllValues().get(0); + assertThat(storedDoc1.getDocNum(), equalTo(0)); + TrainedModelDefinitionDoc storedDoc2 = storedDocCapture.getAllValues().get(1); + assertThat(storedDoc2.getDocNum(), equalTo(1)); + + assertThat(storedModel.getModelId(), equalTo(storedDoc1.getModelId())); + assertThat(storedModel.getModelId(), equalTo(storedDoc2.getModelId())); + + ArgumentCaptor auditCaptor = ArgumentCaptor.forClass(String.class); + verify(auditor).info(eq(JOB_ID), auditCaptor.capture()); + assertThat(auditCaptor.getValue(), containsString("Stored trained model with id [" + JOB_ID)); + Mockito.verifyNoMoreInteractions(auditor); + } + + public void testGetPredictionFieldType() { + DataFrameAnalyticsConfig analyticsConfig = new DataFrameAnalyticsConfig.Builder() + .setId(JOB_ID) + .setDescription(JOB_DESCRIPTION) + .setSource(new DataFrameAnalyticsSource(new String[] {"my_source"}, null, null)) + .setDest(new DataFrameAnalyticsDest("my_dest", null)) + .setAnalysis(randomBoolean() ? new Regression("foo") : new Classification("foo")) + .build(); + List extractedFieldList = Arrays.asList( + new DocValueField("foo", Collections.emptySet()), + new DocValueField("bar", Set.of("keyword")), + new DocValueField("baz", Set.of("long")), + new DocValueField("bingo", Set.of("boolean"))); + ChunkedTrainedModelPersister resultProcessor = createChunkedTrainedModelPersister(extractedFieldList, analyticsConfig); + assertThat(resultProcessor.getPredictionFieldType(new Classification("foo")), equalTo(PredictionFieldType.STRING)); + assertThat(resultProcessor.getPredictionFieldType(new Classification("bar")), equalTo(PredictionFieldType.STRING)); + assertThat(resultProcessor.getPredictionFieldType(new Classification("baz")), equalTo(PredictionFieldType.NUMBER)); + assertThat(resultProcessor.getPredictionFieldType(new Classification("bingo")), equalTo(PredictionFieldType.BOOLEAN)); + } + + public void testBuildInferenceConfigByAnalyticsType() { + List extractedFieldList = Collections.singletonList(new DocValueField("foo", Collections.emptySet())); + DataFrameAnalyticsConfig.Builder analyticsConfigBuilder = new DataFrameAnalyticsConfig.Builder() + .setId(JOB_ID) + .setDescription(JOB_DESCRIPTION) + .setSource(new DataFrameAnalyticsSource(new String[] {"my_source"}, null, null)) + .setDest(new DataFrameAnalyticsDest("my_dest", null)); + { + DataFrameAnalyticsConfig analyticsConfig = analyticsConfigBuilder + .setAnalysis(new Regression("foo", + new BoostedTreeParams(null, null, null, null, null, 2), + null, + null, + null, + null, + null + )) + .build(); + ChunkedTrainedModelPersister resultProcessor = createChunkedTrainedModelPersister(extractedFieldList, analyticsConfig); + InferenceConfig inferenceConfig = resultProcessor.buildInferenceConfigByAnalyticsType(); + + assertThat(inferenceConfig, instanceOf(RegressionConfig.class)); + assertThat(((RegressionConfig)inferenceConfig).getNumTopFeatureImportanceValues(), equalTo(2)); + } + { + DataFrameAnalyticsConfig analyticsConfig = analyticsConfigBuilder + .setAnalysis(new Classification("foo", + new BoostedTreeParams(null, null, null, null, null, 2), + null, + null, + 1, + null, + null + )) + .build(); + ChunkedTrainedModelPersister resultProcessor = createChunkedTrainedModelPersister(extractedFieldList, analyticsConfig); + InferenceConfig inferenceConfig = resultProcessor.buildInferenceConfigByAnalyticsType(); + + assertThat(inferenceConfig, instanceOf(ClassificationConfig.class)); + ClassificationConfig classificationConfig = (ClassificationConfig)inferenceConfig; + assertThat(classificationConfig.getNumTopFeatureImportanceValues(), equalTo(2)); + assertThat(classificationConfig.getNumTopClasses(), equalTo(1)); + assertThat(classificationConfig.getPredictionFieldType(), equalTo(PredictionFieldType.STRING)); + } + } + + private ChunkedTrainedModelPersister createChunkedTrainedModelPersister(List fieldNames, + DataFrameAnalyticsConfig analyticsConfig) { + return new ChunkedTrainedModelPersister(trainedModelProvider, + analyticsConfig, + auditor, + (unused)->{}, + fieldNames); + } + +} diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/results/AnalyticsResultTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/results/AnalyticsResultTests.java index bb484d8dd651e..60426f77aa79b 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/results/AnalyticsResultTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/results/AnalyticsResultTests.java @@ -20,8 +20,6 @@ import org.elasticsearch.xpack.core.ml.dataframe.stats.regression.RegressionStats; import org.elasticsearch.xpack.core.ml.dataframe.stats.regression.RegressionStatsTests; import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider; -import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinition; -import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinitionTests; import org.elasticsearch.xpack.core.ml.utils.PhaseProgress; import org.elasticsearch.xpack.core.ml.utils.ToXContentParams; import org.elasticsearch.xpack.ml.inference.modelsize.MlModelSizeNamedXContentProvider; @@ -46,21 +44,18 @@ protected NamedXContentRegistry xContentRegistry() { protected AnalyticsResult createTestInstance() { RowResults rowResults = null; PhaseProgress phaseProgress = null; - TrainedModelDefinition.Builder inferenceModel = null; MemoryUsage memoryUsage = null; OutlierDetectionStats outlierDetectionStats = null; ClassificationStats classificationStats = null; RegressionStats regressionStats = null; ModelSizeInfo modelSizeInfo = null; + TrainedModelDefinitionChunk trainedModelDefinitionChunk = null; if (randomBoolean()) { rowResults = RowResultsTests.createRandom(); } if (randomBoolean()) { phaseProgress = new PhaseProgress(randomAlphaOfLength(10), randomIntBetween(0, 100)); } - if (randomBoolean()) { - inferenceModel = TrainedModelDefinitionTests.createRandomBuilder(); - } if (randomBoolean()) { memoryUsage = MemoryUsageTests.createRandom(); } @@ -76,8 +71,13 @@ protected AnalyticsResult createTestInstance() { if (randomBoolean()) { modelSizeInfo = ModelSizeInfoTests.createRandom(); } - return new AnalyticsResult(rowResults, phaseProgress, inferenceModel, memoryUsage, outlierDetectionStats, - classificationStats, regressionStats, modelSizeInfo); + if (randomBoolean()) { + String def = randomAlphaOfLengthBetween(100, 1000); + long totallength = def.length() * randomLongBetween(1, 10); + trainedModelDefinitionChunk = new TrainedModelDefinitionChunk(def, totallength); + } + return new AnalyticsResult(rowResults, phaseProgress, memoryUsage, outlierDetectionStats, + classificationStats, regressionStats, modelSizeInfo, trainedModelDefinitionChunk); } @Override From 3a5ce9301d227876ea178c118b8b97b7cd1bc6c5 Mon Sep 17 00:00:00 2001 From: Benjamin Trent <4357155+benwtrent@users.noreply.github.com> Date: Fri, 26 Jun 2020 09:55:10 -0400 Subject: [PATCH 2/9] fixing after merge --- .../ChunkedTrainedMoodelPersisterIT.java | 3 +- .../process/AnalyticsResultProcessor.java | 90 +---------------- .../process/ChunkedTrainedModelPersister.java | 46 ++------- .../AnalyticsResultProcessorTests.java | 96 ------------------- .../ChunkedTrainedModelPersisterTests.java | 77 +-------------- 5 files changed, 11 insertions(+), 301 deletions(-) diff --git a/x-pack/plugin/ml/src/internalClusterTest/java/org/elasticsearch/xpack/ml/integration/ChunkedTrainedMoodelPersisterIT.java b/x-pack/plugin/ml/src/internalClusterTest/java/org/elasticsearch/xpack/ml/integration/ChunkedTrainedMoodelPersisterIT.java index 28ff17c6c47d5..542c64b15349f 100644 --- a/x-pack/plugin/ml/src/internalClusterTest/java/org/elasticsearch/xpack/ml/integration/ChunkedTrainedMoodelPersisterIT.java +++ b/x-pack/plugin/ml/src/internalClusterTest/java/org/elasticsearch/xpack/ml/integration/ChunkedTrainedMoodelPersisterIT.java @@ -27,6 +27,7 @@ import org.elasticsearch.xpack.ml.dataframe.process.results.TrainedModelDefinitionChunk; import org.elasticsearch.xpack.ml.extractor.DocValueField; import org.elasticsearch.xpack.ml.extractor.ExtractedField; +import org.elasticsearch.xpack.ml.extractor.ExtractedFields; import org.elasticsearch.xpack.ml.inference.modelsize.MlModelSizeNamedXContentProvider; import org.elasticsearch.xpack.ml.inference.modelsize.ModelSizeInfo; import org.elasticsearch.xpack.ml.inference.modelsize.ModelSizeInfoTests; @@ -70,7 +71,7 @@ public void testStoreModelViaChunkedPersister() throws IOException { analyticsConfig, new DataFrameAnalyticsAuditor(client(), "test-node"), (ex) -> { throw new ElasticsearchException(ex); }, - extractedFieldList + new ExtractedFields(extractedFieldList, Collections.emptyMap()) ); //Accuracy for size is not tested here diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsResultProcessor.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsResultProcessor.java index 555775eb0312b..3c1ffca7afe5b 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsResultProcessor.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsResultProcessor.java @@ -14,9 +14,6 @@ import org.elasticsearch.xpack.core.ml.dataframe.stats.common.MemoryUsage; import org.elasticsearch.xpack.core.ml.dataframe.stats.outlierdetection.OutlierDetectionStats; import org.elasticsearch.xpack.core.ml.dataframe.stats.regression.RegressionStats; -import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig; -import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinition; -import org.elasticsearch.xpack.core.ml.inference.TrainedModelInput; import org.elasticsearch.xpack.core.ml.job.messages.Messages; import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; import org.elasticsearch.xpack.core.ml.utils.PhaseProgress; @@ -25,16 +22,12 @@ import org.elasticsearch.xpack.ml.dataframe.process.results.TrainedModelDefinitionChunk; import org.elasticsearch.xpack.ml.dataframe.stats.StatsHolder; import org.elasticsearch.xpack.ml.dataframe.stats.StatsPersister; -import org.elasticsearch.xpack.ml.extractor.ExtractedField; import org.elasticsearch.xpack.ml.inference.modelsize.ModelSizeInfo; import org.elasticsearch.xpack.ml.extractor.ExtractedFields; -import org.elasticsearch.xpack.ml.extractor.MultiField; import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider; import org.elasticsearch.xpack.ml.notifications.DataFrameAnalyticsAuditor; -import java.util.Collections; import java.util.Iterator; -import java.util.List; import java.util.Objects; import java.util.concurrent.CountDownLatch; @@ -58,10 +51,8 @@ public class AnalyticsResultProcessor { private final DataFrameAnalyticsConfig analytics; private final DataFrameRowsJoiner dataFrameRowsJoiner; private final StatsHolder statsHolder; - private final TrainedModelProvider trainedModelProvider; private final DataFrameAnalyticsAuditor auditor; private final StatsPersister statsPersister; - private final ExtractedFields extractedFields; private final CountDownLatch completionLatch = new CountDownLatch(1); private final ChunkedTrainedModelPersister chunkedTrainedModelPersister; private volatile String failure; @@ -73,18 +64,15 @@ public AnalyticsResultProcessor(DataFrameAnalyticsConfig analytics, DataFrameRow this.analytics = Objects.requireNonNull(analytics); this.dataFrameRowsJoiner = Objects.requireNonNull(dataFrameRowsJoiner); this.statsHolder = Objects.requireNonNull(statsHolder); - this.trainedModelProvider = Objects.requireNonNull(trainedModelProvider); this.auditor = Objects.requireNonNull(auditor); this.statsPersister = Objects.requireNonNull(statsPersister); - this.fieldNames = Collections.unmodifiableList(Objects.requireNonNull(fieldNames)); this.chunkedTrainedModelPersister = new ChunkedTrainedModelPersister( trainedModelProvider, analytics, auditor, this::setAndReportFailure, - fieldNames + extractedFields ); - this.extractedFields = Objects.requireNonNull(extractedFields); } @Nullable @@ -193,82 +181,6 @@ private void processResult(AnalyticsResult result, DataFrameRowsJoiner resultsJo } } - private void createAndIndexInferenceModel(TrainedModelDefinition.Builder inferenceModel) { - TrainedModelConfig trainedModelConfig = createTrainedModelConfig(inferenceModel); - CountDownLatch latch = storeTrainedModel(trainedModelConfig); - - try { - if (latch.await(30, TimeUnit.SECONDS) == false) { - LOGGER.error("[{}] Timed out (30s) waiting for inference model to be stored", analytics.getId()); - } - } catch (InterruptedException e) { - Thread.currentThread().interrupt(); - setAndReportFailure(ExceptionsHelper.serverError("interrupted waiting for inference model to be stored")); - } - } - - private TrainedModelConfig createTrainedModelConfig(TrainedModelDefinition.Builder inferenceModel) { - Instant createTime = Instant.now(); - String modelId = analytics.getId() + "-" + createTime.toEpochMilli(); - TrainedModelDefinition definition = inferenceModel.build(); - String dependentVariable = getDependentVariable(); - List fieldNames = extractedFields.getAllFields(); - List fieldNamesWithoutDependentVariable = fieldNames.stream() - .map(ExtractedField::getName) - .filter(f -> f.equals(dependentVariable) == false) - .collect(toList()); - Map defaultFieldMapping = fieldNames.stream() - .filter(ef -> ef instanceof MultiField && (ef.getName().equals(dependentVariable) == false)) - .collect(Collectors.toMap(ExtractedField::getParentField, ExtractedField::getName)); - return TrainedModelConfig.builder() - .setModelId(modelId) - .setCreatedBy(XPackUser.NAME) - .setVersion(Version.CURRENT) - .setCreateTime(createTime) - // NOTE: GET _cat/ml/trained_models relies on the creating analytics ID being in the tags - .setTags(Collections.singletonList(analytics.getId())) - .setDescription(analytics.getDescription()) - .setMetadata(Collections.singletonMap("analytics_config", - XContentHelper.convertToMap(JsonXContent.jsonXContent, analytics.toString(), true))) - .setEstimatedHeapMemory(definition.ramBytesUsed()) - .setEstimatedOperations(definition.getTrainedModel().estimatedNumOperations()) - .setParsedDefinition(inferenceModel) - .setInput(new TrainedModelInput(fieldNamesWithoutDependentVariable)) - .setLicenseLevel(License.OperationMode.PLATINUM.description()) - .setDefaultFieldMap(defaultFieldMapping) - .setInferenceConfig(analytics.getAnalysis().inferenceConfig(new AnalysisFieldInfo(extractedFields))) - .build(); - } - - private String getDependentVariable() { - if (analytics.getAnalysis() instanceof Classification) { - return ((Classification)analytics.getAnalysis()).getDependentVariable(); - } - if (analytics.getAnalysis() instanceof Regression) { - return ((Regression)analytics.getAnalysis()).getDependentVariable(); - } - return null; - } - - private CountDownLatch storeTrainedModel(TrainedModelConfig trainedModelConfig) { - CountDownLatch latch = new CountDownLatch(1); - ActionListener storeListener = ActionListener.wrap( - aBoolean -> { - if (aBoolean == false) { - LOGGER.error("[{}] Storing trained model responded false", analytics.getId()); - setAndReportFailure(ExceptionsHelper.serverError("storing trained model responded false")); - } else { - LOGGER.info("[{}] Stored trained model with id [{}]", analytics.getId(), trainedModelConfig.getModelId()); - auditor.info(analytics.getId(), "Stored trained model with id [" + trainedModelConfig.getModelId() + "]"); - } - }, - e -> setAndReportFailure(ExceptionsHelper.serverError("error storing trained model with id [{}]", e, - trainedModelConfig.getModelId())) - ); - trainedModelProvider.storeTrainedModel(trainedModelConfig, new LatchedActionListener<>(storeListener, latch)); - return latch; - } - private void setAndReportFailure(Exception e) { LOGGER.error(new ParameterizedMessage("[{}] Error processing results; ", analytics.getId()), e); failure = "error processing results; " + e.getMessage(); diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/ChunkedTrainedModelPersister.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/ChunkedTrainedModelPersister.java index 6f1df92e822f1..591a04c4a6fe2 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/ChunkedTrainedModelPersister.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/ChunkedTrainedModelPersister.java @@ -21,14 +21,11 @@ import org.elasticsearch.xpack.core.ml.dataframe.analyses.Regression; import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig; import org.elasticsearch.xpack.core.ml.inference.TrainedModelInput; -import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig; -import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig; -import org.elasticsearch.xpack.core.ml.inference.trainedmodel.PredictionFieldType; -import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfig; import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; import org.elasticsearch.xpack.core.security.user.XPackUser; import org.elasticsearch.xpack.ml.dataframe.process.results.TrainedModelDefinitionChunk; import org.elasticsearch.xpack.ml.extractor.ExtractedField; +import org.elasticsearch.xpack.ml.extractor.ExtractedFields; import org.elasticsearch.xpack.ml.extractor.MultiField; import org.elasticsearch.xpack.ml.inference.modelsize.ModelSizeInfo; import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelDefinitionDoc; @@ -39,7 +36,6 @@ import java.util.Collections; import java.util.List; import java.util.Map; -import java.util.Optional; import java.util.concurrent.CountDownLatch; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicInteger; @@ -60,14 +56,14 @@ public class ChunkedTrainedModelPersister { private final DataFrameAnalyticsConfig analytics; private final DataFrameAnalyticsAuditor auditor; private final Consumer failureHandler; - private final List fieldNames; + private final ExtractedFields extractedFields; private volatile boolean readyToStoreNewModel = true; public ChunkedTrainedModelPersister(TrainedModelProvider provider, DataFrameAnalyticsConfig analytics, DataFrameAnalyticsAuditor auditor, Consumer failureHandler, - List fieldNames) { + ExtractedFields extractedFields) { this.provider = provider; this.currentModelId = new AtomicReference<>(""); this.currentChunkedDoc = new AtomicInteger(0); @@ -75,7 +71,7 @@ public ChunkedTrainedModelPersister(TrainedModelProvider provider, this.analytics = analytics; this.auditor = auditor; this.failureHandler = failureHandler; - this.fieldNames = fieldNames; + this.extractedFields = extractedFields; } public void createAndIndexInferenceModelDoc(TrainedModelDefinitionChunk trainedModelDefinitionChunk) { @@ -182,6 +178,7 @@ private TrainedModelConfig createTrainedModelConfig(ModelSizeInfo modelSize) { currentModelId.set(modelId); currentChunkedDoc.set(0); persistedChunkLengths.set(0); + List fieldNames = extractedFields.getAllFields(); String dependentVariable = getDependentVariable(); List fieldNamesWithoutDependentVariable = fieldNames.stream() .map(ExtractedField::getName) @@ -205,7 +202,7 @@ private TrainedModelConfig createTrainedModelConfig(ModelSizeInfo modelSize) { .setInput(new TrainedModelInput(fieldNamesWithoutDependentVariable)) .setLicenseLevel(License.OperationMode.PLATINUM.description()) .setDefaultFieldMap(defaultFieldMapping) - .setInferenceConfig(buildInferenceConfigByAnalyticsType()) + .setInferenceConfig(analytics.getAnalysis().inferenceConfig(new AnalysisFieldInfo(extractedFields))) .build(); } @@ -219,35 +216,4 @@ private String getDependentVariable() { return null; } - InferenceConfig buildInferenceConfigByAnalyticsType() { - if (analytics.getAnalysis() instanceof Classification) { - Classification classification = ((Classification)analytics.getAnalysis()); - PredictionFieldType predictionFieldType = getPredictionFieldType(classification); - return ClassificationConfig.builder() - .setNumTopClasses(classification.getNumTopClasses()) - .setNumTopFeatureImportanceValues(classification.getBoostedTreeParams().getNumTopFeatureImportanceValues()) - .setPredictionFieldType(predictionFieldType) - .build(); - } else if (analytics.getAnalysis() instanceof Regression) { - Regression regression = ((Regression)analytics.getAnalysis()); - return RegressionConfig.builder() - .setNumTopFeatureImportanceValues(regression.getBoostedTreeParams().getNumTopFeatureImportanceValues()) - .build(); - } - throw ExceptionsHelper.serverError( - "analytics type [{}] does not support model creation", - null, - analytics.getAnalysis().getWriteableName()); - } - - PredictionFieldType getPredictionFieldType(Classification classification) { - String dependentVariable = classification.getDependentVariable(); - Optional extractedField = fieldNames.stream() - .filter(f -> f.getName().equals(dependentVariable)) - .findAny(); - PredictionFieldType predictionFieldType = Classification.getPredictionFieldType( - extractedField.isPresent() ? extractedField.get().getTypes() : null - ); - return predictionFieldType == null ? PredictionFieldType.STRING : predictionFieldType; - } } diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsResultProcessorTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsResultProcessorTests.java index 44a6ae90d96eb..5c3e2b9ce1985 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsResultProcessorTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsResultProcessorTests.java @@ -5,7 +5,6 @@ */ package org.elasticsearch.xpack.ml.dataframe.process; -import org.elasticsearch.action.ActionListener; import org.elasticsearch.common.unit.ByteSizeValue; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig; @@ -13,11 +12,6 @@ import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsSource; import org.elasticsearch.xpack.core.ml.dataframe.analyses.DataFrameAnalysis; import org.elasticsearch.xpack.core.ml.dataframe.analyses.Regression; -import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig; -import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinition; -import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinitionTests; -import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TargetType; -import org.elasticsearch.xpack.core.security.user.XPackUser; import org.elasticsearch.xpack.ml.dataframe.process.results.AnalyticsResult; import org.elasticsearch.xpack.ml.dataframe.process.results.RowResults; import org.elasticsearch.xpack.ml.dataframe.stats.ProgressTracker; @@ -25,8 +19,6 @@ import org.elasticsearch.xpack.ml.dataframe.stats.StatsPersister; import org.elasticsearch.xpack.ml.extractor.ExtractedField; import org.elasticsearch.xpack.ml.extractor.ExtractedFields; -import org.elasticsearch.xpack.ml.inference.modelsize.ModelSizeInfo; -import org.elasticsearch.xpack.ml.inference.modelsize.ModelSizeInfoTests; import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider; import org.elasticsearch.xpack.ml.notifications.DataFrameAnalyticsAuditor; import org.junit.Before; @@ -37,14 +29,11 @@ import java.util.Arrays; import java.util.Collections; import java.util.List; -import java.util.Map; import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.equalTo; -import static org.hamcrest.Matchers.startsWith; import static org.mockito.Matchers.any; import static org.mockito.Matchers.eq; -import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.doThrow; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.verify; @@ -149,91 +138,6 @@ public void testProcess_GivenDataFrameRowsJoinerFails() { assertThat(statsHolder.getProgressTracker().getWritingResultsProgressPercent(), equalTo(0)); } - @SuppressWarnings("unchecked") - public void testProcess_GivenInferenceModelIsStoredSuccessfully() { - givenDataFrameRows(0); - - doAnswer(invocationOnMock -> { - ActionListener storeListener = (ActionListener) invocationOnMock.getArguments()[1]; - storeListener.onResponse(true); - return null; - }).when(trainedModelProvider).storeTrainedModel(any(TrainedModelConfig.class), any(ActionListener.class)); - - List extractedFieldList = new ArrayList<>(3); - extractedFieldList.add(new DocValueField("foo", Collections.emptySet())); - extractedFieldList.add(new MultiField("bar", new DocValueField("bar.keyword", Collections.emptySet()))); - extractedFieldList.add(new DocValueField("baz", Collections.emptySet())); - TargetType targetType = analyticsConfig.getAnalysis() instanceof Regression ? TargetType.REGRESSION : TargetType.CLASSIFICATION; - TrainedModelDefinition.Builder inferenceModel = TrainedModelDefinitionTests.createRandomBuilder(targetType); - givenProcessResults(Arrays.asList(new AnalyticsResult(null, null, inferenceModel, null, null, null, null, null))); - AnalyticsResultProcessor resultProcessor = createResultProcessor(extractedFieldList); - - resultProcessor.process(process); - resultProcessor.awaitForCompletion(); - - ArgumentCaptor storedModelCaptor = ArgumentCaptor.forClass(TrainedModelConfig.class); - verify(trainedModelProvider).storeTrainedModel(storedModelCaptor.capture(), any(ActionListener.class)); - - TrainedModelConfig storedModel = storedModelCaptor.getValue(); - assertThat(storedModel.getLicenseLevel(), equalTo(License.OperationMode.PLATINUM)); - assertThat(storedModel.getModelId(), containsString(JOB_ID)); - assertThat(storedModel.getVersion(), equalTo(Version.CURRENT)); - assertThat(storedModel.getCreatedBy(), equalTo(XPackUser.NAME)); - assertThat(storedModel.getTags(), contains(JOB_ID)); - assertThat(storedModel.getDescription(), equalTo(JOB_DESCRIPTION)); - assertThat(storedModel.getModelDefinition(), equalTo(inferenceModel.build())); - assertThat(storedModel.getDefaultFieldMap(), equalTo(Collections.singletonMap("bar", "bar.keyword"))); - assertThat(storedModel.getInput().getFieldNames(), equalTo(Arrays.asList("bar.keyword", "baz"))); - assertThat(storedModel.getEstimatedHeapMemory(), equalTo(inferenceModel.build().ramBytesUsed())); - assertThat(storedModel.getEstimatedOperations(), equalTo(inferenceModel.build().getTrainedModel().estimatedNumOperations())); - if (targetType.equals(TargetType.CLASSIFICATION)) { - assertThat(storedModel.getInferenceConfig().getName(), equalTo("classification")); - } else { - assertThat(storedModel.getInferenceConfig().getName(), equalTo("regression")); - } - Map metadata = storedModel.getMetadata(); - assertThat(metadata.size(), equalTo(1)); - assertThat(metadata, hasKey("analytics_config")); - Map analyticsConfigAsMap = XContentHelper.convertToMap(JsonXContent.jsonXContent, analyticsConfig.toString(), - true); - assertThat(analyticsConfigAsMap, equalTo(metadata.get("analytics_config"))); - - ArgumentCaptor auditCaptor = ArgumentCaptor.forClass(String.class); - verify(auditor).info(eq(JOB_ID), auditCaptor.capture()); - assertThat(auditCaptor.getValue(), containsString("Stored trained model with id [" + JOB_ID)); - Mockito.verifyNoMoreInteractions(auditor); - } - - - @SuppressWarnings("unchecked") - public void testProcess_GivenInferenceModelFailedToStore() { - givenDataFrameRows(0); - - doAnswer(invocationOnMock -> { - ActionListener storeListener = (ActionListener) invocationOnMock.getArguments()[1]; - storeListener.onFailure(new RuntimeException("some failure")); - return null; - }).when(trainedModelProvider).storeTrainedModelMetadata(any(TrainedModelConfig.class), any(ActionListener.class)); - - ModelSizeInfo modelSizeInfo = ModelSizeInfoTests.createRandom(); - givenProcessResults(Arrays.asList(new AnalyticsResult(null, null, null, null, null, null, modelSizeInfo, null))); - AnalyticsResultProcessor resultProcessor = createResultProcessor(); - - resultProcessor.process(process); - resultProcessor.awaitForCompletion(); - - // This test verifies the processor knows how to handle a failure on storing the model and completes normally - ArgumentCaptor auditCaptor = ArgumentCaptor.forClass(String.class); - verify(auditor).error(eq(JOB_ID), auditCaptor.capture()); - assertThat(auditCaptor.getValue(), - containsString("Error processing results; error storing trained model metadata with id [" + JOB_ID)); - Mockito.verifyNoMoreInteractions(auditor); - - assertThat(resultProcessor.getFailure(), - startsWith("error processing results; error storing trained model metadata with id [" + JOB_ID)); - assertThat(statsHolder.getProgressTracker().getWritingResultsProgressPercent(), equalTo(0)); - } - private void givenProcessResults(List results) { when(process.readAnalyticsResults()).thenReturn(results.iterator()); } diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/ChunkedTrainedModelPersisterTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/ChunkedTrainedModelPersisterTests.java index 277673a4a834b..9b2c099de1098 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/ChunkedTrainedModelPersisterTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/ChunkedTrainedModelPersisterTests.java @@ -15,18 +15,14 @@ import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig; import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsDest; import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsSource; -import org.elasticsearch.xpack.core.ml.dataframe.analyses.BoostedTreeParams; import org.elasticsearch.xpack.core.ml.dataframe.analyses.Classification; import org.elasticsearch.xpack.core.ml.dataframe.analyses.Regression; import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig; -import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig; -import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig; -import org.elasticsearch.xpack.core.ml.inference.trainedmodel.PredictionFieldType; -import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfig; import org.elasticsearch.xpack.core.security.user.XPackUser; import org.elasticsearch.xpack.ml.dataframe.process.results.TrainedModelDefinitionChunk; import org.elasticsearch.xpack.ml.extractor.DocValueField; import org.elasticsearch.xpack.ml.extractor.ExtractedField; +import org.elasticsearch.xpack.ml.extractor.ExtractedFields; import org.elasticsearch.xpack.ml.inference.modelsize.ModelSizeInfo; import org.elasticsearch.xpack.ml.inference.modelsize.ModelSizeInfoTests; import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelDefinitionDoc; @@ -36,17 +32,14 @@ import org.mockito.ArgumentCaptor; import org.mockito.Mockito; -import java.util.Arrays; import java.util.Collections; import java.util.List; import java.util.Map; -import java.util.Set; import static org.hamcrest.Matchers.contains; import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.hasKey; -import static org.hamcrest.Matchers.instanceOf; import static org.hamcrest.Matchers.is; import static org.hamcrest.Matchers.nullValue; import static org.mockito.Matchers.any; @@ -145,79 +138,13 @@ public void testPersistAllDocs() { Mockito.verifyNoMoreInteractions(auditor); } - public void testGetPredictionFieldType() { - DataFrameAnalyticsConfig analyticsConfig = new DataFrameAnalyticsConfig.Builder() - .setId(JOB_ID) - .setDescription(JOB_DESCRIPTION) - .setSource(new DataFrameAnalyticsSource(new String[] {"my_source"}, null, null)) - .setDest(new DataFrameAnalyticsDest("my_dest", null)) - .setAnalysis(randomBoolean() ? new Regression("foo") : new Classification("foo")) - .build(); - List extractedFieldList = Arrays.asList( - new DocValueField("foo", Collections.emptySet()), - new DocValueField("bar", Set.of("keyword")), - new DocValueField("baz", Set.of("long")), - new DocValueField("bingo", Set.of("boolean"))); - ChunkedTrainedModelPersister resultProcessor = createChunkedTrainedModelPersister(extractedFieldList, analyticsConfig); - assertThat(resultProcessor.getPredictionFieldType(new Classification("foo")), equalTo(PredictionFieldType.STRING)); - assertThat(resultProcessor.getPredictionFieldType(new Classification("bar")), equalTo(PredictionFieldType.STRING)); - assertThat(resultProcessor.getPredictionFieldType(new Classification("baz")), equalTo(PredictionFieldType.NUMBER)); - assertThat(resultProcessor.getPredictionFieldType(new Classification("bingo")), equalTo(PredictionFieldType.BOOLEAN)); - } - - public void testBuildInferenceConfigByAnalyticsType() { - List extractedFieldList = Collections.singletonList(new DocValueField("foo", Collections.emptySet())); - DataFrameAnalyticsConfig.Builder analyticsConfigBuilder = new DataFrameAnalyticsConfig.Builder() - .setId(JOB_ID) - .setDescription(JOB_DESCRIPTION) - .setSource(new DataFrameAnalyticsSource(new String[] {"my_source"}, null, null)) - .setDest(new DataFrameAnalyticsDest("my_dest", null)); - { - DataFrameAnalyticsConfig analyticsConfig = analyticsConfigBuilder - .setAnalysis(new Regression("foo", - new BoostedTreeParams(null, null, null, null, null, 2), - null, - null, - null, - null, - null - )) - .build(); - ChunkedTrainedModelPersister resultProcessor = createChunkedTrainedModelPersister(extractedFieldList, analyticsConfig); - InferenceConfig inferenceConfig = resultProcessor.buildInferenceConfigByAnalyticsType(); - - assertThat(inferenceConfig, instanceOf(RegressionConfig.class)); - assertThat(((RegressionConfig)inferenceConfig).getNumTopFeatureImportanceValues(), equalTo(2)); - } - { - DataFrameAnalyticsConfig analyticsConfig = analyticsConfigBuilder - .setAnalysis(new Classification("foo", - new BoostedTreeParams(null, null, null, null, null, 2), - null, - null, - 1, - null, - null - )) - .build(); - ChunkedTrainedModelPersister resultProcessor = createChunkedTrainedModelPersister(extractedFieldList, analyticsConfig); - InferenceConfig inferenceConfig = resultProcessor.buildInferenceConfigByAnalyticsType(); - - assertThat(inferenceConfig, instanceOf(ClassificationConfig.class)); - ClassificationConfig classificationConfig = (ClassificationConfig)inferenceConfig; - assertThat(classificationConfig.getNumTopFeatureImportanceValues(), equalTo(2)); - assertThat(classificationConfig.getNumTopClasses(), equalTo(1)); - assertThat(classificationConfig.getPredictionFieldType(), equalTo(PredictionFieldType.STRING)); - } - } - private ChunkedTrainedModelPersister createChunkedTrainedModelPersister(List fieldNames, DataFrameAnalyticsConfig analyticsConfig) { return new ChunkedTrainedModelPersister(trainedModelProvider, analyticsConfig, auditor, (unused)->{}, - fieldNames); + new ExtractedFields(fieldNames, Collections.emptyMap())); } } From 354fc25be4b26de9d1625d69e7b7f75f26099eb4 Mon Sep 17 00:00:00 2001 From: Benjamin Trent <4357155+benwtrent@users.noreply.github.com> Date: Fri, 26 Jun 2020 11:27:47 -0400 Subject: [PATCH 3/9] adjusting doc storage format --- .../process/ChunkedTrainedModelPersister.java | 15 ++------ .../results/TrainedModelDefinitionChunk.java | 36 ++++++++++++------- .../TrainedModelDefinitionDoc.java | 36 ++++++++++++++----- .../persistence/TrainedModelProvider.java | 27 +++++++++++--- .../ChunkedTrainedModelPersisterTests.java | 4 +-- .../process/results/AnalyticsResultTests.java | 3 +- 6 files changed, 78 insertions(+), 43 deletions(-) diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/ChunkedTrainedModelPersister.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/ChunkedTrainedModelPersister.java index 591a04c4a6fe2..bc8a2744c11f2 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/ChunkedTrainedModelPersister.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/ChunkedTrainedModelPersister.java @@ -38,8 +38,6 @@ import java.util.Map; import java.util.concurrent.CountDownLatch; import java.util.concurrent.TimeUnit; -import java.util.concurrent.atomic.AtomicInteger; -import java.util.concurrent.atomic.AtomicLong; import java.util.concurrent.atomic.AtomicReference; import java.util.function.Consumer; import java.util.stream.Collectors; @@ -51,8 +49,6 @@ public class ChunkedTrainedModelPersister { private static final Logger LOGGER = LogManager.getLogger(ChunkedTrainedModelPersister.class); private final TrainedModelProvider provider; private final AtomicReference currentModelId; - private final AtomicInteger currentChunkedDoc; - private final AtomicLong persistedChunkLengths; private final DataFrameAnalyticsConfig analytics; private final DataFrameAnalyticsAuditor auditor; private final Consumer failureHandler; @@ -66,8 +62,6 @@ public ChunkedTrainedModelPersister(TrainedModelProvider provider, ExtractedFields extractedFields) { this.provider = provider; this.currentModelId = new AtomicReference<>(""); - this.currentChunkedDoc = new AtomicInteger(0); - this.persistedChunkLengths = new AtomicLong(0L); this.analytics = analytics; this.auditor = auditor; this.failureHandler = failureHandler; @@ -81,9 +75,7 @@ public void createAndIndexInferenceModelDoc(TrainedModelDefinitionChunk trainedM )); return; } - TrainedModelDefinitionDoc trainedModelDefinitionDoc = trainedModelDefinitionChunk.createTrainedModelDoc( - this.currentModelId.get(), - this.currentChunkedDoc.getAndIncrement()); + TrainedModelDefinitionDoc trainedModelDefinitionDoc = trainedModelDefinitionChunk.createTrainedModelDoc(this.currentModelId.get()); CountDownLatch latch = new CountDownLatch(1); ActionListener storeListener = ActionListener.wrap( @@ -94,8 +86,7 @@ public void createAndIndexInferenceModelDoc(TrainedModelDefinitionChunk trainedM trainedModelDefinitionDoc.getModelId(), trainedModelDefinitionDoc.getDocNum())); - long persistedChunkLengths = this.persistedChunkLengths.addAndGet(trainedModelDefinitionDoc.getDefinitionLength()); - if (persistedChunkLengths >= trainedModelDefinitionDoc.getTotalDefinitionLength()) { + if (trainedModelDefinitionChunk.isEos()) { readyToStoreNewModel = true; LOGGER.info( "[{}] finished stored trained model definition chunks with id [{}]", @@ -176,8 +167,6 @@ private TrainedModelConfig createTrainedModelConfig(ModelSizeInfo modelSize) { Instant createTime = Instant.now(); String modelId = analytics.getId() + "-" + createTime.toEpochMilli(); currentModelId.set(modelId); - currentChunkedDoc.set(0); - persistedChunkLengths.set(0); List fieldNames = extractedFields.getAllFields(); String dependentVariable = getDependentVariable(); List fieldNamesWithoutDependentVariable = fieldNames.stream() diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/results/TrainedModelDefinitionChunk.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/results/TrainedModelDefinitionChunk.java index 9b28b15971e67..e0b611aada03a 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/results/TrainedModelDefinitionChunk.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/results/TrainedModelDefinitionChunk.java @@ -17,33 +17,37 @@ import java.util.Objects; import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg; +import static org.elasticsearch.common.xcontent.ConstructingObjectParser.optionalConstructorArg; public class TrainedModelDefinitionChunk implements ToXContentObject { private static final ParseField DEFINITION = new ParseField("definition"); - private static final ParseField TOTAL_DEFINITION_LENGTH = new ParseField("total_definition_length"); + private static final ParseField DOC_NUM = new ParseField("doc_num"); + private static final ParseField EOS = new ParseField("eos"); public static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>( "chunked_trained_model_definition", - a -> new TrainedModelDefinitionChunk((String) a[0], (Long) a[1])); + a -> new TrainedModelDefinitionChunk((String) a[0], (Integer) a[1], (Boolean) a[2])); static { PARSER.declareString(constructorArg(), DEFINITION); - PARSER.declareLong(constructorArg(), TOTAL_DEFINITION_LENGTH); + PARSER.declareInt(constructorArg(), DOC_NUM); + PARSER.declareBoolean(optionalConstructorArg(), EOS); } private final String definition; - private final long totalDefinitionLength; + private final int docNum; + private final Boolean eos; - public TrainedModelDefinitionChunk(String definition, long totalDefinitionLength) { + public TrainedModelDefinitionChunk(String definition, int docNum, Boolean eos) { this.definition = definition; - this.totalDefinitionLength = totalDefinitionLength; + this.docNum = docNum; + this.eos = eos; } - public TrainedModelDefinitionDoc createTrainedModelDoc(String modelId, int docNum) { + public TrainedModelDefinitionDoc createTrainedModelDoc(String modelId) { return new TrainedModelDefinitionDoc.Builder() .setCompressionVersion(TrainedModelConfig.CURRENT_DEFINITION_COMPRESSION_VERSION) - .setTotalDefinitionLength(totalDefinitionLength) .setModelId(modelId) .setDefinitionLength(definition.length()) .setDocNum(docNum) @@ -51,11 +55,18 @@ public TrainedModelDefinitionDoc createTrainedModelDoc(String modelId, int docNu .build(); } + public boolean isEos() { + return eos != null && eos; + } + @Override public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { builder.startObject(); builder.field(DEFINITION.getPreferredName(), definition); - builder.field(TOTAL_DEFINITION_LENGTH.getPreferredName(), totalDefinitionLength); + builder.field(DOC_NUM.getPreferredName(), docNum); + if (eos != null) { + builder.field(EOS.getPreferredName(), eos); + } builder.endObject(); return builder; } @@ -65,12 +76,13 @@ public boolean equals(Object o) { if (this == o) return true; if (o == null || getClass() != o.getClass()) return false; TrainedModelDefinitionChunk that = (TrainedModelDefinitionChunk) o; - return totalDefinitionLength == that.totalDefinitionLength && - Objects.equals(definition, that.definition); + return docNum == that.docNum + && Objects.equals(definition, that.definition) + && Objects.equals(eos, that.eos); } @Override public int hashCode() { - return Objects.hash(definition, totalDefinitionLength); + return Objects.hash(definition, docNum, eos); } } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelDefinitionDoc.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelDefinitionDoc.java index a378828756516..de332b77af6f5 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelDefinitionDoc.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelDefinitionDoc.java @@ -33,6 +33,7 @@ public class TrainedModelDefinitionDoc implements ToXContentObject { public static final ParseField COMPRESSION_VERSION = new ParseField("compression_version"); public static final ParseField TOTAL_DEFINITION_LENGTH = new ParseField("total_definition_length"); public static final ParseField DEFINITION_LENGTH = new ParseField("definition_length"); + public static final ParseField EOS = new ParseField("eos"); // These parsers follow the pattern that metadata is parsed leniently (to allow for enhancements), whilst config is parsed strictly public static final ObjectParser LENIENT_PARSER = createParser(true); @@ -48,6 +49,7 @@ private static ObjectParser createParse parser.declareInt(TrainedModelDefinitionDoc.Builder::setCompressionVersion, COMPRESSION_VERSION); parser.declareLong(TrainedModelDefinitionDoc.Builder::setDefinitionLength, DEFINITION_LENGTH); parser.declareLong(TrainedModelDefinitionDoc.Builder::setTotalDefinitionLength, TOTAL_DEFINITION_LENGTH); + parser.declareBoolean(TrainedModelDefinitionDoc.Builder::setEos, EOS); return parser; } @@ -63,23 +65,26 @@ public static String docId(String modelId, int docNum) { private final String compressedString; private final String modelId; private final int docNum; - private final long totalDefinitionLength; + // for BWC + private final Long totalDefinitionLength; private final long definitionLength; private final int compressionVersion; + private final boolean eos; private TrainedModelDefinitionDoc(String compressedString, String modelId, int docNum, - long totalDefinitionLength, + Long totalDefinitionLength, long definitionLength, - int compressionVersion) { + int compressionVersion, + boolean eos) { this.compressedString = ExceptionsHelper.requireNonNull(compressedString, DEFINITION); this.modelId = ExceptionsHelper.requireNonNull(modelId, TrainedModelConfig.MODEL_ID); if (docNum < 0) { throw new IllegalArgumentException("[doc_num] must be greater than or equal to 0"); } this.docNum = docNum; - if (totalDefinitionLength <= 0L) { + if (totalDefinitionLength != null && totalDefinitionLength <= 0L) { throw new IllegalArgumentException("[total_definition_length] must be greater than 0"); } this.totalDefinitionLength = totalDefinitionLength; @@ -88,6 +93,7 @@ private TrainedModelDefinitionDoc(String compressedString, } this.definitionLength = definitionLength; this.compressionVersion = compressionVersion; + this.eos = eos; } public String getCompressedString() { @@ -102,7 +108,7 @@ public int getDocNum() { return docNum; } - public long getTotalDefinitionLength() { + public Long getTotalDefinitionLength() { return totalDefinitionLength; } @@ -114,6 +120,10 @@ public int getCompressionVersion() { return compressionVersion; } + public boolean isEos() { + return eos; + } + public String getDocId() { return docId(modelId, docNum); } @@ -124,10 +134,10 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws builder.field(InferenceIndexConstants.DOC_TYPE.getPreferredName(), NAME); builder.field(TrainedModelConfig.MODEL_ID.getPreferredName(), modelId); builder.field(DOC_NUM.getPreferredName(), docNum); - builder.field(TOTAL_DEFINITION_LENGTH.getPreferredName(), totalDefinitionLength); builder.field(DEFINITION_LENGTH.getPreferredName(), definitionLength); builder.field(COMPRESSION_VERSION.getPreferredName(), compressionVersion); builder.field(DEFINITION.getPreferredName(), compressedString); + builder.field(EOS.getPreferredName(), eos); builder.endObject(); return builder; } @@ -147,12 +157,13 @@ public boolean equals(Object o) { Objects.equals(definitionLength, that.definitionLength) && Objects.equals(totalDefinitionLength, that.totalDefinitionLength) && Objects.equals(compressionVersion, that.compressionVersion) && + Objects.equals(eos, that.eos) && Objects.equals(compressedString, that.compressedString); } @Override public int hashCode() { - return Objects.hash(modelId, docNum, totalDefinitionLength, definitionLength, compressionVersion, compressedString); + return Objects.hash(modelId, docNum, definitionLength, totalDefinitionLength, compressionVersion, compressedString, eos); } public static class Builder { @@ -160,9 +171,10 @@ public static class Builder { private String modelId; private String compressedString; private int docNum; - private long totalDefinitionLength; + private Long totalDefinitionLength; private long definitionLength; private int compressionVersion; + private boolean eos; public Builder setModelId(String modelId) { this.modelId = modelId; @@ -194,6 +206,11 @@ public Builder setCompressionVersion(int compressionVersion) { return this; } + public Builder setEos(boolean eos) { + this.eos = eos; + return this; + } + public TrainedModelDefinitionDoc build() { return new TrainedModelDefinitionDoc( this.compressedString, @@ -201,7 +218,8 @@ public TrainedModelDefinitionDoc build() { this.docNum, this.totalDefinitionLength, this.definitionLength, - this.compressionVersion); + this.compressionVersion, + this.eos); } } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelProvider.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelProvider.java index 5b72b9cd82109..a7eb60bb4bc83 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelProvider.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelProvider.java @@ -236,7 +236,8 @@ private void storeTrainedModelAndDefinition(TrainedModelConfig trainedModelConfi .setCompressedString(chunkedStrings.get(i)) .setCompressionVersion(TrainedModelConfig.CURRENT_DEFINITION_COMPRESSION_VERSION) .setDefinitionLength(chunkedStrings.get(i).length()) - .setTotalDefinitionLength(compressedString.length()) + // If it is the last doc, it is the EOS + .setEos(i == chunkedStrings.size() - 1) .build()); } } catch (IOException ex) { @@ -336,6 +337,9 @@ public void getTrainedModelForInference(final String modelId, final ActionListen .unmappedType("long")) .request(); executeAsyncWithOrigin(client, ML_ORIGIN, SearchAction.INSTANCE, searchRequest, ActionListener.wrap( + // TODO how could we stream in the model definition WHILE parsing it? + // This would reduce the overall memory usage as we won't have to load the whole compressed string + // XContentParser supports streams. searchResponse -> { if (searchResponse.getHits().getHits().length == 0) { listener.onFailure(new ResourceNotFoundException( @@ -348,11 +352,24 @@ public void getTrainedModelForInference(final String modelId, final ActionListen String compressedString = docs.stream() .map(TrainedModelDefinitionDoc::getCompressedString) .collect(Collectors.joining()); - if (compressedString.length() != docs.get(0).getTotalDefinitionLength()) { - listener.onFailure(ExceptionsHelper.serverError( - Messages.getMessage(Messages.MODEL_DEFINITION_TRUNCATED, modelId))); - return; + // BWC for when we tracked the total definition length + // TODO: remove in 9 + if (docs.get(0).getTotalDefinitionLength() != null) { + if (compressedString.length() != docs.get(0).getTotalDefinitionLength()) { + listener.onFailure(ExceptionsHelper.serverError( + Messages.getMessage(Messages.MODEL_DEFINITION_TRUNCATED, modelId))); + return; + } + } else { + TrainedModelDefinitionDoc lastDoc = docs.get(docs.size() - 1); + // Either we are missing the last doc, or some previous doc + if(lastDoc.isEos() == false || lastDoc.getDocNum() != docs.size() - 1) { + listener.onFailure(ExceptionsHelper.serverError( + Messages.getMessage(Messages.MODEL_DEFINITION_TRUNCATED, modelId))); + return; + } } + InferenceDefinition inferenceDefinition = InferenceToXContentCompressor.inflate( compressedString, InferenceDefinition::fromXContent, diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/ChunkedTrainedModelPersisterTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/ChunkedTrainedModelPersisterTests.java index 9b2c099de1098..ee01e297907d6 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/ChunkedTrainedModelPersisterTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/ChunkedTrainedModelPersisterTests.java @@ -88,8 +88,8 @@ public void testPersistAllDocs() { ChunkedTrainedModelPersister resultProcessor = createChunkedTrainedModelPersister(extractedFieldList, analyticsConfig); ModelSizeInfo modelSizeInfo = ModelSizeInfoTests.createRandom(); - TrainedModelDefinitionChunk chunk1 = new TrainedModelDefinitionChunk(randomAlphaOfLength(10), 20L); - TrainedModelDefinitionChunk chunk2 = new TrainedModelDefinitionChunk(randomAlphaOfLength(10), 20L); + TrainedModelDefinitionChunk chunk1 = new TrainedModelDefinitionChunk(randomAlphaOfLength(10), 0, false); + TrainedModelDefinitionChunk chunk2 = new TrainedModelDefinitionChunk(randomAlphaOfLength(10), 1, true); resultProcessor.createAndIndexInferenceModelMetadata(modelSizeInfo); resultProcessor.createAndIndexInferenceModelDoc(chunk1); diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/results/AnalyticsResultTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/results/AnalyticsResultTests.java index 60426f77aa79b..66dce7d70a58e 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/results/AnalyticsResultTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/results/AnalyticsResultTests.java @@ -73,8 +73,7 @@ protected AnalyticsResult createTestInstance() { } if (randomBoolean()) { String def = randomAlphaOfLengthBetween(100, 1000); - long totallength = def.length() * randomLongBetween(1, 10); - trainedModelDefinitionChunk = new TrainedModelDefinitionChunk(def, totallength); + trainedModelDefinitionChunk = new TrainedModelDefinitionChunk(def, randomIntBetween(0, 10), randomBoolean()); } return new AnalyticsResult(rowResults, phaseProgress, memoryUsage, outlierDetectionStats, classificationStats, regressionStats, modelSizeInfo, trainedModelDefinitionChunk); From ff192a4704a855cd5aa0b6a8e2e5b0ac9d8df351 Mon Sep 17 00:00:00 2001 From: Benjamin Trent <4357155+benwtrent@users.noreply.github.com> Date: Fri, 26 Jun 2020 13:58:45 -0400 Subject: [PATCH 4/9] fixing model storage --- .../ChunkedTrainedMoodelPersisterIT.java | 4 ++-- .../results/TrainedModelDefinitionChunk.java | 1 + .../persistence/TrainedModelProvider.java | 20 +++++++++++++++---- 3 files changed, 19 insertions(+), 6 deletions(-) diff --git a/x-pack/plugin/ml/src/internalClusterTest/java/org/elasticsearch/xpack/ml/integration/ChunkedTrainedMoodelPersisterIT.java b/x-pack/plugin/ml/src/internalClusterTest/java/org/elasticsearch/xpack/ml/integration/ChunkedTrainedMoodelPersisterIT.java index 542c64b15349f..2ee9c6431ca5b 100644 --- a/x-pack/plugin/ml/src/internalClusterTest/java/org/elasticsearch/xpack/ml/integration/ChunkedTrainedMoodelPersisterIT.java +++ b/x-pack/plugin/ml/src/internalClusterTest/java/org/elasticsearch/xpack/ml/integration/ChunkedTrainedMoodelPersisterIT.java @@ -77,8 +77,8 @@ public void testStoreModelViaChunkedPersister() throws IOException { //Accuracy for size is not tested here ModelSizeInfo modelSizeInfo = ModelSizeInfoTests.createRandom(); persister.createAndIndexInferenceModelMetadata(modelSizeInfo); - for (String chunk : chunks) { - persister.createAndIndexInferenceModelDoc(new TrainedModelDefinitionChunk(chunk, totalSize)); + for (int i = 0; i < chunks.size(); i++) { + persister.createAndIndexInferenceModelDoc(new TrainedModelDefinitionChunk(chunks.get(i), i, i == (chunks.size() - 1))); } PlainActionFuture>> getIdsFuture = new PlainActionFuture<>(); diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/results/TrainedModelDefinitionChunk.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/results/TrainedModelDefinitionChunk.java index e0b611aada03a..3d5ce84a6af35 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/results/TrainedModelDefinitionChunk.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/results/TrainedModelDefinitionChunk.java @@ -52,6 +52,7 @@ public TrainedModelDefinitionDoc createTrainedModelDoc(String modelId) { .setDefinitionLength(definition.length()) .setDocNum(docNum) .setCompressedString(definition) + .setEos(isEos()) .build(); } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelProvider.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelProvider.java index a7eb60bb4bc83..1dc9212201828 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelProvider.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelProvider.java @@ -452,10 +452,22 @@ public void getTrainedModel(final String modelId, final boolean includeDefinitio String compressedString = docs.stream() .map(TrainedModelDefinitionDoc::getCompressedString) .collect(Collectors.joining()); - if (compressedString.length() != docs.get(0).getTotalDefinitionLength()) { - listener.onFailure(ExceptionsHelper.serverError( - Messages.getMessage(Messages.MODEL_DEFINITION_TRUNCATED, modelId))); - return; + // BWC for when we tracked the total definition length + // TODO: remove in 9 + if (docs.get(0).getTotalDefinitionLength() != null) { + if (compressedString.length() != docs.get(0).getTotalDefinitionLength()) { + listener.onFailure(ExceptionsHelper.serverError( + Messages.getMessage(Messages.MODEL_DEFINITION_TRUNCATED, modelId))); + return; + } + } else { + TrainedModelDefinitionDoc lastDoc = docs.get(docs.size() - 1); + // Either we are missing the last doc, or some previous doc + if(lastDoc.isEos() == false || lastDoc.getDocNum() != docs.size() - 1) { + listener.onFailure(ExceptionsHelper.serverError( + Messages.getMessage(Messages.MODEL_DEFINITION_TRUNCATED, modelId))); + return; + } } builder.setDefinitionFromString(compressedString); } catch (ResourceNotFoundException ex) { From 40571e2ad2daa05eda3f1130f7c923a8f65bae5b Mon Sep 17 00:00:00 2001 From: Benjamin Trent <4357155+benwtrent@users.noreply.github.com> Date: Fri, 26 Jun 2020 15:30:09 -0400 Subject: [PATCH 5/9] unmuting tests --- .../org/elasticsearch/xpack/ml/integration/ClassificationIT.java | 1 - .../org/elasticsearch/xpack/ml/integration/RegressionIT.java | 1 - 2 files changed, 2 deletions(-) diff --git a/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/ClassificationIT.java b/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/ClassificationIT.java index c015bf7bcd0c4..2e6090efd25eb 100644 --- a/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/ClassificationIT.java +++ b/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/ClassificationIT.java @@ -64,7 +64,6 @@ import static org.hamcrest.Matchers.nullValue; import static org.hamcrest.Matchers.startsWith; -@LuceneTestCase.AwaitsFix(bugUrl = "https://github.com/elastic/ml-cpp/pull/1349") public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase { private static final String BOOLEAN_FIELD = "boolean-field"; diff --git a/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/RegressionIT.java b/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/RegressionIT.java index d87ee1f2235aa..38e725ebf352d 100644 --- a/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/RegressionIT.java +++ b/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/RegressionIT.java @@ -43,7 +43,6 @@ import static org.hamcrest.Matchers.lessThan; import static org.hamcrest.Matchers.not; -@LuceneTestCase.AwaitsFix(bugUrl = "https://github.com/elastic/ml-cpp/pull/1349") public class RegressionIT extends MlNativeDataFrameAnalyticsIntegTestCase { private static final String NUMERICAL_FEATURE_FIELD = "feature"; From 8a340bfac24b19cbb32170f16a8d7323ae296854 Mon Sep 17 00:00:00 2001 From: Benjamin Trent <4357155+benwtrent@users.noreply.github.com> Date: Fri, 26 Jun 2020 15:30:33 -0400 Subject: [PATCH 6/9] unmuting tests --- .../org/elasticsearch/xpack/ml/integration/ClassificationIT.java | 1 - .../org/elasticsearch/xpack/ml/integration/RegressionIT.java | 1 - 2 files changed, 2 deletions(-) diff --git a/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/ClassificationIT.java b/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/ClassificationIT.java index 2e6090efd25eb..2fcdf6c5ffad4 100644 --- a/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/ClassificationIT.java +++ b/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/ClassificationIT.java @@ -6,7 +6,6 @@ package org.elasticsearch.xpack.ml.integration; import org.apache.logging.log4j.message.ParameterizedMessage; -import org.apache.lucene.util.LuceneTestCase; import org.elasticsearch.ElasticsearchException; import org.elasticsearch.ElasticsearchStatusException; import org.elasticsearch.action.ActionModule; diff --git a/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/RegressionIT.java b/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/RegressionIT.java index 38e725ebf352d..84d6f39df4369 100644 --- a/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/RegressionIT.java +++ b/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/RegressionIT.java @@ -5,7 +5,6 @@ */ package org.elasticsearch.xpack.ml.integration; -import org.apache.lucene.util.LuceneTestCase; import org.elasticsearch.ElasticsearchException; import org.elasticsearch.action.ActionModule; import org.elasticsearch.action.DocWriteRequest; From b20fb0586874ab6bdd92e4ccda019245c25a708a Mon Sep 17 00:00:00 2001 From: Benjamin Trent <4357155+benwtrent@users.noreply.github.com> Date: Tue, 30 Jun 2020 10:36:45 -0400 Subject: [PATCH 7/9] addressing pr comments --- ...va => ChunkedTrainedModelPersisterIT.java} | 8 +- .../integration/TrainedModelProviderIT.java | 52 +++++++- .../process/ChunkedTrainedModelPersister.java | 119 +++++++++++------- .../persistence/TrainedModelProvider.java | 81 ++++++------ 4 files changed, 163 insertions(+), 97 deletions(-) rename x-pack/plugin/ml/src/internalClusterTest/java/org/elasticsearch/xpack/ml/integration/{ChunkedTrainedMoodelPersisterIT.java => ChunkedTrainedModelPersisterIT.java} (95%) diff --git a/x-pack/plugin/ml/src/internalClusterTest/java/org/elasticsearch/xpack/ml/integration/ChunkedTrainedMoodelPersisterIT.java b/x-pack/plugin/ml/src/internalClusterTest/java/org/elasticsearch/xpack/ml/integration/ChunkedTrainedModelPersisterIT.java similarity index 95% rename from x-pack/plugin/ml/src/internalClusterTest/java/org/elasticsearch/xpack/ml/integration/ChunkedTrainedMoodelPersisterIT.java rename to x-pack/plugin/ml/src/internalClusterTest/java/org/elasticsearch/xpack/ml/integration/ChunkedTrainedModelPersisterIT.java index 2ee9c6431ca5b..74581ac3d45ad 100644 --- a/x-pack/plugin/ml/src/internalClusterTest/java/org/elasticsearch/xpack/ml/integration/ChunkedTrainedMoodelPersisterIT.java +++ b/x-pack/plugin/ml/src/internalClusterTest/java/org/elasticsearch/xpack/ml/integration/ChunkedTrainedModelPersisterIT.java @@ -43,7 +43,7 @@ import static org.hamcrest.Matchers.equalTo; -public class ChunkedTrainedMoodelPersisterIT extends MlSingleNodeTestCase { +public class ChunkedTrainedModelPersisterIT extends MlSingleNodeTestCase { private TrainedModelProvider trainedModelProvider; @@ -111,9 +111,9 @@ private static TrainedModelConfig.Builder buildTrainedModelConfigBuilder(String .setInput(TrainedModelInputTests.createRandomInput()); } - private static List chunkStringWithSize(String str, int chunkSize) { - List subStrings = new ArrayList<>((int)Math.ceil(str.length()/(double)chunkSize)); - for (int i = 0; i < str.length();i += chunkSize) { + public static List chunkStringWithSize(String str, int chunkSize) { + List subStrings = new ArrayList<>((str.length() + chunkSize - 1) / chunkSize); + for (int i = 0; i < str.length(); i += chunkSize) { subStrings.add(str.substring(i, Math.min(i + chunkSize, str.length()))); } return subStrings; diff --git a/x-pack/plugin/ml/src/internalClusterTest/java/org/elasticsearch/xpack/ml/integration/TrainedModelProviderIT.java b/x-pack/plugin/ml/src/internalClusterTest/java/org/elasticsearch/xpack/ml/integration/TrainedModelProviderIT.java index a687124066d5c..51ac866647928 100644 --- a/x-pack/plugin/ml/src/internalClusterTest/java/org/elasticsearch/xpack/ml/integration/TrainedModelProviderIT.java +++ b/x-pack/plugin/ml/src/internalClusterTest/java/org/elasticsearch/xpack/ml/integration/TrainedModelProviderIT.java @@ -31,8 +31,11 @@ import java.util.Collections; import java.util.List; import java.util.concurrent.atomic.AtomicReference; +import java.util.stream.Collectors; +import java.util.stream.IntStream; import static org.elasticsearch.xpack.core.ml.utils.ToXContentParams.FOR_INTERNAL_STORAGE; +import static org.elasticsearch.xpack.ml.integration.ChunkedTrainedModelPersisterIT.chunkStringWithSize; import static org.hamcrest.CoreMatchers.is; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.not; @@ -156,8 +159,8 @@ public void testGetMissingTrainingModelConfigDefinition() throws Exception { equalTo(Messages.getMessage(Messages.MODEL_DEFINITION_NOT_FOUND, modelId))); } - public void testGetTruncatedModelDefinition() throws Exception { - String modelId = "test-get-truncated-model-config"; + public void testGetTruncatedModelDeprecatedDefinition() throws Exception { + String modelId = "test-get-truncated-legacy-model-config"; TrainedModelConfig config = buildTrainedModelConfig(modelId); AtomicReference putConfigHolder = new AtomicReference<>(); AtomicReference exceptionHolder = new AtomicReference<>(); @@ -195,6 +198,51 @@ public void testGetTruncatedModelDefinition() throws Exception { assertThat(exceptionHolder.get().getMessage(), equalTo(Messages.getMessage(Messages.MODEL_DEFINITION_TRUNCATED, modelId))); } + public void testGetTruncatedModelDefinition() throws Exception { + String modelId = "test-get-truncated-model-config"; + TrainedModelConfig config = buildTrainedModelConfig(modelId); + AtomicReference putConfigHolder = new AtomicReference<>(); + AtomicReference exceptionHolder = new AtomicReference<>(); + + blockingCall(listener -> trainedModelProvider.storeTrainedModel(config, listener), putConfigHolder, exceptionHolder); + assertThat(putConfigHolder.get(), is(true)); + assertThat(exceptionHolder.get(), is(nullValue())); + + List chunks = chunkStringWithSize(config.getCompressedDefinition(), config.getCompressedDefinition().length()/3); + + List docBuilders = IntStream.range(0, chunks.size() - 1) + .mapToObj(i -> new TrainedModelDefinitionDoc.Builder() + .setDocNum(i) + .setCompressedString(chunks.get(i)) + .setCompressionVersion(TrainedModelConfig.CURRENT_DEFINITION_COMPRESSION_VERSION) + .setDefinitionLength(chunks.get(i).length()) + .setEos(i == chunks.size() - 1) + .setModelId(modelId)) + .collect(Collectors.toList()); + boolean missingEos = randomBoolean(); + docBuilders.get(docBuilders.size() - 1).setEos(missingEos == false); + for (int i = missingEos ? 0 : 1 ; i < docBuilders.size(); ++i) { + TrainedModelDefinitionDoc doc = docBuilders.get(i).build(); + try(XContentBuilder xContentBuilder = doc.toXContent(XContentFactory.jsonBuilder(), + new ToXContent.MapParams(Collections.singletonMap(FOR_INTERNAL_STORAGE, "true")))) { + AtomicReference putDocHolder = new AtomicReference<>(); + blockingCall(listener -> client().prepareIndex(InferenceIndexConstants.LATEST_INDEX_NAME) + .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE) + .setSource(xContentBuilder) + .setId(TrainedModelDefinitionDoc.docId(modelId, 0)) + .execute(listener), + putDocHolder, + exceptionHolder); + assertThat(exceptionHolder.get(), is(nullValue())); + } + } + AtomicReference getConfigHolder = new AtomicReference<>(); + blockingCall(listener -> trainedModelProvider.getTrainedModel(modelId, true, listener), getConfigHolder, exceptionHolder); + assertThat(getConfigHolder.get(), is(nullValue())); + assertThat(exceptionHolder.get(), is(not(nullValue()))); + assertThat(exceptionHolder.get().getMessage(), equalTo(Messages.getMessage(Messages.MODEL_DEFINITION_TRUNCATED, modelId))); + } + private static TrainedModelConfig.Builder buildTrainedModelConfigBuilder(String modelId) { return TrainedModelConfig.builder() .setCreatedBy("ml_test") diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/ChunkedTrainedModelPersister.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/ChunkedTrainedModelPersister.java index bc8a2744c11f2..0741f42204e20 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/ChunkedTrainedModelPersister.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/ChunkedTrainedModelPersister.java @@ -12,6 +12,7 @@ import org.elasticsearch.Version; import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.LatchedActionListener; +import org.elasticsearch.action.admin.indices.refresh.RefreshResponse; import org.elasticsearch.common.Strings; import org.elasticsearch.common.xcontent.XContentHelper; import org.elasticsearch.common.xcontent.json.JsonXContent; @@ -38,6 +39,7 @@ import java.util.Map; import java.util.concurrent.CountDownLatch; import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicReference; import java.util.function.Consumer; import java.util.stream.Collectors; @@ -47,13 +49,14 @@ public class ChunkedTrainedModelPersister { private static final Logger LOGGER = LogManager.getLogger(ChunkedTrainedModelPersister.class); + private static final int STORE_TIMEOUT_SEC = 30; private final TrainedModelProvider provider; private final AtomicReference currentModelId; private final DataFrameAnalyticsConfig analytics; private final DataFrameAnalyticsAuditor auditor; private final Consumer failureHandler; private final ExtractedFields extractedFields; - private volatile boolean readyToStoreNewModel = true; + private final AtomicBoolean readyToStoreNewModel = new AtomicBoolean(true); public ChunkedTrainedModelPersister(TrainedModelProvider provider, DataFrameAnalyticsConfig analytics, @@ -77,56 +80,23 @@ public void createAndIndexInferenceModelDoc(TrainedModelDefinitionChunk trainedM } TrainedModelDefinitionDoc trainedModelDefinitionDoc = trainedModelDefinitionChunk.createTrainedModelDoc(this.currentModelId.get()); - CountDownLatch latch = new CountDownLatch(1); - ActionListener storeListener = ActionListener.wrap( - r -> { - LOGGER.debug(() -> new ParameterizedMessage( - "[{}] stored trained model definition chunk [{}] [{}]", - analytics.getId(), - trainedModelDefinitionDoc.getModelId(), - trainedModelDefinitionDoc.getDocNum())); - - if (trainedModelDefinitionChunk.isEos()) { - readyToStoreNewModel = true; - LOGGER.info( - "[{}] finished stored trained model definition chunks with id [{}]", - analytics.getId(), - this.currentModelId.get()); - auditor.info(analytics.getId(), "Stored trained model with id [" + this.currentModelId.get() + "]"); - CountDownLatch refreshLatch = new CountDownLatch(1); - provider.refreshInferenceIndex( - new LatchedActionListener<>(ActionListener.wrap( - refreshResponse -> LOGGER.debug(() -> new ParameterizedMessage( - "[{}] refreshed inference index after model store", - analytics.getId() - )), - e -> LOGGER.warn("[{}] failed to refresh inference index after model store", analytics.getId())), - refreshLatch)); - try { - if (refreshLatch.await(30, TimeUnit.SECONDS) == false) { - LOGGER.error("[{}] Timed out (30s) waiting for index refresh", analytics.getId()); - } - } catch (InterruptedException e) { - Thread.currentThread().interrupt(); - } - } - }, - e -> failureHandler.accept(ExceptionsHelper.serverError("error storing trained model definition chunk [{}] with id [{}]", e, - trainedModelDefinitionDoc.getModelId(), trainedModelDefinitionDoc.getDocNum())) - ); - provider.storeTrainedModelDefinitionDoc(trainedModelDefinitionDoc, new LatchedActionListener<>(storeListener, latch)); + CountDownLatch latch = storeTrainedModelDoc(trainedModelDefinitionDoc); try { - if (latch.await(30, TimeUnit.SECONDS) == false) { + if (latch.await(STORE_TIMEOUT_SEC, TimeUnit.SECONDS) == false) { LOGGER.error("[{}] Timed out (30s) waiting for chunked inference definition to be stored", analytics.getId()); + if (trainedModelDefinitionChunk.isEos()) { + this.readyToStoreNewModel.set(true); + } } } catch (InterruptedException e) { Thread.currentThread().interrupt(); + this.readyToStoreNewModel.set(true); failureHandler.accept(ExceptionsHelper.serverError("interrupted waiting for chunked inference definition to be stored")); } } public void createAndIndexInferenceModelMetadata(ModelSizeInfo inferenceModelSize) { - if (readyToStoreNewModel == false) { + if (readyToStoreNewModel.compareAndSet(true, false) == false) { failureHandler.accept(ExceptionsHelper.serverError( "new inference model is attempting to be stored before completion previous model storage" )); @@ -135,29 +105,86 @@ public void createAndIndexInferenceModelMetadata(ModelSizeInfo inferenceModelSiz TrainedModelConfig trainedModelConfig = createTrainedModelConfig(inferenceModelSize); CountDownLatch latch = storeTrainedModelMetadata(trainedModelConfig); try { - readyToStoreNewModel = false; - if (latch.await(30, TimeUnit.SECONDS) == false) { + if (latch.await(STORE_TIMEOUT_SEC, TimeUnit.SECONDS) == false) { LOGGER.error("[{}] Timed out (30s) waiting for inference model metadata to be stored", analytics.getId()); } } catch (InterruptedException e) { Thread.currentThread().interrupt(); + this.readyToStoreNewModel.set(true); failureHandler.accept(ExceptionsHelper.serverError("interrupted waiting for inference model metadata to be stored")); } } + private CountDownLatch storeTrainedModelDoc(TrainedModelDefinitionDoc trainedModelDefinitionDoc) { + CountDownLatch latch = new CountDownLatch(1); + + // Latch is attached to this action as it is the last one to execute. + ActionListener refreshListener = new LatchedActionListener<>(ActionListener.wrap( + refreshed -> { + if (refreshed != null) { + LOGGER.debug(() -> new ParameterizedMessage( + "[{}] refreshed inference index after model store", + analytics.getId() + )); + } + }, + e -> LOGGER.warn( + new ParameterizedMessage("[{}] failed to refresh inference index after model store", analytics.getId()), + e) + ), latch); + + // First, store the model and refresh is necessary + ActionListener storeListener = ActionListener.wrap( + r -> { + LOGGER.debug(() -> new ParameterizedMessage( + "[{}] stored trained model definition chunk [{}] [{}]", + analytics.getId(), + trainedModelDefinitionDoc.getModelId(), + trainedModelDefinitionDoc.getDocNum())); + if (trainedModelDefinitionDoc.isEos() == false) { + refreshListener.onResponse(null); + return; + } + readyToStoreNewModel.set(true); + LOGGER.info( + "[{}] finished storing trained model with id [{}]", + analytics.getId(), + this.currentModelId.get()); + auditor.info(analytics.getId(), "Stored trained model with id [" + this.currentModelId.get() + "]"); + this.currentModelId.set(""); + provider.refreshInferenceIndex(refreshListener); + }, + e -> { + this.readyToStoreNewModel.set(true); + failureHandler.accept(ExceptionsHelper.serverError( + "error storing trained model definition chunk [{}] with id [{}]", + e, + trainedModelDefinitionDoc.getModelId(), + trainedModelDefinitionDoc.getDocNum())); + refreshListener.onResponse(null); + } + ); + provider.storeTrainedModelDefinitionDoc(trainedModelDefinitionDoc, storeListener); + return latch; + } private CountDownLatch storeTrainedModelMetadata(TrainedModelConfig trainedModelConfig) { CountDownLatch latch = new CountDownLatch(1); ActionListener storeListener = ActionListener.wrap( aBoolean -> { if (aBoolean == false) { LOGGER.error("[{}] Storing trained model metadata responded false", analytics.getId()); + readyToStoreNewModel.set(true); failureHandler.accept(ExceptionsHelper.serverError("storing trained model responded false")); } else { - LOGGER.info("[{}] Stored trained model metadata with id [{}]", analytics.getId(), trainedModelConfig.getModelId()); + LOGGER.debug("[{}] Stored trained model metadata with id [{}]", analytics.getId(), trainedModelConfig.getModelId()); } }, - e -> failureHandler.accept(ExceptionsHelper.serverError("error storing trained model metadata with id [{}]", e, - trainedModelConfig.getModelId())) + e -> { + readyToStoreNewModel.set(true); + failureHandler.accept(ExceptionsHelper.serverError("error storing trained model metadata with id [{}]", + e, + trainedModelConfig.getModelId())); + } ); provider.storeTrainedModelMetadata(trainedModelConfig, new LatchedActionListener<>(storeListener, latch)); return latch; diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelProvider.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelProvider.java index 1dc9212201828..e460be8b59094 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelProvider.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelProvider.java @@ -349,32 +349,16 @@ public void getTrainedModelForInference(final String modelId, final ActionListen List docs = handleHits(searchResponse.getHits().getHits(), modelId, this::parseModelDefinitionDocLenientlyFromSource); - String compressedString = docs.stream() - .map(TrainedModelDefinitionDoc::getCompressedString) - .collect(Collectors.joining()); - // BWC for when we tracked the total definition length - // TODO: remove in 9 - if (docs.get(0).getTotalDefinitionLength() != null) { - if (compressedString.length() != docs.get(0).getTotalDefinitionLength()) { - listener.onFailure(ExceptionsHelper.serverError( - Messages.getMessage(Messages.MODEL_DEFINITION_TRUNCATED, modelId))); - return; - } - } else { - TrainedModelDefinitionDoc lastDoc = docs.get(docs.size() - 1); - // Either we are missing the last doc, or some previous doc - if(lastDoc.isEos() == false || lastDoc.getDocNum() != docs.size() - 1) { - listener.onFailure(ExceptionsHelper.serverError( - Messages.getMessage(Messages.MODEL_DEFINITION_TRUNCATED, modelId))); - return; - } + try { + String compressedString = getDefinitionFromDocs(docs, modelId); + InferenceDefinition inferenceDefinition = InferenceToXContentCompressor.inflate( + compressedString, + InferenceDefinition::fromXContent, + xContentRegistry); + listener.onResponse(inferenceDefinition); + } catch (ElasticsearchException elasticsearchException) { + listener.onFailure(elasticsearchException); } - - InferenceDefinition inferenceDefinition = InferenceToXContentCompressor.inflate( - compressedString, - InferenceDefinition::fromXContent, - xContentRegistry); - listener.onResponse(inferenceDefinition); }, e -> { if (ExceptionsHelper.unwrapCause(e) instanceof ResourceNotFoundException) { @@ -449,27 +433,14 @@ public void getTrainedModel(final String modelId, final boolean includeDefinitio List docs = handleSearchItems(multiSearchResponse.getResponses()[1], modelId, this::parseModelDefinitionDocLenientlyFromSource); - String compressedString = docs.stream() - .map(TrainedModelDefinitionDoc::getCompressedString) - .collect(Collectors.joining()); - // BWC for when we tracked the total definition length - // TODO: remove in 9 - if (docs.get(0).getTotalDefinitionLength() != null) { - if (compressedString.length() != docs.get(0).getTotalDefinitionLength()) { - listener.onFailure(ExceptionsHelper.serverError( - Messages.getMessage(Messages.MODEL_DEFINITION_TRUNCATED, modelId))); - return; - } - } else { - TrainedModelDefinitionDoc lastDoc = docs.get(docs.size() - 1); - // Either we are missing the last doc, or some previous doc - if(lastDoc.isEos() == false || lastDoc.getDocNum() != docs.size() - 1) { - listener.onFailure(ExceptionsHelper.serverError( - Messages.getMessage(Messages.MODEL_DEFINITION_TRUNCATED, modelId))); - return; - } + try { + String compressedString = getDefinitionFromDocs(docs, modelId); + builder.setDefinitionFromString(compressedString); + } catch (ElasticsearchException elasticsearchException) { + listener.onFailure(elasticsearchException); + return; } - builder.setDefinitionFromString(compressedString); + } catch (ResourceNotFoundException ex) { listener.onFailure(new ResourceNotFoundException( Messages.getMessage(Messages.MODEL_DEFINITION_NOT_FOUND, modelId))); @@ -906,6 +877,26 @@ private static List handleHits(SearchHit[] hits, return results; } + private static String getDefinitionFromDocs(List docs, String modelId) throws ElasticsearchException { + String compressedString = docs.stream() + .map(TrainedModelDefinitionDoc::getCompressedString) + .collect(Collectors.joining()); + // BWC for when we tracked the total definition length + // TODO: remove in 9 + if (docs.get(0).getTotalDefinitionLength() != null) { + if (compressedString.length() != docs.get(0).getTotalDefinitionLength()) { + throw ExceptionsHelper.serverError(Messages.getMessage(Messages.MODEL_DEFINITION_TRUNCATED, modelId)); + } + } else { + TrainedModelDefinitionDoc lastDoc = docs.get(docs.size() - 1); + // Either we are missing the last doc, or some previous doc + if(lastDoc.isEos() == false || lastDoc.getDocNum() != docs.size() - 1) { + throw ExceptionsHelper.serverError(Messages.getMessage(Messages.MODEL_DEFINITION_TRUNCATED, modelId)); + } + } + return compressedString; + } + static List chunkStringWithSize(String str, int chunkSize) { List subStrings = new ArrayList<>((int)Math.ceil(str.length()/(double)chunkSize)); for (int i = 0; i < str.length();i += chunkSize) { From 398d9f94f7e200f968f9f5411c690c62cd7bd711 Mon Sep 17 00:00:00 2001 From: Benjamin Trent <4357155+benwtrent@users.noreply.github.com> Date: Tue, 30 Jun 2020 10:58:40 -0400 Subject: [PATCH 8/9] moving boolean flag --- .../ml/dataframe/process/ChunkedTrainedModelPersister.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/ChunkedTrainedModelPersister.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/ChunkedTrainedModelPersister.java index 0741f42204e20..58e2227f4dad0 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/ChunkedTrainedModelPersister.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/ChunkedTrainedModelPersister.java @@ -145,13 +145,13 @@ private CountDownLatch storeTrainedModelDoc(TrainedModelDefinitionDoc trainedMod refreshListener.onResponse(null); return; } - readyToStoreNewModel.set(true); LOGGER.info( "[{}] finished storing trained model with id [{}]", analytics.getId(), this.currentModelId.get()); auditor.info(analytics.getId(), "Stored trained model with id [" + this.currentModelId.get() + "]"); this.currentModelId.set(""); + readyToStoreNewModel.set(true); provider.refreshInferenceIndex(refreshListener); }, e -> { From caed31555c00c2060cfb0530ce3ed0647660d076 Mon Sep 17 00:00:00 2001 From: Benjamin Trent <4357155+benwtrent@users.noreply.github.com> Date: Wed, 1 Jul 2020 06:52:37 -0400 Subject: [PATCH 9/9] fixing test --- .../xpack/ml/integration/TrainedModelProviderIT.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/x-pack/plugin/ml/src/internalClusterTest/java/org/elasticsearch/xpack/ml/integration/TrainedModelProviderIT.java b/x-pack/plugin/ml/src/internalClusterTest/java/org/elasticsearch/xpack/ml/integration/TrainedModelProviderIT.java index 51ac866647928..6c9b5634b9448 100644 --- a/x-pack/plugin/ml/src/internalClusterTest/java/org/elasticsearch/xpack/ml/integration/TrainedModelProviderIT.java +++ b/x-pack/plugin/ml/src/internalClusterTest/java/org/elasticsearch/xpack/ml/integration/TrainedModelProviderIT.java @@ -210,7 +210,7 @@ public void testGetTruncatedModelDefinition() throws Exception { List chunks = chunkStringWithSize(config.getCompressedDefinition(), config.getCompressedDefinition().length()/3); - List docBuilders = IntStream.range(0, chunks.size() - 1) + List docBuilders = IntStream.range(0, chunks.size()) .mapToObj(i -> new TrainedModelDefinitionDoc.Builder() .setDocNum(i) .setCompressedString(chunks.get(i))