Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/ApiService/ApiService/Functions/AgentCommands.cs
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ private async Async.Task<HttpResponseData> Get(HttpRequestData req) {
}
var nodeCommand = request.OkV;

var message = await _context.NodeMessageOperations.GetMessage(nodeCommand.MachineId).FirstOrDefaultAsync();
var message = await _context.NodeMessageOperations.GetMessage(nodeCommand.MachineId);
if (message != null) {
var command = message.Message;
var messageId = message.MessageId;
Expand Down
2 changes: 1 addition & 1 deletion src/ApiService/ApiService/Functions/Node.cs
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ private async Async.Task<HttpResponseData> Get(HttpRequestData req) {

var (tasks, messages) = await (
_context.NodeTasksOperations.GetByMachineId(machineId).ToListAsync().AsTask(),
_context.NodeMessageOperations.GetMessage(machineId).ToListAsync().AsTask());
_context.NodeMessageOperations.GetMessages(machineId).ToListAsync().AsTask());

var commands = messages.Select(m => m.Message).ToList();
return await RequestHandling.Ok(req, NodeToNodeSearchResult(node with { Tasks = tasks, Messages = commands }));
Expand Down
12 changes: 9 additions & 3 deletions src/ApiService/ApiService/onefuzzlib/NodeMessageOperations.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using ApiService.OneFuzzLib.Orm;
using System.Threading.Tasks;
using ApiService.OneFuzzLib.Orm;
using Microsoft.OneFuzz.Service.OneFuzzLib.Orm;

namespace Microsoft.OneFuzz.Service;
Expand All @@ -14,7 +15,9 @@ public NodeMessage(Guid machineId, NodeCommand message) : this(machineId, NewSor
};

public interface INodeMessageOperations : IOrm<NodeMessage> {
IAsyncEnumerable<NodeMessage> GetMessage(Guid machineId);
IAsyncEnumerable<NodeMessage> GetMessages(Guid machineId);

Async.Task<NodeMessage?> GetMessage(Guid machineId);
Async.Task ClearMessages(Guid machineId);

Async.Task SendMessage(Guid machineId, NodeCommand message, string? messageId = null);
Expand All @@ -25,7 +28,7 @@ public class NodeMessageOperations : Orm<NodeMessage>, INodeMessageOperations {
public NodeMessageOperations(ILogTracer log, IOnefuzzContext context)
: base(log, context) { }

public IAsyncEnumerable<NodeMessage> GetMessage(Guid machineId)
public IAsyncEnumerable<NodeMessage> GetMessages(Guid machineId)
=> QueryAsync(Query.PartitionKey(machineId.ToString()));

public async Async.Task ClearMessages(Guid machineId) {
Expand All @@ -45,4 +48,7 @@ public async Async.Task SendMessage(Guid machineId, NodeCommand message, string?
_logTracer.WithHttpStatus(r.ErrorV).Error($"failed to insert message with id: {messageId:Tag:MessageId} for machine id: {machineId:Tag:MachineId} message: {message:Tag:Message}");
}
}

public async Task<NodeMessage?> GetMessage(Guid machineId)
=> await QueryAsync(Query.PartitionKey(machineId.ToString()), maxPerPage: 1).FirstOrDefaultAsync();
}
63 changes: 55 additions & 8 deletions src/ApiService/ApiService/onefuzzlib/orm/EntityConverter.cs
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ public override string ConvertName(string name) {
}
public class EntityConverter {

private const int MAX_DESERIALIZATION_RECURSION_DEPTH = 100;
private readonly ConcurrentDictionary<Type, EntityInfo> _cache;
private static readonly JsonSerializerOptions _options = new() {
PropertyNamingPolicy = new OnefuzzNamingPolicy(),
Expand Down Expand Up @@ -124,8 +125,8 @@ public static JsonSerializerOptions GetJsonSerializerOptions() {
}

private static IEnumerable<EntityProperty> GetEntityProperties<T>(ParameterInfo parameterInfo) {
var name = parameterInfo.Name.EnsureNotNull($"Invalid paramter {parameterInfo}");
var parameterType = parameterInfo.ParameterType.EnsureNotNull($"Invalid paramter {parameterInfo}");
var name = parameterInfo.Name.EnsureNotNull($"Invalid parameter {parameterInfo}");
var parameterType = parameterInfo.ParameterType.EnsureNotNull($"Invalid parameter {parameterInfo}");
var isRowkey = parameterInfo.GetCustomAttribute(typeof(RowKeyAttribute)) != null;
var isPartitionkey = parameterInfo.GetCustomAttribute(typeof(PartitionKeyAttribute)) != null;

Expand All @@ -135,7 +136,7 @@ private static IEnumerable<EntityProperty> GetEntityProperties<T>(ParameterInfo

(TypeDiscrimnatorAttribute, ITypeProvider)? discriminator = null;
if (discriminatorAttribute != null) {
var t = (ITypeProvider)(Activator.CreateInstance(discriminatorAttribute.ConverterType) ?? throw new Exception("unable to retrive the type provider"));
var t = (ITypeProvider)(Activator.CreateInstance(discriminatorAttribute.ConverterType) ?? throw new Exception("unable to retrieve the type provider"));
discriminator = (discriminatorAttribute, t);
}

Expand Down Expand Up @@ -222,7 +223,7 @@ public TableEntity ToTableEntity<T>(T typedEntity) where T : EntityBase {
}


private object? GetFieldValue(EntityInfo info, string name, TableEntity entity) {
private object? GetFieldValue(EntityInfo info, string name, TableEntity entity, int iterationCount) {
var ef = info.properties[name].First();
if (ef.kind == EntityPropertyKind.PartitionKey || ef.kind == EntityPropertyKind.RowKey) {
// partition & row keys must always be strings
Expand Down Expand Up @@ -285,7 +286,23 @@ public TableEntity ToTableEntity<T>(T typedEntity) where T : EntityBase {
var outputType = ef.type;
if (ef.discriminator != null) {
var (attr, typeProvider) = ef.discriminator.Value;
var v = GetFieldValue(info, attr.FieldName, entity) ?? throw new Exception($"No value for {attr.FieldName}");
if (iterationCount > MAX_DESERIALIZATION_RECURSION_DEPTH) {
var tags = GenerateTableEntityTags(entity);
tags.AddRange(new (string, string)[] {
("outputType", outputType?.Name ?? string.Empty),
("fieldName", fieldName)
});
throw new OrmMaxRecursionDepthReachedException($"MAX_DESERIALIZATION_RECURSION_DEPTH reached. Too many iterations deserializing {info.type}. {PrintTags(tags)}");
}
if (attr.FieldName == name) {
var tags = GenerateTableEntityTags(entity);
tags.AddRange(new (string, string)[] {
("outputType", outputType?.Name ?? string.Empty),
("fieldName", fieldName)
});
throw new OrmInvalidDiscriminatorFieldException($"Discriminator field is the same as the field being deserialized {name}. {PrintTags(tags)}");
}
var v = GetFieldValue(info, attr.FieldName, entity, ++iterationCount) ?? throw new Exception($"No value for {attr.FieldName}");
outputType = typeProvider.GetTypeInfo(v);
}

Expand All @@ -302,8 +319,13 @@ public TableEntity ToTableEntity<T>(T typedEntity) where T : EntityBase {
return JsonSerializer.Deserialize(value, outputType, options: _options);
}
}
} catch (Exception ex) {
throw new InvalidOperationException($"Unable to get value for property '{name}' (entity field '{fieldName}')", ex);
} catch (Exception ex)
when (ex is not OrmException) {
var tags = GenerateTableEntityTags(entity);
tags.AddRange(new (string, string)[] {
("fieldName", fieldName)
});
throw new InvalidOperationException($"Unable to get value for property '{name}' (entity field '{fieldName}'). {PrintTags(tags)}", ex);
}
}

Expand All @@ -313,7 +335,7 @@ public T ToRecord<T>(TableEntity entity) where T : EntityBase {

object?[] parameters;
try {
parameters = entityInfo.properties.Select(grouping => GetFieldValue(entityInfo, grouping.Key, entity)).ToArray();
parameters = entityInfo.properties.Select(grouping => GetFieldValue(entityInfo, grouping.Key, entity, 0)).ToArray();
} catch (Exception ex) {
throw new InvalidOperationException($"Unable to extract properties from TableEntity for {typeof(T)}", ex);
}
Expand Down Expand Up @@ -361,6 +383,31 @@ public T ToRecord<T>(TableEntity entity) where T : EntityBase {
return Expression.Lambda<Func<T, object?>>(call, paramter).Compile();
}

private static List<(string, string)> GenerateTableEntityTags(TableEntity entity) {
var entityKeys = string.Join(',', entity.Keys);
var partitionKey = entity.ContainsKey(EntityPropertyKind.PartitionKey.ToString()) ? entity.GetString(EntityPropertyKind.PartitionKey.ToString()) : string.Empty;
var rowKey = entity.ContainsKey(EntityPropertyKind.RowKey.ToString()) ? entity.GetString(EntityPropertyKind.RowKey.ToString()) : string.Empty;

return new List<(string, string)> {
("entityKeys", entityKeys),
("partitionKey", partitionKey),
("rowKey", rowKey)
};
}

private static string PrintTags(List<(string, string)>? tags) {
return tags != null ? string.Join(", ", tags.Select(x => $"{x.Item1}={x.Item2}")) : string.Empty;
}
}

public class OrmInvalidDiscriminatorFieldException : OrmException {
public OrmInvalidDiscriminatorFieldException(string message) : base(message) { }
}

public class OrmMaxRecursionDepthReachedException : OrmException {
public OrmMaxRecursionDepthReachedException(string message) : base(message) { }
}

public class OrmException : Exception {
public OrmException(string message) : base(message) { }
}
6 changes: 3 additions & 3 deletions src/ApiService/ApiService/onefuzzlib/orm/Orm.cs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
namespace ApiService.OneFuzzLib.Orm {
public interface IOrm<T> where T : EntityBase {
Task<TableClient> GetTableClient(string table, ResourceIdentifier? accountId = null);
IAsyncEnumerable<T> QueryAsync(string? filter = null);
IAsyncEnumerable<T> QueryAsync(string? filter = null, int? maxPerPage = null);

Task<T> GetEntityAsync(string partitionKey, string rowKey);
Task<ResultVoid<(HttpStatusCode Status, string Reason)>> Insert(T entity);
Expand Down Expand Up @@ -49,14 +49,14 @@ public Orm(ILogTracer logTracer, IOnefuzzContext context) {
_entityConverter = _context.EntityConverter;
}

public async IAsyncEnumerable<T> QueryAsync(string? filter = null) {
public async IAsyncEnumerable<T> QueryAsync(string? filter = null, int? maxPerPage = null) {
var tableClient = await GetTableClient(typeof(T).Name);

if (filter == "") {
filter = null;
}

await foreach (var x in tableClient.QueryAsync<TableEntity>(filter).Select(x => _entityConverter.ToRecord<T>(x))) {
await foreach (var x in tableClient.QueryAsync<TableEntity>(filter: filter, maxPerPage: maxPerPage).Select(x => _entityConverter.ToRecord<T>(x))) {
yield return x;
}
}
Expand Down
32 changes: 31 additions & 1 deletion src/ApiService/IntegrationTests/AgentCommandsTests.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
using System.Net;
using System;
using System.Net;
using FluentAssertions;
using IntegrationTests.Fakes;
using Microsoft.OneFuzz.Service;
using Microsoft.OneFuzz.Service.Functions;
Expand Down Expand Up @@ -50,4 +52,32 @@ public async Async.Task AgentAuthorization_IsAccepted() {
var result = await func.Run(TestHttpRequestData.Empty("GET"));
Assert.Equal(HttpStatusCode.BadRequest, result.StatusCode); // BadRequest due to no body, not Unauthorized
}

[Fact]
public async Async.Task AgentCommand_GetsCommand() {
var machineId = Guid.NewGuid();
var messageId = Guid.NewGuid().ToString();
var command = new NodeCommand {
Stop = new StopNodeCommand()
};
await Context.InsertAll(new[] {
new NodeMessage (
machineId,
messageId,
command
),
});

var commandRequest = new NodeCommandGet(machineId);
var auth = new TestEndpointAuthorization(RequestType.Agent, Logger, Context);
var func = new AgentCommands(Logger, auth, Context);

var result = await func.Run(TestHttpRequestData.FromJson("GET", commandRequest));
Assert.Equal(HttpStatusCode.OK, result.StatusCode);

var pendingNodeCommand = BodyAs<PendingNodeCommand>(result);
pendingNodeCommand.Envelope.Should().NotBeNull();
pendingNodeCommand.Envelope?.Command.Should().BeEquivalentTo(command);
pendingNodeCommand.Envelope?.MessageId.Should().Be(messageId);
}
}
35 changes: 34 additions & 1 deletion src/ApiService/Tests/OrmTest.cs
Original file line number Diff line number Diff line change
@@ -1,15 +1,17 @@
using System;
using System.Linq;
using System.Text.Json;
using System.Text.Json.Nodes;
using System.Text.Json.Serialization;
using Azure.Data.Tables;
using FluentAssertions;
using Microsoft.OneFuzz.Service;
using Microsoft.OneFuzz.Service.OneFuzzLib.Orm;
using Moq;
using Xunit;

namespace Tests {
public class OrmTest {

sealed class TestObject {
public String? TheName { get; set; }
public TestEnum TheEnum { get; set; }
Expand Down Expand Up @@ -410,5 +412,36 @@ public void TestKeyGetters() {
Assert.Equal(test.PartitionKey, actualPartitionKey);
Assert.Equal(test.RowKey, actualRowKey);
}

sealed record NestedEntity(
[PartitionKey] int Id,
[RowKey] string TheName,
[property: TypeDiscrimnatorAttribute("EventType", typeof(EventTypeProvider))]
[property: JsonConverter(typeof(BaseEventConverter))]
Nested? EventType
) : EntityBase();

#pragma warning disable CS0169
public record Nested(
bool? B,
Nested? EventType
) : BaseEvent();
#pragma warning restore CS0169

[Fact]
public void TestDeeplyNestedObjects() {
var converter = new EntityConverter();
var deeplyNestedJson = $"{{{string.Concat(Enumerable.Repeat("\"EventType\": {", 3))}{new String('}', 3)}}}"; // {{{...}}}
var nestedEntity = new NestedEntity(
Id: 123,
TheName: "abc",
EventType: JsonSerializer.Deserialize<Nested>(deeplyNestedJson, new JsonSerializerOptions())
);

var tableEntity = converter.ToTableEntity(nestedEntity);
var toRecord = () => converter.ToRecord<NestedEntity>(tableEntity);

_ = toRecord.Should().Throw<Exception>().And.InnerException!.Should().BeOfType<OrmInvalidDiscriminatorFieldException>();
}
}
}