diff --git a/src/Serilog.Extensions.Hosting/SerilogHostBuilderExtensions.cs b/src/Serilog.Extensions.Hosting/SerilogHostBuilderExtensions.cs index 086f810..d588000 100644 --- a/src/Serilog.Extensions.Hosting/SerilogHostBuilderExtensions.cs +++ b/src/Serilog.Extensions.Hosting/SerilogHostBuilderExtensions.cs @@ -84,6 +84,9 @@ public static IHostBuilder UseSerilog( { // This won't (and shouldn't) take ownership of the logger. collection.AddSingleton(logger); + + // Still need to use RegisteredLogger as it is used by ConfigureDiagnosticContext. + collection.AddSingleton(new RegisteredLogger(logger)); } bool useRegisteredLogger = logger != null; ConfigureDiagnosticContext(collection, useRegisteredLogger); diff --git a/test/Serilog.Extensions.Hosting.Tests/SerilogHostBuilderExtensionsTests.cs b/test/Serilog.Extensions.Hosting.Tests/SerilogHostBuilderExtensionsTests.cs new file mode 100644 index 0000000..c9ba14b --- /dev/null +++ b/test/Serilog.Extensions.Hosting.Tests/SerilogHostBuilderExtensionsTests.cs @@ -0,0 +1,109 @@ +using System; +using System.Collections.Generic; +using Microsoft.Extensions.Configuration; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Hosting; +using Microsoft.Extensions.Logging; +using Xunit; + +namespace Serilog.Extensions.Hosting.Tests +{ + public class SerilogHostBuilderExtensionsTests + { + [Fact] + public void ServicesAreRegisteredWhenCallingUseSerilog() + { + // Arrange + var collection = new ServiceCollection(); + IHostBuilder builder = new FakeHostBuilder(collection); + + // Act + builder.UseSerilog(); + + // Assert + IServiceProvider provider = collection.BuildServiceProvider(); + provider.GetRequiredService(); + provider.GetRequiredService(); + } + + [Fact] + public void ServicesAreRegisteredWhenCallingUseSerilogWithLogger() + { + // Arrange + var collection = new ServiceCollection(); + IHostBuilder builder = new FakeHostBuilder(collection); + ILogger logger = new LoggerConfiguration().CreateLogger(); + + // Act + builder.UseSerilog(logger); + + // Assert + IServiceProvider provider = collection.BuildServiceProvider(); + provider.GetRequiredService(); + provider.GetRequiredService(); + provider.GetRequiredService(); + } + + [Fact] + public void ServicesAreRegisteredWhenCallingUseSerilogWithConfigureDelegate() + { + // Arrange + var collection = new ServiceCollection(); + IHostBuilder builder = new FakeHostBuilder(collection); + + // Act + builder.UseSerilog((_, _) => { }); + + // Assert + IServiceProvider provider = collection.BuildServiceProvider(); + provider.GetRequiredService(); + provider.GetRequiredService(); + provider.GetRequiredService(); + } + + private class FakeHostBuilder : IHostBuilder + { + private readonly IServiceCollection _collection; + + public FakeHostBuilder(IServiceCollection collection) => _collection = collection; + + public IHostBuilder ConfigureHostConfiguration(Action configureDelegate) + { + throw new NotImplementedException(); + } + + public IHostBuilder ConfigureAppConfiguration(Action configureDelegate) + { + throw new NotImplementedException(); + } + + public IHostBuilder ConfigureServices(Action configureDelegate) + { + configureDelegate(null, _collection); + return this; + } + + public IHostBuilder UseServiceProviderFactory(IServiceProviderFactory factory) + { + throw new NotImplementedException(); + } + + public IHostBuilder UseServiceProviderFactory(Func> factory) + { + throw new NotImplementedException(); + } + + public IHostBuilder ConfigureContainer(Action configureDelegate) + { + throw new NotImplementedException(); + } + + public IHost Build() + { + throw new NotImplementedException(); + } + + public IDictionary Properties { get; } + } + } +}