diff --git a/src/Microsoft.ML.Data/Model/ModelOperationsCatalog.cs b/src/Microsoft.ML.Data/Model/ModelOperationsCatalog.cs index 8524a5b7bf..b449bdfa8e 100644 --- a/src/Microsoft.ML.Data/Model/ModelOperationsCatalog.cs +++ b/src/Microsoft.ML.Data/Model/ModelOperationsCatalog.cs @@ -17,7 +17,7 @@ namespace Microsoft.ML /// public sealed class ModelOperationsCatalog : IInternalCatalog { - private const string SchemaEntryName = "Schema"; + internal const string SchemaEntryName = "Schema"; IHostEnvironment IInternalCatalog.Environment => _env; private readonly IHostEnvironment _env; diff --git a/src/Microsoft.ML.Data/Utilities/ModelFileUtils.cs b/src/Microsoft.ML.Data/Utilities/ModelFileUtils.cs index 46daa56578..80bf00a761 100644 --- a/src/Microsoft.ML.Data/Utilities/ModelFileUtils.cs +++ b/src/Microsoft.ML.Data/Utilities/ModelFileUtils.cs @@ -25,6 +25,8 @@ internal static class ModelFileUtils { public const string DirPredictor = "Predictor"; public const string DirDataLoaderModel = "DataLoaderModel"; + public const string DirTransformerChain = TransformerChain.LoaderSignature; + public const string SchemaEntryName = ModelOperationsCatalog.SchemaEntryName; // ResultsProcessor needs access to this constant. public const string DirTrainingInfo = "TrainingInfo"; @@ -64,6 +66,15 @@ public static IDataView LoadPipeline(IHostEnvironment env, RepositoryReader rep, Contracts.CheckValue(env, nameof(env)); env.CheckValue(rep, nameof(rep)); env.CheckValue(files, nameof(files)); + + var entry = rep.OpenEntryOrNull(SchemaEntryName); + if (entry != null) + { + var loader = new BinaryLoader(env, new BinaryLoader.Arguments(), entry.Stream); + ModelLoadContext.LoadModel(env, out var transformerChain, rep, DirTransformerChain); + return transformerChain.Transform(loader); + } + using (var ent = rep.OpenEntry(DirDataLoaderModel, ModelLoadContext.ModelStreamName)) { ILegacyDataLoader loader; diff --git a/test/Microsoft.ML.Core.Tests/UnitTests/TestEntryPoints.cs b/test/Microsoft.ML.Core.Tests/UnitTests/TestEntryPoints.cs index 06796a6cad..d5bf30ab96 100644 --- a/test/Microsoft.ML.Core.Tests/UnitTests/TestEntryPoints.cs +++ b/test/Microsoft.ML.Core.Tests/UnitTests/TestEntryPoints.cs @@ -612,6 +612,100 @@ public void EntryPointExecGraphCommand() cmd.Run(); } + [Fact] + public void ScoreTransformerChainModel() + { + var dataPath = GetDataPath("wikipedia-detox-250-line-data.tsv"); + var modelPath = DeleteOutputPath("model.zip"); + var outputDataPath = DeleteOutputPath("scored.idv"); + + var mlContext = new MLContext(); + + var data = new TextLoader(mlContext, + new TextLoader.Options() + { + AllowQuoting = true, + Separator = "\t", + HasHeader = true, + Columns = new[] + { + new TextLoader.Column("Label", DataKind.Boolean, 0), + new TextLoader.Column("SentimentText", DataKind.String, 1) + } + }).Load(dataPath); + + var pipeline = mlContext.Transforms.Text.FeaturizeText("Features", "SentimentText") + .Append(mlContext.BinaryClassification.Trainers.AveragedPerceptron()); + + var model = pipeline.Fit(data); + + mlContext.Model.Save(model, data.Schema, modelPath); + + string inputGraph = string.Format(@" + {{ + 'Inputs': {{ + 'file': '{0}', + 'transform_model': '{1}' + }}, + 'Nodes': [ + {{ + 'Name': 'Data.CustomTextLoader', + 'Inputs': {{ + 'CustomSchema': 'col=Sentiment:I8:0 col=SentimentText:TX:1 quote+ header=+ sep=tab', + 'InputFile': '$file' + }}, + 'Outputs': {{ + 'Data': '$data' + }} + }}, + {{ + 'Name': 'Transforms.DatasetTransformScorer', + 'Inputs': {{ + 'Data': '$data', + 'TransformModel': '$transform_model' + }}, + 'Outputs': {{ + 'ScoredData': '$scoredVectorData' + }} + }}, + {{ + 'Inputs': {{ + 'Data': '$scoredVectorData' + }}, + 'Name': 'Transforms.ScoreColumnSelector', + 'Outputs': {{ + 'OutputData': '$scoreColumnsOnlyData' + }} + }}, + {{ + 'Inputs': {{ + 'Data': '$scoreColumnsOnlyData', + 'PredictedLabelColumn': 'PredictedLabel' + }}, + 'Name': 'Transforms.PredictedLabelColumnOriginalValueConverter', + 'Outputs': {{ + 'OutputData': '$output_data' + }} + }} + ], + 'Outputs': {{ + 'output_data': '{2}' + }} + }}", EscapePath(dataPath), EscapePath(modelPath), EscapePath(outputDataPath)); + + var jsonPath = DeleteOutputPath("graph.json"); + File.WriteAllLines(jsonPath, new[] { inputGraph }); + + var args = new ExecuteGraphCommand.Arguments() { GraphPath = jsonPath }; + var cmd = new ExecuteGraphCommand(Env, args); + cmd.Run(); + + var loadedData = mlContext.Data.LoadFromBinary(outputDataPath); + + Assert.NotNull(loadedData.Schema.GetColumnOrNull("PredictedLabel")); + Assert.NotNull(loadedData.Schema.GetColumnOrNull("Score")); + } + //[Fact] //public void EntryPointArrayOfVariables() //{