diff --git a/Microsoft.ML.sln b/Microsoft.ML.sln index c9aae84314..5e95aea98a 100644 --- a/Microsoft.ML.sln +++ b/Microsoft.ML.sln @@ -274,6 +274,16 @@ Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "Microsoft.ML.FastTree", "Mi pkg\Microsoft.ML.FastTree\Microsoft.ML.FastTree.symbols.nupkgproj = pkg\Microsoft.ML.FastTree\Microsoft.ML.FastTree.symbols.nupkgproj EndProjectSection EndProject +Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.Extensions.ML", "src\Microsoft.Extensions.ML\Microsoft.Extensions.ML.csproj", "{D6741C37-B5E6-4050-BCBA-9715809EA15B}" +EndProject +Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.Extensions.ML.Tests", "test\Microsoft.Extensions.ML.Tests\Microsoft.Extensions.ML.Tests.csproj", "{21CAD3A1-5E1F-42C1-BB73-46B6E67F4206}" +EndProject +Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "Microsoft.Extensions.ML", "Microsoft.Extensions.ML", "{AE4F7569-26F3-4160-8A8B-7A57D0DA3350}" + ProjectSection(SolutionItems) = preProject + pkg\Microsoft.Extensions.ML\Microsoft.Extensions.ML.nupkgproj = pkg\Microsoft.Extensions.ML\Microsoft.Extensions.ML.nupkgproj + pkg\Microsoft.Extensions.ML\Microsoft.Extensions.ML.symbols.nupkgproj = pkg\Microsoft.Extensions.ML\Microsoft.Extensions.ML.symbols.nupkgproj + EndProjectSection +EndProject Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.ML.StableApi", "tools-local\Microsoft.ML.StableApi\Microsoft.ML.StableApi.csproj", "{F308DC6B-7E59-40D7-A581-834E8CD99CFE}" EndProject Global @@ -970,6 +980,30 @@ Global {E02DA82D-3FEE-4C60-BD80-9EC3C3448DFC}.Release-Intrinsics|Any CPU.Build.0 = Release-Intrinsics|Any CPU {E02DA82D-3FEE-4C60-BD80-9EC3C3448DFC}.Release-netfx|Any CPU.ActiveCfg = Release-netfx|Any CPU {E02DA82D-3FEE-4C60-BD80-9EC3C3448DFC}.Release-netfx|Any CPU.Build.0 = Release-netfx|Any CPU + {D6741C37-B5E6-4050-BCBA-9715809EA15B}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {D6741C37-B5E6-4050-BCBA-9715809EA15B}.Debug|Any CPU.Build.0 = Debug|Any CPU + {D6741C37-B5E6-4050-BCBA-9715809EA15B}.Debug-Intrinsics|Any CPU.ActiveCfg = Debug-Intrinsics|Any CPU + {D6741C37-B5E6-4050-BCBA-9715809EA15B}.Debug-Intrinsics|Any CPU.Build.0 = Debug-Intrinsics|Any CPU + {D6741C37-B5E6-4050-BCBA-9715809EA15B}.Debug-netfx|Any CPU.ActiveCfg = Debug-netfx|Any CPU + {D6741C37-B5E6-4050-BCBA-9715809EA15B}.Debug-netfx|Any CPU.Build.0 = Debug-netfx|Any CPU + {D6741C37-B5E6-4050-BCBA-9715809EA15B}.Release|Any CPU.ActiveCfg = Release|Any CPU + {D6741C37-B5E6-4050-BCBA-9715809EA15B}.Release|Any CPU.Build.0 = Release|Any CPU + {D6741C37-B5E6-4050-BCBA-9715809EA15B}.Release-Intrinsics|Any CPU.ActiveCfg = Release-Intrinsics|Any CPU + {D6741C37-B5E6-4050-BCBA-9715809EA15B}.Release-Intrinsics|Any CPU.Build.0 = Release-Intrinsics|Any CPU + {D6741C37-B5E6-4050-BCBA-9715809EA15B}.Release-netfx|Any CPU.ActiveCfg = Release-netfx|Any CPU + {D6741C37-B5E6-4050-BCBA-9715809EA15B}.Release-netfx|Any CPU.Build.0 = Release-netfx|Any CPU + {21CAD3A1-5E1F-42C1-BB73-46B6E67F4206}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {21CAD3A1-5E1F-42C1-BB73-46B6E67F4206}.Debug|Any CPU.Build.0 = Debug|Any CPU + {21CAD3A1-5E1F-42C1-BB73-46B6E67F4206}.Debug-Intrinsics|Any CPU.ActiveCfg = Debug-Intrinsics|Any CPU + {21CAD3A1-5E1F-42C1-BB73-46B6E67F4206}.Debug-Intrinsics|Any CPU.Build.0 = Debug-Intrinsics|Any CPU + {21CAD3A1-5E1F-42C1-BB73-46B6E67F4206}.Debug-netfx|Any CPU.ActiveCfg = Debug-netfx|Any CPU + {21CAD3A1-5E1F-42C1-BB73-46B6E67F4206}.Debug-netfx|Any CPU.Build.0 = Debug-netfx|Any CPU + {21CAD3A1-5E1F-42C1-BB73-46B6E67F4206}.Release|Any CPU.ActiveCfg = Release|Any CPU + {21CAD3A1-5E1F-42C1-BB73-46B6E67F4206}.Release|Any CPU.Build.0 = Release|Any CPU + {21CAD3A1-5E1F-42C1-BB73-46B6E67F4206}.Release-Intrinsics|Any CPU.ActiveCfg = Release-Intrinsics|Any CPU + {21CAD3A1-5E1F-42C1-BB73-46B6E67F4206}.Release-Intrinsics|Any CPU.Build.0 = Release-Intrinsics|Any CPU + {21CAD3A1-5E1F-42C1-BB73-46B6E67F4206}.Release-netfx|Any CPU.ActiveCfg = Release-netfx|Any CPU + {21CAD3A1-5E1F-42C1-BB73-46B6E67F4206}.Release-netfx|Any CPU.Build.0 = Release-netfx|Any CPU {F308DC6B-7E59-40D7-A581-834E8CD99CFE}.Debug|Any CPU.ActiveCfg = Debug|Any CPU {F308DC6B-7E59-40D7-A581-834E8CD99CFE}.Debug|Any CPU.Build.0 = Debug|Any CPU {F308DC6B-7E59-40D7-A581-834E8CD99CFE}.Debug-Intrinsics|Any CPU.ActiveCfg = Debug-Intrinsics|Any CPU @@ -1069,6 +1103,9 @@ Global {AD7058C9-5608-49A8-BE23-58C33A74EE91} = {D3D38B03-B557-484D-8348-8BADEE4DF592} {E02DA82D-3FEE-4C60-BD80-9EC3C3448DFC} = {09EADF06-BE25-4228-AB53-95AE3E15B530} {B1B3F284-FA3D-4D76-A712-FF04495D244B} = {D3D38B03-B557-484D-8348-8BADEE4DF592} + {D6741C37-B5E6-4050-BCBA-9715809EA15B} = {09EADF06-BE25-4228-AB53-95AE3E15B530} + {21CAD3A1-5E1F-42C1-BB73-46B6E67F4206} = {AED9C836-31E3-4F3F-8ABC-929555D3F3C4} + {AE4F7569-26F3-4160-8A8B-7A57D0DA3350} = {D3D38B03-B557-484D-8348-8BADEE4DF592} {F308DC6B-7E59-40D7-A581-834E8CD99CFE} = {7F13E156-3EBA-4021-84A5-CD56BA72F99E} EndGlobalSection GlobalSection(ExtensibilityGlobals) = postSolution diff --git a/build/Dependencies.props b/build/Dependencies.props index 377f59579c..2ed9835fb4 100644 --- a/build/Dependencies.props +++ b/build/Dependencies.props @@ -14,6 +14,7 @@ 3.5.1 2.2.3 + 2.1.0 0.3.0 0.0.0.9 2.1.3 diff --git a/pkg/Microsoft.Extensions.ML/Microsoft.Extensions.ML.nupkgproj b/pkg/Microsoft.Extensions.ML/Microsoft.Extensions.ML.nupkgproj new file mode 100644 index 0000000000..4dbb257b8f --- /dev/null +++ b/pkg/Microsoft.Extensions.ML/Microsoft.Extensions.ML.nupkgproj @@ -0,0 +1,16 @@ + + + + netstandard2.0 + An integration package for ML.NET models on scalable web apps and services. + + + + + + + + + + + diff --git a/pkg/Microsoft.Extensions.ML/Microsoft.Extensions.ML.symbols.nupkgproj b/pkg/Microsoft.Extensions.ML/Microsoft.Extensions.ML.symbols.nupkgproj new file mode 100644 index 0000000000..0b7af4d817 --- /dev/null +++ b/pkg/Microsoft.Extensions.ML/Microsoft.Extensions.ML.symbols.nupkgproj @@ -0,0 +1,5 @@ + + + + + diff --git a/src/Microsoft.Extensions.ML/Builder/BuilderExtensions.cs b/src/Microsoft.Extensions.ML/Builder/BuilderExtensions.cs new file mode 100644 index 0000000000..ebeb366086 --- /dev/null +++ b/src/Microsoft.Extensions.ML/Builder/BuilderExtensions.cs @@ -0,0 +1,213 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using Microsoft.Extensions.DependencyInjection; + +namespace Microsoft.Extensions.ML +{ + /// + /// Extension methods for . + /// + public static class BuilderExtensions + { + /// + /// Adds the model at the specified location to the builder. + /// + /// The builder to which to add the model. + /// The location of the model. + /// + /// The updated . + /// + public static PredictionEnginePoolBuilder FromUri( + this PredictionEnginePoolBuilder builder, string uri) + where TData : class + where TPrediction : class, new() + { + return builder.FromUri(string.Empty, new Uri(uri)); + } + + /// + /// Adds the named model at the specified location to the builder. + /// + /// The builder to which to add the model. + /// + /// The name of the model which allows for uniquely identifying the model when + /// multiple models have the same and + /// types. + /// + /// The location of the model. + /// + /// The updated . + /// + public static PredictionEnginePoolBuilder FromUri( + this PredictionEnginePoolBuilder builder, string modelName, string uri) + where TData : class + where TPrediction : class, new() + { + return builder.FromUri(modelName, new Uri(uri)); + } + + /// + /// Adds the named model at the specified location to the builder. + /// + /// The builder to which to add the model. + /// + /// The name of the model which allows for uniquely identifying the model when + /// multiple models have the same and + /// types. + /// + /// The location of the model. + /// + /// The updated . + /// + public static PredictionEnginePoolBuilder FromUri( + this PredictionEnginePoolBuilder builder, string modelName, Uri uri) + where TData : class where TPrediction : class, new() + { + return builder.FromUri(modelName, uri, TimeSpan.FromMinutes(5)); + } + + /// + /// Adds the model at the specified location to the builder. + /// + /// The builder to which to add the model. + /// The location of the model. + /// + /// How often to query if the model has been updated at the specified location. + /// + /// + /// The updated . + /// + public static PredictionEnginePoolBuilder FromUri( + this PredictionEnginePoolBuilder builder, string uri, TimeSpan period) + where TData : class where TPrediction : class, new() + { + return builder.FromUri(string.Empty, new Uri(uri), period); + } + + /// + /// Adds the named model at the specified location to the builder. + /// + /// The builder to which to add the model. + /// + /// The name of the model which allows for uniquely identifying the model when + /// multiple models have the same and + /// types. + /// + /// The location of the model. + /// + /// How often to query if the model has been updated at the specified location. + /// + /// + /// The updated . + /// + public static PredictionEnginePoolBuilder FromUri( + this PredictionEnginePoolBuilder builder, string modelName, string uri, TimeSpan period) + where TData : class + where TPrediction : class, new() + { + return builder.FromUri(modelName, new Uri(uri), period); + } + + /// + /// Adds the named model at the specified location to the builder. + /// + /// The builder to which to add the model. + /// + /// The name of the model which allows for uniquely identifying the model when + /// multiple models have the same and + /// types. + /// + /// The location of the model. + /// + /// How often to query if the model has been updated at the specified location. + /// + /// + /// The updated . + /// + public static PredictionEnginePoolBuilder FromUri( + this PredictionEnginePoolBuilder builder, string modelName, Uri uri, TimeSpan period) + where TData : class + where TPrediction : class, new() + { + builder.Services.AddTransient(); + builder.Services.AddOptions>(modelName) + .Configure((opt, loader) => + { + loader.Start(uri, period); + opt.ModelLoader = loader; + }); + return builder; + } + + /// + /// Adds the model at the specified file to the builder. + /// + /// The builder to which to add the model. + /// The location of the model. + /// + /// The updated . + /// + public static PredictionEnginePoolBuilder FromFile( + this PredictionEnginePoolBuilder builder, string filePath) + where TData : class + where TPrediction : class, new() + { + return builder.FromFile(string.Empty, filePath, true); + } + + /// + /// Adds the model at the specified file to the builder. + /// + /// The builder to which to add the model. + /// + /// The name of the model which allows for uniquely identifying the model when + /// multiple models have the same and + /// types. + /// + /// The location of the model. + /// + /// The updated . + /// + public static PredictionEnginePoolBuilder FromFile( + this PredictionEnginePoolBuilder builder, string modelName, string filePath) + where TData : class + where TPrediction : class, new() + { + return builder.FromFile(modelName, filePath, true); + } + + /// + /// Adds the model at the specified file to the builder. + /// + /// The builder to which to add the model. + /// + /// The name of the model which allows for uniquely identifying the model when + /// multiple models have the same and + /// types. + /// + /// The location of the model. + /// + /// Whether to watch for changes to the file path and update the model when the file is changed or not. + /// + /// + /// The updated . + /// + public static PredictionEnginePoolBuilder FromFile( + this PredictionEnginePoolBuilder builder, string modelName, string filePath, bool watchForChanges) + where TData : class + where TPrediction : class, new() + { + builder.Services.AddTransient(); + builder.Services.AddOptions>(modelName) + .Configure((options, loader) => + { + loader.Start(filePath, watchForChanges); + options.ModelLoader = loader; + }); + return builder; + } + } +} diff --git a/src/Microsoft.Extensions.ML/Builder/PredictionEnginePoolBuilder.cs b/src/Microsoft.Extensions.ML/Builder/PredictionEnginePoolBuilder.cs new file mode 100644 index 0000000000..cb99b1ef7f --- /dev/null +++ b/src/Microsoft.Extensions.ML/Builder/PredictionEnginePoolBuilder.cs @@ -0,0 +1,33 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.ML; + +namespace Microsoft.Extensions.ML +{ + /// + /// A class that provides the mechanisms to configure a pool + /// of ML.NET objects. + /// + public class PredictionEnginePoolBuilder + where TData : class + where TPrediction : class, new() + { + /// + /// Initializes a new instance of . + /// + /// The to add services to. + public PredictionEnginePoolBuilder(IServiceCollection services) + { + Services = services ?? throw new ArgumentException(nameof(services)); + } + + /// + /// The to add services to. + /// + public IServiceCollection Services { get; private set; } + } +} diff --git a/src/Microsoft.Extensions.ML/MLOptions.cs b/src/Microsoft.Extensions.ML/MLOptions.cs new file mode 100644 index 0000000000..08beed6cc4 --- /dev/null +++ b/src/Microsoft.Extensions.ML/MLOptions.cs @@ -0,0 +1,65 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using Microsoft.Extensions.Logging; +using Microsoft.Extensions.Options; +using Microsoft.ML; + +namespace Microsoft.Extensions.ML +{ + /// + /// Provides options for ML.NET objects. + /// + public class MLOptions + { + private MLContext _context; + + /// + /// Initializes a new instance of . + /// + public MLOptions() + { + } + + /// + /// The which all the ML.NET operations happen. + /// + public MLContext MLContext + { + get { return _context ?? (_context = new MLContext()); } + set { _context = value; } + } + } + + /// + /// Configures the type. + /// + /// + /// Note: This is run after all . + /// + internal class PostMLContextOptionsConfiguration : IPostConfigureOptions + { + private readonly ILogger _logger; + + /// + /// Initializes a new instance of . + /// + /// The to write to. + public PostMLContextOptionsConfiguration(ILogger logger) + { + _logger = logger; + } + + /// + public void PostConfigure(string name, MLOptions options) + { + options.MLContext.Log += Log; + } + + private void Log(object sender, LoggingEventArgs e) + { + _logger.LogTrace(e.Message); + } + } +} diff --git a/src/Microsoft.Extensions.ML/Microsoft.Extensions.ML.csproj b/src/Microsoft.Extensions.ML/Microsoft.Extensions.ML.csproj new file mode 100644 index 0000000000..a27aed8386 --- /dev/null +++ b/src/Microsoft.Extensions.ML/Microsoft.Extensions.ML.csproj @@ -0,0 +1,17 @@ + + + + netstandard2.0 + Microsoft.Extensions.ML + + + + + + + + + + + + diff --git a/src/Microsoft.Extensions.ML/ModelLoaders/FileModelLoader.cs b/src/Microsoft.Extensions.ML/ModelLoaders/FileModelLoader.cs new file mode 100644 index 0000000000..ba343d82e4 --- /dev/null +++ b/src/Microsoft.Extensions.ML/ModelLoaders/FileModelLoader.cs @@ -0,0 +1,155 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using System.Diagnostics; +using System.IO; +using System.Threading; +using Microsoft.Extensions.Logging; +using Microsoft.Extensions.Options; +using Microsoft.Extensions.Primitives; +using Microsoft.ML; + +namespace Microsoft.Extensions.ML +{ + internal class FileModelLoader : ModelLoader, IDisposable + { + private readonly ILogger _logger; + private string _filePath; + private FileSystemWatcher _watcher; + private ModelReloadToken _reloadToken; + private ITransformer _model; + + private readonly MLContext _context; + + private readonly object _lock; + + public FileModelLoader(IOptions contextOptions, ILogger logger) + { + _logger = logger ?? throw new ArgumentNullException(nameof(logger)); + _context = contextOptions.Value?.MLContext ?? throw new ArgumentNullException(nameof(contextOptions)); + _lock = new object(); + } + + public void Start(string filePath, bool watchFile) + { + _filePath = filePath; + _reloadToken = new ModelReloadToken(); + + if (!File.Exists(filePath)) + { + throw new ArgumentException($"The provided model file {filePath} doesn't exist."); + } + + var directory = Path.GetDirectoryName(filePath); + + if (string.IsNullOrEmpty(directory)) + { + directory = Directory.GetCurrentDirectory(); + } + + var file = Path.GetFileName(filePath); + + LoadModel(); + + if (watchFile) + { + _watcher = new FileSystemWatcher(directory, file); + _watcher.EnableRaisingEvents = true; + _watcher.Changed += WatcherChanged; + } + } + + private void WatcherChanged(object sender, FileSystemEventArgs e) + { + var timer = Stopwatch.StartNew(); + Logger.FileReloadBegin(_logger, _filePath); + + var previousToken = Interlocked.Exchange(ref _reloadToken, new ModelReloadToken()); + lock (_lock) + { + //TODO: We get here multiple times when you copy and paste a file + //because of the way file watchers work. Need to think through the + //ramifications. + LoadModel(); + Logger.ReloadingFile(_logger, _filePath, timer.Elapsed); + } + previousToken.OnReload(); + timer.Stop(); + Logger.FileReloadEnd(_logger, _filePath, timer.Elapsed); + } + + public override IChangeToken GetReloadToken() + { + if (_reloadToken == null) throw new InvalidOperationException("Start must be called on a ModelLoader before it can be used."); + return _reloadToken; + } + + public override ITransformer GetModel() + { + if (_model == null) throw new InvalidOperationException("Start must be called on a ModelLoader before it can be used."); + + return _model; + } + + //internal virtual for testing purposes. + internal virtual void LoadModel() + { + //Sleep to avoid some file system locking issues + //TODO: The same thing occurs in configuration reload + //we should make sure the sleeps are the same. + Thread.Sleep(50); + using (var fileStream = File.OpenRead(_filePath)) + { + _model = _context.Model.Load(fileStream, out _); + } + } + + public void Dispose() + { + _watcher?.Dispose(); + } + + internal static class EventIds + { + public static readonly EventId FileReloadBegin = new EventId(100, "FileReloadBegin"); + public static readonly EventId FileReloadEnd = new EventId(101, "FileReloadEnd"); + public static readonly EventId FileReload = new EventId(102, "FileReload"); + } + + private static class Logger + { + private static readonly Action _fileLoadBegin = LoggerMessage.Define( + LogLevel.Debug, + EventIds.FileReloadBegin, + "File reload for '{filePath}'"); + + private static readonly Action _fileLoadEnd = LoggerMessage.Define( + LogLevel.Debug, + EventIds.FileReloadEnd, + "File reload for '{filePath}' completed after {ElapsedMilliseconds}ms"); + + private static readonly Action _fileReLoad = LoggerMessage.Define( + LogLevel.Information, + EventIds.FileReloadEnd, + "Reloading file '{filePath}' completed after {ElapsedMilliseconds}ms"); + + public static void FileReloadBegin(ILogger logger, string filePath) + { + _fileLoadBegin(logger, filePath, null); + } + + public static void FileReloadEnd(ILogger logger, string filePath, TimeSpan duration) + { + _fileLoadEnd(logger, filePath, duration.TotalMilliseconds, null); + } + + public static void ReloadingFile(ILogger logger, string filePath, TimeSpan duration) + { + _fileReLoad(logger, filePath, duration.TotalMilliseconds, null); + } + } + + } +} diff --git a/src/Microsoft.Extensions.ML/ModelLoaders/ModelLoader.cs b/src/Microsoft.Extensions.ML/ModelLoaders/ModelLoader.cs new file mode 100644 index 0000000000..46a083fff8 --- /dev/null +++ b/src/Microsoft.Extensions.ML/ModelLoaders/ModelLoader.cs @@ -0,0 +1,27 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using Microsoft.Extensions.Primitives; +using Microsoft.ML; + +namespace Microsoft.Extensions.ML +{ + /// + /// Defines a class that provides the mechanisms to load an ML.NET model + /// and to propagate notifications that the source of the model has changed. + /// + public abstract class ModelLoader + { + /// + /// Gets an object that can propagate notifications that + /// the model has changed. + /// + public abstract IChangeToken GetReloadToken(); + + /// + /// Gets the ML.NET model. + /// + public abstract ITransformer GetModel(); + } +} diff --git a/src/Microsoft.Extensions.ML/ModelLoaders/UriModelLoader.cs b/src/Microsoft.Extensions.ML/ModelLoaders/UriModelLoader.cs new file mode 100644 index 0000000000..eabdeb31ff --- /dev/null +++ b/src/Microsoft.Extensions.ML/ModelLoaders/UriModelLoader.cs @@ -0,0 +1,231 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using System.Diagnostics; +using System.Linq; +using System.Net.Http; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.Extensions.Logging; +using Microsoft.Extensions.Options; +using Microsoft.Extensions.Primitives; +using Microsoft.ML; + +namespace Microsoft.Extensions.ML +{ + internal class UriModelLoader : ModelLoader, IDisposable + { + //TODO: This should be able to be removed for HeaderNames.ETag + private const string ETagHeader = "ETag"; + private const int TimeoutMilliseconds = 60000; + private readonly MLContext _context; + private TimeSpan? _timerPeriod; + private Uri _uri; + private ITransformer _model; + private ModelReloadToken _reloadToken; + private Timer _reloadTimer; + private readonly object _reloadTimerLock; + private string _eTag; + private readonly ILogger _logger; + private readonly CancellationTokenSource _stopping; + private bool _started; + + public UriModelLoader(IOptions contextOptions, ILogger logger) + { + _logger = logger ?? throw new ArgumentNullException(nameof(logger)); + _context = contextOptions.Value?.MLContext; + _reloadTimerLock = new object(); + _reloadToken = new ModelReloadToken(); + _stopping = new CancellationTokenSource(); + _started = false; + } + + internal void Start(Uri uri, TimeSpan period) + { + _timerPeriod = period; + _uri = uri; + if (LoadModel().ConfigureAwait(false).GetAwaiter().GetResult()) + { + StartReloadTimer(); + } + _started = true; + } + + private async void ReloadTimerTick(object state) + { + StopReloadTimer(); + + await RunAsync(); + + StartReloadTimer(); + } + + internal bool IsStopping => _stopping.IsCancellationRequested; + + internal async Task RunAsync() + { + CancellationTokenSource cancellation = null; + //TODO: Switch to ValueStopWatch + var duration = Stopwatch.StartNew(); + try + { + cancellation = CancellationTokenSource.CreateLinkedTokenSource(_stopping.Token); + cancellation.CancelAfter(TimeoutMilliseconds); + Logger.UriReloadBegin(_logger, _uri); + + var eTagMatches = await MatchEtag(_uri, _eTag); + if (!eTagMatches) + { + await LoadModel(); + var previousToken = Interlocked.Exchange(ref _reloadToken, new ModelReloadToken()); + previousToken.OnReload(); + } + + Logger.UriReloadEnd(_logger, _uri, duration.Elapsed); + } + catch (OperationCanceledException) when (IsStopping) + { + // This is a cancellation - if the app is shutting down we want to ignore it. + } + catch (Exception ex) + { + Logger.UriReloadError(_logger, _uri, duration.Elapsed, ex); + } + finally + { + cancellation.Dispose(); + } + } + + internal virtual async Task MatchEtag(Uri uri, string eTag) + { + using (var client = new HttpClient()) + { + var headRequest = new HttpRequestMessage(HttpMethod.Head, uri); + var resp = await client.SendAsync(headRequest); + + return resp.Headers.GetValues(ETagHeader).First() == eTag; + } + } + + internal void StartReloadTimer() + { + lock (_reloadTimerLock) + { + if (_reloadTimer == null) + { + _reloadTimer = new Timer(ReloadTimerTick, this, Convert.ToInt32(_timerPeriod.Value.TotalMilliseconds), Timeout.Infinite); + } + } + } + + internal void StopReloadTimer() + { + lock (_reloadTimerLock) + { + _reloadTimer.Dispose(); + _reloadTimer = null; + } + } + + internal virtual async Task LoadModel() + { + //TODO: We probably need some sort of retry policy for this. + try + { + using (var client = new HttpClient()) + { + var resp = await client.GetAsync(_uri); + using (var stream = await resp.Content.ReadAsStreamAsync()) + { + _model = _context.Model.Load(stream, out _); + } + + if (resp.Headers.Contains(ETagHeader)) + { + _eTag = resp.Headers.GetValues(ETagHeader).First(); + return true; + } + return false; + } + } + catch (Exception ex) + { + Logger.UriLoadError(_logger, _uri, ex); + throw; + } + } + + public override ITransformer GetModel() + { + if (!_started) throw new InvalidOperationException("Start must be called on a ModelLoader before it can be used."); + + return _model; + } + + public override IChangeToken GetReloadToken() + { + if (!_started) throw new InvalidOperationException("Start must be called on a ModelLoader before it can be used."); + + return _reloadToken; + } + + public void Dispose() + { + _reloadTimer?.Dispose(); + } + + internal static class EventIds + { + public static readonly EventId UriReloadBegin = new EventId(100, "UriReloadBegin"); + public static readonly EventId UriReloadEnd = new EventId(101, "UriReloadEnd"); + public static readonly EventId UriReloadError = new EventId(102, "UriReloadError"); + public static readonly EventId UriLoadError = new EventId(103, "UriLoadError"); + } + + private static class Logger + { + private static readonly Action _uriReloadBegin = LoggerMessage.Define( + LogLevel.Debug, + EventIds.UriReloadBegin, + "URI reload '{uri}'"); + + private static readonly Action _uriReloadEnd = LoggerMessage.Define( + LogLevel.Debug, + EventIds.UriReloadEnd, + "URI reload '{uri}' completed after {ElapsedMilliseconds}ms"); + + private static readonly Action _uriReloadError = LoggerMessage.Define( + LogLevel.Error, + EventIds.UriReloadError, + "URI reload for {uri} threw an unhandled exception after {ElapsedMilliseconds}ms"); + + private static readonly Action _uriLoadError = LoggerMessage.Define( + LogLevel.Error, + EventIds.UriLoadError, + "Error loading {uri}"); + + public static void UriReloadBegin(ILogger logger, Uri uri) + { + _uriReloadBegin(logger, uri, null); + } + + public static void UriReloadEnd(ILogger logger, Uri uri, TimeSpan duration) + { + _uriReloadEnd(logger, uri, duration.TotalMilliseconds, null); + } + + public static void UriReloadError(ILogger logger, Uri uri, TimeSpan duration, Exception exception) + { + _uriReloadError(logger, uri, duration.TotalMilliseconds, exception); + } + + public static void UriLoadError(ILogger logger, Uri uri, Exception exception) + { + _uriLoadError(logger, uri, exception); + } + } + } +} diff --git a/src/Microsoft.Extensions.ML/ModelReloadToken.cs b/src/Microsoft.Extensions.ML/ModelReloadToken.cs new file mode 100644 index 0000000000..4669f75d43 --- /dev/null +++ b/src/Microsoft.Extensions.ML/ModelReloadToken.cs @@ -0,0 +1,49 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using System.Threading; +using Microsoft.Extensions.Primitives; + +namespace Microsoft.Extensions.ML +{ + /// + /// Implements + /// + public class ModelReloadToken : IChangeToken + { + private readonly CancellationTokenSource _cts; + + public ModelReloadToken() + { + _cts = new CancellationTokenSource(); + } + + /// + /// Indicates if this token will proactively raise callbacks. + /// + public bool ActiveChangeCallbacks => true; + + /// + /// Gets a value that indicates if a change has occurred. + /// + public bool HasChanged => _cts.IsCancellationRequested; + + /// + /// Registers for a callback that will be invoked when the entry has changed. + /// MUST be set before the callback is invoked. + /// + /// The callback to invoke. + /// State to be passed into the callback. + /// + /// An System.IDisposable that is used to unregister the callback. + /// + public IDisposable RegisterChangeCallback(Action callback, object state) => _cts.Token.Register(callback, state); + + /// + /// Used to trigger the change token when a reload occurs. + /// + public void OnReload() => _cts.Cancel(); + } +} diff --git a/src/Microsoft.Extensions.ML/PoolLoader.cs b/src/Microsoft.Extensions.ML/PoolLoader.cs new file mode 100644 index 0000000000..42c2817b68 --- /dev/null +++ b/src/Microsoft.Extensions.ML/PoolLoader.cs @@ -0,0 +1,53 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using System.Threading; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.ObjectPool; +using Microsoft.Extensions.Options; +using Microsoft.Extensions.Primitives; +using Microsoft.ML; + +namespace Microsoft.Extensions.ML +{ + /// + /// Encapsulates the data and logic required for loading and reloading PredictionEngine object pools. + /// + internal class PoolLoader: IDisposable + where TData : class + where TPrediction : class, new() + { + private DefaultObjectPool> _pool; + private readonly IDisposable _changeTokenRegistration; + + public PoolLoader(IServiceProvider sp, PredictionEnginePoolOptions poolOptions) + { + var contextOptions = sp.GetRequiredService>(); + Context = contextOptions.Value.MLContext ?? throw new ArgumentNullException(nameof(contextOptions)); + Loader = poolOptions.ModelLoader ?? throw new ArgumentNullException(nameof(poolOptions)); + + LoadPool(); + + _changeTokenRegistration = ChangeToken.OnChange( + () => Loader.GetReloadToken(), + () => LoadPool()); + } + + public ModelLoader Loader { get; } + private MLContext Context { get; } + public ObjectPool> PredictionEnginePool { get { return _pool; } } + + private void LoadPool() + { + var predictionEnginePolicy = new PredictionEnginePoolPolicy(Context, Loader.GetModel()); + Interlocked.Exchange(ref _pool, new DefaultObjectPool>(predictionEnginePolicy)); + } + + public void Dispose() + { + _changeTokenRegistration?.Dispose(); + } + } +} diff --git a/src/Microsoft.Extensions.ML/PredictionEnginePool.cs b/src/Microsoft.Extensions.ML/PredictionEnginePool.cs new file mode 100644 index 0000000000..57b7637286 --- /dev/null +++ b/src/Microsoft.Extensions.ML/PredictionEnginePool.cs @@ -0,0 +1,142 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using System.Collections.Generic; +using Microsoft.Extensions.Options; +using Microsoft.ML; + +namespace Microsoft.Extensions.ML +{ + /// + /// Provides a pool of objects + /// that can be used to make predictions. + /// + public class PredictionEnginePool + where TData : class + where TPrediction : class, new() + { + private readonly MLOptions _mlContextOptions; + private readonly IOptionsFactory> _predictionEngineOptions; + private readonly IServiceProvider _serviceProvider; + private readonly PoolLoader _defaultEnginePool; + private readonly Dictionary> _namedPools; + + public PredictionEnginePool(IServiceProvider serviceProvider, + IOptions mlContextOptions, + IOptionsFactory> predictionEngineOptions) + { + _mlContextOptions = mlContextOptions.Value; + _predictionEngineOptions = predictionEngineOptions; + _serviceProvider = serviceProvider; + + var defaultOptions = _predictionEngineOptions.Create(string.Empty); + + if (defaultOptions.ModelLoader != null) + { + _defaultEnginePool = new PoolLoader(_serviceProvider, defaultOptions); + } + + _namedPools = new Dictionary>(); + } + + /// + /// Get the Model used to create the pooled PredictionEngine. + /// + /// + /// The name of the model. Used when there are multiple models with the same input/output. + /// + public ITransformer GetModel(string modelName) + { + return _namedPools[modelName].Loader.GetModel(); + } + + /// + /// Get the Model used to create the pooled PredictionEngine. + /// + public ITransformer GetModel() + { + return _defaultEnginePool.Loader.GetModel(); + } + + /// + /// Gets a PredictionEngine that can be used to make predictions using + /// and . + /// + public PredictionEngine GetPredictionEngine() + { + return GetPredictionEngine(string.Empty); + } + + /// + /// Gets a PredictionEngine for a named model. + /// + /// + /// The name of the model which allows for uniquely identifying the model when + /// multiple models have the same and + /// types. + /// + public PredictionEngine GetPredictionEngine(string modelName) + { + if (_namedPools.ContainsKey(modelName)) + { + return _namedPools[modelName].PredictionEnginePool.Get(); + } + + //This is the case where someone has used string.Empty to get the default model. + //We can throw all the time, but it seems reasonable that we would just do what + //they are expecting if they know that an empty string means default. + if (string.IsNullOrEmpty(modelName)) + { + if (_defaultEnginePool == null) + { + throw new ArgumentException("You need to configure a default, not named, model before you use this method."); + } + + return _defaultEnginePool.PredictionEnginePool.Get(); + } + + //Here we are in the world of named models where the model hasn't been built yet. + var options = _predictionEngineOptions.Create(modelName); + var pool = new PoolLoader(_serviceProvider, options); + _namedPools.Add(modelName, pool); + return pool.PredictionEnginePool.Get(); + } + + /// + /// Returns a rented PredictionEngine to the pool. + /// + /// The rented PredictionEngine. + public void ReturnPredictionEngine(PredictionEngine engine) + { + ReturnPredictionEngine(string.Empty, engine); + } + + /// + /// Returns a rented PredictionEngine to the pool. + /// + /// + /// The name of the model which allows for uniquely identifying the model when + /// multiple models have the same and + /// types. + /// + /// The rented PredictionEngine. + public void ReturnPredictionEngine(string modelName, PredictionEngine engine) + { + if (engine == null) + { + throw new ArgumentNullException(nameof(engine)); + } + + if (string.IsNullOrEmpty(modelName)) + { + _defaultEnginePool.PredictionEnginePool.Return(engine); + } + else + { + _namedPools[modelName].PredictionEnginePool.Return(engine); + } + } + } +} diff --git a/src/Microsoft.Extensions.ML/PredictionEnginePoolExtensions.cs b/src/Microsoft.Extensions.ML/PredictionEnginePoolExtensions.cs new file mode 100644 index 0000000000..efdb4123f6 --- /dev/null +++ b/src/Microsoft.Extensions.ML/PredictionEnginePoolExtensions.cs @@ -0,0 +1,56 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +namespace Microsoft.Extensions.ML +{ + /// + /// Extension methods for . + /// + public static class PredictionEnginePoolExtensions + { + /// + /// Run prediction pipeline on one example using a PredictionEngine from the pool. + /// + /// + /// The pool of PredictionEngine instances to get the PredictionEngine. + /// + /// The example to run on. + /// The result of prediction. A new object is created for every call. + public static TPrediction Predict( + this PredictionEnginePool predictionEnginePool, TData example) + where TData : class + where TPrediction : class, new() + { + return predictionEnginePool.Predict(string.Empty, example); + } + + /// + /// Run prediction pipeline on one example using a PredictionEngine from the pool. + /// + /// + /// The pool of PredictionEngine instances to get the PredictionEngine. + /// + /// + /// The name of the model. Used when there are multiple models with the same input/output. + /// + /// The example to run on. + /// The result of prediction. A new object is created for every call. + public static TPrediction Predict( + this PredictionEnginePool predictionEnginePool, string modelName, TData example) + where TData : class + where TPrediction : class, new() + { + var predictionEngine = predictionEnginePool.GetPredictionEngine(modelName); + + try + { + return predictionEngine.Predict(example); + } + finally + { + predictionEnginePool.ReturnPredictionEngine(modelName, predictionEngine); + } + } + } +} diff --git a/src/Microsoft.Extensions.ML/PredictionEnginePoolOptions.cs b/src/Microsoft.Extensions.ML/PredictionEnginePoolOptions.cs new file mode 100644 index 0000000000..894037e986 --- /dev/null +++ b/src/Microsoft.Extensions.ML/PredictionEnginePoolOptions.cs @@ -0,0 +1,20 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +namespace Microsoft.Extensions.ML +{ + /// + /// Specifies the options to use when creating a . + /// + public class PredictionEnginePoolOptions + where TData : class + where TPrediction : class, new() + { + /// + /// Gets the object used to load the model + /// from the source location. + /// + public ModelLoader ModelLoader { get; set; } + } +} \ No newline at end of file diff --git a/src/Microsoft.Extensions.ML/PredictionEnginePoolPolicy.cs b/src/Microsoft.Extensions.ML/PredictionEnginePoolPolicy.cs new file mode 100644 index 0000000000..1048ad10c8 --- /dev/null +++ b/src/Microsoft.Extensions.ML/PredictionEnginePoolPolicy.cs @@ -0,0 +1,54 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using System.Collections.Generic; +using System.Linq; +using Microsoft.Extensions.ObjectPool; +using Microsoft.ML; + +namespace Microsoft.Extensions.ML +{ + /// + /// for + /// which is responsible for creating pooled objects, and when to return objects to the pool. + /// + internal class PredictionEnginePoolPolicy + : IPooledObjectPolicy> + where TData : class + where TPrediction : class, new() + { + private readonly MLContext _mlContext; + private readonly ITransformer _model; + private readonly List _references; + + /// + /// Initializes a new instance of . + /// + /// + /// used to load the model. + /// + /// The transformer to use for prediction. + public PredictionEnginePoolPolicy(MLContext mlContext, ITransformer model) + { + _mlContext = mlContext; + _model = model; + _references = new List(); + } + + /// + public PredictionEngine Create() + { + var engine = _mlContext.Model.CreatePredictionEngine(_model); + _references.Add(new WeakReference(engine)); + return engine; + } + + /// + public bool Return(PredictionEngine obj) + { + return _references.Any(x => x.Target == obj); + } + } +} \ No newline at end of file diff --git a/src/Microsoft.Extensions.ML/Properties/AssemblyInfo.cs b/src/Microsoft.Extensions.ML/Properties/AssemblyInfo.cs new file mode 100644 index 0000000000..85612a0522 --- /dev/null +++ b/src/Microsoft.Extensions.ML/Properties/AssemblyInfo.cs @@ -0,0 +1,3 @@ +using System.Runtime.CompilerServices; + +[assembly: InternalsVisibleTo("Microsoft.Extensions.ML.Tests, PublicKey=002400000480000094000000060200000024000052534131000400000100010015c01ae1f50e8cc09ba9eac9147cf8fd9fce2cfe9f8dce4f7301c4132ca9fb50ce8cbf1df4dc18dd4d210e4345c744ecb3365ed327efdbc52603faa5e21daa11234c8c4a73e51f03bf192544581ebe107adee3a34928e39d04e524a9ce729d5090bfd7dad9d10c722c0def9ccc08ff0a03790e48bcd1f9b6c476063e1966a1c4")] diff --git a/src/Microsoft.Extensions.ML/ServiceCollectionExtensions.cs b/src/Microsoft.Extensions.ML/ServiceCollectionExtensions.cs new file mode 100644 index 0000000000..cdd4996ae5 --- /dev/null +++ b/src/Microsoft.Extensions.ML/ServiceCollectionExtensions.cs @@ -0,0 +1,46 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.DependencyInjection.Extensions; +using Microsoft.Extensions.Options; + +namespace Microsoft.Extensions.ML +{ + /// + /// Extension methods for . + /// + public static class ServiceCollectionExtensions + { + /// + /// Adds a to the service collection. + /// + /// + /// The to add services to. + /// + /// + /// The that was added to the collection. + /// + public static PredictionEnginePoolBuilder AddPredictionEnginePool( + this IServiceCollection services) + where TData : class + where TPrediction : class, new() + { + services.AddPredictionEngineServices(); + return new PredictionEnginePoolBuilder(services); + } + + internal static IServiceCollection AddPredictionEngineServices( + this IServiceCollection services) + where TData : class + where TPrediction : class, new() + { + services.AddLogging(); + services.AddOptions(); + services.TryAddEnumerable(ServiceDescriptor.Singleton, PostMLContextOptionsConfiguration>()); + services.AddSingleton, PredictionEnginePool>(); + return services; + } + } +} diff --git a/test/Microsoft.Extensions.ML.Tests/FileLoaderTests.cs b/test/Microsoft.Extensions.ML.Tests/FileLoaderTests.cs new file mode 100644 index 0000000000..ed096bb5c0 --- /dev/null +++ b/test/Microsoft.Extensions.ML.Tests/FileLoaderTests.cs @@ -0,0 +1,106 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using System.IO; +using System.Threading.Tasks; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Logging; +using Microsoft.Extensions.Options; +using Microsoft.Extensions.Primitives; +using Microsoft.ML.Data; +using Xunit; + +namespace Microsoft.Extensions.ML +{ + public class FileLoaderTests + { + [Fact] + public void throw_until_started() + { + var services = new ServiceCollection() + .AddOptions() + .AddLogging(); + var sp = services.BuildServiceProvider(); + + var loaderUnderTest = ActivatorUtilities.CreateInstance(sp); + Assert.Throws(() => loaderUnderTest.GetModel()); + Assert.Throws(() => loaderUnderTest.GetReloadToken()); + } + + [Fact] + public void can_load_model() + { + var services = new ServiceCollection() + .AddOptions() + .AddLogging(); + var sp = services.BuildServiceProvider(); + + var loaderUnderTest = ActivatorUtilities.CreateInstance(sp); + loaderUnderTest.Start(Path.Combine("TestModels", "SentimentModel.zip"), false); + + var model = loaderUnderTest.GetModel(); + var context = sp.GetRequiredService>().Value.MLContext; + var engine = context.Model.CreatePredictionEngine(model); + + var prediction = engine.Predict(new SentimentData() { SentimentText = "This is great" }); + Assert.True(prediction.Sentiment); + } + + //TODO: This is a quick test to give coverage of the main scenarios. Refactoring and re-implementing of tests should happen. + //Right now this screams of probably flakeyness + [Fact] + public async Task can_reload_model() + { + var services = new ServiceCollection() + .AddOptions() + .AddLogging(); + var sp = services.BuildServiceProvider(); + + var loaderUnderTest = ActivatorUtilities.CreateInstance(sp); + loaderUnderTest.Start("testdata.txt", true); + + var changed = false; + var changeTokenRegistration = ChangeToken.OnChange( + () => loaderUnderTest.GetReloadToken(), + () => changed = true); + + File.WriteAllText("testdata.txt", "test"); + + await Task.Delay(1000); + + Assert.True(changed); + } + + + private class FileLoaderMock : FileModelLoader + { + public FileLoaderMock(IOptions contextOptions, ILogger logger) + : base(contextOptions, logger) + { + } + + internal override void LoadModel() + { + } + } + + public class SentimentData + { + [ColumnName("Label"), LoadColumn(0)] + public bool Sentiment; + + [LoadColumn(1)] + public string SentimentText; + } + + public class SentimentPrediction + { + [ColumnName("PredictedLabel")] + public bool Sentiment; + + public float Score; + } + } +} diff --git a/test/Microsoft.Extensions.ML.Tests/Microsoft.Extensions.ML.Tests.csproj b/test/Microsoft.Extensions.ML.Tests/Microsoft.Extensions.ML.Tests.csproj new file mode 100644 index 0000000000..2f65d79a5d --- /dev/null +++ b/test/Microsoft.Extensions.ML.Tests/Microsoft.Extensions.ML.Tests.csproj @@ -0,0 +1,24 @@ + + + + + + + + + + + + + + + + + + + + Always + + + + diff --git a/test/Microsoft.Extensions.ML.Tests/UriLoaderTests.cs b/test/Microsoft.Extensions.ML.Tests/UriLoaderTests.cs new file mode 100644 index 0000000000..91d4c9a5e1 --- /dev/null +++ b/test/Microsoft.Extensions.ML.Tests/UriLoaderTests.cs @@ -0,0 +1,98 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Logging; +using Microsoft.Extensions.Options; +using Microsoft.Extensions.Primitives; +using Microsoft.ML; +using Xunit; + +namespace Microsoft.Extensions.ML +{ + public class UriLoaderTests + { + [Fact] + public void throw_until_started() + { + var services = new ServiceCollection() + .AddOptions() + .AddLogging(); + var sp = services.BuildServiceProvider(); + + var loaderUnderTest = ActivatorUtilities.CreateInstance(sp); + Assert.Throws(() => loaderUnderTest.GetModel()); + Assert.Throws(() => loaderUnderTest.GetReloadToken()); + } + + [Fact] + public void can_reload_model() + { + var services = new ServiceCollection() + .AddOptions() + .AddLogging(); + var sp = services.BuildServiceProvider(); + + var loaderUnderTest = ActivatorUtilities.CreateInstance(sp); + loaderUnderTest.Start(new Uri("http://microsoft.com"), TimeSpan.FromMilliseconds(1)); + + var changed = false; + var changeTokenRegistration = ChangeToken.OnChange( + () => loaderUnderTest.GetReloadToken(), + () => changed = true); + Thread.Sleep(30); + Assert.True(changed); + } + + [Fact] + public void no_reload_no_change() + { + var services = new ServiceCollection() + .AddOptions() + .AddLogging(); + var sp = services.BuildServiceProvider(); + + var loaderUnderTest = ActivatorUtilities.CreateInstance(sp); + + loaderUnderTest.ETagMatches = (a,b) => true; + + loaderUnderTest.Start(new Uri("http://microsoft.com"), TimeSpan.FromMilliseconds(1)); + + var changed = false; + var changeTokenRegistration = ChangeToken.OnChange( + () => loaderUnderTest.GetReloadToken(), + () => changed = true); + Thread.Sleep(30); + Assert.False(changed); + } + } + + class UriLoaderMock : UriModelLoader + { + public Func ETagMatches { get; set; } = (_, __) => false; + + public UriLoaderMock(IOptions contextOptions, + ILogger logger) : base(contextOptions, logger) + { + } + + public override ITransformer GetModel() + { + return null; + } + + internal override Task LoadModel() + { + return Task.FromResult(true); + } + + internal override Task MatchEtag(Uri uri, string eTag) + { + return Task.FromResult(ETagMatches(uri, eTag)); + } + } +} diff --git a/test/Microsoft.Extensions.ML.Tests/testdata.txt b/test/Microsoft.Extensions.ML.Tests/testdata.txt new file mode 100644 index 0000000000..28e0d50134 --- /dev/null +++ b/test/Microsoft.Extensions.ML.Tests/testdata.txt @@ -0,0 +1 @@ +Contents don't matter. \ No newline at end of file