Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Assign correct slots during IEnumerable resolution #80410

Merged
merged 6 commits into from
Jun 9, 2023
Original file line number Diff line number Diff line change
Expand Up @@ -761,6 +761,28 @@ public void ClosedServicesPreferredOverOpenGenericServices()
Assert.IsType<FakeService>(service);
}

// Reproduces https://github.com/dotnet/runtime/issues/79938
[Fact]
public void ResolvingEnumerableContainingOpenGenericServiceUsesCorrectSlot()
steveharter marked this conversation as resolved.
Show resolved Hide resolved
{
// Arrange
TestServiceCollection collection = new();
collection.AddTransient<IFakeOpenGenericService<PocoClass>, FakeService>();
collection.AddTransient(typeof(IFakeOpenGenericService<>), typeof(FakeOpenGenericService<>));
collection.AddSingleton<PocoClass>(); // needed for FakeOpenGenericService<>
IServiceProvider provider = CreateServiceProvider(collection);

// Act
IFakeOpenGenericService<PocoClass> service = provider.GetService<IFakeOpenGenericService<PocoClass>>();
IFakeOpenGenericService<PocoClass>[] services = provider.GetServices<IFakeOpenGenericService<PocoClass>>().ToArray();

// Assert
Assert.IsType<FakeService>(service);
Assert.Equal(2, services.Length);
Assert.True(services.Any(s => s.GetType() == typeof(FakeService)));
Assert.True(services.Any(s => s.GetType() == typeof(FakeOpenGenericService<PocoClass>)));
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note: ideally we'd assert some order, but the Lamar container returns a different order than other implementations.

steveharter marked this conversation as resolved.
Show resolved Hide resolved
}

[Fact]
public void AttemptingToResolveNonexistentServiceReturnsNull()
{
Expand Down Expand Up @@ -905,6 +927,7 @@ public void ResolvesMixedOpenClosedGenericsAsEnumerable()
Assert.NotNull(enumerable[2]);

Assert.Equal(instance, enumerable[2]);
Assert.True(enumerable[0] is FakeService, string.Join(", ", enumerable.Select(e => e.GetType())));
Assert.IsType<FakeService>(enumerable[0]);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
using System.Collections.Generic;
using System.Diagnostics;
using System.Diagnostics.CodeAnalysis;
using System.Linq;
madelson marked this conversation as resolved.
Show resolved Hide resolved
using System.Reflection;
using Microsoft.Extensions.Internal;

Expand Down Expand Up @@ -240,65 +241,80 @@ private static bool AreCompatible(DynamicallyAccessedMemberTypes serviceDynamica
{
callSiteChain.Add(serviceType);

if (serviceType.IsConstructedGenericType &&
serviceType.GetGenericTypeDefinition() == typeof(IEnumerable<>))
if (!serviceType.IsConstructedGenericType ||
serviceType.GetGenericTypeDefinition() != typeof(IEnumerable<>))
{
Type itemType = serviceType.GenericTypeArguments[0];
CallSiteResultCacheLocation cacheLocation = CallSiteResultCacheLocation.Root;
return null;
}

var callSites = new List<ServiceCallSite>();
Type itemType = serviceType.GenericTypeArguments[0];
CallSiteResultCacheLocation cacheLocation = CallSiteResultCacheLocation.Root;
ServiceCallSite[] callSites;

// If item type is not generic we can safely use descriptor cache
if (!itemType.IsConstructedGenericType &&
_descriptorLookup.TryGetValue(itemType, out ServiceDescriptorCacheItem descriptors))
// If item type is not generic we can safely use descriptor cache
if (!itemType.IsConstructedGenericType &&
_descriptorLookup.TryGetValue(itemType, out ServiceDescriptorCacheItem descriptors))
{
callSites = new ServiceCallSite[descriptors.Count];
for (int i = 0; i < descriptors.Count; i++)
{
for (int i = 0; i < descriptors.Count; i++)
{
ServiceDescriptor descriptor = descriptors[i];
ServiceDescriptor descriptor = descriptors[i];

// Last service should get slot 0
int slot = descriptors.Count - i - 1;
// There may not be any open generics here
ServiceCallSite? callSite = TryCreateExact(descriptor, itemType, callSiteChain, slot);
Debug.Assert(callSite != null);

cacheLocation = GetCommonCacheLocation(cacheLocation, callSite.Cache.Location);
callSites[i] = callSite;
}
}
else
{
// We need to construct a list of matching call sites in declaration order, but to ensure
// correct caching we must assign slots in reverse declaration order and with slots being
// given out first to any exact matches before any open generic matches. Therefore, we
// iterate over the descriptors twice in reverse, catching exact matches on the first pass
// and open generic matches on the second pass.

// Last service should get slot 0
int slot = descriptors.Count - i - 1;
// There may not be any open generics here
ServiceCallSite? callSite = TryCreateExact(descriptor, itemType, callSiteChain, slot);
Debug.Assert(callSite != null);
List<KeyValuePair<int, ServiceCallSite>> callSitesByIndex = new();

cacheLocation = GetCommonCacheLocation(cacheLocation, callSite.Cache.Location);
callSites.Add(callSite);
int slot = 0;
for (int i = _descriptors.Length - 1; i >= 0; i--)
{
if (TryCreateExact(_descriptors[i], itemType, callSiteChain, slot) is { } callSite)
{
AddCallSite(callSite, i);
}
}
else
for (int i = _descriptors.Length - 1; i >= 0; i--)
{
int slot = 0;
// We are going in reverse so the last service in descriptor list gets slot 0
for (int i = _descriptors.Length - 1; i >= 0; i--)
if (TryCreateOpenGeneric(_descriptors[i], itemType, callSiteChain, slot, throwOnConstraintViolation: false) is { } callSite)
{
ServiceDescriptor descriptor = _descriptors[i];
ServiceCallSite? callSite = TryCreateExact(descriptor, itemType, callSiteChain, slot) ??
TryCreateOpenGeneric(descriptor, itemType, callSiteChain, slot, false);

if (callSite != null)
{
slot++;

cacheLocation = GetCommonCacheLocation(cacheLocation, callSite.Cache.Location);
callSites.Add(callSite);
}
AddCallSite(callSite, i);
}

callSites.Reverse();
}

callSitesByIndex.Sort((a, b) => a.Key.CompareTo(b.Key));
callSites = callSitesByIndex.Select(p => p.Value).ToArray();

ResultCache resultCache = ResultCache.None;
if (cacheLocation == CallSiteResultCacheLocation.Scope || cacheLocation == CallSiteResultCacheLocation.Root)
void AddCallSite(ServiceCallSite callSite, int index)
{
resultCache = new ResultCache(cacheLocation, callSiteKey);
slot++;

cacheLocation = GetCommonCacheLocation(cacheLocation, callSite.Cache.Location);
callSitesByIndex.Add(new(index, callSite));
}
}

return _callSiteCache[callSiteKey] = new IEnumerableCallSite(resultCache, itemType, callSites.ToArray());
ResultCache resultCache = ResultCache.None;
if (cacheLocation == CallSiteResultCacheLocation.Scope || cacheLocation == CallSiteResultCacheLocation.Root)
{
resultCache = new ResultCache(cacheLocation, callSiteKey);
}

return null;
return _callSiteCache[callSiteKey] = new IEnumerableCallSite(resultCache, itemType, callSites);
}
finally
{
Expand Down