Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 14 additions & 2 deletions src/Scrutor/ILifetimeSelector.cs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
using System;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.DependencyInjection;
using System;

namespace Scrutor;

Expand Down Expand Up @@ -29,4 +29,16 @@ public interface ILifetimeSelector : IServiceTypeSelector
/// Registers each matching concrete type with a lifetime based on the provided <paramref name="selector"/>.
/// </summary>
IImplementationTypeSelector WithLifetime(Func<Type, ServiceLifetime> selector);

/// <summary>
/// Registers each matching concrete type with the specified <paramref name="serviceKey"/>.
/// </summary>
/// <param name="serviceKey">The service key to use for registration.</param>
ILifetimeSelector WithServiceKey(object serviceKey);

/// <summary>
/// Registers each matching concrete type with a service key based on the provided <paramref name="selector"/>.
/// </summary>
/// <param name="selector">A function to determine the service key for each type.</param>
ILifetimeSelector WithServiceKey(Func<Type, object?> selector);
}
55 changes: 49 additions & 6 deletions src/Scrutor/LifetimeSelector.cs
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
using System;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.DependencyModel;
using System;
using System.Collections.Generic;
using System.Diagnostics.CodeAnalysis;
using System.Reflection;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.DependencyModel;

namespace Scrutor;

Expand All @@ -24,6 +24,8 @@ public LifetimeSelector(ServiceTypeSelector inner, IEnumerable<TypeMap> typeMaps

public Func<Type, ServiceLifetime>? SelectorFn { get; set; }

public Func<Type, object?>? ServiceKeySelectorFn { get; set; }

public IImplementationTypeSelector WithSingletonLifetime()
{
return WithLifetime(ServiceLifetime.Singleton);
Expand Down Expand Up @@ -54,6 +56,22 @@ public IImplementationTypeSelector WithLifetime(Func<Type, ServiceLifetime> sele
return this;
}

public ILifetimeSelector WithServiceKey(object serviceKey)
{
Preconditions.NotNull(serviceKey, nameof(serviceKey));

return WithServiceKey(_ => serviceKey);
}

public ILifetimeSelector WithServiceKey(Func<Type, object?> selector)
{
Preconditions.NotNull(selector, nameof(selector));

Inner.PropagateServiceKey(selector);

return this;
}

#region Chain Methods

[ExcludeFromCodeCoverage]
Expand Down Expand Up @@ -231,6 +249,7 @@ void ISelector.Populate(IServiceCollection services, RegistrationStrategy? strat
strategy ??= RegistrationStrategy.Append;

var lifetimes = new Dictionary<Type, ServiceLifetime>();
var serviceKeys = new Dictionary<Type, object?>();

foreach (var typeMap in TypeMaps)
{
Expand All @@ -244,8 +263,10 @@ void ISelector.Populate(IServiceCollection services, RegistrationStrategy? strat
}

var lifetime = GetOrAddLifetime(lifetimes, implementationType);

var descriptor = new ServiceDescriptor(serviceType, implementationType, lifetime);
var serviceKey = GetOrAddServiceKey(serviceKeys, implementationType);
var descriptor = serviceKey is not null
? new ServiceDescriptor(serviceType, serviceKey, implementationType, lifetime)
: new ServiceDescriptor(serviceType, implementationType, lifetime);

strategy.Apply(services, descriptor);
}
Expand All @@ -256,8 +277,11 @@ void ISelector.Populate(IServiceCollection services, RegistrationStrategy? strat
foreach (var serviceType in typeFactoryMap.ServiceTypes)
{
var lifetime = GetOrAddLifetime(lifetimes, typeFactoryMap.ImplementationType);
var serviceKey = GetOrAddServiceKey(serviceKeys, typeFactoryMap.ImplementationType);

var descriptor = new ServiceDescriptor(serviceType, typeFactoryMap.ImplementationFactory, lifetime);
var descriptor = serviceKey is not null
? new ServiceDescriptor(serviceType, serviceKey, WrapImplementationFactory(typeFactoryMap.ImplementationFactory), lifetime)
: new ServiceDescriptor(serviceType, typeFactoryMap.ImplementationFactory, lifetime);

strategy.Apply(services, descriptor);
}
Expand All @@ -277,4 +301,23 @@ private ServiceLifetime GetOrAddLifetime(Dictionary<Type, ServiceLifetime> lifet

return lifetime;
}

private object? GetOrAddServiceKey(Dictionary<Type, object?> serviceKeys, Type implementationType)
{
if (serviceKeys.TryGetValue(implementationType, out var serviceKey))
{
return serviceKey;
}

serviceKey = ServiceKeySelectorFn?.Invoke(implementationType);

serviceKeys[implementationType] = serviceKey;

return serviceKey;
}

private static Func<IServiceProvider, object?, object> WrapImplementationFactory(Func<IServiceProvider, object> factory)
{
return (sp, _) => factory(sp);
}
}
14 changes: 11 additions & 3 deletions src/Scrutor/ServiceTypeSelector.cs
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
using System;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.DependencyModel;
using System;
using System.Collections;
using System.Collections.Generic;
using System.Linq;
using System.Reflection;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.DependencyModel;

namespace Scrutor;

Expand Down Expand Up @@ -207,6 +207,14 @@ internal void PropagateLifetime(Func<Type, ServiceLifetime> selectorFn)
}
}

internal void PropagateServiceKey(Func<Type, object?> selectorFn)
{
foreach (var selector in Selectors.OfType<LifetimeSelector>())
{
selector.ServiceKeySelectorFn = selectorFn;
}
}

void ISelector.Populate(IServiceCollection services, RegistrationStrategy? registrationStrategy)
{
if (Selectors.Count == 0)
Expand Down
64 changes: 62 additions & 2 deletions test/Scrutor.Tests/ScanningTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -584,10 +584,70 @@ public void ShouldAllowOptInToCompilerGeneratedTypes()
.AsSelf()
.WithTransientLifetime());
});

