Skip to content

Commit

Permalink
Avoid scanning typeof checks when building whole program view (#103883)
Browse files Browse the repository at this point in the history
Before this PR, we were somewhat able to eliminate dead typeof checks such as:

```csharp
if (someType == typeof(Foo)
{
    ExpensiveMethod();
}
```

This work was done in #102248.

However, the optimization only happened during codegen. This meant that when building the whole program view, we'd still look at `ExpensiveMethod` and whatever damage this caused to the whole program view was permanent.

With this PR, the scanner now becomes aware of the optimization we do during codegen and tries to defer injecting dependencies until we will need them.

With this change, we detect the conditional branch, and generate whatever dependencies from the basic block as conditional. That way scanning can fully skip scanning `ExpensiveMethod` and the subsequent optimization will ensure the missed scanning will not cause issues at codegen time.
  • Loading branch information
MichalStrehovsky authored Jul 18, 2024
1 parent fee1b41 commit 9c7ee97
Show file tree
Hide file tree
Showing 6 changed files with 184 additions and 26 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -85,14 +85,37 @@ public override IEnumerable<DependencyListEntry> GetStaticDependencies(NodeFacto
return dependencies;
}

public sealed override IEnumerable<CombinedDependencyListEntry> GetConditionalStaticDependencies(NodeFactory factory)
{
// Instantiate the runtime determined dependencies of the canonical method body
// with the concrete instantiation of the method to get concrete dependencies.
Instantiation typeInst = Method.OwningType.Instantiation;
Instantiation methodInst = Method.Instantiation;
IEnumerable<CombinedDependencyListEntry> staticDependencies = CanonicalMethodNode.GetConditionalStaticDependencies(factory);

if (staticDependencies != null)
{
foreach (CombinedDependencyListEntry canonDep in staticDependencies)
{
Debug.Assert(canonDep.OtherReasonNode is not INodeWithRuntimeDeterminedDependencies);

var node = canonDep.Node;
if (node is INodeWithRuntimeDeterminedDependencies runtimeDeterminedNode)
{
foreach (var nodeInner in runtimeDeterminedNode.InstantiateDependencies(factory, typeInst, methodInst))
yield return new CombinedDependencyListEntry(nodeInner.Node, canonDep.OtherReasonNode, nodeInner.Reason);
}
}
}
}

protected override string GetName(NodeFactory factory) => $"{Method} backed by {CanonicalMethodNode.GetMangledName(factory.NameMangler)}";

public sealed override bool HasConditionalStaticDependencies => false;
public sealed override bool HasConditionalStaticDependencies => CanonicalMethodNode.HasConditionalStaticDependencies;
public sealed override bool HasDynamicDependencies => false;
public sealed override bool InterestingForDynamicDependencyAnalysis => false;

public sealed override IEnumerable<CombinedDependencyListEntry> SearchDynamicDependencies(List<DependencyNodeCore<NodeFactory>> markedNodes, int firstNode, NodeFactory factory) => null;
public sealed override IEnumerable<CombinedDependencyListEntry> GetConditionalStaticDependencies(NodeFactory factory) => null;

int ISortableNode.ClassCode => -1440570971;

Expand Down
9 changes: 7 additions & 2 deletions src/coreclr/tools/Common/TypeSystem/IL/ILImporter.cs
Original file line number Diff line number Diff line change
Expand Up @@ -301,12 +301,17 @@ private void ImportBasicBlocks()
}

private void MarkBasicBlock(BasicBlock basicBlock)
{
MarkBasicBlock(basicBlock, ref _pendingBasicBlocks);
}

