| 
5 | 5 | using System.Collections.Generic;  | 
6 | 6 | using System.Collections.Immutable;  | 
7 | 7 | using System.Diagnostics;  | 
 | 8 | +using System.Diagnostics.CodeAnalysis;  | 
8 | 9 | using System.IO;  | 
9 | 10 | using System.Linq;  | 
10 | 11 | using System.Text;  | 
 | 
15 | 16 | 
 
  | 
16 | 17 | namespace XUnitWrapperGenerator;  | 
17 | 18 | 
 
  | 
 | 19 | +internal struct CompData  | 
 | 20 | +{  | 
 | 21 | +    internal CompData(string assemblyName, IMethodSymbol? entryPoint, IEnumerable<IMethodSymbol> possibleEntryPoints, OutputKind outputKind)  | 
 | 22 | +    {  | 
 | 23 | +        AssemblyName = assemblyName;  | 
 | 24 | +        EntryPoint = entryPoint;  | 
 | 25 | +        PossibleEntryPoints = possibleEntryPoints;  | 
 | 26 | +        OutputKind = outputKind;  | 
 | 27 | +    }  | 
 | 28 | + | 
 | 29 | +    public string AssemblyName { get; private set; }  | 
 | 30 | +    public IMethodSymbol? EntryPoint { get; private set; }  | 
 | 31 | +    public IEnumerable<IMethodSymbol> PossibleEntryPoints { get; private set; }  | 
 | 32 | +    public OutputKind OutputKind { get; private set; }  | 
 | 33 | +}  | 
 | 34 | + | 
18 | 35 | [Generator]  | 
19 | 36 | public sealed class XUnitWrapperGenerator : IIncrementalGenerator  | 
20 | 37 | {  | 
@@ -58,9 +75,11 @@ public void Initialize(IncrementalGeneratorInitializationContext context)  | 
58 | 75 |             return aliasMap.ToImmutable();  | 
59 | 76 |         }).WithComparer(new ImmutableDictionaryValueComparer<string, string>(EqualityComparer<string>.Default));  | 
60 | 77 | 
 
  | 
61 |  | -        var assemblyName = context.CompilationProvider.Select((comp, ct) => comp.Assembly.MetadataName);  | 
62 |  | - | 
63 |  | -        var alwaysWriteEntryPoint = context.CompilationProvider.Select((comp, ct) => comp.Options.OutputKind == OutputKind.ConsoleApplication && comp.GetEntryPoint(ct) is null);  | 
 | 78 | +        var compData = context.CompilationProvider.Select((comp, ct) => new CompData(  | 
 | 79 | +            assemblyName: comp.Assembly.MetadataName,  | 
 | 80 | +            entryPoint: comp.GetEntryPoint(ct),  | 
 | 81 | +            possibleEntryPoints: RoslynUtils.GetPossibleEntryPoints(comp, ct),  | 
 | 82 | +            outputKind: comp.Options.OutputKind));  | 
64 | 83 | 
 
  | 
65 | 84 |         var testsInSource =  | 
66 | 85 |             methodsInSource  | 
@@ -112,40 +131,65 @@ public void Initialize(IncrementalGeneratorInitializationContext context)  | 
112 | 131 |             .Collect()  | 
113 | 132 |             .Combine(context.AnalyzerConfigOptionsProvider)  | 
114 | 133 |             .Combine(aliasMap)  | 
115 |  | -            .Combine(assemblyName)  | 
116 |  | -            .Combine(alwaysWriteEntryPoint),  | 
 | 134 | +            .Combine(compData),  | 
