Skip to content

Commit

Permalink
Add Binding cache
Browse files Browse the repository at this point in the history
  • Loading branch information
martin-strecker-sonarsource committed Dec 28, 2023
1 parent 404f314 commit f76cdf4
Showing 1 changed file with 52 additions and 21 deletions.
73 changes: 52 additions & 21 deletions ProtoBuf.Logic/MessageBinder.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,13 @@
using Google.Protobuf;
using Google.Protobuf.Reflection;
using ProtoBuf.Antlr;
using System.Collections.Concurrent;
using static Google.Protobuf.WireFormat;
using static ProtoBuf.Antlr.Protobuf3Parser;

namespace ProtoBuf.Logic
{
internal class MessageBinder(ProtoContext protoContext, MessageDefContext? messageDef) : IMessage
internal class MessageBinder : IMessage
{
enum FieldType
{
Expand All @@ -16,7 +17,19 @@ enum FieldType
Packed
}

public TypedMessage? Result { get; internal set; }
public MessageBinder(ProtoContext protoContext, MessageDefContext? messageDef) : this(protoContext, messageDef, null) { }

private MessageBinder(ProtoContext protoContext, MessageDefContext? messageDef, ConcurrentDictionary<(Type_Context, string names), ParserRuleContext?>? cache)
{
ProtoContext = protoContext;
MessageDef = messageDef;
BindingCache = cache ?? new();
}

private ConcurrentDictionary<(Type_Context, string names), ParserRuleContext?> BindingCache { get; }
public ProtoContext ProtoContext { get; }
public MessageDefContext? MessageDef { get; }
public TypedMessage? Result { get; private set; }

public void MergeFrom(CodedInputStream input)
{
Expand All @@ -26,13 +39,13 @@ public void MergeFrom(CodedInputStream input)
while (input.Position < targetPosition && !input.IsAtEnd)
{
var (index, type) = input.ReadWireTag();
var field = messageDef?.messageBody().messageElement().Select(x => x.field()).FirstOrDefault(x => int.TryParse(x?.fieldNumber()?.GetText(), out var i) && i == index);
var parsedFields = messageDef != null && field != null && FitsFieldType(type, field.type_()) is var fieldType and not FieldType.Unknown
? ParseField(input, protoContext, fieldType, field)
var field = MessageDef?.messageBody().messageElement().Select(x => x.field()).FirstOrDefault(x => int.TryParse(x?.fieldNumber()?.GetText(), out var i) && i == index);
var parsedFields = MessageDef != null && field != null && FitsFieldType(type, field.type_()) is var fieldType and not FieldType.Unknown
? ParseField(input, fieldType, field)
: ParseUnknownField(input, new WireTag(index, type));
fields.AddRange(parsedFields);
}
Result = new TypedMessage(fields, messageDef, messageDef);
Result = new TypedMessage(fields, MessageDef, MessageDef);
}

private IEnumerable<TypedField> ParseUnknownField(CodedInputStream stream, WireTag wireTag)
Expand All @@ -42,40 +55,40 @@ private IEnumerable<TypedField> ParseUnknownField(CodedInputStream stream, WireT
yield return new TypedField("Unknown", index, null, new TypedUnknown(value));
}

private IEnumerable<TypedField> ParseField(CodedInputStream stream, ProtoContext protoContext, FieldType type, FieldContext field)
private IEnumerable<TypedField> ParseField(CodedInputStream stream, FieldType type, FieldContext field)
{
var (fieldName, index) = (field.fieldName().GetText(), int.Parse(field.fieldNumber().GetText()));
var values = ReadExpectedType(stream, protoContext, type, field.type_());
var values = ReadExpectedType(stream, type, field.type_());
foreach (var value in values)
{
yield return new TypedField(fieldName, index, field, value);
}
}

private static IEnumerable<ProtoType> ReadExpectedType(CodedInputStream stream, ProtoContext protoContext, FieldType type, Type_Context expectedType)
private IEnumerable<ProtoType> ReadExpectedType(CodedInputStream stream, FieldType type, Type_Context expectedType)
{
if (type == FieldType.Packed)
{
var length = stream.ReadLength();
var expectedEnd = stream.Position + length;
while (stream.Position < expectedEnd)
{
if (ReadExpectedType(stream, protoContext, expectedType) is { } protoType)
if (ReadExpectedType(stream, expectedType) is { } protoType)
{
yield return protoType;
}
}
}
else
{
if (ReadExpectedType(stream, protoContext, expectedType) is { } protoType)
if (ReadExpectedType(stream, expectedType) is { } protoType)
{
yield return protoType;
}
}
}

private static ProtoType? ReadExpectedType(CodedInputStream stream, ProtoContext protoContext, Type_Context expectedType)
private ProtoType? ReadExpectedType(CodedInputStream stream, Type_Context expectedType)
{
if (expectedType.INT32() is not null)
return new TypedInt32(stream.ReadInt32(), expectedType);
Expand Down Expand Up @@ -107,10 +120,10 @@ private static IEnumerable<ProtoType> ReadExpectedType(CodedInputStream stream,
return new TypedString(stream.ReadString(), expectedType);
else if (expectedType.enumType() is not null || expectedType.messageType() is not null)
{
return BindMessageOrEnumDef(protoContext, expectedType) switch
return BindMessageOrEnumDef(expectedType) switch
{
EnumDefContext enumDef => new TypedEnum(stream.ReadEnum(), expectedType, enumDef),
MessageDefContext innerMessageDef => ParseMessage(stream, protoContext, innerMessageDef),
MessageDefContext innerMessageDef => ParseMessage(stream, innerMessageDef),
_ => throw new NotImplementedException(),
};
}
Expand Down Expand Up @@ -154,19 +167,37 @@ static IEnumerable<string> DottedNames(EnumTypeContext message)
yield return name.GetText();
}

private static ParserRuleContext? BindMessageOrEnumDef(ProtoContext protoContext, Type_Context expectedType)
private ParserRuleContext? BindMessageOrEnumDef(Type_Context expectedType)
{
var names = (expectedType.enumType(), expectedType.messageType()) switch
{
(null, { } messageType) => DottedNames(messageType),
({ } enumType, null) => DottedNames(enumType),
_ => throw new InvalidOperationException(),
};
return BindType(protoContext, names, expectedType, names => new MessageDefBinder(names)) as ParserRuleContext
?? BindType(protoContext, names, expectedType, names => new EnumDefBinder(names));
return BindCached(expectedType, names)
?? BindType(names, expectedType, names => new MessageDefBinder(names)) as ParserRuleContext
?? BindType(names, expectedType, names => new EnumDefBinder(names));
}

private ParserRuleContext? BindCached(Type_Context expectedType, IEnumerable<string> names)
{
var namesString = string.Join(".", names);
var key = (expectedType, namesString);
if (BindingCache.TryGetValue(key, out var cached))
{
return cached;
}
else
{
var result = BindType(names, expectedType, names => new MessageDefBinder(names)) as ParserRuleContext
?? BindType(names, expectedType, names => new EnumDefBinder(names));
BindingCache.TryAdd(key, result);
return result;
}
}

private static T? BindType<T>(ProtoContext protoContext, IEnumerable<string> names, Type_Context expectedType, Func<IEnumerable<string>, IProtobuf3Visitor<T>> visitorFactory)
private T? BindType<T>(IEnumerable<string> names, Type_Context expectedType, Func<IEnumerable<string>, IProtobuf3Visitor<T>> visitorFactory)
{
var rootSearch = names.FirstOrDefault() == ".";
if (!rootSearch
Expand All @@ -175,12 +206,12 @@ static IEnumerable<string> DottedNames(EnumTypeContext message)
{
return relative;
}
return visitorFactory(rootSearch ? names.Skip(1) : names).Visit(protoContext);
return visitorFactory(rootSearch ? names.Skip(1) : names).Visit(ProtoContext);
}

private static ProtoType? ParseMessage(CodedInputStream stream, ProtoContext protoContext, MessageDefContext? messageDef)
private ProtoType? ParseMessage(CodedInputStream stream, MessageDefContext? messageDef)
{
var builder = new MessageBinder(protoContext, messageDef);
var builder = new MessageBinder(ProtoContext, messageDef, BindingCache);
stream.ReadRawMessage(builder);
return builder.Result;
}
Expand Down

0 comments on commit f76cdf4

Please sign in to comment.