Skip to content
This repository was archived by the owner on Nov 1, 2023. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 55 additions & 8 deletions src/ApiService/ApiService/onefuzzlib/orm/EntityConverter.cs
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ public override string ConvertName(string name) {
}
public class EntityConverter {

private const int MAX_DESERIALIZATION_RECURSION_DEPTH = 100;
private readonly ConcurrentDictionary<Type, EntityInfo> _cache;
private static readonly JsonSerializerOptions _options = new() {
PropertyNamingPolicy = new OnefuzzNamingPolicy(),
Expand Down Expand Up @@ -124,8 +125,8 @@ public static JsonSerializerOptions GetJsonSerializerOptions() {
}

private static IEnumerable<EntityProperty> GetEntityProperties<T>(ParameterInfo parameterInfo) {
var name = parameterInfo.Name.EnsureNotNull($"Invalid paramter {parameterInfo}");
var parameterType = parameterInfo.ParameterType.EnsureNotNull($"Invalid paramter {parameterInfo}");
var name = parameterInfo.Name.EnsureNotNull($"Invalid parameter {parameterInfo}");
var parameterType = parameterInfo.ParameterType.EnsureNotNull($"Invalid parameter {parameterInfo}");
var isRowkey = parameterInfo.GetCustomAttribute(typeof(RowKeyAttribute)) != null;
var isPartitionkey = parameterInfo.GetCustomAttribute(typeof(PartitionKeyAttribute)) != null;

Expand All @@ -135,7 +136,7 @@ private static IEnumerable<EntityProperty> GetEntityProperties<T>(ParameterInfo

(TypeDiscrimnatorAttribute, ITypeProvider)? discriminator = null;
if (discriminatorAttribute != null) {
var t = (ITypeProvider)(Activator.CreateInstance(discriminatorAttribute.ConverterType) ?? throw new Exception("unable to retrive the type provider"));
var t = (ITypeProvider)(Activator.CreateInstance(discriminatorAttribute.ConverterType) ?? throw new Exception("unable to retrieve the type provider"));
discriminator = (discriminatorAttribute, t);
}

Expand Down Expand Up @@ -222,7 +223,7 @@ public TableEntity ToTableEntity<T>(T typedEntity) where T : EntityBase {
}


private object? GetFieldValue(EntityInfo info, string name, TableEntity entity) {
private object? GetFieldValue(EntityInfo info, string name, TableEntity entity, int iterationCount) {
var ef = info.properties[name].First();
if (ef.kind == EntityPropertyKind.PartitionKey || ef.kind == EntityPropertyKind.RowKey) {
// partition & row keys must always be strings
Expand Down Expand Up @@ -285,7 +286,23 @@ public TableEntity ToTableEntity<T>(T typedEntity) where T : EntityBase {
var outputType = ef.type;
if (ef.discriminator != null) {
var (attr, typeProvider) = ef.discriminator.Value;
var v = GetFieldValue(info, attr.FieldName, entity) ?? throw new Exception($"No value for {attr.FieldName}");
if (iterationCount > MAX_DESERIALIZATION_RECURSION_DEPTH) {
var tags = GenerateTableEntityTags(entity);
tags.AddRange(new (string, string)[] {
("outputType", outputType?.Name ?? string.Empty),
("fieldName", fieldName)
});
throw new OrmMaxRecursionDepthReachedException($"MAX_DESERIALIZATION_RECURSION_DEPTH reached. Too many iterations deserializing {info.type}. {PrintTags(tags)}");
}
if (attr.FieldName == name) {
var tags = GenerateTableEntityTags(entity);
tags.AddRange(new (string, string)[] {
("outputType", outputType?.Name ?? string.Empty),
("fieldName", fieldName)
});
throw new OrmInvalidDiscriminatorFieldException($"Discriminator field is the same as the field being deserialized {name}. {PrintTags(tags)}");
}
var v = GetFieldValue(info, attr.FieldName, entity, ++iterationCount) ?? throw new Exception($"No value for {attr.FieldName}");
outputType = typeProvider.GetTypeInfo(v);
}

Expand All @@ -302,8 +319,13 @@ public TableEntity ToTableEntity<T>(T typedEntity) where T : EntityBase {
return JsonSerializer.Deserialize(value, outputType, options: _options);
}
}
} catch (Exception ex) {
throw new InvalidOperationException($"Unable to get value for property '{name}' (entity field '{fieldName}')", ex);
} catch (Exception ex)
when (ex is not OrmException) {
var tags = GenerateTableEntityTags(entity);
tags.AddRange(new (string, string)[] {
("fieldName", fieldName)
});
throw new InvalidOperationException($"Unable to get value for property '{name}' (entity field '{fieldName}'). {PrintTags(tags)}", ex);
}
}

Expand All @@ -313,7 +335,7 @@ public T ToRecord<T>(TableEntity entity) where T : EntityBase {

object?[] parameters;
try {
parameters = entityInfo.properties.Select(grouping => GetFieldValue(entityInfo, grouping.Key, entity)).ToArray();
parameters = entityInfo.properties.Select(grouping => GetFieldValue(entityInfo, grouping.Key, entity, 0)).ToArray();
} catch (Exception ex) {
throw new InvalidOperationException($"Unable to extract properties from TableEntity for {typeof(T)}", ex);
}
Expand Down Expand Up @@ -361,6 +383,31 @@ public T ToRecord<T>(TableEntity entity) where T : EntityBase {
return Expression.Lambda<Func<T, object?>>(call, paramter).Compile();
}

private static List<(string, string)> GenerateTableEntityTags(TableEntity entity) {
var entityKeys = string.Join(',', entity.Keys);
var partitionKey = entity.ContainsKey(EntityPropertyKind.PartitionKey.ToString()) ? entity.GetString(EntityPropertyKind.PartitionKey.ToString()) : string.Empty;
var rowKey = entity.ContainsKey(EntityPropertyKind.RowKey.ToString()) ? entity.GetString(EntityPropertyKind.RowKey.ToString()) : string.Empty;

return new List<(string, string)> {
("entityKeys", entityKeys),
("partitionKey", partitionKey),
("rowKey", rowKey)
};
}

private static string PrintTags(List<(string, string)>? tags) {
return tags != null ? string.Join(", ", tags.Select(x => $"{x.Item1}={x.Item2}")) : string.Empty;
}
}

public class OrmInvalidDiscriminatorFieldException : OrmException {
public OrmInvalidDiscriminatorFieldException(string message) : base(message) { }
}

public class OrmMaxRecursionDepthReachedException : OrmException {
public OrmMaxRecursionDepthReachedException(string message) : base(message) { }
}

public class OrmException : Exception {
public OrmException(string message) : base(message) { }
}
35 changes: 34 additions & 1 deletion src/ApiService/Tests/OrmTest.cs
Original file line number Diff line number Diff line change
@@ -1,15 +1,17 @@
using System;
using System.Linq;
using System.Text.Json;
using System.Text.Json.Nodes;
using System.Text.Json.Serialization;
using Azure.Data.Tables;
using FluentAssertions;
using Microsoft.OneFuzz.Service;
using Microsoft.OneFuzz.Service.OneFuzzLib.Orm;
using Moq;
using Xunit;

namespace Tests {
public class OrmTest {

sealed class TestObject {
public String? TheName { get; set; }
public TestEnum TheEnum { get; set; }
Expand Down Expand Up @@ -410,5 +412,36 @@ public void TestKeyGetters() {
Assert.Equal(test.PartitionKey, actualPartitionKey);
Assert.Equal(test.RowKey, actualRowKey);
}

sealed record NestedEntity(
[PartitionKey] int Id,
[RowKey] string TheName,
[property: TypeDiscrimnatorAttribute("EventType", typeof(EventTypeProvider))]
[property: JsonConverter(typeof(BaseEventConverter))]
Nested? EventType
) : EntityBase();

#pragma warning disable CS0169
public record Nested(
bool? B,
Nested? EventType
) : BaseEvent();
#pragma warning restore CS0169

[Fact]
public void TestDeeplyNestedObjects() {
var converter = new EntityConverter();
var deeplyNestedJson = $"{{{string.Concat(Enumerable.Repeat("\"EventType\": {", 3))}{new String('}', 3)}}}"; // {{{...}}}
var nestedEntity = new NestedEntity(
Id: 123,
TheName: "abc",
EventType: JsonSerializer.Deserialize<Nested>(deeplyNestedJson, new JsonSerializerOptions())
);

var tableEntity = converter.ToTableEntity(nestedEntity);
var toRecord = () => converter.ToRecord<NestedEntity>(tableEntity);

_ = toRecord.Should().Throw<Exception>().And.InnerException!.Should().BeOfType<OrmInvalidDiscriminatorFieldException>();
}
}
}