Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/Microsoft.ML.Recommender/MatrixFactorizationTrainer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -513,7 +513,7 @@ private MatrixFactorizationModelParameters TrainCore(IChannel ch, RoleMappedData

private SafeTrainingAndModelBuffer PrepareBuffer()
{
return new SafeTrainingAndModelBuffer(_host, _fun, _k, _threads, _threads == 1 ? 1 : Math.Max(20, 2 * _threads),
return new SafeTrainingAndModelBuffer(_host, _fun, _k, _threads, 2 * _threads + 1,
_iter, _lambda, _eta, _alpha, _c, _doNmf, _quiet, copyData: false);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -96,10 +96,10 @@ public void MatrixFactorizationSimpleTrainAndPredict()
// MF produce different matrices on different platforms, so check their content on Windows.
if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows))
{
Assert.Equal(0.301269173622131, leftMatrix[0], 5);
Assert.Equal(0.558746933937073, leftMatrix[leftMatrix.Count - 1], 5);
Assert.Equal(0.27028301358223, rightMatrix[0], 5);
Assert.Equal(0.390790820121765, rightMatrix[rightMatrix.Count - 1], 5);
Assert.Equal(0.290507137775421, leftMatrix[0], 5);
Assert.Equal(0.558072924613953, leftMatrix[leftMatrix.Count - 1], 5);
Comment thread
kere-nel marked this conversation as resolved.
Assert.Equal(0.270811557769775, rightMatrix[0], 5);
Assert.Equal(0.376706808805466, rightMatrix[rightMatrix.Count - 1], 5);
}
// Read the test data set as an IDataView
var testData = reader.Load(new MultiFileSource(GetDataPath(TestDatasets.trivialMatrixFactorization.testFilename)));
Expand Down Expand Up @@ -142,7 +142,7 @@ public void MatrixFactorizationSimpleTrainAndPredict()
else if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows))
{
// Windows case
var expectedWindowsL2Error = 0.622283290742721; // Windows baseline
var expectedWindowsL2Error = 0.60226203382884; // Windows baseline
Assert.InRange(metrices.MeanSquaredError, expectedWindowsL2Error - windowsTolerance, expectedWindowsL2Error + windowsTolerance);
}

Expand Down