Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion src/Microsoft.ML.FastTree/FastTree.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3046,7 +3046,7 @@ private enum AggregateFunction
Max
}

bool ISingleCanSaveOnnx.SaveAsOnnx(OnnxContext ctx, string[] outputNames, string featureColumn)
private protected virtual bool SaveAsOnnx(OnnxContext ctx, string[] outputNames, string featureColumn)
{
Host.CheckValue(ctx, nameof(ctx));

Expand Down Expand Up @@ -3132,6 +3132,10 @@ bool ISingleCanSaveOnnx.SaveAsOnnx(OnnxContext ctx, string[] outputNames, string

return true;
}
bool ISingleCanSaveOnnx.SaveAsOnnx(OnnxContext ctx, string[] outputNames, string featureColumn)
{
return SaveAsOnnx(ctx,outputNames,featureColumn);
}

void ICanSaveSummary.SaveSummary(TextWriter writer, RoleMappedSchema schema)
{
Expand Down
13 changes: 12 additions & 1 deletion src/Microsoft.ML.FastTree/FastTreeTweedie.cs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
using Microsoft.ML.EntryPoints;
using Microsoft.ML.Internal.Utilities;
using Microsoft.ML.Model;
using Microsoft.ML.Model.OnnxConverter;
using Microsoft.ML.Runtime;
using Microsoft.ML.Trainers.FastTree;

Expand Down Expand Up @@ -482,7 +483,7 @@ protected override void GetGradientInOneQuery(int query, int threadIndex)
/// <summary>
/// Model parameters for <see cref="FastTreeTweedieTrainer"/>.
/// </summary>
public sealed class FastTreeTweedieModelParameters : TreeEnsembleModelParametersBasedOnRegressionTree
public sealed class FastTreeTweedieModelParameters : TreeEnsembleModelParametersBasedOnRegressionTree, ISingleCanSaveOnnx
{
internal const string LoaderSignature = "FastTreeTweedieExec";
internal const string RegistrationName = "FastTreeTweediePredictor";
Expand Down Expand Up @@ -530,6 +531,16 @@ private static FastTreeTweedieModelParameters Create(IHostEnvironment env, Model
return new FastTreeTweedieModelParameters(env, ctx);
}

bool ISingleCanSaveOnnx.SaveAsOnnx(OnnxContext ctx, string[] outputNames, string featureColumn)
{
// Mapping score to prediction
var fastTreeOutput = ctx.AddIntermediateVariable(null, "FastTreeOutput", true);
base.SaveAsOnnx(ctx, new[] { fastTreeOutput }, featureColumn);
var opType = "Exp";
ctx.CreateNode(opType, new[] { fastTreeOutput }, outputNames, ctx.GetNodeName(opType), "");
return true;
}

private protected override void Map(in VBuffer<float> src, ref float dst)
{
// The value learnt and predicted by the trees is the log of the expected value,
Expand Down
15 changes: 14 additions & 1 deletion src/Microsoft.ML.FastTree/RandomForestRegression.cs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
using Microsoft.ML.EntryPoints;
using Microsoft.ML.Internal.Utilities;
using Microsoft.ML.Model;
using Microsoft.ML.Model.OnnxConverter;
using Microsoft.ML.Runtime;
using Microsoft.ML.Trainers.FastTree;

Expand All @@ -30,7 +31,8 @@ namespace Microsoft.ML.Trainers.FastTree
public sealed class FastForestRegressionModelParameters :
TreeEnsembleModelParametersBasedOnQuantileRegressionTree,
IQuantileValueMapper,
IQuantileRegressionPredictor
IQuantileRegressionPredictor,
ISingleCanSaveOnnx
{
private sealed class QuantileStatistics
{
Expand Down Expand Up @@ -209,6 +211,17 @@ private static FastForestRegressionModelParameters Create(IHostEnvironment env,

private protected override PredictionKind PredictionKind => PredictionKind.Regression;

bool ISingleCanSaveOnnx.SaveAsOnnx(OnnxContext ctx, string[] outputNames, string featureColumn)
{
// Mapping score to prediction
var fastTreeOutput = ctx.AddIntermediateVariable(null, "FastTreeOutput", true);
var numTrees = ctx.AddInitializer((float)TrainedEnsemble.NumTrees, "NumTrees");
base.SaveAsOnnx(ctx, new[] { fastTreeOutput }, featureColumn);
var opType = "Div";
ctx.CreateNode(opType, new[] { fastTreeOutput, numTrees }, outputNames, ctx.GetNodeName(opType), "");
return true;
}

private protected override void Map(in VBuffer<float> src, ref float dst)
{
int inputVectorSize = InputType.GetVectorSize();
Expand Down
39 changes: 26 additions & 13 deletions src/Microsoft.ML.StandardTrainers/Standard/LinearModelParameters.cs
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,20 @@ internal LinearModelParameters(IHostEnvironment env, string name, in VBuffer<flo
_weightsDenseLock = new object();
}

private protected virtual bool SaveAsOnnx(OnnxContext ctx, string[] outputs, string featureColumn)
{
Host.CheckValue(ctx, nameof(ctx));
Host.Check(Utils.Size(outputs) == 1);
string opType = "LinearRegressor";
var node = ctx.CreateNode(opType, new[] { featureColumn }, outputs, ctx.GetNodeName(opType));
// Selection of logit or probit output transform. enum {'NONE', 'LOGIT', 'PROBIT}
node.AddAttribute("post_transform", "NONE");
node.AddAttribute("targets", 1);
node.AddAttribute("coefficients", Weight.DenseValues());
node.AddAttribute("intercepts", new float[] { Bias });
return true;
}

private protected LinearModelParameters(IHostEnvironment env, string name, ModelLoadContext ctx)
: base(env, name, ctx)
{
Expand Down Expand Up @@ -188,7 +202,6 @@ private protected LinearModelParameters(IHostEnvironment env, string name, Model
else
_weightsDenseLock = new object();
}

[BestFriend]
private protected override void SaveCore(ModelSaveContext ctx)
{
Expand Down Expand Up @@ -239,17 +252,7 @@ JToken ISingleCanSavePfa.SaveAsPfa(BoundPfaContext ctx, JToken input)

bool ISingleCanSaveOnnx.SaveAsOnnx(OnnxContext ctx, string[] outputs, string featureColumn)
{
Host.CheckValue(ctx, nameof(ctx));
Host.Check(Utils.Size(outputs) == 1);

string opType = "LinearRegressor";
var node = ctx.CreateNode(opType, new[] { featureColumn }, outputs, ctx.GetNodeName(opType));
// Selection of logit or probit output transform. enum {'NONE', 'LOGIT', 'PROBIT}
node.AddAttribute("post_transform", "NONE");
node.AddAttribute("targets", 1);
node.AddAttribute("coefficients", Weight.DenseValues());
node.AddAttribute("intercepts", new float[] { Bias });
return true;
return SaveAsOnnx(ctx, outputs, featureColumn);
}

// Generate the score from the given values, assuming they have already been normalized.
Expand Down Expand Up @@ -685,7 +688,7 @@ IList<KeyValuePair<string, object>> ICanGetSummaryInKeyValuePairs.GetSummaryInKe
/// <summary>
/// Model parameters for Poisson Regression.
/// </summary>
public sealed class PoissonRegressionModelParameters : RegressionModelParameters, IParameterMixer<float>
public sealed class PoissonRegressionModelParameters : RegressionModelParameters, IParameterMixer<float>, ISingleCanSaveOnnx
{
internal const string LoaderSignature = "PoissonRegressionExec";
internal const string RegistrationName = "PoissonRegressionPredictor";
Expand Down Expand Up @@ -727,6 +730,16 @@ private static PoissonRegressionModelParameters Create(IHostEnvironment env, Mod
return new PoissonRegressionModelParameters(env, ctx);
}

bool ISingleCanSaveOnnx.SaveAsOnnx(OnnxContext ctx, string[] outputs, string featureColumn)
{
// Mapping score to prediction
var linearRegressorOutput = ctx.AddIntermediateVariable(null,"LinearRegressorOutput",true);
base.SaveAsOnnx(ctx, new[] { linearRegressorOutput }, featureColumn);
var opType = "Exp";
ctx.CreateNode(opType, new[] { linearRegressorOutput }, outputs, ctx.GetNodeName(opType), "");
return true;
}

private protected override void SaveCore(ModelSaveContext ctx)
{
base.SaveCore(ctx);
Expand Down
Loading