diff --git a/src/Microsoft.ML.StandardTrainers/Standard/MulticlassClassification/MulticlassNaiveBayesTrainer.cs b/src/Microsoft.ML.StandardTrainers/Standard/MulticlassClassification/MulticlassNaiveBayesTrainer.cs index 62b3e7ef98..963d612ec7 100644 --- a/src/Microsoft.ML.StandardTrainers/Standard/MulticlassClassification/MulticlassNaiveBayesTrainer.cs +++ b/src/Microsoft.ML.StandardTrainers/Standard/MulticlassClassification/MulticlassNaiveBayesTrainer.cs @@ -145,8 +145,8 @@ private protected override NaiveBayesMulticlassModelParameters TrainModelCore(Tr if (labelCol.Type is KeyDataViewType labelKeyType) labelCount = labelKeyType.GetCountAsInt32(Host); - int[] labelHistogram = new int[labelCount]; - int[][] featureHistogram = new int[labelCount][]; + long[] labelHistogram = new long[labelCount]; + long[][] featureHistogram = new long[labelCount][]; using (var pch = Host.StartProgressChannel("Multi Class Naive Bayes training")) using (var ch = Host.Start("Training")) using (var cursor = new MulticlassLabelCursor(labelCount, data, CursOpt.Features | CursOpt.Label)) @@ -169,7 +169,7 @@ private protected override NaiveBayesMulticlassModelParameters TrainModelCore(Tr Utils.EnsureSize(ref labelHistogram, size); Utils.EnsureSize(ref featureHistogram, size); if (featureHistogram[cursor.Label] == null) - featureHistogram[cursor.Label] = new int[featureCount]; + featureHistogram[cursor.Label] = new long[featureCount]; labelHistogram[cursor.Label] += 1; labelCount = labelCount < size ? size : labelCount; @@ -231,17 +231,18 @@ private static VersionInfo GetVersionInfo() { return new VersionInfo( modelSignature: "MNABYPRD", - verWrittenCur: 0x00010001, // Initial + //verWrittenCur: 0x00010001, // Initial + verWrittenCur: 0x00010002, // Histograms are of type long verReadableCur: 0x00010001, verWeCanReadBack: 0x00010001, loaderSignature: LoaderSignature, loaderAssemblyName: typeof(NaiveBayesMulticlassModelParameters).Assembly.FullName); } - private readonly int[] _labelHistogram; - private readonly int[][] _featureHistogram; + private readonly long[] _labelHistogram; + private readonly long[][] _featureHistogram; private readonly double[] _absentFeaturesLogProb; - private readonly int _totalTrainingCount; + private readonly long _totalTrainingCount; private readonly int _labelCount; private readonly int _featureCount; private readonly VectorDataViewType _inputType; @@ -259,12 +260,26 @@ private static VersionInfo GetVersionInfo() /// /// Get the label histogram. /// - public IReadOnlyList GetLabelHistogram() => _labelHistogram; + [Obsolete("This API is deprecated, please use GetLabelHistogramLong() which returns _labelHistogram " + + "with type IReadOnlyList to avoid overflow errors with large datasets.", true)] + public IReadOnlyList GetLabelHistogram() => Array.ConvertAll(_labelHistogram, x => (int)x); + + /// + /// Get the label histogram with generic type long. + /// + public IReadOnlyList GetLabelHistogramLong() => _labelHistogram; /// /// Get the feature histogram. /// - public IReadOnlyList> GetFeatureHistogram() => _featureHistogram; + [Obsolete("This API is deprecated, please use GetFeatureHistogramLong() which returns _featureHistogram " + + "with type IReadOnlyList to avoid overflow errors with large datasets.", true)] + public IReadOnlyList> GetFeatureHistogram() => Array.ConvertAll(_featureHistogram, x => Array.ConvertAll(x, y=> (int)y)); + + /// + /// Get the feature histogram with generic type long. + /// + public IReadOnlyList> GetFeatureHistogramLong() => _featureHistogram; /// /// Instantiates new model parameters from trained model. @@ -273,7 +288,7 @@ private static VersionInfo GetVersionInfo() /// The histogram of labels. /// The feature histogram. /// The number of features. - internal NaiveBayesMulticlassModelParameters(IHostEnvironment env, int[] labelHistogram, int[][] featureHistogram, int featureCount) + internal NaiveBayesMulticlassModelParameters(IHostEnvironment env, long[] labelHistogram, long[][] featureHistogram, int featureCount) : base(env, LoaderSignature) { Host.AssertValue(labelHistogram); @@ -290,16 +305,26 @@ internal NaiveBayesMulticlassModelParameters(IHostEnvironment env, int[] labelHi _outputType = new VectorDataViewType(NumberDataViewType.Single, _labelCount); } + /// + /// The unit test TestEntryPoints.LoadEntryPointModel() exercises the ReadIntArrary(int size) codepath below + /// as its ctx.Header.ModelVerWritten is 0x00010001, and the persistent model that gets loaded and executed + /// for this unit test is located at test\data\backcompat\ep_model3.zip/> + /// private NaiveBayesMulticlassModelParameters(IHostEnvironment env, ModelLoadContext ctx) : base(env, LoaderSignature, ctx) { // *** Binary format *** - // int: _labelCount - // int[_labelCount]: _labelHistogram + // int: _labelCount (read during reading of _labelHistogram in ReadLongArray()) + // long[_labelCount]: _labelHistogram // int: _featureCount - // int[_labelCount][_featureCount]: _featureHistogram + // long[_labelCount][_featureCount]: _featureHistogram // int[_labelCount]: _absentFeaturesLogProb - _labelHistogram = ctx.Reader.ReadIntArray() ?? new int[0]; + if (ctx.Header.ModelVerWritten >= 0x00010002) + _labelHistogram = ctx.Reader.ReadLongArray() ?? new long[0]; + else + { + _labelHistogram = Array.ConvertAll(ctx.Reader.ReadIntArray() ?? new int[0], x => (long)x); + } _labelCount = _labelHistogram.Length; foreach (int labelCount in _labelHistogram) @@ -307,12 +332,15 @@ private NaiveBayesMulticlassModelParameters(IHostEnvironment env, ModelLoadConte _featureCount = ctx.Reader.ReadInt32(); Host.CheckDecode(_featureCount >= 0); - _featureHistogram = new int[_labelCount][]; + _featureHistogram = new long[_labelCount][]; for (int iLabel = 0; iLabel < _labelCount; iLabel += 1) { if (_labelHistogram[iLabel] > 0) { - _featureHistogram[iLabel] = ctx.Reader.ReadIntArray(_featureCount); + if (ctx.Header.ModelVerWritten >= 0x00010002) + _featureHistogram[iLabel] = ctx.Reader.ReadLongArray(_featureCount); + else + _featureHistogram[iLabel] = Array.ConvertAll(ctx.Reader.ReadIntArray(_featureCount) ?? new int[0], x => (long)x); for (int iFeature = 0; iFeature < _featureCount; iFeature += 1) Host.CheckDecode(_featureHistogram[iLabel][iFeature] >= 0); } @@ -339,22 +367,23 @@ private protected override void SaveCore(ModelSaveContext ctx) // *** Binary format *** // int: _labelCount - // int[_labelCount]: _labelHistogram + // long[_labelCount]: _labelHistogram // int: _featureCount - // int[_labelCount][_featureCount]: _featureHistogram + // long[_labelCount][_featureCount]: _featureHistogram // int[_labelCount]: _absentFeaturesLogProb - ctx.Writer.WriteIntArray(_labelHistogram.AsSpan(0, _labelCount)); + ctx.Writer.Write(_labelCount); + ctx.Writer.WriteLongStream(_labelHistogram); ctx.Writer.Write(_featureCount); for (int i = 0; i < _labelCount; i += 1) { if (_labelHistogram[i] > 0) - ctx.Writer.WriteIntsNoCount(_featureHistogram[i].AsSpan(0, _featureCount)); + ctx.Writer.WriteLongStream(_featureHistogram[i]); } ctx.Writer.WriteDoublesNoCount(_absentFeaturesLogProb.AsSpan(0, _labelCount)); } - private static double[] CalculateAbsentFeatureLogProbabilities(int[] labelHistogram, int[][] featureHistogram, int featureCount) + private static double[] CalculateAbsentFeatureLogProbabilities(long[] labelHistogram, long[][] featureHistogram, int featureCount) { int labelCount = labelHistogram.Length; double[] absentFeaturesLogProb = new double[labelCount]; @@ -365,7 +394,7 @@ private static double[] CalculateAbsentFeatureLogProbabilities(int[] labelHistog double logProb = 0; for (int iFeature = 0; iFeature < featureCount; iFeature += 1) { - int labelOccuranceCount = labelHistogram[iLabel]; + long labelOccuranceCount = labelHistogram[iLabel]; logProb += Math.Log(1 + ((double)labelOccuranceCount - featureHistogram[iLabel][iFeature])) - Math.Log(labelOccuranceCount + labelCount);