From dbb462036e67b786ffd8fc3043d0c330ab63ba89 Mon Sep 17 00:00:00 2001 From: Dimitris Athanasiou Date: Mon, 27 Jul 2020 12:22:02 +0300 Subject: [PATCH] [7.x][ML] DFA result processor should only skip rows and model chunks on cancel (#60113) When the job is force-closed or shutting down due to a fatal error we clean up all cancellable job operations. This includes cancelling the results processor. However, this means that we might not persist objects that are written from the process like stats, memory usage, etc. In hindsight, we do not gain from cancelling the results processor in its entirety. It makes more sense to skip row results and model chunks but keep stats and instrumentation about the job as the latter may contain useful information to understand what happened to the job. Backport of #60113 --- .../process/AnalyticsProcessManager.java | 1 - .../process/AnalyticsResultProcessor.java | 35 ++-- .../process/results/AnalyticsResult.java | 87 ++++++++- .../ml/dataframe/stats/StatsPersister.java | 11 +- .../process/AnalyticsProcessManagerTests.java | 2 +- .../AnalyticsResultProcessorTests.java | 166 +++++++++++++++++- .../process/results/AnalyticsResultTests.java | 36 ++-- 7 files changed, 269 insertions(+), 69 deletions(-) diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsProcessManager.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsProcessManager.java index f125d46097f1e..31baca95cf044 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsProcessManager.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsProcessManager.java @@ -434,7 +434,6 @@ synchronized void stop() { if (inferenceRunner.get() != null) { inferenceRunner.get().cancel(); } - statsPersister.cancel(); if (process.get() != null) { try { process.get().kill(); diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsResultProcessor.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsResultProcessor.java index 32af30a93da4a..41dcea15577e8 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsResultProcessor.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsResultProcessor.java @@ -56,6 +56,7 @@ public class AnalyticsResultProcessor { private final ChunkedTrainedModelPersister chunkedTrainedModelPersister; private volatile String failure; private volatile boolean isCancelled; + private long processedRows; private volatile String latestModelId; @@ -92,31 +93,17 @@ public void awaitForCompletion() { public void cancel() { dataFrameRowsJoiner.cancel(); - statsPersister.cancel(); isCancelled = true; } public void process(AnalyticsProcess process) { long totalRows = process.getConfig().rows(); - long processedRows = 0; // TODO When java 9 features can be used, we will not need the local variable here try (DataFrameRowsJoiner resultsJoiner = dataFrameRowsJoiner) { Iterator iterator = process.readAnalyticsResults(); while (iterator.hasNext()) { - if (isCancelled) { - break; - } - AnalyticsResult result = iterator.next(); - processResult(result, resultsJoiner); - if (result.getRowResults() != null) { - if (processedRows == 0) { - LOGGER.info("[{}] Started writing results", analytics.getId()); - auditor.info(analytics.getId(), Messages.getMessage(Messages.DATA_FRAME_ANALYTICS_AUDIT_STARTED_WRITING_RESULTS)); - } - processedRows++; - updateResultsProgress(processedRows >= totalRows ? 100 : (int) (processedRows * 100.0 / totalRows)); - } + processResult(iterator.next(), resultsJoiner, totalRows); } } catch (Exception e) { if (isCancelled) { @@ -141,10 +128,10 @@ private void completeResultsProgress() { statsHolder.getProgressTracker().updateWritingResultsProgress(100); } - private void processResult(AnalyticsResult result, DataFrameRowsJoiner resultsJoiner) { + private void processResult(AnalyticsResult result, DataFrameRowsJoiner resultsJoiner, long totalRows) { RowResults rowResults = result.getRowResults(); - if (rowResults != null) { - resultsJoiner.processRowResults(rowResults); + if (rowResults != null && isCancelled == false) { + processRowResult(resultsJoiner, totalRows, rowResults); } PhaseProgress phaseProgress = result.getPhaseProgress(); if (phaseProgress != null) { @@ -157,7 +144,7 @@ private void processResult(AnalyticsResult result, DataFrameRowsJoiner resultsJo latestModelId = chunkedTrainedModelPersister.createAndIndexInferenceModelMetadata(modelSize); } TrainedModelDefinitionChunk trainedModelDefinitionChunk = result.getTrainedModelDefinitionChunk(); - if (trainedModelDefinitionChunk != null) { + if (trainedModelDefinitionChunk != null && isCancelled == false) { chunkedTrainedModelPersister.createAndIndexInferenceModelDoc(trainedModelDefinitionChunk); } MemoryUsage memoryUsage = result.getMemoryUsage(); @@ -181,6 +168,16 @@ private void processResult(AnalyticsResult result, DataFrameRowsJoiner resultsJo } } + private void processRowResult(DataFrameRowsJoiner rowsJoiner, long totalRows, RowResults rowResults) { + rowsJoiner.processRowResults(rowResults); + if (processedRows == 0) { + LOGGER.info("[{}] Started writing results", analytics.getId()); + auditor.info(analytics.getId(), Messages.getMessage(Messages.DATA_FRAME_ANALYTICS_AUDIT_STARTED_WRITING_RESULTS)); + } + processedRows++; + updateResultsProgress(processedRows >= totalRows ? 100 : (int) (processedRows * 100.0 / totalRows)); + } + private void setAndReportFailure(Exception e) { LOGGER.error(new ParameterizedMessage("[{}] Error processing results; ", analytics.getId()), e); failure = "error processing results; " + e.getMessage(); diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/results/AnalyticsResult.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/results/AnalyticsResult.java index c7e6f9a1c4377..0020c2df8bc88 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/results/AnalyticsResult.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/results/AnalyticsResult.java @@ -66,14 +66,14 @@ public class AnalyticsResult implements ToXContentObject { private final ModelSizeInfo modelSizeInfo; private final TrainedModelDefinitionChunk trainedModelDefinitionChunk; - public AnalyticsResult(@Nullable RowResults rowResults, - @Nullable PhaseProgress phaseProgress, - @Nullable MemoryUsage memoryUsage, - @Nullable OutlierDetectionStats outlierDetectionStats, - @Nullable ClassificationStats classificationStats, - @Nullable RegressionStats regressionStats, - @Nullable ModelSizeInfo modelSizeInfo, - @Nullable TrainedModelDefinitionChunk trainedModelDefinitionChunk) { + private AnalyticsResult(@Nullable RowResults rowResults, + @Nullable PhaseProgress phaseProgress, + @Nullable MemoryUsage memoryUsage, + @Nullable OutlierDetectionStats outlierDetectionStats, + @Nullable ClassificationStats classificationStats, + @Nullable RegressionStats regressionStats, + @Nullable ModelSizeInfo modelSizeInfo, + @Nullable TrainedModelDefinitionChunk trainedModelDefinitionChunk) { this.rowResults = rowResults; this.phaseProgress = phaseProgress; this.memoryUsage = memoryUsage; @@ -172,4 +172,75 @@ public int hashCode() { return Objects.hash(rowResults, phaseProgress, memoryUsage, outlierDetectionStats, classificationStats, regressionStats, modelSizeInfo, trainedModelDefinitionChunk); } + + public static Builder builder() { + return new Builder(); + } + + public static class Builder { + + private RowResults rowResults; + private PhaseProgress phaseProgress; + private MemoryUsage memoryUsage; + private OutlierDetectionStats outlierDetectionStats; + private ClassificationStats classificationStats; + private RegressionStats regressionStats; + private ModelSizeInfo modelSizeInfo; + private TrainedModelDefinitionChunk trainedModelDefinitionChunk; + + private Builder() {} + + public Builder setRowResults(RowResults rowResults) { + this.rowResults = rowResults; + return this; + } + + public Builder setPhaseProgress(PhaseProgress phaseProgress) { + this.phaseProgress = phaseProgress; + return this; + } + + public Builder setMemoryUsage(MemoryUsage memoryUsage) { + this.memoryUsage = memoryUsage; + return this; + } + + public Builder setOutlierDetectionStats(OutlierDetectionStats outlierDetectionStats) { + this.outlierDetectionStats = outlierDetectionStats; + return this; + } + + public Builder setClassificationStats(ClassificationStats classificationStats) { + this.classificationStats = classificationStats; + return this; + } + + public Builder setRegressionStats(RegressionStats regressionStats) { + this.regressionStats = regressionStats; + return this; + } + + public Builder setModelSizeInfo(ModelSizeInfo modelSizeInfo) { + this.modelSizeInfo = modelSizeInfo; + return this; + } + + public Builder setTrainedModelDefinitionChunk(TrainedModelDefinitionChunk trainedModelDefinitionChunk) { + this.trainedModelDefinitionChunk = trainedModelDefinitionChunk; + return this; + } + + public AnalyticsResult build() { + return new AnalyticsResult( + rowResults, + phaseProgress, + memoryUsage, + outlierDetectionStats, + classificationStats, + regressionStats, + modelSizeInfo, + trainedModelDefinitionChunk + ); + } + } } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/stats/StatsPersister.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/stats/StatsPersister.java index e553ac1810729..4d8179e9fcc94 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/stats/StatsPersister.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/stats/StatsPersister.java @@ -29,7 +29,6 @@ public class StatsPersister { private final String jobId; private final ResultsPersisterService resultsPersisterService; private final DataFrameAnalyticsAuditor auditor; - private volatile boolean isCancelled; public StatsPersister(String jobId, ResultsPersisterService resultsPersisterService, DataFrameAnalyticsAuditor auditor) { this.jobId = Objects.requireNonNull(jobId); @@ -38,10 +37,6 @@ public StatsPersister(String jobId, ResultsPersisterService resultsPersisterServ } public void persistWithRetry(ToXContentObject result, Function docIdSupplier) { - if (isCancelled) { - return; - } - try { resultsPersisterService.indexWithRetry(jobId, MlStatsIndex.writeAlias(), @@ -49,7 +44,7 @@ public void persistWithRetry(ToXContentObject result, Function d new ToXContent.MapParams(Collections.singletonMap(ToXContentParams.FOR_INTERNAL_STORAGE, "true")), WriteRequest.RefreshPolicy.NONE, docIdSupplier.apply(jobId), - () -> isCancelled == false, + () -> true, errorMsg -> auditor.error(jobId, "failed to persist result with id [" + docIdSupplier.apply(jobId) + "]; " + errorMsg) ); @@ -59,8 +54,4 @@ public void persistWithRetry(ToXContentObject result, Function d LOGGER.error(() -> new ParameterizedMessage("[{}] Failed indexing stats result", jobId), e); } } - - public void cancel() { - isCancelled = true; - } } diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsProcessManagerTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsProcessManagerTests.java index 1f1c2851820b6..dd155e30b4002 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsProcessManagerTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsProcessManagerTests.java @@ -60,7 +60,7 @@ public class AnalyticsProcessManagerTests extends ESTestCase { private static final String CONFIG_ID = "config-id"; private static final int NUM_ROWS = 100; private static final int NUM_COLS = 4; - private static final AnalyticsResult PROCESS_RESULT = new AnalyticsResult(null, null, null, null, null, null, null, null); + private static final AnalyticsResult PROCESS_RESULT = AnalyticsResult.builder().build(); private Client client; private DataFrameAnalyticsAuditor auditor; diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsResultProcessorTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsResultProcessorTests.java index f2b33c30755f6..1e404360ae738 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsResultProcessorTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsResultProcessorTests.java @@ -12,8 +12,17 @@ import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsSource; import org.elasticsearch.xpack.core.ml.dataframe.analyses.DataFrameAnalysis; import org.elasticsearch.xpack.core.ml.dataframe.analyses.Regression; +import org.elasticsearch.xpack.core.ml.dataframe.stats.classification.ClassificationStats; +import org.elasticsearch.xpack.core.ml.dataframe.stats.classification.ClassificationStatsTests; +import org.elasticsearch.xpack.core.ml.dataframe.stats.common.MemoryUsage; +import org.elasticsearch.xpack.core.ml.dataframe.stats.outlierdetection.OutlierDetectionStats; +import org.elasticsearch.xpack.core.ml.dataframe.stats.outlierdetection.OutlierDetectionStatsTests; +import org.elasticsearch.xpack.core.ml.dataframe.stats.regression.RegressionStats; +import org.elasticsearch.xpack.core.ml.dataframe.stats.regression.RegressionStatsTests; +import org.elasticsearch.xpack.core.ml.utils.PhaseProgress; import org.elasticsearch.xpack.ml.dataframe.process.results.AnalyticsResult; import org.elasticsearch.xpack.ml.dataframe.process.results.RowResults; +import org.elasticsearch.xpack.ml.dataframe.process.results.TrainedModelDefinitionChunk; import org.elasticsearch.xpack.ml.dataframe.stats.ProgressTracker; import org.elasticsearch.xpack.ml.dataframe.stats.StatsHolder; import org.elasticsearch.xpack.ml.dataframe.stats.StatsPersister; @@ -26,12 +35,15 @@ import org.mockito.InOrder; import org.mockito.Mockito; +import java.time.Instant; import java.util.Arrays; import java.util.Collections; import java.util.List; +import java.util.Optional; import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.is; import static org.mockito.Matchers.any; import static org.mockito.Matchers.eq; import static org.mockito.Mockito.doThrow; @@ -85,8 +97,8 @@ public void testProcess_GivenNoResults() { public void testProcess_GivenEmptyResults() { givenDataFrameRows(2); givenProcessResults(Arrays.asList( - new AnalyticsResult(null, null, null,null, null, null, null, null), - new AnalyticsResult(null, null, null, null, null, null, null, null))); + AnalyticsResult.builder().build(), + AnalyticsResult.builder().build())); AnalyticsResultProcessor resultProcessor = createResultProcessor(); resultProcessor.process(process); @@ -101,8 +113,9 @@ public void testProcess_GivenRowResults() { givenDataFrameRows(2); RowResults rowResults1 = mock(RowResults.class); RowResults rowResults2 = mock(RowResults.class); - givenProcessResults(Arrays.asList(new AnalyticsResult(rowResults1, null, null,null, null, null, null, null), - new AnalyticsResult(rowResults2, null, null, null, null, null, null, null))); + givenProcessResults(Arrays.asList( + AnalyticsResult.builder().setRowResults(rowResults1).build(), + AnalyticsResult.builder().setRowResults(rowResults2).build())); AnalyticsResultProcessor resultProcessor = createResultProcessor(); resultProcessor.process(process); @@ -119,8 +132,9 @@ public void testProcess_GivenDataFrameRowsJoinerFails() { givenDataFrameRows(2); RowResults rowResults1 = mock(RowResults.class); RowResults rowResults2 = mock(RowResults.class); - givenProcessResults(Arrays.asList(new AnalyticsResult(rowResults1, null, null,null, null, null, null, null), - new AnalyticsResult(rowResults2, null, null, null, null, null, null, null))); + givenProcessResults(Arrays.asList( + AnalyticsResult.builder().setRowResults(rowResults1).build(), + AnalyticsResult.builder().setRowResults(rowResults2).build())); doThrow(new RuntimeException("some failure")).when(dataFrameRowsJoiner).processRowResults(any(RowResults.class)); @@ -138,6 +152,146 @@ public void testProcess_GivenDataFrameRowsJoinerFails() { assertThat(statsHolder.getProgressTracker().getWritingResultsProgressPercent(), equalTo(0)); } + public void testCancel_GivenRowResults() { + givenDataFrameRows(2); + RowResults rowResults1 = mock(RowResults.class); + RowResults rowResults2 = mock(RowResults.class); + givenProcessResults(Arrays.asList( + AnalyticsResult.builder().setRowResults(rowResults1).build(), + AnalyticsResult.builder().setRowResults(rowResults2).build())); + AnalyticsResultProcessor resultProcessor = createResultProcessor(); + + resultProcessor.cancel(); + + resultProcessor.process(process); + resultProcessor.awaitForCompletion(); + + verify(dataFrameRowsJoiner).cancel(); + verify(dataFrameRowsJoiner).close(); + Mockito.verifyNoMoreInteractions(dataFrameRowsJoiner, trainedModelProvider); + assertThat(statsHolder.getProgressTracker().getWritingResultsProgressPercent(), equalTo(0)); + } + + public void testCancel_GivenModelChunk() { + givenDataFrameRows(2); + TrainedModelDefinitionChunk modelChunk = mock(TrainedModelDefinitionChunk.class); + givenProcessResults(Arrays.asList(AnalyticsResult.builder().setTrainedModelDefinitionChunk(modelChunk).build())); + AnalyticsResultProcessor resultProcessor = createResultProcessor(); + + resultProcessor.cancel(); + + resultProcessor.process(process); + resultProcessor.awaitForCompletion(); + + verify(dataFrameRowsJoiner).cancel(); + verify(dataFrameRowsJoiner).close(); + Mockito.verifyNoMoreInteractions(dataFrameRowsJoiner, trainedModelProvider); + assertThat(statsHolder.getProgressTracker().getWritingResultsProgressPercent(), equalTo(0)); + } + + public void testCancel_GivenPhaseProgress() { + givenDataFrameRows(2); + PhaseProgress phaseProgress = new PhaseProgress("analyzing", 18); + givenProcessResults(Arrays.asList(AnalyticsResult.builder().setPhaseProgress(phaseProgress).build())); + AnalyticsResultProcessor resultProcessor = createResultProcessor(); + + resultProcessor.cancel(); + + resultProcessor.process(process); + resultProcessor.awaitForCompletion(); + + verify(dataFrameRowsJoiner).cancel(); + verify(dataFrameRowsJoiner).close(); + Mockito.verifyNoMoreInteractions(dataFrameRowsJoiner, trainedModelProvider); + assertThat(statsHolder.getProgressTracker().getWritingResultsProgressPercent(), equalTo(0)); + + Optional testPhaseProgress = statsHolder.getProgressTracker().report().stream() + .filter(p -> p.getPhase().equals(phaseProgress.getPhase())) + .findAny(); + assertThat(testPhaseProgress.isPresent(), is(true)); + assertThat(testPhaseProgress.get().getProgressPercent(), equalTo(18)); + } + + public void testCancel_GivenMemoryUsage() { + givenDataFrameRows(2); + MemoryUsage memoryUsage = new MemoryUsage(analyticsConfig.getId(), Instant.now(), 1000L, MemoryUsage.Status.HARD_LIMIT, null); + givenProcessResults(Arrays.asList(AnalyticsResult.builder().setMemoryUsage(memoryUsage).build())); + AnalyticsResultProcessor resultProcessor = createResultProcessor(); + + resultProcessor.cancel(); + + resultProcessor.process(process); + resultProcessor.awaitForCompletion(); + + verify(dataFrameRowsJoiner).cancel(); + verify(dataFrameRowsJoiner).close(); + Mockito.verifyNoMoreInteractions(dataFrameRowsJoiner, trainedModelProvider); + assertThat(statsHolder.getProgressTracker().getWritingResultsProgressPercent(), equalTo(0)); + + assertThat(statsHolder.getMemoryUsage(), equalTo(memoryUsage)); + verify(statsPersister).persistWithRetry(eq(memoryUsage), any()); + } + + public void testCancel_GivenOutlierDetectionStats() { + givenDataFrameRows(2); + OutlierDetectionStats outlierDetectionStats = OutlierDetectionStatsTests.createRandom(); + givenProcessResults(Arrays.asList(AnalyticsResult.builder().setOutlierDetectionStats(outlierDetectionStats).build())); + AnalyticsResultProcessor resultProcessor = createResultProcessor(); + + resultProcessor.cancel(); + + resultProcessor.process(process); + resultProcessor.awaitForCompletion(); + + verify(dataFrameRowsJoiner).cancel(); + verify(dataFrameRowsJoiner).close(); + Mockito.verifyNoMoreInteractions(dataFrameRowsJoiner, trainedModelProvider); + assertThat(statsHolder.getProgressTracker().getWritingResultsProgressPercent(), equalTo(0)); + + assertThat(statsHolder.getAnalysisStats(), equalTo(outlierDetectionStats)); + verify(statsPersister).persistWithRetry(eq(outlierDetectionStats), any()); + } + + public void testCancel_GivenClassificationStats() { + givenDataFrameRows(2); + ClassificationStats classificationStats = ClassificationStatsTests.createRandom(); + givenProcessResults(Arrays.asList(AnalyticsResult.builder().setClassificationStats(classificationStats).build())); + AnalyticsResultProcessor resultProcessor = createResultProcessor(); + + resultProcessor.cancel(); + + resultProcessor.process(process); + resultProcessor.awaitForCompletion(); + + verify(dataFrameRowsJoiner).cancel(); + verify(dataFrameRowsJoiner).close(); + Mockito.verifyNoMoreInteractions(dataFrameRowsJoiner, trainedModelProvider); + assertThat(statsHolder.getProgressTracker().getWritingResultsProgressPercent(), equalTo(0)); + + assertThat(statsHolder.getAnalysisStats(), equalTo(classificationStats)); + verify(statsPersister).persistWithRetry(eq(classificationStats), any()); + } + + public void testCancel_GivenRegressionStats() { + givenDataFrameRows(2); + RegressionStats regressionStats = RegressionStatsTests.createRandom(); + givenProcessResults(Arrays.asList(AnalyticsResult.builder().setRegressionStats(regressionStats).build())); + AnalyticsResultProcessor resultProcessor = createResultProcessor(); + + resultProcessor.cancel(); + + resultProcessor.process(process); + resultProcessor.awaitForCompletion(); + + verify(dataFrameRowsJoiner).cancel(); + verify(dataFrameRowsJoiner).close(); + Mockito.verifyNoMoreInteractions(dataFrameRowsJoiner, trainedModelProvider); + assertThat(statsHolder.getProgressTracker().getWritingResultsProgressPercent(), equalTo(0)); + + assertThat(statsHolder.getAnalysisStats(), equalTo(regressionStats)); + verify(statsPersister).persistWithRetry(eq(regressionStats), any()); + } + private void givenProcessResults(List results) { when(process.readAnalyticsResults()).thenReturn(results.iterator()); } diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/results/AnalyticsResultTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/results/AnalyticsResultTests.java index 3f48583ef3451..7e79dcc7d84aa 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/results/AnalyticsResultTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/results/AnalyticsResultTests.java @@ -11,19 +11,14 @@ import org.elasticsearch.common.xcontent.XContentParser; import org.elasticsearch.search.SearchModule; import org.elasticsearch.test.AbstractXContentTestCase; -import org.elasticsearch.xpack.core.ml.dataframe.stats.common.MemoryUsage; -import org.elasticsearch.xpack.core.ml.dataframe.stats.common.MemoryUsageTests; -import org.elasticsearch.xpack.core.ml.dataframe.stats.classification.ClassificationStats; import org.elasticsearch.xpack.core.ml.dataframe.stats.classification.ClassificationStatsTests; -import org.elasticsearch.xpack.core.ml.dataframe.stats.outlierdetection.OutlierDetectionStats; +import org.elasticsearch.xpack.core.ml.dataframe.stats.common.MemoryUsageTests; import org.elasticsearch.xpack.core.ml.dataframe.stats.outlierdetection.OutlierDetectionStatsTests; -import org.elasticsearch.xpack.core.ml.dataframe.stats.regression.RegressionStats; import org.elasticsearch.xpack.core.ml.dataframe.stats.regression.RegressionStatsTests; import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider; import org.elasticsearch.xpack.core.ml.utils.PhaseProgress; import org.elasticsearch.xpack.core.ml.utils.ToXContentParams; import org.elasticsearch.xpack.ml.inference.modelsize.MlModelSizeNamedXContentProvider; -import org.elasticsearch.xpack.ml.inference.modelsize.ModelSizeInfo; import org.elasticsearch.xpack.ml.inference.modelsize.ModelSizeInfoTests; import java.util.ArrayList; @@ -42,41 +37,34 @@ protected NamedXContentRegistry xContentRegistry() { } protected AnalyticsResult createTestInstance() { - RowResults rowResults = null; - PhaseProgress phaseProgress = null; - MemoryUsage memoryUsage = null; - OutlierDetectionStats outlierDetectionStats = null; - ClassificationStats classificationStats = null; - RegressionStats regressionStats = null; - ModelSizeInfo modelSizeInfo = null; - TrainedModelDefinitionChunk trainedModelDefinitionChunk = null; + AnalyticsResult.Builder builder = AnalyticsResult.builder(); + if (randomBoolean()) { - rowResults = RowResultsTests.createRandom(); + builder.setRowResults(RowResultsTests.createRandom()); } if (randomBoolean()) { - phaseProgress = new PhaseProgress(randomAlphaOfLength(10), randomIntBetween(0, 100)); + builder.setPhaseProgress(new PhaseProgress(randomAlphaOfLength(10), randomIntBetween(0, 100))); } if (randomBoolean()) { - memoryUsage = MemoryUsageTests.createRandom(); + builder.setMemoryUsage(MemoryUsageTests.createRandom()); } if (randomBoolean()) { - outlierDetectionStats = OutlierDetectionStatsTests.createRandom(); + builder.setOutlierDetectionStats(OutlierDetectionStatsTests.createRandom()); } if (randomBoolean()) { - classificationStats = ClassificationStatsTests.createRandom(); + builder.setClassificationStats(ClassificationStatsTests.createRandom()); } if (randomBoolean()) { - regressionStats = RegressionStatsTests.createRandom(); + builder.setRegressionStats(RegressionStatsTests.createRandom()); } if (randomBoolean()) { - modelSizeInfo = ModelSizeInfoTests.createRandom(); + builder.setModelSizeInfo(ModelSizeInfoTests.createRandom()); } if (randomBoolean()) { String def = randomAlphaOfLengthBetween(100, 1000); - trainedModelDefinitionChunk = new TrainedModelDefinitionChunk(def, randomIntBetween(0, 10), randomBoolean()); + builder.setTrainedModelDefinitionChunk(new TrainedModelDefinitionChunk(def, randomIntBetween(0, 10), randomBoolean())); } - return new AnalyticsResult(rowResults, phaseProgress, memoryUsage, outlierDetectionStats, - classificationStats, regressionStats, modelSizeInfo, trainedModelDefinitionChunk); + return builder.build(); } @Override