private static void MarkBasicBlock(BasicBlock basicBlock, ref BasicBlock list)
{
if (basicBlock.State == BasicBlock.ImportState.Unmarked)
{
// Link
basicBlock.Next = _pendingBasicBlocks;
_pendingBasicBlocks = basicBlock;
basicBlock.Next = list;
list = basicBlock;

basicBlock.State = BasicBlock.ImportState.IsPending;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ public class ScannedMethodNode : DependencyNodeCore<NodeFactory>, IMethodBodyNod
{
private readonly MethodDesc _method;
private DependencyList _dependencies;
private CombinedDependencyList _conditionalDependencies;

// If we failed to scan the method, the dependencies reported by the node will
// be for a throwing method body. This field will store the underlying cause of the failure.
Expand All @@ -42,11 +43,14 @@ public ScannedMethodNode(MethodDesc method)

public bool RepresentsIndirectionCell => false;

public override bool HasConditionalStaticDependencies => _conditionalDependencies != null;

public override bool StaticDependenciesAreComputed => _dependencies != null;

public void InitializeDependencies(NodeFactory factory, IEnumerable<DependencyListEntry> dependencies, TypeSystemException scanningException = null)
public void InitializeDependencies(NodeFactory factory, (DependencyList, CombinedDependencyList) dependencies, TypeSystemException scanningException = null)
{
_dependencies = new DependencyList(dependencies);
_dependencies = dependencies.Item1;
_conditionalDependencies = dependencies.Item2;

if (factory.TypeSystemContext.IsSpecialUnboxingThunk(_method))
{
Expand All @@ -72,19 +76,13 @@ public override IEnumerable<DependencyListEntry> GetStaticDependencies(NodeFacto
return _dependencies;
}

protected override string GetName(NodeFactory factory) => this.GetMangledName(factory.NameMangler);
public override IEnumerable<CombinedDependencyListEntry> GetConditionalStaticDependencies(NodeFactory factory) => _conditionalDependencies;

public override IEnumerable<CombinedDependencyListEntry> GetConditionalStaticDependencies(NodeFactory factory)
{
CombinedDependencyList dependencies = null;
CodeBasedDependencyAlgorithm.AddConditionalDependenciesDueToMethodCodePresence(ref dependencies, factory, _method);
return dependencies ?? (IEnumerable<CombinedDependencyListEntry>)Array.Empty<CombinedDependencyListEntry>();
}
protected override string GetName(NodeFactory factory) => this.GetMangledName(factory.NameMangler);

public override IEnumerable<CombinedDependencyListEntry> SearchDynamicDependencies(List<DependencyNodeCore<NodeFactory>> markedNodes, int firstNode, NodeFactory factory) => null;
public override bool InterestingForDynamicDependencyAnalysis => _method.HasInstantiation || _method.OwningType.HasInstantiation;
public override bool HasDynamicDependencies => false;
public override bool HasConditionalStaticDependencies => CodeBasedDependencyAlgorithm.HasConditionalDependenciesDueToMethodCodePresence(_method);

int ISortableNode.ClassCode => -1381809560;

Expand Down
115 changes: 105 additions & 10 deletions src/coreclr/tools/aot/ILCompiler.Compiler/IL/ILImporter.Scanner.cs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@

using Debug = System.Diagnostics.Debug;
using DependencyList = ILCompiler.DependencyAnalysisFramework.DependencyNodeCore<ILCompiler.DependencyAnalysis.NodeFactory>.DependencyList;
using CombinedDependencyList = System.Collections.Generic.List<ILCompiler.DependencyAnalysisFramework.DependencyNodeCore<ILCompiler.DependencyAnalysis.NodeFactory>.CombinedDependencyListEntry>;
using DependencyListEntry = ILCompiler.DependencyAnalysisFramework.DependencyNodeCore<ILCompiler.DependencyAnalysis.NodeFactory>.DependencyListEntry;

#pragma warning disable IDE0060

Expand All @@ -28,7 +30,7 @@ internal partial class ILImporter

private readonly MethodDesc _canonMethod;

private DependencyList _dependencies = new DependencyList();
private DependencyList _unconditionalDependencies = new DependencyList();

private readonly byte[] _ilBytes;

Expand All @@ -51,11 +53,17 @@ public enum ImportState : byte
public bool TryStart;
public bool FilterStart;
public bool HandlerStart;

public object Condition;
public DependencyList Dependencies;
}

private bool _isReadOnly;
private TypeDesc _constrained;

private DependencyList _dependencies;
private BasicBlock _lateBasicBlocks;

private sealed class ExceptionRegion
{
public ILExceptionRegion ILRegion;
Expand Down Expand Up @@ -107,9 +115,11 @@ public ILImporter(ILScanner compilation, MethodDesc method, MethodIL methodIL =
{
_exceptionRegions[i] = new ExceptionRegion() { ILRegion = ilExceptionRegions[i] };
}

_dependencies = _unconditionalDependencies;
}

public DependencyList Import()
public (DependencyList, CombinedDependencyList) Import()
{
TypeDesc owningType = _canonMethod.OwningType;
if (_compilation.HasLazyStaticConstructor(owningType))
Expand Down Expand Up @@ -172,9 +182,21 @@ public DependencyList Import()
FindBasicBlocks();
ImportBasicBlocks();

CodeBasedDependencyAlgorithm.AddDependenciesDueToMethodCodePresence(ref _dependencies, _factory, _canonMethod, _canonMethodIL);
CombinedDependencyList conditionalDependencies = null;
foreach (BasicBlock bb in _basicBlocks)
{
if (bb?.Condition == null)
continue;

conditionalDependencies ??= new CombinedDependencyList();
foreach (DependencyListEntry dep in bb.Dependencies)
conditionalDependencies.Add(new(dep.Node, bb.Condition, dep.Reason));
}

CodeBasedDependencyAlgorithm.AddDependenciesDueToMethodCodePresence(ref _unconditionalDependencies, _factory, _canonMethod, _canonMethodIL);
CodeBasedDependencyAlgorithm.AddConditionalDependenciesDueToMethodCodePresence(ref conditionalDependencies, _factory, _canonMethod);

return _dependencies;
return (_unconditionalDependencies, conditionalDependencies);
}

private ISymbolNode GetGenericLookupHelper(ReadyToRunHelperId helperId, object helperArgument)
Expand All @@ -199,19 +221,29 @@ private ISymbolNode GetHelperEntrypoint(ReadyToRunHelper helper)
}

private static void MarkInstructionBoundary() { }
private static void EndImportingBasicBlock(BasicBlock basicBlock) { }

private void EndImportingBasicBlock(BasicBlock basicBlock)
{
if (_pendingBasicBlocks == null)
{
_pendingBasicBlocks = _lateBasicBlocks;
_lateBasicBlocks = null;
}
}

private void StartImportingBasicBlock(BasicBlock basicBlock)
{
_dependencies = basicBlock.Condition != null ? basicBlock.Dependencies : _unconditionalDependencies;

// Import all associated EH regions
foreach (ExceptionRegion ehRegion in _exceptionRegions)
{
ILExceptionRegion region = ehRegion.ILRegion;
if (region.TryOffset == basicBlock.StartOffset)
{
MarkBasicBlock(_basicBlocks[region.HandlerOffset]);
ImportBasicBlockEdge(basicBlock, _basicBlocks[region.HandlerOffset]);
if (region.Kind == ILExceptionRegionKind.Filter)
MarkBasicBlock(_basicBlocks[region.FilterOffset]);
ImportBasicBlockEdge(basicBlock, _basicBlocks[region.FilterOffset]);

if (region.Kind == ILExceptionRegionKind.Catch)
{
Expand Down Expand Up @@ -789,10 +821,26 @@ private void ImportCalli(int token)

private void ImportBranch(ILOpcode opcode, BasicBlock target, BasicBlock fallthrough)
{
object condition = null;

if (opcode == ILOpcode.brfalse
&& _typeEqualityPatternAnalyzer.IsTypeEqualityBranch
&& !_typeEqualityPatternAnalyzer.IsTwoTokens
&& !_typeEqualityPatternAnalyzer.IsInequality)
{
TypeDesc typeEqualityCheckType = (TypeDesc)_canonMethodIL.GetObject(_typeEqualityPatternAnalyzer.Token1);
if (!typeEqualityCheckType.IsGenericDefinition
&& ConstructedEETypeNode.CreationAllowed(typeEqualityCheckType)
&& !typeEqualityCheckType.ConvertToCanonForm(CanonicalFormKind.Specific).IsCanonicalSubtype(CanonicalFormKind.Any))
{
condition = _factory.MaximallyConstructableType(typeEqualityCheckType);
}
}

ImportFallthrough(target);

if (fallthrough != null)
ImportFallthrough(fallthrough);
ImportFallthrough(fallthrough, condition);
}

private void ImportSwitchJump(int jmpBase, int[] jmpDelta, BasicBlock fallthrough)
Expand Down Expand Up @@ -1278,9 +1326,56 @@ private void ImportConvert(WellKnownType wellKnownType, bool checkOverflow, bool
}
}

private void ImportFallthrough(BasicBlock next)
private void ImportBasicBlockEdge(BasicBlock source, BasicBlock next, object condition = null)
{
// We don't track multiple conditions; if the source basic block is only reachable under a condition,
// this condition will be used for the next basic block, irrespective if we could make it more narrow.
object effectiveCondition = source.Condition ?? condition;

// Did we already look at this basic block?
if (next.State != BasicBlock.ImportState.Unmarked)
{
// If next is not conditioned, it stays not conditioned.
// If it was conditioned on something else, it needs to become unconditional.
// If the conditions match, it stays conditioned on the same thing.
if (next.Condition != null && next.Condition != effectiveCondition)
{
// Now we need to make `next` not conditioned. We move all of its dependencies to
// unconditional dependencies, and do this for all basic blocks that are reachable
// from it.
// TODO-SIZE: below doesn't do it for all basic blocks reachable from `next`, but
// for all basic blocks with the same conditon. This is a shortcut. It likely
// doesn't matter in practice.
object conditionToRemove = next.Condition;
foreach (BasicBlock bb in _basicBlocks)
{
if (bb?.Condition == conditionToRemove)
{
_unconditionalDependencies.AddRange(bb.Dependencies);
bb.Dependencies = null;
bb.Condition = null;
}
}
}
}
else
{
if (effectiveCondition != null)
{
next.Condition = effectiveCondition;
next.Dependencies = new DependencyList();
MarkBasicBlock(next, ref _lateBasicBlocks);
}
else
{
MarkBasicBlock(next);
}
}
}

private void ImportFallthrough(BasicBlock next, object condition = null)
{
MarkBasicBlock(next);
ImportBasicBlockEdge(_currentBasicBlock, next, condition);
}

private int ReadILTokenAt(int ilOffset)
Expand Down
4 changes: 2 additions & 2 deletions src/tests/nativeaot/SmokeTests/HardwareIntrinsics/Program.cs
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@ static int Main()
Console.WriteLine("****************************************************");

long lowerBound, upperBound;
lowerBound = 1300 * 1024; // ~1.3 MB
upperBound = 1900 * 1024; // ~1.90 MB
lowerBound = 1200 * 1024; // ~1.2 MB
upperBound = 1600 * 1024; // ~1.6 MB

if (fileSize < lowerBound || fileSize > upperBound)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ public static int Run()
TestArrayElementTypeOperations.Run();
TestStaticVirtualMethodOptimizations.Run();
TestTypeEquals.Run();
TestTypeEqualityDeadBranchScanRemoval.Run();
TestTypeIsEnum.Run();
TestTypeIsValueType.Run();
TestBranchesInGenericCodeRemoval.Run();
Expand Down Expand Up @@ -452,6 +453,42 @@ static void RunCheck<T>(Type t)
}
}

class TestTypeEqualityDeadBranchScanRemoval
{
class NeverAllocated1 { }
class NeverAllocated2 { }

class PossiblyAllocated1 { }
class PossiblyAllocated2 { }

[MethodImpl(MethodImplOptions.NoInlining)]
static Type GetNeverObject() => null;

static volatile Type s_sink;

public static void Run()
{
if (GetNeverObject() == typeof(NeverAllocated1))
{
Console.WriteLine(new NeverAllocated1().ToString());
Console.WriteLine(new NeverAllocated2().ToString());
}
#if !DEBUG
ThrowIfPresentWithUsableMethodTable(typeof(TestTypeEqualityDeadBranchScanRemoval), nameof(NeverAllocated1));
ThrowIfPresent(typeof(TestTypeEqualityDeadBranchScanRemoval), nameof(NeverAllocated2));
#endif

if (GetNeverObject() == typeof(PossiblyAllocated1))
{
Console.WriteLine(new PossiblyAllocated1().ToString());
Console.WriteLine(new PossiblyAllocated2().ToString());
}
if (Environment.GetEnvironmentVariable("SURETHING") != null)
s_sink = typeof(PossiblyAllocated1);
}
}


class TestTypeIsEnum
{
class Never { }
Expand Down

0 comments on commit 9c7ee97

Please sign in to comment.