Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 13 additions & 2 deletions src/Microsoft.Extensions.ML/PredictionEnginePool.cs
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,11 @@ public PredictionEnginePool(IServiceProvider serviceProvider,
/// </param>
public ITransformer GetModel(string modelName)
{
if (!_namedPools.ContainsKey(modelName))
{
AddPool(modelName);
}

return _namedPools[modelName].Loader.GetModel();
}

Expand Down Expand Up @@ -95,14 +100,20 @@ public PredictionEngine<TData, TPrediction> GetPredictionEngine(string modelName
throw new ArgumentException("You need to configure a default, not named, model before you use this method.");
}

return _defaultEnginePool.PredictionEnginePool.Get();
return _defaultEnginePool.PredictionEnginePool.Get();
}

var pool = AddPool(modelName);
return pool.PredictionEnginePool.Get();
}

private PoolLoader<TData, TPrediction> AddPool(string modelName)
{
//Here we are in the world of named models where the model hasn't been built yet.
var options = _predictionEngineOptions.Create(modelName);
var pool = new PoolLoader<TData, TPrediction>(_serviceProvider, options);
pool = _namedPools.GetOrAdd(modelName, pool);
return pool.PredictionEnginePool.Get();
return pool;
}

/// <summary>
Expand Down
61 changes: 61 additions & 0 deletions test/Microsoft.Extensions.ML.Tests/PredictionEnginePoolTests.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

using System;
using System.IO;
using System.Threading;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.Logging;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good!

I think you have some redundant imports here.

using Microsoft.Extensions.Options;
using Microsoft.Extensions.Primitives;
using Microsoft.ML.Data;
using Microsoft.ML.TestFramework;
using Microsoft.ML.TestFrameworkCommon;
using Xunit;
using Xunit.Abstractions;

namespace Microsoft.Extensions.ML
{
public class PredictionEnginePoolTests : BaseTestClass
{
public PredictionEnginePoolTests(ITestOutputHelper output) : base(output)
{
}

[Fact]
public void can_load_namedmodel()
{
var services = new ServiceCollection()
.AddOptions()
.AddLogging();

services.AddPredictionEnginePool<SentimentData, SentimentPrediction>()
.FromFile(modelName: "model1", filePath: Path.Combine("TestModels", "SentimentModel.zip"), watchForChanges: false);

var sp = services.BuildServiceProvider();

var pool = sp.GetRequiredService<PredictionEnginePool<SentimentData, SentimentPrediction>>();
var model = pool.GetModel("model1");

Assert.NotNull(model);
}

public class SentimentData
{
[ColumnName("Label"), LoadColumn(0)]
public bool Sentiment;

[LoadColumn(1)]
public string SentimentText;
}

public class SentimentPrediction
{
[ColumnName("PredictedLabel")]
public bool Sentiment;

public float Score;
}
}
}