diff --git a/src/Scrutor/AttributeSelector.cs b/src/Scrutor/AttributeSelector.cs index 0026ded4..af3ee630 100644 --- a/src/Scrutor/AttributeSelector.cs +++ b/src/Scrutor/AttributeSelector.cs @@ -37,7 +37,9 @@ void ISelector.Populate(IServiceCollection services, RegistrationStrategy? regis foreach (var serviceType in serviceTypes) { - var descriptor = new ServiceDescriptor(serviceType, type, attribute.Lifetime); + var descriptor = attribute.ServiceKey is null + ? new ServiceDescriptor(serviceType, type, attribute.Lifetime) + : new ServiceDescriptor(serviceType, attribute.ServiceKey, type, attribute.Lifetime); strategy.Apply(services, descriptor); } @@ -47,6 +49,6 @@ void ISelector.Populate(IServiceCollection services, RegistrationStrategy? regis private static IEnumerable GetDuplicates(IEnumerable attributes) { - return attributes.GroupBy(s => s.ServiceType).SelectMany(grp => grp.Skip(1)); + return attributes.GroupBy(s => new { s.ServiceType, s.ServiceKey }).SelectMany(grp => grp.Skip(1)); } } diff --git a/src/Scrutor/ServiceDescriptorAttribute.cs b/src/Scrutor/ServiceDescriptorAttribute.cs index 4a613466..47e8965a 100644 --- a/src/Scrutor/ServiceDescriptorAttribute.cs +++ b/src/Scrutor/ServiceDescriptorAttribute.cs @@ -13,16 +13,21 @@ public ServiceDescriptorAttribute() : this(null) { } public ServiceDescriptorAttribute(Type? serviceType) : this(serviceType, ServiceLifetime.Transient) { } - public ServiceDescriptorAttribute(Type? serviceType, ServiceLifetime lifetime) + public ServiceDescriptorAttribute(Type? serviceType, ServiceLifetime lifetime) : this(serviceType, lifetime, null) { } + + public ServiceDescriptorAttribute(Type? serviceType, ServiceLifetime lifetime, object? serviceKey) { ServiceType = serviceType; Lifetime = lifetime; + ServiceKey = serviceKey; } public Type? ServiceType { get; } public ServiceLifetime Lifetime { get; } + public object? ServiceKey { get; } + public IEnumerable GetServiceTypes(Type fallbackType) { if (ServiceType is null) @@ -60,4 +65,8 @@ public sealed class ServiceDescriptorAttribute : ServiceDescriptorAttr public ServiceDescriptorAttribute() : base(typeof(TService)) { } public ServiceDescriptorAttribute(ServiceLifetime lifetime) : base(typeof(TService), lifetime) { } + + public ServiceDescriptorAttribute(object? serviceKey) : base(typeof(TService), ServiceLifetime.Transient, serviceKey) { } + + public ServiceDescriptorAttribute(ServiceLifetime lifetime, object? serviceKey) : base(typeof(TService), lifetime, serviceKey) { } } diff --git a/test/Scrutor.Tests/KeyedServiceTests.cs b/test/Scrutor.Tests/KeyedServiceTests.cs new file mode 100644 index 00000000..b28c003e --- /dev/null +++ b/test/Scrutor.Tests/KeyedServiceTests.cs @@ -0,0 +1,329 @@ +using Microsoft.Extensions.DependencyInjection; +using System; +using System.Linq; +using Xunit; + +namespace Scrutor.Tests; + +public class KeyedServiceTests : TestBase +{ + private IServiceCollection Collection { get; } = new ServiceCollection(); + + [Fact] + public void CanRegisterKeyedServiceWithStringKey() + { + Collection.Scan(scan => scan + .FromTypes(typeof(KeyedTransientService)) + .UsingAttributes()); + + Assert.Single(Collection); + + var service = Collection.Single(); + Assert.Equal(typeof(IKeyedTestService), service.ServiceType); + Assert.Equal(typeof(KeyedTransientService), service.KeyedImplementationType); + Assert.Equal("test-key", service.ServiceKey); + Assert.Equal(ServiceLifetime.Transient, service.Lifetime); + Assert.True(service.IsKeyedService); + } + + [Fact] + public void CanRegisterKeyedServiceWithGenericAttribute() + { + Collection.Scan(scan => scan + .FromTypes(typeof(GenericKeyedService)) + .UsingAttributes()); + + Assert.Single(Collection); + + var service = Collection.Single(); + Assert.Equal(typeof(IKeyedTestService), service.ServiceType); + Assert.Equal(typeof(GenericKeyedService), service.KeyedImplementationType); + Assert.Equal("generic-key", service.ServiceKey); + Assert.Equal(ServiceLifetime.Scoped, service.Lifetime); + Assert.True(service.IsKeyedService); + } + + [Fact] + public void CanRegisterMultipleKeyedServicesOnSameType() + { + Collection.Scan(scan => scan + .FromTypes(typeof(MultipleKeyedService)) + .UsingAttributes()); + + Assert.Equal(2, Collection.Count); + + var services = Collection.ToArray(); + + var service1 = services.First(s => s.ServiceKey?.ToString() == "key1"); + Assert.Equal(typeof(IKeyedTestService), service1.ServiceType); + Assert.Equal(typeof(MultipleKeyedService), service1.KeyedImplementationType); + Assert.Equal(ServiceLifetime.Transient, service1.Lifetime); + + var service2 = services.First(s => s.ServiceKey?.ToString() == "key2"); + Assert.Equal(typeof(IKeyedTestService), service2.ServiceType); + Assert.Equal(typeof(MultipleKeyedService), service2.KeyedImplementationType); + Assert.Equal(ServiceLifetime.Singleton, service2.Lifetime); + } + + [Fact] + public void CanRegisterMixedKeyedAndNonKeyedServices() + { + Collection.Scan(scan => scan + .FromTypes(typeof(MixedKeyedService)) + .UsingAttributes()); + + Assert.Equal(2, Collection.Count); + + var keyedService = Collection.First(s => s.IsKeyedService); + Assert.Equal(typeof(IKeyedTestService), keyedService.ServiceType); + Assert.Equal(typeof(MixedKeyedService), keyedService.KeyedImplementationType); + Assert.Equal("mixed-key", keyedService.ServiceKey); + Assert.Equal(ServiceLifetime.Scoped, keyedService.Lifetime); + + var nonKeyedService = Collection.First(s => !s.IsKeyedService); + Assert.Equal(typeof(IKeyedTestService), nonKeyedService.ServiceType); + Assert.Equal(typeof(MixedKeyedService), nonKeyedService.ImplementationType); + Assert.Null(nonKeyedService.ServiceKey); + Assert.Equal(ServiceLifetime.Transient, nonKeyedService.Lifetime); + } + + [Fact] + public void CanResolveKeyedServices() + { + var provider = ConfigureProvider(services => + { + services.Scan(scan => scan + .FromTypes(typeof(KeyedTransientService), typeof(GenericKeyedService), typeof(MultipleKeyedService)) + .UsingAttributes()); + }); + + var keyedTransient = provider.GetRequiredKeyedService("test-key"); + Assert.IsType(keyedTransient); + + using var scope = provider.CreateScope(); + var genericKeyed = scope.ServiceProvider.GetRequiredKeyedService("generic-key"); + Assert.IsType(genericKeyed); + + var multipleKeyed1 = provider.GetRequiredKeyedService("key1"); + Assert.IsType(multipleKeyed1); + + var multipleKeyed2 = provider.GetRequiredKeyedService("key2"); + Assert.IsType(multipleKeyed2); + + // Verify they are different instances for transient services + var anotherKeyedTransient = provider.GetRequiredKeyedService("test-key"); + Assert.NotSame(keyedTransient, anotherKeyedTransient); + + // Verify singleton behavior + var anotherMultipleKeyed2 = provider.GetRequiredKeyedService("key2"); + Assert.Same(multipleKeyed2, anotherMultipleKeyed2); + } + + [Fact] + public void KeyedServicesAreIsolatedFromNonKeyedServices() + { + var provider = ConfigureProvider(services => + { + services.Scan(scan => scan + .FromTypes(typeof(MixedKeyedService)) + .UsingAttributes()); + }); + + using var scope = provider.CreateScope(); + var keyedService = scope.ServiceProvider.GetRequiredKeyedService("mixed-key"); + var nonKeyedService = provider.GetRequiredService(); + + Assert.IsType(keyedService); + Assert.IsType(nonKeyedService); + Assert.NotSame(keyedService, nonKeyedService); + } + + [Fact] + public void CanRegisterKeyedServiceWithObjectKey() + { + Collection.Scan(scan => scan + .FromTypes(typeof(ObjectKeyedService)) + .UsingAttributes()); + + Assert.Single(Collection); + + var service = Collection.Single(); + Assert.Equal(typeof(IKeyedTestService), service.ServiceType); + Assert.Equal(typeof(ObjectKeyedService), service.KeyedImplementationType); + Assert.Equal(42, service.ServiceKey); + Assert.True(service.IsKeyedService); + } + + [Fact] + public void CanResolveKeyedServiceWithObjectKey() + { + var provider = ConfigureProvider(services => + { + services.Scan(scan => scan + .FromTypes(typeof(ObjectKeyedService)) + .UsingAttributes()); + }); + + var keyedService = provider.GetRequiredKeyedService(42); + Assert.IsType(keyedService); + } + + [Fact] + public void CanRegisterKeyedServiceWithEnumKey() + { + Collection.Scan(scan => scan + .FromTypes(typeof(EnumKeyedService)) + .UsingAttributes()); + + Assert.Single(Collection); + + var service = Collection.Single(); + Assert.Equal(typeof(IKeyedTestService), service.ServiceType); + Assert.Equal(typeof(EnumKeyedService), service.KeyedImplementationType); + Assert.Equal(TestEnum.Value1, service.ServiceKey); + Assert.True(service.IsKeyedService); + } + + [Fact] + public void CanRegisterKeyedServiceWithDifferentServiceTypes() + { + Collection.Scan(scan => scan + .FromTypes(typeof(MultiServiceKeyedService)) + .UsingAttributes()); + + Assert.Equal(2, Collection.Count); + + var keyedService = Collection.First(s => s.ServiceType == typeof(IKeyedTestService)); + Assert.Equal(typeof(MultiServiceKeyedService), keyedService.KeyedImplementationType); + Assert.Equal("service-key", keyedService.ServiceKey); + + var otherKeyedService = Collection.First(s => s.ServiceType == typeof(IOtherKeyedTestService)); + Assert.Equal(typeof(MultiServiceKeyedService), otherKeyedService.KeyedImplementationType); + Assert.Equal("other-key", otherKeyedService.ServiceKey); + } + + [Fact] + public void ThrowsWhenResolvingNonExistentKeyedService() + { + var provider = ConfigureProvider(services => + { + services.Scan(scan => scan + .FromTypes(typeof(KeyedTransientService)) + .UsingAttributes()); + }); + + Assert.Throws(() => + provider.GetRequiredKeyedService("non-existent-key")); + } + + [Fact] + public void CanRegisterKeyedServiceWithNullServiceType() + { + Collection.Scan(scan => scan + .FromTypes(typeof(KeyedServiceWithNullServiceType)) + .UsingAttributes()); + + // Should register for the implementation type and all its interfaces + Assert.Equal(2, Collection.Count); // IKeyedTestService and KeyedServiceWithNullServiceType itself + + var services = Collection.ToArray(); + Assert.All(services, s => + { + Assert.Equal("null-service-type-key", s.ServiceKey); + Assert.True(s.IsKeyedService); + }); + } + + + [Fact] + public void AllowsSameServiceTypeWithDifferentKeys() + { + Collection.Scan(scan => scan + .FromTypes(typeof(SameServiceTypeDifferentKeys)) + .UsingAttributes()); + + Assert.Equal(2, Collection.Count); + + var service1 = Collection.First(s => s.ServiceKey?.ToString() == "key1"); + var service2 = Collection.First(s => s.ServiceKey?.ToString() == "key2"); + + Assert.Equal(typeof(IKeyedTestService), service1.ServiceType); + Assert.Equal(typeof(IKeyedTestService), service2.ServiceType); + Assert.NotEqual(service1.ServiceKey, service2.ServiceKey); + } + + [Fact] + public void CanRegisterServiceWithNullKey() + { + Collection.Scan(scan => scan + .FromTypes(typeof(NullKeyedService)) + .UsingAttributes()); + + Assert.Single(Collection); + + var service = Collection.Single(); + Assert.Equal(typeof(IKeyedTestService), service.ServiceType); + Assert.Equal(typeof(NullKeyedService), service.ImplementationType); + Assert.Null(service.ServiceKey); + Assert.False(service.IsKeyedService); + } + + [Fact] + public void CanResolveServiceWithNullKey() + { + var provider = ConfigureProvider(services => + { + services.Scan(scan => scan + .FromTypes(typeof(NullKeyedService)) + .UsingAttributes()); + }); + + var service = provider.GetRequiredService(); + Assert.IsType(service); + } +} + +// Test interfaces and classes for keyed services +public interface IKeyedTestService { } +public interface IOtherKeyedTestService { } + +[ServiceDescriptor(typeof(IKeyedTestService), ServiceLifetime.Transient, "test-key")] +public class KeyedTransientService : IKeyedTestService { } + +[ServiceDescriptor(ServiceLifetime.Scoped, "generic-key")] +public class GenericKeyedService : IKeyedTestService { } + +[ServiceDescriptor(typeof(IKeyedTestService), ServiceLifetime.Transient, "key1")] +[ServiceDescriptor(typeof(IKeyedTestService), ServiceLifetime.Singleton, "key2")] +public class MultipleKeyedService : IKeyedTestService { } + +[ServiceDescriptor(typeof(IKeyedTestService), ServiceLifetime.Scoped, "mixed-key")] +[ServiceDescriptor(typeof(IKeyedTestService), ServiceLifetime.Transient)] +public class MixedKeyedService : IKeyedTestService { } + +[ServiceDescriptor(typeof(IKeyedTestService), ServiceLifetime.Transient, 42)] +public class ObjectKeyedService : IKeyedTestService { } + +public enum TestEnum +{ + Value1, + Value2 +} + +[ServiceDescriptor(typeof(IKeyedTestService), ServiceLifetime.Transient, TestEnum.Value1)] +public class EnumKeyedService : IKeyedTestService { } + +[ServiceDescriptor(typeof(IKeyedTestService), ServiceLifetime.Transient, "service-key")] +[ServiceDescriptor(typeof(IOtherKeyedTestService), ServiceLifetime.Scoped, "other-key")] +public class MultiServiceKeyedService : IKeyedTestService, IOtherKeyedTestService { } + +[ServiceDescriptor(null, ServiceLifetime.Transient, "null-service-type-key")] +public class KeyedServiceWithNullServiceType : IKeyedTestService { } + + +[ServiceDescriptor(typeof(IKeyedTestService), ServiceLifetime.Transient, "key1")] +[ServiceDescriptor(typeof(IKeyedTestService), ServiceLifetime.Scoped, "key2")] +public class SameServiceTypeDifferentKeys : IKeyedTestService { } + +[ServiceDescriptor(typeof(IKeyedTestService), ServiceLifetime.Transient, null)] +public class NullKeyedService : IKeyedTestService { }