Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions src/Microsoft.ML.Data/Model/ModelOperationsCatalog.cs
Original file line number Diff line number Diff line change
Expand Up @@ -333,5 +333,22 @@ public PredictionEngine<TSrc, TDst> CreatePredictionEngine<TSrc, TDst>(ITransfor
return transformer.CreatePredictionEngine<TSrc, TDst>(_env, false,
DataViewConstructionUtils.GetSchemaDefinition<TSrc>(_env, inputSchema));
}

/// <summary>
/// Create a prediction engine for one-time prediction.
/// It's mainly used in conjunction with <see cref="Load(Stream, out DataViewSchema)"/>,
/// where input schema is extracted during loading the model.
/// </summary>
/// <typeparam name="TSrc">The class that defines the input data.</typeparam>
/// <typeparam name="TDst">The class that defines the output data.</typeparam>
/// <param name="transformer">The transformer to use for prediction.</param>
/// <param name="options">Advaned configuration options.</param>
Comment thread
michaelgsharp marked this conversation as resolved.
Outdated
public PredictionEngine<TSrc, TDst> CreatePredictionEngine<TSrc, TDst>(ITransformer transformer, PredictionEngine.Options options)
where TSrc : class
where TDst : class, new()
{
return transformer.CreatePredictionEngine<TSrc, TDst>(_env, options.IgnoreMissingColumns,
options.InputSchemaDefinition, options.OutputSchemaDefinition, options.OwnModelFile);
}
}
}
5 changes: 3 additions & 2 deletions src/Microsoft.ML.Data/Model/PredictionEngineExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,11 @@ internal static class PredictionEngineExtensions
/// <typeparamref name="TDst"/>.</param>
/// <param name="inputSchemaDefinition">Additional settings of the input schema.</param>
/// <param name="outputSchemaDefinition">Additional settings of the output schema.</param>
/// <param name="ownModelFile">Whether the prediction engine owns the model file and should dispose of it.</param>
public static PredictionEngine<TSrc, TDst> CreatePredictionEngine<TSrc, TDst>(this ITransformer transformer,
IHostEnvironment env, bool ignoreMissingColumns = true, SchemaDefinition inputSchemaDefinition = null, SchemaDefinition outputSchemaDefinition = null)
IHostEnvironment env, bool ignoreMissingColumns = true, SchemaDefinition inputSchemaDefinition = null, SchemaDefinition outputSchemaDefinition = null, bool ownModelFile = true)
where TSrc : class
where TDst : class, new()
=> new PredictionEngine<TSrc, TDst>(env, transformer, ignoreMissingColumns, inputSchemaDefinition, outputSchemaDefinition);
=> new PredictionEngine<TSrc, TDst>(env, transformer, ignoreMissingColumns, inputSchemaDefinition, outputSchemaDefinition, ownModelFile);
}
}
43 changes: 39 additions & 4 deletions src/Microsoft.ML.Data/Prediction/PredictionEngine.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

using System;
using System.Collections.Generic;
using Microsoft.ML.CommandLine;
using Microsoft.ML.Data;
using Microsoft.ML.Runtime;

Expand Down Expand Up @@ -58,8 +59,8 @@ public sealed class PredictionEngine<TSrc, TDst> : PredictionEngineBase<TSrc, TD
where TDst : class, new()
{
internal PredictionEngine(IHostEnvironment env, ITransformer transformer, bool ignoreMissingColumns,
SchemaDefinition inputSchemaDefinition = null, SchemaDefinition outputSchemaDefinition = null)
: base(env, transformer, ignoreMissingColumns, inputSchemaDefinition, outputSchemaDefinition)
SchemaDefinition inputSchemaDefinition = null, SchemaDefinition outputSchemaDefinition = null, bool ownModelFile = true)
: base(env, transformer, ignoreMissingColumns, inputSchemaDefinition, outputSchemaDefinition, ownModelFile)
{
}

Expand Down Expand Up @@ -92,6 +93,7 @@ public abstract class PredictionEngineBase<TSrc, TDst> : IDisposable
private readonly DataViewConstructionUtils.InputRow<TSrc> _inputRow;
private readonly IRowReadableAs<TDst> _outputRow;
private readonly Action _disposer;
private readonly bool _ownModelFile;
private bool _disposed;

