From 2a2b6557b901fcfde44f2a80890a8093fff51b3c Mon Sep 17 00:00:00 2001 From: "Harish S. Kulkarni" Date: Tue, 31 Mar 2020 14:44:29 -0700 Subject: [PATCH 1/7] Added the assembly name of the custom transform to the model file --- .../CustomMappingTransformer.cs | 5 ++++- .../LambdaTransform.cs | 21 +++++++++++++++---- .../UnitTests/TestCustomTypeRegister.cs | 6 +++++- 3 files changed, 26 insertions(+), 6 deletions(-) diff --git a/src/Microsoft.ML.Transforms/CustomMappingTransformer.cs b/src/Microsoft.ML.Transforms/CustomMappingTransformer.cs index 7a781128a5..76dae2b75c 100644 --- a/src/Microsoft.ML.Transforms/CustomMappingTransformer.cs +++ b/src/Microsoft.ML.Transforms/CustomMappingTransformer.cs @@ -4,6 +4,7 @@ using System; using System.Linq; +using System.Reflection; using Microsoft.ML.Data; using Microsoft.ML.Internal.Utilities; using Microsoft.ML.Runtime; @@ -22,6 +23,7 @@ public sealed class CustomMappingTransformer : ITransformer private readonly IHost _host; private readonly Action _mapAction; private readonly string _contractName; + private readonly string _contractAssembly; internal InternalSchemaDefinition AddedSchema { get; } internal SchemaDefinition InputSchemaDefinition { get; } @@ -58,6 +60,7 @@ internal CustomMappingTransformer(IHostEnvironment env, Action mapAc : InternalSchemaDefinition.Create(typeof(TDst), outputSchemaDefinition); _contractName = contractName; + _contractAssembly = _mapAction.Method.DeclaringType.Assembly.FullName; AddedSchema = outSchema; } @@ -67,7 +70,7 @@ internal void SaveModel(ModelSaveContext ctx) { if (_contractName == null) throw _host.Except("Empty contract name for a transform: the transform cannot be saved"); - LambdaTransform.SaveCustomTransformer(_host, ctx, _contractName); + LambdaTransform.SaveCustomTransformer(_host, ctx, _contractName, _contractAssembly); } /// diff --git a/src/Microsoft.ML.Transforms/LambdaTransform.cs b/src/Microsoft.ML.Transforms/LambdaTransform.cs index 4ba05ee4f8..869c311225 100644 --- a/src/Microsoft.ML.Transforms/LambdaTransform.cs +++ b/src/Microsoft.ML.Transforms/LambdaTransform.cs @@ -3,7 +3,9 @@ // See the LICENSE file in the project root for more information. using System; +using System.Diagnostics.Contracts; using System.IO; +using System.Reflection; using System.Text; using Microsoft.ML; using Microsoft.ML.Data; @@ -40,14 +42,17 @@ private static VersionInfo GetVersionInfo() { return new VersionInfo( modelSignature: "CUSTOMXF", - verWrittenCur: 0x00010001, - verReadableCur: 0x00010001, + //verWrittenCur: 0x00010001, // Initial + verWrittenCur: 0x00010002, // Added name of assembly in which the contractName is present + verReadableCur: 0x00010002, verWeCanReadBack: 0x00010001, loaderSignature: LoaderSignature, loaderAssemblyName: typeof(LambdaTransform).Assembly.FullName); } - internal static void SaveCustomTransformer(IExceptionContext ectx, ModelSaveContext ctx, string contractName) + private const uint VerAssemblyNameSaved = 0x00010002; + + internal static void SaveCustomTransformer(IExceptionContext ectx, ModelSaveContext ctx, string contractName, string contractAssembly) { ectx.CheckValue(ctx, nameof(ctx)); ectx.CheckValue(contractName, nameof(contractName)); @@ -56,6 +61,7 @@ internal static void SaveCustomTransformer(IExceptionContext ectx, ModelSaveCont ctx.SetVersionInfo(GetVersionInfo()); ctx.SaveString(contractName); + ctx.SaveString(contractAssembly); } // Factory for SignatureLoadModel. @@ -63,9 +69,16 @@ private static ITransformer Create(IHostEnvironment env, ModelLoadContext ctx) { Contracts.CheckValue(env, nameof(env)); env.CheckValue(ctx, nameof(ctx)); - ctx.CheckAtModel(GetVersionInfo()); + var versionInfo = GetVersionInfo(); + ctx.CheckAtModel(versionInfo); var contractName = ctx.LoadString(); + if (ctx.Header.ModelVerWritten >= VerAssemblyNameSaved) + { + var contractAssembly = ctx.LoadString(); + Assembly assembly = Assembly.Load(contractAssembly); + env.ComponentCatalog.RegisterAssembly(assembly); + } object factoryObject = env.ComponentCatalog.GetExtensionValue(env, typeof(CustomMappingFactoryAttributeAttribute), contractName); if (!(factoryObject is ICustomMappingFactory mappingFactory)) diff --git a/test/Microsoft.ML.Core.Tests/UnitTests/TestCustomTypeRegister.cs b/test/Microsoft.ML.Core.Tests/UnitTests/TestCustomTypeRegister.cs index b0660bfe08..a746bd762b 100644 --- a/test/Microsoft.ML.Core.Tests/UnitTests/TestCustomTypeRegister.cs +++ b/test/Microsoft.ML.Core.Tests/UnitTests/TestCustomTypeRegister.cs @@ -184,6 +184,10 @@ public void RegisterTypeWithAttribute() var tribeTransformed = model.Transform(tribeDataView); var tribeEnumerable = ML.Data.CreateEnumerable(tribeTransformed, false).ToList(); + // save and reload the model + ML.Model.Save(model, tribeDataView.Schema, "customTransform.zip"); + var modelSaved = ML.Model.Load("customTransform.zip", out var tribeDataViewSaved); + // Make sure the pipeline output is correct. Assert.Equal(tribeEnumerable[0].Name, "Super " + tribe[0].Name); Assert.Equal(tribeEnumerable[0].Merged.Age, tribe[0].One.Age + tribe[0].Two.Age); @@ -192,7 +196,7 @@ public void RegisterTypeWithAttribute() Assert.Equal(tribeEnumerable[0].Merged.HandCount, tribe[0].One.HandCount + tribe[0].Two.HandCount); // Build prediction engine from the trained pipeline. - var engine = ML.Model.CreatePredictionEngine(model); + var engine = ML.Model.CreatePredictionEngine(modelSaved); var alien = new AlienHero("TEN.LM", 1, 2, 3, 4, 5, 6, 7, 8); var superAlien = engine.Predict(alien); From 87eb3708e98e631d975c1ff0b6a8659a2db4c117 Mon Sep 17 00:00:00 2001 From: "Harish S. Kulkarni" Date: Tue, 31 Mar 2020 14:51:36 -0700 Subject: [PATCH 2/7] Backed out unnecessary change --- src/Microsoft.ML.Transforms/LambdaTransform.cs | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/Microsoft.ML.Transforms/LambdaTransform.cs b/src/Microsoft.ML.Transforms/LambdaTransform.cs index 869c311225..8e61332d87 100644 --- a/src/Microsoft.ML.Transforms/LambdaTransform.cs +++ b/src/Microsoft.ML.Transforms/LambdaTransform.cs @@ -69,8 +69,7 @@ private static ITransformer Create(IHostEnvironment env, ModelLoadContext ctx) { Contracts.CheckValue(env, nameof(env)); env.CheckValue(ctx, nameof(ctx)); - var versionInfo = GetVersionInfo(); - ctx.CheckAtModel(versionInfo); + ctx.CheckAtModel(GetVersionInfo()); var contractName = ctx.LoadString(); if (ctx.Header.ModelVerWritten >= VerAssemblyNameSaved) From c77a09e93928571d17116894727903fa32c3c2fb Mon Sep 17 00:00:00 2001 From: "Harish S. Kulkarni" Date: Tue, 31 Mar 2020 15:47:17 -0700 Subject: [PATCH 3/7] Fixed failing test --- .../Transformers/CustomMappingTests.cs | 11 ----------- 1 file changed, 11 deletions(-) diff --git a/test/Microsoft.ML.Tests/Transformers/CustomMappingTests.cs b/test/Microsoft.ML.Tests/Transformers/CustomMappingTests.cs index c4d9378024..a54af27426 100644 --- a/test/Microsoft.ML.Tests/Transformers/CustomMappingTests.cs +++ b/test/Microsoft.ML.Tests/Transformers/CustomMappingTests.cs @@ -62,17 +62,6 @@ public void TestCustomTransformer() var tempoEnv = new MLContext(1); var customEst = new CustomMappingEstimator(tempoEnv, MyLambda.MyAction, "MyLambda"); - try - { - TestEstimatorCore(customEst, data); - Assert.True(false, "Cannot work without RegisterAssembly"); - } - catch (InvalidOperationException ex) - { - if (!ex.IsMarked()) - throw; - } - ML.ComponentCatalog.RegisterAssembly(typeof(MyLambda).Assembly); TestEstimatorCore(customEst, data); transformedData = customEst.Fit(data).Transform(data); From 47966e7aae1f87eec7be35192001ad1580c9e5f1 Mon Sep 17 00:00:00 2001 From: "Harish S. Kulkarni" Date: Tue, 31 Mar 2020 22:28:56 -0700 Subject: [PATCH 4/7] Disabled two tensorflow tests that are hanging --- .../TensorflowTests.cs | 15 ++++----------- 1 file changed, 4 insertions(+), 11 deletions(-) diff --git a/test/Microsoft.ML.Tests/ScenariosWithDirectInstantiation/TensorflowTests.cs b/test/Microsoft.ML.Tests/ScenariosWithDirectInstantiation/TensorflowTests.cs index 08b7f38ee8..97f473dee6 100644 --- a/test/Microsoft.ML.Tests/ScenariosWithDirectInstantiation/TensorflowTests.cs +++ b/test/Microsoft.ML.Tests/ScenariosWithDirectInstantiation/TensorflowTests.cs @@ -1250,14 +1250,10 @@ public void TensorFlowStringTest() } [TensorFlowFact] + // This test hangs occasionally + [Trait("Category", "SkipInCI")] public void TensorFlowImageClassificationDefault() { - if (RuntimeInformation.IsOSPlatform(OSPlatform.Linux)) - { - Output.WriteLine("TODO TEST_STABILITY: TensorFlowImageClassificationDefault hangs on Linux."); - return; - } - string imagesDownloadFolderPath = Path.Combine(TensorFlowScenariosTestsFixture.assetsPath, "inputs", "images"); @@ -1628,13 +1624,10 @@ internal void TensorFlowImageClassificationWithLRScheduling(LearningRateSchedule [TensorFlowTheory] [InlineData(ImageClassificationTrainer.EarlyStoppingMetric.Accuracy)] [InlineData(ImageClassificationTrainer.EarlyStoppingMetric.Loss)] + // This test hangs ocassionally + [Trait("Category", "SkipInCI")] public void TensorFlowImageClassificationEarlyStopping(ImageClassificationTrainer.EarlyStoppingMetric earlyStoppingMetric) { - if (RuntimeInformation.IsOSPlatform(OSPlatform.Linux)) - { - Output.WriteLine("TODO TEST_STABILITY: TensorFlowImageClassificationEarlyStopping hangs on Linux."); - return; - } string imagesDownloadFolderPath = Path.Combine(TensorFlowScenariosTestsFixture.assetsPath, "inputs", "images"); From db3dcb7f986c515ca3afab4952cd807c6935917c Mon Sep 17 00:00:00 2001 From: "Harish S. Kulkarni" Date: Wed, 1 Apr 2020 11:25:53 -0700 Subject: [PATCH 5/7] Addressed code review comments and added another test for backcompat --- .../UnitTests/TestCustomTypeRegister.cs | 45 +++++++++++++++--- .../Transformers/CustomMappingTests.cs | 13 ++++- test/data/backcompat/customTransform.zip | Bin 0 -> 785 bytes 3 files changed, 49 insertions(+), 9 deletions(-) create mode 100644 test/data/backcompat/customTransform.zip diff --git a/test/Microsoft.ML.Core.Tests/UnitTests/TestCustomTypeRegister.cs b/test/Microsoft.ML.Core.Tests/UnitTests/TestCustomTypeRegister.cs index a746bd762b..3bfa01c6b7 100644 --- a/test/Microsoft.ML.Core.Tests/UnitTests/TestCustomTypeRegister.cs +++ b/test/Microsoft.ML.Core.Tests/UnitTests/TestCustomTypeRegister.cs @@ -4,6 +4,7 @@ using System; using System.Collections.Generic; +using System.IO; using System.Linq; using Microsoft.ML.Data; using Microsoft.ML.Transforms; @@ -151,7 +152,7 @@ public SuperAlienHero() /// /// A mapping from to . It is used to create a - /// in . + /// in . /// [CustomMappingFactoryAttribute("LambdaAlienHero")] private class AlienFusionProcess : CustomMappingFactory @@ -171,8 +172,10 @@ public override Action GetMapping() } } - [Fact] - public void RegisterTypeWithAttribute() + [Theory] + [InlineData(true)] + [InlineData(false)] + public void RegisterTypeWithAttribute(bool saveModel) { // Build in-memory data. var tribe = new List() { new AlienHero("ML.NET", 2, 1000, 2000, 3000, 4000, 5000, 6000, 7000) }; @@ -184,9 +187,12 @@ public void RegisterTypeWithAttribute() var tribeTransformed = model.Transform(tribeDataView); var tribeEnumerable = ML.Data.CreateEnumerable(tribeTransformed, false).ToList(); - // save and reload the model - ML.Model.Save(model, tribeDataView.Schema, "customTransform.zip"); - var modelSaved = ML.Model.Load("customTransform.zip", out var tribeDataViewSaved); + ITransformer modelForPrediction = model; + if (saveModel) + { + ML.Model.Save(model, tribeDataView.Schema, "customTransform.zip"); + modelForPrediction = ML.Model.Load("customTransform.zip", out var tribeDataViewSchema); + } // Make sure the pipeline output is correct. Assert.Equal(tribeEnumerable[0].Name, "Super " + tribe[0].Name); @@ -196,7 +202,30 @@ public void RegisterTypeWithAttribute() Assert.Equal(tribeEnumerable[0].Merged.HandCount, tribe[0].One.HandCount + tribe[0].Two.HandCount); // Build prediction engine from the trained pipeline. - var engine = ML.Model.CreatePredictionEngine(modelSaved); + var engine = ML.Model.CreatePredictionEngine(modelForPrediction); + var alien = new AlienHero("TEN.LM", 1, 2, 3, 4, 5, 6, 7, 8); + var superAlien = engine.Predict(alien); + + // Make sure the prediction engine produces expected result. + Assert.Equal(superAlien.Name, "Super " + alien.Name); + Assert.Equal(superAlien.Merged.Age, alien.One.Age + alien.Two.Age); + Assert.Equal(superAlien.Merged.Height, alien.One.Height + alien.Two.Height); + Assert.Equal(superAlien.Merged.Weight, alien.One.Weight + alien.Two.Weight); + Assert.Equal(superAlien.Merged.HandCount, alien.One.HandCount + alien.Two.HandCount); + + Done(); + } + + [Fact] + void TestCustomTransformBackcompat() + { + // With older versions, it is necessary to register the assembly + ML.ComponentCatalog.RegisterAssembly(typeof(AlienFusionProcess).Assembly); + + var modelPath = Path.Combine(DataDir, "backcompat", "customTransform.zip"); + var trainedModel = ML.Model.Load(modelPath, out var dataViewSchema); + + var engine = ML.Model.CreatePredictionEngine(trainedModel); var alien = new AlienHero("TEN.LM", 1, 2, 3, 4, 5, 6, 7, 8); var superAlien = engine.Predict(alien); @@ -206,6 +235,8 @@ public void RegisterTypeWithAttribute() Assert.Equal(superAlien.Merged.Height, alien.One.Height + alien.Two.Height); Assert.Equal(superAlien.Merged.Weight, alien.One.Weight + alien.Two.Weight); Assert.Equal(superAlien.Merged.HandCount, alien.One.HandCount + alien.Two.HandCount); + + Done(); } [Fact] diff --git a/test/Microsoft.ML.Tests/Transformers/CustomMappingTests.cs b/test/Microsoft.ML.Tests/Transformers/CustomMappingTests.cs index a54af27426..d6e58b5ebc 100644 --- a/test/Microsoft.ML.Tests/Transformers/CustomMappingTests.cs +++ b/test/Microsoft.ML.Tests/Transformers/CustomMappingTests.cs @@ -44,8 +44,10 @@ public override Action GetMapping() } } - [Fact] - public void TestCustomTransformer() + [Theory] + [InlineData(true)] + [InlineData(false)] + public void TestCustomTransformer(bool registerAssembly) { string dataPath = GetDataPath("adult.tiny.with-schema.txt"); var source = new MultiFileSource(dataPath); @@ -62,6 +64,13 @@ public void TestCustomTransformer() var tempoEnv = new MLContext(1); var customEst = new CustomMappingEstimator(tempoEnv, MyLambda.MyAction, "MyLambda"); + // Before 1.5-preview3 it was required to register the assembly. + // Now, the assembly information is automatically saved in the model and the assembly is registered + // when loading. + // This tests the case that the CustomTransformer still works even if you explicitly register the assembly + if (registerAssembly) + ML.ComponentCatalog.RegisterAssembly(typeof(MyLambda).Assembly); + TestEstimatorCore(customEst, data); transformedData = customEst.Fit(data).Transform(data); diff --git a/test/data/backcompat/customTransform.zip b/test/data/backcompat/customTransform.zip new file mode 100644 index 0000000000000000000000000000000000000000..967a43f0d5c51a2ad5341a0caf29c50045d4d202 GIT binary patch literal 785 zcmWIWW@Zs#U|`^2kO^%JXxu+lXabOT42UIwIHV{sGcPkQ-7_yOKPD`-s5mn}Pp_n+ zr1zAgP?LfP!-ZPoyN$KN9jPm{H~aeTwm2wmcWB9t1LYUre&$$vfAXG3scZcviLZHm z-5~IP?#y{PL` zg&its26irU1cG=9MWzU}GA4(-oTH=DP?Eaz!qw-FeCrriu3EP7Rnqq2ALj4M?E_GJ zVQsgV_W{uL&lnjPq#(Y?D^ANV%1teD&H#ES#y3ADHAgQywQ{YWZTBGuf!eOt*8Ps_ z!~|^xSBtoZwA@y03d~9NaWK()bL`Q->HD9b`YpRJw*B>yV>9RWn}1GA-_qolR(#ck z`QD$;m#SE{uxw{7XDMerKD%_*i@S>sI z%anUg$o`w1eq=|Hzia(bM65>5Ak?_H!CO<7?^?Z5RkTH0`UOWgCVK_ literal 0 HcmV?d00001 From 2fd1d400b5485e5ebf07f3145b778c9e449aee50 Mon Sep 17 00:00:00 2001 From: "Harish S. Kulkarni" Date: Wed, 1 Apr 2020 13:33:11 -0700 Subject: [PATCH 6/7] Removed unnecessary using statements --- src/Microsoft.ML.Transforms/CustomMappingTransformer.cs | 1 - src/Microsoft.ML.Transforms/LambdaTransform.cs | 1 - 2 files changed, 2 deletions(-) diff --git a/src/Microsoft.ML.Transforms/CustomMappingTransformer.cs b/src/Microsoft.ML.Transforms/CustomMappingTransformer.cs index 76dae2b75c..03c55c8c58 100644 --- a/src/Microsoft.ML.Transforms/CustomMappingTransformer.cs +++ b/src/Microsoft.ML.Transforms/CustomMappingTransformer.cs @@ -4,7 +4,6 @@ using System; using System.Linq; -using System.Reflection; using Microsoft.ML.Data; using Microsoft.ML.Internal.Utilities; using Microsoft.ML.Runtime; diff --git a/src/Microsoft.ML.Transforms/LambdaTransform.cs b/src/Microsoft.ML.Transforms/LambdaTransform.cs index 8e61332d87..88de4a68e1 100644 --- a/src/Microsoft.ML.Transforms/LambdaTransform.cs +++ b/src/Microsoft.ML.Transforms/LambdaTransform.cs @@ -3,7 +3,6 @@ // See the LICENSE file in the project root for more information. using System; -using System.Diagnostics.Contracts; using System.IO; using System.Reflection; using System.Text; From 019d7cf0334c3664a624e3f68063c6ef6dcd6fb9 Mon Sep 17 00:00:00 2001 From: "Harish S. Kulkarni" Date: Wed, 1 Apr 2020 14:01:52 -0700 Subject: [PATCH 7/7] Updated docs --- .../Dynamic/Transforms/CustomMappingSaveAndLoad.cs | 1 + src/Microsoft.ML.Transforms/CustomMappingCatalog.cs | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/docs/samples/Microsoft.ML.Samples/Dynamic/Transforms/CustomMappingSaveAndLoad.cs b/docs/samples/Microsoft.ML.Samples/Dynamic/Transforms/CustomMappingSaveAndLoad.cs index f2d97b70eb..57b52c342c 100644 --- a/docs/samples/Microsoft.ML.Samples/Dynamic/Transforms/CustomMappingSaveAndLoad.cs +++ b/docs/samples/Microsoft.ML.Samples/Dynamic/Transforms/CustomMappingSaveAndLoad.cs @@ -45,6 +45,7 @@ public static void Example() // the custom action is defined needs to be registered in the // environment. The following registers the assembly where // IsUnderThirtyCustomAction is defined. + // This is necessary only in versions v1.5-preview2 and earlier mlContext.ComponentCatalog.RegisterAssembly(typeof( IsUnderThirtyCustomAction).Assembly); diff --git a/src/Microsoft.ML.Transforms/CustomMappingCatalog.cs b/src/Microsoft.ML.Transforms/CustomMappingCatalog.cs index 819a518188..c918302d5e 100644 --- a/src/Microsoft.ML.Transforms/CustomMappingCatalog.cs +++ b/src/Microsoft.ML.Transforms/CustomMappingCatalog.cs @@ -25,7 +25,7 @@ public static class CustomMappingCatalog /// If the resulting transformer needs to be save-able, the class defining should implement /// and needs to be decorated with /// with the provided . - /// The assembly containing the class should be registered in the environment where it is loaded back + /// In versions v1.5-preview2 and earlier, the assembly containing the class should be registered in the environment where it is loaded back /// using . /// The contract name, used by ML.NET for loading the model. /// If is specified, resulting transformer would not be save-able.