Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
152 changes: 145 additions & 7 deletions src/coreclr/vm/ceeload.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4659,6 +4659,122 @@ PTR_VOID ReflectionModule::GetRvaField(RVA field) // virtual
// VASigCookies
// ===========================================================================

static bool TypeSignatureContainsGenericVariables(SigParser& sp);
static bool MethodSignatureContainsGenericVariables(SigParser& sp);

static bool TypeSignatureContainsGenericVariables(SigParser& sp)
{
STANDARD_VM_CONTRACT;

CorElementType et = ELEMENT_TYPE_END;
IfFailThrow(sp.GetElemType(&et));

if (CorIsPrimitiveType(et))
return false;

switch (et)
{
case ELEMENT_TYPE_OBJECT:
case ELEMENT_TYPE_STRING:
case ELEMENT_TYPE_TYPEDBYREF:
return false;

case ELEMENT_TYPE_BYREF:
case ELEMENT_TYPE_PTR:
case ELEMENT_TYPE_SZARRAY:
return TypeSignatureContainsGenericVariables(sp);

case ELEMENT_TYPE_VALUETYPE:
case ELEMENT_TYPE_CLASS:
IfFailThrow(sp.GetToken(NULL)); // Skip RID
return false;

case ELEMENT_TYPE_FNPTR:
return MethodSignatureContainsGenericVariables(sp);

case ELEMENT_TYPE_ARRAY:
{
if (TypeSignatureContainsGenericVariables(sp))
return true;

uint32_t rank;
IfFailThrow(sp.GetData(&rank)); // Get rank
if (rank)
{
uint32_t nsizes;
IfFailThrow(sp.GetData(&nsizes)); // Get # of sizes
while (nsizes--)
{
IfFailThrow(sp.GetData(NULL)); // Skip size
}

uint32_t nlbounds;
IfFailThrow(sp.GetData(&nlbounds)); // Get # of lower bounds
while (nlbounds--)
{
IfFailThrow(sp.GetData(NULL)); // Skip lower bounds
}
}
}
return false;

case ELEMENT_TYPE_GENERICINST:
{
if (TypeSignatureContainsGenericVariables(sp))
return true;

uint32_t argCnt;
IfFailThrow(sp.GetData(&argCnt)); // Get number of parameters
while (argCnt--)
{
if (TypeSignatureContainsGenericVariables(sp))
return true;
}
}
return false;

case ELEMENT_TYPE_INTERNAL:
IfFailThrow(sp.GetPointer(NULL));
return false;

case ELEMENT_TYPE_VAR:
case ELEMENT_TYPE_MVAR:
return true;

default:
// Return conservative answer for unhandled elements
_ASSERTE(!"Unexpected element type.");
return true;
}
}

static bool MethodSignatureContainsGenericVariables(SigParser& sp)
{
STANDARD_VM_CONTRACT;

uint32_t callConv = 0;
IfFailThrow(sp.GetCallingConvInfo(&callConv));

if (callConv & IMAGE_CEE_CS_CALLCONV_GENERIC)
{
// Generic signatures should never show up here, return conservative answer.
_ASSERTE(!"Unexpected generic signature.");
return true;
}

uint32_t numArgs = 0;
IfFailThrow(sp.GetData(&numArgs));

// iterate over the return type and parameters
for (uint32_t i = 0; i <= numArgs; i++)
{
if (TypeSignatureContainsGenericVariables(sp))
return true;
}

return false;
}

//==========================================================================
// Enregisters a VASig.
//==========================================================================
Expand All @@ -4667,15 +4783,39 @@ VASigCookie *Module::GetVASigCookie(Signature vaSignature, const SigTypeContext*
CONTRACT(VASigCookie*)
{
INSTANCE_CHECK;
THROWS;
GC_TRIGGERS;
MODE_ANY;
STANDARD_VM_CHECK;
POSTCONDITION(CheckPointer(RETVAL));
INJECT_FAULT(COMPlusThrowOM());
}
CONTRACT_END;

Module* pLoaderModule = ClassLoader::ComputeLoaderModuleWorker(this, mdTokenNil, typeContext->m_classInst, typeContext->m_methodInst);
SigTypeContext emptyContext;

Module* pLoaderModule = this;
if (!typeContext->IsEmpty())
{
// Strip the generic context if it is not actually used by the signature. It is nececessary for both:
// - Performance: allow more sharing of vasig cookies
// - Functionality: built-in runtime marshalling is disallowed for generic signatures
SigParser sigParser = vaSignature.CreateSigParser();
if (MethodSignatureContainsGenericVariables(sigParser))
{
pLoaderModule = ClassLoader::ComputeLoaderModuleWorker(this, mdTokenNil, typeContext->m_classInst, typeContext->m_methodInst);
}
else
{
typeContext = &emptyContext;
}
}
else
{
#ifdef _DEBUG
// The method signature should not contain any generic variables if the generic context is not provided.
SigParser sigParser = vaSignature.CreateSigParser();
_ASSERTE(!MethodSignatureContainsGenericVariables(sigParser));
#endif
}

VASigCookie *pCookie = GetVASigCookieWorker(this, pLoaderModule, vaSignature, typeContext);

RETURN pCookie;
Expand All @@ -4685,9 +4825,7 @@ VASigCookie *Module::GetVASigCookieWorker(Module* pDefiningModule, Module* pLoad
{
CONTRACT(VASigCookie*)
{
THROWS;
GC_TRIGGERS;
MODE_ANY;
STANDARD_VM_CHECK;
POSTCONDITION(CheckPointer(RETVAL));
INJECT_FAULT(COMPlusThrowOM());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,12 @@ static BlittableGeneric<string> UnmanagedExportedFunctionBlittableGenericString(
return new() { X = Convert.ToInt32(arg) };
}

[UnmanagedCallersOnly]
static unsafe void UnmanagedExportedFunctionRefInt(int* pval, float arg)
{
*pval = Convert.ToInt32(arg);
}

class GenericCaller<T>
{
internal static unsafe T GenericCalli<U>(void* fnptr, U arg)
Expand All @@ -40,6 +46,11 @@ internal static unsafe BlittableGeneric<T> WrappedGenericCalli<U>(void* fnptr, U
{
return ((delegate* unmanaged<U, BlittableGeneric<T>>)fnptr)(arg);
}

internal static unsafe void NonGenericCalli<U>(void* fnptr, ref int val, float arg)
{
((delegate* unmanaged<ref int, float, void>)fnptr)(ref val, arg);
}
}

struct BlittableGeneric<T>
Expand Down Expand Up @@ -81,6 +92,14 @@ public static void RunGenericFunctionPointerTest(float inVal)
outVar = GenericCaller<string>.WrappedGenericCalli((delegate* unmanaged<float, BlittableGeneric<string>>)&UnmanagedExportedFunctionBlittableGenericString, inVal).X;
}
Assert.Equal(expectedValue, outVar);

outVar = 0;
Console.WriteLine("Testing non-GenericCalli with non-blittable argument in a generic caller");
unsafe
{
GenericCaller<string>.NonGenericCalli<string>((delegate* unmanaged<int*, float, void>)&UnmanagedExportedFunctionRefInt, ref outVar, inVal);
}
Assert.Equal(expectedValue, outVar);
}

[ConditionalFact(nameof(CanRunInvalidGenericFunctionPointerTest))]
Expand Down