From b53d09eea3af59e36aef022e26df53171dd926d6 Mon Sep 17 00:00:00 2001 From: Senja Filipi Date: Thu, 13 Sep 2018 23:36:56 -0700 Subject: [PATCH 01/10] FAFM to extend TrainerEstimatorBase --- .../FactorizationMachineTrainer.cs | 110 +++++++++++++++--- 1 file changed, 91 insertions(+), 19 deletions(-) diff --git a/src/Microsoft.ML.StandardLearners/FactorizationMachine/FactorizationMachineTrainer.cs b/src/Microsoft.ML.StandardLearners/FactorizationMachine/FactorizationMachineTrainer.cs index 0980cb22dc..aea709aa1a 100644 --- a/src/Microsoft.ML.StandardLearners/FactorizationMachine/FactorizationMachineTrainer.cs +++ b/src/Microsoft.ML.StandardLearners/FactorizationMachine/FactorizationMachineTrainer.cs @@ -5,6 +5,7 @@ using System; using System.Collections.Generic; using System.Linq; +using Microsoft.ML.Core.Data; using Microsoft.ML.Runtime; using Microsoft.ML.Runtime.CommandLine; using Microsoft.ML.Runtime.Data; @@ -30,12 +31,12 @@ namespace Microsoft.ML.Runtime.FactorizationMachine [3] https://github.com/wschin/fast-ffm/blob/master/fast-ffm.pdf */ /// - public sealed class FieldAwareFactorizationMachineTrainer : TrainerBase + public sealed class FieldAwareFactorizationMachineTrainer : TrainerEstimatorBase, FieldAwareFactorizationMachinePredictor> { - public const string Summary = "Train a field-aware factorization machine for binary classification"; - public const string UserName = "Field-aware Factorization Machine"; - public const string LoadName = "FieldAwareFactorizationMachine"; - public const string ShortName = "ffm"; + internal const string Summary = "Train a field-aware factorization machine for binary classification"; + internal const string UserName = "Field-aware Factorization Machine"; + internal const string LoadName = "FieldAwareFactorizationMachine"; + internal const string ShortName = "ffm"; public sealed class Arguments : LearnerInputBaseWithLabel { @@ -74,19 +75,61 @@ public sealed class Arguments : LearnerInputBaseWithLabel } public override PredictionKind PredictionKind => PredictionKind.BinaryClassification; + + /// + /// The containing at least the training data for this trainer. + /// public override TrainerInfo Info { get; } - private readonly int _latentDim; - private readonly int _latentDimAligned; - private readonly float _lambdaLinear; - private readonly float _lambdaLatent; - private readonly float _learningRate; - private readonly int _numIterations; - private readonly bool _norm; - private readonly bool _shuffle; - private readonly bool _verbose; - private readonly float _radius; - - public FieldAwareFactorizationMachineTrainer(IHostEnvironment env, Arguments args) : base(env, LoadName) + + private int _latentDim; + private int _latentDimAligned; + private float _lambdaLinear; + private float _lambdaLatent; + private float _learningRate; + private int _numIterations; + private bool _norm; + private bool _shuffle; + private bool _verbose; + private float _radius; + private SchemaShape.Column[] _outputColumns; + + /// + /// Legacy constructor initializing a new instance of through the legacy + /// class. + /// + /// The private instance of . + /// An instance of the legacy to apply advanced parameters to the algorithm. + public FieldAwareFactorizationMachineTrainer(IHostEnvironment env, Arguments args) + : base(Contracts.CheckRef(env, nameof(env)).Register(LoadName), MakeFeatureColumn(args.FeatureColumn), MakeLabelColumn(args.LabelColumn)) + { + Initialize(env, args); + Info = new TrainerInfo(supportValid: true, supportIncrementalTrain: true); + } + + /// + /// Initializing a new instance of . + /// + /// The private instance of . + /// The name of the label column. + /// The name of the feature column. + /// A delegate to apply all the advanced arguments to the algorithm. + public FieldAwareFactorizationMachineTrainer(IHostEnvironment env, string labelColumn, string featureColumn, Action advancedSettings) + : base(Contracts.CheckRef(env, nameof(env)).Register(LoadName), MakeFeatureColumn(featureColumn), MakeLabelColumn(labelColumn)) + { + var args = new Arguments(); + advancedSettings?.Invoke(args); + + Initialize(env, args); + Info = new TrainerInfo(supportValid: true, supportIncrementalTrain: true); + } + + /// + /// Initializes the instance. Shared between the two constructors. + /// REVIEW: Once the legacy constructor goes away, this can move to the only constructor and most of the fields can be back to readonly. + /// + /// + /// + private void Initialize(IHostEnvironment env, Arguments args) { Host.CheckUserArg(args.LatentDim > 0, nameof(args.LatentDim), "Must be positive"); Host.CheckUserArg(args.LambdaLinear >= 0, nameof(args.LambdaLinear), "Must be non-negative"); @@ -103,7 +146,13 @@ public FieldAwareFactorizationMachineTrainer(IHostEnvironment env, Arguments arg _shuffle = args.Shuffle; _verbose = args.Verbose; _radius = args.Radius; - Info = new TrainerInfo(supportValid: true, supportIncrementalTrain: true); + + _outputColumns = new[] + { + new SchemaShape.Column(DefaultColumnNames.Score, SchemaShape.Column.VectorKind.Scalar, NumberType.R4, false), + new SchemaShape.Column(DefaultColumnNames.Probability, SchemaShape.Column.VectorKind.Scalar, NumberType.R4, false), + new SchemaShape.Column(DefaultColumnNames.PredictedLabel, SchemaShape.Column.VectorKind.Scalar, BoolType.Instance, false) + }; } private void InitializeTrainingState(int fieldCount, int featureCount, FieldAwareFactorizationMachinePredictor predictor, out float[] linearWeights, @@ -345,7 +394,7 @@ private FieldAwareFactorizationMachinePredictor TrainCore(IChannel ch, IProgress return new FieldAwareFactorizationMachinePredictor(Host, _norm, fieldCount, totalFeatureCount, _latentDim, linearWeights, latentWeightsAligned); } - public override FieldAwareFactorizationMachinePredictor Train(TrainContext context) + protected override FieldAwareFactorizationMachinePredictor TrainModelCore(TrainContext context) { Host.CheckValue(context, nameof(context)); var initPredictor = context.InitialPredictor as FieldAwareFactorizationMachinePredictor; @@ -361,6 +410,23 @@ public override FieldAwareFactorizationMachinePredictor Train(TrainContext conte } } + private static SchemaShape.Column MakeWeightColumn(string weightColumn) + { + if (weightColumn == null) + return null; + return new SchemaShape.Column(weightColumn, SchemaShape.Column.VectorKind.Scalar, NumberType.R4, false); + } + + private static SchemaShape.Column MakeFeatureColumn(string featureColumn) + { + return new SchemaShape.Column(featureColumn, SchemaShape.Column.VectorKind.Vector, NumberType.R4, false); + } + + private static SchemaShape.Column MakeLabelColumn(string labelColumn) + { + return new SchemaShape.Column(labelColumn, SchemaShape.Column.VectorKind.Scalar, BoolType.Instance, false); + } + [TlcModule.EntryPoint(Name = "Trainers.FieldAwareFactorizationMachineBinaryClassifier", Desc = Summary, UserName = UserName, @@ -376,5 +442,11 @@ public static CommonOutputs.BinaryClassificationOutput TrainBinary(IHostEnvironm return LearnerEntryPointsUtils.Train(host, input, () => new FieldAwareFactorizationMachineTrainer(host, input), () => LearnerEntryPointsUtils.FindColumn(host, input.TrainingData.Schema, input.LabelColumn)); } + + protected override SchemaShape.Column[] GetOutputColumnsCore(SchemaShape inputSchema) => _outputColumns; + + protected override BinaryPredictionTransformer MakeTransformer(FieldAwareFactorizationMachinePredictor model, ISchema trainSchema) + => new BinaryPredictionTransformer(Host, model, trainSchema, FeatureColumn.Name); + } } From 63587774e4737f11fdd0cd61c5764d3a63404ac2 Mon Sep 17 00:00:00 2001 From: Senja Filipi Date: Mon, 17 Sep 2018 16:55:18 -0700 Subject: [PATCH 02/10] support for an array of features FAFM doesn't inherit from TrainerEstimatorBase, just implements ITrainerEstimator --- .../Prediction/IPredictionTransformer.cs | 9 +- .../Scorers/PredictionTransformer.cs | 93 ++++++++--- .../Training/TrainerEstimatorContext.cs | 50 ++++++ .../FactorizationMachineTrainer.cs | 152 +++++++++++++----- ...FieldAwareFactorizationMachinePredictor.cs | 7 +- .../Standard/LinearClassificationTrainer.cs | 2 +- .../Standard/MultiClass/Ova.cs | 4 +- .../Standard/MultiClass/Pkpd.cs | 4 +- .../Standard/Online/AveragedPerceptron.cs | 2 +- .../Standard/Online/LinearSvm.cs | 2 +- .../DataPipe/TestDataPipeBase.cs | 4 +- .../Microsoft.ML.Tests.csproj | 1 + .../Scenarios/Api/Estimators/Wrappers.cs | 23 ++- .../TrainerEstimators/FAFMEstimator.cs | 48 ++++++ .../TrainerEstimators/MetalinearEstimators.cs | 8 +- .../TrainerEstimators/OnlineLinearTests.cs | 11 +- .../TrainerEstimators/SdcaTests.cs | 8 +- .../TrainerEstimators/TrainerEstimators.cs | 20 +++ 18 files changed, 341 insertions(+), 107 deletions(-) create mode 100644 src/Microsoft.ML.Data/Training/TrainerEstimatorContext.cs create mode 100644 test/Microsoft.ML.Tests/TrainerEstimators/FAFMEstimator.cs create mode 100644 test/Microsoft.ML.Tests/TrainerEstimators/TrainerEstimators.cs diff --git a/src/Microsoft.ML.Data/Prediction/IPredictionTransformer.cs b/src/Microsoft.ML.Data/Prediction/IPredictionTransformer.cs index 899c7622dc..a84aaf3f4e 100644 --- a/src/Microsoft.ML.Data/Prediction/IPredictionTransformer.cs +++ b/src/Microsoft.ML.Data/Prediction/IPredictionTransformer.cs @@ -2,21 +2,18 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. +using System.Collections.Generic; using Microsoft.ML.Core.Data; using Microsoft.ML.Runtime.Data; -using Microsoft.ML.Runtime.Internal.Calibration; -using System; -using System.Collections.Generic; -using System.Text; namespace Microsoft.ML.Runtime { public interface IPredictionTransformer : ITransformer where TModel : IPredictor { - string FeatureColumn { get; } + string[] FeatureColumn { get; } - ColumnType FeatureColumnType { get; } + ColumnType[] FeatureColumnType { get; } TModel Model { get; } } diff --git a/src/Microsoft.ML.Data/Scorers/PredictionTransformer.cs b/src/Microsoft.ML.Data/Scorers/PredictionTransformer.cs index f04febd055..251f273952 100644 --- a/src/Microsoft.ML.Data/Scorers/PredictionTransformer.cs +++ b/src/Microsoft.ML.Data/Scorers/PredictionTransformer.cs @@ -3,11 +3,13 @@ // See the LICENSE file in the project root for more information. using System; +using System.Collections.Generic; using System.IO; using Microsoft.ML.Runtime; using Microsoft.ML.Runtime.Data; using Microsoft.ML.Runtime.Data.IO; using Microsoft.ML.Runtime.Model; +using static Microsoft.ML.Runtime.Data.RoleMappedSchema; [assembly: LoadableClass(typeof(BinaryPredictionTransformer>), typeof(BinaryPredictionTransformer), null, typeof(SignatureLoadModel), "", BinaryPredictionTransformer.LoaderSignature)] @@ -30,23 +32,33 @@ public abstract class PredictionTransformerBase : IPredictionTransformer protected readonly ISchemaBindableMapper BindableMapper; protected readonly ISchema TrainSchema; - public string FeatureColumn { get; } + public string[] FeatureColumn { get; } - public ColumnType FeatureColumnType { get; } + public ColumnType[] FeatureColumnType { get; } public TModel Model { get; } - public PredictionTransformerBase(IHost host, TModel model, ISchema trainSchema, string featureColumn) + public PredictionTransformerBase(IHost host, TModel model, ISchema trainSchema, string[] featureColumns) { Contracts.CheckValue(host, nameof(host)); Host = host; Host.CheckValue(trainSchema, nameof(trainSchema)); + Host.CheckValue(featureColumns, nameof(featureColumns)); + + int featCount = featureColumns.Length; + Host.Check(featCount >= 0 , "Empty features column."); Model = model; - FeatureColumn = featureColumn; - if (!trainSchema.TryGetColumnIndex(featureColumn, out int col)) - throw Host.ExceptSchemaMismatch(nameof(featureColumn), RoleMappedSchema.ColumnRole.Feature.Value, featureColumn); - FeatureColumnType = trainSchema.GetColumnType(col); + FeatureColumn = featureColumns; + FeatureColumnType = new ColumnType[featCount]; + + int i = 0; + foreach (var feat in featureColumns) + { + if (!trainSchema.TryGetColumnIndex(feat, out int col)) + throw Host.ExceptSchemaMismatch(nameof(featureColumns), RoleMappedSchema.ColumnRole.Feature.Value, feat); + FeatureColumnType[i++] = trainSchema.GetColumnType(col); + } TrainSchema = trainSchema; BindableMapper = ScoreUtils.GetSchemaBindableMapper(Host, model); @@ -62,7 +74,8 @@ internal PredictionTransformerBase(IHost host, ModelLoadContext ctx) // *** Binary format *** // model: prediction model. // stream: empty data view that contains train schema. - // id of string: feature column. + // count of features + // id of string: feature columns. // Clone the stream with the schema into memory. var ms = new MemoryStream(); @@ -75,10 +88,19 @@ internal PredictionTransformerBase(IHost host, ModelLoadContext ctx) var loader = new BinaryLoader(host, new BinaryLoader.Arguments(), ms); TrainSchema = loader.Schema; - FeatureColumn = ctx.LoadString(); - if (!TrainSchema.TryGetColumnIndex(FeatureColumn, out int col)) - throw Host.ExceptSchemaMismatch(nameof(FeatureColumn), RoleMappedSchema.ColumnRole.Feature.Value, FeatureColumn); - FeatureColumnType = TrainSchema.GetColumnType(col); + // count of feature columns. FAFM uses more than one. + int featCount = int.Parse(ctx.LoadString()); + + FeatureColumn = new string[featCount]; + FeatureColumnType = new ColumnType[featCount]; + + for (int i = 0; i < featCount; i++) + { + FeatureColumn[i] = ctx.LoadString(); + if (!TrainSchema.TryGetColumnIndex(FeatureColumn[i], out int col)) + throw Host.ExceptSchemaMismatch(nameof(FeatureColumn), RoleMappedSchema.ColumnRole.Feature.Value, FeatureColumn[i]); + FeatureColumnType[i] = TrainSchema.GetColumnType(col); + } BindableMapper = ScoreUtils.GetSchemaBindableMapper(Host, model); } @@ -87,10 +109,15 @@ public ISchema GetOutputSchema(ISchema inputSchema) { Host.CheckValue(inputSchema, nameof(inputSchema)); - if (!inputSchema.TryGetColumnIndex(FeatureColumn, out int col)) - throw Host.ExceptSchemaMismatch(nameof(inputSchema), RoleMappedSchema.ColumnRole.Feature.Value, FeatureColumn, FeatureColumnType.ToString(), null); - if (!inputSchema.GetColumnType(col).Equals(FeatureColumnType)) - throw Host.ExceptSchemaMismatch(nameof(inputSchema), RoleMappedSchema.ColumnRole.Feature.Value, FeatureColumn, FeatureColumnType.ToString(), inputSchema.GetColumnType(col).ToString()); + for (int i=0; i< FeatureColumn.Length; i++) + { + var feat = FeatureColumn[i]; + if (!inputSchema.TryGetColumnIndex(feat, out int col)) + throw Host.ExceptSchemaMismatch(nameof(inputSchema), RoleMappedSchema.ColumnRole.Feature.Value, feat, FeatureColumnType[i].ToString(), null); + + if (!inputSchema.GetColumnType(col).Equals(FeatureColumnType[i])) + throw Host.ExceptSchemaMismatch(nameof(inputSchema), RoleMappedSchema.ColumnRole.Feature.Value, feat, FeatureColumnType[i].ToString(), inputSchema.GetColumnType(col).ToString()); + } return Transform(new EmptyDataView(Host, inputSchema)).Schema; } @@ -109,6 +136,7 @@ protected virtual void SaveCore(ModelSaveContext ctx) // *** Binary format *** // model: prediction model. // stream: empty data view that contains train schema. + // number of feature columns // id of string: feature column. ctx.SaveModel(Model, DirModel); @@ -121,7 +149,24 @@ protected virtual void SaveCore(ModelSaveContext ctx) } }); - ctx.SaveString(FeatureColumn); + int featCount = FeatureColumn.Length; + + ctx.SaveString(featCount.ToString()); + for(int i=0; i< featCount; i++) + ctx.SaveString(FeatureColumn[i]); + } + + protected RoleMappedSchema GetSchema(ISchema inputSchema = null, string trainLabelColumn = null) + { + var roles = new List>(); + foreach (var feat in FeatureColumn) + roles.Add(new KeyValuePair(ColumnRole.Feature, feat)); + + if(trainLabelColumn !=null) + roles.Add(new KeyValuePair(ColumnRole.Label, trainLabelColumn)); + + var schema = new RoleMappedSchema(inputSchema ?? TrainSchema, roles); + return schema; } } @@ -133,12 +178,12 @@ public sealed class BinaryPredictionTransformer : PredictionTransformerB public readonly string ThresholdColumn; public readonly float Threshold; - public BinaryPredictionTransformer(IHostEnvironment env, TModel model, ISchema inputSchema, string featureColumn, + public BinaryPredictionTransformer(IHostEnvironment env, TModel model, ISchema inputSchema, string[] featureColumn, float threshold = 0f, string thresholdColumn = DefaultColumnNames.Score) : base(Contracts.CheckRef(env, nameof(env)).Register(nameof(BinaryPredictionTransformer)), model, inputSchema, featureColumn) { Host.CheckNonEmpty(thresholdColumn, nameof(thresholdColumn)); - var schema = new RoleMappedSchema(inputSchema, null, featureColumn); + var schema = GetSchema(inputSchema); Threshold = threshold; ThresholdColumn = thresholdColumn; @@ -157,7 +202,7 @@ public BinaryPredictionTransformer(IHostEnvironment env, ModelLoadContext ctx) Threshold = ctx.Reader.ReadSingle(); ThresholdColumn = ctx.LoadString(); - var schema = new RoleMappedSchema(TrainSchema, null, FeatureColumn); + var schema = GetSchema(); var args = new BinaryClassifierScorer.Arguments { Threshold = Threshold, ThresholdColumn = ThresholdColumn }; _scorer = new BinaryClassifierScorer(Host, args, new EmptyDataView(Host, TrainSchema), BindableMapper.Bind(Host, schema), schema); } @@ -201,7 +246,7 @@ public sealed class MulticlassPredictionTransformer : PredictionTransfor private readonly string _trainLabelColumn; public MulticlassPredictionTransformer(IHostEnvironment env, TModel model, ISchema inputSchema, string featureColumn, string labelColumn) - : base(Contracts.CheckRef(env, nameof(env)).Register(nameof(MulticlassPredictionTransformer)), model, inputSchema, featureColumn) + : base(Contracts.CheckRef(env, nameof(env)).Register(nameof(MulticlassPredictionTransformer)), model, inputSchema, new[] { featureColumn }) { Host.CheckValueOrNull(labelColumn); @@ -220,7 +265,7 @@ public MulticlassPredictionTransformer(IHostEnvironment env, ModelLoadContext ct _trainLabelColumn = ctx.LoadStringOrNull(); - var schema = new RoleMappedSchema(TrainSchema, _trainLabelColumn, FeatureColumn); + var schema = GetSchema(trainLabelColumn: _trainLabelColumn); var args = new MultiClassClassifierScorer.Arguments(); _scorer = new MultiClassClassifierScorer(Host, args, new EmptyDataView(Host, TrainSchema), BindableMapper.Bind(Host, schema), schema); } @@ -261,7 +306,7 @@ public sealed class RegressionPredictionTransformer : PredictionTransfor private readonly GenericScorer _scorer; public RegressionPredictionTransformer(IHostEnvironment env, TModel model, ISchema inputSchema, string featureColumn) - : base(Contracts.CheckRef(env, nameof(env)).Register(nameof(RegressionPredictionTransformer)), model, inputSchema, featureColumn) + : base(Contracts.CheckRef(env, nameof(env)).Register(nameof(RegressionPredictionTransformer)), model, inputSchema, new[] { featureColumn }) { var schema = new RoleMappedSchema(inputSchema, null, featureColumn); _scorer = new GenericScorer(Host, new GenericScorer.Arguments(), new EmptyDataView(Host, inputSchema), BindableMapper.Bind(Host, schema), schema); @@ -270,7 +315,7 @@ public RegressionPredictionTransformer(IHostEnvironment env, TModel model, ISche internal RegressionPredictionTransformer(IHostEnvironment env, ModelLoadContext ctx) : base(Contracts.CheckRef(env, nameof(env)).Register(nameof(RegressionPredictionTransformer)), ctx) { - var schema = new RoleMappedSchema(TrainSchema, null, FeatureColumn); + var schema = GetSchema(); _scorer = new GenericScorer(Host, new GenericScorer.Arguments(), new EmptyDataView(Host, TrainSchema), BindableMapper.Bind(Host, schema), schema); } diff --git a/src/Microsoft.ML.Data/Training/TrainerEstimatorContext.cs b/src/Microsoft.ML.Data/Training/TrainerEstimatorContext.cs new file mode 100644 index 0000000000..f1388dffd2 --- /dev/null +++ b/src/Microsoft.ML.Data/Training/TrainerEstimatorContext.cs @@ -0,0 +1,50 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using System.Collections.Generic; +using System.Text; +using Microsoft.ML.Runtime; +using Microsoft.ML.Runtime.Data; +using Microsoft.ML.Runtime.Training; + +namespace Microsoft.ML.Core.Prediction +{ + /// + /// Holds information relevant to trainers. It is passed to the constructor of the + /// holding additional data needed to fit the estimator. The additional data can be a validation set or an initial model. + /// This holds at least a training set, as well as optioonally a predictor. + /// + public class TrainerEstimatorContext + { + /// + /// The validation set. Can be null. Note that passing a non-null validation set into + /// a trainer that does not support validation sets should not be considered an error condition. It + /// should simply be ignored in that case. + /// + public IDataView ValidationSet { get; } + + /// + /// The initial predictor, for incremental training. Note that if a implementor + /// does not support incremental training, then it can ignore it similarly to how one would ignore + /// . However, if the trainer does support incremental training and there + /// is something wrong with a non-null value of this, then the trainer ought to throw an exception. + /// + public IPredictor InitialPredictor { get; } + + /// + /// Initializes a new instance of , given a training set and optional other arguments. + /// + /// Will set to this value if specified + /// Will set to this value if specified + public TrainerEstimatorContext(IDataView validationSet = null, IPredictor initialPredictor = null) + { + Contracts.CheckValueOrNull(validationSet); + Contracts.CheckValueOrNull(initialPredictor); + + ValidationSet = validationSet; + InitialPredictor = initialPredictor; + } + } +} diff --git a/src/Microsoft.ML.StandardLearners/FactorizationMachine/FactorizationMachineTrainer.cs b/src/Microsoft.ML.StandardLearners/FactorizationMachine/FactorizationMachineTrainer.cs index 6c0d4dc8f3..ad1b9fed5f 100644 --- a/src/Microsoft.ML.StandardLearners/FactorizationMachine/FactorizationMachineTrainer.cs +++ b/src/Microsoft.ML.StandardLearners/FactorizationMachine/FactorizationMachineTrainer.cs @@ -6,6 +6,7 @@ using System.Collections.Generic; using System.Linq; using Microsoft.ML.Core.Data; +using Microsoft.ML.Core.Prediction; using Microsoft.ML.Runtime; using Microsoft.ML.Runtime.CommandLine; using Microsoft.ML.Runtime.Data; @@ -15,8 +16,9 @@ using Microsoft.ML.Runtime.Internal.Utilities; using Microsoft.ML.Runtime.Training; -[assembly: LoadableClass(FieldAwareFactorizationMachineTrainer.Summary, typeof(FieldAwareFactorizationMachineTrainer), typeof(FieldAwareFactorizationMachineTrainer.Arguments), - new[] { typeof(SignatureBinaryClassifierTrainer), typeof(SignatureTrainer) }, FieldAwareFactorizationMachineTrainer.UserName, FieldAwareFactorizationMachineTrainer.LoadName, +[assembly: LoadableClass(FieldAwareFactorizationMachineTrainer.Summary, typeof(FieldAwareFactorizationMachineTrainer), + typeof(FieldAwareFactorizationMachineTrainer.Arguments), new[] { typeof(SignatureBinaryClassifierTrainer), typeof(SignatureTrainer) } + , FieldAwareFactorizationMachineTrainer.UserName, FieldAwareFactorizationMachineTrainer.LoadName, FieldAwareFactorizationMachineTrainer.ShortName, DocName = "trainer/FactorizationMachine.md")] [assembly: LoadableClass(typeof(void), typeof(FieldAwareFactorizationMachineTrainer), null, typeof(SignatureEntryPointModule), FieldAwareFactorizationMachineTrainer.LoadName)] @@ -25,13 +27,14 @@ namespace Microsoft.ML.Runtime.FactorizationMachine { /* Train a field-aware factorization machine using ADAGRAD (an advanced stochastic gradient method). See references below - for details. This trainer is essentially faster the one introduced in [2] because of some implemtation tricks[3]. + for details. This trainer is essentially faster the one introduced in [2] because of some implementation tricks[3]. [1] http://jmlr.org/papers/volume12/duchi11a/duchi11a.pdf [2] https://www.csie.ntu.edu.tw/~cjlin/papers/ffm.pdf [3] https://github.com/wschin/fast-ffm/blob/master/fast-ffm.pdf */ /// - public sealed class FieldAwareFactorizationMachineTrainer : TrainerEstimatorBase, FieldAwareFactorizationMachinePredictor> + public sealed class FieldAwareFactorizationMachineTrainer : TrainerBase, + ITrainerEstimator, FieldAwareFactorizationMachinePredictor> { internal const string Summary = "Train a field-aware factorization machine for binary classification"; internal const string UserName = "Field-aware Factorization Machine"; @@ -76,11 +79,33 @@ public sealed class Arguments : LearnerInputBaseWithLabel public override PredictionKind PredictionKind => PredictionKind.BinaryClassification; + /// + /// The feature column that the trainer expects. + /// + public readonly SchemaShape.Column[] FeatureColumns; + + /// + /// The label column that the trainer expects. Can be null, which indicates that label + /// is not used for training. + /// + public readonly SchemaShape.Column LabelColumn; + + /// + /// The weight column that the trainer expects. Can be null, which indicates that weight is + /// not used for training. + /// + public readonly SchemaShape.Column WeightColumn; + /// /// The containing at least the training data for this trainer. /// public override TrainerInfo Info { get; } + /// + /// Additional data for training, through + /// + public readonly TrainerEstimatorContext Context; + private int _latentDim; private int _latentDimAligned; private float _lambdaLinear; @@ -100,7 +125,7 @@ public sealed class Arguments : LearnerInputBaseWithLabel /// The private instance of . /// An instance of the legacy to apply advanced parameters to the algorithm. public FieldAwareFactorizationMachineTrainer(IHostEnvironment env, Arguments args) - : base(Contracts.CheckRef(env, nameof(env)).Register(LoadName), MakeFeatureColumn(args.FeatureColumn), MakeLabelColumn(args.LabelColumn)) + :base(env, LoadName) { Initialize(env, args); Info = new TrainerInfo(supportValid: true, supportIncrementalTrain: true); @@ -111,16 +136,36 @@ public FieldAwareFactorizationMachineTrainer(IHostEnvironment env, Arguments arg /// /// The private instance of . /// The name of the label column. - /// The name of the feature column. + /// The name of column hosting the features. /// A delegate to apply all the advanced arguments to the algorithm. - public FieldAwareFactorizationMachineTrainer(IHostEnvironment env, string labelColumn, string featureColumn, Action advancedSettings) - : base(Contracts.CheckRef(env, nameof(env)).Register(LoadName), MakeFeatureColumn(featureColumn), MakeLabelColumn(labelColumn)) + /// The name of the weight column. + /// The for additional input data to training. + public FieldAwareFactorizationMachineTrainer(IHostEnvironment env, string labelColumn, string[] featureColumns, + string weightColumn = null, TrainerEstimatorContext context = null, Action advancedSettings= null) + : base(env, LoadName) { var args = new Arguments(); advancedSettings?.Invoke(args); Initialize(env, args); Info = new TrainerInfo(supportValid: true, supportIncrementalTrain: true); + + Context = context; + + FeatureColumns = new SchemaShape.Column[featureColumns.Length]; + + for(int i=0; i< featureColumns.Length; i++) + FeatureColumns[i] = new SchemaShape.Column(featureColumns[i], SchemaShape.Column.VectorKind.Vector, NumberType.R4, false); + + LabelColumn = new SchemaShape.Column(labelColumn, SchemaShape.Column.VectorKind.Scalar, BoolType.Instance, false); + WeightColumn = weightColumn != null? new SchemaShape.Column(weightColumn, SchemaShape.Column.VectorKind.Scalar, NumberType.R4, false): null; + + _outputColumns = new[] + { + new SchemaShape.Column(DefaultColumnNames.Score, SchemaShape.Column.VectorKind.Scalar, NumberType.R4, false), + new SchemaShape.Column(DefaultColumnNames.Probability, SchemaShape.Column.VectorKind.Scalar, NumberType.R4, false), + new SchemaShape.Column(DefaultColumnNames.PredictedLabel, SchemaShape.Column.VectorKind.Scalar, BoolType.Instance, false) + }; } /// @@ -146,13 +191,6 @@ private void Initialize(IHostEnvironment env, Arguments args) _shuffle = args.Shuffle; _verbose = args.Verbose; _radius = args.Radius; - - _outputColumns = new[] - { - new SchemaShape.Column(DefaultColumnNames.Score, SchemaShape.Column.VectorKind.Scalar, NumberType.R4, false), - new SchemaShape.Column(DefaultColumnNames.Probability, SchemaShape.Column.VectorKind.Scalar, NumberType.R4, false), - new SchemaShape.Column(DefaultColumnNames.PredictedLabel, SchemaShape.Column.VectorKind.Scalar, BoolType.Instance, false) - }; } private void InitializeTrainingState(int fieldCount, int featureCount, FieldAwareFactorizationMachinePredictor predictor, out float[] linearWeights, @@ -391,10 +429,11 @@ private FieldAwareFactorizationMachinePredictor TrainCore(IChannel ch, IProgress ch.Warning($"Skipped {badExampleCount} examples with bad label/weight/features in training set"); if (validBadExampleCount != 0) ch.Warning($"Skipped {validBadExampleCount} examples with bad label/weight/features in validation set"); + return new FieldAwareFactorizationMachinePredictor(Host, _norm, fieldCount, totalFeatureCount, _latentDim, linearWeights, latentWeightsAligned); } - protected override FieldAwareFactorizationMachinePredictor TrainModelCore(TrainContext context) + public override FieldAwareFactorizationMachinePredictor Train(TrainContext context) { Host.CheckValue(context, nameof(context)); var initPredictor = context.InitialPredictor as FieldAwareFactorizationMachinePredictor; @@ -410,23 +449,6 @@ protected override FieldAwareFactorizationMachinePredictor TrainModelCore(TrainC } } - private static SchemaShape.Column MakeWeightColumn(string weightColumn) - { - if (weightColumn == null) - return null; - return new SchemaShape.Column(weightColumn, SchemaShape.Column.VectorKind.Scalar, NumberType.R4, false); - } - - private static SchemaShape.Column MakeFeatureColumn(string featureColumn) - { - return new SchemaShape.Column(featureColumn, SchemaShape.Column.VectorKind.Vector, NumberType.R4, false); - } - - private static SchemaShape.Column MakeLabelColumn(string labelColumn) - { - return new SchemaShape.Column(labelColumn, SchemaShape.Column.VectorKind.Scalar, BoolType.Instance, false); - } - [TlcModule.EntryPoint(Name = "Trainers.FieldAwareFactorizationMachineBinaryClassifier", Desc = Summary, UserName = UserName, @@ -443,10 +465,68 @@ public static CommonOutputs.BinaryClassificationOutput TrainBinary(IHostEnvironm () => LearnerEntryPointsUtils.FindColumn(host, input.TrainingData.Schema, input.LabelColumn)); } - protected override SchemaShape.Column[] GetOutputColumnsCore(SchemaShape inputSchema) => _outputColumns; + public BinaryPredictionTransformer Fit(IDataView input) + { + FieldAwareFactorizationMachinePredictor model = null; + + var roles = new List>(); + foreach (var feat in FeatureColumns) + roles.Add(new KeyValuePair(RoleMappedSchema.ColumnRole.Feature, feat.Name)); + + roles.Add(new KeyValuePair(RoleMappedSchema.ColumnRole.Label, LabelColumn.Name)); - protected override BinaryPredictionTransformer MakeTransformer(FieldAwareFactorizationMachinePredictor model, ISchema trainSchema) - => new BinaryPredictionTransformer(Host, model, trainSchema, FeatureColumn.Name); + if(WeightColumn != null) + roles.Add(new KeyValuePair(RoleMappedSchema.ColumnRole.Feature, WeightColumn.Name)); + + var trainingData = new RoleMappedData(input, roles); + + RoleMappedData validData = null; + if (Context != null) + validData = new RoleMappedData(Context.ValidationSet, roles); + + using (var ch = Host.Start("Training")) + using (var pch = Host.StartProgressChannel("Training")) + { + var pred = TrainCore(ch, pch, trainingData, validData, Context?.InitialPredictor as FieldAwareFactorizationMachinePredictor); + ch.Done(); + model = pred; + } + return new BinaryPredictionTransformer(Host, model, input.Schema, FeatureColumns.Select(x => x.Name).ToArray() ); + } + + public SchemaShape GetOutputSchema(SchemaShape inputSchema) + { + Host.CheckValue(inputSchema, nameof(inputSchema)); + + if (LabelColumn != null) + { + if (!inputSchema.TryFindColumn(LabelColumn.Name, out var labelCol)) + throw Host.ExceptSchemaMismatch(nameof(labelCol), DefaultColumnNames.PredictedLabel, DefaultColumnNames.PredictedLabel); + + if (!LabelColumn.IsCompatibleWith(labelCol)) + throw Host.Except($"Label column '{LabelColumn.Name}' is not compatible"); + } + + var outColumns = inputSchema.Columns.ToDictionary(x => x.Name); + foreach (var col in GetOutputColumnsCore(inputSchema)) + outColumns[col.Name] = col; + + return new SchemaShape(outColumns.Values); + } + + private SchemaShape.Column[] GetOutputColumnsCore(SchemaShape inputSchema) + { + bool success = inputSchema.TryFindColumn(LabelColumn.Name, out var labelCol); + Contracts.Assert(success); + + // var metadata = new SchemaShape(labelCol.Metadata.Columns.Where(x => x.Name == MetadataUtils.Kinds.KeyValues) + // .Concat(MetadataForScoreColumn())); + return new[] + { + new SchemaShape.Column(DefaultColumnNames.Score, SchemaShape.Column.VectorKind.Vector, NumberType.R4, false), + new SchemaShape.Column(DefaultColumnNames.PredictedLabel, SchemaShape.Column.VectorKind.Scalar, NumberType.U4, true) + }; + } } } diff --git a/src/Microsoft.ML.StandardLearners/FactorizationMachine/FieldAwareFactorizationMachinePredictor.cs b/src/Microsoft.ML.StandardLearners/FactorizationMachine/FieldAwareFactorizationMachinePredictor.cs index 37261cb55b..9c68956cf7 100644 --- a/src/Microsoft.ML.StandardLearners/FactorizationMachine/FieldAwareFactorizationMachinePredictor.cs +++ b/src/Microsoft.ML.StandardLearners/FactorizationMachine/FieldAwareFactorizationMachinePredictor.cs @@ -125,6 +125,9 @@ protected override void SaveCore(ModelSaveContext ctx) // float[]: linear coefficients // float[]: latent representation of features + // REVIEW:FAFM needs to store the names of the features, so that they prediction data does not have the + // restriciton of the columns needing to be ordered the same as the training data. + Host.Assert(FieldCount > 0); Host.Assert(FeatureCount > 0); Host.Assert(LatentDim > 0); @@ -163,9 +166,7 @@ internal float CalculateResponse(ValueGetter>[] getters, VBuffer< } public ISchemaBoundMapper Bind(IHostEnvironment env, RoleMappedSchema schema) - { - return new FieldAwareFactorizationMachineScalarRowMapper(env, schema, new BinaryClassifierSchema(), this); - } + => new FieldAwareFactorizationMachineScalarRowMapper(env, schema, new BinaryClassifierSchema(), this); internal void CopyLinearWeightsTo(float[] linearWeights) { diff --git a/src/Microsoft.ML.StandardLearners/Standard/LinearClassificationTrainer.cs b/src/Microsoft.ML.StandardLearners/Standard/LinearClassificationTrainer.cs index 7a862af977..ae2542463e 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/LinearClassificationTrainer.cs +++ b/src/Microsoft.ML.StandardLearners/Standard/LinearClassificationTrainer.cs @@ -1490,7 +1490,7 @@ protected override void CheckLabel(RoleMappedData examples, out int weightSetCou } protected override BinaryPredictionTransformer MakeTransformer(TScalarPredictor model, ISchema trainSchema) - => new BinaryPredictionTransformer(Host, model, trainSchema, FeatureColumn.Name); + => new BinaryPredictionTransformer(Host, model, trainSchema, new[] { FeatureColumn.Name }); public BinaryPredictionTransformer Train(IDataView trainData, IDataView validationData = null, IPredictor initialPredictor = null) => TrainTransformer(trainData, validationData, initialPredictor); } diff --git a/src/Microsoft.ML.StandardLearners/Standard/MultiClass/Ova.cs b/src/Microsoft.ML.StandardLearners/Standard/MultiClass/Ova.cs index 4f9416ecef..b4b902d071 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/MultiClass/Ova.cs +++ b/src/Microsoft.ML.StandardLearners/Standard/MultiClass/Ova.cs @@ -127,7 +127,7 @@ private IPredictionTransformer TrainOne(IChannel ch, TScalarTr // REVIEW: restoring the RoleMappedData, as much as we can. // not having the weight column on the data passed to the TrainCalibrator should be addressed. - var trainedData = new RoleMappedData(view, label: trainerLabel, feature: transformer.FeatureColumn); + var trainedData = new RoleMappedData(view, label: trainerLabel, feature: transformer.FeatureColumn[0]); if (calibratedModel == null) calibratedModel = CalibratorUtils.TrainCalibrator(Host, ch, Calibrator, Args.MaxCalibrationExamples, transformer.Model, trainedData) as TDistPredictor; @@ -185,7 +185,7 @@ public override MulticlassPredictionTransformer Fit(IDataView inpu if (i == 0) { var transformer = TrainOne(ch, GetTrainer(), td, i); - featureColumn = transformer.FeatureColumn; + featureColumn = transformer.FeatureColumn[0]; } predictors[i] = TrainOne(ch, GetTrainer(), td, i).Model; diff --git a/src/Microsoft.ML.StandardLearners/Standard/MultiClass/Pkpd.cs b/src/Microsoft.ML.StandardLearners/Standard/MultiClass/Pkpd.cs index 9e7063cd70..79362009c7 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/MultiClass/Pkpd.cs +++ b/src/Microsoft.ML.StandardLearners/Standard/MultiClass/Pkpd.cs @@ -129,7 +129,7 @@ private IPredictionTransformer TrainOne(IChannel ch, TScalarTrai var transformer = trainer.Fit(view); // the validations in the calibrator check for the feature column, in the RoleMappedData - var trainedData = new RoleMappedData(view, label: trainerLabel, feature: transformer.FeatureColumn); + var trainedData = new RoleMappedData(view, label: trainerLabel, feature: transformer.FeatureColumn[0]); var calibratedModel = transformer.Model as TDistPredictor; if (calibratedModel == null) @@ -198,7 +198,7 @@ public override TTransformer Fit(IDataView input) if (i == 0 && j == 0) { var transformer = TrainOne(ch, GetTrainer(), td, i, j); - featureColumn = transformer.FeatureColumn; + featureColumn = transformer.FeatureColumn[0]; } predictors[i][j] = TrainOne(ch, GetTrainer(), td, i, j).Model; diff --git a/src/Microsoft.ML.StandardLearners/Standard/Online/AveragedPerceptron.cs b/src/Microsoft.ML.StandardLearners/Standard/Online/AveragedPerceptron.cs index f371322dca..07fd177059 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/Online/AveragedPerceptron.cs +++ b/src/Microsoft.ML.StandardLearners/Standard/Online/AveragedPerceptron.cs @@ -122,7 +122,7 @@ protected override LinearBinaryPredictor CreatePredictor() } protected override BinaryPredictionTransformer MakeTransformer(LinearBinaryPredictor model, ISchema trainSchema) - => new BinaryPredictionTransformer(Host, model, trainSchema, FeatureColumn.Name); + => new BinaryPredictionTransformer(Host, model, trainSchema, new[] { FeatureColumn.Name }); [TlcModule.EntryPoint(Name = "Trainers.AveragedPerceptronBinaryClassifier", Desc = Summary, diff --git a/src/Microsoft.ML.StandardLearners/Standard/Online/LinearSvm.cs b/src/Microsoft.ML.StandardLearners/Standard/Online/LinearSvm.cs index ad8442433d..ab27c4e283 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/Online/LinearSvm.cs +++ b/src/Microsoft.ML.StandardLearners/Standard/Online/LinearSvm.cs @@ -260,6 +260,6 @@ public static CommonOutputs.BinaryClassificationOutput TrainLinearSvm(IHostEnvir } protected override BinaryPredictionTransformer MakeTransformer(LinearBinaryPredictor model, ISchema trainSchema) - => new BinaryPredictionTransformer(Host, model, trainSchema, FeatureColumn.Name); + => new BinaryPredictionTransformer(Host, model, trainSchema, new[] { FeatureColumn.Name }); } } diff --git a/test/Microsoft.ML.TestFramework/DataPipe/TestDataPipeBase.cs b/test/Microsoft.ML.TestFramework/DataPipe/TestDataPipeBase.cs index 64c85d88b2..5974b11021 100644 --- a/test/Microsoft.ML.TestFramework/DataPipe/TestDataPipeBase.cs +++ b/test/Microsoft.ML.TestFramework/DataPipe/TestDataPipeBase.cs @@ -123,8 +123,8 @@ protected void TestEstimatorCore(IEstimator estimator, // and original transformer. // This in turn means that the schema of loaded transformer matches for // Transform and GetOutputSchema calls. - CheckSameSchemas(scoredTrain.Schema, scoredTrain2.Schema); - CheckSameValues(scoredTrain, scoredTrain2); + CheckSameSchemas(scoredTrain.Schema, scoredTrain2.Schema); + CheckSameValues(scoredTrain, scoredTrain2); }; checkOnData(validFitInput); diff --git a/test/Microsoft.ML.Tests/Microsoft.ML.Tests.csproj b/test/Microsoft.ML.Tests/Microsoft.ML.Tests.csproj index 7984c79f63..69480c9311 100644 --- a/test/Microsoft.ML.Tests/Microsoft.ML.Tests.csproj +++ b/test/Microsoft.ML.Tests/Microsoft.ML.Tests.csproj @@ -25,6 +25,7 @@ + diff --git a/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/Wrappers.cs b/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/Wrappers.cs index 86799ce445..3e492106ed 100644 --- a/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/Wrappers.cs +++ b/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/Wrappers.cs @@ -177,9 +177,9 @@ public TransformWrapper(IHostEnvironment env, ModelLoadContext ctx) public class ScorerWrapper : TransformWrapper, IPredictionTransformer where TModel : IPredictor { - protected readonly string _featureColumn; + protected readonly string[] _featureColumn; - public ScorerWrapper(IHostEnvironment env, IDataView scorer, TModel trainedModel, string featureColumn) + public ScorerWrapper(IHostEnvironment env, IDataView scorer, TModel trainedModel, string[] featureColumn) : base(env, scorer) { _featureColumn = featureColumn; @@ -188,20 +188,20 @@ public ScorerWrapper(IHostEnvironment env, IDataView scorer, TModel trainedModel public TModel Model { get; } - public string FeatureColumn => _featureColumn; + public string[] FeatureColumn => _featureColumn; - public ColumnType FeatureColumnType => throw _env.ExceptNotSupp(); + public ColumnType[] FeatureColumnType => throw _env.ExceptNotSupp(); } public class BinaryScorerWrapper : ScorerWrapper where TModel : IPredictor { - public BinaryScorerWrapper(IHostEnvironment env, TModel model, ISchema inputSchema, string featureColumn, BinaryClassifierScorer.Arguments args) + public BinaryScorerWrapper(IHostEnvironment env, TModel model, ISchema inputSchema, string[] featureColumn, BinaryClassifierScorer.Arguments args) : base(env, MakeScorer(env, inputSchema, featureColumn, model, args), model, featureColumn) { } - private static IDataView MakeScorer(IHostEnvironment env, ISchema schema, string featureColumn, TModel model, BinaryClassifierScorer.Arguments args) + private static IDataView MakeScorer(IHostEnvironment env, ISchema schema, string[] featureColumn, TModel model, BinaryClassifierScorer.Arguments args) { var settings = $"Binary{{{CmdParser.GetSettings(env, args, new BinaryClassifierScorer.Arguments())}}}"; @@ -212,7 +212,14 @@ private static IDataView MakeScorer(IHostEnvironment env, ISchema schema, string var bindable = ScoreUtils.GetSchemaBindableMapper(env, model, scorerFactorySettings: scorerFactorySettings); var edv = new EmptyDataView(env, schema); - var data = new RoleMappedData(edv, "Label", featureColumn, opt: true); + + var roles = new List>(); + foreach (var feat in featureColumn) + roles.Add(new KeyValuePair(RoleMappedSchema.ColumnRole.Feature, feat)); + + roles.Add(new KeyValuePair(RoleMappedSchema.ColumnRole.Label, "Label")); + + var data = new RoleMappedData(edv, roles, opt: true); return new BinaryClassifierScorer(env, args, data.Data, bindable.Bind(env, data.Schema), data.Schema); } @@ -295,7 +302,7 @@ public SchemaShape GetOutputSchema(SchemaShape inputSchema) protected ScorerWrapper MakeScorerBasic(TModel predictor, RoleMappedData data) { var scorer = ScoreUtils.GetScorer(predictor, data, _env, data.Schema); - return (TTransformer)(new ScorerWrapper(_env, scorer, predictor, data.Schema.Feature.Name)); + return (TTransformer)(new ScorerWrapper(_env, scorer, predictor, new[] { data.Schema.Feature.Name })); } } diff --git a/test/Microsoft.ML.Tests/TrainerEstimators/FAFMEstimator.cs b/test/Microsoft.ML.Tests/TrainerEstimators/FAFMEstimator.cs new file mode 100644 index 0000000000..fd8d729fd8 --- /dev/null +++ b/test/Microsoft.ML.Tests/TrainerEstimators/FAFMEstimator.cs @@ -0,0 +1,48 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using Microsoft.ML.Core.Data; +using Microsoft.ML.Runtime.Data; +using Microsoft.ML.Runtime.FactorizationMachine; +using Microsoft.ML.Runtime.Learners; +using Microsoft.ML.Runtime.RunTests; +using Xunit; +using Xunit.Abstractions; + +namespace Microsoft.ML.Tests.TrainerEstimators +{ + public partial class TrainerEstimators : TestDataPipeBase + { + [Fact] + public void FieldAwareFactorizationMachine_Estimator() + { + var data = new TextLoader(Env, GetFafmBCLoaderArgs()) + .Read(new MultiFileSource(GetDataPath(TestDatasets.breastCancer.trainFilename))); + + IEstimator est = new FieldAwareFactorizationMachineTrainer(Env, "Label", new[] { "Feature1", "Feature2", "Feature3", "Feature4" }); + + //var result = est.Fit(data); + TestEstimatorCore(est, data); + + Done(); + } + + private TextLoader.Arguments GetFafmBCLoaderArgs() + { + return new TextLoader.Arguments() + { + Separator = "\t", + HasHeader = false, + Column = new[] + { + new TextLoader.Column("Feature1", DataKind.R4, new [] { new TextLoader.Range(1, 2) }), + new TextLoader.Column("Feature2", DataKind.R4, new [] { new TextLoader.Range(3, 4) }), + new TextLoader.Column("Feature3", DataKind.R4, new [] { new TextLoader.Range(5, 6) }), + new TextLoader.Column("Feature4", DataKind.R4, new [] { new TextLoader.Range(7, 9) }), + new TextLoader.Column("Label", DataKind.BL, 0) + } + }; + } + } +} diff --git a/test/Microsoft.ML.Tests/TrainerEstimators/MetalinearEstimators.cs b/test/Microsoft.ML.Tests/TrainerEstimators/MetalinearEstimators.cs index 91fc72185d..985aaeff0d 100644 --- a/test/Microsoft.ML.Tests/TrainerEstimators/MetalinearEstimators.cs +++ b/test/Microsoft.ML.Tests/TrainerEstimators/MetalinearEstimators.cs @@ -12,14 +12,8 @@ namespace Microsoft.ML.Tests.TrainerEstimators { - public partial class MetalinearEstimators : TestDataPipeBase + public partial class TrainerEstimators { - - public MetalinearEstimators(ITestOutputHelper output) : base(output) - { - } - - /// /// OVA with calibrator argument /// diff --git a/test/Microsoft.ML.Tests/TrainerEstimators/OnlineLinearTests.cs b/test/Microsoft.ML.Tests/TrainerEstimators/OnlineLinearTests.cs index cd0cb3fe94..2dea99e131 100644 --- a/test/Microsoft.ML.Tests/TrainerEstimators/OnlineLinearTests.cs +++ b/test/Microsoft.ML.Tests/TrainerEstimators/OnlineLinearTests.cs @@ -4,19 +4,14 @@ using Microsoft.ML.Core.Data; using Microsoft.ML.Runtime.Data; +using Microsoft.ML.Runtime.FactorizationMachine; using Microsoft.ML.Runtime.Learners; -using Microsoft.ML.Runtime.RunTests; using Xunit; -using Xunit.Abstractions; -namespace Microsoft.ML.Tests.Transformers +namespace Microsoft.ML.Tests.TrainerEstimators { - public sealed class OnlineLinearTests : TestDataPipeBase + public partial class TrainerEstimators { - public OnlineLinearTests(ITestOutputHelper helper) : base(helper) - { - } - [Fact(Skip = "AP is now uncalibrated but advertises as calibrated")] public void OnlineLinearWorkout() { diff --git a/test/Microsoft.ML.Tests/TrainerEstimators/SdcaTests.cs b/test/Microsoft.ML.Tests/TrainerEstimators/SdcaTests.cs index 6a31a38237..0eab4c7e1b 100644 --- a/test/Microsoft.ML.Tests/TrainerEstimators/SdcaTests.cs +++ b/test/Microsoft.ML.Tests/TrainerEstimators/SdcaTests.cs @@ -9,14 +9,10 @@ using Xunit; using Xunit.Abstractions; -namespace Microsoft.ML.Tests.Transformers +namespace Microsoft.ML.Tests.TrainerEstimators { - public sealed class SdcaTests : TestDataPipeBase + public partial class TrainerEstimators { - public SdcaTests(ITestOutputHelper helper) : base(helper) - { - } - [Fact] public void SdcaWorkout() { diff --git a/test/Microsoft.ML.Tests/TrainerEstimators/TrainerEstimators.cs b/test/Microsoft.ML.Tests/TrainerEstimators/TrainerEstimators.cs new file mode 100644 index 0000000000..dc28fccc97 --- /dev/null +++ b/test/Microsoft.ML.Tests/TrainerEstimators/TrainerEstimators.cs @@ -0,0 +1,20 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using Microsoft.ML.Core.Data; +using Microsoft.ML.Runtime.Data; +using Microsoft.ML.Runtime.Learners; +using Microsoft.ML.Runtime.RunTests; +using Xunit; +using Xunit.Abstractions; + +namespace Microsoft.ML.Tests.TrainerEstimators +{ + public partial class TrainerEstimators : TestDataPipeBase + { + public TrainerEstimators(ITestOutputHelper helper) : base(helper) + { + } + } +} From 67f41a3b8504eb1f7cfdf3400443b3d5bfef4f66 Mon Sep 17 00:00:00 2001 From: Senja Filipi Date: Mon, 17 Sep 2018 17:39:51 -0700 Subject: [PATCH 03/10] incorporating metadata in the columns --- .../FactorizationMachineTrainer.cs | 36 ++++++++++--------- .../Standard/Online/AveragedPerceptron.cs | 4 +-- 2 files changed, 20 insertions(+), 20 deletions(-) diff --git a/src/Microsoft.ML.StandardLearners/FactorizationMachine/FactorizationMachineTrainer.cs b/src/Microsoft.ML.StandardLearners/FactorizationMachine/FactorizationMachineTrainer.cs index ad1b9fed5f..8dd18ac0c5 100644 --- a/src/Microsoft.ML.StandardLearners/FactorizationMachine/FactorizationMachineTrainer.cs +++ b/src/Microsoft.ML.StandardLearners/FactorizationMachine/FactorizationMachineTrainer.cs @@ -116,7 +116,6 @@ public sealed class Arguments : LearnerInputBaseWithLabel private bool _shuffle; private bool _verbose; private float _radius; - private SchemaShape.Column[] _outputColumns; /// /// Legacy constructor initializing a new instance of through the legacy @@ -159,13 +158,6 @@ public FieldAwareFactorizationMachineTrainer(IHostEnvironment env, string labelC LabelColumn = new SchemaShape.Column(labelColumn, SchemaShape.Column.VectorKind.Scalar, BoolType.Instance, false); WeightColumn = weightColumn != null? new SchemaShape.Column(weightColumn, SchemaShape.Column.VectorKind.Scalar, NumberType.R4, false): null; - - _outputColumns = new[] - { - new SchemaShape.Column(DefaultColumnNames.Score, SchemaShape.Column.VectorKind.Scalar, NumberType.R4, false), - new SchemaShape.Column(DefaultColumnNames.Probability, SchemaShape.Column.VectorKind.Scalar, NumberType.R4, false), - new SchemaShape.Column(DefaultColumnNames.PredictedLabel, SchemaShape.Column.VectorKind.Scalar, BoolType.Instance, false) - }; } /// @@ -499,15 +491,26 @@ public SchemaShape GetOutputSchema(SchemaShape inputSchema) { Host.CheckValue(inputSchema, nameof(inputSchema)); + void CheckColumnsCompatible(SchemaShape.Column column, string defaultName){ + + if (!inputSchema.TryFindColumn(column.Name, out var col)) + throw Host.ExceptSchemaMismatch(nameof(col), defaultName, defaultName); + + if (!column.IsCompatibleWith(col)) + throw Host.Except($"{defaultName} column '{column.Name}' is not compatible"); + } + if (LabelColumn != null) - { - if (!inputSchema.TryFindColumn(LabelColumn.Name, out var labelCol)) - throw Host.ExceptSchemaMismatch(nameof(labelCol), DefaultColumnNames.PredictedLabel, DefaultColumnNames.PredictedLabel); + CheckColumnsCompatible(LabelColumn, DefaultColumnNames.Label); - if (!LabelColumn.IsCompatibleWith(labelCol)) - throw Host.Except($"Label column '{LabelColumn.Name}' is not compatible"); + foreach (var feat in FeatureColumns) + { + CheckColumnsCompatible(feat, DefaultColumnNames.Features); } + if (WeightColumn != null) + CheckColumnsCompatible(WeightColumn, DefaultColumnNames.Weight); + var outColumns = inputSchema.Columns.ToDictionary(x => x.Name); foreach (var col in GetOutputColumnsCore(inputSchema)) outColumns[col.Name] = col; @@ -520,12 +523,11 @@ private SchemaShape.Column[] GetOutputColumnsCore(SchemaShape inputSchema) bool success = inputSchema.TryFindColumn(LabelColumn.Name, out var labelCol); Contracts.Assert(success); - // var metadata = new SchemaShape(labelCol.Metadata.Columns.Where(x => x.Name == MetadataUtils.Kinds.KeyValues) - // .Concat(MetadataForScoreColumn())); return new[] { - new SchemaShape.Column(DefaultColumnNames.Score, SchemaShape.Column.VectorKind.Vector, NumberType.R4, false), - new SchemaShape.Column(DefaultColumnNames.PredictedLabel, SchemaShape.Column.VectorKind.Scalar, NumberType.U4, true) + new SchemaShape.Column(DefaultColumnNames.Score, SchemaShape.Column.VectorKind.Scalar, NumberType.R4, false, new SchemaShape(MetadataUtils.GetTrainerOutputMetadata())), + new SchemaShape.Column(DefaultColumnNames.Probability, SchemaShape.Column.VectorKind.Scalar, NumberType.R4, false, new SchemaShape(MetadataUtils.GetTrainerOutputMetadata(true))), + new SchemaShape.Column(DefaultColumnNames.PredictedLabel, SchemaShape.Column.VectorKind.Scalar, BoolType.Instance, false, new SchemaShape(MetadataUtils.GetTrainerOutputMetadata())) }; } } diff --git a/src/Microsoft.ML.StandardLearners/Standard/Online/AveragedPerceptron.cs b/src/Microsoft.ML.StandardLearners/Standard/Online/AveragedPerceptron.cs index 07fd177059..c3bdfa0bf6 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/Online/AveragedPerceptron.cs +++ b/src/Microsoft.ML.StandardLearners/Standard/Online/AveragedPerceptron.cs @@ -95,9 +95,7 @@ protected override void CheckLabelCompatible(SchemaShape.Column labelCol) } private static SchemaShape.Column MakeLabelColumn(string labelColumn) - { - return new SchemaShape.Column(labelColumn, SchemaShape.Column.VectorKind.Scalar, BoolType.Instance, false); - } + => new SchemaShape.Column(labelColumn, SchemaShape.Column.VectorKind.Scalar, BoolType.Instance, false); protected override LinearBinaryPredictor CreatePredictor() { From d4f541333eff84a6905972e7b8adbfd934345bbb Mon Sep 17 00:00:00 2001 From: Senja Filipi Date: Mon, 17 Sep 2018 18:15:26 -0700 Subject: [PATCH 04/10] checkign whether any new regressions due to the work. --- test/Microsoft.ML.Tests/TrainerEstimators/FAFMEstimator.cs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/Microsoft.ML.Tests/TrainerEstimators/FAFMEstimator.cs b/test/Microsoft.ML.Tests/TrainerEstimators/FAFMEstimator.cs index fd8d729fd8..37b4e65798 100644 --- a/test/Microsoft.ML.Tests/TrainerEstimators/FAFMEstimator.cs +++ b/test/Microsoft.ML.Tests/TrainerEstimators/FAFMEstimator.cs @@ -14,7 +14,7 @@ namespace Microsoft.ML.Tests.TrainerEstimators { public partial class TrainerEstimators : TestDataPipeBase { - [Fact] + [Fact(Skip ="Disabling to check whether any regressions, will enable prior to check-in")] public void FieldAwareFactorizationMachine_Estimator() { var data = new TextLoader(Env, GetFafmBCLoaderArgs()) From 26691e3ad96bcba8aabc7bcc290ce6d834530583 Mon Sep 17 00:00:00 2001 From: Senja Filipi Date: Tue, 18 Sep 2018 13:55:30 -0700 Subject: [PATCH 05/10] reverting the change on IPredictionTransformer taking an array of features. Splitting IPredictionTransformer into two interfaces Creating a transformer wrapping the FAFM predictor. --- .../Prediction/IPredictionTransformer.cs | 10 +- .../Scorers/PredictionTransformer.cs | 184 +++++++++--------- .../Training/ITrainerEstimator.cs | 2 +- .../Training/TrainerEstimatorBase.cs | 2 +- .../FactorizationMachineTrainer.cs | 7 +- ...FieldAwareFactorizationMachinePredictor.cs | 163 ++++++++++++++++ .../Standard/LinearClassificationTrainer.cs | 4 +- .../MultiClass/MetaMulticlassTrainer.cs | 4 +- .../Standard/MultiClass/Ova.cs | 8 +- .../Standard/MultiClass/Pkpd.cs | 8 +- .../Standard/Online/AveragedLinear.cs | 2 +- .../Standard/Online/AveragedPerceptron.cs | 2 +- .../Standard/Online/LinearSvm.cs | 2 +- .../Standard/Online/OnlineLinear.cs | 2 +- .../Standard/StochasticTrainerBase.cs | 2 +- .../Scenarios/Api/Estimators/Wrappers.cs | 25 +-- .../TrainerEstimators/FAFMEstimator.cs | 4 +- 17 files changed, 290 insertions(+), 141 deletions(-) diff --git a/src/Microsoft.ML.Data/Prediction/IPredictionTransformer.cs b/src/Microsoft.ML.Data/Prediction/IPredictionTransformer.cs index a84aaf3f4e..1b93149783 100644 --- a/src/Microsoft.ML.Data/Prediction/IPredictionTransformer.cs +++ b/src/Microsoft.ML.Data/Prediction/IPredictionTransformer.cs @@ -11,10 +11,14 @@ namespace Microsoft.ML.Runtime public interface IPredictionTransformer : ITransformer where TModel : IPredictor { - string[] FeatureColumn { get; } + TModel Model { get; } + } - ColumnType[] FeatureColumnType { get; } + public interface IClassicPredictionTransformer : IPredictionTransformer + where TModel : IPredictor + { + string FeatureColumn { get; } - TModel Model { get; } + ColumnType FeatureColumnType { get; } } } diff --git a/src/Microsoft.ML.Data/Scorers/PredictionTransformer.cs b/src/Microsoft.ML.Data/Scorers/PredictionTransformer.cs index 251f273952..6d91277fb1 100644 --- a/src/Microsoft.ML.Data/Scorers/PredictionTransformer.cs +++ b/src/Microsoft.ML.Data/Scorers/PredictionTransformer.cs @@ -3,13 +3,12 @@ // See the LICENSE file in the project root for more information. using System; -using System.Collections.Generic; using System.IO; +using Microsoft.ML.Core.Data; using Microsoft.ML.Runtime; using Microsoft.ML.Runtime.Data; using Microsoft.ML.Runtime.Data.IO; using Microsoft.ML.Runtime.Model; -using static Microsoft.ML.Runtime.Data.RoleMappedSchema; [assembly: LoadableClass(typeof(BinaryPredictionTransformer>), typeof(BinaryPredictionTransformer), null, typeof(SignatureLoadModel), "", BinaryPredictionTransformer.LoaderSignature)] @@ -22,49 +21,33 @@ namespace Microsoft.ML.Runtime.Data { - public abstract class PredictionTransformerBase : IPredictionTransformer, ICanSaveModel + + public abstract class PredictionTransformerBase : IPredictionTransformer where TModel : class, IPredictor { - private const string DirModel = "Model"; - private const string DirTransSchema = "TrainSchema"; + /// + /// The model. + /// + public TModel Model { get; } + protected const string DirModel = "Model"; + protected const string DirTransSchema = "TrainSchema"; protected readonly IHost Host; - protected readonly ISchemaBindableMapper BindableMapper; - protected readonly ISchema TrainSchema; - - public string[] FeatureColumn { get; } - - public ColumnType[] FeatureColumnType { get; } + protected ISchemaBindableMapper BindableMapper; + protected ISchema TrainSchema; - public TModel Model { get; } - - public PredictionTransformerBase(IHost host, TModel model, ISchema trainSchema, string[] featureColumns) + protected PredictionTransformerBase(IHost host, TModel model, ISchema trainSchema) { Contracts.CheckValue(host, nameof(host)); Host = host; Host.CheckValue(trainSchema, nameof(trainSchema)); - Host.CheckValue(featureColumns, nameof(featureColumns)); - - int featCount = featureColumns.Length; - Host.Check(featCount >= 0 , "Empty features column."); Model = model; - FeatureColumn = featureColumns; - FeatureColumnType = new ColumnType[featCount]; - - int i = 0; - foreach (var feat in featureColumns) - { - if (!trainSchema.TryGetColumnIndex(feat, out int col)) - throw Host.ExceptSchemaMismatch(nameof(featureColumns), RoleMappedSchema.ColumnRole.Feature.Value, feat); - FeatureColumnType[i++] = trainSchema.GetColumnType(col); - } - TrainSchema = trainSchema; - BindableMapper = ScoreUtils.GetSchemaBindableMapper(Host, model); } - internal PredictionTransformerBase(IHost host, ModelLoadContext ctx) + protected PredictionTransformerBase(IHost host, ModelLoadContext ctx) + { Host = host; @@ -74,8 +57,7 @@ internal PredictionTransformerBase(IHost host, ModelLoadContext ctx) // *** Binary format *** // model: prediction model. // stream: empty data view that contains train schema. - // count of features - // id of string: feature columns. + // id of string: feature column. // Clone the stream with the schema into memory. var ms = new MemoryStream(); @@ -87,56 +69,22 @@ internal PredictionTransformerBase(IHost host, ModelLoadContext ctx) ms.Position = 0; var loader = new BinaryLoader(host, new BinaryLoader.Arguments(), ms); TrainSchema = loader.Schema; - - // count of feature columns. FAFM uses more than one. - int featCount = int.Parse(ctx.LoadString()); - - FeatureColumn = new string[featCount]; - FeatureColumnType = new ColumnType[featCount]; - - for (int i = 0; i < featCount; i++) - { - FeatureColumn[i] = ctx.LoadString(); - if (!TrainSchema.TryGetColumnIndex(FeatureColumn[i], out int col)) - throw Host.ExceptSchemaMismatch(nameof(FeatureColumn), RoleMappedSchema.ColumnRole.Feature.Value, FeatureColumn[i]); - FeatureColumnType[i] = TrainSchema.GetColumnType(col); - } - - BindableMapper = ScoreUtils.GetSchemaBindableMapper(Host, model); } - public ISchema GetOutputSchema(ISchema inputSchema) - { - Host.CheckValue(inputSchema, nameof(inputSchema)); - - for (int i=0; i< FeatureColumn.Length; i++) - { - var feat = FeatureColumn[i]; - if (!inputSchema.TryGetColumnIndex(feat, out int col)) - throw Host.ExceptSchemaMismatch(nameof(inputSchema), RoleMappedSchema.ColumnRole.Feature.Value, feat, FeatureColumnType[i].ToString(), null); - - if (!inputSchema.GetColumnType(col).Equals(FeatureColumnType[i])) - throw Host.ExceptSchemaMismatch(nameof(inputSchema), RoleMappedSchema.ColumnRole.Feature.Value, feat, FeatureColumnType[i].ToString(), inputSchema.GetColumnType(col).ToString()); - } - - return Transform(new EmptyDataView(Host, inputSchema)).Schema; - } + public abstract ISchema GetOutputSchema(ISchema inputSchema); + /// + /// Transforms the input data. + /// + /// + /// public abstract IDataView Transform(IDataView input); - public void Save(ModelSaveContext ctx) - { - Host.CheckValue(ctx, nameof(ctx)); - ctx.CheckAtModel(); - SaveCore(ctx); - } - - protected virtual void SaveCore(ModelSaveContext ctx) + protected void SaveModel(ModelSaveContext ctx) { // *** Binary format *** // model: prediction model. // stream: empty data view that contains train schema. - // number of feature columns // id of string: feature column. ctx.SaveModel(Model, DirModel); @@ -148,29 +96,71 @@ protected virtual void SaveCore(ModelSaveContext ctx) DataSaverUtils.SaveDataView(ch, saver, new EmptyDataView(Host, TrainSchema), writer.BaseStream); } }); + } + } - int featCount = FeatureColumn.Length; + public abstract class ClassicPredictionTransformerBase : PredictionTransformerBase, IClassicPredictionTransformer, ICanSaveModel + where TModel : class, IPredictor + { + /// + /// The name of the feature column used by the prediction transformer. + /// + public string FeatureColumn { get; } + + /// + /// The type of the prediction transformer + /// + public ColumnType FeatureColumnType { get; } + + public ClassicPredictionTransformerBase(IHost host, TModel model, ISchema trainSchema, string featureColumn) + :base(host, model, trainSchema) + { + FeatureColumn = featureColumn; + if (!trainSchema.TryGetColumnIndex(featureColumn, out int col)) + throw Host.ExceptSchemaMismatch(nameof(featureColumn), RoleMappedSchema.ColumnRole.Feature.Value, featureColumn); + FeatureColumnType = trainSchema.GetColumnType(col); - ctx.SaveString(featCount.ToString()); - for(int i=0; i< featCount; i++) - ctx.SaveString(FeatureColumn[i]); + BindableMapper = ScoreUtils.GetSchemaBindableMapper(Host, model); } - protected RoleMappedSchema GetSchema(ISchema inputSchema = null, string trainLabelColumn = null) + internal ClassicPredictionTransformerBase(IHost host, ModelLoadContext ctx) + :base(host, ctx) { - var roles = new List>(); - foreach (var feat in FeatureColumn) - roles.Add(new KeyValuePair(ColumnRole.Feature, feat)); + FeatureColumn = ctx.LoadString(); + if (!TrainSchema.TryGetColumnIndex(FeatureColumn, out int col)) + throw Host.ExceptSchemaMismatch(nameof(FeatureColumn), RoleMappedSchema.ColumnRole.Feature.Value, FeatureColumn); + FeatureColumnType = TrainSchema.GetColumnType(col); - if(trainLabelColumn !=null) - roles.Add(new KeyValuePair(ColumnRole.Label, trainLabelColumn)); + BindableMapper = ScoreUtils.GetSchemaBindableMapper(Host, Model); + } + + public override ISchema GetOutputSchema(ISchema inputSchema) + { + Host.CheckValue(inputSchema, nameof(inputSchema)); - var schema = new RoleMappedSchema(inputSchema ?? TrainSchema, roles); - return schema; + if (!inputSchema.TryGetColumnIndex(FeatureColumn, out int col)) + throw Host.ExceptSchemaMismatch(nameof(inputSchema), RoleMappedSchema.ColumnRole.Feature.Value, FeatureColumn, FeatureColumnType.ToString(), null); + if (!inputSchema.GetColumnType(col).Equals(FeatureColumnType)) + throw Host.ExceptSchemaMismatch(nameof(inputSchema), RoleMappedSchema.ColumnRole.Feature.Value, FeatureColumn, FeatureColumnType.ToString(), inputSchema.GetColumnType(col).ToString()); + + return Transform(new EmptyDataView(Host, inputSchema)).Schema; + } + + public void Save(ModelSaveContext ctx) + { + Host.CheckValue(ctx, nameof(ctx)); + ctx.CheckAtModel(); + SaveCore(ctx); + } + + protected virtual void SaveCore(ModelSaveContext ctx) + { + SaveModel(ctx); + ctx.SaveString(FeatureColumn); } } - public sealed class BinaryPredictionTransformer : PredictionTransformerBase + public sealed class BinaryPredictionTransformer : ClassicPredictionTransformerBase where TModel : class, IPredictorProducing { private readonly BinaryClassifierScorer _scorer; @@ -178,12 +168,12 @@ public sealed class BinaryPredictionTransformer : PredictionTransformerB public readonly string ThresholdColumn; public readonly float Threshold; - public BinaryPredictionTransformer(IHostEnvironment env, TModel model, ISchema inputSchema, string[] featureColumn, + public BinaryPredictionTransformer(IHostEnvironment env, TModel model, ISchema inputSchema, string featureColumn, float threshold = 0f, string thresholdColumn = DefaultColumnNames.Score) : base(Contracts.CheckRef(env, nameof(env)).Register(nameof(BinaryPredictionTransformer)), model, inputSchema, featureColumn) { Host.CheckNonEmpty(thresholdColumn, nameof(thresholdColumn)); - var schema = GetSchema(inputSchema); + var schema = new RoleMappedSchema(inputSchema, null, featureColumn); Threshold = threshold; ThresholdColumn = thresholdColumn; @@ -202,7 +192,7 @@ public BinaryPredictionTransformer(IHostEnvironment env, ModelLoadContext ctx) Threshold = ctx.Reader.ReadSingle(); ThresholdColumn = ctx.LoadString(); - var schema = GetSchema(); + var schema = new RoleMappedSchema(TrainSchema, null, FeatureColumn); var args = new BinaryClassifierScorer.Arguments { Threshold = Threshold, ThresholdColumn = ThresholdColumn }; _scorer = new BinaryClassifierScorer(Host, args, new EmptyDataView(Host, TrainSchema), BindableMapper.Bind(Host, schema), schema); } @@ -239,14 +229,14 @@ private static VersionInfo GetVersionInfo() } } - public sealed class MulticlassPredictionTransformer : PredictionTransformerBase + public sealed class MulticlassPredictionTransformer : ClassicPredictionTransformerBase where TModel : class, IPredictorProducing> { private readonly MultiClassClassifierScorer _scorer; private readonly string _trainLabelColumn; public MulticlassPredictionTransformer(IHostEnvironment env, TModel model, ISchema inputSchema, string featureColumn, string labelColumn) - : base(Contracts.CheckRef(env, nameof(env)).Register(nameof(MulticlassPredictionTransformer)), model, inputSchema, new[] { featureColumn }) + : base(Contracts.CheckRef(env, nameof(env)).Register(nameof(MulticlassPredictionTransformer)), model, inputSchema, featureColumn) { Host.CheckValueOrNull(labelColumn); @@ -265,7 +255,7 @@ public MulticlassPredictionTransformer(IHostEnvironment env, ModelLoadContext ct _trainLabelColumn = ctx.LoadStringOrNull(); - var schema = GetSchema(trainLabelColumn: _trainLabelColumn); + var schema = new RoleMappedSchema(TrainSchema, _trainLabelColumn, FeatureColumn); var args = new MultiClassClassifierScorer.Arguments(); _scorer = new MultiClassClassifierScorer(Host, args, new EmptyDataView(Host, TrainSchema), BindableMapper.Bind(Host, schema), schema); } @@ -300,13 +290,13 @@ private static VersionInfo GetVersionInfo() } } - public sealed class RegressionPredictionTransformer : PredictionTransformerBase + public sealed class RegressionPredictionTransformer : ClassicPredictionTransformerBase where TModel : class, IPredictorProducing { private readonly GenericScorer _scorer; public RegressionPredictionTransformer(IHostEnvironment env, TModel model, ISchema inputSchema, string featureColumn) - : base(Contracts.CheckRef(env, nameof(env)).Register(nameof(RegressionPredictionTransformer)), model, inputSchema, new[] { featureColumn }) + : base(Contracts.CheckRef(env, nameof(env)).Register(nameof(RegressionPredictionTransformer)), model, inputSchema, featureColumn) { var schema = new RoleMappedSchema(inputSchema, null, featureColumn); _scorer = new GenericScorer(Host, new GenericScorer.Arguments(), new EmptyDataView(Host, inputSchema), BindableMapper.Bind(Host, schema), schema); @@ -315,7 +305,7 @@ public RegressionPredictionTransformer(IHostEnvironment env, TModel model, ISche internal RegressionPredictionTransformer(IHostEnvironment env, ModelLoadContext ctx) : base(Contracts.CheckRef(env, nameof(env)).Register(nameof(RegressionPredictionTransformer)), ctx) { - var schema = GetSchema(); + var schema = new RoleMappedSchema(TrainSchema, null, FeatureColumn); _scorer = new GenericScorer(Host, new GenericScorer.Arguments(), new EmptyDataView(Host, TrainSchema), BindableMapper.Bind(Host, schema), schema); } @@ -369,4 +359,4 @@ internal static class RegressionPredictionTransformer public static RegressionPredictionTransformer> Create(IHostEnvironment env, ModelLoadContext ctx) => new RegressionPredictionTransformer>(env, ctx); } -} +} \ No newline at end of file diff --git a/src/Microsoft.ML.Data/Training/ITrainerEstimator.cs b/src/Microsoft.ML.Data/Training/ITrainerEstimator.cs index 4eb6ea1482..243013d67e 100644 --- a/src/Microsoft.ML.Data/Training/ITrainerEstimator.cs +++ b/src/Microsoft.ML.Data/Training/ITrainerEstimator.cs @@ -7,7 +7,7 @@ namespace Microsoft.ML.Runtime.Training { public interface ITrainerEstimator: IEstimator - where TTransformer: IPredictionTransformer + where TTransformer: IClassicPredictionTransformer where TPredictor: IPredictor { TrainerInfo Info { get; } diff --git a/src/Microsoft.ML.Data/Training/TrainerEstimatorBase.cs b/src/Microsoft.ML.Data/Training/TrainerEstimatorBase.cs index 02c8c60667..dccf25cf51 100644 --- a/src/Microsoft.ML.Data/Training/TrainerEstimatorBase.cs +++ b/src/Microsoft.ML.Data/Training/TrainerEstimatorBase.cs @@ -15,7 +15,7 @@ namespace Microsoft.ML.Runtime.Training /// It produces a 'prediction transformer'. /// public abstract class TrainerEstimatorBase : ITrainerEstimator, ITrainer - where TTransformer : IPredictionTransformer + where TTransformer : IClassicPredictionTransformer where TModel : IPredictor { /// diff --git a/src/Microsoft.ML.StandardLearners/FactorizationMachine/FactorizationMachineTrainer.cs b/src/Microsoft.ML.StandardLearners/FactorizationMachine/FactorizationMachineTrainer.cs index 8dd18ac0c5..2f2161fa12 100644 --- a/src/Microsoft.ML.StandardLearners/FactorizationMachine/FactorizationMachineTrainer.cs +++ b/src/Microsoft.ML.StandardLearners/FactorizationMachine/FactorizationMachineTrainer.cs @@ -34,7 +34,7 @@ namespace Microsoft.ML.Runtime.FactorizationMachine */ /// public sealed class FieldAwareFactorizationMachineTrainer : TrainerBase, - ITrainerEstimator, FieldAwareFactorizationMachinePredictor> + IEstimator { internal const string Summary = "Train a field-aware factorization machine for binary classification"; internal const string UserName = "Field-aware Factorization Machine"; @@ -457,7 +457,7 @@ public static CommonOutputs.BinaryClassificationOutput TrainBinary(IHostEnvironm () => LearnerEntryPointsUtils.FindColumn(host, input.TrainingData.Schema, input.LabelColumn)); } - public BinaryPredictionTransformer Fit(IDataView input) + public FieldAwareFactorizationMachinePredictionTransformer Fit(IDataView input) { FieldAwareFactorizationMachinePredictor model = null; @@ -484,11 +484,12 @@ public BinaryPredictionTransformer Fit( model = pred; } - return new BinaryPredictionTransformer(Host, model, input.Schema, FeatureColumns.Select(x => x.Name).ToArray() ); + return new FieldAwareFactorizationMachinePredictionTransformer(Host, model, input.Schema, FeatureColumns.Select(x => x.Name).ToArray() ); } public SchemaShape GetOutputSchema(SchemaShape inputSchema) { + Host.CheckValue(inputSchema, nameof(inputSchema)); void CheckColumnsCompatible(SchemaShape.Column column, string defaultName){ diff --git a/src/Microsoft.ML.StandardLearners/FactorizationMachine/FieldAwareFactorizationMachinePredictor.cs b/src/Microsoft.ML.StandardLearners/FactorizationMachine/FieldAwareFactorizationMachinePredictor.cs index 9c68956cf7..32866a0633 100644 --- a/src/Microsoft.ML.StandardLearners/FactorizationMachine/FieldAwareFactorizationMachinePredictor.cs +++ b/src/Microsoft.ML.StandardLearners/FactorizationMachine/FieldAwareFactorizationMachinePredictor.cs @@ -3,8 +3,11 @@ // See the LICENSE file in the project root for more information. using System; +using System.Collections.Generic; +using System.IO; using Microsoft.ML.Runtime; using Microsoft.ML.Runtime.Data; +using Microsoft.ML.Runtime.Data.IO; using Microsoft.ML.Runtime.FactorizationMachine; using Microsoft.ML.Runtime.Internal.CpuMath; using Microsoft.ML.Runtime.Internal.Internallearn; @@ -13,6 +16,9 @@ [assembly: LoadableClass(typeof(FieldAwareFactorizationMachinePredictor), null, typeof(SignatureLoadModel), "Field Aware Factorization Machine", FieldAwareFactorizationMachinePredictor.LoaderSignature)] +[assembly: LoadableClass(typeof(FieldAwareFactorizationMachinePredictionTransformer), typeof(FieldAwareFactorizationMachinePredictionTransformer), null, typeof(SignatureLoadModel), + "", FieldAwareFactorizationMachinePredictionTransformer.LoaderSignature)] + namespace Microsoft.ML.Runtime.FactorizationMachine { public sealed class FieldAwareFactorizationMachinePredictor : PredictorBase, ISchemaBindableMapper, ICanSaveModel @@ -182,4 +188,161 @@ internal void CopyLatentWeightsTo(AlignedArray latentWeights) latentWeights.CopyFrom(_latentWeightsAligned); } } + + public sealed class FieldAwareFactorizationMachinePredictionTransformer : PredictionTransformerBase, ICanSaveModel + { + public const string LoaderSignature = "FAFMPredXfer"; + + /// + /// The name of the feature column used by the prediction transformer. + /// + public string[] FeatureColumns { get; } + + /// + /// The type of the prediction transformer + /// + public ColumnType[] FeatureColumnTypes { get; } + + private readonly BinaryClassifierScorer _scorer; + + public readonly string ThresholdColumn; + public readonly float Threshold; + + public FieldAwareFactorizationMachinePredictionTransformer(IHostEnvironment host, FieldAwareFactorizationMachinePredictor model, ISchema trainSchema, + string[] featureColumns, float threshold = 0f, string thresholdColumn = DefaultColumnNames.Score) + :base(Contracts.CheckRef(host, nameof(host)).Register(nameof(FieldAwareFactorizationMachinePredictionTransformer)), model, trainSchema) + { + Host.CheckNonEmpty(thresholdColumn, nameof(thresholdColumn)); + Threshold = threshold; + ThresholdColumn = thresholdColumn; + + Host.CheckValue(featureColumns, nameof(featureColumns)); + int featCount = featureColumns.Length; + Host.Check(featCount >= 0, "Empty features column."); + + FeatureColumns = featureColumns; + FeatureColumnTypes = new ColumnType[featCount]; + + int i = 0; + foreach (var feat in featureColumns) + { + if (!trainSchema.TryGetColumnIndex(feat, out int col)) + throw Host.ExceptSchemaMismatch(nameof(featureColumns), RoleMappedSchema.ColumnRole.Feature.Value, feat); + FeatureColumnTypes[i++] = trainSchema.GetColumnType(col); + } + + BindableMapper = ScoreUtils.GetSchemaBindableMapper(Host, model); + + var schema = GetSchema(); + var args = new BinaryClassifierScorer.Arguments { Threshold = Threshold, ThresholdColumn = ThresholdColumn }; + _scorer = new BinaryClassifierScorer(Host, args, new EmptyDataView(Host, trainSchema), BindableMapper.Bind(Host, schema), schema); + } + + public FieldAwareFactorizationMachinePredictionTransformer(IHostEnvironment host, ModelLoadContext ctx) + :base(Contracts.CheckRef(host, nameof(host)).Register(nameof(FieldAwareFactorizationMachinePredictionTransformer)), ctx) + { + // *** Binary format *** + // + // ids of strings: feature columns. + // float: scorer threshold + // id of string: scorer threshold column + + // count of feature columns. FAFM uses more than one. + int featCount = Model.FieldCount; + + FeatureColumns = new string[featCount]; + FeatureColumnTypes = new ColumnType[featCount]; + + for (int i = 0; i < featCount; i++) + { + FeatureColumns[i] = ctx.LoadString(); + if (!TrainSchema.TryGetColumnIndex(FeatureColumns[i], out int col)) + throw Host.ExceptSchemaMismatch(nameof(FeatureColumns), RoleMappedSchema.ColumnRole.Feature.Value, FeatureColumns[i]); + FeatureColumnTypes[i] = TrainSchema.GetColumnType(col); + } + + Threshold = ctx.Reader.ReadSingle(); + ThresholdColumn = ctx.LoadString(); + + BindableMapper = ScoreUtils.GetSchemaBindableMapper(Host, Model); + + var schema = GetSchema(); + var args = new BinaryClassifierScorer.Arguments { Threshold = Threshold, ThresholdColumn = ThresholdColumn }; + _scorer = new BinaryClassifierScorer(Host, args, new EmptyDataView(Host, TrainSchema), BindableMapper.Bind(Host, schema), schema); + } + + public override ISchema GetOutputSchema(ISchema inputSchema) + { + for (int i = 0; i < FeatureColumns.Length; i++) + { + var feat = FeatureColumns[i]; + if (!inputSchema.TryGetColumnIndex(feat, out int col)) + throw Host.ExceptSchemaMismatch(nameof(inputSchema), RoleMappedSchema.ColumnRole.Feature.Value, feat, FeatureColumnTypes[i].ToString(), null); + + if (!inputSchema.GetColumnType(col).Equals(FeatureColumnTypes[i])) + throw Host.ExceptSchemaMismatch(nameof(inputSchema), RoleMappedSchema.ColumnRole.Feature.Value, feat, FeatureColumnTypes[i].ToString(), inputSchema.GetColumnType(col).ToString()); + } + + return Transform(new EmptyDataView(Host, inputSchema)).Schema; + } + + public override IDataView Transform(IDataView input) + { + Host.CheckValue(input, nameof(input)); + return _scorer.ApplyToData(Host, input); + } + + public void Save(ModelSaveContext ctx) + { + Host.CheckValue(ctx, nameof(ctx)); + ctx.CheckAtModel(); + ctx.SetVersionInfo(GetVersionInfo()); + + // *** Binary format *** + // model: prediction model. + // stream: empty data view that contains train schema. + // ids of strings: feature columns. + // float: scorer threshold + // id of string: scorer threshold column + + ctx.SaveModel(Model, DirModel); + ctx.SaveBinaryStream(DirTransSchema, writer => + { + using (var ch = Host.Start("Saving train schema")) + { + var saver = new BinarySaver(Host, new BinarySaver.Arguments { Silent = true }); + DataSaverUtils.SaveDataView(ch, saver, new EmptyDataView(Host, TrainSchema), writer.BaseStream); + } + }); + + for (int i = 0; i < Model.FieldCount; i++) + ctx.SaveString(FeatureColumns[i]); + + ctx.Writer.Write(Threshold); + ctx.SaveString(ThresholdColumn); + } + + private RoleMappedSchema GetSchema() + { + var roles = new List>(); + foreach (var feat in FeatureColumns) + roles.Add(new KeyValuePair(RoleMappedSchema.ColumnRole.Feature, feat)); + + var schema = new RoleMappedSchema(TrainSchema, roles); + return schema; + } + + private static VersionInfo GetVersionInfo() + { + return new VersionInfo( + modelSignature: "FAFMPRED", + verWrittenCur: 0x00010001, // Initial + verReadableCur: 0x00010001, + verWeCanReadBack: 0x00010001, + loaderSignature: LoaderSignature); + } + + private static FieldAwareFactorizationMachinePredictionTransformer Create(IHostEnvironment env, ModelLoadContext ctx) + => new FieldAwareFactorizationMachinePredictionTransformer(env, ctx); + } } diff --git a/src/Microsoft.ML.StandardLearners/Standard/LinearClassificationTrainer.cs b/src/Microsoft.ML.StandardLearners/Standard/LinearClassificationTrainer.cs index ae2542463e..423f8c528e 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/LinearClassificationTrainer.cs +++ b/src/Microsoft.ML.StandardLearners/Standard/LinearClassificationTrainer.cs @@ -147,7 +147,7 @@ protected virtual int ComputeNumThreads(FloatLabelCursor.Factory cursorFactory) } public abstract class SdcaTrainerBase : StochasticTrainerBase - where TTransformer : IPredictionTransformer + where TTransformer : IClassicPredictionTransformer where TModel : IPredictor { // REVIEW: Making it even faster and more accurate: @@ -1490,7 +1490,7 @@ protected override void CheckLabel(RoleMappedData examples, out int weightSetCou } protected override BinaryPredictionTransformer MakeTransformer(TScalarPredictor model, ISchema trainSchema) - => new BinaryPredictionTransformer(Host, model, trainSchema, new[] { FeatureColumn.Name }); + => new BinaryPredictionTransformer(Host, model, trainSchema, FeatureColumn.Name); public BinaryPredictionTransformer Train(IDataView trainData, IDataView validationData = null, IPredictor initialPredictor = null) => TrainTransformer(trainData, validationData, initialPredictor); } diff --git a/src/Microsoft.ML.StandardLearners/Standard/MultiClass/MetaMulticlassTrainer.cs b/src/Microsoft.ML.StandardLearners/Standard/MultiClass/MetaMulticlassTrainer.cs index 00563b7fc6..90ed322f1b 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/MultiClass/MetaMulticlassTrainer.cs +++ b/src/Microsoft.ML.StandardLearners/Standard/MultiClass/MetaMulticlassTrainer.cs @@ -15,10 +15,10 @@ namespace Microsoft.ML.Runtime.Learners { - using TScalarTrainer = ITrainerEstimator>, IPredictorProducing>; + using TScalarTrainer = ITrainerEstimator>, IPredictorProducing>; public abstract class MetaMulticlassTrainer : ITrainerEstimator, ITrainer - where TTransformer : IPredictionTransformer + where TTransformer : IClassicPredictionTransformer where TModel : IPredictor { public abstract class ArgumentsBase diff --git a/src/Microsoft.ML.StandardLearners/Standard/MultiClass/Ova.cs b/src/Microsoft.ML.StandardLearners/Standard/MultiClass/Ova.cs index b4b902d071..bd061b2c91 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/MultiClass/Ova.cs +++ b/src/Microsoft.ML.StandardLearners/Standard/MultiClass/Ova.cs @@ -34,7 +34,7 @@ namespace Microsoft.ML.Runtime.Learners { using TScalarPredictor = IPredictorProducing; - using TScalarTrainer = ITrainerEstimator>, IPredictorProducing>; + using TScalarTrainer = ITrainerEstimator>, IPredictorProducing>; using TDistPredictor = IDistPredictorProducing; using CR = RoleMappedSchema.ColumnRole; @@ -111,7 +111,7 @@ protected override OvaPredictor TrainCore(IChannel ch, RoleMappedData data, int return OvaPredictor.Create(Host, _args.UseProbabilities, predictors); } - private IPredictionTransformer TrainOne(IChannel ch, TScalarTrainer trainer, RoleMappedData data, int cls) + private IClassicPredictionTransformer TrainOne(IChannel ch, TScalarTrainer trainer, RoleMappedData data, int cls) { var view = MapLabels(data, cls); @@ -127,7 +127,7 @@ private IPredictionTransformer TrainOne(IChannel ch, TScalarTr // REVIEW: restoring the RoleMappedData, as much as we can. // not having the weight column on the data passed to the TrainCalibrator should be addressed. - var trainedData = new RoleMappedData(view, label: trainerLabel, feature: transformer.FeatureColumn[0]); + var trainedData = new RoleMappedData(view, label: trainerLabel, feature: transformer.FeatureColumn); if (calibratedModel == null) calibratedModel = CalibratorUtils.TrainCalibrator(Host, ch, Calibrator, Args.MaxCalibrationExamples, transformer.Model, trainedData) as TDistPredictor; @@ -185,7 +185,7 @@ public override MulticlassPredictionTransformer Fit(IDataView inpu if (i == 0) { var transformer = TrainOne(ch, GetTrainer(), td, i); - featureColumn = transformer.FeatureColumn[0]; + featureColumn = transformer.FeatureColumn; } predictors[i] = TrainOne(ch, GetTrainer(), td, i).Model; diff --git a/src/Microsoft.ML.StandardLearners/Standard/MultiClass/Pkpd.cs b/src/Microsoft.ML.StandardLearners/Standard/MultiClass/Pkpd.cs index 79362009c7..bad83ed24d 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/MultiClass/Pkpd.cs +++ b/src/Microsoft.ML.StandardLearners/Standard/MultiClass/Pkpd.cs @@ -26,7 +26,7 @@ namespace Microsoft.ML.Runtime.Learners { using TDistPredictor = IDistPredictorProducing; - using TScalarTrainer = ITrainerEstimator>, IPredictorProducing>; + using TScalarTrainer = ITrainerEstimator>, IPredictorProducing>; using CR = RoleMappedSchema.ColumnRole; using TTransformer = MulticlassPredictionTransformer; @@ -119,7 +119,7 @@ protected override PkpdPredictor TrainCore(IChannel ch, RoleMappedData data, int return new PkpdPredictor(Host, predModels); } - private IPredictionTransformer TrainOne(IChannel ch, TScalarTrainer trainer, RoleMappedData data, int cls1, int cls2) + private IClassicPredictionTransformer TrainOne(IChannel ch, TScalarTrainer trainer, RoleMappedData data, int cls1, int cls2) { // this should not be necessary when the legacy constructor doesn't exist, and the label column is not an optional parameter on the // MetaMulticlassTrainer constructor. @@ -129,7 +129,7 @@ private IPredictionTransformer TrainOne(IChannel ch, TScalarTrai var transformer = trainer.Fit(view); // the validations in the calibrator check for the feature column, in the RoleMappedData - var trainedData = new RoleMappedData(view, label: trainerLabel, feature: transformer.FeatureColumn[0]); + var trainedData = new RoleMappedData(view, label: trainerLabel, feature: transformer.FeatureColumn); var calibratedModel = transformer.Model as TDistPredictor; if (calibratedModel == null) @@ -198,7 +198,7 @@ public override TTransformer Fit(IDataView input) if (i == 0 && j == 0) { var transformer = TrainOne(ch, GetTrainer(), td, i, j); - featureColumn = transformer.FeatureColumn[0]; + featureColumn = transformer.FeatureColumn; } predictors[i][j] = TrainOne(ch, GetTrainer(), td, i, j).Model; diff --git a/src/Microsoft.ML.StandardLearners/Standard/Online/AveragedLinear.cs b/src/Microsoft.ML.StandardLearners/Standard/Online/AveragedLinear.cs index 402d227fad..c64bea48ca 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/Online/AveragedLinear.cs +++ b/src/Microsoft.ML.StandardLearners/Standard/Online/AveragedLinear.cs @@ -54,7 +54,7 @@ public abstract class AveragedLinearArguments : OnlineLinearArguments } public abstract class AveragedLinearTrainer : OnlineLinearTrainer - where TTransformer : IPredictionTransformer + where TTransformer : IClassicPredictionTransformer where TModel : IPredictor { protected readonly new AveragedLinearArguments Args; diff --git a/src/Microsoft.ML.StandardLearners/Standard/Online/AveragedPerceptron.cs b/src/Microsoft.ML.StandardLearners/Standard/Online/AveragedPerceptron.cs index c3bdfa0bf6..5101e009eb 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/Online/AveragedPerceptron.cs +++ b/src/Microsoft.ML.StandardLearners/Standard/Online/AveragedPerceptron.cs @@ -120,7 +120,7 @@ protected override LinearBinaryPredictor CreatePredictor() } protected override BinaryPredictionTransformer MakeTransformer(LinearBinaryPredictor model, ISchema trainSchema) - => new BinaryPredictionTransformer(Host, model, trainSchema, new[] { FeatureColumn.Name }); + => new BinaryPredictionTransformer(Host, model, trainSchema, FeatureColumn.Name); [TlcModule.EntryPoint(Name = "Trainers.AveragedPerceptronBinaryClassifier", Desc = Summary, diff --git a/src/Microsoft.ML.StandardLearners/Standard/Online/LinearSvm.cs b/src/Microsoft.ML.StandardLearners/Standard/Online/LinearSvm.cs index ab27c4e283..ad8442433d 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/Online/LinearSvm.cs +++ b/src/Microsoft.ML.StandardLearners/Standard/Online/LinearSvm.cs @@ -260,6 +260,6 @@ public static CommonOutputs.BinaryClassificationOutput TrainLinearSvm(IHostEnvir } protected override BinaryPredictionTransformer MakeTransformer(LinearBinaryPredictor model, ISchema trainSchema) - => new BinaryPredictionTransformer(Host, model, trainSchema, new[] { FeatureColumn.Name }); + => new BinaryPredictionTransformer(Host, model, trainSchema, FeatureColumn.Name); } } diff --git a/src/Microsoft.ML.StandardLearners/Standard/Online/OnlineLinear.cs b/src/Microsoft.ML.StandardLearners/Standard/Online/OnlineLinear.cs index d7b7d4cf2f..1e0fa3821b 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/Online/OnlineLinear.cs +++ b/src/Microsoft.ML.StandardLearners/Standard/Online/OnlineLinear.cs @@ -44,7 +44,7 @@ public abstract class OnlineLinearArguments : LearnerInputBaseWithLabel } public abstract class OnlineLinearTrainer : TrainerEstimatorBase - where TTransformer : IPredictionTransformer + where TTransformer : IClassicPredictionTransformer where TModel : IPredictor { protected readonly OnlineLinearArguments Args; diff --git a/src/Microsoft.ML.StandardLearners/Standard/StochasticTrainerBase.cs b/src/Microsoft.ML.StandardLearners/Standard/StochasticTrainerBase.cs index 7fc70ce621..eaab91e31b 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/StochasticTrainerBase.cs +++ b/src/Microsoft.ML.StandardLearners/Standard/StochasticTrainerBase.cs @@ -11,7 +11,7 @@ namespace Microsoft.ML.Runtime.Learners { public abstract class StochasticTrainerBase : TrainerEstimatorBase - where TTransformer : IPredictionTransformer + where TTransformer : IClassicPredictionTransformer where TModel : IPredictor { public StochasticTrainerBase(IHost host, SchemaShape.Column feature, SchemaShape.Column label, SchemaShape.Column weight = null) diff --git a/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/Wrappers.cs b/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/Wrappers.cs index 3e492106ed..b1953ac651 100644 --- a/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/Wrappers.cs +++ b/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/Wrappers.cs @@ -174,12 +174,12 @@ public TransformWrapper(IHostEnvironment env, ModelLoadContext ctx) public IDataView Transform(IDataView input) => ApplyTransformUtils.ApplyAllTransformsToData(_env, _xf, input); } - public class ScorerWrapper : TransformWrapper, IPredictionTransformer + public class ScorerWrapper : TransformWrapper, IClassicPredictionTransformer where TModel : IPredictor { - protected readonly string[] _featureColumn; + protected readonly string _featureColumn; - public ScorerWrapper(IHostEnvironment env, IDataView scorer, TModel trainedModel, string[] featureColumn) + public ScorerWrapper(IHostEnvironment env, IDataView scorer, TModel trainedModel, string featureColumn) : base(env, scorer) { _featureColumn = featureColumn; @@ -188,20 +188,20 @@ public ScorerWrapper(IHostEnvironment env, IDataView scorer, TModel trainedModel public TModel Model { get; } - public string[] FeatureColumn => _featureColumn; + public string FeatureColumn => _featureColumn; - public ColumnType[] FeatureColumnType => throw _env.ExceptNotSupp(); + public ColumnType FeatureColumnType => throw _env.ExceptNotSupp(); } public class BinaryScorerWrapper : ScorerWrapper where TModel : IPredictor { - public BinaryScorerWrapper(IHostEnvironment env, TModel model, ISchema inputSchema, string[] featureColumn, BinaryClassifierScorer.Arguments args) + public BinaryScorerWrapper(IHostEnvironment env, TModel model, ISchema inputSchema, string featureColumn, BinaryClassifierScorer.Arguments args) : base(env, MakeScorer(env, inputSchema, featureColumn, model, args), model, featureColumn) { } - private static IDataView MakeScorer(IHostEnvironment env, ISchema schema, string[] featureColumn, TModel model, BinaryClassifierScorer.Arguments args) + private static IDataView MakeScorer(IHostEnvironment env, ISchema schema, string featureColumn, TModel model, BinaryClassifierScorer.Arguments args) { var settings = $"Binary{{{CmdParser.GetSettings(env, args, new BinaryClassifierScorer.Arguments())}}}"; @@ -212,14 +212,7 @@ private static IDataView MakeScorer(IHostEnvironment env, ISchema schema, string var bindable = ScoreUtils.GetSchemaBindableMapper(env, model, scorerFactorySettings: scorerFactorySettings); var edv = new EmptyDataView(env, schema); - - var roles = new List>(); - foreach (var feat in featureColumn) - roles.Add(new KeyValuePair(RoleMappedSchema.ColumnRole.Feature, feat)); - - roles.Add(new KeyValuePair(RoleMappedSchema.ColumnRole.Label, "Label")); - - var data = new RoleMappedData(edv, roles, opt: true); + var data = new RoleMappedData(edv, "Label", featureColumn, opt: true); return new BinaryClassifierScorer(env, args, data.Data, bindable.Bind(env, data.Schema), data.Schema); } @@ -302,7 +295,7 @@ public SchemaShape GetOutputSchema(SchemaShape inputSchema) protected ScorerWrapper MakeScorerBasic(TModel predictor, RoleMappedData data) { var scorer = ScoreUtils.GetScorer(predictor, data, _env, data.Schema); - return (TTransformer)(new ScorerWrapper(_env, scorer, predictor, new[] { data.Schema.Feature.Name })); + return (TTransformer)(new ScorerWrapper(_env, scorer, predictor, data.Schema.Feature.Name)); } } diff --git a/test/Microsoft.ML.Tests/TrainerEstimators/FAFMEstimator.cs b/test/Microsoft.ML.Tests/TrainerEstimators/FAFMEstimator.cs index 37b4e65798..303cf099dc 100644 --- a/test/Microsoft.ML.Tests/TrainerEstimators/FAFMEstimator.cs +++ b/test/Microsoft.ML.Tests/TrainerEstimators/FAFMEstimator.cs @@ -14,15 +14,13 @@ namespace Microsoft.ML.Tests.TrainerEstimators { public partial class TrainerEstimators : TestDataPipeBase { - [Fact(Skip ="Disabling to check whether any regressions, will enable prior to check-in")] + [Fact] public void FieldAwareFactorizationMachine_Estimator() { var data = new TextLoader(Env, GetFafmBCLoaderArgs()) .Read(new MultiFileSource(GetDataPath(TestDatasets.breastCancer.trainFilename))); IEstimator est = new FieldAwareFactorizationMachineTrainer(Env, "Label", new[] { "Feature1", "Feature2", "Feature3", "Feature4" }); - - //var result = est.Fit(data); TestEstimatorCore(est, data); Done(); From 5890b11e8ecdf6e178346f7febcbf3117af0b986 Mon Sep 17 00:00:00 2001 From: Senja Filipi Date: Wed, 19 Sep 2018 08:51:42 -0700 Subject: [PATCH 06/10] renaming the interface for transformers with feature column information. Commenting code further. --- src/Microsoft.ML.Core/Data/IEstimator.cs | 4 +- .../Prediction/IPredictionTransformer.cs | 16 ++++++- .../Scorers/PredictionTransformer.cs | 44 ++++++++++++++----- .../Training/ITrainerEstimator.cs | 2 +- .../Training/TrainerEstimatorBase.cs | 2 +- ...FieldAwareFactorizationMachinePredictor.cs | 36 ++++++++++----- .../Standard/LinearClassificationTrainer.cs | 2 +- .../MultiClass/MetaMulticlassTrainer.cs | 4 +- .../Standard/MultiClass/Ova.cs | 4 +- .../Standard/MultiClass/Pkpd.cs | 4 +- .../Standard/Online/AveragedLinear.cs | 2 +- .../Standard/Online/OnlineLinear.cs | 2 +- .../Standard/StochasticTrainerBase.cs | 2 +- .../DataPipe/TestDataPipeBase.cs | 4 +- .../Scenarios/Api/Estimators/Wrappers.cs | 2 +- .../TrainerEstimators/FAFMEstimator.cs | 8 +++- 16 files changed, 98 insertions(+), 40 deletions(-) diff --git a/src/Microsoft.ML.Core/Data/IEstimator.cs b/src/Microsoft.ML.Core/Data/IEstimator.cs index a410f02001..60e085d38c 100644 --- a/src/Microsoft.ML.Core/Data/IEstimator.cs +++ b/src/Microsoft.ML.Core/Data/IEstimator.cs @@ -229,14 +229,14 @@ public interface IDataReaderEstimator /// /// The transformer is a component that transforms data. - /// It also supports 'schema propagation' to answer the question of 'how the data with this schema look after you transform it?'. + /// It also supports 'schema propagation' to answer the question of 'how will the data with this schema look, after you transform it?'. /// public interface ITransformer { /// /// Schema propagation for transformers. /// Returns the output schema of the data, if the input schema is like the one provided. - /// Throws iff the input schema is not valid for the transformer. + /// Throws if the input schema is not valid for the transformer. /// ISchema GetOutputSchema(ISchema inputSchema); diff --git a/src/Microsoft.ML.Data/Prediction/IPredictionTransformer.cs b/src/Microsoft.ML.Data/Prediction/IPredictionTransformer.cs index 1b93149783..06c1894f0a 100644 --- a/src/Microsoft.ML.Data/Prediction/IPredictionTransformer.cs +++ b/src/Microsoft.ML.Data/Prediction/IPredictionTransformer.cs @@ -8,17 +8,31 @@ namespace Microsoft.ML.Runtime { + /// + /// An interface for all the transformer that can transform data based on the field. + /// The implemendations of this interface either have no feature column, or have more than one feature column, and cannot implement the + /// , which most of the ML.Net tranformer implement. + /// + /// The used for the data transformation. public interface IPredictionTransformer : ITransformer where TModel : IPredictor { TModel Model { get; } } - public interface IClassicPredictionTransformer : IPredictionTransformer + /// + /// An ISingleFeaturePredictionTransformer contains the name of the + /// and its type, . Implementations of this interface, have the ability + /// to score the data of an input through the + /// + /// The used for the data transformation. + public interface ISingleFeaturePredictionTransformer : IPredictionTransformer where TModel : IPredictor { + /// The name of the feature column. string FeatureColumn { get; } + /// Holds information about the type of the feature column. ColumnType FeatureColumnType { get; } } } diff --git a/src/Microsoft.ML.Data/Scorers/PredictionTransformer.cs b/src/Microsoft.ML.Data/Scorers/PredictionTransformer.cs index 6d91277fb1..136d10ca9e 100644 --- a/src/Microsoft.ML.Data/Scorers/PredictionTransformer.cs +++ b/src/Microsoft.ML.Data/Scorers/PredictionTransformer.cs @@ -2,9 +2,7 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using System; using System.IO; -using Microsoft.ML.Core.Data; using Microsoft.ML.Runtime; using Microsoft.ML.Runtime.Data; using Microsoft.ML.Runtime.Data.IO; @@ -22,6 +20,10 @@ namespace Microsoft.ML.Runtime.Data { + /// + /// Base class for transformers with no feature column, or more than one feature columns. + /// + /// public abstract class PredictionTransformerBase : IPredictionTransformer where TModel : class, IPredictor { @@ -71,13 +73,18 @@ protected PredictionTransformerBase(IHost host, ModelLoadContext ctx) TrainSchema = loader.Schema; } + /// + /// Gets the output schema resulting from the + /// + /// The of the input data. + /// The resulting . public abstract ISchema GetOutputSchema(ISchema inputSchema); /// /// Transforms the input data. /// - /// - /// + /// The input data. + /// The transformed public abstract IDataView Transform(IDataView input); protected void SaveModel(ModelSaveContext ctx) @@ -99,7 +106,12 @@ protected void SaveModel(ModelSaveContext ctx) } } - public abstract class ClassicPredictionTransformerBase : PredictionTransformerBase, IClassicPredictionTransformer, ICanSaveModel + /// + /// The base class for all the transformers implementing the . + /// Those are all the transformers that work with one feature column. + /// + /// The model used to transform the data. + public abstract class SingleFeaturePredictionTransformerBase : PredictionTransformerBase, ISingleFeaturePredictionTransformer, ICanSaveModel where TModel : class, IPredictor { /// @@ -112,7 +124,7 @@ public abstract class ClassicPredictionTransformerBase : PredictionTrans /// public ColumnType FeatureColumnType { get; } - public ClassicPredictionTransformerBase(IHost host, TModel model, ISchema trainSchema, string featureColumn) + public SingleFeaturePredictionTransformerBase(IHost host, TModel model, ISchema trainSchema, string featureColumn) :base(host, model, trainSchema) { FeatureColumn = featureColumn; @@ -123,7 +135,7 @@ public ClassicPredictionTransformerBase(IHost host, TModel model, ISchema trainS BindableMapper = ScoreUtils.GetSchemaBindableMapper(Host, model); } - internal ClassicPredictionTransformerBase(IHost host, ModelLoadContext ctx) + internal SingleFeaturePredictionTransformerBase(IHost host, ModelLoadContext ctx) :base(host, ctx) { FeatureColumn = ctx.LoadString(); @@ -160,7 +172,11 @@ protected virtual void SaveCore(ModelSaveContext ctx) } } - public sealed class BinaryPredictionTransformer : ClassicPredictionTransformerBase + /// + /// Base class for the working on binary classification tasks. + /// + /// An implementation of the + public sealed class BinaryPredictionTransformer : SingleFeaturePredictionTransformerBase where TModel : class, IPredictorProducing { private readonly BinaryClassifierScorer _scorer; @@ -229,7 +245,11 @@ private static VersionInfo GetVersionInfo() } } - public sealed class MulticlassPredictionTransformer : ClassicPredictionTransformerBase + /// + /// Base class for the working on multi-class classification tasks. + /// + /// An implementation of the + public sealed class MulticlassPredictionTransformer : SingleFeaturePredictionTransformerBase where TModel : class, IPredictorProducing> { private readonly MultiClassClassifierScorer _scorer; @@ -290,7 +310,11 @@ private static VersionInfo GetVersionInfo() } } - public sealed class RegressionPredictionTransformer : ClassicPredictionTransformerBase + /// + /// Base class for the working on regression tasks. + /// + /// An implementation of the + public sealed class RegressionPredictionTransformer : SingleFeaturePredictionTransformerBase where TModel : class, IPredictorProducing { private readonly GenericScorer _scorer; diff --git a/src/Microsoft.ML.Data/Training/ITrainerEstimator.cs b/src/Microsoft.ML.Data/Training/ITrainerEstimator.cs index 243013d67e..2c9942e8d0 100644 --- a/src/Microsoft.ML.Data/Training/ITrainerEstimator.cs +++ b/src/Microsoft.ML.Data/Training/ITrainerEstimator.cs @@ -7,7 +7,7 @@ namespace Microsoft.ML.Runtime.Training { public interface ITrainerEstimator: IEstimator - where TTransformer: IClassicPredictionTransformer + where TTransformer: ISingleFeaturePredictionTransformer where TPredictor: IPredictor { TrainerInfo Info { get; } diff --git a/src/Microsoft.ML.Data/Training/TrainerEstimatorBase.cs b/src/Microsoft.ML.Data/Training/TrainerEstimatorBase.cs index dccf25cf51..7ae6475c35 100644 --- a/src/Microsoft.ML.Data/Training/TrainerEstimatorBase.cs +++ b/src/Microsoft.ML.Data/Training/TrainerEstimatorBase.cs @@ -15,7 +15,7 @@ namespace Microsoft.ML.Runtime.Training /// It produces a 'prediction transformer'. /// public abstract class TrainerEstimatorBase : ITrainerEstimator, ITrainer - where TTransformer : IClassicPredictionTransformer + where TTransformer : ISingleFeaturePredictionTransformer where TModel : IPredictor { /// diff --git a/src/Microsoft.ML.StandardLearners/FactorizationMachine/FieldAwareFactorizationMachinePredictor.cs b/src/Microsoft.ML.StandardLearners/FactorizationMachine/FieldAwareFactorizationMachinePredictor.cs index 32866a0633..f5da9327c1 100644 --- a/src/Microsoft.ML.StandardLearners/FactorizationMachine/FieldAwareFactorizationMachinePredictor.cs +++ b/src/Microsoft.ML.StandardLearners/FactorizationMachine/FieldAwareFactorizationMachinePredictor.cs @@ -199,22 +199,22 @@ public sealed class FieldAwareFactorizationMachinePredictionTransformer : Predic public string[] FeatureColumns { get; } /// - /// The type of the prediction transformer + /// The type of the feature columns. /// public ColumnType[] FeatureColumnTypes { get; } private readonly BinaryClassifierScorer _scorer; - public readonly string ThresholdColumn; - public readonly float Threshold; + private readonly string _thresholdColumn; + private readonly float _threshold; public FieldAwareFactorizationMachinePredictionTransformer(IHostEnvironment host, FieldAwareFactorizationMachinePredictor model, ISchema trainSchema, string[] featureColumns, float threshold = 0f, string thresholdColumn = DefaultColumnNames.Score) :base(Contracts.CheckRef(host, nameof(host)).Register(nameof(FieldAwareFactorizationMachinePredictionTransformer)), model, trainSchema) { Host.CheckNonEmpty(thresholdColumn, nameof(thresholdColumn)); - Threshold = threshold; - ThresholdColumn = thresholdColumn; + _threshold = threshold; + _thresholdColumn = thresholdColumn; Host.CheckValue(featureColumns, nameof(featureColumns)); int featCount = featureColumns.Length; @@ -234,7 +234,7 @@ public FieldAwareFactorizationMachinePredictionTransformer(IHostEnvironment host BindableMapper = ScoreUtils.GetSchemaBindableMapper(Host, model); var schema = GetSchema(); - var args = new BinaryClassifierScorer.Arguments { Threshold = Threshold, ThresholdColumn = ThresholdColumn }; + var args = new BinaryClassifierScorer.Arguments { Threshold = _threshold, ThresholdColumn = _thresholdColumn }; _scorer = new BinaryClassifierScorer(Host, args, new EmptyDataView(Host, trainSchema), BindableMapper.Bind(Host, schema), schema); } @@ -261,16 +261,21 @@ public FieldAwareFactorizationMachinePredictionTransformer(IHostEnvironment host FeatureColumnTypes[i] = TrainSchema.GetColumnType(col); } - Threshold = ctx.Reader.ReadSingle(); - ThresholdColumn = ctx.LoadString(); + _threshold = ctx.Reader.ReadSingle(); + _thresholdColumn = ctx.LoadString(); BindableMapper = ScoreUtils.GetSchemaBindableMapper(Host, Model); var schema = GetSchema(); - var args = new BinaryClassifierScorer.Arguments { Threshold = Threshold, ThresholdColumn = ThresholdColumn }; + var args = new BinaryClassifierScorer.Arguments { Threshold = _threshold, ThresholdColumn = _thresholdColumn }; _scorer = new BinaryClassifierScorer(Host, args, new EmptyDataView(Host, TrainSchema), BindableMapper.Bind(Host, schema), schema); } + /// + /// Gets the result after applying . + /// + /// The of the input data. + /// The post transformation . public override ISchema GetOutputSchema(ISchema inputSchema) { for (int i = 0; i < FeatureColumns.Length; i++) @@ -286,12 +291,21 @@ public override ISchema GetOutputSchema(ISchema inputSchema) return Transform(new EmptyDataView(Host, inputSchema)).Schema; } + /// + /// Applies the transformer to the , scoring it through the . + /// + /// The data to be scored with the . + /// The scored . public override IDataView Transform(IDataView input) { Host.CheckValue(input, nameof(input)); return _scorer.ApplyToData(Host, input); } + /// + /// Saves the transformer to file. + /// + /// The that facilitates saving to the . public void Save(ModelSaveContext ctx) { Host.CheckValue(ctx, nameof(ctx)); @@ -318,8 +332,8 @@ public void Save(ModelSaveContext ctx) for (int i = 0; i < Model.FieldCount; i++) ctx.SaveString(FeatureColumns[i]); - ctx.Writer.Write(Threshold); - ctx.SaveString(ThresholdColumn); + ctx.Writer.Write(_threshold); + ctx.SaveString(_thresholdColumn); } private RoleMappedSchema GetSchema() diff --git a/src/Microsoft.ML.StandardLearners/Standard/LinearClassificationTrainer.cs b/src/Microsoft.ML.StandardLearners/Standard/LinearClassificationTrainer.cs index 423f8c528e..8722b79ea6 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/LinearClassificationTrainer.cs +++ b/src/Microsoft.ML.StandardLearners/Standard/LinearClassificationTrainer.cs @@ -147,7 +147,7 @@ protected virtual int ComputeNumThreads(FloatLabelCursor.Factory cursorFactory) } public abstract class SdcaTrainerBase : StochasticTrainerBase - where TTransformer : IClassicPredictionTransformer + where TTransformer : ISingleFeaturePredictionTransformer where TModel : IPredictor { // REVIEW: Making it even faster and more accurate: diff --git a/src/Microsoft.ML.StandardLearners/Standard/MultiClass/MetaMulticlassTrainer.cs b/src/Microsoft.ML.StandardLearners/Standard/MultiClass/MetaMulticlassTrainer.cs index 90ed322f1b..ee100e9e59 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/MultiClass/MetaMulticlassTrainer.cs +++ b/src/Microsoft.ML.StandardLearners/Standard/MultiClass/MetaMulticlassTrainer.cs @@ -15,10 +15,10 @@ namespace Microsoft.ML.Runtime.Learners { - using TScalarTrainer = ITrainerEstimator>, IPredictorProducing>; + using TScalarTrainer = ITrainerEstimator>, IPredictorProducing>; public abstract class MetaMulticlassTrainer : ITrainerEstimator, ITrainer - where TTransformer : IClassicPredictionTransformer + where TTransformer : ISingleFeaturePredictionTransformer where TModel : IPredictor { public abstract class ArgumentsBase diff --git a/src/Microsoft.ML.StandardLearners/Standard/MultiClass/Ova.cs b/src/Microsoft.ML.StandardLearners/Standard/MultiClass/Ova.cs index bd061b2c91..b918c5acce 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/MultiClass/Ova.cs +++ b/src/Microsoft.ML.StandardLearners/Standard/MultiClass/Ova.cs @@ -34,7 +34,7 @@ namespace Microsoft.ML.Runtime.Learners { using TScalarPredictor = IPredictorProducing; - using TScalarTrainer = ITrainerEstimator>, IPredictorProducing>; + using TScalarTrainer = ITrainerEstimator>, IPredictorProducing>; using TDistPredictor = IDistPredictorProducing; using CR = RoleMappedSchema.ColumnRole; @@ -111,7 +111,7 @@ protected override OvaPredictor TrainCore(IChannel ch, RoleMappedData data, int return OvaPredictor.Create(Host, _args.UseProbabilities, predictors); } - private IClassicPredictionTransformer TrainOne(IChannel ch, TScalarTrainer trainer, RoleMappedData data, int cls) + private ISingleFeaturePredictionTransformer TrainOne(IChannel ch, TScalarTrainer trainer, RoleMappedData data, int cls) { var view = MapLabels(data, cls); diff --git a/src/Microsoft.ML.StandardLearners/Standard/MultiClass/Pkpd.cs b/src/Microsoft.ML.StandardLearners/Standard/MultiClass/Pkpd.cs index bad83ed24d..82994143cb 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/MultiClass/Pkpd.cs +++ b/src/Microsoft.ML.StandardLearners/Standard/MultiClass/Pkpd.cs @@ -26,7 +26,7 @@ namespace Microsoft.ML.Runtime.Learners { using TDistPredictor = IDistPredictorProducing; - using TScalarTrainer = ITrainerEstimator>, IPredictorProducing>; + using TScalarTrainer = ITrainerEstimator>, IPredictorProducing>; using CR = RoleMappedSchema.ColumnRole; using TTransformer = MulticlassPredictionTransformer; @@ -119,7 +119,7 @@ protected override PkpdPredictor TrainCore(IChannel ch, RoleMappedData data, int return new PkpdPredictor(Host, predModels); } - private IClassicPredictionTransformer TrainOne(IChannel ch, TScalarTrainer trainer, RoleMappedData data, int cls1, int cls2) + private ISingleFeaturePredictionTransformer TrainOne(IChannel ch, TScalarTrainer trainer, RoleMappedData data, int cls1, int cls2) { // this should not be necessary when the legacy constructor doesn't exist, and the label column is not an optional parameter on the // MetaMulticlassTrainer constructor. diff --git a/src/Microsoft.ML.StandardLearners/Standard/Online/AveragedLinear.cs b/src/Microsoft.ML.StandardLearners/Standard/Online/AveragedLinear.cs index c64bea48ca..70ee279a1c 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/Online/AveragedLinear.cs +++ b/src/Microsoft.ML.StandardLearners/Standard/Online/AveragedLinear.cs @@ -54,7 +54,7 @@ public abstract class AveragedLinearArguments : OnlineLinearArguments } public abstract class AveragedLinearTrainer : OnlineLinearTrainer - where TTransformer : IClassicPredictionTransformer + where TTransformer : ISingleFeaturePredictionTransformer where TModel : IPredictor { protected readonly new AveragedLinearArguments Args; diff --git a/src/Microsoft.ML.StandardLearners/Standard/Online/OnlineLinear.cs b/src/Microsoft.ML.StandardLearners/Standard/Online/OnlineLinear.cs index 1e0fa3821b..15bd5da290 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/Online/OnlineLinear.cs +++ b/src/Microsoft.ML.StandardLearners/Standard/Online/OnlineLinear.cs @@ -44,7 +44,7 @@ public abstract class OnlineLinearArguments : LearnerInputBaseWithLabel } public abstract class OnlineLinearTrainer : TrainerEstimatorBase - where TTransformer : IClassicPredictionTransformer + where TTransformer : ISingleFeaturePredictionTransformer where TModel : IPredictor { protected readonly OnlineLinearArguments Args; diff --git a/src/Microsoft.ML.StandardLearners/Standard/StochasticTrainerBase.cs b/src/Microsoft.ML.StandardLearners/Standard/StochasticTrainerBase.cs index eaab91e31b..82e2223e46 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/StochasticTrainerBase.cs +++ b/src/Microsoft.ML.StandardLearners/Standard/StochasticTrainerBase.cs @@ -11,7 +11,7 @@ namespace Microsoft.ML.Runtime.Learners { public abstract class StochasticTrainerBase : TrainerEstimatorBase - where TTransformer : IClassicPredictionTransformer + where TTransformer : ISingleFeaturePredictionTransformer where TModel : IPredictor { public StochasticTrainerBase(IHost host, SchemaShape.Column feature, SchemaShape.Column label, SchemaShape.Column weight = null) diff --git a/test/Microsoft.ML.TestFramework/DataPipe/TestDataPipeBase.cs b/test/Microsoft.ML.TestFramework/DataPipe/TestDataPipeBase.cs index 9c81da6cda..42766d3934 100644 --- a/test/Microsoft.ML.TestFramework/DataPipe/TestDataPipeBase.cs +++ b/test/Microsoft.ML.TestFramework/DataPipe/TestDataPipeBase.cs @@ -123,8 +123,8 @@ protected void TestEstimatorCore(IEstimator estimator, // and original transformer. // This in turn means that the schema of loaded transformer matches for // Transform and GetOutputSchema calls. - CheckSameSchemas(scoredTrain.Schema, scoredTrain2.Schema); - CheckSameValues(scoredTrain, scoredTrain2); + CheckSameSchemas(scoredTrain.Schema, scoredTrain2.Schema); + CheckSameValues(scoredTrain, scoredTrain2); }; checkOnData(validFitInput); diff --git a/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/Wrappers.cs b/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/Wrappers.cs index b1953ac651..13ca935b15 100644 --- a/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/Wrappers.cs +++ b/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/Wrappers.cs @@ -174,7 +174,7 @@ public TransformWrapper(IHostEnvironment env, ModelLoadContext ctx) public IDataView Transform(IDataView input) => ApplyTransformUtils.ApplyAllTransformsToData(_env, _xf, input); } - public class ScorerWrapper : TransformWrapper, IClassicPredictionTransformer + public class ScorerWrapper : TransformWrapper, ISingleFeaturePredictionTransformer where TModel : IPredictor { protected readonly string _featureColumn; diff --git a/test/Microsoft.ML.Tests/TrainerEstimators/FAFMEstimator.cs b/test/Microsoft.ML.Tests/TrainerEstimators/FAFMEstimator.cs index 303cf099dc..42795e3f4a 100644 --- a/test/Microsoft.ML.Tests/TrainerEstimators/FAFMEstimator.cs +++ b/test/Microsoft.ML.Tests/TrainerEstimators/FAFMEstimator.cs @@ -20,7 +20,13 @@ public void FieldAwareFactorizationMachine_Estimator() var data = new TextLoader(Env, GetFafmBCLoaderArgs()) .Read(new MultiFileSource(GetDataPath(TestDatasets.breastCancer.trainFilename))); - IEstimator est = new FieldAwareFactorizationMachineTrainer(Env, "Label", new[] { "Feature1", "Feature2", "Feature3", "Feature4" }); + IEstimator est = new FieldAwareFactorizationMachineTrainer(Env, "Label", new[] { "Feature1", "Feature2", "Feature3", "Feature4" }, + advancedSettings:s=> + { + s.Shuffle = false; + s.Iters = 3; + s.LatentDim = 7; + }); TestEstimatorCore(est, data); Done(); From 174e75db52227b0b43900cd3d39ef1c0f0762a6f Mon Sep 17 00:00:00 2001 From: Senja Filipi Date: Wed, 19 Sep 2018 08:55:41 -0700 Subject: [PATCH 07/10] Fixing thy space. --- src/Microsoft.ML.Data/Scorers/PredictionTransformer.cs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/Microsoft.ML.Data/Scorers/PredictionTransformer.cs b/src/Microsoft.ML.Data/Scorers/PredictionTransformer.cs index 136d10ca9e..803d9badd7 100644 --- a/src/Microsoft.ML.Data/Scorers/PredictionTransformer.cs +++ b/src/Microsoft.ML.Data/Scorers/PredictionTransformer.cs @@ -21,7 +21,7 @@ namespace Microsoft.ML.Runtime.Data { /// - /// Base class for transformers with no feature column, or more than one feature columns. + /// Base class for transformers with no feature column, or more than one feature columns. /// /// public abstract class PredictionTransformerBase : IPredictionTransformer From 65a629651b6658b948f62aa27cab3876f82f87f6 Mon Sep 17 00:00:00 2001 From: Senja Filipi Date: Wed, 19 Sep 2018 10:47:40 -0700 Subject: [PATCH 08/10] Incorporating Pete's fix about the unecessary creation of getters --- .../FieldAwareFactorizationMachineUtils.cs | 8 ++++++-- .../Microsoft.ML.Tests/TrainerEstimators/FAFMEstimator.cs | 3 ++- 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/src/Microsoft.ML.StandardLearners/FactorizationMachine/FieldAwareFactorizationMachineUtils.cs b/src/Microsoft.ML.StandardLearners/FactorizationMachine/FieldAwareFactorizationMachineUtils.cs index 67c53223d7..11881b1925 100644 --- a/src/Microsoft.ML.StandardLearners/FactorizationMachine/FieldAwareFactorizationMachineUtils.cs +++ b/src/Microsoft.ML.StandardLearners/FactorizationMachine/FieldAwareFactorizationMachineUtils.cs @@ -101,8 +101,12 @@ public IRow GetOutputRow(IRow input, Func predicate, out Action actio var featureIndexBuffer = new int[_pred.FeatureCount]; var featureValueBuffer = new float[_pred.FeatureCount]; var inputGetters = new ValueGetter>[_pred.FieldCount]; - for (int f = 0; f < _pred.FieldCount; f++) - inputGetters[f] = input.GetGetter>(_inputColumnIndexes[f]); + + if (predicate(0) || predicate(1)) + { + for (int f = 0; f < _pred.FieldCount; f++) + inputGetters[f] = input.GetGetter>(_inputColumnIndexes[f]); + } action = null; var getters = new Delegate[2]; diff --git a/test/Microsoft.ML.Tests/TrainerEstimators/FAFMEstimator.cs b/test/Microsoft.ML.Tests/TrainerEstimators/FAFMEstimator.cs index 42795e3f4a..87fa345d40 100644 --- a/test/Microsoft.ML.Tests/TrainerEstimators/FAFMEstimator.cs +++ b/test/Microsoft.ML.Tests/TrainerEstimators/FAFMEstimator.cs @@ -20,13 +20,14 @@ public void FieldAwareFactorizationMachine_Estimator() var data = new TextLoader(Env, GetFafmBCLoaderArgs()) .Read(new MultiFileSource(GetDataPath(TestDatasets.breastCancer.trainFilename))); - IEstimator est = new FieldAwareFactorizationMachineTrainer(Env, "Label", new[] { "Feature1", "Feature2", "Feature3", "Feature4" }, + var est = new FieldAwareFactorizationMachineTrainer(Env, "Label", new[] { "Feature1", "Feature2", "Feature3", "Feature4" }, advancedSettings:s=> { s.Shuffle = false; s.Iters = 3; s.LatentDim = 7; }); + TestEstimatorCore(est, data); Done(); From 492a890a2beca07f17ed3b88dbe8533c9aaa2bbd Mon Sep 17 00:00:00 2001 From: Senja Filipi Date: Wed, 19 Sep 2018 12:38:51 -0700 Subject: [PATCH 09/10] fixing comments. --- src/Microsoft.ML.Data/Scorers/PredictionTransformer.cs | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/Microsoft.ML.Data/Scorers/PredictionTransformer.cs b/src/Microsoft.ML.Data/Scorers/PredictionTransformer.cs index 803d9badd7..cf51798af2 100644 --- a/src/Microsoft.ML.Data/Scorers/PredictionTransformer.cs +++ b/src/Microsoft.ML.Data/Scorers/PredictionTransformer.cs @@ -53,14 +53,14 @@ protected PredictionTransformerBase(IHost host, ModelLoadContext ctx) { Host = host; - ctx.LoadModel(host, out TModel model, DirModel); - Model = model; - // *** Binary format *** // model: prediction model. // stream: empty data view that contains train schema. // id of string: feature column. + ctx.LoadModel(host, out TModel model, DirModel); + Model = model; + // Clone the stream with the schema into memory. var ms = new MemoryStream(); ctx.TryLoadBinaryStream(DirTransSchema, reader => @@ -90,9 +90,9 @@ protected PredictionTransformerBase(IHost host, ModelLoadContext ctx) protected void SaveModel(ModelSaveContext ctx) { // *** Binary format *** + // // model: prediction model. // stream: empty data view that contains train schema. - // id of string: feature column. ctx.SaveModel(Model, DirModel); ctx.SaveBinaryStream(DirTransSchema, writer => From 3d4858d7c3332e018ba13d4f54b627ab95f1ca44 Mon Sep 17 00:00:00 2001 From: Senja Filipi Date: Thu, 20 Sep 2018 10:28:53 -0700 Subject: [PATCH 10/10] post merge fixes --- .../Scorers/PredictionTransformer.cs | 21 +++++++------------ src/Microsoft.ML.FastTree/BoostingFastTree.cs | 2 +- src/Microsoft.ML.FastTree/FastTree.cs | 2 +- src/Microsoft.ML.FastTree/RandomForest.cs | 2 +- 4 files changed, 10 insertions(+), 17 deletions(-) diff --git a/src/Microsoft.ML.Data/Scorers/PredictionTransformer.cs b/src/Microsoft.ML.Data/Scorers/PredictionTransformer.cs index a66100f62b..aa71ac2dc5 100644 --- a/src/Microsoft.ML.Data/Scorers/PredictionTransformer.cs +++ b/src/Microsoft.ML.Data/Scorers/PredictionTransformer.cs @@ -44,19 +44,11 @@ public abstract class PredictionTransformerBase : IPredictionTransformer protected PredictionTransformerBase(IHost host, TModel model, ISchema trainSchema) { Contracts.CheckValue(host, nameof(host)); - Contracts.CheckValueOrNull(featureColumn); + Host = host; Host.CheckValue(trainSchema, nameof(trainSchema)); Model = model; - FeatureColumn = featureColumn; - if (featureColumn == null) - FeatureColumnType = null; - else if (!trainSchema.TryGetColumnIndex(featureColumn, out int col)) - throw Host.ExceptSchemaMismatch(nameof(featureColumn), RoleMappedSchema.ColumnRole.Feature.Value, featureColumn); - else - FeatureColumnType = trainSchema.GetColumnType(col); - TrainSchema = trainSchema; } @@ -103,7 +95,6 @@ protected void SaveModel(ModelSaveContext ctx) { // *** Binary format *** // - // model: prediction model. // stream: empty data view that contains train schema. ctx.SaveModel(Model, DirModel); @@ -141,11 +132,13 @@ public SingleFeaturePredictionTransformerBase(IHost host, TModel model, ISchema { FeatureColumn = featureColumn; - if (FeatureColumn == null) + FeatureColumn = featureColumn; + if (featureColumn == null) FeatureColumnType = null; - if (!trainSchema.TryGetColumnIndex(featureColumn, out int col)) + else if (!trainSchema.TryGetColumnIndex(featureColumn, out int col)) throw Host.ExceptSchemaMismatch(nameof(featureColumn), RoleMappedSchema.ColumnRole.Feature.Value, featureColumn); - FeatureColumnType = trainSchema.GetColumnType(col); + else + FeatureColumnType = trainSchema.GetColumnType(col); BindableMapper = ScoreUtils.GetSchemaBindableMapper(Host, model); } @@ -382,7 +375,7 @@ private static VersionInfo GetVersionInfo() } } - public sealed class RankingPredictionTransformer : PredictionTransformerBase + public sealed class RankingPredictionTransformer : SingleFeaturePredictionTransformerBase where TModel : class, IPredictorProducing { private readonly GenericScorer _scorer; diff --git a/src/Microsoft.ML.FastTree/BoostingFastTree.cs b/src/Microsoft.ML.FastTree/BoostingFastTree.cs index ef91e6c688..1b2fb7e07c 100644 --- a/src/Microsoft.ML.FastTree/BoostingFastTree.cs +++ b/src/Microsoft.ML.FastTree/BoostingFastTree.cs @@ -14,7 +14,7 @@ namespace Microsoft.ML.Runtime.FastTree { public abstract class BoostingFastTreeTrainerBase : FastTreeTrainerBase - where TTransformer : IPredictionTransformer + where TTransformer : ISingleFeaturePredictionTransformer where TArgs : BoostedTreeArgs, new() where TModel : IPredictorProducing { diff --git a/src/Microsoft.ML.FastTree/FastTree.cs b/src/Microsoft.ML.FastTree/FastTree.cs index 124e1117bd..0d9d5bc192 100644 --- a/src/Microsoft.ML.FastTree/FastTree.cs +++ b/src/Microsoft.ML.FastTree/FastTree.cs @@ -46,7 +46,7 @@ internal static class FastTreeShared public abstract class FastTreeTrainerBase : TrainerEstimatorBase - where TTransformer: IPredictionTransformer + where TTransformer: ISingleFeaturePredictionTransformer where TArgs : TreeArgs, new() where TModel : IPredictorProducing { diff --git a/src/Microsoft.ML.FastTree/RandomForest.cs b/src/Microsoft.ML.FastTree/RandomForest.cs index 5ce40742ca..057841d78c 100644 --- a/src/Microsoft.ML.FastTree/RandomForest.cs +++ b/src/Microsoft.ML.FastTree/RandomForest.cs @@ -11,7 +11,7 @@ namespace Microsoft.ML.Runtime.FastTree public abstract class RandomForestTrainerBase : FastTreeTrainerBase where TArgs : FastForestArgumentsBase, new() where TModel : IPredictorProducing - where TTransformer: IPredictionTransformer + where TTransformer: ISingleFeaturePredictionTransformer { private readonly bool _quantileEnabled;