diff --git a/src/MemoryPack.Generator/MemoryPackGenerator.Emitter.cs b/src/MemoryPack.Generator/MemoryPackGenerator.Emitter.cs index d17735fa..be06a69a 100644 --- a/src/MemoryPack.Generator/MemoryPackGenerator.Emitter.cs +++ b/src/MemoryPack.Generator/MemoryPackGenerator.Emitter.cs @@ -105,6 +105,30 @@ static void Generate(TypeDeclarationSyntax syntax, Compilation compilation, stri } sb.AppendLine(); + // emit type info + if (unionFormatter) + { + AppendTypeRemarks(serializationInfoLogDirectoryPath, typeMeta, sb, fullType); + typeMeta.EmitUnionFormatterTemplate(sb, context, typeSymbol); + } + else + { + // Emit the type. Wrap the append method to capture the current variables. + typeMeta.Emit(sb, context, + () => AppendTypeRemarks(serializationInfoLogDirectoryPath, typeMeta, sb, fullType)); + } + + if (!ns.IsGlobalNamespace && !context.IsCSharp10OrGreater()) + { + sb.AppendLine($"}}"); + } + + var code = sb.ToString(); + context.AddSource($"{fullType}.MemoryPackFormatter.g.cs", code); + } + + static void AppendTypeRemarks(string? serializationInfoLogDirectoryPath, TypeMeta typeMeta, StringBuilder sb, string fullType) + { // Write document comment as remarks if (typeMeta.GenerateType is GenerateType.Object or GenerateType.VersionTolerant or GenerateType.CircularReference) { @@ -131,24 +155,6 @@ static void Generate(TypeDeclarationSyntax syntax, Compilation compilation, stri } } } - - // emit type info - if (unionFormatter) - { - typeMeta.EmitUnionFormatterTemplate(sb, context, typeSymbol); - } - else - { - typeMeta.Emit(sb, context); - } - - if (!ns.IsGlobalNamespace && !context.IsCSharp10OrGreater()) - { - sb.AppendLine($"}}"); - } - - var code = sb.ToString(); - context.AddSource($"{fullType}.MemoryPackFormatter.g.cs", code); } static bool IsPartial(TypeDeclarationSyntax typeDeclaration) @@ -253,17 +259,54 @@ string WithEscape(ISymbol symbol) public partial class TypeMeta { - public void Emit(StringBuilder writer, IGeneratorContext context) + public void Emit(StringBuilder writer, IGeneratorContext context, Action appendTypeRemarks) { + var containingTypeDeclarations = new List(); + var containingType = Symbol.ContainingType; + while (containingType is not null) + { + var isInterface = containingType.TypeKind == TypeKind.Interface; + containingTypeDeclarations.Add((containingType.IsRecord, containingType.IsValueType, containingType.IsAbstract, isInterface) switch + { + (true, true, false, false) => $"partial record struct {containingType.Name}", + (true, false, false, false) => $"partial record {containingType.Name}", + (false, true, false, false) => $"partial struct {containingType.Name}", + (false, false, false, false) => $"partial class {containingType.Name}", + (false, false, true, false) => $"abstract partial class {containingType.Name}", + (false, false, true, true) => $"partial interface {containingType.Name}", + _ => $"partial class {containingType.Name}" + }); + containingType = containingType.ContainingType; + } + containingTypeDeclarations.Reverse(); + + foreach (var declaration in containingTypeDeclarations) + { + writer.AppendLine(declaration); + writer.AppendLine("{"); + } + + + // Write the documentation header after any containing classes. + appendTypeRemarks.Invoke(); + if (IsUnion) { writer.AppendLine(EmitUnionTemplate(context)); - return; } if (GenerateType == GenerateType.Collection) { writer.AppendLine(EmitGenericCollectionTemplate(context)); + } + + // Write the closing braces. + if (IsUnion || GenerateType == GenerateType.Collection) + { + for(int i = 0; i < containingTypeDeclarations.Count; ++i) + { + writer.AppendLine("}"); + } return; } @@ -310,21 +353,6 @@ public void Emit(StringBuilder writer, IGeneratorContext context) (false, false) => "class", }; - var containingTypeDeclarations = new List(); - var containingType = Symbol.ContainingType; - while (containingType is not null) - { - containingTypeDeclarations.Add((containingType.IsRecord, containingType.IsValueType) switch - { - (true, true) => $"partial record struct {containingType.Name}", - (true, false) => $"partial record {containingType.Name}", - (false, true) => $"partial struct {containingType.Name}", - (false, false) => $"partial class {containingType.Name}", - }); - containingType = containingType.ContainingType; - } - containingTypeDeclarations.Reverse(); - var nullable = IsValueType ? "" : "?"; string staticRegisterFormatterMethod, staticMemoryPackableMethod, scopedRef, constraint, registerBody, registerT; @@ -374,12 +402,6 @@ public void Emit(StringBuilder writer, IGeneratorContext context) ? "Serialize(ref MemoryPackWriter" : "Serialize(ref MemoryPackWriter"; - foreach (var declaration in containingTypeDeclarations) - { - writer.AppendLine(declaration); - writer.AppendLine("{"); - } - writer.AppendLine($$""" partial {{classOrStructOrRecord}} {{TypeName}} : IMemoryPackable<{{TypeName}}>{{fixedSizeInterface}} { diff --git a/tests/MemoryPack.Tests/GeneratorTest.cs b/tests/MemoryPack.Tests/GeneratorTest.cs index cb388931..53567ad3 100644 --- a/tests/MemoryPack.Tests/GeneratorTest.cs +++ b/tests/MemoryPack.Tests/GeneratorTest.cs @@ -54,6 +54,10 @@ public void Standard() public void Nested() { VerifyEquivalent(new NestedContainer.StandardTypeNested() { One = 9999 }); + VerifyEquivalent(new NestedStructContainer.StandardTypeNested() { One = 9999 }); + VerifyEquivalent(new NestedRecordStructContainer.StandardTypeNested() { One = 9999 }); + VerifyEquivalent(new NestedInterfaceContainer.StandardTypeNested() { One = 9999 }); + VerifyEquivalent(new NestedAbstractClassContainer.StandardTypeNested() { One = 9999 }); VerifyEquivalent(new DoublyNestedContainer.DoublyNestedContainerInner.StandardTypeDoublyNested() { One = 9999 }); } diff --git a/tests/MemoryPack.Tests/Models/StandardType.cs b/tests/MemoryPack.Tests/Models/StandardType.cs index 484835fa..225439cb 100644 --- a/tests/MemoryPack.Tests/Models/StandardType.cs +++ b/tests/MemoryPack.Tests/Models/StandardType.cs @@ -71,6 +71,42 @@ public partial class StandardTypeNested } } + public partial struct NestedStructContainer + { + [MemoryPackable] + public partial class StandardTypeNested + { + public int One { get; set; } + } + } + + public partial record struct NestedRecordStructContainer + { + [MemoryPackable] + public partial class StandardTypeNested + { + public int One { get; set; } + } + } + + public partial interface NestedInterfaceContainer + { + [MemoryPackable] + public partial class StandardTypeNested + { + public int One { get; set; } + } + } + + public abstract partial class NestedAbstractClassContainer + { + [MemoryPackable] + public partial class StandardTypeNested + { + public int One { get; set; } + } + } + public partial class DoublyNestedContainer { public partial class DoublyNestedContainerInner