Skip to content

Commit

Permalink
Fix c# generated string marshalling
Browse files Browse the repository at this point in the history
  • Loading branch information
arsher authored and ralfbiedert committed Feb 1, 2025
1 parent 31052bf commit 1f68902
Show file tree
Hide file tree
Showing 14 changed files with 180 additions and 29 deletions.
4 changes: 2 additions & 2 deletions crates/backend_csharp/src/interop/types/composite.rs
Original file line number Diff line number Diff line change
Expand Up @@ -123,10 +123,10 @@ pub fn write_type_definition_composite_to_managed_marshal_field(
}
} else if matches!(a.array_type(), CType::Pattern(TypePattern::CChar)) {
indented!(w, r"var source_{0} = new ReadOnlySpan<byte>(unmanaged.{0}, {1});", field_name, a.len())?;
indented!(w, r"var terminatorIndex = source_{}.IndexOf<byte>(0);", field_name)?;
indented!(w, r"var terminatorIndex_{0} = source_{0}.IndexOf<byte>(0);", field_name)?;
indented!(
w,
r"result.{0} = Encoding.UTF8.GetString(source_{0}.Slice(0, terminatorIndex == -1 ? Math.Min(source_{0}.Length, {1}) : terminatorIndex));",
r"result.{0} = Encoding.UTF8.GetString(source_{0}.Slice(0, terminatorIndex_{0} == -1 ? Math.Min(source_{0}.Length, {1}) : terminatorIndex_{0}));",
field_name,
a.len()
)?;
Expand Down
1 change: 1 addition & 0 deletions crates/reference_project/src/functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,7 @@ pub fn nested_array_3(input: NestedArray) -> u8 {
pub fn char_array_1() -> CharArray {
let mut result = CharArray {
str: FixedString { data: [0; 32] },
str_2: FixedString { data: [0; 32] },
};

result.str.data[..14].copy_from_slice(b"Hello, World!\0");
Expand Down
1 change: 1 addition & 0 deletions crates/reference_project/src/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,7 @@ unsafe impl<const N: usize> CTypeInfo for FixedString<N> {
#[ffi_type]
pub struct CharArray {
pub str: FixedString<32>,
pub str_2: FixedString<32>,
}

#[ffi_type]
Expand Down
1 change: 1 addition & 0 deletions tests/tests/c_function_styles_typedefs.h
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,7 @@ typedef struct ARRAY
typedef struct CHARARRAY
{
char str[32];
char str_2[32];
} CHARARRAY;

typedef struct CONTAINER
Expand Down
1 change: 1 addition & 0 deletions tests/tests/c_reference_project/reference_project.h
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,7 @@ typedef struct ARRAY
typedef struct CHARARRAY
{
char str[32];
char str_2[32];
} CHARARRAY;

typedef struct CONTAINER
Expand Down
13 changes: 12 additions & 1 deletion tests/tests/cpython_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1147,11 +1147,14 @@ class CharArray(ctypes.Structure):
# These fields represent the underlying C data layout
_fields_ = [
("str", ctypes.c_char * 32),
("str_2", ctypes.c_char * 32),
]

def __init__(self, str = None):
def __init__(self, str = None, str_2 = None):
if str is not None:
self.str = str
if str_2 is not None:
self.str_2 = str_2

@property
def str(self):
Expand All @@ -1161,6 +1164,14 @@ def str(self):
def str(self, value):
return ctypes.Structure.__set__(self, "str", value)

@property
def str_2(self):
return ctypes.Structure.__get__(self, "str_2")

@str_2.setter
def str_2(self, value):
return ctypes.Structure.__set__(self, "str_2", value)


class Container(ctypes.Structure):

Expand Down
13 changes: 12 additions & 1 deletion tests/tests/cpython_benchmarks/reference_project.py
Original file line number Diff line number Diff line change
Expand Up @@ -1147,11 +1147,14 @@ class CharArray(ctypes.Structure):
# These fields represent the underlying C data layout
_fields_ = [
("str", ctypes.c_char * 32),
("str_2", ctypes.c_char * 32),
]

def __init__(self, str = None):
def __init__(self, str = None, str_2 = None):
if str is not None:
self.str = str
if str_2 is not None:
self.str_2 = str_2

@property
def str(self):
Expand All @@ -1161,6 +1164,14 @@ def str(self):
def str(self, value):
return ctypes.Structure.__set__(self, "str", value)

@property
def str_2(self):
return ctypes.Structure.__get__(self, "str_2")

@str_2.setter
def str_2(self, value):
return ctypes.Structure.__set__(self, "str_2", value)


class Container(ctypes.Structure):

Expand Down
13 changes: 12 additions & 1 deletion tests/tests/cpython_reference_project/reference_project.py
Original file line number Diff line number Diff line change
Expand Up @@ -1147,11 +1147,14 @@ class CharArray(ctypes.Structure):
# These fields represent the underlying C data layout
_fields_ = [
("str", ctypes.c_char * 32),
("str_2", ctypes.c_char * 32),
]

def __init__(self, str = None):
def __init__(self, str = None, str_2 = None):
if str is not None:
self.str = str
if str_2 is not None:
self.str_2 = str_2

@property
def str(self):
Expand All @@ -1161,6 +1164,14 @@ def str(self):
def str(self, value):
return ctypes.Structure.__set__(self, "str", value)

@property
def str_2(self):
return ctypes.Structure.__get__(self, "str_2")

@str_2.setter
def str_2(self, value):
return ctypes.Structure.__set__(self, "str_2", value)


class Container(ctypes.Structure):

Expand Down
27 changes: 23 additions & 4 deletions tests/tests/csharp_benchmarks/Interop.cs
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,9 @@ public static partial class Interop
static Interop()
{
var api_version = Interop.pattern_api_guard();
if (api_version != 5676777239214057195ul)
if (api_version != 5403378736106778954ul)
{
throw new TypeLoadException($"API reports hash {api_version} which differs from hash in bindings (5676777239214057195). You probably forgot to update / copy either the bindings or the library.");
throw new TypeLoadException($"API reports hash {api_version} which differs from hash in bindings (5403378736106778954). You probably forgot to update / copy either the bindings or the library.");
}
}

Expand Down Expand Up @@ -1085,6 +1085,7 @@ public static BoolField ConvertToManaged(Unmanaged unmanaged)
public partial struct CharArray
{
public string str;
public string str_2;
}

[CustomMarshaller(typeof(CharArray), MarshalMode.Default, typeof(CharArrayMarshaller))]
Expand All @@ -1094,6 +1095,7 @@ internal static class CharArrayMarshaller
public unsafe struct Unmanaged
{
public fixed byte str[32];
public fixed byte str_2[32];
}

public static Unmanaged ConvertToUnmanaged(CharArray managed)
Expand All @@ -1116,6 +1118,19 @@ public static Unmanaged ConvertToUnmanaged(CharArray managed)
result.str[written] = 0;
}
}

if(managed.str_2 != null)
{
fixed(char* s = managed.str_2)
{
if(Encoding.UTF8.GetByteCount(managed.str_2, 0, managed.str_2.Length) + 1 > 32)
{
throw new InvalidOperationException($"The managed string field '{nameof(CharArray.str_2)}' cannot be encoded to fit the fixed size array of 32.");
}
var written = Encoding.UTF8.GetBytes(s, managed.str_2.Length, result.str_2, 31);
result.str_2[written] = 0;
}
}
}

return result;
Expand All @@ -1130,8 +1145,12 @@ public static CharArray ConvertToManaged(Unmanaged unmanaged)
unsafe
{
var source_str = new ReadOnlySpan<byte>(unmanaged.str, 32);
var terminatorIndex = source_str.IndexOf<byte>(0);
result.str = Encoding.UTF8.GetString(source_str.Slice(0, terminatorIndex == -1 ? Math.Min(source_str.Length, 32) : terminatorIndex));
var terminatorIndex_str = source_str.IndexOf<byte>(0);
result.str = Encoding.UTF8.GetString(source_str.Slice(0, terminatorIndex_str == -1 ? Math.Min(source_str.Length, 32) : terminatorIndex_str));

var source_str_2 = new ReadOnlySpan<byte>(unmanaged.str_2, 32);
var terminatorIndex_str_2 = source_str_2.IndexOf<byte>(0);
result.str_2 = Encoding.UTF8.GetString(source_str_2.Slice(0, terminatorIndex_str_2 == -1 ? Math.Min(source_str_2.Length, 32) : terminatorIndex_str_2));
}

return result;
Expand Down
27 changes: 23 additions & 4 deletions tests/tests/csharp_overloads_dotnet.cs
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,9 @@ public static partial class Interop
static Interop()
{
var api_version = Interop.pattern_api_guard();
if (api_version != 5676777239214057195ul)
if (api_version != 5403378736106778954ul)
{
throw new TypeLoadException($"API reports hash {api_version} which differs from hash in bindings (5676777239214057195). You probably forgot to update / copy either the bindings or the library.");
throw new TypeLoadException($"API reports hash {api_version} which differs from hash in bindings (5403378736106778954). You probably forgot to update / copy either the bindings or the library.");
}
}

Expand Down Expand Up @@ -1085,6 +1085,7 @@ public static BoolField ConvertToManaged(Unmanaged unmanaged)
public partial struct CharArray
{
public string str;
public string str_2;
}

[CustomMarshaller(typeof(CharArray), MarshalMode.Default, typeof(CharArrayMarshaller))]
Expand All @@ -1094,6 +1095,7 @@ internal static class CharArrayMarshaller
public unsafe struct Unmanaged
{
public fixed byte str[32];
public fixed byte str_2[32];
}

public static Unmanaged ConvertToUnmanaged(CharArray managed)
Expand All @@ -1116,6 +1118,19 @@ public static Unmanaged ConvertToUnmanaged(CharArray managed)
result.str[written] = 0;
}
}

