From 485e4bf291285e281f1d8ff8861bf9b7a7827c64 Mon Sep 17 00:00:00 2001 From: Pavel Ivanov Date: Thu, 20 Jul 2023 18:57:55 +0500 Subject: [PATCH] Fix service accessor scope validation for the emit-based version --- .../src/ServiceProvider.cs | 50 ++++++++++--------- .../ServiceProviderValidationTests.cs | 25 ++++++++++ 2 files changed, 52 insertions(+), 23 deletions(-) diff --git a/src/libraries/Microsoft.Extensions.DependencyInjection/src/ServiceProvider.cs b/src/libraries/Microsoft.Extensions.DependencyInjection/src/ServiceProvider.cs index 613305cca6e5c..6705f7db39fd5 100644 --- a/src/libraries/Microsoft.Extensions.DependencyInjection/src/ServiceProvider.cs +++ b/src/libraries/Microsoft.Extensions.DependencyInjection/src/ServiceProvider.cs @@ -21,14 +21,14 @@ public sealed class ServiceProvider : IServiceProvider, IKeyedServiceProvider, I { private readonly CallSiteValidator? _callSiteValidator; - private readonly Func> _createServiceAccessor; + private readonly Func _createServiceAccessor; // Internal for testing internal ServiceProviderEngine _engine; private bool _disposed; - private readonly ConcurrentDictionary> _realizedServices; + private readonly ConcurrentDictionary _serviceAccessors; internal CallSiteFactory CallSiteFactory { get; } @@ -50,7 +50,7 @@ internal ServiceProvider(ICollection serviceDescriptors, Serv Root = new ServiceProviderEngineScope(this, isRootScope: true); _engine = GetEngine(); _createServiceAccessor = CreateServiceAccessor; - _realizedServices = new ConcurrentDictionary>(); + _serviceAccessors = new ConcurrentDictionary(); CallSiteFactory = new CallSiteFactory(serviceDescriptors); // The list of built in services that aren't part of the list of service descriptors @@ -137,9 +137,12 @@ private void OnCreate(ServiceCallSite callSite) _callSiteValidator?.ValidateCallSite(callSite); } - private void OnResolve(ServiceCallSite callSite, IServiceScope scope) + private void OnResolve(ServiceCallSite? callSite, IServiceScope scope) { - _callSiteValidator?.ValidateResolution(callSite, scope, Root); + if (callSite != null) + { + _callSiteValidator?.ValidateResolution(callSite, scope, Root); + } } internal object? GetService(ServiceIdentifier serviceIdentifier, ServiceProviderEngineScope serviceProviderEngineScope) @@ -148,9 +151,10 @@ private void OnResolve(ServiceCallSite callSite, IServiceScope scope) { ThrowHelper.ThrowObjectDisposedException(); } - - Func realizedService = _realizedServices.GetOrAdd(serviceIdentifier, _createServiceAccessor); - var result = realizedService.Invoke(serviceProviderEngineScope); + ServiceAccessor serviceAccessor = _serviceAccessors.GetOrAdd(serviceIdentifier, _createServiceAccessor); + OnResolve(serviceAccessor.CallSite, serviceProviderEngineScope); + DependencyInjectionEventSource.Log.ServiceResolved(this, serviceIdentifier.ServiceType); + object? result = serviceAccessor.RealizedService?.Invoke(serviceProviderEngineScope); System.Diagnostics.Debug.Assert(result is null || CallSiteFactory.IsService(serviceIdentifier)); return result; } @@ -176,7 +180,7 @@ private void ValidateService(ServiceDescriptor descriptor) } } - private Func CreateServiceAccessor(ServiceIdentifier serviceIdentifier) + private ServiceAccessor CreateServiceAccessor(ServiceIdentifier serviceIdentifier) { ServiceCallSite? callSite = CallSiteFactory.GetCallSite(serviceIdentifier, new CallSiteChain()); if (callSite != null) @@ -188,28 +192,22 @@ private void ValidateService(ServiceDescriptor descriptor) if (callSite.Cache.Location == CallSiteResultCacheLocation.Root) { object? value = CallSiteRuntimeResolver.Instance.Resolve(callSite, Root); - return scope => - { - DependencyInjectionEventSource.Log.ServiceResolved(this, serviceIdentifier.ServiceType); - return value; - }; + return new ServiceAccessor { CallSite = callSite, RealizedService = scope => value }; } Func realizedService = _engine.RealizeService(callSite); - return scope => - { - OnResolve(callSite, scope); - DependencyInjectionEventSource.Log.ServiceResolved(this, serviceIdentifier.ServiceType); - return realizedService(scope); - }; + return new ServiceAccessor { CallSite = callSite, RealizedService = realizedService }; } - - return _ => null; + return new ServiceAccessor { CallSite = callSite, RealizedService = _ => null }; } internal void ReplaceServiceAccessor(ServiceCallSite callSite, Func accessor) { - _realizedServices[new ServiceIdentifier(callSite.Key, callSite.ServiceType)] = accessor; + _serviceAccessors[new ServiceIdentifier(callSite.Key, callSite.ServiceType)] = new ServiceAccessor + { + CallSite = callSite, + RealizedService = accessor + }; } internal IServiceScope CreateScope() @@ -262,5 +260,11 @@ public ServiceProviderDebugView(ServiceProvider serviceProvider) public bool Disposed => _serviceProvider.Root.Disposed; public bool IsScope => !_serviceProvider.Root.IsRootScope; } + + private sealed class ServiceAccessor + { + public ServiceCallSite? CallSite { get; set; } + public Func? RealizedService { get; set; } + } } } diff --git a/src/libraries/Microsoft.Extensions.DependencyInjection/tests/DI.Tests/ServiceProviderValidationTests.cs b/src/libraries/Microsoft.Extensions.DependencyInjection/tests/DI.Tests/ServiceProviderValidationTests.cs index c6f1834888f0f..8780312c2e8ff 100644 --- a/src/libraries/Microsoft.Extensions.DependencyInjection/tests/DI.Tests/ServiceProviderValidationTests.cs +++ b/src/libraries/Microsoft.Extensions.DependencyInjection/tests/DI.Tests/ServiceProviderValidationTests.cs @@ -4,6 +4,7 @@ using System; using System.Collections.Generic; using System.Linq; +using System.Threading.Tasks; using Microsoft.Extensions.DependencyInjection.Specification.Fakes; using Xunit; @@ -85,6 +86,30 @@ public void GetService_Throws_WhenGetServiceForScopedServiceIsCalledOnRoot() Assert.Equal($"Cannot resolve scoped service '{typeof(IBar)}' from root provider.", exception.Message); } + [Fact] + public async void GetService_Throws_WhenGetServiceForScopedServiceIsCalledOnRoot_IL_Replacement() + { + // Arrange + var serviceCollection = new ServiceCollection(); + serviceCollection.AddScoped(); + var serviceProvider = serviceCollection.BuildServiceProvider(validateScopes: true); + + // Act + Assert + using (var scope = serviceProvider.CreateScope()) + { + // Switch to an emit-based version which is triggered in the background after 2 calls to GetService. + scope.ServiceProvider.GetRequiredService(typeof(IBar)); + scope.ServiceProvider.GetRequiredService(typeof(IBar)); + + // Give the background thread time to generate the emit version. + await Task.Delay(100); + + // Ensure the emit-based version has the correct scope checks. + var exception = Assert.Throws(serviceProvider.GetRequiredService); + Assert.Equal($"Cannot resolve scoped service '{typeof(IBar)}' from root provider.", exception.Message); + } + } + [Fact] public void GetService_Throws_WhenGetServiceForScopedServiceIsCalledOnRootViaTransient() {