Skip to content

Commit

Permalink
Fix service accessor scope validation for the emit-based version
Browse files Browse the repository at this point in the history
  • Loading branch information
mapogolions authored Jul 20, 2023
1 parent d672bcc commit 485e4bf
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 23 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,14 @@ public sealed class ServiceProvider : IServiceProvider, IKeyedServiceProvider, I
{
private readonly CallSiteValidator? _callSiteValidator;

private readonly Func<ServiceIdentifier, Func<ServiceProviderEngineScope, object?>> _createServiceAccessor;
private readonly Func<ServiceIdentifier, ServiceAccessor> _createServiceAccessor;

// Internal for testing
internal ServiceProviderEngine _engine;

private bool _disposed;

private readonly ConcurrentDictionary<ServiceIdentifier, Func<ServiceProviderEngineScope, object?>> _realizedServices;
private readonly ConcurrentDictionary<ServiceIdentifier, ServiceAccessor> _serviceAccessors;

internal CallSiteFactory CallSiteFactory { get; }

Expand All @@ -50,7 +50,7 @@ internal ServiceProvider(ICollection<ServiceDescriptor> serviceDescriptors, Serv
Root = new ServiceProviderEngineScope(this, isRootScope: true);
_engine = GetEngine();
_createServiceAccessor = CreateServiceAccessor;
_realizedServices = new ConcurrentDictionary<ServiceIdentifier, Func<ServiceProviderEngineScope, object?>>();
_serviceAccessors = new ConcurrentDictionary<ServiceIdentifier, ServiceAccessor>();

CallSiteFactory = new CallSiteFactory(serviceDescriptors);
// The list of built in services that aren't part of the list of service descriptors
Expand Down Expand Up @@ -137,9 +137,12 @@ private void OnCreate(ServiceCallSite callSite)
_callSiteValidator?.ValidateCallSite(callSite);
}

private void OnResolve(ServiceCallSite callSite, IServiceScope scope)
private void OnResolve(ServiceCallSite? callSite, IServiceScope scope)
{
_callSiteValidator?.ValidateResolution(callSite, scope, Root);
if (callSite != null)
{
_callSiteValidator?.ValidateResolution(callSite, scope, Root);
}
}

internal object? GetService(ServiceIdentifier serviceIdentifier, ServiceProviderEngineScope serviceProviderEngineScope)
Expand All @@ -148,9 +151,10 @@ private void OnResolve(ServiceCallSite callSite, IServiceScope scope)
{
ThrowHelper.ThrowObjectDisposedException();
}

Func<ServiceProviderEngineScope, object?> realizedService = _realizedServices.GetOrAdd(serviceIdentifier, _createServiceAccessor);
var result = realizedService.Invoke(serviceProviderEngineScope);
ServiceAccessor serviceAccessor = _serviceAccessors.GetOrAdd(serviceIdentifier, _createServiceAccessor);
OnResolve(serviceAccessor.CallSite, serviceProviderEngineScope);
DependencyInjectionEventSource.Log.ServiceResolved(this, serviceIdentifier.ServiceType);
object? result = serviceAccessor.RealizedService?.Invoke(serviceProviderEngineScope);
System.Diagnostics.Debug.Assert(result is null || CallSiteFactory.IsService(serviceIdentifier));
return result;
}
Expand All @@ -176,7 +180,7 @@ private void ValidateService(ServiceDescriptor descriptor)
}
}

private Func<ServiceProviderEngineScope, object?> CreateServiceAccessor(ServiceIdentifier serviceIdentifier)
private ServiceAccessor CreateServiceAccessor(ServiceIdentifier serviceIdentifier)
{
ServiceCallSite? callSite = CallSiteFactory.GetCallSite(serviceIdentifier, new CallSiteChain());
if (callSite != null)
Expand All @@ -188,28 +192,22 @@ private void ValidateService(ServiceDescriptor descriptor)
if (callSite.Cache.Location == CallSiteResultCacheLocation.Root)
{
object? value = CallSiteRuntimeResolver.Instance.Resolve(callSite, Root);
return scope =>
{
DependencyInjectionEventSource.Log.ServiceResolved(this, serviceIdentifier.ServiceType);
return value;
};
return new ServiceAccessor { CallSite = callSite, RealizedService = scope => value };
}

Func<ServiceProviderEngineScope, object?> realizedService = _engine.RealizeService(callSite);
return scope =>
{
OnResolve(callSite, scope);
DependencyInjectionEventSource.Log.ServiceResolved(this, serviceIdentifier.ServiceType);
return realizedService(scope);
};
return new ServiceAccessor { CallSite = callSite, RealizedService = realizedService };
}

return _ => null;
return new ServiceAccessor { CallSite = callSite, RealizedService = _ => null };
}

internal void ReplaceServiceAccessor(ServiceCallSite callSite, Func<ServiceProviderEngineScope, object?> accessor)
{
_realizedServices[new ServiceIdentifier(callSite.Key, callSite.ServiceType)] = accessor;
_serviceAccessors[new ServiceIdentifier(callSite.Key, callSite.ServiceType)] = new ServiceAccessor
{
CallSite = callSite,
RealizedService = accessor
};
}

internal IServiceScope CreateScope()
Expand Down Expand Up @@ -262,5 +260,11 @@ public ServiceProviderDebugView(ServiceProvider serviceProvider)
public bool Disposed => _serviceProvider.Root.Disposed;
public bool IsScope => !_serviceProvider.Root.IsRootScope;
}

private sealed class ServiceAccessor
{
public ServiceCallSite? CallSite { get; set; }
public Func<ServiceProviderEngineScope, object?>? RealizedService { get; set; }
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Threading.Tasks;
using Microsoft.Extensions.DependencyInjection.Specification.Fakes;
using Xunit;

Expand Down Expand Up @@ -85,6 +86,30 @@ public void GetService_Throws_WhenGetServiceForScopedServiceIsCalledOnRoot()
Assert.Equal($"Cannot resolve scoped service '{typeof(IBar)}' from root provider.", exception.Message);
}

[Fact]
public async void GetService_Throws_WhenGetServiceForScopedServiceIsCalledOnRoot_IL_Replacement()
{
// Arrange
var serviceCollection = new ServiceCollection();
serviceCollection.AddScoped<IBar, Bar>();
var serviceProvider = serviceCollection.BuildServiceProvider(validateScopes: true);

// Act + Assert
using (var scope = serviceProvider.CreateScope())
{
// Switch to an emit-based version which is triggered in the background after 2 calls to GetService.
scope.ServiceProvider.GetRequiredService(typeof(IBar));
scope.ServiceProvider.GetRequiredService(typeof(IBar));

// Give the background thread time to generate the emit version.
await Task.Delay(100);

// Ensure the emit-based version has the correct scope checks.
var exception = Assert.Throws<InvalidOperationException>(serviceProvider.GetRequiredService<IBar>);
Assert.Equal($"Cannot resolve scoped service '{typeof(IBar)}' from root provider.", exception.Message);
}
}

[Fact]
public void GetService_Throws_WhenGetServiceForScopedServiceIsCalledOnRootViaTransient()
{
Expand Down

0 comments on commit 485e4bf

Please sign in to comment.