/// <summary>
Expand All @@ -104,14 +106,15 @@ public abstract class PredictionEngineBase<TSrc, TDst> : IDisposable

[BestFriend]
private protected PredictionEngineBase(IHostEnvironment env, ITransformer transformer, bool ignoreMissingColumns,
SchemaDefinition inputSchemaDefinition = null, SchemaDefinition outputSchemaDefinition = null)
SchemaDefinition inputSchemaDefinition = null, SchemaDefinition outputSchemaDefinition = null, bool ownModelFile = true)
{
Contracts.CheckValue(env, nameof(env));
env.AssertValue(transformer);
Transformer = transformer;
var makeMapper = TransformerChecker(env, transformer);
env.AssertValue(makeMapper);
_inputRow = DataViewConstructionUtils.CreateInputRow<TSrc>(env, inputSchemaDefinition);
_ownModelFile = ownModelFile;
PredictionEngineCore(env, _inputRow, makeMapper(_inputRow.Schema), ignoreMissingColumns, outputSchemaDefinition, out _disposer, out _outputRow);
OutputSchema = Transformer.GetOutputSchema(_inputRow.Schema);
}
Expand Down Expand Up @@ -139,7 +142,9 @@ public void Dispose()
return;

_disposer?.Invoke();
(Transformer as IDisposable)?.Dispose();

if (_ownModelFile)
(Transformer as IDisposable)?.Dispose();

_disposed = true;
}
Expand Down Expand Up @@ -170,4 +175,34 @@ public TDst Predict(TSrc example)
/// is reused.</param>
public abstract void Predict(TSrc example, ref TDst prediction);
}

public sealed class PredictionEngine
Comment thread
eerhardt marked this conversation as resolved.
Outdated
{
/// <summary>
/// Options for the <see cref="PredictionEngine{TSrc, TDst}"/> as used in
/// [RandomizedPca(Options)](xref:Microsoft.ML.PcaCatalog.RandomizedPca(Microsoft.ML.AnomalyDetectionCatalog.AnomalyDetectionTrainers,Microsoft.ML.Trainers.RandomizedPcaTrainer.Options)).

@eerhardt eerhardt Oct 11, 2021

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this a copy-paste error? #Resolved

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, resolved.

/// </summary>
public sealed class Options
{
[Argument(ArgumentType.AtMostOnce, HelpText = "Whether to throw an error if a column exists in the output schema but not the output object.", ShortName = "ignore", SortOrder = 50)]
public bool IgnoreMissingColumns = Defaults.IgnoreMissingColumns;

[Argument(ArgumentType.AtMostOnce, HelpText = "Additional settings of the input schema.", ShortName = "input", SortOrder = 50)]
public SchemaDefinition InputSchemaDefinition = Defaults.InputSchemaDefinition;

[Argument(ArgumentType.AtMostOnce, HelpText = "Additional settings of the output schema.", ShortName = "output")]
public SchemaDefinition OutputSchemaDefinition = Defaults.OutputSchemaDefinition;

[Argument(ArgumentType.AtMostOnce, HelpText = "Whether the prediction engine owns the model file and should dispose of it.", ShortName = "own")]
public bool OwnModelFile = Defaults.OwnModelFile;
Comment thread
eerhardt marked this conversation as resolved.
Outdated

internal static class Defaults
{
public const bool IgnoreMissingColumns = true;
public const SchemaDefinition InputSchemaDefinition = null;
public const SchemaDefinition OutputSchemaDefinition = null;
public const bool OwnModelFile = true;
}
}
}
}
37 changes: 37 additions & 0 deletions src/Microsoft.ML.TimeSeries/PredictionEngine.cs
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,15 @@ public TimeSeriesPredictionEngine(IHostEnvironment env, ITransformer transformer
{
}

/// <summary>
/// Contructor for creating time series specific prediction engine. It allows the time series model to be updated with the observations
/// seen at prediction time via <see cref="CheckPoint(IHostEnvironment, string)"/>
/// </summary>
public TimeSeriesPredictionEngine(IHostEnvironment env, ITransformer transformer, PredictionEngine.Options options) :
base(env, CloneTransformers(transformer), options.IgnoreMissingColumns, options.InputSchemaDefinition, options.OutputSchemaDefinition, options.OwnModelFile)
{
}

internal DataViewRow GetStatefulRows(DataViewRow input, IRowToRowMapper mapper, IEnumerable<DataViewSchema.Column> activeColumns, List<StatefulRow> rows)
{
Contracts.CheckValue(input, nameof(input));
Expand Down Expand Up @@ -398,5 +407,33 @@ public static TimeSeriesPredictionEngine<TSrc, TDst> CreateTimeSeriesEngine<TSrc
env.CheckValueOrNull(outputSchemaDefinition);
return new TimeSeriesPredictionEngine<TSrc, TDst>(env, transformer, ignoreMissingColumns, inputSchemaDefinition, outputSchemaDefinition);
}

/// <summary>
/// <see cref="TimeSeriesPredictionEngine{TSrc, TDst}"/> creates a prediction engine for a time series pipeline.
/// It updates the state of time series model with observations seen at prediction phase and allows checkpointing the model.
/// </summary>
/// <typeparam name="TSrc">Class describing input schema to the model.</typeparam>
/// <typeparam name="TDst">Class describing the output schema of the prediction.</typeparam>
/// <param name="transformer">The time series pipeline in the form of a <see cref="ITransformer"/>.</param>
/// <param name="env">Usually <see cref="MLContext"/></param>
/// <param name="options">Advaned configuration options.</param>
Comment thread
eerhardt marked this conversation as resolved.
Outdated
/// <p>Example code can be found by searching for <i>TimeSeriesPredictionEngine</i> in <a href='https://github.com/dotnet/machinelearning'>ML.NET.</a></p>
/// <example>
/// <format type="text/markdown">
/// <![CDATA[
/// This is an example for detecting change point using Singular Spectrum Analysis (SSA) model.
/// [!code-csharp[MF](~/../docs/samples/docs/samples/Microsoft.ML.Samples/Dynamic/Transforms/TimeSeries/DetectChangePointBySsa.cs)]
/// ]]>
/// </format>
/// </example>
public static TimeSeriesPredictionEngine<TSrc, TDst> CreateTimeSeriesEngine<TSrc, TDst>(this ITransformer transformer, IHostEnvironment env,
Comment thread
eerhardt marked this conversation as resolved.
PredictionEngine.Options options)
where TSrc : class
where TDst : class, new()
{
Contracts.CheckValue(env, nameof(env));
env.CheckValue(options, nameof(options));
return new TimeSeriesPredictionEngine<TSrc, TDst>(env, transformer, options);
}
}
}
41 changes: 41 additions & 0 deletions test/Microsoft.ML.IntegrationTests/Prediction.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

