diff --git a/src/coreclr/tools/Common/TypeSystem/Ecma/MetadataExtensions.cs b/src/coreclr/tools/Common/TypeSystem/Ecma/MetadataExtensions.cs index 4cbc5f3b2a19..35650f9044ca 100644 --- a/src/coreclr/tools/Common/TypeSystem/Ecma/MetadataExtensions.cs +++ b/src/coreclr/tools/Common/TypeSystem/Ecma/MetadataExtensions.cs @@ -196,8 +196,8 @@ public static bool GetAttributeTypeNamespaceAndName(this MetadataReader metadata public static PInvokeFlags GetDelegatePInvokeFlags(this EcmaType type) { - PInvokeFlags flags = new PInvokeFlags(); - + PInvokeFlags flags = new PInvokeFlags(PInvokeAttributes.PreserveSig); + if (!type.IsDelegate) { return flags; diff --git a/src/coreclr/tools/Common/TypeSystem/IL/Stubs/PInvokeILEmitter.cs b/src/coreclr/tools/Common/TypeSystem/IL/Stubs/PInvokeILEmitter.cs index aa67eaf60d89..d77bb5994ec0 100644 --- a/src/coreclr/tools/Common/TypeSystem/IL/Stubs/PInvokeILEmitter.cs +++ b/src/coreclr/tools/Common/TypeSystem/IL/Stubs/PInvokeILEmitter.cs @@ -94,9 +94,26 @@ private static Marshaller[] InitializeMarshallers(MethodDesc targetMethod, Inter parameterMetadata = parameterMetadataArray[parameterIndex++]; } - TypeDesc parameterType = (i == 0) - ? methodSig.ReturnType //first item is the return type - : methodSig[i - 1]; + TypeDesc parameterType; + bool isHRSwappedRetVal = false; + if (i == 0) + { + // First item is the return type + parameterType = methodSig.ReturnType; + if (!flags.PreserveSig && !parameterType.IsVoid) + { + // PreserveSig = false can only show up an regular forward PInvokes + Debug.Assert(direction == MarshalDirection.Forward); + + parameterType = methodSig.Context.GetByRefType(parameterType); + isHRSwappedRetVal = true; + } + } + else + { + parameterType = methodSig[i - 1]; + } + marshallers[i] = Marshaller.CreateMarshaller(parameterType, parameterIndex, methodSig.GetEmbeddedSignatureData(), @@ -108,8 +125,8 @@ private static Marshaller[] InitializeMarshallers(MethodDesc targetMethod, Inter indexOffset + parameterMetadata.Index, flags, parameterMetadata.In, - parameterMetadata.Out, - parameterMetadata.Return + isHRSwappedRetVal ? true : parameterMetadata.Out, + isHRSwappedRetVal ? false : parameterMetadata.Return ); } @@ -231,16 +248,14 @@ private void EmitDelegateCall(DelegateMarshallingMethodThunk delegateMethod, PIn private void EmitPInvokeCall(PInvokeILCodeStreams ilCodeStreams) { - if (!_flags.PreserveSig && _targetMethod.Signature.ReturnType != _targetMethod.Context.GetWellKnownType(WellKnownType.Void)) - throw new NotSupportedException(); - ILEmitter emitter = ilCodeStreams.Emitter; ILCodeStream fnptrLoadStream = ilCodeStreams.FunctionPointerLoadStream; ILCodeStream callsiteSetupCodeStream = ilCodeStreams.CallsiteSetupCodeStream; TypeSystemContext context = _targetMethod.Context; + bool isHRSwappedRetVal = !_flags.PreserveSig && !_targetMethod.Signature.ReturnType.IsVoid; TypeDesc nativeReturnType = _flags.PreserveSig ? _marshallers[0].NativeParameterType : context.GetWellKnownType(WellKnownType.Int32); - TypeDesc[] nativeParameterTypes = new TypeDesc[_marshallers.Length - 1]; + TypeDesc[] nativeParameterTypes = new TypeDesc[isHRSwappedRetVal ? _marshallers.Length : _marshallers.Length - 1]; // if the SetLastError flag is set in DllImport, clear the error code before doing P/Invoke if (_flags.SetLastError) @@ -254,6 +269,11 @@ private void EmitPInvokeCall(PInvokeILCodeStreams ilCodeStreams) nativeParameterTypes[i - 1] = _marshallers[i].NativeParameterType; } + if (isHRSwappedRetVal) + { + nativeParameterTypes[_marshallers.Length - 1] = _marshallers[0].NativeParameterType; + } + if (!_pInvokeILEmitterConfiguration.GenerateDirectCall(_targetMethod, out _)) { MetadataType lazyHelperType = context.GetHelperType("InteropHelpers"); @@ -357,11 +377,17 @@ private MethodIL EmitIL() cleanupCodestream.BeginHandler(tryFinally); // Marshal the arguments - for (int i = 0; i < _marshallers.Length; i++) + bool isHRSwappedRetVal = !_flags.PreserveSig && !_targetMethod.Signature.ReturnType.IsVoid; + for (int i = isHRSwappedRetVal ? 1 : 0; i < _marshallers.Length; i++) { _marshallers[i].EmitMarshallingIL(pInvokeILCodeStreams); } + if (isHRSwappedRetVal) + { + _marshallers[0].EmitMarshallingIL(pInvokeILCodeStreams); + } + // make the call switch (_targetMethod) { diff --git a/src/coreclr/tools/Common/TypeSystem/Interop/IL/Marshaller.cs b/src/coreclr/tools/Common/TypeSystem/Interop/IL/Marshaller.cs index 46c53157b465..d79d54c618a2 100644 --- a/src/coreclr/tools/Common/TypeSystem/Interop/IL/Marshaller.cs +++ b/src/coreclr/tools/Common/TypeSystem/Interop/IL/Marshaller.cs @@ -149,6 +149,8 @@ internal virtual bool CleanupRequired } } + internal bool IsHRSwappedRetVal => Index == 0 && !Return; + public bool In; public bool Out; public bool Return; @@ -512,7 +514,7 @@ protected virtual void EmitMarshalReturnValueManagedToNative() public virtual void LoadReturnValue(ILCodeStream codeStream) { - Debug.Assert(Return); + Debug.Assert(Return || IsHRSwappedRetVal); switch (MarshalDirection) { @@ -653,6 +655,13 @@ protected void PropagateFromByRefArg(ILCodeStream stream, Home home) /// protected void PropagateToByRefArg(ILCodeStream stream, Home home) { + // If by-ref arg has index == 0 then that argument is used for HR swapping and we just return that value. + if (IsHRSwappedRetVal) + { + // Returning result would be handled by LoadReturnValue + return; + } + stream.EmitLdArg(Index - 1); home.LoadValue(stream); stream.EmitStInd(ManagedType); @@ -903,6 +912,12 @@ class BlittableValueMarshaller : Marshaller { protected override void EmitMarshalArgumentManagedToNative() { + if (IsHRSwappedRetVal) + { + base.EmitMarshalArgumentManagedToNative(); + return; + } + if (IsNativeByRef && MarshalDirection == MarshalDirection.Forward) { ILCodeStream marshallingCodeStream = _ilCodeStreams.MarshallingCodeStream; diff --git a/src/tests/nativeaot/SmokeTests/ComWrappers/ComWrappers.cs b/src/tests/nativeaot/SmokeTests/ComWrappers/ComWrappers.cs index 5fa111cba18a..df2c9942cef2 100644 --- a/src/tests/nativeaot/SmokeTests/ComWrappers/ComWrappers.cs +++ b/src/tests/nativeaot/SmokeTests/ComWrappers/ComWrappers.cs @@ -64,6 +64,9 @@ public static void ThrowIfNotEquals(bool expected, bool actual, string message) [DllImport("ComWrappersNative", CallingConvention = CallingConvention.StdCall)] static extern int BuildComPointer(out IComInterface foo); + [DllImport("ComWrappersNative", CallingConvention = CallingConvention.StdCall, PreserveSig = false, EntryPoint="BuildComPointer")] + static extern IComInterface BuildComPointerNoPreserveSig(); + public static void TestComInteropNullPointers() { Console.WriteLine("Testing Marshal APIs for COM interfaces"); @@ -160,6 +163,9 @@ public static void TestComInteropCCWCreation() int result = BuildComPointer(out var comPointer); ThrowIfNotEquals(0, result, "Seems to be COM marshalling behave strange."); comPointer.DoWork(11); + + comPointer = BuildComPointerNoPreserveSig(); + comPointer.DoWork(22); } [MethodImpl(MethodImplOptions.NoInlining)] diff --git a/src/tests/nativeaot/SmokeTests/PInvoke/PInvoke.cs b/src/tests/nativeaot/SmokeTests/PInvoke/PInvoke.cs index f1698766d895..4e5375389455 100644 --- a/src/tests/nativeaot/SmokeTests/PInvoke/PInvoke.cs +++ b/src/tests/nativeaot/SmokeTests/PInvoke/PInvoke.cs @@ -275,9 +275,20 @@ internal struct Callbacks [DllImport("PInvokeNative", CallingConvention = CallingConvention.StdCall, PreserveSig = false)] static extern void ValidateSuccessCall(int errorCode); + [DllImport("PInvokeNative", CallingConvention = CallingConvention.StdCall, PreserveSig = false)] + static extern int ValidateIntResult(int errorCode); + + [DllImport("PInvokeNative", EntryPoint = "ValidateIntResult", CallingConvention = CallingConvention.StdCall, PreserveSig = false)] + static extern MagicEnum ValidateEnumResult(int errorCode); + [DllImport("PInvokeNative", CallingConvention = CallingConvention.StdCall)] internal static extern decimal DecimalTest(decimal value); + internal enum MagicEnum + { + MagicResult = 42, + } + public static int Main(string[] args) { TestBlittableType(); @@ -1035,6 +1046,22 @@ private static void TestWithoutPreserveSig() catch (NotImplementedException) { } + + var intResult = ValidateIntResult(0); + ThrowIfNotEquals(intResult, 42, "Int32 marshalling failed."); + + try + { + const int E_NOTIMPL = -2147467263; + intResult = ValidateIntResult(E_NOTIMPL); + throw new Exception("Exception should be thrown for E_NOTIMPL error code"); + } + catch (NotImplementedException) + { + } + + var enumResult = ValidateEnumResult(0); + ThrowIfNotEquals(enumResult, MagicEnum.MagicResult, "Enum marshalling failed."); } public static unsafe void TestForwardDelegateWithUnmanagedCallersOnly() diff --git a/src/tests/nativeaot/SmokeTests/PInvoke/PInvokeNative.cpp b/src/tests/nativeaot/SmokeTests/PInvoke/PInvokeNative.cpp index 9cc98b3f72f6..8e3c15108c10 100644 --- a/src/tests/nativeaot/SmokeTests/PInvoke/PInvokeNative.cpp +++ b/src/tests/nativeaot/SmokeTests/PInvoke/PInvokeNative.cpp @@ -678,6 +678,12 @@ DLL_EXPORT int __stdcall ValidateSuccessCall(int errorCode) return errorCode; } +DLL_EXPORT int __stdcall ValidateIntResult(int errorCode, int* result) +{ + *result = 42; + return errorCode; +} + #ifndef DECIMAL_NEG // defined in wtypes.h typedef struct tagDEC { uint16_t wReserved;