diff --git a/src/Microsoft.ML.FastTree/RegressionTree.cs b/src/Microsoft.ML.FastTree/RegressionTree.cs index 2e9d1bd0fd..15b3a67f00 100644 --- a/src/Microsoft.ML.FastTree/RegressionTree.cs +++ b/src/Microsoft.ML.FastTree/RegressionTree.cs @@ -57,7 +57,7 @@ public abstract class RegressionTreeBase /// (2) the categorical features indexed by 's /// returned value with nodeIndex=i is NOT a sub-set of with /// nodeIndex=i. - /// Note that the case (1) happens only when [i] is true and otherwise (2) + /// Note that the case (1) happens only when [i] is false and otherwise (2) /// occurs. A non-negative returned value means a node (i.e., not a leaf); for example, 2 means the 3rd node in /// the underlying . A negative returned value means a leaf; for example, -1 stands for the /// (-1)-th leaf in the underlying . Note that is the diff --git a/src/Microsoft.ML.FastTree/TreeEnsemble/InternalRegressionTree.cs b/src/Microsoft.ML.FastTree/TreeEnsemble/InternalRegressionTree.cs index be32e79e6e..28c27af81b 100644 --- a/src/Microsoft.ML.FastTree/TreeEnsemble/InternalRegressionTree.cs +++ b/src/Microsoft.ML.FastTree/TreeEnsemble/InternalRegressionTree.cs @@ -1515,28 +1515,83 @@ public void AppendFeatureContributions(in VBuffer src, BufferBuilder= 0) { - int ifeat = SplitFeatures[node]; - var val = src.GetItemOrDefault(ifeat); - val = GetFeatureValue(val, node); int otherWay; - if (val <= RawThresholds[node]) + if (CategoricalSplit[node]) { - otherWay = GtChild[node]; - node = LteChild[node]; + Contracts.Assert(CategoricalSplitFeatures != null); + bool match = false; + int selectedIndex = -1; + int newNode = 0; + foreach (var index in CategoricalSplitFeatures[node]) + { + float fv = GetFeatureValue(src.GetItemOrDefault(index), node); + if (fv > 0.0f) + { + match = true; + selectedIndex = index; // We only expect at most one match + break; + } + } + + // If the ghost got a smaller output, the contribution of the categorical features is positive, so + // the contribution is true minus ghost. + if (match) + { + newNode = GtChild[node]; + otherWay = LteChild[node]; + + var ghostLeaf = GetLeafFrom(in src, otherWay); + var ghostOutput = GetOutput(ghostLeaf); + var diff = (float)(trueOutput - ghostOutput); + foreach (var index in CategoricalSplitFeatures[node]) + { + if (index == selectedIndex) // this index caused the input to go to the GtChild + contributions.AddFeature(index, diff); + else // All of the others wouldn't cause it + contributions.AddFeature(index, -diff); + } + } + else + { + newNode = LteChild[node]; + otherWay = GtChild[node]; + + var ghostLeaf = GetLeafFrom(in src, otherWay); + var ghostOutput = GetOutput(ghostLeaf); + var diff = (float)(trueOutput - ghostOutput); + + // None of the indices caused the input to go to the GtChild, + // So all of them caused it to go to the Lte. + foreach (var index in CategoricalSplitFeatures[node]) + contributions.AddFeature(index, diff); + } + + node = newNode; } else { - otherWay = LteChild[node]; - node = GtChild[node]; - } + int ifeat = SplitFeatures[node]; + var val = src.GetItemOrDefault(ifeat); + val = GetFeatureValue(val, node); + if (val <= RawThresholds[node]) + { + otherWay = GtChild[node]; + node = LteChild[node]; + } + else + { + otherWay = LteChild[node]; + node = GtChild[node]; + } - // What if we went the other way? - var ghostLeaf = GetLeafFrom(in src, otherWay); - var ghostOutput = GetOutput(ghostLeaf); + // What if we went the other way? + var ghostLeaf = GetLeafFrom(in src, otherWay); + var ghostOutput = GetOutput(ghostLeaf); - // If the ghost got a smaller output, the contribution of the feature is positive, so - // the contribution is true minus ghost. - contributions.AddFeature(ifeat, (float)(trueOutput - ghostOutput)); + // If the ghost got a smaller output, the contribution of the feature is positive, so + // the contribution is true minus ghost. + contributions.AddFeature(ifeat, (float)(trueOutput - ghostOutput)); + } } } } diff --git a/test/BaselineOutput/Common/FeatureContribution/LightGbmRegressionWithCategoricalSplit.tsv b/test/BaselineOutput/Common/FeatureContribution/LightGbmRegressionWithCategoricalSplit.tsv new file mode 100644 index 0000000000..ca60f7819c --- /dev/null +++ b/test/BaselineOutput/Common/FeatureContribution/LightGbmRegressionWithCategoricalSplit.tsv @@ -0,0 +1,29 @@ +#@ TextLoader{ +#@ sep=tab +#@ col=VendorId:TX:0 +#@ col=RateCode:R4:1 +#@ col=PassengerCount:R4:2 +#@ col=PassengerCount:R4:3 +#@ col=TripTime:R4:4 +#@ col=TripTime:R4:5 +#@ col=TripDistance:R4:6 +#@ col=TripDistance:R4:7 +#@ col=PaymentType:TX:8 +#@ col=FareAmount:R4:9 +#@ col=Label:R4:10 +#@ col=VendorIdEncoded:U4[1]:11 +#@ col=VendorIdEncoded:R4:12-12 +#@ col=RateCodeEncoded:U4[2]:13 +#@ col=RateCodeEncoded:R4:14-15 +#@ col=PaymentTypeEncoded:U4[3]:16 +#@ col=PaymentTypeEncoded:R4:17-19 +#@ col=Features:R4:20-28 +#@ col=FeatureContributions:R4:29-37 +#@ col=FeatureContributions:R4:38-46 +#@ col=FeatureContributions:R4:47-55 +#@ col=FeatureContributions:R4:56-64 +#@ } +CMT 1 1 0.7088812 1271 1.64874518 3.8 1.0118916 CRD 17.5 17.5 0 1 0 1 0 0 1 0 0 1 1 0 1 0 0 0.7088812 1.64874518 1.0118916 36 4:0.107879594 7:0.725665748 8:1 15:-1 24:-0.0418495 26:1 33:-0.370121539 35:8.844109 +CMT 1 1 0.7088812 474 0.6148743 1.5 0.3994309 CRD 8 8 0 1 0 1 0 0 1 0 0 1 1 0 1 0 0 0.7088812 0.6148743 0.3994309 36 4:1 15:-0.0364986733 16:-0.847436965 17:-1 22:0.011381451 26:-1 31:0.115415707 35:-10.1406841 +CMT 1 1 0.7088812 637 0.8263184 1.4 0.372802168 CRD 8.5 8.5 0 1 0 1 0 0 1 0 0 1 1 0 1 0 0 0.7088812 0.8263184 0.372802168 36 4:1 15:-0.0366709046 16:-0.5593253 17:-1 22:0.0182117485 26:-1 31:0.183812216 35:-10.0930576 +CMT 1 1 0.7088812 181 0.234793767 0.6 0.159772366 CSH 4.5 4.5 0 1 0 1 0 1 0 1 0 1 1 0 0 1 0 0.7088812 0.234793767 0.159772366 36 6:1 13:-0.293414325 16:-0.7202999 17:-1 24:0.0291313324 26:-1 33:0.33991462 35:-11.6683512 diff --git a/test/Microsoft.ML.Tests/FeatureContributionTests.cs b/test/Microsoft.ML.Tests/FeatureContributionTests.cs index fa5132d2ce..cbe7927e7e 100644 --- a/test/Microsoft.ML.Tests/FeatureContributionTests.cs +++ b/test/Microsoft.ML.Tests/FeatureContributionTests.cs @@ -12,6 +12,7 @@ using Microsoft.ML.TestFramework.Attributes; using Microsoft.ML.TestFrameworkCommon.Attributes; using Microsoft.ML.Trainers; +using Microsoft.ML.Trainers.LightGbm; using Xunit; using Xunit.Abstractions; @@ -52,6 +53,12 @@ public void TestLightGbmRegression() TestFeatureContribution(ML.Regression.Trainers.LightGbm(), GetSparseDataset(numberOfInstances: 100), "LightGbmRegression"); } + [LightGBMFact] + public void TestLightGbmRegressionWithCategoricalSplit() + { + TestFeatureContribution(ML.Regression.Trainers.LightGbm(new LightGbmRegressionTrainer.Options() { UseCategoricalSplit = true }), GetOneHotEncodedData(numberOfInstances: 100), "LightGbmRegressionWithCategoricalSplit"); + } + [Fact] public void TestFastTreeRegression() { @@ -377,5 +384,55 @@ private enum TaskType Ranking, Clustering } + + public class TaxiTrip + { + [LoadColumn(0)] + public string VendorId; + + [LoadColumn(1)] + public float RateCode; + + [LoadColumn(2)] + public float PassengerCount; + + [LoadColumn(3)] + public float TripTime; + + [LoadColumn(4)] + public float TripDistance; + + [LoadColumn(5)] + public string PaymentType; + + [LoadColumn(6)] + public float FareAmount; + } + + /// + /// Returns a DataView with a Features column which include HotEncodedData + /// + private IDataView GetOneHotEncodedData(int numberOfInstances = 100) + { + var trainDataPath = GetDataPath("taxi-fare-train.csv"); + IDataView trainingDataView = ML.Data.LoadFromTextFile(trainDataPath, hasHeader: true, separatorChar: ','); + + var vendorIdEncoded = "VendorIdEncoded"; + var rateCodeEncoded = "RateCodeEncoded"; + var paymentTypeEncoded = "PaymentTypeEncoded"; + + var dataProcessPipeline = ML.Transforms.CopyColumns(outputColumnName: DefaultColumnNames.Label, inputColumnName: nameof(TaxiTrip.FareAmount)) + .Append(ML.Transforms.Categorical.OneHotEncoding(outputColumnName: vendorIdEncoded, inputColumnName: nameof(TaxiTrip.VendorId))) + .Append(ML.Transforms.Categorical.OneHotEncoding(outputColumnName: rateCodeEncoded, inputColumnName: nameof(TaxiTrip.RateCode))) + .Append(ML.Transforms.Categorical.OneHotEncoding(outputColumnName: paymentTypeEncoded, inputColumnName: nameof(TaxiTrip.PaymentType))) + .Append(ML.Transforms.NormalizeMeanVariance(outputColumnName: nameof(TaxiTrip.PassengerCount))) + .Append(ML.Transforms.NormalizeMeanVariance(outputColumnName: nameof(TaxiTrip.TripTime))) + .Append(ML.Transforms.NormalizeMeanVariance(outputColumnName: nameof(TaxiTrip.TripDistance))) + .Append(ML.Transforms.Concatenate(DefaultColumnNames.Features, vendorIdEncoded, rateCodeEncoded, paymentTypeEncoded, + nameof(TaxiTrip.PassengerCount), nameof(TaxiTrip.TripTime), nameof(TaxiTrip.TripDistance))); + + var someRows = ML.Data.TakeRows(trainingDataView, numberOfInstances); + return dataProcessPipeline.Fit(someRows).Transform(someRows); + } } }