Skip to content
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -343,10 +343,76 @@ public void Decorate_KeyedService_Registered_As_Factory()
Assert.IsNotNull(decorator1.Inner);
Assert.IsInstanceOfType<TestService>(decorator1.Inner);

service = _serviceProvider.GetRequiredKeyedService<ITestService>("key2");
Assert.IsNotNull(service);
Assert.IsInstanceOfType<TestService>(service);
}
service = _serviceProvider.GetRequiredKeyedService<ITestService>("key2");
Assert.IsNotNull(service);
Assert.IsInstanceOfType<TestService>(service);
}

#region Tests for Class Decoration (not just interfaces)

[TestMethod]
public void Decorate_ConcreteClass_Registered_As_Type()
{
var serviceCollection = new ServiceCollection();
serviceCollection.AddTransient<TestService>(); // Register concrete class directly
serviceCollection.Decorate<TestService, TestServiceConcreteDecorator>();

serviceCollection.AddTransient<ExternalService>();

using var _serviceProvider = serviceCollection.BuildServiceProvider();

var service = _serviceProvider.GetRequiredService<TestService>();

Assert.IsNotNull(service);
Assert.IsInstanceOfType<TestServiceConcreteDecorator>(service);
var decorator = (TestServiceConcreteDecorator)service;
Assert.IsNotNull(decorator.Inner);
Assert.IsInstanceOfType<TestService>(decorator.Inner);
}

[TestMethod]
public void Decorate_ConcreteClass_Registered_As_Instance()
{
var serviceCollection = new ServiceCollection();
var implementationInstance = new TestService();
serviceCollection.AddSingleton(implementationInstance); // Register concrete class instance
serviceCollection.Decorate<TestService, TestServiceConcreteDecorator>();

serviceCollection.AddTransient<ExternalService>();

using var _serviceProvider = serviceCollection.BuildServiceProvider();

var service = _serviceProvider.GetRequiredService<TestService>();

Assert.IsNotNull(service);
Assert.IsInstanceOfType<TestServiceConcreteDecorator>(service);
var decorator = (TestServiceConcreteDecorator)service;
Assert.IsNotNull(decorator.Inner);
Assert.IsInstanceOfType<TestService>(decorator.Inner);
Assert.AreEqual(decorator.Inner, implementationInstance);
}

[TestMethod]
public void Decorate_ConcreteClass_Registered_As_Factory()
{
var serviceCollection = new ServiceCollection();
serviceCollection.AddTransient<TestService>(_ => new TestService()); // Register with factory
serviceCollection.Decorate<TestService, TestServiceConcreteDecorator>();

serviceCollection.AddTransient<ExternalService>();

using var _serviceProvider = serviceCollection.BuildServiceProvider();

var service = _serviceProvider.GetRequiredService<TestService>();

Assert.IsNotNull(service);
Assert.IsInstanceOfType<TestServiceConcreteDecorator>(service);
var decorator = (TestServiceConcreteDecorator)service;
Assert.IsNotNull(decorator.Inner);
Assert.IsInstanceOfType<TestService>(decorator.Inner);
}

#endregion
}

public class TestServiceDecorator1 : ITestService
Expand Down Expand Up @@ -409,11 +475,11 @@ public TestServiceDecorator3(ITestService transientService)
}

public bool WasDisposed()
{
return Inner.WasDisposed();
}
}
}
#pragma warning restore MSTEST0032 // Assertion condition is always true
{
return Inner.WasDisposed();
}
}
}

#pragma warning restore MSTEST0032 // Assertion condition is always true
#pragma warning restore IDE0079 // Remove unnecessary suppression
56 changes: 38 additions & 18 deletions src/Mammoth.Extensions.DependencyInjection.Tests/TestServices.cs
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,13 @@ public interface ITestService

public interface IAnotherInterface;

public class TestService : IDisposable, ITestService, IAnotherInterface
{
private bool disposedValue;
public bool WasDisposed()
{
return disposedValue;
public class TestService : IDisposable, ITestService, IAnotherInterface
{
private bool disposedValue;

public virtual bool WasDisposed()
{
return disposedValue;
}

protected virtual void Dispose(bool disposing)
Expand Down Expand Up @@ -44,15 +44,15 @@ public void Dispose()
{
// Do not change this code. Put cleanup code in 'Dispose(bool disposing)' method
Dispose(disposing: true);
GC.SuppressFinalize(this);
}
public T? GetById<T>(string id) where T : class
{
return default;
}
}
GC.SuppressFinalize(this);
}

public virtual T? GetById<T>(string id) where T : class
{
return default;
}
}

