diff --git a/src/Scrutor/ILifetimeSelector.cs b/src/Scrutor/ILifetimeSelector.cs index d473110b..9bab980d 100644 --- a/src/Scrutor/ILifetimeSelector.cs +++ b/src/Scrutor/ILifetimeSelector.cs @@ -1,5 +1,5 @@ -using System; -using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.DependencyInjection; +using System; namespace Scrutor; @@ -29,4 +29,16 @@ public interface ILifetimeSelector : IServiceTypeSelector /// Registers each matching concrete type with a lifetime based on the provided . /// IImplementationTypeSelector WithLifetime(Func selector); + + /// + /// Registers each matching concrete type with the specified . + /// + /// The service key to use for registration. + ILifetimeSelector WithServiceKey(object serviceKey); + + /// + /// Registers each matching concrete type with a service key based on the provided . + /// + /// A function to determine the service key for each type. + ILifetimeSelector WithServiceKey(Func selector); } diff --git a/src/Scrutor/LifetimeSelector.cs b/src/Scrutor/LifetimeSelector.cs index 595ab23d..3e07acd8 100644 --- a/src/Scrutor/LifetimeSelector.cs +++ b/src/Scrutor/LifetimeSelector.cs @@ -1,9 +1,9 @@ -using System; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.DependencyModel; +using System; using System.Collections.Generic; using System.Diagnostics.CodeAnalysis; using System.Reflection; -using Microsoft.Extensions.DependencyInjection; -using Microsoft.Extensions.DependencyModel; namespace Scrutor; @@ -24,6 +24,8 @@ public LifetimeSelector(ServiceTypeSelector inner, IEnumerable typeMaps public Func? SelectorFn { get; set; } + public Func? ServiceKeySelectorFn { get; set; } + public IImplementationTypeSelector WithSingletonLifetime() { return WithLifetime(ServiceLifetime.Singleton); @@ -54,6 +56,22 @@ public IImplementationTypeSelector WithLifetime(Func sele return this; } + public ILifetimeSelector WithServiceKey(object serviceKey) + { + Preconditions.NotNull(serviceKey, nameof(serviceKey)); + + return WithServiceKey(_ => serviceKey); + } + + public ILifetimeSelector WithServiceKey(Func selector) + { + Preconditions.NotNull(selector, nameof(selector)); + + Inner.PropagateServiceKey(selector); + + return this; + } + #region Chain Methods [ExcludeFromCodeCoverage] @@ -231,6 +249,7 @@ void ISelector.Populate(IServiceCollection services, RegistrationStrategy? strat strategy ??= RegistrationStrategy.Append; var lifetimes = new Dictionary(); + var serviceKeys = new Dictionary(); foreach (var typeMap in TypeMaps) { @@ -244,8 +263,10 @@ void ISelector.Populate(IServiceCollection services, RegistrationStrategy? strat } var lifetime = GetOrAddLifetime(lifetimes, implementationType); - - var descriptor = new ServiceDescriptor(serviceType, implementationType, lifetime); + var serviceKey = GetOrAddServiceKey(serviceKeys, implementationType); + var descriptor = serviceKey is not null + ? new ServiceDescriptor(serviceType, serviceKey, implementationType, lifetime) + : new ServiceDescriptor(serviceType, implementationType, lifetime); strategy.Apply(services, descriptor); } @@ -256,8 +277,11 @@ void ISelector.Populate(IServiceCollection services, RegistrationStrategy? strat foreach (var serviceType in typeFactoryMap.ServiceTypes) { var lifetime = GetOrAddLifetime(lifetimes, typeFactoryMap.ImplementationType); + var serviceKey = GetOrAddServiceKey(serviceKeys, typeFactoryMap.ImplementationType); - var descriptor = new ServiceDescriptor(serviceType, typeFactoryMap.ImplementationFactory, lifetime); + var descriptor = serviceKey is not null + ? new ServiceDescriptor(serviceType, serviceKey, WrapImplementationFactory(typeFactoryMap.ImplementationFactory), lifetime) + : new ServiceDescriptor(serviceType, typeFactoryMap.ImplementationFactory, lifetime); strategy.Apply(services, descriptor); } @@ -277,4 +301,23 @@ private ServiceLifetime GetOrAddLifetime(Dictionary lifet return lifetime; } + + private object? GetOrAddServiceKey(Dictionary serviceKeys, Type implementationType) + { + if (serviceKeys.TryGetValue(implementationType, out var serviceKey)) + { + return serviceKey; + } + + serviceKey = ServiceKeySelectorFn?.Invoke(implementationType); + + serviceKeys[implementationType] = serviceKey; + + return serviceKey; + } + + private static Func WrapImplementationFactory(Func factory) + { + return (sp, _) => factory(sp); + } } diff --git a/src/Scrutor/ServiceTypeSelector.cs b/src/Scrutor/ServiceTypeSelector.cs index 4bed2357..c971e552 100644 --- a/src/Scrutor/ServiceTypeSelector.cs +++ b/src/Scrutor/ServiceTypeSelector.cs @@ -1,10 +1,10 @@ -using System; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.DependencyModel; +using System; using System.Collections; using System.Collections.Generic; using System.Linq; using System.Reflection; -using Microsoft.Extensions.DependencyInjection; -using Microsoft.Extensions.DependencyModel; namespace Scrutor; @@ -207,6 +207,14 @@ internal void PropagateLifetime(Func selectorFn) } } + internal void PropagateServiceKey(Func selectorFn) + { + foreach (var selector in Selectors.OfType()) + { + selector.ServiceKeySelectorFn = selectorFn; + } + } + void ISelector.Populate(IServiceCollection services, RegistrationStrategy? registrationStrategy) { if (Selectors.Count == 0) diff --git a/test/Scrutor.Tests/ScanningTests.cs b/test/Scrutor.Tests/ScanningTests.cs index 4dcbbf2e..bda4e6d0 100644 --- a/test/Scrutor.Tests/ScanningTests.cs +++ b/test/Scrutor.Tests/ScanningTests.cs @@ -584,10 +584,70 @@ public void ShouldAllowOptInToCompilerGeneratedTypes() .AsSelf() .WithTransientLifetime()); }); - + var compilerGeneratedSubclass = provider.GetService(); Assert.NotNull(compilerGeneratedSubclass); } + + [Fact] + public void CanRegisterWithServiceKey() + { + Collection.Scan(scan => scan + .FromTypes() + .AsImplementedInterfaces(x => x != typeof(IOtherInheritance)) + .WithServiceKey("my-key") + .WithSingletonLifetime()); + + Assert.Equal(2, Collection.Count); + + Assert.All(Collection, x => + { + Assert.Equal(ServiceLifetime.Singleton, x.Lifetime); + Assert.Equal(typeof(ITransientService), x.ServiceType); + Assert.True(x.IsKeyedService); + Assert.Equal("my-key", x.ServiceKey); + }); + } + + [Fact] + public void CanRegisterWithServiceKeySelector() + { + Collection.Scan(scan => scan + .FromTypes() + .AsImplementedInterfaces(x => x != typeof(IOtherInheritance)) + .WithServiceKey(type => type.Name) + .WithSingletonLifetime()); + + Assert.Equal(2, Collection.Count); + + var service1 = Collection.First(x => x.ServiceKey as string == nameof(TransientService1)); + Assert.Equal(typeof(ITransientService), service1.ServiceType); + Assert.Equal(ServiceLifetime.Singleton, service1.Lifetime); + Assert.True(service1.IsKeyedService); + + var service2 = Collection.First(x => x.ServiceKey as string == nameof(TransientService2)); + Assert.Equal(typeof(ITransientService), service2.ServiceType); + Assert.Equal(ServiceLifetime.Singleton, service2.Lifetime); + Assert.True(service2.IsKeyedService); + } + + [Fact] + public void CanResolveKeyedServices() + { + Collection.Scan(scan => scan + .FromTypes() + .AsSelf() + .WithServiceKey(type => type.Name) + .WithTransientLifetime()); + + var provider = Collection.BuildServiceProvider(); + + var service1 = provider.GetRequiredKeyedService(nameof(TransientService1)); + var service2 = provider.GetRequiredKeyedService(nameof(TransientService2)); + + Assert.NotNull(service1); + Assert.NotNull(service2); + } } // ReSharper disable UnusedTypeParameter @@ -671,7 +731,7 @@ public class DefaultAttributes : IDefault3Level2, IDefault1, IDefault2 { } [CompilerGenerated] public class CompilerGenerated { } - public class CombinedService2: IDefault1, IDefault2, IDefault3Level2 { } + public class CombinedService2 : IDefault1, IDefault2, IDefault3Level2 { } public interface IGenericAttribute { }