if(managed.str_2 != null)
{
fixed(char* s = managed.str_2)
{
if(Encoding.UTF8.GetByteCount(managed.str_2, 0, managed.str_2.Length) + 1 > 32)
{
throw new InvalidOperationException($"The managed string field '{nameof(CharArray.str_2)}' cannot be encoded to fit the fixed size array of 32.");
}
var written = Encoding.UTF8.GetBytes(s, managed.str_2.Length, result.str_2, 31);
result.str_2[written] = 0;
}
}
}

return result;
Expand All @@ -1130,8 +1145,12 @@ public static CharArray ConvertToManaged(Unmanaged unmanaged)
unsafe
{
var source_str = new ReadOnlySpan<byte>(unmanaged.str, 32);
var terminatorIndex = source_str.IndexOf<byte>(0);
result.str = Encoding.UTF8.GetString(source_str.Slice(0, terminatorIndex == -1 ? Math.Min(source_str.Length, 32) : terminatorIndex));
var terminatorIndex_str = source_str.IndexOf<byte>(0);
result.str = Encoding.UTF8.GetString(source_str.Slice(0, terminatorIndex_str == -1 ? Math.Min(source_str.Length, 32) : terminatorIndex_str));

var source_str_2 = new ReadOnlySpan<byte>(unmanaged.str_2, 32);
var terminatorIndex_str_2 = source_str_2.IndexOf<byte>(0);
result.str_2 = Encoding.UTF8.GetString(source_str_2.Slice(0, terminatorIndex_str_2 == -1 ? Math.Min(source_str_2.Length, 32) : terminatorIndex_str_2));
}

