|  | 
|  | 1 | +using System.Buffers; | 
|  | 2 | +using System.Collections.Immutable; | 
|  | 3 | +using System.Reflection; | 
|  | 4 | +using System.Text; | 
|  | 5 | +using Microsoft.CodeAnalysis; | 
|  | 6 | +using Microsoft.CodeAnalysis.CSharp; | 
|  | 7 | +using Microsoft.CodeAnalysis.CSharp.Syntax; | 
|  | 8 | + | 
|  | 9 | +namespace StackExchange.Redis.Build; | 
|  | 10 | + | 
|  | 11 | +[Generator(LanguageNames.CSharp)] | 
|  | 12 | +public class FastHashGenerator : IIncrementalGenerator | 
|  | 13 | +{ | 
|  | 14 | +    public void Initialize(IncrementalGeneratorInitializationContext context) | 
|  | 15 | +    { | 
|  | 16 | +        var literals = context.SyntaxProvider | 
|  | 17 | +            .CreateSyntaxProvider(Predicate, Transform) | 
|  | 18 | +            .Where(pair => pair.Name is { Length: > 0 }) | 
|  | 19 | +            .Collect(); | 
|  | 20 | + | 
|  | 21 | +        context.RegisterSourceOutput(literals, Generate); | 
|  | 22 | +    } | 
|  | 23 | + | 
|  | 24 | +    private bool Predicate(SyntaxNode node, CancellationToken cancellationToken) | 
|  | 25 | +    { | 
|  | 26 | +        // looking for [FastHash] partial static class Foo { } | 
|  | 27 | +        if (node is ClassDeclarationSyntax decl | 
|  | 28 | +            && decl.Modifiers.Any(SyntaxKind.StaticKeyword) | 
|  | 29 | +            && decl.Modifiers.Any(SyntaxKind.PartialKeyword)) | 
|  | 30 | +        { | 
|  | 31 | +            foreach (var attribList in decl.AttributeLists) | 
|  | 32 | +            { | 
|  | 33 | +                foreach (var attrib in attribList.Attributes) | 
|  | 34 | +                { | 
|  | 35 | +                    if (attrib.Name.ToString() is "FastHashAttribute" or "FastHash") return true; | 
|  | 36 | +                } | 
|  | 37 | +            } | 
|  | 38 | +        } | 
|  | 39 | + | 
|  | 40 | +        return false; | 
|  | 41 | +    } | 
|  | 42 | + | 
|  | 43 | +    private static string GetName(INamedTypeSymbol type) | 
|  | 44 | +    { | 
|  | 45 | +        if (type.ContainingType is null) return type.Name; | 
|  | 46 | +        var stack = new Stack<string>(); | 
|  | 47 | +        while (true) | 
|  | 48 | +        { | 
|  | 49 | +            stack.Push(type.Name); | 
|  | 50 | +            if (type.ContainingType is null) break; | 
|  | 51 | +            type = type.ContainingType; | 
|  | 52 | +        } | 
|  | 53 | +        var sb = new StringBuilder(stack.Pop()); | 
|  | 54 | +        while (stack.Count != 0) | 
|  | 55 | +        { | 
|  | 56 | +            sb.Append('.').Append(stack.Pop()); | 
|  | 57 | +        } | 
|  | 58 | +        return sb.ToString(); | 
|  | 59 | +    } | 
|  | 60 | + | 
|  | 61 | +    private (string Namespace, string ParentType, string Name, string Value) Transform( | 
|  | 62 | +        GeneratorSyntaxContext ctx, | 
|  | 63 | +        CancellationToken cancellationToken) | 
|  | 64 | +    { | 
|  | 65 | +        // extract the name and value (defaults to name, but can be overridden via attribute) and the location | 
|  | 66 | +        if (ctx.SemanticModel.GetDeclaredSymbol(ctx.Node) is not INamedTypeSymbol named) return default; | 
|  | 67 | +        string ns = "", parentType = ""; | 
|  | 68 | +        if (named.ContainingType is { } containingType) | 
|  | 69 | +        { | 
|  | 70 | +            parentType = GetName(containingType); | 
|  | 71 | +            ns = containingType.ContainingNamespace.ToDisplayString(SymbolDisplayFormat.CSharpErrorMessageFormat); | 
|  | 72 | +        } | 
|  | 73 | +        else if (named.ContainingNamespace is { } containingNamespace) | 
|  | 74 | +        { | 
|  | 75 | +            ns = containingNamespace.ToDisplayString(SymbolDisplayFormat.CSharpErrorMessageFormat); | 
|  | 76 | +        } | 
|  | 77 | + | 
|  | 78 | +        string name = named.Name, value = ""; | 
|  | 79 | +        foreach (var attrib in named.GetAttributes()) | 
|  | 80 | +        { | 
|  | 81 | +            if (attrib.AttributeClass?.Name == "FastHashAttribute") | 
|  | 82 | +            { | 
|  | 83 | +                if (attrib.ConstructorArguments.Length == 1) | 
|  | 84 | +                { | 
|  | 85 | +                    if (attrib.ConstructorArguments[0].Value?.ToString() is { Length: > 0 } val) | 
|  | 86 | +                    { | 
|  | 87 | +                        value = val; | 
|  | 88 | +                        break; | 
|  | 89 | +                    } | 
|  | 90 | +                } | 
|  | 91 | +            } | 
|  | 92 | +        } | 
|  | 93 | + | 
|  | 94 | +        if (string.IsNullOrWhiteSpace(value)) | 
|  | 95 | +        { | 
|  | 96 | +            value = name.Replace("_", "-"); // if nothing explicit: infer from name | 
|  | 97 | +        } | 
|  | 98 | + | 
|  | 99 | +        return (ns, parentType, name, value); | 
|  | 100 | +    } | 
|  | 101 | + | 
|  | 102 | +    private string GetVersion() | 
|  | 103 | +    { | 
|  | 104 | +        var asm = GetType().Assembly; | 
|  | 105 | +        if (asm.GetCustomAttributes(typeof(AssemblyFileVersionAttribute), false).FirstOrDefault() is | 
|  | 106 | +            AssemblyFileVersionAttribute { Version: { Length: > 0 } } version) | 
|  | 107 | +        { | 
|  | 108 | +            return version.Version; | 
|  | 109 | +        } | 
|  | 110 | + | 
|  | 111 | +        return asm.GetName().Version?.ToString() ?? "??"; | 
|  | 112 | +    } | 
|  | 113 | + | 
|  | 114 | +    private void Generate( | 
|  | 115 | +        SourceProductionContext ctx, | 
|  | 116 | +        ImmutableArray<(string Namespace, string ParentType, string Name, string Value)> literals) | 
|  | 117 | +    { | 
|  | 118 | +        if (literals.IsDefaultOrEmpty) return; | 
|  | 119 | + | 
|  | 120 | +        var sb = new StringBuilder("// <auto-generated />") | 
|  | 121 | +            .AppendLine().Append("// ").Append(GetType().Name).Append(" v").Append(GetVersion()).AppendLine(); | 
|  | 122 | + | 
|  | 123 | +        // lease a buffer that is big enough for the longest string | 
|  | 124 | +        var buffer = ArrayPool<byte>.Shared.Rent( | 
|  | 125 | +            Encoding.UTF8.GetMaxByteCount(literals.Max(l => l.Value.Length))); | 
|  | 126 | +        int indent = 0; | 
|  | 127 | + | 
|  | 128 | +        StringBuilder NewLine() => sb.AppendLine().Append(' ', indent * 4); | 
|  | 129 | +        NewLine().Append("using System;"); | 
|  | 130 | +        NewLine().Append("using StackExchange.Redis;"); | 
|  | 131 | +        NewLine().Append("#pragma warning disable CS8981"); | 
|  | 132 | +        foreach (var grp in literals.GroupBy(l => (l.Namespace, l.ParentType))) | 
|  | 133 | +        { | 
|  | 134 | +            NewLine(); | 
|  | 135 | +            int braces = 0; | 
|  | 136 | +            if (!string.IsNullOrWhiteSpace(grp.Key.Namespace)) | 
|  | 137 | +            { | 
|  | 138 | +                NewLine().Append("namespace ").Append(grp.Key.Namespace); | 
|  | 139 | +                NewLine().Append("{"); | 
|  | 140 | +                indent++; | 
|  | 141 | +                braces++; | 
|  | 142 | +            } | 
|  | 143 | +            if (!string.IsNullOrWhiteSpace(grp.Key.ParentType)) | 
|  | 144 | +            { | 
|  | 145 | +                if (grp.Key.ParentType.Contains('.')) // nested types | 
|  | 146 | +                { | 
|  | 147 | +                    foreach (var part in grp.Key.ParentType.Split('.')) | 
|  | 148 | +                    { | 
|  | 149 | +                        NewLine().Append("partial class ").Append(part); | 
|  | 150 | +                        NewLine().Append("{"); | 
|  | 151 | +                        indent++; | 
|  | 152 | +                        braces++; | 
|  | 153 | +                    } | 
|  | 154 | +                } | 
|  | 155 | +                else | 
|  | 156 | +                { | 
|  | 157 | +                    NewLine().Append("partial class ").Append(grp.Key.ParentType); | 
|  | 158 | +                    NewLine().Append("{"); | 
|  | 159 | +                    indent++; | 
|  | 160 | +                    braces++; | 
|  | 161 | +                } | 
|  | 162 | +            } | 
|  | 163 | + | 
|  | 164 | +            foreach (var literal in grp) | 
|  | 165 | +            { | 
|  | 166 | +                int len; | 
|  | 167 | +                unsafe | 
|  | 168 | +                { | 
|  | 169 | +                    fixed (byte* bPtr = buffer) // netstandard2.0 forces fallback API | 
|  | 170 | +                    { | 
|  | 171 | +                        fixed (char* cPtr = literal.Value) | 
|  | 172 | +                        { | 
|  | 173 | +                            len = Encoding.UTF8.GetBytes(cPtr, literal.Value.Length, bPtr, buffer.Length); | 
|  | 174 | +                        } | 
|  | 175 | +                    } | 
|  | 176 | +                } | 
|  | 177 | + | 
|  | 178 | +                // perform string escaping on the generated value (this includes the quotes, note) | 
|  | 179 | +                var csValue = SyntaxFactory.LiteralExpression(SyntaxKind.StringLiteralExpression, SyntaxFactory.Literal(literal.Value)).ToFullString(); | 
|  | 180 | + | 
|  | 181 | +                var hash = FastHash.Hash64(buffer.AsSpan(0, len)); | 
|  | 182 | +                NewLine().Append("static partial class ").Append(literal.Name); | 
|  | 183 | +                NewLine().Append("{"); | 
|  | 184 | +                indent++; | 
|  | 185 | +                NewLine().Append("public const int Length = ").Append(len).Append(';'); | 
|  | 186 | +                NewLine().Append("public const long Hash = ").Append(hash).Append(';'); | 
|  | 187 | +                NewLine().Append("public static ReadOnlySpan<byte> U8 => ").Append(csValue).Append("u8;"); | 
|  | 188 | +                NewLine().Append("public const string Text = ").Append(csValue).Append(';'); | 
|  | 189 | +                if (len <= 8) | 
|  | 190 | +                { | 
|  | 191 | +                    // the hash enforces all the values | 
|  | 192 | +                    NewLine().Append("public static bool Is(long hash, in RawResult value) => hash == Hash && value.Payload.Length == Length;"); | 
|  | 193 | +                    NewLine().Append("public static bool Is(long hash, ReadOnlySpan<byte> value) => hash == Hash & value.Length == Length;"); | 
|  | 194 | +                } | 
|  | 195 | +                else | 
|  | 196 | +                { | 
|  | 197 | +                    NewLine().Append("public static bool Is(long hash, in RawResult value) => hash == Hash && value.IsEqual(U8);"); | 
|  | 198 | +                    NewLine().Append("public static bool Is(long hash, ReadOnlySpan<byte> value) => hash == Hash && value.SequenceEqual(U8);"); | 
|  | 199 | +                } | 
|  | 200 | +                indent--; | 
|  | 201 | +                NewLine().Append("}"); | 
|  | 202 | +            } | 
|  | 203 | + | 
|  | 204 | +            // handle any closing braces | 
|  | 205 | +            while (braces-- > 0) | 
|  | 206 | +            { | 
|  | 207 | +                indent--; | 
|  | 208 | +                NewLine().Append("}"); | 
|  | 209 | +            } | 
|  | 210 | +        } | 
|  | 211 | + | 
|  | 212 | +        ArrayPool<byte>.Shared.Return(buffer); | 
|  | 213 | +        ctx.AddSource("FastHash.generated.cs", sb.ToString()); | 
|  | 214 | +    } | 
|  | 215 | +} | 
0 commit comments