var compilerGeneratedSubclass = provider.GetService<AllowedCompilerGeneratedSubclass>();
Assert.NotNull(compilerGeneratedSubclass);
}

[Fact]
public void CanRegisterWithServiceKey()
{
Collection.Scan(scan => scan
.FromTypes<TransientService1, TransientService2>()
.AsImplementedInterfaces(x => x != typeof(IOtherInheritance))
.WithServiceKey("my-key")
.WithSingletonLifetime());

Assert.Equal(2, Collection.Count);

Assert.All(Collection, x =>
{
Assert.Equal(ServiceLifetime.Singleton, x.Lifetime);
Assert.Equal(typeof(ITransientService), x.ServiceType);
Assert.True(x.IsKeyedService);
Assert.Equal("my-key", x.ServiceKey);
});
}

[Fact]
public void CanRegisterWithServiceKeySelector()
{
Collection.Scan(scan => scan
.FromTypes<TransientService1, TransientService2>()
.AsImplementedInterfaces(x => x != typeof(IOtherInheritance))
.WithServiceKey(type => type.Name)
.WithSingletonLifetime());

Assert.Equal(2, Collection.Count);

var service1 = Collection.First(x => x.ServiceKey as string == nameof(TransientService1));
Assert.Equal(typeof(ITransientService), service1.ServiceType);
Assert.Equal(ServiceLifetime.Singleton, service1.Lifetime);
Assert.True(service1.IsKeyedService);

var service2 = Collection.First(x => x.ServiceKey as string == nameof(TransientService2));
Assert.Equal(typeof(ITransientService), service2.ServiceType);
Assert.Equal(ServiceLifetime.Singleton, service2.Lifetime);
Assert.True(service2.IsKeyedService);
}

[Fact]
public void CanResolveKeyedServices()
{
Collection.Scan(scan => scan
.FromTypes<TransientService1, TransientService2>()
.AsSelf()
.WithServiceKey(type => type.Name)
.WithTransientLifetime());

var provider = Collection.BuildServiceProvider();

var service1 = provider.GetRequiredKeyedService<TransientService1>(nameof(TransientService1));
var service2 = provider.GetRequiredKeyedService<TransientService2>(nameof(TransientService2));

Assert.NotNull(service1);
Assert.NotNull(service2);
}
}

// ReSharper disable UnusedTypeParameter
Expand Down Expand Up @@ -671,7 +731,7 @@ public class DefaultAttributes : IDefault3Level2, IDefault1, IDefault2 { }
[CompilerGenerated]
public class CompilerGenerated { }

public class CombinedService2: IDefault1, IDefault2, IDefault3Level2 { }
public class CombinedService2 : IDefault1, IDefault2, IDefault3Level2 { }

public interface IGenericAttribute { }

Expand Down