diff --git a/src/Microsoft.ML.FastTree/TreeEnsemble/TreeEnsemble.cs b/src/Microsoft.ML.FastTree/TreeEnsemble/TreeEnsemble.cs
index f120007347..f2259b5782 100644
--- a/src/Microsoft.ML.FastTree/TreeEnsemble/TreeEnsemble.cs
+++ b/src/Microsoft.ML.FastTree/TreeEnsemble/TreeEnsemble.cs
@@ -18,6 +18,9 @@ namespace Microsoft.ML.Trainers.FastTree.Internal
{
public class TreeEnsemble
{
+ ///
+ /// String appended to the text representation of . This is mainly used in .
+ ///
private readonly string _firstInputInitializationContent;
private readonly List _trees;
diff --git a/src/Microsoft.ML.SamplesUtils/SamplesDatasetUtils.cs b/src/Microsoft.ML.SamplesUtils/SamplesDatasetUtils.cs
index 60a9a12ee5..90f8417d1f 100644
--- a/src/Microsoft.ML.SamplesUtils/SamplesDatasetUtils.cs
+++ b/src/Microsoft.ML.SamplesUtils/SamplesDatasetUtils.cs
@@ -257,7 +257,7 @@ public static IEnumerable GenerateBinaryLa
// Initialize an example with a random label and an empty feature vector.
var sample = new BinaryLabelFloatFeatureVectorSample() { Label = rnd.Next() % 2 == 0, Features = new float[_simpleBinaryClassSampleFeatureLength] };
// Fill feature vector according the assigned label.
- for (int j = 0; j < 10; ++j)
+ for (int j = 0; j < _simpleBinaryClassSampleFeatureLength; ++j)
{
var value = (float)rnd.NextDouble();
// Positive class gets larger feature value.
@@ -271,6 +271,58 @@ public static IEnumerable GenerateBinaryLa
return data;
}
+ public class FfmExample
+ {
+ public bool Label;
+
+ [VectorType(_simpleBinaryClassSampleFeatureLength)]
+ public float[] Field0;
+
+ [VectorType(_simpleBinaryClassSampleFeatureLength)]
+ public float[] Field1;
+
+ [VectorType(_simpleBinaryClassSampleFeatureLength)]
+ public float[] Field2;
+ }
+
+ public static IEnumerable GenerateFfmSamples(int exampleCount)
+ {
+ var rnd = new Random(0);
+ var data = new List();
+ for (int i = 0; i < exampleCount; ++i)
+ {
+ // Initialize an example with a random label and an empty feature vector.
+ var sample = new FfmExample() { Label = rnd.Next() % 2 == 0,
+ Field0 = new float[_simpleBinaryClassSampleFeatureLength],
+ Field1 = new float[_simpleBinaryClassSampleFeatureLength],
+ Field2 = new float[_simpleBinaryClassSampleFeatureLength] };
+ // Fill feature vector according the assigned label.
+ for (int j = 0; j < 10; ++j)
+ {
+ var value0 = (float)rnd.NextDouble();
+ // Positive class gets larger feature value.
+ if (sample.Label)
+ value0 += 0.2f;
+ sample.Field0[j] = value0;
+
+ var value1 = (float)rnd.NextDouble();
+ // Positive class gets smaller feature value.
+ if (sample.Label)
+ value1 -= 0.2f;
+ sample.Field1[j] = value1;
+
+ var value2 = (float)rnd.NextDouble();
+ // Positive class gets larger feature value.
+ if (sample.Label)
+ value2 += 0.8f;
+ sample.Field2[j] = value2;
+ }
+
+ data.Add(sample);
+ }
+ return data;
+ }
+
///
/// feature vector's length in .
///
diff --git a/src/Microsoft.ML.StandardLearners/FactorizationMachine/FactorizationMachineTrainer.cs b/src/Microsoft.ML.StandardLearners/FactorizationMachine/FactorizationMachineTrainer.cs
index 60ad74bbe2..247c85c766 100644
--- a/src/Microsoft.ML.StandardLearners/FactorizationMachine/FactorizationMachineTrainer.cs
+++ b/src/Microsoft.ML.StandardLearners/FactorizationMachine/FactorizationMachineTrainer.cs
@@ -40,7 +40,7 @@ public sealed class FieldAwareFactorizationMachineTrainer : TrainerBase
+ /// Extra feature column names. The column named stores features from the first field.
+ /// The i-th string in stores the name of the (i+1)-th field's feature column.
+ ///
+ [Argument(ArgumentType.Multiple, HelpText = "Extra columns to use for feature vectors. The i-th specified string denotes the column containing features form the (i+1)-th field." +
+ " Note that the first field is specified by \"feat\" instead of \"exfeat\".",
+ ShortName = "exfeat", SortOrder = 7)]
+ public string[] ExtraFeatureColumns;
+
[Argument(ArgumentType.AtMostOnce, HelpText = "Whether to shuffle for each training iteration", ShortName = "shuf", SortOrder = 90)]
public bool Shuffle = true;
@@ -122,13 +131,26 @@ public FieldAwareFactorizationMachineTrainer(IHostEnvironment env, Arguments arg
{
Initialize(env, args);
Info = new TrainerInfo(supportValid: true, supportIncrementalTrain: true);
+
+ // There can be multiple feature columns in FFM, jointly specified by args.FeatureColumn and args.ExtraFeatureColumns.
+ FeatureColumns = new SchemaShape.Column[1 + args.ExtraFeatureColumns.Length];
+
+ // Treat the default feature column as the 1st field.
+ FeatureColumns[0] = new SchemaShape.Column(args.FeatureColumn, SchemaShape.Column.VectorKind.Vector, NumberType.R4, false);
+
+ // Add 2nd, 3rd, and other fields from a FFM-specific argument, args.ExtraFeatureColumns.
+ for (int i = 0; args.ExtraFeatureColumns != null && i < args.ExtraFeatureColumns.Length; i++)
+ FeatureColumns[i + 1] = new SchemaShape.Column(args.ExtraFeatureColumns[i], SchemaShape.Column.VectorKind.Vector, NumberType.R4, false);
+
+ LabelColumn = new SchemaShape.Column(args.LabelColumn, SchemaShape.Column.VectorKind.Scalar, BoolType.Instance, false);
+ WeightColumn = args.WeightColumn.IsExplicit ? new SchemaShape.Column(args.WeightColumn, SchemaShape.Column.VectorKind.Scalar, NumberType.R4, false) : default;
}
///
/// Initializing a new instance of .
///
/// The private instance of .
- /// The name of column hosting the features.
+ /// The name of column hosting the features. The i-th element stores feature column of the i-th field.
/// The name of the label column.
/// A delegate to apply all the advanced arguments to the algorithm.
/// The name of the optional weights' column.
diff --git a/test/BaselineOutput/Common/EntryPoints/core_manifest.json b/test/BaselineOutput/Common/EntryPoints/core_manifest.json
index 38e3c12b60..b7417b9392 100644
--- a/test/BaselineOutput/Common/EntryPoints/core_manifest.json
+++ b/test/BaselineOutput/Common/EntryPoints/core_manifest.json
@@ -10222,6 +10222,18 @@
"IsLogScale": true
}
},
+ {
+ "Name": "WeightColumn",
+ "Type": "String",
+ "Desc": "Column to use for example weight",
+ "Aliases": [
+ "weight"
+ ],
+ "Required": false,
+ "SortOrder": 4.0,
+ "IsNullable": false,
+ "Default": "Weight"
+ },
{
"Name": "LambdaLatent",
"Type": "Float",
@@ -10292,6 +10304,21 @@
"IsNullable": false,
"Default": "Auto"
},
+ {
+ "Name": "ExtraFeatureColumns",
+ "Type": {
+ "Kind": "Array",
+ "ItemType": "String"
+ },
+ "Desc": "Extra columns to use for feature vectors. The i-th specified string denotes the column containing features form the (i+1)-th field. Note that the first field is specified by \"feat\" instead of \"exfeat\".",
+ "Aliases": [
+ "exfeat"
+ ],
+ "Required": false,
+ "SortOrder": 7.0,
+ "IsNullable": false,
+ "Default": null
+ },
{
"Name": "Shuffle",
"Type": "Bool",
@@ -10342,6 +10369,7 @@
}
],
"InputKind": [
+ "ITrainerInputWithWeight",
"ITrainerInputWithLabel",
"ITrainerInput"
],
diff --git a/test/Microsoft.ML.Tests/TrainerEstimators/FAFMEstimator.cs b/test/Microsoft.ML.Tests/TrainerEstimators/FAFMEstimator.cs
index 85b096ad30..1e9be1492f 100644
--- a/test/Microsoft.ML.Tests/TrainerEstimators/FAFMEstimator.cs
+++ b/test/Microsoft.ML.Tests/TrainerEstimators/FAFMEstimator.cs
@@ -2,15 +2,43 @@
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
+using System.Linq;
using Microsoft.ML.Data;
using Microsoft.ML.FactorizationMachine;
using Microsoft.ML.RunTests;
+using Microsoft.ML.SamplesUtils;
using Xunit;
namespace Microsoft.ML.Tests.TrainerEstimators
{
public partial class TrainerEstimators : TestDataPipeBase
{
+ [Fact]
+ public void FfmBinaryClassificationWithAdvancedArguments()
+ {
+ var mlContext = new MLContext(seed: 0);
+ var data = DatasetUtils.GenerateFfmSamples(500);
+ var dataView = ComponentCreation.CreateDataView(mlContext, data.ToList());
+
+ var ffmArgs = new FieldAwareFactorizationMachineTrainer.Arguments();
+
+ // Customized the field names.
+ ffmArgs.FeatureColumn = nameof(DatasetUtils.FfmExample.Field0); // First field.
+ ffmArgs.ExtraFeatureColumns = new[]{ nameof(DatasetUtils.FfmExample.Field1), nameof(DatasetUtils.FfmExample.Field2) };
+
+ var pipeline = new FieldAwareFactorizationMachineTrainer(mlContext, ffmArgs);
+
+ var model = pipeline.Fit(dataView);
+ var prediction = model.Transform(dataView);
+
+ var metrics = mlContext.BinaryClassification.Evaluate(prediction);
+
+ // Run a sanity check against a few of the metrics.
+ Assert.InRange(metrics.Accuracy, 0.9, 1);
+ Assert.InRange(metrics.Auc, 0.9, 1);
+ Assert.InRange(metrics.Auprc, 0.9, 1);
+ }
+
[Fact]
public void FieldAwareFactorizationMachine_Estimator()
{