Skip to content

Commit e7e8d0d

Browse files
authored
[NRBF] Fixes and fuzzing improvements (#110194)
* Simplify array handling to fix issues with jagged and abstract array types Jagged arrays in the payload can contain cycles. In that scenario, no value is correct for `ArrayRecord.FlattenedLength`, and `ArrayRecord.GetArray` does not have enough context to know how to handle the cycles. To address these issues, jagged array handling is simplified so that calling code can handle the cycles in the most appropriate way for the application. Single-dimension arrays can be represented in the payload using abstract types such as `IComparable[]` where the concrete element type is not known. When the concrete element type is known, `ArrayRecord.GetArray` could return either `SZArrayRecord<ClassRecord>` or `SZArrayRecord<object>`; without a concrete type, we need to return something that represents the element abstractly. 1. `ArrayRecord.FlattenedLength` is removed from the API 2. `ArrayRecord.GetArray` now returns `ArrayRecord[]` for jagged arrays instead of trying to populate them 3. `ArrayRecord.GetArray` now returns `SZArrayRecord<SerializationRecord>` for single-dimension arrays instead of either `SZArrayRecord<ClassRecord>` or `SZArrayRecord<object>` * extend the Fuzzer to consume all possible data exposed by the NrbfDecoder
1 parent 592183a commit e7e8d0d

35 files changed

+1855
-878
lines changed

src/libraries/Fuzzing/DotnetFuzzing/Fuzzers/NrbfDecoderFuzzer.cs

Lines changed: 215 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33

44
using System.Buffers;
55
using System.Formats.Nrbf;
6+
using System.Numerics;
7+
using System.Runtime.CompilerServices;
68
using System.Runtime.Serialization;
79
using System.Text;
810

@@ -38,71 +40,27 @@ private static void Test(Span<byte> testSpan, Stream stream)
3840
{
3941
if (NrbfDecoder.StartsWithPayloadHeader(testSpan))
4042
{
43+
HashSet<SerializationRecordId> visited = new();
44+
Queue<SerializationRecord> queue = new();
4145
try
4246
{
4347
SerializationRecord record = NrbfDecoder.Decode(stream, out IReadOnlyDictionary<SerializationRecordId, SerializationRecord> recordMap);
44-
switch (record.RecordType)
48+
49+
Assert.Equal(true, recordMap.ContainsKey(record.Id)); // make sure the loop below includes it
50+
foreach (SerializationRecord fromMap in recordMap.Values)
4551
{
46-
case SerializationRecordType.ArraySingleObject:
47-
SZArrayRecord<object?> arrayObj = (SZArrayRecord<object?>)record;
48-
object?[] objArray = arrayObj.GetArray();
49-
Assert.Equal(arrayObj.Length, objArray.Length);
50-
Assert.Equal(1, arrayObj.Rank);
51-
break;
52-
case SerializationRecordType.ArraySingleString:
53-
SZArrayRecord<string?> arrayString = (SZArrayRecord<string?>)record;
54-
string?[] array = arrayString.GetArray();
55-
Assert.Equal(arrayString.Length, array.Length);
56-
Assert.Equal(1, arrayString.Rank);
57-
Assert.Equal(true, arrayString.TypeNameMatches(typeof(string[])));
58-
break;
59-
case SerializationRecordType.ArraySinglePrimitive:
60-
case SerializationRecordType.BinaryArray:
61-
ArrayRecord arrayBinary = (ArrayRecord)record;
62-
Assert.NotNull(arrayBinary.TypeName);
63-
break;
64-
case SerializationRecordType.BinaryObjectString:
65-
_ = ((PrimitiveTypeRecord<string>)record).Value;
66-
break;
67-
case SerializationRecordType.ClassWithId:
68-
case SerializationRecordType.ClassWithMembersAndTypes:
69-
case SerializationRecordType.SystemClassWithMembersAndTypes:
70-
ClassRecord classRecord = (ClassRecord)record;
71-
Assert.NotNull(classRecord.TypeName);
72-
73-
foreach (string name in classRecord.MemberNames)
74-
{
75-
Assert.Equal(true, classRecord.HasMember(name));
76-
}
77-
break;
78-
case SerializationRecordType.MemberPrimitiveTyped:
79-
PrimitiveTypeRecord primitiveType = (PrimitiveTypeRecord)record;
80-
Assert.NotNull(primitiveType.Value);
81-
break;
82-
case SerializationRecordType.MemberReference:
83-
Assert.NotNull(record.TypeName);
84-
break;
85-
case SerializationRecordType.BinaryLibrary:
86-
Assert.Equal(false, record.Id.Equals(default));
87-
break;
88-
case SerializationRecordType.ObjectNull:
89-
case SerializationRecordType.ObjectNullMultiple:
90-
case SerializationRecordType.ObjectNullMultiple256:
91-
Assert.Equal(default, record.Id);
92-
break;
93-
case SerializationRecordType.MessageEnd:
94-
case SerializationRecordType.SerializedStreamHeader:
95-
// case SerializationRecordType.ClassWithMembers: will cause NotSupportedException
96-
// case SerializationRecordType.SystemClassWithMembers: will cause NotSupportedException
97-
default:
98-
throw new Exception("Unexpected RecordType");
52+
visited.Add(fromMap.Id);
53+
queue.Enqueue(fromMap);
9954
}
10055
}
10156
catch (SerializationException) { /* Reading from the stream encountered invalid NRBF data.*/ }
10257
catch (NotSupportedException) { /* Reading from the stream encountered unsupported records */ }
10358
catch (DecoderFallbackException) { /* Reading from the stream encountered an invalid UTF8 sequence. */ }
10459
catch (EndOfStreamException) { /* The end of the stream was reached before reading SerializationRecordType.MessageEnd record. */ }
10560
catch (IOException) { /* An I/O error occurred. */ }
61+
62+
// Lets consume it outside of the try/catch block to not swallow any exceptions by accident.
63+
Consume(visited, queue);
10664
}
10765
else
10866
{
@@ -117,6 +75,209 @@ private static void Test(Span<byte> testSpan, Stream stream)
11775
}
11876
}
11977

78+
private static void Consume(HashSet<SerializationRecordId> visited, Queue<SerializationRecord> queue)
79+
{
80+
while (queue.Count > 0)
81+
{
82+
SerializationRecord serializationRecord = queue.Dequeue();
83+
84+
if (serializationRecord is PrimitiveTypeRecord primitiveTypeRecord)
85+
{
86+
ConsumePrimitiveValue(primitiveTypeRecord.Value);
87+
}
88+
else if (serializationRecord is ClassRecord classRecord)
89+
{
90+
foreach (string memberName in classRecord.MemberNames)
91+
{
92+
ConsumePrimitiveValue(memberName);
93+
94+
Assert.Equal(true, classRecord.HasMember(memberName));
95+
96+
object? rawValue;
97+
98+
try
99+
{
100+
rawValue = classRecord.GetRawValue(memberName);
101+
}
102+
catch (SerializationException ex) when (ex.Message == "Invalid member reference.")
103+
{
104+
// It was a reference to a non-existing record, just continue.
105+
continue;
106+
}
107+
108+
if (rawValue is not null)
109+
{
110+
if (rawValue is SerializationRecord nestedRecord)
111+
{
112+
TryEnqueue(nestedRecord);
113+
}
114+
else
115+
{
116+
ConsumePrimitiveValue(rawValue);
117+
}
118+
}
119+
}
120+
}
121+
else if (serializationRecord is ArrayRecord arrayRecord)
122+
{
123+
Type? type;
124+
125+
try
126+
{
127+
// THIS IS VERY BAD IDEA FOR ANY KIND OF PRODUCT CODE!!
128+
// IT'S USED ONLY FOR THE PURPOSE OF TESTING, DO NOT COPY IT.
129+
type = Type.GetType(arrayRecord.TypeName.AssemblyQualifiedName, throwOnError: false);
130+
if (type is null)
131+
{
132+
continue;
133+
}
134+
}
135+
catch (Exception) // throwOnError passed to GetType does not prevent from all kinds of exceptions
136+
{
137+
// It was some type made up by the Fuzzer.
138+
// Since it's currently impossible to get the array without providing the type,
139+
// we just bail here (in the future we may add an enumerator to ArrayRecord).
140+
continue;
141+
}
142+
143+
Array? array;
144+
try
145+
{
146+
array = arrayRecord.GetArray(type);
147+
}
148+
catch (SerializationException ex) when (ex.Message == "Invalid member reference.")
149+
{
150+
// It contained a reference to a non-existing record, just continue.
151+
continue;
152+
}
153+
154+
ReadOnlySpan<int> lengths = arrayRecord.Lengths;
155+
long totalElementsCount = 1;
156+
for (int i = 0; i < arrayRecord.Rank; i++)
157+
{
158+
Assert.Equal(lengths[i], array.GetLength(i));
159+
totalElementsCount *= lengths[i];
160+
}
161+
162+
// This array contains indices that are used to get values of multi-dimensional array.
163+
// At the beginning, all values are set to 0, so we start from the first element.
164+
int[] indices = new int[arrayRecord.Rank];
165+
166+
long flatIndex = 0;
167+
for (; flatIndex < totalElementsCount; flatIndex++)
168+
{
169+
object? rawValue = array.GetValue(indices);
170+
if (rawValue is not null)
171+
{
172+
if (rawValue is SerializationRecord record)
173+
{
174+
TryEnqueue(record);
175+
}
176+
else
177+
{
178+
ConsumePrimitiveValue(rawValue);
179+
}
180+
}
181+
182+
// The loop below is responsible for incrementing the multi-dimensional indices.
183+
// It finds the dimension and then performs an increment.
184+
int dimension = indices.Length - 1;
185+
while (dimension >= 0)
186+
{
187+
indices[dimension]++;
188+
if (indices[dimension] < lengths[dimension])
189+
{
190+
break;
191+
}
192+
indices[dimension] = 0;
193+
dimension--;
194+
}
195+
}
196+
197+
// We track the flat index to ensure that we have enumerated over all elements.
198+
Assert.Equal(totalElementsCount, flatIndex);
199+
}
200+
else
201+
{
202+
// The map may currently contain it (it may change in the future)
203+
Assert.Equal(SerializationRecordType.BinaryLibrary, serializationRecord.RecordType);
204+
}
205+
}
206+
207+
void TryEnqueue(SerializationRecord record)
208+
{
209+
if (visited.Add(record.Id)) // avoid unbounded recursion
210+
{
211+
queue.Enqueue(record);
212+
}
213+
}
214+
}
215+
216+
[MethodImpl(MethodImplOptions.NoInlining)]
217+
private static void ConsumePrimitiveValue(object value)
218+
{
219+
if (value is string text)
220+
Assert.Equal(text, text.ToString()); // we want to touch all elements to see if memory is not corrupted
221+
else if (value is bool boolean)
222+
Assert.Equal(true, Unsafe.BitCast<bool, byte>(boolean) is 1 or 0); // other values are illegal!!
223+
else if (value is sbyte @sbyte)
224+
TestNumber(@sbyte);
225+
else if (value is byte @byte)
226+
TestNumber(@byte);
227+
else if (value is char character)
228+
TestNumber(character);
229+
else if (value is short @short)
230+
TestNumber(@short);
231+
else if (value is ushort @ushort)
232+
TestNumber(@ushort);
233+
else if (value is int integer)
234+
TestNumber(integer);
235+
else if (value is uint @uint)
236+
TestNumber(@uint);
237+
else if (value is long @long)
238+
TestNumber(@long);
239+
else if (value is ulong @ulong)
240+
TestNumber(@ulong);
241+
else if (value is float @float)
242+
{
243+
if (!float.IsNaN(@float) && !float.IsInfinity(@float))
244+
{
245+
TestNumber(@float);
246+
}
247+
}
248+
else if (value is double @double)
249+
{
250+
if (!double.IsNaN(@double) && !double.IsInfinity(@double))
251+
{
252+
TestNumber(@double);
253+
}
254+
}
255+
else if (value is decimal @decimal)
256+
TestNumber(@decimal);
257+
else if (value is nint @nint)
258+
TestNumber(@nint);
259+
else if (value is nuint @nuint)
260+
TestNumber(@nuint);
261+
else if (value is DateTime datetime)
262+
Assert.Equal(true, datetime >= DateTime.MinValue && datetime <= DateTime.MaxValue);
263+
else if (value is TimeSpan timeSpan)
264+
Assert.Equal(true, timeSpan >= TimeSpan.MinValue && timeSpan <= TimeSpan.MaxValue);
265+
else
266+
throw new InvalidOperationException();
267+
268+
static void TestNumber<T>(T value) where T : IComparable<T>, IMinMaxValue<T>
269+
{
270+
if (value.CompareTo(T.MinValue) < 0)
271+
{
272+
throw new Exception($"Expected {value} to be more or equal {T.MinValue}, {value.CompareTo(T.MinValue)}.");
273+
}
274+
if (value.CompareTo(T.MaxValue) > 0)
275+
{
276+
throw new Exception($"Expected {value} to be less or equal {T.MaxValue}, {value.CompareTo(T.MaxValue)}.");
277+
}
278+
}
279+
}
280+
120281
private sealed class NonSeekableStream : MemoryStream
121282
{
122283
public NonSeekableStream(byte[] buffer) : base(buffer) { }

src/libraries/System.Formats.Nrbf/ref/System.Formats.Nrbf.cs

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@ namespace System.Formats.Nrbf
99
public abstract partial class ArrayRecord : System.Formats.Nrbf.SerializationRecord
1010
{
1111
internal ArrayRecord() { }
12-
public virtual long FlattenedLength { get { throw null; } }
1312
public override System.Formats.Nrbf.SerializationRecordId Id { get { throw null; } }
1413
public abstract System.ReadOnlySpan<int> Lengths { get; }
1514
public int Rank { get { throw null; } }

src/libraries/System.Formats.Nrbf/src/PACKAGE.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ There are more than a dozen different serialization [record types](https://learn
5454
- `PrimitiveTypeRecord<T>` derives from the non-generic [PrimitiveTypeRecord](https://learn.microsoft.com/dotnet/api/system.formats.nrbf.primitivetyperecord), which also exposes a [Value](https://learn.microsoft.com/dotnet/api/system.formats.nrbf.primitivetyperecord.value) property. But on the base class, the value is returned as `object` (which introduces boxing for value types).
5555
- [ClassRecord](https://learn.microsoft.com/dotnet/api/system.formats.nrbf.classrecord): describes all `class` and `struct` besides the aforementioned primitive types.
5656
- [ArrayRecord](https://learn.microsoft.com/dotnet/api/system.formats.nrbf.arrayrecord): describes all array records, including jagged and multi-dimensional arrays.
57-
- [`SZArrayRecord<T>`](https://learn.microsoft.com/dotnet/api/system.formats.nrbf.szarrayrecord-1): describes single-dimensional, zero-indexed array records, where `T` can be either a primitive type or a `ClassRecord`.
57+
- [`SZArrayRecord<T>`](https://learn.microsoft.com/dotnet/api/system.formats.nrbf.szarrayrecord-1): describes single-dimensional, zero-indexed array records, where `T` can be either a primitive type or a `SerializationRecord`.
5858

5959
```csharp
6060
SerializationRecord rootObject = NrbfDecoder.Decode(payload); // payload is a Stream

src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/AllowedRecordType.cs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,12 +28,13 @@ internal enum AllowedRecordTypes : uint
2828
ArraySingleString = 1 << SerializationRecordType.ArraySingleString,
2929

3030
Nulls = ObjectNull | ObjectNullMultiple256 | ObjectNullMultiple,
31+
Arrays = ArraySingleObject | ArraySinglePrimitive | ArraySingleString | BinaryArray,
3132

3233
/// <summary>
3334
/// Any .NET object (a primitive, a reference type, a reference or single null).
3435
/// </summary>
3536
AnyObject = MemberPrimitiveTyped
36-
| ArraySingleObject | ArraySinglePrimitive | ArraySingleString | BinaryArray
37+
| Arrays
3738
| ClassWithId | ClassWithMembersAndTypes | SystemClassWithMembersAndTypes
3839
| BinaryObjectString
3940
| MemberReference

0 commit comments

Comments
 (0)