From 6a6c361095a57631df1c2e32e7a15e54a7d16da6 Mon Sep 17 00:00:00 2001 From: Ivan Matantsev Date: Mon, 30 Jul 2018 14:09:48 -0700 Subject: [PATCH] support validation and incremental trainers --- src/Microsoft.ML.FastTree/FastTree.cs | 2 +- .../FactorizationMachine/FactorizationMachineTrainer.cs | 2 +- .../Standard/Online/OnlineLinear.cs | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/Microsoft.ML.FastTree/FastTree.cs b/src/Microsoft.ML.FastTree/FastTree.cs index fad40495f4..9f24b4bc09 100644 --- a/src/Microsoft.ML.FastTree/FastTree.cs +++ b/src/Microsoft.ML.FastTree/FastTree.cs @@ -95,7 +95,7 @@ private protected FastTreeTrainerBase(IHostEnvironment env, TArgs args) // The discretization step renders this trainer non-parametric, and therefore it does not need normalization. // Also since it builds its own internal discretized columnar structures, it cannot benefit from caching. // Finally, even the binary classifiers, being logitboost, tend to not benefit from external calibration. - Info = new TrainerInfo(normalization: false, caching: false, calibration: NeedCalibration); + Info = new TrainerInfo(normalization: false, caching: false, calibration: NeedCalibration, supportValid: true); int numThreads = Args.NumThreads ?? Environment.ProcessorCount; if (Host.ConcurrencyFactor > 0 && numThreads > Host.ConcurrencyFactor) { diff --git a/src/Microsoft.ML.StandardLearners/FactorizationMachine/FactorizationMachineTrainer.cs b/src/Microsoft.ML.StandardLearners/FactorizationMachine/FactorizationMachineTrainer.cs index 62270763de..0980cb22dc 100644 --- a/src/Microsoft.ML.StandardLearners/FactorizationMachine/FactorizationMachineTrainer.cs +++ b/src/Microsoft.ML.StandardLearners/FactorizationMachine/FactorizationMachineTrainer.cs @@ -103,7 +103,7 @@ public FieldAwareFactorizationMachineTrainer(IHostEnvironment env, Arguments arg _shuffle = args.Shuffle; _verbose = args.Verbose; _radius = args.Radius; - Info = new TrainerInfo(); + Info = new TrainerInfo(supportValid: true, supportIncrementalTrain: true); } private void InitializeTrainingState(int fieldCount, int featureCount, FieldAwareFactorizationMachinePredictor predictor, out float[] linearWeights, diff --git a/src/Microsoft.ML.StandardLearners/Standard/Online/OnlineLinear.cs b/src/Microsoft.ML.StandardLearners/Standard/Online/OnlineLinear.cs index 60fe7f9705..4976bf20d3 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/Online/OnlineLinear.cs +++ b/src/Microsoft.ML.StandardLearners/Standard/Online/OnlineLinear.cs @@ -84,7 +84,7 @@ protected OnlineLinearTrainer(TArguments args, IHostEnvironment env, string name Args = args; // REVIEW: Caching could be false for one iteration, if we got around the whole shuffling issue. - Info = new TrainerInfo(calibration: NeedCalibration); + Info = new TrainerInfo(calibration: NeedCalibration, supportIncrementalTrain: true); } ///