Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement handling PreserveSig = false for interfaces #1206

Merged
merged 9 commits into from
Jun 8, 2021
Original file line number Diff line number Diff line change
Expand Up @@ -196,8 +196,8 @@ public static bool GetAttributeTypeNamespaceAndName(this MetadataReader metadata

public static PInvokeFlags GetDelegatePInvokeFlags(this EcmaType type)
{
PInvokeFlags flags = new PInvokeFlags();

PInvokeFlags flags = new PInvokeFlags(PInvokeAttributes.PreserveSig);
if (!type.IsDelegate)
{
return flags;
Expand Down
46 changes: 36 additions & 10 deletions src/coreclr/tools/Common/TypeSystem/IL/Stubs/PInvokeILEmitter.cs
Original file line number Diff line number Diff line change
Expand Up @@ -94,9 +94,26 @@ private static Marshaller[] InitializeMarshallers(MethodDesc targetMethod, Inter
parameterMetadata = parameterMetadataArray[parameterIndex++];
}

TypeDesc parameterType = (i == 0)
? methodSig.ReturnType //first item is the return type
: methodSig[i - 1];
TypeDesc parameterType;
bool isHRSwappedRetVal = false;
if (i == 0)
{
// First item is the return type
parameterType = methodSig.ReturnType;
if (!flags.PreserveSig && !parameterType.IsVoid)
{
// PreserveSig = false can only show up an regular forward PInvokes
Debug.Assert(direction == MarshalDirection.Forward);

parameterType = methodSig.Context.GetByRefType(parameterType);
isHRSwappedRetVal = true;
}
}
else
{
parameterType = methodSig[i - 1];
}

marshallers[i] = Marshaller.CreateMarshaller(parameterType,
parameterIndex,
methodSig.GetEmbeddedSignatureData(),
Expand All @@ -108,8 +125,8 @@ private static Marshaller[] InitializeMarshallers(MethodDesc targetMethod, Inter
indexOffset + parameterMetadata.Index,
flags,
parameterMetadata.In,
parameterMetadata.Out,
parameterMetadata.Return
isHRSwappedRetVal ? true : parameterMetadata.Out,
isHRSwappedRetVal ? false : parameterMetadata.Return
);
}

Expand Down Expand Up @@ -231,16 +248,14 @@ private void EmitDelegateCall(DelegateMarshallingMethodThunk delegateMethod, PIn

private void EmitPInvokeCall(PInvokeILCodeStreams ilCodeStreams)
{
if (!_flags.PreserveSig && _targetMethod.Signature.ReturnType != _targetMethod.Context.GetWellKnownType(WellKnownType.Void))
throw new NotSupportedException();

ILEmitter emitter = ilCodeStreams.Emitter;
ILCodeStream fnptrLoadStream = ilCodeStreams.FunctionPointerLoadStream;
ILCodeStream callsiteSetupCodeStream = ilCodeStreams.CallsiteSetupCodeStream;
TypeSystemContext context = _targetMethod.Context;

bool isHRSwappedRetVal = !_flags.PreserveSig && !_targetMethod.Signature.ReturnType.IsVoid;
TypeDesc nativeReturnType = _flags.PreserveSig ? _marshallers[0].NativeParameterType : context.GetWellKnownType(WellKnownType.Int32);
TypeDesc[] nativeParameterTypes = new TypeDesc[_marshallers.Length - 1];
TypeDesc[] nativeParameterTypes = new TypeDesc[isHRSwappedRetVal ? _marshallers.Length : _marshallers.Length - 1];

// if the SetLastError flag is set in DllImport, clear the error code before doing P/Invoke
if (_flags.SetLastError)
Expand All @@ -254,6 +269,11 @@ private void EmitPInvokeCall(PInvokeILCodeStreams ilCodeStreams)
nativeParameterTypes[i - 1] = _marshallers[i].NativeParameterType;
}

if (isHRSwappedRetVal)
{
nativeParameterTypes[_marshallers.Length - 1] = _marshallers[0].NativeParameterType;
}

if (!_pInvokeILEmitterConfiguration.GenerateDirectCall(_targetMethod, out _))
{
MetadataType lazyHelperType = context.GetHelperType("InteropHelpers");
Expand Down Expand Up @@ -357,11 +377,17 @@ private MethodIL EmitIL()
cleanupCodestream.BeginHandler(tryFinally);

// Marshal the arguments
for (int i = 0; i < _marshallers.Length; i++)
bool isHRSwappedRetVal = !_flags.PreserveSig && !_targetMethod.Signature.ReturnType.IsVoid;
for (int i = isHRSwappedRetVal ? 1 : 0; i < _marshallers.Length; i++)
{
_marshallers[i].EmitMarshallingIL(pInvokeILCodeStreams);
}

if (isHRSwappedRetVal)
{
_marshallers[0].EmitMarshallingIL(pInvokeILCodeStreams);
}

// make the call
switch (_targetMethod)
{
Expand Down
17 changes: 16 additions & 1 deletion src/coreclr/tools/Common/TypeSystem/Interop/IL/Marshaller.cs
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,8 @@ internal virtual bool CleanupRequired
}
}

internal bool IsHRSwappedRetVal => Index == 0 && !Return;

public bool In;
public bool Out;
public bool Return;
Expand Down Expand Up @@ -512,7 +514,7 @@ protected virtual void EmitMarshalReturnValueManagedToNative()

public virtual void LoadReturnValue(ILCodeStream codeStream)
{
Debug.Assert(Return);
Debug.Assert(Return || IsHRSwappedRetVal);

switch (MarshalDirection)
{
Expand Down Expand Up @@ -653,6 +655,13 @@ protected void PropagateFromByRefArg(ILCodeStream stream, Home home)
/// </summary>
protected void PropagateToByRefArg(ILCodeStream stream, Home home)
{
// If by-ref arg has index == 0 then that argument is used for HR swapping and we just return that value.
if (IsHRSwappedRetVal)
{
// Returning result would be handled by LoadReturnValue
return;
}

stream.EmitLdArg(Index - 1);
home.LoadValue(stream);
stream.EmitStInd(ManagedType);
Expand Down Expand Up @@ -903,6 +912,12 @@ class BlittableValueMarshaller : Marshaller
{
protected override void EmitMarshalArgumentManagedToNative()
{
if (IsHRSwappedRetVal)
{
base.EmitMarshalArgumentManagedToNative();
return;
}

if (IsNativeByRef && MarshalDirection == MarshalDirection.Forward)
{
ILCodeStream marshallingCodeStream = _ilCodeStreams.MarshallingCodeStream;
Expand Down
6 changes: 6 additions & 0 deletions src/tests/nativeaot/SmokeTests/ComWrappers/ComWrappers.cs
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,9 @@ public static void ThrowIfNotEquals(bool expected, bool actual, string message)
[DllImport("ComWrappersNative", CallingConvention = CallingConvention.StdCall)]
static extern int BuildComPointer(out IComInterface foo);

[DllImport("ComWrappersNative", CallingConvention = CallingConvention.StdCall, PreserveSig = false, EntryPoint="BuildComPointer")]
static extern IComInterface BuildComPointerNoPreserveSig();

public static void TestComInteropNullPointers()
{
Console.WriteLine("Testing Marshal APIs for COM interfaces");
Expand Down Expand Up @@ -160,6 +163,9 @@ public static void TestComInteropCCWCreation()
int result = BuildComPointer(out var comPointer);
ThrowIfNotEquals(0, result, "Seems to be COM marshalling behave strange.");
comPointer.DoWork(11);

comPointer = BuildComPointerNoPreserveSig();
comPointer.DoWork(22);
}

[MethodImpl(MethodImplOptions.NoInlining)]
Expand Down
27 changes: 27 additions & 0 deletions src/tests/nativeaot/SmokeTests/PInvoke/PInvoke.cs
Original file line number Diff line number Diff line change
Expand Up @@ -275,9 +275,20 @@ internal struct Callbacks
[DllImport("PInvokeNative", CallingConvention = CallingConvention.StdCall, PreserveSig = false)]
static extern void ValidateSuccessCall(int errorCode);

[DllImport("PInvokeNative", CallingConvention = CallingConvention.StdCall, PreserveSig = false)]
static extern int ValidateIntResult(int errorCode);

[DllImport("PInvokeNative", EntryPoint = "ValidateIntResult", CallingConvention = CallingConvention.StdCall, PreserveSig = false)]
static extern MagicEnum ValidateEnumResult(int errorCode);

[DllImport("PInvokeNative", CallingConvention = CallingConvention.StdCall)]
internal static extern decimal DecimalTest(decimal value);

internal enum MagicEnum
{
MagicResult = 42,
}

public static int Main(string[] args)
{
TestBlittableType();
Expand Down Expand Up @@ -1035,6 +1046,22 @@ private static void TestWithoutPreserveSig()
catch (NotImplementedException)
{
}

var intResult = ValidateIntResult(0);
ThrowIfNotEquals(intResult, 42, "Int32 marshalling failed.");

try
{
const int E_NOTIMPL = -2147467263;
intResult = ValidateIntResult(E_NOTIMPL);
throw new Exception("Exception should be thrown for E_NOTIMPL error code");
}
catch (NotImplementedException)
{
}

var enumResult = ValidateEnumResult(0);
ThrowIfNotEquals(enumResult, MagicEnum.MagicResult, "Enum marshalling failed.");
}

public static unsafe void TestForwardDelegateWithUnmanagedCallersOnly()
Expand Down
6 changes: 6 additions & 0 deletions src/tests/nativeaot/SmokeTests/PInvoke/PInvokeNative.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -678,6 +678,12 @@ DLL_EXPORT int __stdcall ValidateSuccessCall(int errorCode)
return errorCode;
}

DLL_EXPORT int __stdcall ValidateIntResult(int errorCode, int* result)
{
*result = 42;
return errorCode;
}

#ifndef DECIMAL_NEG // defined in wtypes.h
typedef struct tagDEC {
uint16_t wReserved;
Expand Down