Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions src/Microsoft.ML.Data/Model/ModelOperationsCatalog.cs
Original file line number Diff line number Diff line change
Expand Up @@ -333,5 +333,22 @@ public PredictionEngine<TSrc, TDst> CreatePredictionEngine<TSrc, TDst>(ITransfor
return transformer.CreatePredictionEngine<TSrc, TDst>(_env, false,
DataViewConstructionUtils.GetSchemaDefinition<TSrc>(_env, inputSchema));
}

/// <summary>
/// Create a prediction engine for one-time prediction.
/// It's mainly used in conjunction with <see cref="Load(Stream, out DataViewSchema)"/>,
/// where input schema is extracted during loading the model.
/// </summary>
/// <typeparam name="TSrc">The class that defines the input data.</typeparam>
/// <typeparam name="TDst">The class that defines the output data.</typeparam>
/// <param name="transformer">The transformer to use for prediction.</param>
/// <param name="options">Advanced configuration options.</param>
public PredictionEngine<TSrc, TDst> CreatePredictionEngine<TSrc, TDst>(ITransformer transformer, PredictionEngineOptions options)
where TSrc : class
where TDst : class, new()
{
return transformer.CreatePredictionEngine<TSrc, TDst>(_env, options.IgnoreMissingColumns,
options.InputSchemaDefinition, options.OutputSchemaDefinition, options.OwnsTransformer);
}
}
}
5 changes: 3 additions & 2 deletions src/Microsoft.ML.Data/Model/PredictionEngineExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,11 @@ internal static class PredictionEngineExtensions
/// <typeparamref name="TDst"/>.</param>
/// <param name="inputSchemaDefinition">Additional settings of the input schema.</param>
/// <param name="outputSchemaDefinition">Additional settings of the output schema.</param>
/// <param name="ownsTransformer">Whether the prediction engine owns the transformer and should dispose of it.</param>
public static PredictionEngine<TSrc, TDst> CreatePredictionEngine<TSrc, TDst>(this ITransformer transformer,
IHostEnvironment env, bool ignoreMissingColumns = true, SchemaDefinition inputSchemaDefinition = null, SchemaDefinition outputSchemaDefinition = null)
IHostEnvironment env, bool ignoreMissingColumns = true, SchemaDefinition inputSchemaDefinition = null, SchemaDefinition outputSchemaDefinition = null, bool ownsTransformer = true)
where TSrc : class
where TDst : class, new()
=> new PredictionEngine<TSrc, TDst>(env, transformer, ignoreMissingColumns, inputSchemaDefinition, outputSchemaDefinition);
=> new PredictionEngine<TSrc, TDst>(env, transformer, ignoreMissingColumns, inputSchemaDefinition, outputSchemaDefinition, ownsTransformer);
}
}
39 changes: 35 additions & 4 deletions src/Microsoft.ML.Data/Prediction/PredictionEngine.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

using System;
using System.Collections.Generic;
using Microsoft.ML.CommandLine;
using Microsoft.ML.Data;
using Microsoft.ML.Runtime;

Expand Down Expand Up @@ -58,8 +59,8 @@ public sealed class PredictionEngine<TSrc, TDst> : PredictionEngineBase<TSrc, TD
where TDst : class, new()
{
internal PredictionEngine(IHostEnvironment env, ITransformer transformer, bool ignoreMissingColumns,
SchemaDefinition inputSchemaDefinition = null, SchemaDefinition outputSchemaDefinition = null)
: base(env, transformer, ignoreMissingColumns, inputSchemaDefinition, outputSchemaDefinition)
SchemaDefinition inputSchemaDefinition = null, SchemaDefinition outputSchemaDefinition = null, bool ownsTransformer = true)
: base(env, transformer, ignoreMissingColumns, inputSchemaDefinition, outputSchemaDefinition, ownsTransformer)
{
}

Expand Down Expand Up @@ -92,6 +93,7 @@ public abstract class PredictionEngineBase<TSrc, TDst> : IDisposable
private readonly DataViewConstructionUtils.InputRow<TSrc> _inputRow;
private readonly IRowReadableAs<TDst> _outputRow;
private readonly Action _disposer;
private readonly bool _ownsTransformer;
private bool _disposed;

