diff --git a/src/libraries/Microsoft.Extensions.DependencyInjection/src/ServiceLookup/CallSiteFactory.cs b/src/libraries/Microsoft.Extensions.DependencyInjection/src/ServiceLookup/CallSiteFactory.cs index d660eb104007d..908d46787f13a 100644 --- a/src/libraries/Microsoft.Extensions.DependencyInjection/src/ServiceLookup/CallSiteFactory.cs +++ b/src/libraries/Microsoft.Extensions.DependencyInjection/src/ServiceLookup/CallSiteFactory.cs @@ -316,13 +316,9 @@ void AddCallSite(ServiceCallSite callSite, int index) callSitesByIndex.Add(new(index, callSite)); } } - - ResultCache resultCache = ResultCache.None; - if (cacheLocation == CallSiteResultCacheLocation.Scope || cacheLocation == CallSiteResultCacheLocation.Root) - { - resultCache = new ResultCache(cacheLocation, callSiteKey); - } - + ResultCache resultCache = (cacheLocation == CallSiteResultCacheLocation.Scope || cacheLocation == CallSiteResultCacheLocation.Root) + ? new ResultCache(cacheLocation, callSiteKey) + : new ResultCache(CallSiteResultCacheLocation.None, callSiteKey); return _callSiteCache[callSiteKey] = new IEnumerableCallSite(resultCache, itemType, callSites); } finally diff --git a/src/libraries/Microsoft.Extensions.DependencyInjection/src/ServiceLookup/CallSiteValidator.cs b/src/libraries/Microsoft.Extensions.DependencyInjection/src/ServiceLookup/CallSiteValidator.cs index 17c5d34068c64..f228507a6b8ab 100644 --- a/src/libraries/Microsoft.Extensions.DependencyInjection/src/ServiceLookup/CallSiteValidator.cs +++ b/src/libraries/Microsoft.Extensions.DependencyInjection/src/ServiceLookup/CallSiteValidator.cs @@ -10,22 +10,23 @@ namespace Microsoft.Extensions.DependencyInjection.ServiceLookup internal sealed class CallSiteValidator: CallSiteVisitor { // Keys are services being resolved via GetService, values - first scoped service in their call site tree - private readonly ConcurrentDictionary _scopedServices = new ConcurrentDictionary(); + private readonly ConcurrentDictionary _scopedServices = new ConcurrentDictionary(); public void ValidateCallSite(ServiceCallSite callSite) { Type? scoped = VisitCallSite(callSite, default); if (scoped != null) { - _scopedServices[callSite.ServiceType] = scoped; + _scopedServices[callSite.Cache.Key] = scoped; } } - public void ValidateResolution(Type serviceType, IServiceScope scope, IServiceScope rootScope) + public void ValidateResolution(ServiceCallSite callSite, IServiceScope scope, IServiceScope rootScope) { if (ReferenceEquals(scope, rootScope) - && _scopedServices.TryGetValue(serviceType, out Type? scopedService)) + && _scopedServices.TryGetValue(callSite.Cache.Key, out Type? scopedService)) { + Type serviceType = callSite.ServiceType; if (serviceType == scopedService) { throw new InvalidOperationException( diff --git a/src/libraries/Microsoft.Extensions.DependencyInjection/src/ServiceLookup/ConstantCallSite.cs b/src/libraries/Microsoft.Extensions.DependencyInjection/src/ServiceLookup/ConstantCallSite.cs index 8071c67013352..6698464813455 100644 --- a/src/libraries/Microsoft.Extensions.DependencyInjection/src/ServiceLookup/ConstantCallSite.cs +++ b/src/libraries/Microsoft.Extensions.DependencyInjection/src/ServiceLookup/ConstantCallSite.cs @@ -10,7 +10,7 @@ internal sealed class ConstantCallSite : ServiceCallSite private readonly Type _serviceType; internal object? DefaultValue => Value; - public ConstantCallSite(Type serviceType, object? defaultValue): base(ResultCache.None) + public ConstantCallSite(Type serviceType, object? defaultValue) : base(ResultCache.None(serviceType)) { _serviceType = serviceType ?? throw new ArgumentNullException(nameof(serviceType)); if (defaultValue != null && !serviceType.IsInstanceOfType(defaultValue)) diff --git a/src/libraries/Microsoft.Extensions.DependencyInjection/src/ServiceLookup/ResultCache.cs b/src/libraries/Microsoft.Extensions.DependencyInjection/src/ServiceLookup/ResultCache.cs index 65b1c799b6f3a..5b4da78aaec11 100644 --- a/src/libraries/Microsoft.Extensions.DependencyInjection/src/ServiceLookup/ResultCache.cs +++ b/src/libraries/Microsoft.Extensions.DependencyInjection/src/ServiceLookup/ResultCache.cs @@ -8,7 +8,11 @@ namespace Microsoft.Extensions.DependencyInjection.ServiceLookup { internal struct ResultCache { - public static ResultCache None { get; } = new ResultCache(CallSiteResultCacheLocation.None, ServiceCacheKey.Empty); + public static ResultCache None(Type serviceType) + { + var cacheKey = new ServiceCacheKey(serviceType, 0); + return new ResultCache(CallSiteResultCacheLocation.None, cacheKey); + } internal ResultCache(CallSiteResultCacheLocation lifetime, ServiceCacheKey cacheKey) { diff --git a/src/libraries/Microsoft.Extensions.DependencyInjection/src/ServiceLookup/ServiceCacheKey.cs b/src/libraries/Microsoft.Extensions.DependencyInjection/src/ServiceLookup/ServiceCacheKey.cs index 737c23d7f4445..569fbef9de9bb 100644 --- a/src/libraries/Microsoft.Extensions.DependencyInjection/src/ServiceLookup/ServiceCacheKey.cs +++ b/src/libraries/Microsoft.Extensions.DependencyInjection/src/ServiceLookup/ServiceCacheKey.cs @@ -8,8 +8,6 @@ namespace Microsoft.Extensions.DependencyInjection.ServiceLookup { internal readonly struct ServiceCacheKey : IEquatable { - public static ServiceCacheKey Empty { get; } = new ServiceCacheKey(null, 0); - /// /// Type of service being cached /// diff --git a/src/libraries/Microsoft.Extensions.DependencyInjection/src/ServiceLookup/ServiceProviderCallSite.cs b/src/libraries/Microsoft.Extensions.DependencyInjection/src/ServiceLookup/ServiceProviderCallSite.cs index 3db2f7f0723e0..6271473505c29 100644 --- a/src/libraries/Microsoft.Extensions.DependencyInjection/src/ServiceLookup/ServiceProviderCallSite.cs +++ b/src/libraries/Microsoft.Extensions.DependencyInjection/src/ServiceLookup/ServiceProviderCallSite.cs @@ -7,7 +7,7 @@ namespace Microsoft.Extensions.DependencyInjection.ServiceLookup { internal sealed class ServiceProviderCallSite : ServiceCallSite { - public ServiceProviderCallSite() : base(ResultCache.None) + public ServiceProviderCallSite() : base(ResultCache.None(typeof(IServiceProvider))) { } diff --git a/src/libraries/Microsoft.Extensions.DependencyInjection/src/ServiceProvider.cs b/src/libraries/Microsoft.Extensions.DependencyInjection/src/ServiceProvider.cs index b776b78b835ff..0610d2096f273 100644 --- a/src/libraries/Microsoft.Extensions.DependencyInjection/src/ServiceProvider.cs +++ b/src/libraries/Microsoft.Extensions.DependencyInjection/src/ServiceProvider.cs @@ -120,9 +120,9 @@ private void OnCreate(ServiceCallSite callSite) _callSiteValidator?.ValidateCallSite(callSite); } - private void OnResolve(Type serviceType, IServiceScope scope) + private void OnResolve(ServiceCallSite callSite, IServiceScope scope) { - _callSiteValidator?.ValidateResolution(serviceType, scope, Root); + _callSiteValidator?.ValidateResolution(callSite, scope, Root); } internal object? GetService(Type serviceType, ServiceProviderEngineScope serviceProviderEngineScope) @@ -133,8 +133,6 @@ private void OnResolve(Type serviceType, IServiceScope scope) } Func realizedService = _realizedServices.GetOrAdd(serviceType, _createServiceAccessor); - OnResolve(serviceType, serviceProviderEngineScope); - DependencyInjectionEventSource.Log.ServiceResolved(this, serviceType); var result = realizedService.Invoke(serviceProviderEngineScope); System.Diagnostics.Debug.Assert(result is null || CallSiteFactory.IsService(serviceType)); return result; @@ -173,10 +171,20 @@ private void ValidateService(ServiceDescriptor descriptor) if (callSite.Cache.Location == CallSiteResultCacheLocation.Root) { object? value = CallSiteRuntimeResolver.Instance.Resolve(callSite, Root); - return scope => value; + return scope => + { + DependencyInjectionEventSource.Log.ServiceResolved(this, serviceType); + return value; + }; } - return _engine.RealizeService(callSite); + Func realizedService = _engine.RealizeService(callSite); + return scope => + { + OnResolve(callSite, scope); + DependencyInjectionEventSource.Log.ServiceResolved(this, serviceType); + return realizedService(scope); + }; } return _ => null; diff --git a/src/libraries/Microsoft.Extensions.DependencyInjection/tests/DI.Tests/ServiceLookup/CallSiteFactoryTest.cs b/src/libraries/Microsoft.Extensions.DependencyInjection/tests/DI.Tests/ServiceLookup/CallSiteFactoryTest.cs index f1abdcb46b0c2..dbf4e7b561918 100644 --- a/src/libraries/Microsoft.Extensions.DependencyInjection/tests/DI.Tests/ServiceLookup/CallSiteFactoryTest.cs +++ b/src/libraries/Microsoft.Extensions.DependencyInjection/tests/DI.Tests/ServiceLookup/CallSiteFactoryTest.cs @@ -786,17 +786,10 @@ public void CreateCallSite_EnumberableCachedAtLowestLevel(ServiceDescriptor[] de var callSite = factory(typeof(IEnumerable)); var expectedLocation = (CallSiteResultCacheLocation)expectedCacheLocation; - Assert.Equal(expectedLocation, callSite.Cache.Location); - if (expectedLocation != CallSiteResultCacheLocation.None) - { - Assert.Equal(0, callSite.Cache.Key.Slot); - Assert.Equal(typeof(IEnumerable), callSite.Cache.Key.Type); - } - else - { - Assert.Equal(ResultCache.None, callSite.Cache); - } + Assert.Equal(expectedLocation, callSite.Cache.Location); + Assert.Equal(0, callSite.Cache.Key.Slot); + Assert.Equal(typeof(IEnumerable), callSite.Cache.Key.Type); } [ConditionalFact(typeof(PlatformDetection), nameof(PlatformDetection.IsThreadingSupported))] diff --git a/src/libraries/Microsoft.Extensions.DependencyInjection/tests/DI.Tests/ServiceProviderValidationTests.cs b/src/libraries/Microsoft.Extensions.DependencyInjection/tests/DI.Tests/ServiceProviderValidationTests.cs index c26e8d65fbfce..c6f1834888f0f 100644 --- a/src/libraries/Microsoft.Extensions.DependencyInjection/tests/DI.Tests/ServiceProviderValidationTests.cs +++ b/src/libraries/Microsoft.Extensions.DependencyInjection/tests/DI.Tests/ServiceProviderValidationTests.cs @@ -2,6 +2,8 @@ // The .NET Foundation licenses this file to you under the MIT license. using System; +using System.Collections.Generic; +using System.Linq; using Microsoft.Extensions.DependencyInjection.Specification.Fakes; using Xunit; @@ -97,6 +99,49 @@ public void GetService_Throws_WhenGetServiceForScopedServiceIsCalledOnRootViaTra Assert.Equal($"Cannot resolve '{typeof(IFoo)}' from root provider because it requires scoped service '{typeof(IBar)}'.", exception.Message); } + [Theory] + [InlineData(true)] + [InlineData(false)] + public void GetService_DoesNotThrow_WhenGetServiceForPolymorphicServiceIsCalledOnRoot_AndTheLastOneIsNotScoped(bool validateOnBuild) + { + // Arrange + var serviceCollection = new ServiceCollection(); + serviceCollection.AddScoped(); + serviceCollection.AddTransient(); + using var serviceProvider = serviceCollection.BuildServiceProvider(new ServiceProviderOptions + { + ValidateScopes = true, + ValidateOnBuild = validateOnBuild + }); + + // Act + var actual = serviceProvider.GetService(); + + // Assert + Assert.IsType(actual); + } + + [Fact] + public void ScopeValidation_ShouldBeAbleToDistingushGenericCollections_WhenGetServiceIsCalledOnRoot() + { + // Arrange + var serviceCollection = new ServiceCollection(); + serviceCollection.AddTransient(); + serviceCollection.AddScoped(); + + serviceCollection.AddTransient(); + serviceCollection.AddTransient(); + + // Act + using var serviceProvider = serviceCollection.BuildServiceProvider(validateScopes: true); + Assert.Throws(() => serviceProvider.GetService>()); + var actual = serviceProvider.GetService>(); + + // Assert + Assert.IsType(actual.First()); + Assert.IsType(actual.Last()); + } + [Fact] public void GetService_DoesNotThrow_WhenScopeFactoryIsInjectedIntoSingleton() { @@ -206,6 +251,7 @@ private class Bar : IBar { } + private class Bar2 : IBar { public Bar2(IBaz baz) @@ -213,6 +259,10 @@ public Bar2(IBaz baz) } } + private class Bar3 : IBar + { + } + private interface IBaz { } @@ -221,6 +271,10 @@ private class Baz : IBaz { } + private class Baz2 : IBaz + { + } + private class BazRecursive : IBaz { public BazRecursive(IBaz baz)