diff --git a/src/Microsoft.ML.ImageAnalytics/ImageLoader.cs b/src/Microsoft.ML.ImageAnalytics/ImageLoader.cs index a8120cd524..c2f2173f1c 100644 --- a/src/Microsoft.ML.ImageAnalytics/ImageLoader.cs +++ b/src/Microsoft.ML.ImageAnalytics/ImageLoader.cs @@ -217,7 +217,17 @@ private Delegate MakeGetterImageDataViewType(DataViewRow input, int iinfo, Func< { Contracts.AssertValue(input); Contracts.Assert(0 <= iinfo && iinfo < _parent.ColumnPairs.Length); - disposer = null; + var lastImage = default(Bitmap); + + disposer = () => + { + if (lastImage != null) + { + lastImage.Dispose(); + lastImage = null; + } + }; + var getSrc = input.GetGetter>(input.Schema[ColMapNewToOld[iinfo]]); ReadOnlyMemory src = default; ValueGetter del = @@ -247,6 +257,8 @@ private Delegate MakeGetterImageDataViewType(DataViewRow input, int iinfo, Func< if (dst.PixelFormat == System.Drawing.Imaging.PixelFormat.DontCare) throw Host.Except($"Failed to load image {src.ToString()}."); } + + lastImage = dst; }; return del; diff --git a/src/Microsoft.ML.ImageAnalytics/ImageResizer.cs b/src/Microsoft.ML.ImageAnalytics/ImageResizer.cs index 626f5fee6e..00abbf5ad3 100644 --- a/src/Microsoft.ML.ImageAnalytics/ImageResizer.cs +++ b/src/Microsoft.ML.ImageAnalytics/ImageResizer.cs @@ -285,7 +285,6 @@ protected override Delegate MakeGetter(DataViewRow input, int iinfo, Func LoadFromTsv(MLContext mlContext, string tsvPath, string imageFolder) + { + var inMemoryImages = new List(); + var tsvFile = mlContext.Data.LoadFromTextFile(tsvPath, columns: new[] + { + new TextLoader.Column("ImagePath", DataKind.String, 0), + new TextLoader.Column("Label", DataKind.String, 1), + } + ); + + using (var cursor = tsvFile.GetRowCursorForAllColumns()) + { + var pathBuffer = default(ReadOnlyMemory); + var labelBuffer = default(ReadOnlyMemory); + var pathGetter = cursor.GetGetter>(tsvFile.Schema["ImagePath"]); + var labelGetter = cursor.GetGetter>(tsvFile.Schema["Label"]); + while (cursor.MoveNext()) + { + pathGetter(ref pathBuffer); + labelGetter(ref labelBuffer); + + var label = labelBuffer.ToString(); + var fileName = pathBuffer.ToString(); + var imagePath = Path.Combine(imageFolder, fileName); + + inMemoryImages.Add( + new InMemoryImage() + { + Label = label, + LoadedImage = (Bitmap)Image.FromFile(imagePath) + } + ); + } + } + + return inMemoryImages; + + } + } + + public class InMemoryImageOutput : InMemoryImage + { + [ImageType(100, 100)] + public Bitmap ResizedImage; + } + + [Fact] + public void ResizeInMemoryImages() + { + var mlContext = new MLContext(seed: 1); + var dataFile = GetDataPath("images/images.tsv"); + var imageFolder = Path.GetDirectoryName(dataFile); + var dataObjects = InMemoryImage.LoadFromTsv(mlContext, dataFile, imageFolder); + + var dataView = mlContext.Data.LoadFromEnumerable(dataObjects); + var pipeline = mlContext.Transforms.ResizeImages("ResizedImage", 100, 100, nameof(InMemoryImage.LoadedImage)); + + // Check that the output is resized, and that it didn't resize the original image object + var model = pipeline.Fit(dataView); + var resizedDV = model.Transform(dataView); + var rowView = resizedDV.Preview().RowView; + var resizedImage = (Bitmap)rowView.First().Values.Last().Value; + Assert.Equal(100, resizedImage.Height); + Assert.NotEqual(100, dataObjects[0].LoadedImage.Height); + + // Also check usage of prediction Engine + // And that the references to the original image objects aren't lost + var predEngine = mlContext.Model.CreatePredictionEngine(model); + for(int i = 0; i < dataObjects.Count(); i++) + { + var prediction = predEngine.Predict(dataObjects[i]); + Assert.Equal(100, prediction.ResizedImage.Height); + Assert.NotEqual(100, prediction.LoadedImage.Height); + Assert.True(prediction.LoadedImage == dataObjects[i].LoadedImage); + Assert.False(prediction.ResizedImage == dataObjects[i].LoadedImage); + } + + // Check that the last in-memory image hasn't been disposed + // By running ResizeImageTransformer (see https://github.com/dotnet/machinelearning/issues/4126) + bool disposed = false; + try + { + int i = dataObjects.Last().LoadedImage.Height; + } + catch + { + disposed = true; + } + + Assert.False(disposed, "The last in memory image had been disposed by running ResizeImageTransformer"); + } } } diff --git a/test/Microsoft.ML.Tests/ScenariosWithDirectInstantiation/TensorflowTests.cs b/test/Microsoft.ML.Tests/ScenariosWithDirectInstantiation/TensorflowTests.cs index d2357d138a..31050f5d6d 100644 --- a/test/Microsoft.ML.Tests/ScenariosWithDirectInstantiation/TensorflowTests.cs +++ b/test/Microsoft.ML.Tests/ScenariosWithDirectInstantiation/TensorflowTests.cs @@ -4,6 +4,7 @@ using System; using System.Collections.Generic; +using System.Drawing; using System.IO; using System.IO.Compression; using System.Linq; @@ -18,6 +19,7 @@ using Microsoft.ML.Transforms; using Microsoft.ML.Transforms.Image; using Microsoft.ML.TensorFlow; +using InMemoryImage = Microsoft.ML.Tests.ImageTests.InMemoryImage; using Xunit; using Xunit.Abstractions; using static Microsoft.ML.DataOperationsCatalog; @@ -1126,6 +1128,35 @@ public void TensorFlowTransformCifarSavedModel() } } + // This test doesn't really check the values of the results + // Simply checks that CrossValidation is doable with in-memory images + // See issue https://github.com/dotnet/machinelearning/issues/4126 + [TensorFlowFact] + public void TensorFlowTransformCifarCrossValidationWithInMemoryImages() + { + var modelLocation = "cifar_saved_model"; + var mlContext = new MLContext(seed: 1); + using var tensorFlowModel = mlContext.Model.LoadTensorFlowModel(modelLocation); + var schema = tensorFlowModel.GetInputSchema(); + Assert.True(schema.TryGetColumnIndex("Input", out int column)); + var type = (VectorDataViewType)schema[column].Type; + var imageHeight = type.Dimensions[0]; + var imageWidth = type.Dimensions[1]; + var dataFile = GetDataPath("images/images.tsv"); + var imageFolder = Path.GetDirectoryName(dataFile); + var dataObjects = InMemoryImage.LoadFromTsv(mlContext, dataFile, imageFolder); + + var dataView = mlContext.Data.LoadFromEnumerable(dataObjects); + var pipeline = mlContext.Transforms.ResizeImages("ResizedImage", imageWidth, imageHeight, nameof(InMemoryImage.LoadedImage)) + .Append(mlContext.Transforms.ExtractPixels("Input", "ResizedImage", interleavePixelColors: true)) + .Append(tensorFlowModel.ScoreTensorFlowModel("Output", "Input")) + .Append(mlContext.Transforms.Conversion.MapValueToKey("Label")) + .Append(mlContext.MulticlassClassification.Trainers.NaiveBayes("Label", "Output")); + + var cross = mlContext.MulticlassClassification.CrossValidate(dataView, pipeline, 2); + Assert.Equal(2, cross.Count()); + } + // This test has been created as result of https://github.com/dotnet/machinelearning/issues/2156. [TensorFlowFact] public void TensorFlowGettingSchemaMultipleTimes()