using System;
using System.Collections.Generic;
using System.Reflection;
using Microsoft.ML.Calibrators;
using Microsoft.ML.Data;
using Microsoft.ML.IntegrationTests.Datasets;
Expand Down Expand Up @@ -97,5 +98,45 @@ public void ReconfigurablePredictionNoPipeline()
Assert.True(pr.Score <= 0);
}

[Fact]
public void PredictionEngineModelDisposal()
{
var mlContext = new MLContext(seed: 1);
var data = mlContext.Data.LoadFromEnumerable(TypeTestData.GenerateDataset());
var pipeline = mlContext.BinaryClassification.Trainers.LbfgsLogisticRegression(
new Trainers.LbfgsLogisticRegressionBinaryTrainer.Options { NumberOfThreads = 1 });
var model = pipeline.Fit(data);

var engine = mlContext.Model.CreatePredictionEngine<TypeTestData, Prediction>(model, new PredictionEngine.Options());

// Dispose of prediction engine, should dispose of model
engine.Dispose();

// Get disposed flag using reflection
var bfIsDisposed = BindingFlags.Instance | BindingFlags.NonPublic;
var field = model.GetType().BaseType.BaseType.GetField("_disposed", bfIsDisposed);

// Make sure the model is actually disposed
Assert.True((bool)field.GetValue(model));

// Make a new model/prediction engine. Set the options so prediction engine doesn't dispose
model = pipeline.Fit(data);

var options = new PredictionEngine.Options()
{
OwnModelFile = false
};

engine = mlContext.Model.CreatePredictionEngine<TypeTestData, Prediction>(model, options);

// Dispose of prediction engine, shoudln't dispose of model
Comment thread
eerhardt marked this conversation as resolved.
Outdated
engine.Dispose();

// Make sure model is not disposed of.
Assert.False((bool)field.GetValue(model));

// Dispose of the model for test cleanliness
model.Dispose();
}
}
}