diff --git a/src/Core/src/MauiContext.cs b/src/Core/src/MauiContext.cs index df690b671e79..97f28fca6796 100644 --- a/src/Core/src/MauiContext.cs +++ b/src/Core/src/MauiContext.cs @@ -23,7 +23,11 @@ public MauiContext(IServiceProvider services, Android.Content.Context context) public MauiContext(IServiceProvider services) { - _services = new WrappedServiceProvider(services ?? throw new ArgumentNullException(nameof(services))); + _ = services ?? throw new ArgumentNullException(nameof(services)); + _services = services is IKeyedServiceProvider + ? new KeyedWrappedServiceProvider(services) + : new WrappedServiceProvider(services); + _handlers = new Lazy(() => _services.GetRequiredService()); #if ANDROID _context = new Lazy(() => _services.GetService()); @@ -73,5 +77,27 @@ public void AddSpecific(Type type, Func getter, object state) _scopeStatic[type] = (state, getter); } } + + class KeyedWrappedServiceProvider : WrappedServiceProvider, IKeyedServiceProvider + { + public KeyedWrappedServiceProvider(IServiceProvider serviceProvider) + : base(serviceProvider) + { + } + + public object? GetKeyedService(Type serviceType, object? serviceKey) + { + if (Inner is IKeyedServiceProvider provider) + return provider.GetKeyedService(serviceType, serviceKey); + + // we know this won't work, but we need to call it to throw the right exception + return Inner.GetRequiredKeyedService(serviceType, serviceKey); + } + + public object GetRequiredKeyedService(Type serviceType, object? serviceKey) + { + return Inner.GetRequiredKeyedService(serviceType, serviceKey); + } + } } } diff --git a/src/Core/tests/UnitTests/MauiContextTests.cs b/src/Core/tests/UnitTests/MauiContextTests.cs index d27430a52764..3c7761db2a9f 100644 --- a/src/Core/tests/UnitTests/MauiContextTests.cs +++ b/src/Core/tests/UnitTests/MauiContextTests.cs @@ -1,5 +1,7 @@ using System; using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.DependencyInjection.Extensions; +using Microsoft.Maui.Hosting; using Microsoft.Maui.Hosting.Internal; using Xunit; @@ -99,6 +101,126 @@ public void CloneCanOverrideIncludeService() Assert.Same(obj2, second.Services.GetService()); } + [Fact] + public void MauiContextSupportsKeyedServices() + { + var collection = new ServiceCollection(); + collection.AddKeyedTransient("foo"); + collection.AddKeyedTransient("foo2"); + var services = collection.BuildServiceProvider(); + + var context = new MauiContext(services); + + var foo = context.Services.GetRequiredKeyedService("foo"); + Assert.IsType(foo); + + var foo2 = context.Services.GetRequiredKeyedService("foo2"); + Assert.IsType(foo2); + } + + [Fact] + public void MauiContextSupportsKeyedServicesUsingAttributes() + { + var collection = new ServiceCollection(); + collection.AddKeyedTransient("foo"); + collection.AddKeyedTransient("bar"); + collection.AddTransient(); + var services = collection.BuildServiceProvider(); + + var context = new MauiContext(services); + + var foobar = context.Services.GetRequiredService(); + var keyed = Assert.IsType(foobar); + Assert.NotNull(keyed.Foo); + Assert.NotNull(keyed.Bar); + } + [Fact] + public void NonKeyedProviderStaysNonKeyed() + { + var builder = MauiApp.CreateBuilder(useDefaults: false); + builder.ConfigureContainer(new KeyedOrNonKeyedProviderFactory(false)); + var mauiApp = builder.Build(); + + var context = new MauiContext(mauiApp.Services); + + Assert.IsAssignableFrom(context.Services); + Assert.IsNotAssignableFrom(context.Services); + + var context2 = new MauiContext(context.Services); + + Assert.IsAssignableFrom(context2.Services); + Assert.IsNotAssignableFrom(context2.Services); + } + + [Fact] + public void KeyedProviderStaysKeyed() + { + var builder = MauiApp.CreateBuilder(useDefaults: false); + builder.ConfigureContainer(new KeyedOrNonKeyedProviderFactory(true)); + var mauiApp = builder.Build(); + + var context = new MauiContext(mauiApp.Services); + + Assert.IsAssignableFrom(context.Services); + Assert.IsAssignableFrom(context.Services); + + var context2 = new MauiContext(context.Services); + + Assert.IsAssignableFrom(context2.Services); + Assert.IsAssignableFrom(context2.Services); + } + + private class KeyedOrNonKeyedProviderFactory : IServiceProviderFactory + { + public KeyedOrNonKeyedProviderFactory(bool keyed) + { + Keyed = keyed; + } + + public bool Keyed { get; } + + public ServiceCollection CreateBuilder(IServiceCollection services) => + new() { services }; + + public IServiceProvider CreateServiceProvider(ServiceCollection containerBuilder) + { + var real = containerBuilder.BuildServiceProvider(); + return Keyed ? new KeyedProvider(real) : new NonKeyedProvider(real); + } + } + + private class NonKeyedProvider : IServiceProvider + { + public NonKeyedProvider(ServiceProvider provider) + { + Provider = provider; + } + + public ServiceProvider Provider { get; } + + public object GetService(Type serviceType) => + Provider.GetService(serviceType); + } + + private class KeyedProvider : IServiceProvider, IKeyedServiceProvider + { + public KeyedProvider(ServiceProvider provider) + { + Provider = provider; + } + + public ServiceProvider Provider { get; } + + public object GetKeyedService(Type serviceType, object serviceKey) => + Provider.GetKeyedService(serviceType, serviceKey); + + public object GetRequiredKeyedService(Type serviceType, object serviceKey) => + Provider.GetRequiredKeyedService(serviceType, serviceKey); + + public object GetService(Type serviceType) => + Provider.GetService(serviceType); + } + class TestThing { } diff --git a/src/Core/tests/UnitTests/TestClasses/TestServices.cs b/src/Core/tests/UnitTests/TestClasses/TestServices.cs index 63d372062a10..120af0d8584e 100644 --- a/src/Core/tests/UnitTests/TestClasses/TestServices.cs +++ b/src/Core/tests/UnitTests/TestClasses/TestServices.cs @@ -1,4 +1,5 @@ using System.Collections.Generic; +using Microsoft.Extensions.DependencyInjection; namespace Microsoft.Maui.UnitTests { @@ -64,6 +65,19 @@ public FooBarService(IFooService foo, IBarService bar) public IBarService Bar { get; } } + class FooBarKeyedService : IFooBarService + { + public FooBarKeyedService([FromKeyedServices("foo")] IFooService foo, [FromKeyedServices("bar")] IBarService bar) + { + Foo = foo; + Bar = bar; + } + + public IFooService Foo { get; } + + public IBarService Bar { get; } + } + class FooTrioConstructor : IFooBarService { public FooTrioConstructor()