117 | 135 |             static (context, data) =>  | 
118 | 136 |             {  | 
119 |  | -                var ((((methods, configOptions), aliasMap), assemblyName), alwaysWriteEntryPoint) = data;  | 
 | 137 | +                var (((methods, configOptions), aliasMap), compData) = data;  | 
120 | 138 | 
 
  | 
121 |  | -                if (methods.Length == 0 && !alwaysWriteEntryPoint)  | 
 | 139 | +                bool inMergedTestDirectory = configOptions.GlobalOptions.InMergedTestDirectory();  | 
 | 140 | +                if (inMergedTestDirectory)  | 
122 | 141 |                 {  | 
123 |  | -                    // If we have no test methods, assume that this project is not migrated to the new system yet  | 
124 |  | -                    // and that we shouldn't generate a no-op Main method.  | 
125 |  | -                    return;  | 
 | 142 | +                    CheckNoEntryPoint(context, compData);  | 
126 | 143 |                 }  | 
127 | 144 | 
 
  | 
128 |  | -                bool isMergedTestRunnerAssembly = configOptions.GlobalOptions.IsMergedTestRunnerAssembly();  | 
129 |  | -                configOptions.GlobalOptions.TryGetValue("build_property.TargetOS", out string? targetOS);  | 
130 |  | - | 
131 |  | -                if (isMergedTestRunnerAssembly)  | 
 | 145 | +                if (compData.OutputKind != OutputKind.ConsoleApplication)  | 
132 | 146 |                 {  | 
133 |  | -                    if (targetOS?.ToLowerInvariant() is "ios" or "iossimulator" or "tvos" or "tvossimulator" or "maccatalyst" or "android" or "browser")  | 
134 |  | -                    {  | 
135 |  | -                        context.AddSource("XHarnessRunner.g.cs", GenerateXHarnessTestRunner(methods, aliasMap, assemblyName));  | 
136 |  | -                    }  | 
137 |  | -                    else  | 
138 |  | -                    {  | 
139 |  | -                        context.AddSource("FullRunner.g.cs", GenerateFullTestRunner(methods, aliasMap, assemblyName));  | 
140 |  | -                    }  | 
 | 147 | +                    return;  | 
141 | 148 |                 }  | 
142 |  | -                else  | 
 | 149 | + | 
 | 150 | +                bool alwaysWriteEntryPoint = (compData.EntryPoint is null);  | 
 | 151 | +                if (methods.IsEmpty && !alwaysWriteEntryPoint)  | 
143 | 152 |                 {  | 
144 |  | -                    context.AddSource("SimpleRunner.g.cs", GenerateStandaloneSimpleTestRunner(methods, aliasMap));  | 
 | 153 | +                    // If we have no test methods, assume that this project is not migrated to the new system yet  | 
 | 154 | +                    // and that we shouldn't generate a no-op Main method.  | 
 | 155 | +                    return;  | 
145 | 156 |                 }  | 
 | 157 | + | 
 | 158 | +                AddRunnerSource(context, methods, configOptions, aliasMap, compData);  | 
146 | 159 |             });  | 
147 | 160 |     }  | 
148 | 161 | 
 
  | 
 | 162 | +    private static void AddRunnerSource(SourceProductionContext context, ImmutableArray<ITestInfo> methods, AnalyzerConfigOptionsProvider configOptions, ImmutableDictionary<string, string> aliasMap, CompData compData)  | 
 | 163 | +    {  | 
 | 164 | +        bool isMergedTestRunnerAssembly = configOptions.GlobalOptions.IsMergedTestRunnerAssembly();  | 
 | 165 | +        configOptions.GlobalOptions.TryGetValue("build_property.TargetOS", out string? targetOS);  | 
 | 166 | +        string assemblyName = compData.AssemblyName;  | 
 | 167 | + | 
 | 168 | +        if (isMergedTestRunnerAssembly)  | 
 | 169 | +        {  | 
 | 170 | +            if (targetOS?.ToLowerInvariant() is "ios" or "iossimulator" or "tvos" or "tvossimulator" or "maccatalyst" or "android" or "browser")  | 
 | 171 | +            {  | 
 | 172 | +                context.AddSource("XHarnessRunner.g.cs", GenerateXHarnessTestRunner(methods, aliasMap, assemblyName));  | 
 | 173 | +            }  | 
 | 174 | +            else  | 
 | 175 | +            {  | 
 | 176 | +                context.AddSource("FullRunner.g.cs", GenerateFullTestRunner(methods, aliasMap, assemblyName));  | 
 | 177 | +            }  | 
 | 178 | +        }  | 
 | 179 | +        else  | 
 | 180 | +        {  | 
 | 181 | +            context.AddSource("SimpleRunner.g.cs", GenerateStandaloneSimpleTestRunner(methods, aliasMap));  | 
 | 182 | +        }  | 
 | 183 | +    }  | 
 | 184 | + | 
 | 185 | +    private static void CheckNoEntryPoint(SourceProductionContext context, CompData compData)  | 
 | 186 | +    {  | 
 | 187 | +        foreach (IMethodSymbol entryPoint in compData.PossibleEntryPoints)  | 
 | 188 | +        {  | 
 | 189 | +            context.ReportDiagnostic(Diagnostic.Create(Descriptors.XUWG1001, entryPoint.Locations[0]));  | 
 | 190 | +        }  | 
 | 191 | +    }  | 
 | 192 | + | 
149 | 193 |     private static void AppendAliasMap(CodeBuilder builder, ImmutableDictionary<string, string> aliasMap)  | 
150 | 194 |     {  | 
151 | 195 |         bool didOutput = false;  | 
@@ -312,7 +356,7 @@ private static string GenerateXHarnessTestRunner(ImmutableArray<ITestInfo> testI  | 
312 | 356 |             builder.AppendLine("System.Collections.Generic.HashSet<string> testExclusionList = XUnitWrapperLibrary.TestFilter.LoadTestExclusionList();");  | 
313 | 357 |             builder.AppendLine($@"return await XHarnessRunnerLibrary.RunnerEntryPoint.RunTests(RunTests, ""{assemblyName}"", args.Length != 0 ? args[0] : null, testExclusionList);");  | 
314 | 358 |         }  | 
315 |  | -        builder.AppendLine("catch(System.Exception ex)");  | 
 | 359 | +        builder.AppendLine("catch (System.Exception ex)");  | 
316 | 360 |         using (builder.NewBracesScope())  | 
317 | 361 |         {  | 
318 | 362 |             builder.AppendLine("System.Console.WriteLine(ex.ToString());");  | 
@@ -435,7 +479,7 @@ private static string GenerateStandaloneSimpleTestRunner(ImmutableArray<ITestInf  | 
435 | 479 |                         builder.Append(testInfo.GenerateTestExecution(reporter));  | 
436 | 480 |                     }  | 
437 | 481 |                 }  | 
438 |  | -                builder.AppendLine("catch(System.Exception ex)");  | 
 | 482 | +                builder.AppendLine("catch (System.Exception ex)");  | 
439 | 483 |                 using (builder.NewBracesScope())  | 
440 | 484 |                 {  | 
441 | 485 |                     builder.AppendLine("System.Console.WriteLine(ex.ToString());");  | 
 | 
0 commit comments