/// <summary>
Expand All @@ -104,14 +106,15 @@ public abstract class PredictionEngineBase<TSrc, TDst> : IDisposable

[BestFriend]
private protected PredictionEngineBase(IHostEnvironment env, ITransformer transformer, bool ignoreMissingColumns,
SchemaDefinition inputSchemaDefinition = null, SchemaDefinition outputSchemaDefinition = null)
SchemaDefinition inputSchemaDefinition = null, SchemaDefinition outputSchemaDefinition = null, bool ownsTransformer = true)
{
Contracts.CheckValue(env, nameof(env));
env.AssertValue(transformer);
Transformer = transformer;
var makeMapper = TransformerChecker(env, transformer);
env.AssertValue(makeMapper);
_inputRow = DataViewConstructionUtils.CreateInputRow<TSrc>(env, inputSchemaDefinition);
_ownsTransformer = ownsTransformer;
PredictionEngineCore(env, _inputRow, makeMapper(_inputRow.Schema), ignoreMissingColumns, outputSchemaDefinition, out _disposer, out _outputRow);
OutputSchema = Transformer.GetOutputSchema(_inputRow.Schema);
}
Expand Down Expand Up @@ -139,7 +142,9 @@ public void Dispose()
return;

_disposer?.Invoke();
(Transformer as IDisposable)?.Dispose();

if (_ownsTransformer)
(Transformer as IDisposable)?.Dispose();

_disposed = true;
}
Expand Down Expand Up @@ -170,4 +175,30 @@ public TDst Predict(TSrc example)
/// is reused.</param>
public abstract void Predict(TSrc example, ref TDst prediction);
}

/// <summary>
/// Options for the <see cref="PredictionEngine{TSrc, TDst}"/>
/// </summary>
public sealed class PredictionEngineOptions
{
[Argument(ArgumentType.AtMostOnce, HelpText = "Whether to throw an error if a column exists in the output schema but not the output object.", ShortName = "ignore", SortOrder = 50)]
public bool IgnoreMissingColumns = Defaults.IgnoreMissingColumns;

[Argument(ArgumentType.AtMostOnce, HelpText = "Additional settings of the input schema.", ShortName = "input", SortOrder = 50)]
public SchemaDefinition InputSchemaDefinition = Defaults.InputSchemaDefinition;

[Argument(ArgumentType.AtMostOnce, HelpText = "Additional settings of the output schema.", ShortName = "output")]
public SchemaDefinition OutputSchemaDefinition = Defaults.OutputSchemaDefinition;

[Argument(ArgumentType.AtMostOnce, HelpText = "Whether the prediction engine owns the transformer and should dispose of it.", ShortName = "own")]
public bool OwnsTransformer = Defaults.OwnsTransformer;

internal static class Defaults
{
public const bool IgnoreMissingColumns = true;
public const SchemaDefinition InputSchemaDefinition = null;
public const SchemaDefinition OutputSchemaDefinition = null;
public const bool OwnsTransformer = true;
}
}
}
37 changes: 37 additions & 0 deletions src/Microsoft.ML.TimeSeries/PredictionEngine.cs
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,15 @@ public TimeSeriesPredictionEngine(IHostEnvironment env, ITransformer transformer
{
}

/// <summary>
/// Contructor for creating time series specific prediction engine. It allows the time series model to be updated with the observations
/// seen at prediction time via <see cref="CheckPoint(IHostEnvironment, string)"/>
/// </summary>
internal TimeSeriesPredictionEngine(IHostEnvironment env, ITransformer transformer, PredictionEngineOptions options) :
base(env, CloneTransformers(transformer), options.IgnoreMissingColumns, options.InputSchemaDefinition, options.OutputSchemaDefinition, options.OwnsTransformer)
{
}

internal DataViewRow GetStatefulRows(DataViewRow input, IRowToRowMapper mapper, IEnumerable<DataViewSchema.Column> activeColumns, List<StatefulRow> rows)
{
Contracts.CheckValue(input, nameof(input));
Expand Down Expand Up @@ -398,5 +407,33 @@ public static TimeSeriesPredictionEngine<TSrc, TDst> CreateTimeSeriesEngine<TSrc
env.CheckValueOrNull(outputSchemaDefinition);
return new TimeSeriesPredictionEngine<TSrc, TDst>(env, transformer, ignoreMissingColumns, inputSchemaDefinition, outputSchemaDefinition);
}

/// <summary>
/// <see cref="TimeSeriesPredictionEngine{TSrc, TDst}"/> creates a prediction engine for a time series pipeline.
/// It updates the state of time series model with observations seen at prediction phase and allows checkpointing the model.
/// </summary>
/// <typeparam name="TSrc">Class describing input schema to the model.</typeparam>
/// <typeparam name="TDst">Class describing the output schema of the prediction.</typeparam>
/// <param name="transformer">The time series pipeline in the form of a <see cref="ITransformer"/>.</param>
/// <param name="env">Usually <see cref="MLContext"/></param>
/// <param name="options">Advanced configuration options.</param>
/// <p>Example code can be found by searching for <i>TimeSeriesPredictionEngine</i> in <a href='https://github.com/dotnet/machinelearning'>ML.NET.</a></p>
/// <example>
/// <format type="text/markdown">
/// <![CDATA[
/// This is an example for detecting change point using Singular Spectrum Analysis (SSA) model.
/// [!code-csharp[MF](~/../docs/samples/docs/samples/Microsoft.ML.Samples/Dynamic/Transforms/TimeSeries/DetectChangePointBySsa.cs)]
/// ]]>
/// </format>
/// </example>
public static TimeSeriesPredictionEngine<TSrc, TDst> CreateTimeSeriesEngine<TSrc, TDst>(this ITransformer transformer, IHostEnvironment env,
PredictionEngineOptions options)
where TSrc : class
where TDst : class, new()
{
Contracts.CheckValue(env, nameof(env));
env.CheckValue(options, nameof(options));
return new TimeSeriesPredictionEngine<TSrc, TDst>(env, transformer, options);
}
}
}
41 changes: 41 additions & 0 deletions test/Microsoft.ML.IntegrationTests/Prediction.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

