diff --git a/src/Microsoft.ML.Recommender/MatrixFactorizationTrainer.cs b/src/Microsoft.ML.Recommender/MatrixFactorizationTrainer.cs index 00153171cf..ac6a89f6d8 100644 --- a/src/Microsoft.ML.Recommender/MatrixFactorizationTrainer.cs +++ b/src/Microsoft.ML.Recommender/MatrixFactorizationTrainer.cs @@ -513,7 +513,7 @@ private MatrixFactorizationModelParameters TrainCore(IChannel ch, RoleMappedData private SafeTrainingAndModelBuffer PrepareBuffer() { - return new SafeTrainingAndModelBuffer(_host, _fun, _k, _threads, _threads == 1 ? 1 : Math.Max(20, 2 * _threads), + return new SafeTrainingAndModelBuffer(_host, _fun, _k, _threads, 2 * _threads + 1, _iter, _lambda, _eta, _alpha, _c, _doNmf, _quiet, copyData: false); } diff --git a/test/Microsoft.ML.Tests/TrainerEstimators/MatrixFactorizationTests.cs b/test/Microsoft.ML.Tests/TrainerEstimators/MatrixFactorizationTests.cs index a6c0093242..09ce334a79 100644 --- a/test/Microsoft.ML.Tests/TrainerEstimators/MatrixFactorizationTests.cs +++ b/test/Microsoft.ML.Tests/TrainerEstimators/MatrixFactorizationTests.cs @@ -96,10 +96,10 @@ public void MatrixFactorizationSimpleTrainAndPredict() // MF produce different matrices on different platforms, so check their content on Windows. if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows)) { - Assert.Equal(0.301269173622131, leftMatrix[0], 5); - Assert.Equal(0.558746933937073, leftMatrix[leftMatrix.Count - 1], 5); - Assert.Equal(0.27028301358223, rightMatrix[0], 5); - Assert.Equal(0.390790820121765, rightMatrix[rightMatrix.Count - 1], 5); + Assert.Equal(0.290507137775421, leftMatrix[0], 5); + Assert.Equal(0.558072924613953, leftMatrix[leftMatrix.Count - 1], 5); + Assert.Equal(0.270811557769775, rightMatrix[0], 5); + Assert.Equal(0.376706808805466, rightMatrix[rightMatrix.Count - 1], 5); } // Read the test data set as an IDataView var testData = reader.Load(new MultiFileSource(GetDataPath(TestDatasets.trivialMatrixFactorization.testFilename))); @@ -129,7 +129,7 @@ public void MatrixFactorizationSimpleTrainAndPredict() if (RuntimeInformation.IsOSPlatform(OSPlatform.Linux)) { // Linux case - var expectedUnixL2Error = 0.612974867782832; // Linux baseline + var expectedUnixL2Error = 0.610332110253861; // Linux baseline Assert.InRange(metrices.MeanSquaredError, expectedUnixL2Error - linuxTolerance, expectedUnixL2Error + linuxTolerance); } else if (RuntimeInformation.IsOSPlatform(OSPlatform.OSX)) @@ -142,7 +142,7 @@ public void MatrixFactorizationSimpleTrainAndPredict() else if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows)) { // Windows case - var expectedWindowsL2Error = 0.622283290742721; // Windows baseline + var expectedWindowsL2Error = 0.60226203382884; // Windows baseline Assert.InRange(metrices.MeanSquaredError, expectedWindowsL2Error - windowsTolerance, expectedWindowsL2Error + windowsTolerance); }