Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -145,8 +145,8 @@ private protected override NaiveBayesMulticlassModelParameters TrainModelCore(Tr
if (labelCol.Type is KeyDataViewType labelKeyType)
labelCount = labelKeyType.GetCountAsInt32(Host);

int[] labelHistogram = new int[labelCount];
int[][] featureHistogram = new int[labelCount][];
long[] labelHistogram = new long[labelCount];
long[][] featureHistogram = new long[labelCount][];
using (var pch = Host.StartProgressChannel("Multi Class Naive Bayes training"))
using (var ch = Host.Start("Training"))
using (var cursor = new MulticlassLabelCursor(labelCount, data, CursOpt.Features | CursOpt.Label))
Expand All @@ -169,7 +169,7 @@ private protected override NaiveBayesMulticlassModelParameters TrainModelCore(Tr
Utils.EnsureSize(ref labelHistogram, size);
Utils.EnsureSize(ref featureHistogram, size);
if (featureHistogram[cursor.Label] == null)
featureHistogram[cursor.Label] = new int[featureCount];
featureHistogram[cursor.Label] = new long[featureCount];
labelHistogram[cursor.Label] += 1;
labelCount = labelCount < size ? size : labelCount;

Expand Down Expand Up @@ -231,17 +231,18 @@ private static VersionInfo GetVersionInfo()
{
return new VersionInfo(
modelSignature: "MNABYPRD",
verWrittenCur: 0x00010001, // Initial
//verWrittenCur: 0x00010001, // Initial
verWrittenCur: 0x00010002, // Histograms are of type long
verReadableCur: 0x00010001,
verWeCanReadBack: 0x00010001,
loaderSignature: LoaderSignature,
loaderAssemblyName: typeof(NaiveBayesMulticlassModelParameters).Assembly.FullName);
}

private readonly int[] _labelHistogram;
private readonly int[][] _featureHistogram;
private readonly long[] _labelHistogram;
private readonly long[][] _featureHistogram;
private readonly double[] _absentFeaturesLogProb;
private readonly int _totalTrainingCount;
private readonly long _totalTrainingCount;
private readonly int _labelCount;
private readonly int _featureCount;
private readonly VectorDataViewType _inputType;
Expand All @@ -259,12 +260,26 @@ private static VersionInfo GetVersionInfo()
/// <summary>
/// Get the label histogram.
/// </summary>
public IReadOnlyList<int> GetLabelHistogram() => _labelHistogram;
[Obsolete("This API is deprecated, please use GetLabelHistogramLong() which returns _labelHistogram " +
"with type IReadOnlyList<long> to avoid overflow errors with large datasets.", true)]
public IReadOnlyList<int> GetLabelHistogram() => Array.ConvertAll(_labelHistogram, x => (int)x);

/// <summary>
/// Get the label histogram with generic type long.
/// </summary>
public IReadOnlyList<long> GetLabelHistogramLong() => _labelHistogram;

/// <summary>
/// Get the feature histogram.
/// </summary>
public IReadOnlyList<IReadOnlyList<int>> GetFeatureHistogram() => _featureHistogram;
[Obsolete("This API is deprecated, please use GetFeatureHistogramLong() which returns _featureHistogram " +
"with type IReadOnlyList<long> to avoid overflow errors with large datasets.", true)]
public IReadOnlyList<IReadOnlyList<int>> GetFeatureHistogram() => Array.ConvertAll(_featureHistogram, x => Array.ConvertAll(x, y=> (int)y));

/// <summary>
/// Get the feature histogram with generic type long.
/// </summary>
public IReadOnlyList<IReadOnlyList<long>> GetFeatureHistogramLong() => _featureHistogram;

/// <summary>
/// Instantiates new model parameters from trained model.
Expand All @@ -273,7 +288,7 @@ private static VersionInfo GetVersionInfo()
/// <param name="labelHistogram">The histogram of labels.</param>
/// <param name="featureHistogram">The feature histogram.</param>
/// <param name="featureCount">The number of features.</param>
internal NaiveBayesMulticlassModelParameters(IHostEnvironment env, int[] labelHistogram, int[][] featureHistogram, int featureCount)
internal NaiveBayesMulticlassModelParameters(IHostEnvironment env, long[] labelHistogram, long[][] featureHistogram, int featureCount)
: base(env, LoaderSignature)
{
Host.AssertValue(labelHistogram);
Expand All @@ -290,29 +305,42 @@ internal NaiveBayesMulticlassModelParameters(IHostEnvironment env, int[] labelHi
_outputType = new VectorDataViewType(NumberDataViewType.Single, _labelCount);
}

/// <remarks>
/// The unit test TestEntryPoints.LoadEntryPointModel() exercises the ReadIntArrary(int size) codepath below
/// as its ctx.Header.ModelVerWritten is 0x00010001, and the persistent model that gets loaded and executed
/// for this unit test is located at test\data\backcompat\ep_model3.zip/>
/// </remarks>
private NaiveBayesMulticlassModelParameters(IHostEnvironment env, ModelLoadContext ctx)
: base(env, LoaderSignature, ctx)
{
// *** Binary format ***
// int: _labelCount
// int[_labelCount]: _labelHistogram
// int: _labelCount (read during reading of _labelHistogram in ReadLongArray())
Copy link
Contributor Author

@mstfbl mstfbl Apr 20, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added this comment as we are not explicitly reading _labelCount here, but in ReadLongArray() as shown below:

public static long[] ReadLongArray(this BinaryReader reader)
{
Contracts.AssertValue(reader);
int size = reader.ReadInt32();
Contracts.CheckDecode(size >= 0);
return ReadLongArray(reader, size);
}
#Resolved

// long[_labelCount]: _labelHistogram
// int: _featureCount
// int[_labelCount][_featureCount]: _featureHistogram
// long[_labelCount][_featureCount]: _featureHistogram
// int[_labelCount]: _absentFeaturesLogProb
_labelHistogram = ctx.Reader.ReadIntArray() ?? new int[0];
if (ctx.Header.ModelVerWritten >= 0x00010002)
_labelHistogram = ctx.Reader.ReadLongArray() ?? new long[0];
else
{
_labelHistogram = Array.ConvertAll(ctx.Reader.ReadIntArray() ?? new int[0], x => (long)x);
}
Copy link
Contributor

@harishsk harishsk Apr 21, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If ReadIntArray returns null, it likely means the file is bad. Should you be throwing an error in this case? The old behavior seems wrong. #Resolved

Copy link
Contributor Author

@mstfbl mstfbl Apr 21, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey Harish, the array being read from ctx.Reader.ReadIntArray(int size) can return null if the size of the array being loaded is 0. Here's the source code:

public static int[] ReadIntArray(this BinaryReader reader)
{
Contracts.AssertValue(reader);
int size = reader.ReadInt32();
Contracts.CheckDecode(size >= 0);
return ReadIntArray(reader, size);
}
public static int[] ReadIntArray(this BinaryReader reader, int size)
{
Contracts.AssertValue(reader);
Contracts.Assert(size >= 0);
if (size == 0)
return null;
var values = new int[size];
long bufferSizeInBytes = (long)size * sizeof(int);
if (bufferSizeInBytes < _bulkReadThresholdInBytes)
{
for (int i = 0; i < size; i++)
values[i] = reader.ReadInt32();
}
else
{
unsafe
{
fixed (void* dst = values)
{
ReadBytes(reader, dst, bufferSizeInBytes, bufferSizeInBytes);
}
}
}
return values;
}
#Resolved

_labelCount = _labelHistogram.Length;

foreach (int labelCount in _labelHistogram)
Host.CheckDecode(labelCount >= 0);

_featureCount = ctx.Reader.ReadInt32();
Host.CheckDecode(_featureCount >= 0);
_featureHistogram = new int[_labelCount][];
_featureHistogram = new long[_labelCount][];
for (int iLabel = 0; iLabel < _labelCount; iLabel += 1)
{
if (_labelHistogram[iLabel] > 0)
{
_featureHistogram[iLabel] = ctx.Reader.ReadIntArray(_featureCount);
if (ctx.Header.ModelVerWritten >= 0x00010002)
_featureHistogram[iLabel] = ctx.Reader.ReadLongArray(_featureCount);
else
_featureHistogram[iLabel] = Array.ConvertAll(ctx.Reader.ReadIntArray(_featureCount) ?? new int[0], x => (long)x);
for (int iFeature = 0; iFeature < _featureCount; iFeature += 1)
Copy link
Contributor

@harishsk harishsk Apr 21, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same comment as above #Resolved

Host.CheckDecode(_featureHistogram[iLabel][iFeature] >= 0);
}
Expand All @@ -339,22 +367,23 @@ private protected override void SaveCore(ModelSaveContext ctx)

// *** Binary format ***
// int: _labelCount
// int[_labelCount]: _labelHistogram
// long[_labelCount]: _labelHistogram
// int: _featureCount
// int[_labelCount][_featureCount]: _featureHistogram
// long[_labelCount][_featureCount]: _featureHistogram
// int[_labelCount]: _absentFeaturesLogProb
ctx.Writer.WriteIntArray(_labelHistogram.AsSpan(0, _labelCount));
ctx.Writer.Write(_labelCount);
ctx.Writer.WriteLongStream(_labelHistogram);
Copy link
Contributor Author

@mstfbl mstfbl Apr 20, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

.AsSpan() is not implemented for converting from 'System.Span' to 'System.Collections.Generic.IEnumerable. The only other way around this is using Array.Copy(), but that introduces needless array copying. I also don't see a case where not all of _labelHistogram is not serialized. As such, _labelHistogram is always serialized whole here. #Resolved

ctx.Writer.Write(_featureCount);
for (int i = 0; i < _labelCount; i += 1)
{
if (_labelHistogram[i] > 0)
ctx.Writer.WriteIntsNoCount(_featureHistogram[i].AsSpan(0, _featureCount));
ctx.Writer.WriteLongStream(_featureHistogram[i]);
Copy link
Contributor Author

@mstfbl mstfbl Apr 20, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sames as above.

.AsSpan() is not implemented for converting from 'System.Span' to 'System.Collections.Generic.IEnumerable. The only other way around this is using Array.Copy(), but that introduces needless array copying. I also don't see a case where not all of _labelHistogram is not serialized. As such, _labelHistogram is always serialized whole here. #Resolved

}

ctx.Writer.WriteDoublesNoCount(_absentFeaturesLogProb.AsSpan(0, _labelCount));
}

private static double[] CalculateAbsentFeatureLogProbabilities(int[] labelHistogram, int[][] featureHistogram, int featureCount)
private static double[] CalculateAbsentFeatureLogProbabilities(long[] labelHistogram, long[][] featureHistogram, int featureCount)
{
int labelCount = labelHistogram.Length;
double[] absentFeaturesLogProb = new double[labelCount];
Expand All @@ -365,7 +394,7 @@ private static double[] CalculateAbsentFeatureLogProbabilities(int[] labelHistog
double logProb = 0;
for (int iFeature = 0; iFeature < featureCount; iFeature += 1)
{
int labelOccuranceCount = labelHistogram[iLabel];
long labelOccuranceCount = labelHistogram[iLabel];
logProb +=
Math.Log(1 + ((double)labelOccuranceCount - featureHistogram[iLabel][iFeature])) -
Math.Log(labelOccuranceCount + labelCount);
Expand Down