diff --git a/src/libraries/Microsoft.Extensions.DependencyInjection/src/ServiceLookup/CallSiteFactory.cs b/src/libraries/Microsoft.Extensions.DependencyInjection/src/ServiceLookup/CallSiteFactory.cs index 15a9b46843922..47d7d02f4db4b 100644 --- a/src/libraries/Microsoft.Extensions.DependencyInjection/src/ServiceLookup/CallSiteFactory.cs +++ b/src/libraries/Microsoft.Extensions.DependencyInjection/src/ServiceLookup/CallSiteFactory.cs @@ -18,6 +18,7 @@ internal sealed class CallSiteFactory private readonly ServiceDescriptor[] _descriptors; private readonly ConcurrentDictionary _callSiteCache = new ConcurrentDictionary(); private readonly Dictionary _descriptorLookup = new Dictionary(); + private readonly ConcurrentDictionary _callSiteLocks = new ConcurrentDictionary(); private readonly StackGuard _stackGuard; @@ -98,13 +99,28 @@ private ServiceCallSite CreateCallSite(Type serviceType, CallSiteChain callSiteC return _stackGuard.RunOnEmptyStack((type, chain) => CreateCallSite(type, chain), serviceType, callSiteChain); } - callSiteChain.CheckCircularDependency(serviceType); + // We need to lock the resolution process for a single service type at a time: + // Consider the following: + // C -> D -> A + // E -> D -> A + // Resolving C and E in parallel means that they will be modifying the callsite cache concurrently + // to add the entry for C and E, but the resolution of D and A is synchronized + // to make sure C and E both reference the same instance of the callsite. - ServiceCallSite callSite = TryCreateExact(serviceType, callSiteChain) ?? - TryCreateOpenGeneric(serviceType, callSiteChain) ?? - TryCreateEnumerable(serviceType, callSiteChain); + // This is to make sure we can safely store singleton values on the callsites themselves - return callSite; + var callsiteLock = _callSiteLocks.GetOrAdd(serviceType, static _ => new object()); + + lock (callsiteLock) + { + callSiteChain.CheckCircularDependency(serviceType); + + ServiceCallSite callSite = TryCreateExact(serviceType, callSiteChain) ?? + TryCreateOpenGeneric(serviceType, callSiteChain) ?? + TryCreateEnumerable(serviceType, callSiteChain); + + return callSite; + } } private ServiceCallSite TryCreateExact(Type serviceType, CallSiteChain callSiteChain) diff --git a/src/libraries/Microsoft.Extensions.DependencyInjection/tests/DI.Tests/ServiceLookup/CallSiteFactoryTest.cs b/src/libraries/Microsoft.Extensions.DependencyInjection/tests/DI.Tests/ServiceLookup/CallSiteFactoryTest.cs index 13d4decd776bf..5a72853b7ce7e 100644 --- a/src/libraries/Microsoft.Extensions.DependencyInjection/tests/DI.Tests/ServiceLookup/CallSiteFactoryTest.cs +++ b/src/libraries/Microsoft.Extensions.DependencyInjection/tests/DI.Tests/ServiceLookup/CallSiteFactoryTest.cs @@ -5,6 +5,7 @@ using System.Collections.Generic; using System.Linq; using System.Reflection; +using System.Threading.Tasks; using Microsoft.Extensions.DependencyInjection.Extensions; using Microsoft.Extensions.DependencyInjection.Specification.Fakes; using Xunit; @@ -738,6 +739,82 @@ public void CreateCallSite_EnumberableCachedAtLowestLevel(ServiceDescriptor[] de } } + [Fact] + public void CallSitesAreUniquePerServiceTypeAndSlot() + { + // Connected graph + // Class1 -> Class2 -> Class3 + // Class4 -> Class3 + // Class5 -> Class2 -> Class3 + var types = new Type[] { typeof(Class1), typeof(Class2), typeof(Class3), typeof(Class4), typeof(Class5) }; + + + for (int i = 0; i < 100; i++) + { + var factory = GetCallSiteFactory(types.Select(t => ServiceDescriptor.Transient(t, t)).ToArray()); + + var tasks = new Task[types.Length]; + for (int j = 0; j < types.Length; j++) + { + var type = types[j]; + tasks[j] = Task.Run(() => factory(type)); + } + + Task.WaitAll(tasks); + + var callsites = tasks.Select(t => t.Result).Cast().ToArray(); + + Assert.Equal(5, callsites.Length); + // Class1 -> Class2 + Assert.Same(callsites[0].ParameterCallSites[0], callsites[1]); + // Class2 -> Class3 + Assert.Same(callsites[1].ParameterCallSites[0], callsites[2]); + // Class4 -> Class3 + Assert.Same(callsites[3].ParameterCallSites[0], callsites[2]); + // Class5 -> Class2 + Assert.Same(callsites[4].ParameterCallSites[0], callsites[1]); + } + } + + [Fact] + public void CallSitesAreUniquePerServiceTypeAndSlotWithOpenGenericInGraph() + { + // Connected graph + // ClassA -> ClassB -> ClassC + // ClassD -> ClassC + // ClassE -> ClassB -> ClassC + var types = new Type[] { typeof(ClassA), typeof(ClassB), typeof(ClassC<>), typeof(ClassD), typeof(ClassE) }; + + for (int i = 0; i < 100; i++) + { + var factory = GetCallSiteFactory(types.Select(t => ServiceDescriptor.Transient(t, t)).ToArray()); + + var tasks = new Task[types.Length]; + for (int j = 0; j < types.Length; j++) + { + var type = types[j]; + tasks[j] = Task.Run(() => factory(type)); + } + + Task.WaitAll(tasks); + + var callsites = tasks.Select(t => t.Result).Cast().ToArray(); + + var cOfObject = factory(typeof(ClassC)); + var cOfString = factory(typeof(ClassC)); + + Assert.Equal(5, callsites.Length); + // ClassA -> ClassB + Assert.Same(callsites[0].ParameterCallSites[0], callsites[1]); + // ClassB -> ClassC + Assert.Same(callsites[1].ParameterCallSites[0], cOfObject); + // ClassD -> ClassC + Assert.Same(callsites[3].ParameterCallSites[0], cOfString); + // ClassE -> ClassB + Assert.Same(callsites[4].ParameterCallSites[0], callsites[1]); + } + } + private static Func GetCallSiteFactory(params ServiceDescriptor[] descriptors) { var collection = new ServiceCollection(); @@ -762,5 +839,19 @@ private static ConstructorInfo GetConstructor(Type type, Type[] parameterTypes) c => Enumerable.SequenceEqual( c.GetParameters().Select(p => p.ParameterType), parameterTypes)); + + + private class Class1 { public Class1(Class2 c2) { } } + private class Class2 { public Class2(Class3 c3) { } } + private class Class3 { } + private class Class4 { public Class4(Class3 c3) { } } + private class Class5 { public Class5(Class2 c2) { } } + + // Open generic + private class ClassA { public ClassA(ClassB cb) { } } + private class ClassB { public ClassB(ClassC cc) { } } + private class ClassC { } + private class ClassD { public ClassD(ClassC cd) { } } + private class ClassE { public ClassE(ClassB cb) { } } } }