Skip to content
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 10 additions & 7 deletions src/Microsoft.ML.FastTree/FastTree.cs
Original file line number Diff line number Diff line change
Expand Up @@ -1545,22 +1545,19 @@ private Dataset Construct(RoleMappedData examples, ref int numExamples, int maxB
{
flocks.Add(CreateOneHotFlockCategorical(ch, pending, binnedValues,
lastOn, true));

if (FeatureMap == null)
FeatureMap = Enumerable.Range(0, NumFeatures)
.Where(f => BinUpperBounds[f].Length > 1).ToArray();
}
iFeature = CategoricalFeatureIndices[catRangeIndex + 1] + 1;
catRangeIndex += 2;
}
else
{
GetFeatureValues(cursor, iFeature, getter, ref temp, ref doubleTemp, copier);

double[] upperBounds = BinUpperBounds[iFeature];
double[] upperBounds = BinUpperBounds[iFeature++];
Host.AssertValue(upperBounds);
if (upperBounds.Length == 1)
continue; //trivial feature, skip it.

flocks.Add(CreateSingletonFlock(ch, ref doubleTemp, binnedValues, upperBounds));
iFeature++;
}
}
}
Expand All @@ -1571,10 +1568,16 @@ private Dataset Construct(RoleMappedData examples, ref int numExamples, int maxB
GetFeatureValues(cursor, i, getter, ref temp, ref doubleTemp, copier);
double[] upperBounds = BinUpperBounds[i];
Host.AssertValue(upperBounds);
if (upperBounds.Length == 1)
continue; //trivial feature, skip it.

flocks.Add(CreateSingletonFlock(ch, ref doubleTemp, binnedValues, upperBounds));
}
}

Contracts.Assert(FeatureMap == null);

FeatureMap = Enumerable.Range(0, NumFeatures).Where(f => BinUpperBounds[f].Length > 1).ToArray();
features = flocks.ToArray();
}
}
Expand Down
18 changes: 11 additions & 7 deletions src/Microsoft.ML.FastTree/GamTrainer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -574,8 +574,7 @@ protected internal GamPredictorBase(IHostEnvironment env, string name, int input
{
Host.CheckValue(trainSet, nameof(trainSet));
Host.CheckParam(trainSet.NumFeatures <= inputLength, nameof(inputLength), "Must be at least as large as dataset number of features");
Host.CheckValue(featureMap, nameof(featureMap));
Host.CheckParam(featureMap.Length == trainSet.NumFeatures, nameof(featureMap), "Not of right size");
Host.CheckParam(featureMap == null || featureMap.Length == trainSet.NumFeatures, nameof(featureMap), "Not of right size");
Host.CheckValue(binEffects, nameof(binEffects));
Host.CheckParam(binEffects.Length == trainSet.NumFeatures, nameof(binEffects), "Not of right size");

Expand All @@ -584,12 +583,17 @@ protected internal GamPredictorBase(IHostEnvironment env, string name, int input
_numFeatures = binEffects.Length;
_inputType = new VectorType(NumberType.Float, _inputLength);
_featureMap = featureMap;
_inputFeatureToDatasetFeatureMap = new Dictionary<int, int>(featureMap.Length);
for (int i = 0; i < featureMap.Length; i++)

//No features were filtered.
if (_featureMap == null)
_featureMap = Enumerable.Range(0, trainSet.NumFeatures).ToArray();

@TomFinley TomFinley May 11, 2018

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Utils.GetIdentityPermutation? #Resolved


_inputFeatureToDatasetFeatureMap = new Dictionary<int, int>(_featureMap.Length);
for (int i = 0; i < _featureMap.Length; i++)
{
Host.CheckParam(0 <= featureMap[i] && featureMap[i] < inputLength, nameof(featureMap), "Contains out of range feature vaule");
Host.CheckParam(!_inputFeatureToDatasetFeatureMap.ContainsValue(featureMap[i]), nameof(featureMap), "Contains duplicate mappings");
_inputFeatureToDatasetFeatureMap[featureMap[i]] = i;
Host.CheckParam(0 <= _featureMap[i] && _featureMap[i] < inputLength, nameof(_featureMap), "Contains out of range feature vaule");
Host.CheckParam(!_inputFeatureToDatasetFeatureMap.ContainsValue(_featureMap[i]), nameof(_featureMap), "Contains duplicate mappings");
_inputFeatureToDatasetFeatureMap[_featureMap[i]] = i;
}

//keep only bin effect and upperbounds where the effect changes.
Expand Down