return result;
Expand Down
27 changes: 23 additions & 4 deletions tests/tests/csharp_reference_project/Interop.cs
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,9 @@ public static partial class Interop
static Interop()
{
var api_version = Interop.pattern_api_guard();
if (api_version != 5676777239214057195ul)
if (api_version != 5403378736106778954ul)
{
throw new TypeLoadException($"API reports hash {api_version} which differs from hash in bindings (5676777239214057195). You probably forgot to update / copy either the bindings or the library.");
throw new TypeLoadException($"API reports hash {api_version} which differs from hash in bindings (5403378736106778954). You probably forgot to update / copy either the bindings or the library.");
}
}

Expand Down Expand Up @@ -1085,6 +1085,7 @@ public static BoolField ConvertToManaged(Unmanaged unmanaged)
public partial struct CharArray
{
public string str;
public string str_2;
}

[CustomMarshaller(typeof(CharArray), MarshalMode.Default, typeof(CharArrayMarshaller))]
Expand All @@ -1094,6 +1095,7 @@ internal static class CharArrayMarshaller
public unsafe struct Unmanaged
{
public fixed byte str[32];
public fixed byte str_2[32];
}

public static Unmanaged ConvertToUnmanaged(CharArray managed)
Expand All @@ -1116,6 +1118,19 @@ public static Unmanaged ConvertToUnmanaged(CharArray managed)
result.str[written] = 0;
}
}

if(managed.str_2 != null)
{
fixed(char* s = managed.str_2)
{
if(Encoding.UTF8.GetByteCount(managed.str_2, 0, managed.str_2.Length) + 1 > 32)
{
throw new InvalidOperationException($"The managed string field '{nameof(CharArray.str_2)}' cannot be encoded to fit the fixed size array of 32.");
}
var written = Encoding.UTF8.GetBytes(s, managed.str_2.Length, result.str_2, 31);
result.str_2[written] = 0;
}
}
}

