Skip to content

Commit

Permalink
P/Invokes should use C types not C++. (#94942)
Browse files Browse the repository at this point in the history
* Reduce source duplication

* Use correct function signature when marshalling is enabled.
  • Loading branch information
AaronRobinsonMSFT authored Nov 20, 2023
1 parent 6036eaf commit f0ba96a
Show file tree
Hide file tree
Showing 15 changed files with 175 additions and 240 deletions.
8 changes: 1 addition & 7 deletions src/tests/Interop/DisabledRuntimeMarshalling/AutoLayout.cs
Original file line number Diff line number Diff line change
Expand Up @@ -13,24 +13,18 @@ public unsafe class PInvokes_AutoLayout
[Fact]
public static void AutoLayoutStruct()
{
short s = 42;
bool b = true;
Assert.Throws<MarshalDirectiveException>(() => DisabledRuntimeMarshallingNative.CallWithAutoLayoutStruct(new AutoLayoutStruct()));
}

[Fact]
public static void StructWithAutoLayoutField()
{
short s = 42;
bool b = true;
AssertThrowsMarshalDirectiveOrTypeLoad(() => DisabledRuntimeMarshallingNative.CallWithAutoLayoutStruct(new SequentialWithAutoLayoutField()));
}

[Fact]
public static void StructWithNestedAutoLayoutField()
{
short s = 42;
bool b = true;
AssertThrowsMarshalDirectiveOrTypeLoad(() => DisabledRuntimeMarshallingNative.CallWithAutoLayoutStruct(new SequentialWithAutoLayoutNestedField()));
}

Expand All @@ -41,7 +35,7 @@ private static void AssertThrowsMarshalDirectiveOrTypeLoad(Action testCode)
testCode();
return;
}
catch (Exception ex) when(ex is MarshalDirectiveException or TypeLoadException)
catch (Exception ex) when (ex is MarshalDirectiveException or TypeLoadException)
{
return;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,9 @@

#include <xplatform.h>

// MSVC versions before 19.38 generate incorrect code for this file when compiling with /O2
#if defined(_MSC_VER) && (_MSC_VER < 1938)
#pragma optimize("", off)
#endif

struct StructWithShortAndBool
{
bool b;
BYTE b;
short s;
// Make sure we don't have any cases where the native code could return a value of this type one way,
// but an invalid managed declaration would expect it differently. This ensures that test failures won't be
Expand All @@ -24,24 +19,31 @@ struct StructWithWCharAndShort
WCHAR c;
};

extern "C" DLL_EXPORT bool STDMETHODCALLTYPE CheckStructWithShortAndBool(StructWithShortAndBool str, short s, bool b)
extern "C" DLL_EXPORT BYTE STDMETHODCALLTYPE CheckStructWithShortAndBool(StructWithShortAndBool str, short s, BYTE b)
{
return str.s == s && str.b == b;
}

extern "C" DLL_EXPORT bool STDMETHODCALLTYPE CheckStructWithWCharAndShort(StructWithWCharAndShort str, short s, WCHAR c)
static BOOL STDMETHODCALLTYPE CheckStructWithShortAndBoolMarshalSupport(StructWithShortAndBool str, short s, BYTE b)
{
return str.s == s && str.c == c;
return (CheckStructWithShortAndBool(str, s, b) != 0) ? TRUE : FALSE;
}

using CheckStructWithShortAndBoolCallback = bool (STDMETHODCALLTYPE*)(StructWithShortAndBool, short, bool);
extern "C" DLL_EXPORT BYTE STDMETHODCALLTYPE CheckStructWithWCharAndShort(StructWithWCharAndShort str, short s, WCHAR c)
{
return str.s == s && str.c == c;
}

extern "C" DLL_EXPORT CheckStructWithShortAndBoolCallback STDMETHODCALLTYPE GetStructWithShortAndBoolCallback()
extern "C" DLL_EXPORT void* STDMETHODCALLTYPE GetStructWithShortAndBoolCallback(BYTE marshalSupported)
{
return &CheckStructWithShortAndBool;
return (marshalSupported != 0)
? (void*)&CheckStructWithShortAndBoolMarshalSupport
: (void*)&CheckStructWithShortAndBool;
}

extern "C" DLL_EXPORT bool STDMETHODCALLTYPE CallCheckStructWithShortAndBoolCallback(CheckStructWithShortAndBoolCallback cb, StructWithShortAndBool str, short s, bool b)
using CheckStructWithShortAndBoolCallback = BYTE (STDMETHODCALLTYPE*)(StructWithShortAndBool, short, BYTE);

extern "C" DLL_EXPORT BYTE STDMETHODCALLTYPE CallCheckStructWithShortAndBoolCallback(CheckStructWithShortAndBoolCallback cb, StructWithShortAndBool str, short s, BYTE b)
{
return cb(str, s, b);
}
Expand All @@ -54,14 +56,14 @@ extern "C" DLL_EXPORT BYTE PassThrough(BYTE b)
extern "C" DLL_EXPORT void Invalid(...) {}


extern "C" DLL_EXPORT bool STDMETHODCALLTYPE CheckStructWithShortAndBoolWithVariantBool(StructWithShortAndBool str, short s, VARIANT_BOOL b)
extern "C" DLL_EXPORT BYTE STDMETHODCALLTYPE CheckStructWithShortAndBoolWithVariantBool(StructWithShortAndBool str, short s, VARIANT_BOOL b)
{
// Specifically use VARIANT_TRUE here as invalid marshalling (in the "disabled runtime marshalling" case) will incorrectly marshal VARAINT_TRUE
// but could accidentally marshal VARIANT_FALSE correctly since it is 0, which is the same representation as a zero or sign extension of the C# false value.
return str.s == s && str.b == (b == VARIANT_TRUE);
}

using CheckStructWithShortAndBoolWithVariantBoolCallback = bool (STDMETHODCALLTYPE*)(StructWithShortAndBool, short, VARIANT_BOOL);
using CheckStructWithShortAndBoolWithVariantBoolCallback = BYTE (STDMETHODCALLTYPE*)(StructWithShortAndBool, short, VARIANT_BOOL);

extern "C" DLL_EXPORT CheckStructWithShortAndBoolWithVariantBoolCallback STDMETHODCALLTYPE GetStructWithShortAndBoolWithVariantBoolCallback()
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,11 @@ public StructWithShortAndBool(short s, bool b)

public struct StructWithShortAndBoolWithMarshalAs
{
#if DISABLE_RUNTIME_MARSHALLING
[MarshalAs(UnmanagedType.VariantBool)]
#else
[MarshalAs(UnmanagedType.U1)]
#endif
bool b;
short s;
int padding;
Expand All @@ -41,6 +45,10 @@ public StructWithShortAndBoolWithMarshalAs(short s, bool b)
public struct StructWithWCharAndShort
{
short s;
#if DISABLE_RUNTIME_MARSHALLING
#else
[MarshalAs(UnmanagedType.U2)]
#endif
char c;

public StructWithWCharAndShort(short s, char c)
Expand All @@ -54,7 +62,11 @@ public StructWithWCharAndShort(short s, char c)
public struct StructWithWCharAndShortWithMarshalAs
{
short s;
#if DISABLE_RUNTIME_MARSHALLING
[MarshalAs(UnmanagedType.U1)]
#else
[MarshalAs(UnmanagedType.U2)]
#endif
char c;

public StructWithWCharAndShortWithMarshalAs(short s, char c)
Expand All @@ -76,6 +88,7 @@ public StructWithShortAndGeneric(short s, T t)
}
}

#if DISABLE_RUNTIME_MARSHALLING
public struct StructWithString
{
string s;
Expand All @@ -89,6 +102,7 @@ public StructWithString(string s)
[StructLayout(LayoutKind.Sequential)]
public class LayoutClass
{}
#endif

[StructLayout(LayoutKind.Auto)]
public struct AutoLayoutStruct
Expand All @@ -106,52 +120,142 @@ public struct SequentialWithAutoLayoutNestedField
SequentialWithAutoLayoutField field;
}

#if DISABLE_RUNTIME_MARSHALLING
public enum ByteEnum : byte
{
Value = 42
}

[DllImport(nameof(DisabledRuntimeMarshallingNative), EntryPoint = "PassThrough")]
public static extern byte GetEnumUnderlyingValue(ByteEnum b);
#endif

[DllImport(nameof(DisabledRuntimeMarshallingNative))]
#if !DISABLE_RUNTIME_MARSHALLING
[return:MarshalAs(UnmanagedType.U1)]
#endif
public static extern bool CheckStructWithShortAndBool(StructWithShortAndBool str, short s, bool b);

[DllImport(nameof(DisabledRuntimeMarshallingNative))]
public static extern bool CheckStructWithShortAndBool(StructWithShortAndBoolWithMarshalAs str, short s, [MarshalAs(UnmanagedType.I4)] bool b);
#if DISABLE_RUNTIME_MARSHALLING
public static extern bool CheckStructWithShortAndBool(StructWithShortAndBoolWithMarshalAs str, short s, bool b);
#else
[return:MarshalAs(UnmanagedType.U1)]
public static extern bool CheckStructWithShortAndBool(StructWithShortAndBoolWithMarshalAs str, short s, [MarshalAs(UnmanagedType.Bool)] bool b);
#endif

[DllImport(nameof(DisabledRuntimeMarshallingNative))]
#if !DISABLE_RUNTIME_MARSHALLING
[return:MarshalAs(UnmanagedType.U1)]
#endif
public static extern bool CheckStructWithWCharAndShort(StructWithWCharAndShort str, short s, char c);

[DllImport(nameof(DisabledRuntimeMarshallingNative))]
#if !DISABLE_RUNTIME_MARSHALLING
[return:MarshalAs(UnmanagedType.U1)]
#endif
public static extern bool CheckStructWithWCharAndShort(StructWithWCharAndShortWithMarshalAs str, short s, char c);

[DllImport(nameof(DisabledRuntimeMarshallingNative))]
#if !DISABLE_RUNTIME_MARSHALLING
[return:MarshalAs(UnmanagedType.U1)]
#endif
public static extern bool CheckStructWithWCharAndShort(StructWithShortAndGeneric<char> str, short s, char c);

[DllImport(nameof(DisabledRuntimeMarshallingNative))]
#if !DISABLE_RUNTIME_MARSHALLING
[return:MarshalAs(UnmanagedType.U1)]
#endif
public static extern bool CheckStructWithWCharAndShort(StructWithShortAndGeneric<short> str, short s, short c);

#if DISABLE_RUNTIME_MARSHALLING
[DllImport(nameof(DisabledRuntimeMarshallingNative))]
public static extern bool CallCheckStructWithShortAndBoolCallback(delegate* unmanaged<StructWithShortAndBool, short, bool, bool> cb, StructWithShortAndBool str, short s, bool b);
#endif

public static IntPtr GetStructWithShortAndBoolCallback()
{
#if DISABLE_RUNTIME_MARSHALLING
return GetStructWithShortAndBoolCallback(false);
#else
return GetStructWithShortAndBoolCallback(true);
#endif
[DllImport(nameof(DisabledRuntimeMarshallingNative))]
static extern IntPtr GetStructWithShortAndBoolCallback([MarshalAs(UnmanagedType.U1)] bool marshalSupported);
}

[DllImport(nameof(DisabledRuntimeMarshallingNative))]
public static extern IntPtr GetStructWithShortAndBoolWithVariantBoolCallback();

#if DISABLE_RUNTIME_MARSHALLING
[DllImport(nameof(DisabledRuntimeMarshallingNative), EntryPoint = "PassThrough")]
public static extern bool GetByteAsBool(byte b);
#endif

[DllImport(nameof(DisabledRuntimeMarshallingNative), EntryPoint = "Invalid")]
public static extern void CallWithAutoLayoutStruct(AutoLayoutStruct s);

[DllImport(nameof(DisabledRuntimeMarshallingNative), EntryPoint = "Invalid")]
public static extern void CallWithAutoLayoutStruct(SequentialWithAutoLayoutField s);

[DllImport(nameof(DisabledRuntimeMarshallingNative), EntryPoint = "Invalid")]
public static extern void CallWithAutoLayoutStruct(SequentialWithAutoLayoutNestedField s);

#if DISABLE_RUNTIME_MARSHALLING
[DllImport(nameof(DisabledRuntimeMarshallingNative))]
[return:MarshalAs(UnmanagedType.U1)]
public static extern bool CheckStructWithShortAndBoolWithVariantBool(StructWithShortAndBool str, short s, [MarshalAs(UnmanagedType.VariantBool)] bool b);
#else
[DllImport(nameof(DisabledRuntimeMarshallingNative))]
[return:MarshalAs(UnmanagedType.U1)]
public static extern bool CheckStructWithShortAndBoolWithVariantBool(StructWithShortAndBoolWithMarshalAs str, short s, [MarshalAs(UnmanagedType.VariantBool)] bool b);
#endif

// Apply the UnmanagedFunctionPointer attributes with the default calling conventions so that Mono's AOT compiler
// recognizes that these delegate types are used in interop and should have managed->native thunks generated for them.
[UnmanagedFunctionPointer(CallingConvention.Winapi)]
public delegate bool CheckStructWithShortAndBoolCallback(StructWithShortAndBool str, short s, bool b);

#if DISABLE_RUNTIME_MARSHALLING
[UnmanagedFunctionPointer(CallingConvention.Winapi)]
public delegate bool CheckStructWithShortAndBoolWithVariantBoolCallback(StructWithShortAndBool str, short s, [MarshalAs(UnmanagedType.VariantBool)] bool b);
#else
[UnmanagedFunctionPointer(CallingConvention.Winapi)]
public delegate bool CheckStructWithShortAndBoolWithMarshalAsAndVariantBoolCallback(StructWithShortAndBoolWithMarshalAs str, short s, [MarshalAs(UnmanagedType.VariantBool)] bool b);
#endif

[UnmanagedCallersOnly]
public static bool CheckStructWithShortAndBoolManaged(StructWithShortAndBool str, short s, bool b)
{
return str.s == s && str.b == b;
}

#if DISABLE_RUNTIME_MARSHALLING
[DllImport(nameof(DisabledRuntimeMarshallingNative), EntryPoint = "Invalid", CharSet = CharSet.Ansi)]
public static extern void CheckStringWithAnsiCharSet(string s);

[DllImport(nameof(DisabledRuntimeMarshallingNative), EntryPoint = "Invalid", CharSet = CharSet.Unicode)]
public static extern void CheckStringWithUnicodeCharSet(string s);

[DllImport(nameof(DisabledRuntimeMarshallingNative), EntryPoint = "Invalid", CharSet = CharSet.Unicode)]
public static extern string GetStringWithUnicodeCharSet();

[DllImport(nameof(DisabledRuntimeMarshallingNative), EntryPoint = "Invalid")]
public static extern void CheckStructWithStructWithString(StructWithString s);

[DllImport(nameof(DisabledRuntimeMarshallingNative), EntryPoint = "Invalid")]
public static extern void CheckLayoutClass(LayoutClass c);

[DllImport(nameof(DisabledRuntimeMarshallingNative), EntryPoint = "Invalid", SetLastError = true)]
public static extern void CallWithSetLastError();

[DllImport(nameof(DisabledRuntimeMarshallingNative), EntryPoint = "Invalid")]
[LCIDConversion(0)]
public static extern void CallWithLCID();

[DllImport(nameof(DisabledRuntimeMarshallingNative), EntryPoint = "Invalid", PreserveSig = false)]
public static extern int CallWithHResultSwap();

[DllImport(nameof(DisabledRuntimeMarshallingNative), EntryPoint = "Invalid")]
public static extern void CallWithAutoLayoutStruct(AutoLayoutStruct s);

[DllImport(nameof(DisabledRuntimeMarshallingNative), EntryPoint = "Invalid")]
public static extern void CallWithAutoLayoutStruct(SequentialWithAutoLayoutField s);

[DllImport(nameof(DisabledRuntimeMarshallingNative), EntryPoint = "Invalid")]
public static extern void CallWithAutoLayoutStruct(SequentialWithAutoLayoutNestedField s);

[DllImport(nameof(DisabledRuntimeMarshallingNative), EntryPoint = "Invalid")]
public static extern void CallWithByRef(ref int i);

Expand All @@ -161,50 +265,28 @@ public enum ByteEnum : byte
[DllImport(nameof(DisabledRuntimeMarshallingNative), EntryPoint = "Invalid")]
public static extern void CallWithInt128(Int128 i);

[DllImport(nameof(DisabledRuntimeMarshallingNative))]
public static extern delegate* unmanaged<StructWithShortAndBool, short, bool, bool> GetStructWithShortAndBoolCallback();

[DllImport(nameof(DisabledRuntimeMarshallingNative))]
public static extern delegate* unmanaged<StructWithShortAndBool, short, bool, bool> GetStructWithShortAndBoolWithVariantBoolCallback();

[DllImport(nameof(DisabledRuntimeMarshallingNative))]
public static extern bool CallCheckStructWithShortAndBoolCallback(delegate* unmanaged<StructWithShortAndBool, short, bool, bool> cb, StructWithShortAndBool str, short s, bool b);

[DllImport(nameof(DisabledRuntimeMarshallingNative), EntryPoint = "PassThrough")]
public static extern bool GetByteAsBool(byte b);

[DllImport(nameof(DisabledRuntimeMarshallingNative), EntryPoint = "PassThrough")]
public static extern byte GetEnumUnderlyingValue(ByteEnum b);

[DllImport(nameof(DisabledRuntimeMarshallingNative), EntryPoint = "CheckStructWithShortAndBoolWithVariantBool")]
[return:MarshalAs(UnmanagedType.U1)]
public static extern bool CheckStructWithShortAndBoolWithVariantBool_FailureExpected(StructWithShortAndBool str, short s, [MarshalAs(UnmanagedType.VariantBool)] bool b);
[DllImport(nameof(DisabledRuntimeMarshallingNative), EntryPoint = "Invalid")]
public static extern void CallWithUInt128(UInt128 i);

[DllImport(nameof(DisabledRuntimeMarshallingNative), EntryPoint = "Invalid")]
public static extern void CallWith(Nullable<int> s);

[DllImport(nameof(DisabledRuntimeMarshallingNative), EntryPoint = "Invalid")]
public static extern void CallWith(Span<int> s);

[DllImport(nameof(DisabledRuntimeMarshallingNative), EntryPoint = "Invalid")]
public static extern void CallWith(ReadOnlySpan<int> ros);

[DllImport(nameof(DisabledRuntimeMarshallingNative), EntryPoint = "Invalid")]
public static extern void CallWith(Vector64<int> v);

[DllImport(nameof(DisabledRuntimeMarshallingNative), EntryPoint = "Invalid")]
public static extern void CallWith(Vector128<int> v);

[DllImport(nameof(DisabledRuntimeMarshallingNative), EntryPoint = "Invalid")]
public static extern void CallWith(Vector256<int> v);

[DllImport(nameof(DisabledRuntimeMarshallingNative), EntryPoint = "Invalid")]
public static extern void CallWith(Vector<int> v);

// Apply the UnmanagedFunctionPointer attributes with the default calling conventions so that Mono's AOT compiler
// recognizes that these delegate types are used in interop and should have managed->native thunks generated for them.
[UnmanagedFunctionPointer(CallingConvention.Winapi)]
public delegate bool CheckStructWithShortAndBoolCallback(StructWithShortAndBool str, short s, bool b);
[UnmanagedFunctionPointer(CallingConvention.Winapi)]
public delegate bool CheckStructWithShortAndBoolWithVariantBoolCallback(StructWithShortAndBool str, short s, [MarshalAs(UnmanagedType.VariantBool)] bool b);

[UnmanagedCallersOnly]
public static bool CheckStructWithShortAndBoolManaged(StructWithShortAndBool str, short s, bool b)
{
return str.s == s && str.b == b;
}
#endif
}
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@
</PropertyGroup>
<ItemGroup>
<Compile Include="PInvokeAssemblyMarshallingDisabled/*.cs" />
<Compile Include="*.cs" />
<Compile Include="AutoLayout.cs" />
<Compile Include="FunctionPointers.cs" />
<Compile Include="RuntimeMarshallingDisabledAttribute.cs" />
</ItemGroup>
<ItemGroup>
<ProjectReference Include="Native_DisabledMarshalling/DisabledRuntimeMarshallingNative_DisabledMarshalling.csproj" />
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@
</PropertyGroup>
<ItemGroup>
<Compile Include="PInvokeAssemblyMarshallingEnabled/*.cs" />
<Compile Include="*.cs" />
<Compile Include="AutoLayout.cs" />
<Compile Include="FunctionPointers.cs" />
<Compile Include="RuntimeMarshallingDisabledAttribute.cs" />
</ItemGroup>
<ItemGroup>
<ProjectReference Include="Native_Default/DisabledRuntimeMarshallingNative_Default.csproj" />
Expand Down
Loading

0 comments on commit f0ba96a

Please sign in to comment.