diff --git a/src/Microsoft.ML.FastTree/RegressionTree.cs b/src/Microsoft.ML.FastTree/RegressionTree.cs
index 2e9d1bd0fd..15b3a67f00 100644
--- a/src/Microsoft.ML.FastTree/RegressionTree.cs
+++ b/src/Microsoft.ML.FastTree/RegressionTree.cs
@@ -57,7 +57,7 @@ public abstract class RegressionTreeBase
/// (2) the categorical features indexed by 's
/// returned value with nodeIndex=i is NOT a sub-set of with
/// nodeIndex=i.
- /// Note that the case (1) happens only when [i] is true and otherwise (2)
+ /// Note that the case (1) happens only when [i] is false and otherwise (2)
/// occurs. A non-negative returned value means a node (i.e., not a leaf); for example, 2 means the 3rd node in
/// the underlying . A negative returned value means a leaf; for example, -1 stands for the
/// (-1)-th leaf in the underlying . Note that is the
diff --git a/src/Microsoft.ML.FastTree/TreeEnsemble/InternalRegressionTree.cs b/src/Microsoft.ML.FastTree/TreeEnsemble/InternalRegressionTree.cs
index be32e79e6e..28c27af81b 100644
--- a/src/Microsoft.ML.FastTree/TreeEnsemble/InternalRegressionTree.cs
+++ b/src/Microsoft.ML.FastTree/TreeEnsemble/InternalRegressionTree.cs
@@ -1515,28 +1515,83 @@ public void AppendFeatureContributions(in VBuffer src, BufferBuilder= 0)
{
- int ifeat = SplitFeatures[node];
- var val = src.GetItemOrDefault(ifeat);
- val = GetFeatureValue(val, node);
int otherWay;
- if (val <= RawThresholds[node])
+ if (CategoricalSplit[node])
{
- otherWay = GtChild[node];
- node = LteChild[node];
+ Contracts.Assert(CategoricalSplitFeatures != null);
+ bool match = false;
+ int selectedIndex = -1;
+ int newNode = 0;
+ foreach (var index in CategoricalSplitFeatures[node])
+ {
+ float fv = GetFeatureValue(src.GetItemOrDefault(index), node);
+ if (fv > 0.0f)
+ {
+ match = true;
+ selectedIndex = index; // We only expect at most one match
+ break;
+ }
+ }
+
+ // If the ghost got a smaller output, the contribution of the categorical features is positive, so
+ // the contribution is true minus ghost.
+ if (match)
+ {
+ newNode = GtChild[node];
+ otherWay = LteChild[node];
+
+ var ghostLeaf = GetLeafFrom(in src, otherWay);
+ var ghostOutput = GetOutput(ghostLeaf);
+ var diff = (float)(trueOutput - ghostOutput);
+ foreach (var index in CategoricalSplitFeatures[node])
+ {
+ if (index == selectedIndex) // this index caused the input to go to the GtChild
+ contributions.AddFeature(index, diff);
+ else // All of the others wouldn't cause it
+ contributions.AddFeature(index, -diff);
+ }
+ }
+ else
+ {
+ newNode = LteChild[node];
+ otherWay = GtChild[node];
+
+ var ghostLeaf = GetLeafFrom(in src, otherWay);
+ var ghostOutput = GetOutput(ghostLeaf);
+ var diff = (float)(trueOutput - ghostOutput);
+
+ // None of the indices caused the input to go to the GtChild,
+ // So all of them caused it to go to the Lte.
+ foreach (var index in CategoricalSplitFeatures[node])
+ contributions.AddFeature(index, diff);
+ }
+
+ node = newNode;
}
else
{
- otherWay = LteChild[node];
- node = GtChild[node];
- }
+ int ifeat = SplitFeatures[node];
+ var val = src.GetItemOrDefault(ifeat);
+ val = GetFeatureValue(val, node);
+ if (val <= RawThresholds[node])
+ {
+ otherWay = GtChild[node];
+ node = LteChild[node];
+ }
+ else
+ {
+ otherWay = LteChild[node];
+ node = GtChild[node];
+ }
- // What if we went the other way?
- var ghostLeaf = GetLeafFrom(in src, otherWay);
- var ghostOutput = GetOutput(ghostLeaf);
+ // What if we went the other way?
+ var ghostLeaf = GetLeafFrom(in src, otherWay);
+ var ghostOutput = GetOutput(ghostLeaf);
- // If the ghost got a smaller output, the contribution of the feature is positive, so
- // the contribution is true minus ghost.
- contributions.AddFeature(ifeat, (float)(trueOutput - ghostOutput));
+ // If the ghost got a smaller output, the contribution of the feature is positive, so
+ // the contribution is true minus ghost.
+ contributions.AddFeature(ifeat, (float)(trueOutput - ghostOutput));
+ }
}
}
}
diff --git a/test/BaselineOutput/Common/FeatureContribution/LightGbmRegressionWithCategoricalSplit.tsv b/test/BaselineOutput/Common/FeatureContribution/LightGbmRegressionWithCategoricalSplit.tsv
new file mode 100644
index 0000000000..ca60f7819c
--- /dev/null
+++ b/test/BaselineOutput/Common/FeatureContribution/LightGbmRegressionWithCategoricalSplit.tsv
@@ -0,0 +1,29 @@
+#@ TextLoader{
+#@ sep=tab
+#@ col=VendorId:TX:0
+#@ col=RateCode:R4:1
+#@ col=PassengerCount:R4:2
+#@ col=PassengerCount:R4:3
+#@ col=TripTime:R4:4
+#@ col=TripTime:R4:5
+#@ col=TripDistance:R4:6
+#@ col=TripDistance:R4:7
+#@ col=PaymentType:TX:8
+#@ col=FareAmount:R4:9
+#@ col=Label:R4:10
+#@ col=VendorIdEncoded:U4[1]:11
+#@ col=VendorIdEncoded:R4:12-12
+#@ col=RateCodeEncoded:U4[2]:13
+#@ col=RateCodeEncoded:R4:14-15
+#@ col=PaymentTypeEncoded:U4[3]:16
+#@ col=PaymentTypeEncoded:R4:17-19
+#@ col=Features:R4:20-28
+#@ col=FeatureContributions:R4:29-37
+#@ col=FeatureContributions:R4:38-46
+#@ col=FeatureContributions:R4:47-55
+#@ col=FeatureContributions:R4:56-64
+#@ }
+CMT 1 1 0.7088812 1271 1.64874518 3.8 1.0118916 CRD 17.5 17.5 0 1 0 1 0 0 1 0 0 1 1 0 1 0 0 0.7088812 1.64874518 1.0118916 36 4:0.107879594 7:0.725665748 8:1 15:-1 24:-0.0418495 26:1 33:-0.370121539 35:8.844109
+CMT 1 1 0.7088812 474 0.6148743 1.5 0.3994309 CRD 8 8 0 1 0 1 0 0 1 0 0 1 1 0 1 0 0 0.7088812 0.6148743 0.3994309 36 4:1 15:-0.0364986733 16:-0.847436965 17:-1 22:0.011381451 26:-1 31:0.115415707 35:-10.1406841
+CMT 1 1 0.7088812 637 0.8263184 1.4 0.372802168 CRD 8.5 8.5 0 1 0 1 0 0 1 0 0 1 1 0 1 0 0 0.7088812 0.8263184 0.372802168 36 4:1 15:-0.0366709046 16:-0.5593253 17:-1 22:0.0182117485 26:-1 31:0.183812216 35:-10.0930576
+CMT 1 1 0.7088812 181 0.234793767 0.6 0.159772366 CSH 4.5 4.5 0 1 0 1 0 1 0 1 0 1 1 0 0 1 0 0.7088812 0.234793767 0.159772366 36 6:1 13:-0.293414325 16:-0.7202999 17:-1 24:0.0291313324 26:-1 33:0.33991462 35:-11.6683512
diff --git a/test/Microsoft.ML.Tests/FeatureContributionTests.cs b/test/Microsoft.ML.Tests/FeatureContributionTests.cs
index fa5132d2ce..cbe7927e7e 100644
--- a/test/Microsoft.ML.Tests/FeatureContributionTests.cs
+++ b/test/Microsoft.ML.Tests/FeatureContributionTests.cs
@@ -12,6 +12,7 @@
using Microsoft.ML.TestFramework.Attributes;
using Microsoft.ML.TestFrameworkCommon.Attributes;
using Microsoft.ML.Trainers;
+using Microsoft.ML.Trainers.LightGbm;
using Xunit;
using Xunit.Abstractions;
@@ -52,6 +53,12 @@ public void TestLightGbmRegression()
TestFeatureContribution(ML.Regression.Trainers.LightGbm(), GetSparseDataset(numberOfInstances: 100), "LightGbmRegression");
}
+ [LightGBMFact]
+ public void TestLightGbmRegressionWithCategoricalSplit()
+ {
+ TestFeatureContribution(ML.Regression.Trainers.LightGbm(new LightGbmRegressionTrainer.Options() { UseCategoricalSplit = true }), GetOneHotEncodedData(numberOfInstances: 100), "LightGbmRegressionWithCategoricalSplit");
+ }
+
[Fact]
public void TestFastTreeRegression()
{
@@ -377,5 +384,55 @@ private enum TaskType
Ranking,
Clustering
}
+
+ public class TaxiTrip
+ {
+ [LoadColumn(0)]
+ public string VendorId;
+
+ [LoadColumn(1)]
+ public float RateCode;
+
+ [LoadColumn(2)]
+ public float PassengerCount;
+
+ [LoadColumn(3)]
+ public float TripTime;
+
+ [LoadColumn(4)]
+ public float TripDistance;
+
+ [LoadColumn(5)]
+ public string PaymentType;
+
+ [LoadColumn(6)]
+ public float FareAmount;
+ }
+
+ ///
+ /// Returns a DataView with a Features column which include HotEncodedData
+ ///
+ private IDataView GetOneHotEncodedData(int numberOfInstances = 100)
+ {
+ var trainDataPath = GetDataPath("taxi-fare-train.csv");
+ IDataView trainingDataView = ML.Data.LoadFromTextFile(trainDataPath, hasHeader: true, separatorChar: ',');
+
+ var vendorIdEncoded = "VendorIdEncoded";
+ var rateCodeEncoded = "RateCodeEncoded";
+ var paymentTypeEncoded = "PaymentTypeEncoded";
+
+ var dataProcessPipeline = ML.Transforms.CopyColumns(outputColumnName: DefaultColumnNames.Label, inputColumnName: nameof(TaxiTrip.FareAmount))
+ .Append(ML.Transforms.Categorical.OneHotEncoding(outputColumnName: vendorIdEncoded, inputColumnName: nameof(TaxiTrip.VendorId)))
+ .Append(ML.Transforms.Categorical.OneHotEncoding(outputColumnName: rateCodeEncoded, inputColumnName: nameof(TaxiTrip.RateCode)))
+ .Append(ML.Transforms.Categorical.OneHotEncoding(outputColumnName: paymentTypeEncoded, inputColumnName: nameof(TaxiTrip.PaymentType)))
+ .Append(ML.Transforms.NormalizeMeanVariance(outputColumnName: nameof(TaxiTrip.PassengerCount)))
+ .Append(ML.Transforms.NormalizeMeanVariance(outputColumnName: nameof(TaxiTrip.TripTime)))
+ .Append(ML.Transforms.NormalizeMeanVariance(outputColumnName: nameof(TaxiTrip.TripDistance)))
+ .Append(ML.Transforms.Concatenate(DefaultColumnNames.Features, vendorIdEncoded, rateCodeEncoded, paymentTypeEncoded,
+ nameof(TaxiTrip.PassengerCount), nameof(TaxiTrip.TripTime), nameof(TaxiTrip.TripDistance)));
+
+ var someRows = ML.Data.TakeRows(trainingDataView, numberOfInstances);
+ return dataProcessPipeline.Fit(someRows).Transform(someRows);
+ }
}
}