using System;
using System.Collections.Generic;
using System.Reflection;
using Microsoft.ML.Calibrators;
using Microsoft.ML.Data;
using Microsoft.ML.IntegrationTests.Datasets;
Expand Down Expand Up @@ -97,5 +98,45 @@ public void ReconfigurablePredictionNoPipeline()
Assert.True(pr.Score <= 0);
}

[Fact]
public void PredictionEngineModelDisposal()
{
var mlContext = new MLContext(seed: 1);
var data = mlContext.Data.LoadFromEnumerable(TypeTestData.GenerateDataset());
var pipeline = mlContext.BinaryClassification.Trainers.LbfgsLogisticRegression(
new Trainers.LbfgsLogisticRegressionBinaryTrainer.Options { NumberOfThreads = 1 });
var model = pipeline.Fit(data);

var engine = mlContext.Model.CreatePredictionEngine<TypeTestData, Prediction>(model, new PredictionEngineOptions());

// Dispose of prediction engine, should dispose of model
engine.Dispose();

// Get disposed flag using reflection
var bfIsDisposed = BindingFlags.Instance | BindingFlags.NonPublic;
var field = model.GetType().BaseType.BaseType.GetField("_disposed", bfIsDisposed);

// Make sure the model is actually disposed
Assert.True((bool)field.GetValue(model));

// Make a new model/prediction engine. Set the options so prediction engine doesn't dispose
model = pipeline.Fit(data);

var options = new PredictionEngineOptions()
{
OwnsTransformer = false
};

engine = mlContext.Model.CreatePredictionEngine<TypeTestData, Prediction>(model, options);

// Dispose of prediction engine, shouldn't dispose of model
engine.Dispose();

// Make sure model is not disposed of.
Assert.False((bool)field.GetValue(model));

// Dispose of the model for test cleanliness
model.Dispose();
}
}
}