From 76c72c9920531b8d949cbdcbe47232861b105ede Mon Sep 17 00:00:00 2001 From: feiyun0112 Date: Thu, 3 Jun 2021 21:07:46 +0800 Subject: [PATCH] Ensure the named model is loaded --- .../PredictionEnginePool.cs | 15 ++++- .../PredictionEnginePoolTests.cs | 61 +++++++++++++++++++ 2 files changed, 74 insertions(+), 2 deletions(-) create mode 100644 test/Microsoft.Extensions.ML.Tests/PredictionEnginePoolTests.cs diff --git a/src/Microsoft.Extensions.ML/PredictionEnginePool.cs b/src/Microsoft.Extensions.ML/PredictionEnginePool.cs index 6c15289bb5..ce501d46b2 100644 --- a/src/Microsoft.Extensions.ML/PredictionEnginePool.cs +++ b/src/Microsoft.Extensions.ML/PredictionEnginePool.cs @@ -50,6 +50,11 @@ public PredictionEnginePool(IServiceProvider serviceProvider, /// public ITransformer GetModel(string modelName) { + if (!_namedPools.ContainsKey(modelName)) + { + AddPool(modelName); + } + return _namedPools[modelName].Loader.GetModel(); } @@ -95,14 +100,20 @@ public PredictionEngine GetPredictionEngine(string modelName throw new ArgumentException("You need to configure a default, not named, model before you use this method."); } - return _defaultEnginePool.PredictionEnginePool.Get(); + return _defaultEnginePool.PredictionEnginePool.Get(); } + var pool = AddPool(modelName); + return pool.PredictionEnginePool.Get(); + } + + private PoolLoader AddPool(string modelName) + { //Here we are in the world of named models where the model hasn't been built yet. var options = _predictionEngineOptions.Create(modelName); var pool = new PoolLoader(_serviceProvider, options); pool = _namedPools.GetOrAdd(modelName, pool); - return pool.PredictionEnginePool.Get(); + return pool; } /// diff --git a/test/Microsoft.Extensions.ML.Tests/PredictionEnginePoolTests.cs b/test/Microsoft.Extensions.ML.Tests/PredictionEnginePoolTests.cs new file mode 100644 index 0000000000..8851aefdf8 --- /dev/null +++ b/test/Microsoft.Extensions.ML.Tests/PredictionEnginePoolTests.cs @@ -0,0 +1,61 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using System.IO; +using System.Threading; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Logging; +using Microsoft.Extensions.Options; +using Microsoft.Extensions.Primitives; +using Microsoft.ML.Data; +using Microsoft.ML.TestFramework; +using Microsoft.ML.TestFrameworkCommon; +using Xunit; +using Xunit.Abstractions; + +namespace Microsoft.Extensions.ML +{ + public class PredictionEnginePoolTests : BaseTestClass + { + public PredictionEnginePoolTests(ITestOutputHelper output) : base(output) + { + } + + [Fact] + public void can_load_namedmodel() + { + var services = new ServiceCollection() + .AddOptions() + .AddLogging(); + + services.AddPredictionEnginePool() + .FromFile(modelName: "model1", filePath: Path.Combine("TestModels", "SentimentModel.zip"), watchForChanges: false); + + var sp = services.BuildServiceProvider(); + + var pool = sp.GetRequiredService>(); + var model = pool.GetModel("model1"); + + Assert.NotNull(model); + } + + public class SentimentData + { + [ColumnName("Label"), LoadColumn(0)] + public bool Sentiment; + + [LoadColumn(1)] + public string SentimentText; + } + + public class SentimentPrediction + { + [ColumnName("PredictedLabel")] + public bool Sentiment; + + public float Score; + } + } +}