diff --git a/src/Microsoft.ML.StandardTrainers/Standard/MulticlassClassification/MulticlassNaiveBayesTrainer.cs b/src/Microsoft.ML.StandardTrainers/Standard/MulticlassClassification/MulticlassNaiveBayesTrainer.cs
index 62b3e7ef98..963d612ec7 100644
--- a/src/Microsoft.ML.StandardTrainers/Standard/MulticlassClassification/MulticlassNaiveBayesTrainer.cs
+++ b/src/Microsoft.ML.StandardTrainers/Standard/MulticlassClassification/MulticlassNaiveBayesTrainer.cs
@@ -145,8 +145,8 @@ private protected override NaiveBayesMulticlassModelParameters TrainModelCore(Tr
if (labelCol.Type is KeyDataViewType labelKeyType)
labelCount = labelKeyType.GetCountAsInt32(Host);
- int[] labelHistogram = new int[labelCount];
- int[][] featureHistogram = new int[labelCount][];
+ long[] labelHistogram = new long[labelCount];
+ long[][] featureHistogram = new long[labelCount][];
using (var pch = Host.StartProgressChannel("Multi Class Naive Bayes training"))
using (var ch = Host.Start("Training"))
using (var cursor = new MulticlassLabelCursor(labelCount, data, CursOpt.Features | CursOpt.Label))
@@ -169,7 +169,7 @@ private protected override NaiveBayesMulticlassModelParameters TrainModelCore(Tr
Utils.EnsureSize(ref labelHistogram, size);
Utils.EnsureSize(ref featureHistogram, size);
if (featureHistogram[cursor.Label] == null)
- featureHistogram[cursor.Label] = new int[featureCount];
+ featureHistogram[cursor.Label] = new long[featureCount];
labelHistogram[cursor.Label] += 1;
labelCount = labelCount < size ? size : labelCount;
@@ -231,17 +231,18 @@ private static VersionInfo GetVersionInfo()
{
return new VersionInfo(
modelSignature: "MNABYPRD",
- verWrittenCur: 0x00010001, // Initial
+ //verWrittenCur: 0x00010001, // Initial
+ verWrittenCur: 0x00010002, // Histograms are of type long
verReadableCur: 0x00010001,
verWeCanReadBack: 0x00010001,
loaderSignature: LoaderSignature,
loaderAssemblyName: typeof(NaiveBayesMulticlassModelParameters).Assembly.FullName);
}
- private readonly int[] _labelHistogram;
- private readonly int[][] _featureHistogram;
+ private readonly long[] _labelHistogram;
+ private readonly long[][] _featureHistogram;
private readonly double[] _absentFeaturesLogProb;
- private readonly int _totalTrainingCount;
+ private readonly long _totalTrainingCount;
private readonly int _labelCount;
private readonly int _featureCount;
private readonly VectorDataViewType _inputType;
@@ -259,12 +260,26 @@ private static VersionInfo GetVersionInfo()
///
/// Get the label histogram.
///
- public IReadOnlyList GetLabelHistogram() => _labelHistogram;
+ [Obsolete("This API is deprecated, please use GetLabelHistogramLong() which returns _labelHistogram " +
+ "with type IReadOnlyList to avoid overflow errors with large datasets.", true)]
+ public IReadOnlyList GetLabelHistogram() => Array.ConvertAll(_labelHistogram, x => (int)x);
+
+ ///
+ /// Get the label histogram with generic type long.
+ ///
+ public IReadOnlyList GetLabelHistogramLong() => _labelHistogram;
///
/// Get the feature histogram.
///
- public IReadOnlyList> GetFeatureHistogram() => _featureHistogram;
+ [Obsolete("This API is deprecated, please use GetFeatureHistogramLong() which returns _featureHistogram " +
+ "with type IReadOnlyList to avoid overflow errors with large datasets.", true)]
+ public IReadOnlyList> GetFeatureHistogram() => Array.ConvertAll(_featureHistogram, x => Array.ConvertAll(x, y=> (int)y));
+
+ ///
+ /// Get the feature histogram with generic type long.
+ ///
+ public IReadOnlyList> GetFeatureHistogramLong() => _featureHistogram;
///
/// Instantiates new model parameters from trained model.
@@ -273,7 +288,7 @@ private static VersionInfo GetVersionInfo()
/// The histogram of labels.
/// The feature histogram.
/// The number of features.
- internal NaiveBayesMulticlassModelParameters(IHostEnvironment env, int[] labelHistogram, int[][] featureHistogram, int featureCount)
+ internal NaiveBayesMulticlassModelParameters(IHostEnvironment env, long[] labelHistogram, long[][] featureHistogram, int featureCount)
: base(env, LoaderSignature)
{
Host.AssertValue(labelHistogram);
@@ -290,16 +305,26 @@ internal NaiveBayesMulticlassModelParameters(IHostEnvironment env, int[] labelHi
_outputType = new VectorDataViewType(NumberDataViewType.Single, _labelCount);
}
+ ///
+ /// The unit test TestEntryPoints.LoadEntryPointModel() exercises the ReadIntArrary(int size) codepath below
+ /// as its ctx.Header.ModelVerWritten is 0x00010001, and the persistent model that gets loaded and executed
+ /// for this unit test is located at test\data\backcompat\ep_model3.zip/>
+ ///
private NaiveBayesMulticlassModelParameters(IHostEnvironment env, ModelLoadContext ctx)
: base(env, LoaderSignature, ctx)
{
// *** Binary format ***
- // int: _labelCount
- // int[_labelCount]: _labelHistogram
+ // int: _labelCount (read during reading of _labelHistogram in ReadLongArray())
+ // long[_labelCount]: _labelHistogram
// int: _featureCount
- // int[_labelCount][_featureCount]: _featureHistogram
+ // long[_labelCount][_featureCount]: _featureHistogram
// int[_labelCount]: _absentFeaturesLogProb
- _labelHistogram = ctx.Reader.ReadIntArray() ?? new int[0];
+ if (ctx.Header.ModelVerWritten >= 0x00010002)
+ _labelHistogram = ctx.Reader.ReadLongArray() ?? new long[0];
+ else
+ {
+ _labelHistogram = Array.ConvertAll(ctx.Reader.ReadIntArray() ?? new int[0], x => (long)x);
+ }
_labelCount = _labelHistogram.Length;
foreach (int labelCount in _labelHistogram)
@@ -307,12 +332,15 @@ private NaiveBayesMulticlassModelParameters(IHostEnvironment env, ModelLoadConte
_featureCount = ctx.Reader.ReadInt32();
Host.CheckDecode(_featureCount >= 0);
- _featureHistogram = new int[_labelCount][];
+ _featureHistogram = new long[_labelCount][];
for (int iLabel = 0; iLabel < _labelCount; iLabel += 1)
{
if (_labelHistogram[iLabel] > 0)
{
- _featureHistogram[iLabel] = ctx.Reader.ReadIntArray(_featureCount);
+ if (ctx.Header.ModelVerWritten >= 0x00010002)
+ _featureHistogram[iLabel] = ctx.Reader.ReadLongArray(_featureCount);
+ else
+ _featureHistogram[iLabel] = Array.ConvertAll(ctx.Reader.ReadIntArray(_featureCount) ?? new int[0], x => (long)x);
for (int iFeature = 0; iFeature < _featureCount; iFeature += 1)
Host.CheckDecode(_featureHistogram[iLabel][iFeature] >= 0);
}
@@ -339,22 +367,23 @@ private protected override void SaveCore(ModelSaveContext ctx)
// *** Binary format ***
// int: _labelCount
- // int[_labelCount]: _labelHistogram
+ // long[_labelCount]: _labelHistogram
// int: _featureCount
- // int[_labelCount][_featureCount]: _featureHistogram
+ // long[_labelCount][_featureCount]: _featureHistogram
// int[_labelCount]: _absentFeaturesLogProb
- ctx.Writer.WriteIntArray(_labelHistogram.AsSpan(0, _labelCount));
+ ctx.Writer.Write(_labelCount);
+ ctx.Writer.WriteLongStream(_labelHistogram);
ctx.Writer.Write(_featureCount);
for (int i = 0; i < _labelCount; i += 1)
{
if (_labelHistogram[i] > 0)
- ctx.Writer.WriteIntsNoCount(_featureHistogram[i].AsSpan(0, _featureCount));
+ ctx.Writer.WriteLongStream(_featureHistogram[i]);
}
ctx.Writer.WriteDoublesNoCount(_absentFeaturesLogProb.AsSpan(0, _labelCount));
}
- private static double[] CalculateAbsentFeatureLogProbabilities(int[] labelHistogram, int[][] featureHistogram, int featureCount)
+ private static double[] CalculateAbsentFeatureLogProbabilities(long[] labelHistogram, long[][] featureHistogram, int featureCount)
{
int labelCount = labelHistogram.Length;
double[] absentFeaturesLogProb = new double[labelCount];
@@ -365,7 +394,7 @@ private static double[] CalculateAbsentFeatureLogProbabilities(int[] labelHistog
double logProb = 0;
for (int iFeature = 0; iFeature < featureCount; iFeature += 1)
{
- int labelOccuranceCount = labelHistogram[iLabel];
+ long labelOccuranceCount = labelHistogram[iLabel];
logProb +=
Math.Log(1 + ((double)labelOccuranceCount - featureHistogram[iLabel][iFeature])) -
Math.Log(labelOccuranceCount + labelCount);