diff --git a/src/Microsoft.ML.TensorFlow/TensorflowTransform.cs b/src/Microsoft.ML.TensorFlow/TensorflowTransform.cs index 1537ab10cf..80b9638f77 100644 --- a/src/Microsoft.ML.TensorFlow/TensorflowTransform.cs +++ b/src/Microsoft.ML.TensorFlow/TensorflowTransform.cs @@ -66,6 +66,14 @@ public sealed class TensorFlowTransformer : RowToRowTransformerBase, IDisposable internal const string ShortName = "TFTransform"; internal const string LoaderSignature = "TensorFlowTransform"; + internal static class DefaultModelFileNames + { + public const string VariablesFolder = "variables"; + public const string Index = "variables.index"; + public const string Data = "variables.data-?????-of-?????"; + public const string Graph = "saved_model.pb"; + } + private static VersionInfo GetVersionInfo() { return new VersionInfo( @@ -454,6 +462,34 @@ private protected override void SaveModel(ModelSaveContext ctx) }); } } + else { + ctx.SaveBinaryStream("TFSavedModel", w => + { + // only these files need to be saved. + var modelFilePaths = new List + { + Path.Combine(_savedModelPath, DefaultModelFileNames.Graph), + Path.Combine(_savedModelPath, DefaultModelFileNames.VariablesFolder, DefaultModelFileNames.Index) + }; + modelFilePaths.AddRange(Directory.GetFiles(Path.Combine(_savedModelPath, DefaultModelFileNames.VariablesFolder), DefaultModelFileNames.Data, SearchOption.TopDirectoryOnly)); + + w.Write(modelFilePaths.Count); + + foreach (var fullPath in modelFilePaths) + { + var relativePath = fullPath.Substring(_savedModelPath.Length + 1); + w.Write(relativePath); + + using (var fs = new FileStream(fullPath, FileMode.Open)) + { + long fileLength = fs.Length; + w.Write(fileLength); + long actualWritten = fs.CopyRange(w.BaseStream, fileLength); + Host.Assert(actualWritten == fileLength); + } + } + }); + } Host.AssertNonEmpty(Inputs); ctx.Writer.Write(Inputs.Length); diff --git a/test/Microsoft.ML.Tests/ScenariosWithDirectInstantiation/TensorflowTests.cs b/test/Microsoft.ML.Tests/ScenariosWithDirectInstantiation/TensorflowTests.cs index 3508cfe481..30861521f9 100644 --- a/test/Microsoft.ML.Tests/ScenariosWithDirectInstantiation/TensorflowTests.cs +++ b/test/Microsoft.ML.Tests/ScenariosWithDirectInstantiation/TensorflowTests.cs @@ -1152,6 +1152,79 @@ public void TensorFlowGettingSchemaMultipleTimes() } } + // This test has been created as result of https://github.com/dotnet/machinelearning/issues/5797. + [TensorFlowFact] + public void TensorFlowSaveAndLoadSavedModel() + { + // Create the model and do some predictions + var imageHeight = 32; + var imageWidth = 32; + var modelLocation = "cifar_saved_model"; + var dataFile = GetDataPath("images/images.tsv"); + var imageFolder = Path.GetDirectoryName(dataFile); + + var data = TextLoader.Create(_mlContext, new TextLoader.Options() + { + Columns = new[] + { + new TextLoader.Column("ImagePath", DataKind.String, 0), + new TextLoader.Column("Label", DataKind.String, 1), + } + }, new MultiFileSource(dataFile)); + + var pipeEstimator = new ImageLoadingEstimator(_mlContext, imageFolder, ("ImageReal", "ImagePath")) + .Append(new ImageResizingEstimator(_mlContext, "ImageCropped", imageHeight, imageWidth, "ImageReal")) + .Append(new ImagePixelExtractingEstimator(_mlContext, "Input", "ImageCropped", interleavePixelColors: true)) + .Append(_mlContext.Model.LoadTensorFlowModel(modelLocation).ScoreTensorFlowModel("Output", "Input")) + .Append(new ColumnConcatenatingEstimator(_mlContext, "Features", "Output")) + .Append(new ValueToKeyMappingEstimator(_mlContext, "Label")) + .AppendCacheCheckpoint(_mlContext) + .Append(_mlContext.MulticlassClassification.Trainers.NaiveBayes()); + + + using var transformer = pipeEstimator.Fit(data); + var transformedData = transformer.Transform(data); + var outputSchema = transformer.GetOutputSchema(data.Schema); + + var metrics = _mlContext.MulticlassClassification.Evaluate(transformedData); + Assert.Equal(1, metrics.MicroAccuracy, 2); + + var predictFunction = _mlContext.Model.CreatePredictionEngine(transformer); + var predictions = new[] + { + predictFunction.Predict(new CifarData() { ImagePath = GetDataPath("images/banana.jpg") }), + predictFunction.Predict(new CifarData() { ImagePath = GetDataPath("images/hotdog.jpg") }), + predictFunction.Predict(new CifarData() { ImagePath = GetDataPath("images/tomato.jpg") }) + }; + + // Save the model as a standard ML.NET zip repo + var mlModelLocation = DeleteOutputPath(Path.ChangeExtension(modelLocation, ".zip")); + _mlContext.Model.Save(transformer, data.Schema, mlModelLocation); + transformer.Dispose(); + predictFunction.Dispose(); + + // Reload the model and check the output schema consistency + DataViewSchema loadedInputschema; + var testTransformer = _mlContext.Model.Load(mlModelLocation, out loadedInputschema); + var testOutputSchema = transformer.GetOutputSchema(data.Schema); + Assert.True(TestCommon.CheckSameSchemas(outputSchema, testOutputSchema)); + + // Repeat the predictions with the model loaded as zip repo + var testPredictFunction = _mlContext.Model.CreatePredictionEngine(testTransformer); + var testPredictions = new[] + { + testPredictFunction.Predict(new CifarData() { ImagePath = GetDataPath("images/banana.jpg") }), + testPredictFunction.Predict(new CifarData() { ImagePath = GetDataPath("images/hotdog.jpg") }), + testPredictFunction.Predict(new CifarData() { ImagePath = GetDataPath("images/tomato.jpg") }) + }; + + // Check the predictions consistency + for (var i = 0; i < predictions.Length; i++) { + for (var j = 0; j < predictions[i].PredictedScores.Length; j++) + Assert.Equal(predictions[i].PredictedScores[j], testPredictions[i].PredictedScores[j], 2); + } + } + [TensorFlowFact] public void TensorFlowTransformCifarInvalidShape() {