return result;
Expand All @@ -1130,8 +1145,12 @@ public static CharArray ConvertToManaged(Unmanaged unmanaged)
unsafe
{
var source_str = new ReadOnlySpan<byte>(unmanaged.str, 32);
var terminatorIndex = source_str.IndexOf<byte>(0);
result.str = Encoding.UTF8.GetString(source_str.Slice(0, terminatorIndex == -1 ? Math.Min(source_str.Length, 32) : terminatorIndex));
var terminatorIndex_str = source_str.IndexOf<byte>(0);
result.str = Encoding.UTF8.GetString(source_str.Slice(0, terminatorIndex_str == -1 ? Math.Min(source_str.Length, 32) : terminatorIndex_str));

var source_str_2 = new ReadOnlySpan<byte>(unmanaged.str_2, 32);
var terminatorIndex_str_2 = source_str_2.IndexOf<byte>(0);
result.str_2 = Encoding.UTF8.GetString(source_str_2.Slice(0, terminatorIndex_str_2 == -1 ? Math.Min(source_str_2.Length, 32) : terminatorIndex_str_2));
}

return result;
Expand Down
27 changes: 23 additions & 4 deletions tests/tests/csharp_write_types_all.cs
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,9 @@ public static partial class Interop
static Interop()
{
var api_version = Interop.pattern_api_guard();
if (api_version != 5676777239214057195ul)
if (api_version != 5403378736106778954ul)
{
throw new TypeLoadException($"API reports hash {api_version} which differs from hash in bindings (5676777239214057195). You probably forgot to update / copy either the bindings or the library.");
throw new TypeLoadException($"API reports hash {api_version} which differs from hash in bindings (5403378736106778954). You probably forgot to update / copy either the bindings or the library.");
}
}

Expand Down Expand Up @@ -1085,6 +1085,7 @@ public static BoolField ConvertToManaged(Unmanaged unmanaged)
public partial struct CharArray
{
public string str;
public string str_2;
}

[CustomMarshaller(typeof(CharArray), MarshalMode.Default, typeof(CharArrayMarshaller))]
Expand All @@ -1094,6 +1095,7 @@ internal static class CharArrayMarshaller
public unsafe struct Unmanaged
{
public fixed byte str[32];
public fixed byte str_2[32];
}

public static Unmanaged ConvertToUnmanaged(CharArray managed)
Expand All @@ -1116,6 +1118,19 @@ public static Unmanaged ConvertToUnmanaged(CharArray managed)
result.str[written] = 0;
}
}

if(managed.str_2 != null)
{
fixed(char* s = managed.str_2)
{
if(Encoding.UTF8.GetByteCount(managed.str_2, 0, managed.str_2.Length) + 1 > 32)
{
throw new InvalidOperationException($"The managed string field '{nameof(CharArray.str_2)}' cannot be encoded to fit the fixed size array of 32.");
}
var written = Encoding.UTF8.GetBytes(s, managed.str_2.Length, result.str_2, 31);
result.str_2[written] = 0;
}
}
}

return result;
Expand All @@ -1130,8 +1145,12 @@ public static CharArray ConvertToManaged(Unmanaged unmanaged)
unsafe
{
var source_str = new ReadOnlySpan<byte>(unmanaged.str, 32);
var terminatorIndex = source_str.IndexOf<byte>(0);
result.str = Encoding.UTF8.GetString(source_str.Slice(0, terminatorIndex == -1 ? Math.Min(source_str.Length, 32) : terminatorIndex));
var terminatorIndex_str = source_str.IndexOf<byte>(0);
result.str = Encoding.UTF8.GetString(source_str.Slice(0, terminatorIndex_str == -1 ? Math.Min(source_str.Length, 32) : terminatorIndex_str));

var source_str_2 = new ReadOnlySpan<byte>(unmanaged.str_2, 32);
var terminatorIndex_str_2 = source_str_2.IndexOf<byte>(0);
result.str_2 = Encoding.UTF8.GetString(source_str_2.Slice(0, terminatorIndex_str_2 == -1 ? Math.Min(source_str_2.Length, 32) : terminatorIndex_str_2));
}

return result;
Expand Down
Loading

0 comments on commit 1f68902

Please sign in to comment.