From 72ce7845dc93599d6c77abd1214c53f9e1dc2c01 Mon Sep 17 00:00:00 2001 From: Benjamin Trent <4357155+benwtrent@users.noreply.github.com> Date: Tue, 24 Mar 2020 16:50:47 -0400 Subject: [PATCH] [ML] relaxing parameters on stratified split test --- .../StratifiedCrossValidationSplitterTests.java | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/crossvalidation/StratifiedCrossValidationSplitterTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/crossvalidation/StratifiedCrossValidationSplitterTests.java index 417ce0a83ff1e..9ef0f773ee852 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/crossvalidation/StratifiedCrossValidationSplitterTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/crossvalidation/StratifiedCrossValidationSplitterTests.java @@ -179,8 +179,8 @@ public void testProcess_GivenRowsWithDependentVariableValue_AndTrainingPercentIs double expectedTotalTrainingCount = ROWS_COUNT * trainingFraction; assertThat(trainingDocsCount + testDocsCount, equalTo((long) ROWS_COUNT)); - assertThat(trainingDocsCount, greaterThanOrEqualTo((long) Math.floor(expectedTotalTrainingCount - 1))); - assertThat(trainingDocsCount, lessThanOrEqualTo((long) Math.ceil(expectedTotalTrainingCount + 1))); + assertThat(trainingDocsCount, greaterThanOrEqualTo((long) (expectedTotalTrainingCount - 2))); + assertThat(trainingDocsCount, lessThanOrEqualTo((long) Math.ceil(expectedTotalTrainingCount) + 2)); for (String classValue : classCardinalities.keySet()) { double expectedClassTrainingCount = totalRowsPerClass.get(classValue) * trainingFraction; @@ -221,7 +221,7 @@ public void testProcess_SelectsTrainingRowsUniformly() { // should be close to the training percent, which is set to 0.5 for (int rowTrainingCount : trainingCountPerRow) { double meanCount = rowTrainingCount / (double) runCount; - assertThat(meanCount, is(closeTo(0.5, 0.1))); + assertThat(meanCount, is(closeTo(0.5, 0.12))); } }