public class AnotherTestService : IDisposable, ITestService, IAnotherInterface
{
private bool disposedValue;
Expand Down Expand Up @@ -103,8 +103,28 @@ public class KeyedService1 : IKeyedService;

public class KeyedService2 : IKeyedService;

public class ExternalService;

public class ExternalService;

public class TestServiceConcreteDecorator : TestService
{
public TestService Inner { get; }

public TestServiceConcreteDecorator(TestService inner)
{
Inner = inner;
}

public override T? GetById<T>(string id) where T : class
{
return Inner.GetById<T>(id);
}

public override bool WasDisposed()
{
return Inner.WasDisposed();
}
}

namespace Nested
{
public interface INestedAnotherInterface ;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,61 +17,86 @@ public static partial class ServiceCollectionExtensions
/// <typeparam name="TDecorator">The type of the decorator.</typeparam>
/// <param name="services">The service collection.</param>
/// <exception cref="InvalidOperationException">Thrown when the service type is not registered.</exception>
public static void Decorate<TInterface, TDecorator>(this IServiceCollection services)
where TInterface : class
where TDecorator : class, TInterface
{
var originalServiceDescriptor = services.LastOrDefault(d => d.ServiceType == typeof(TInterface))
?? throw new InvalidOperationException($"Service type {typeof(TInterface).Name} not registered.");

// Throw exception if TInterface is not an interface
if (!typeof(TInterface).IsInterface)
{
throw new InvalidOperationException($"Service type {typeof(TInterface).Name} is not an interface.");
}

// Remove the original service descriptor that has the interface as the service type
// Create a new descriptor that has the implementation type as new service type.
// If the original descriptor was registered with a factory function, a new proxy type will be created
// and used as the service type to replace the original descriptor.
// A new ServiceDescriptor will be created with the interface as the service type and the decorator as the implementation type.

services.Remove(originalServiceDescriptor);
var implementationType = originalServiceDescriptor.GetImplementationType()
?? throw new InvalidOperationException($"Service type {typeof(TInterface).Name} does not have an implementation type.");
var originalServiceDescriptorReplacement = originalServiceDescriptor.ChangeServiceType(implementationType);
services.Add(originalServiceDescriptorReplacement);

// Create a new service descriptor for the decorator
ServiceDescriptor newServiceDescriptor;
if (!originalServiceDescriptor.IsKeyedService)
{
newServiceDescriptor = new ServiceDescriptor(
typeof(TInterface),
serviceProvider =>
{
TInterface originalService = (TInterface)serviceProvider.GetRequiredService(originalServiceDescriptorReplacement.ServiceType);
return ActivatorUtilities.CreateInstance<TDecorator>(serviceProvider, originalService);
},
originalServiceDescriptor.Lifetime
);
}
else
{
newServiceDescriptor = new ServiceDescriptor(
typeof(TInterface),
originalServiceDescriptor.ServiceKey,
(serviceProvider, _) =>
{
TInterface originalService = (TInterface)serviceProvider.GetRequiredKeyedService(originalServiceDescriptorReplacement.ServiceType, originalServiceDescriptor.ServiceKey);
return ActivatorUtilities.CreateInstance<TDecorator>(serviceProvider, originalService);
},
originalServiceDescriptor.Lifetime
);
}

// Replace or insert the service descriptor
services.Add(newServiceDescriptor);
public static void Decorate<TInterface, TDecorator>(this IServiceCollection services)
where TInterface : class
where TDecorator : class, TInterface
{
var originalServiceDescriptor = services.LastOrDefault(d => d.ServiceType == typeof(TInterface))
?? throw new InvalidOperationException($"Service type {typeof(TInterface).Name} not registered.");

// Remove the original service descriptor that has the service type
services.Remove(originalServiceDescriptor);

// Create a new service descriptor for the decorator
Copy link
Contributor

@AGiorgetti AGiorgetti Aug 6, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@copilot improve the comment explaining the logic, exaplin why it shuould work

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Improved comments throughout ServiceCollectionExtensions.Decorators.cs explaining why the implementation works - the new approach preserves original registration information and creates wrapper factories that eliminate reflection overhead while supporting both interfaces and concrete classes. Commit 7693b10

ServiceDescriptor newServiceDescriptor;
if (!originalServiceDescriptor.IsKeyedService)
{
newServiceDescriptor = new ServiceDescriptor(
typeof(TInterface),
serviceProvider =>
{
TInterface originalService = CreateOriginalService<TInterface>(originalServiceDescriptor, serviceProvider);
return ActivatorUtilities.CreateInstance<TDecorator>(serviceProvider, originalService);
},
originalServiceDescriptor.Lifetime
);
}
else
{
newServiceDescriptor = new ServiceDescriptor(
typeof(TInterface),
originalServiceDescriptor.ServiceKey,
(serviceProvider, _) =>
{
TInterface originalService = CreateOriginalKeyedService<TInterface>(originalServiceDescriptor, serviceProvider, originalServiceDescriptor.ServiceKey);
return ActivatorUtilities.CreateInstance<TDecorator>(serviceProvider, originalService);
},
originalServiceDescriptor.Lifetime
);
}

// Add the decorator service descriptor
services.Add(newServiceDescriptor);
}

private static T CreateOriginalService<T>(ServiceDescriptor originalDescriptor, IServiceProvider serviceProvider) where T : class
{
if (originalDescriptor.ImplementationType != null)
{
return (T)ActivatorUtilities.CreateInstance(serviceProvider, originalDescriptor.ImplementationType);
}

if (originalDescriptor.ImplementationInstance != null)
{
return (T)originalDescriptor.ImplementationInstance;
}

if (originalDescriptor.ImplementationFactory != null)
{
return (T)originalDescriptor.ImplementationFactory(serviceProvider);
}

throw new InvalidOperationException($"Unable to create original service for type {typeof(T).Name}.");
}

private static T CreateOriginalKeyedService<T>(ServiceDescriptor originalDescriptor, IServiceProvider serviceProvider, object? serviceKey) where T : class
{
if (originalDescriptor.KeyedImplementationType != null)
{
return (T)ActivatorUtilities.CreateInstance(serviceProvider, originalDescriptor.KeyedImplementationType);
}

if (originalDescriptor.KeyedImplementationInstance != null)
{
return (T)originalDescriptor.KeyedImplementationInstance;
}

if (originalDescriptor.KeyedImplementationFactory != null)
{
return (T)originalDescriptor.KeyedImplementationFactory(serviceProvider, serviceKey);
}

throw new InvalidOperationException($"Unable to create original keyed service for type {typeof(T).Name}.");
}
}
}
Loading