Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix validation on build #87354

Merged
merged 6 commits into from
Jun 28, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -316,13 +316,9 @@ void AddCallSite(ServiceCallSite callSite, int index)
callSitesByIndex.Add(new(index, callSite));
}
}

ResultCache resultCache = ResultCache.None;
if (cacheLocation == CallSiteResultCacheLocation.Scope || cacheLocation == CallSiteResultCacheLocation.Root)
{
resultCache = new ResultCache(cacheLocation, callSiteKey);
}

ResultCache resultCache = (cacheLocation == CallSiteResultCacheLocation.Scope || cacheLocation == CallSiteResultCacheLocation.Root)
? new ResultCache(cacheLocation, callSiteKey)
: new ResultCache(CallSiteResultCacheLocation.None, callSiteKey);
return _callSiteCache[callSiteKey] = new IEnumerableCallSite(resultCache, itemType, callSites);
}
finally
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,22 +10,23 @@ namespace Microsoft.Extensions.DependencyInjection.ServiceLookup
internal sealed class CallSiteValidator: CallSiteVisitor<CallSiteValidator.CallSiteValidatorState, Type?>
{
// Keys are services being resolved via GetService, values - first scoped service in their call site tree
private readonly ConcurrentDictionary<Type, Type> _scopedServices = new ConcurrentDictionary<Type, Type>();
private readonly ConcurrentDictionary<ServiceCacheKey, Type> _scopedServices = new ConcurrentDictionary<ServiceCacheKey, Type>();

public void ValidateCallSite(ServiceCallSite callSite)
{
Type? scoped = VisitCallSite(callSite, default);
if (scoped != null)
{
_scopedServices[callSite.ServiceType] = scoped;
_scopedServices[callSite.Cache.Key] = scoped;
}
}

public void ValidateResolution(Type serviceType, IServiceScope scope, IServiceScope rootScope)
public void ValidateResolution(ServiceCallSite callSite, IServiceScope scope, IServiceScope rootScope)
{
if (ReferenceEquals(scope, rootScope)
&& _scopedServices.TryGetValue(serviceType, out Type? scopedService))
&& _scopedServices.TryGetValue(callSite.Cache.Key, out Type? scopedService))
{
Type serviceType = callSite.ServiceType;
if (serviceType == scopedService)
{
throw new InvalidOperationException(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ internal sealed class ConstantCallSite : ServiceCallSite
private readonly Type _serviceType;
internal object? DefaultValue => Value;

public ConstantCallSite(Type serviceType, object? defaultValue): base(ResultCache.None)
public ConstantCallSite(Type serviceType, object? defaultValue) : base(ResultCache.None(serviceType))
{
_serviceType = serviceType ?? throw new ArgumentNullException(nameof(serviceType));
if (defaultValue != null && !serviceType.IsInstanceOfType(defaultValue))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,11 @@ namespace Microsoft.Extensions.DependencyInjection.ServiceLookup
{
internal struct ResultCache
{
public static ResultCache None { get; } = new ResultCache(CallSiteResultCacheLocation.None, ServiceCacheKey.Empty);
public static ResultCache None(Type serviceType)
{
var cacheKey = new ServiceCacheKey(serviceType, 0);
return new ResultCache(CallSiteResultCacheLocation.None, cacheKey);
}

internal ResultCache(CallSiteResultCacheLocation lifetime, ServiceCacheKey cacheKey)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,6 @@ namespace Microsoft.Extensions.DependencyInjection.ServiceLookup
{
internal readonly struct ServiceCacheKey : IEquatable<ServiceCacheKey>
{
public static ServiceCacheKey Empty { get; } = new ServiceCacheKey(null, 0);

/// <summary>
/// Type of service being cached
/// </summary>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ namespace Microsoft.Extensions.DependencyInjection.ServiceLookup
{
internal sealed class ServiceProviderCallSite : ServiceCallSite
{
public ServiceProviderCallSite() : base(ResultCache.None)
public ServiceProviderCallSite() : base(ResultCache.None(typeof(IServiceProvider)))
{
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -120,9 +120,9 @@ private void OnCreate(ServiceCallSite callSite)
_callSiteValidator?.ValidateCallSite(callSite);
}

private void OnResolve(Type serviceType, IServiceScope scope)
private void OnResolve(ServiceCallSite callSite, IServiceScope scope)
{
_callSiteValidator?.ValidateResolution(serviceType, scope, Root);
_callSiteValidator?.ValidateResolution(callSite, scope, Root);
}

internal object? GetService(Type serviceType, ServiceProviderEngineScope serviceProviderEngineScope)
Expand All @@ -133,8 +133,6 @@ private void OnResolve(Type serviceType, IServiceScope scope)
}

Func<ServiceProviderEngineScope, object?> realizedService = _realizedServices.GetOrAdd(serviceType, _createServiceAccessor);
OnResolve(serviceType, serviceProviderEngineScope);
DependencyInjectionEventSource.Log.ServiceResolved(this, serviceType);
var result = realizedService.Invoke(serviceProviderEngineScope);
System.Diagnostics.Debug.Assert(result is null || CallSiteFactory.IsService(serviceType));
return result;
Expand Down Expand Up @@ -173,10 +171,20 @@ private void ValidateService(ServiceDescriptor descriptor)
if (callSite.Cache.Location == CallSiteResultCacheLocation.Root)
{
object? value = CallSiteRuntimeResolver.Instance.Resolve(callSite, Root);
return scope => value;
return scope =>
{
DependencyInjectionEventSource.Log.ServiceResolved(this, serviceType);
return value;
};
}

return _engine.RealizeService(callSite);
Func<ServiceProviderEngineScope, object?> realizedService = _engine.RealizeService(callSite);
return scope =>
{
OnResolve(callSite, scope);
DependencyInjectionEventSource.Log.ServiceResolved(this, serviceType);
return realizedService(scope);
};
}

return _ => null;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -786,17 +786,10 @@ public void CreateCallSite_EnumberableCachedAtLowestLevel(ServiceDescriptor[] de
var callSite = factory(typeof(IEnumerable<FakeService>));

var expectedLocation = (CallSiteResultCacheLocation)expectedCacheLocation;
Assert.Equal(expectedLocation, callSite.Cache.Location);

if (expectedLocation != CallSiteResultCacheLocation.None)
{
Assert.Equal(0, callSite.Cache.Key.Slot);
Assert.Equal(typeof(IEnumerable<FakeService>), callSite.Cache.Key.Type);
}
else
{
Assert.Equal(ResultCache.None, callSite.Cache);
}
Assert.Equal(expectedLocation, callSite.Cache.Location);
Assert.Equal(0, callSite.Cache.Key.Slot);
Assert.Equal(typeof(IEnumerable<FakeService>), callSite.Cache.Key.Type);
}

[ConditionalFact(typeof(PlatformDetection), nameof(PlatformDetection.IsThreadingSupported))]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
// The .NET Foundation licenses this file to you under the MIT license.

using System;
using System.Collections.Generic;
using System.Linq;
using Microsoft.Extensions.DependencyInjection.Specification.Fakes;
using Xunit;

Expand Down Expand Up @@ -97,6 +99,49 @@ public void GetService_Throws_WhenGetServiceForScopedServiceIsCalledOnRootViaTra
Assert.Equal($"Cannot resolve '{typeof(IFoo)}' from root provider because it requires scoped service '{typeof(IBar)}'.", exception.Message);
}

[Theory]
[InlineData(true)]
[InlineData(false)]
public void GetService_DoesNotThrow_WhenGetServiceForPolymorphicServiceIsCalledOnRoot_AndTheLastOneIsNotScoped(bool validateOnBuild)
{
// Arrange
var serviceCollection = new ServiceCollection();
serviceCollection.AddScoped<IBar, Bar>();
serviceCollection.AddTransient<IBar, Bar3>();
using var serviceProvider = serviceCollection.BuildServiceProvider(new ServiceProviderOptions
{
ValidateScopes = true,
ValidateOnBuild = validateOnBuild
});

// Act
var actual = serviceProvider.GetService<IBar>();

// Assert
Assert.IsType<Bar3>(actual);
}

[Fact]
public void ScopeValidation_ShouldBeAbleToDistingushGenericCollections_WhenGetServiceIsCalledOnRoot()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It appears this passes without the changes here. Is that intentional?

Copy link
Contributor Author

@mapogolions mapogolions Jun 12, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, It is intentional. It tests the GetCacheKey() private method that replaces ResultCache.None for services. (For instance, IEnumerable<Foo> and IEnumerable<Bar> have the same ResultCache that is equal to ResultCache.None if at least one of the Foo and at least one of the Bar have transient lifetime)

{
// Arrange
var serviceCollection = new ServiceCollection();
serviceCollection.AddTransient<IBar, Bar>();
serviceCollection.AddScoped<IBar, Bar3>();

serviceCollection.AddTransient<IBaz, Baz>();
serviceCollection.AddTransient<IBaz, Baz2>();

// Act
using var serviceProvider = serviceCollection.BuildServiceProvider(validateScopes: true);
Assert.Throws<InvalidOperationException>(() => serviceProvider.GetService<IEnumerable<IBar>>());
var actual = serviceProvider.GetService<IEnumerable<IBaz>>();

// Assert
Assert.IsType<Baz>(actual.First());
Assert.IsType<Baz2>(actual.Last());
}

[Fact]
public void GetService_DoesNotThrow_WhenScopeFactoryIsInjectedIntoSingleton()
{
Expand Down Expand Up @@ -206,13 +251,18 @@ private class Bar : IBar
{
}


private class Bar2 : IBar
{
public Bar2(IBaz baz)
{
}
}

private class Bar3 : IBar
{
}

private interface IBaz
{
}
Expand All @@ -221,6 +271,10 @@ private class Baz : IBaz
{
}

private class Baz2 : IBaz
{
}

private class BazRecursive : IBaz
{
public BazRecursive(IBaz baz)
Expand Down