Skip to content

Commit

Permalink
Implement Assembly.GetCallingAssembly for native AOT
Browse files Browse the repository at this point in the history
The problem with this API has always been the possible non-existence of reflection metadata about the calling method. But I realized that the reflection metadata can be supplemented by stack trace metadata that also knows assemblies of all methods on the stack.

So this API is supportable as long as stack traces are not disabled. It's also conveniently easy to implement with the new `DiagnosticMethodInfo` API we added in .NET 9.

I'm not sure if we want to go as far as make this API not work on CoreCLR with JIT if `StackTraceSupport` is false, but we could do that. It might be preferable, but a bit breaking.

Resolves dotnet#94200.
  • Loading branch information
MichalStrehovsky committed Jul 1, 2024
1 parent fcdb6db commit 718501b
Show file tree
Hide file tree
Showing 8 changed files with 82 additions and 67 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@
// The .NET Foundation licenses this file to you under the MIT license.

using System.Configuration.Assemblies;
using System.Diagnostics;
using System.IO;
using System.Runtime.Serialization;
using System.Security;

using Internal.Reflection.Augments;

Expand All @@ -16,12 +18,42 @@ public abstract partial class Assembly : ICustomAttributeProvider, ISerializable
[System.Runtime.CompilerServices.Intrinsic]
public static Assembly GetExecutingAssembly() { throw NotImplemented.ByDesign; } //Implemented by toolchain.

[DynamicSecurityMethod]
public static Assembly GetCallingAssembly()
{
if (AppContext.TryGetSwitch("Switch.System.Reflection.Assembly.SimulatedCallingAssembly", out bool isSimulated) && isSimulated)
return GetEntryAssembly();

throw new PlatformNotSupportedException();
if (!StackTrace.IsSupported)
throw new NotSupportedException(SR.NotSupported_StackTraceSupportDisabled);

// We want to be able to handle GetCallingAssembly being called from Main (CoreCLR returns
// the assembly of Main), and also GetCallingAssembly being called from things like
// delegate invoke thunks. We do this by making the the definition of "calling assembly"
// a bit more loose.

// Technically we want skipFrames: 2 since we're interested in the frame that
// called the method that calls GetCallingAssembly, but we might need the first frame
// later in this method.
var stackTrace = new StackTrace(skipFrames: 1);

DiagnosticMethodInfo? dmi = null;

// Note: starting at index 1 since we want to skip the method that called GetCallingAssembly.
// We do a foreach so that we can skip any compiler-generated thunks that don't have method info.
for (int i = 1; i < stackTrace.FrameCount; i++)
{
dmi = DiagnosticMethodInfo.Create(stackTrace.GetFrame(i));
if (dmi != null)
break;
}

// If we haven't found anything in the entire stack trace, fall back
// to the method that called this method. This simulates what CoreCLR would
// do if GetCallingAssembly is called from e.g. Main.
dmi ??= stackTrace.GetFrame(0) is StackFrame sf ? DiagnosticMethodInfo.Create(sf) : null;

return dmi.DeclaringAssemblyName is string asmName ? Load(asmName) : null;
}

public static Assembly Load(AssemblyName assemblyRef) => ReflectionAugments.ReflectionCoreCallbacks.Load(assemblyRef, throwOnFileNotFound: true);
Expand Down
36 changes: 34 additions & 2 deletions src/coreclr/tools/Common/JitInterface/CorInfoImpl.cs
Original file line number Diff line number Diff line change
Expand Up @@ -1104,12 +1104,10 @@ private uint getMethodAttribsInternal(MethodDesc method)
if (method.IsPInvoke)
result |= CorInfoFlag.CORINFO_FLG_PINVOKE;

#if READYTORUN
if (method.RequireSecObject)
{
result |= CorInfoFlag.CORINFO_FLG_DONT_INLINE_CALLER;
}
#endif

if (method.IsAggressiveOptimization)
{
Expand Down Expand Up @@ -1177,6 +1175,40 @@ private void setMethodAttribs(CORINFO_METHOD_STRUCT_* ftn, CorInfoMethodRuntimeF
// TODO: Inlining
}

private bool canTailCall(CORINFO_METHOD_STRUCT_* callerHnd, CORINFO_METHOD_STRUCT_* declaredCalleeHnd, CORINFO_METHOD_STRUCT_* exactCalleeHnd, bool fIsTailPrefix)
{
if (!fIsTailPrefix)
{
MethodDesc caller = HandleToObject(callerHnd);

// Do not tailcall out of the entry point as it results in a confusing debugger experience.
if (caller is EcmaMethod em && em.Module.EntryPoint == caller)
{
return false;
}

// Do not tailcall from methods that are marked as noinline (people often use no-inline
// to mean "I want to always see this method in stacktrace")
if (caller.IsNoInlining)
{
return false;
}

// Methods with StackCrawlMark depend on finding their caller on the stack.
// If we tail call one of these guys, they get confused. For lack of
// a better way of identifying them, we use DynamicSecurity attribute to identify
// them.
//
MethodDesc callee = exactCalleeHnd == null ? null : HandleToObject(exactCalleeHnd);
if (callee != null && callee.RequireSecObject)
{
return false;
}
}

return true;
}

private void getMethodSig(CORINFO_METHOD_STRUCT_* ftn, CORINFO_SIG_INFO* sig, CORINFO_CLASS_STRUCT_* memberParent)
{
MethodDesc method = HandleToObject(ftn);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1294,40 +1294,6 @@ private void getFunctionEntryPoint(CORINFO_METHOD_STRUCT_* ftn, ref CORINFO_CONS
throw new RequiresRuntimeJitException(HandleToObject(ftn).ToString());
}

private bool canTailCall(CORINFO_METHOD_STRUCT_* callerHnd, CORINFO_METHOD_STRUCT_* declaredCalleeHnd, CORINFO_METHOD_STRUCT_* exactCalleeHnd, bool fIsTailPrefix)
{
if (!fIsTailPrefix)
{
MethodDesc caller = HandleToObject(callerHnd);

// Do not tailcall out of the entry point as it results in a confusing debugger experience.
if (caller is EcmaMethod em && em.Module.EntryPoint == caller)
{
return false;
}

// Do not tailcall from methods that are marked as noinline (people often use no-inline
// to mean "I want to always see this method in stacktrace")
if (caller.IsNoInlining)
{
return false;
}

// Methods with StackCrawlMark depend on finding their caller on the stack.
// If we tail call one of these guys, they get confused. For lack of
// a better way of identifying them, we use DynamicSecurity attribute to identify
// them.
//
MethodDesc callee = exactCalleeHnd == null ? null : HandleToObject(exactCalleeHnd);
if (callee != null && callee.RequireSecObject)
{
return false;
}
}

return true;
}

private FieldWithToken ComputeFieldWithToken(FieldDesc field, ref CORINFO_RESOLVED_TOKEN pResolvedToken)
{
ModuleToken token = HandleToModuleToken(ref pResolvedToken);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -824,34 +824,6 @@ method.OwningType is MetadataType mdType &&
pResult = CreateConstLookupToSymbol(_compilation.NodeFactory.MethodEntrypoint(method));
}

private bool canTailCall(CORINFO_METHOD_STRUCT_* callerHnd, CORINFO_METHOD_STRUCT_* declaredCalleeHnd, CORINFO_METHOD_STRUCT_* exactCalleeHnd, bool fIsTailPrefix)
{
// Assume we can tail call unless proved otherwise
bool result = true;

if (!fIsTailPrefix)
{
MethodDesc caller = HandleToObject(callerHnd);

if (caller.OwningType is EcmaType ecmaOwningType
&& ecmaOwningType.EcmaModule.EntryPoint == caller)
{
// Do not tailcall from the application entrypoint.
// We want Main to be visible in stack traces.
result = false;
}

if (caller.IsNoInlining)
{
// Do not tailcall from methods that are marked as noinline (people often use no-inline
// to mean "I want to always see this method in stacktrace")
result = false;
}
}

return result;
}

private InfoAccessType constructStringLiteral(CORINFO_MODULE_STRUCT_* module, mdToken metaTok, ref void* ppValue)
{
MethodIL methodIL = (MethodIL)HandleToObject((void*)module);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3017,6 +3017,9 @@
<data name="InvalidOperation_SpanOverlappedOperation" xml:space="preserve">
<value>This operation is invalid on overlapping buffers.</value>
</data>
<data name="NotSupported_StackTraceSupportDisabled" xml:space="preserve">
<value>Unable to retreive stack trace information when StackTraceSupport feature switch is set to false.</value>
</data>
<data name="InvalidOperation_TimeProviderNullLocalTimeZone" xml:space="preserve">
<value>The operation cannot be performed when TimeProvider.LocalTimeZone is null.</value>
</data>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -725,13 +725,21 @@ public static IEnumerable<object[]> GetCallingAssembly_TestData()

[Theory]
[ActiveIssue("https://github.com/dotnet/runtime/issues/51673", typeof(PlatformDetection), nameof(PlatformDetection.IsBrowser), nameof(PlatformDetection.IsMonoAOT))]
[ActiveIssue("https://github.com/dotnet/runtime/issues/69919", typeof(PlatformDetection), nameof(PlatformDetection.IsNativeAot))]
[MemberData(nameof(GetCallingAssembly_TestData))]
public void GetCallingAssembly(Assembly assembly1, Assembly assembly2, bool expected)
{
Assert.Equal(expected, assembly1.Equals(assembly2));
}

[Fact]
[ActiveIssue("https://github.com/dotnet/runtime/issues/51673", typeof(PlatformDetection), nameof(PlatformDetection.IsBrowser), nameof(PlatformDetection.IsMonoAOT))]
public void GetCallingAssemblyThroughDelegate()
{
Func<Assembly> del = static () => Assembly.GetCallingAssembly();
Assert.Equal(Assembly.GetExecutingAssembly(), del());
Assert.Equal(typeof(System.Reflection.TestAssembly.ClassToInvoke).Assembly, System.Reflection.TestAssembly.ClassToInvoke.InvokeDelegate(del));
}

[Fact]
public void GetExecutingAssembly()
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -679,7 +679,6 @@ public static void InvokeCopiesBackMissingParameterAndArgument()

[Fact]
[ActiveIssue("https://github.com/dotnet/runtime/issues/50957", typeof(PlatformDetection), nameof(PlatformDetection.IsMonoInterpreter))]
[ActiveIssue("https://github.com/dotnet/runtime/issues/69919", typeof(PlatformDetection), nameof(PlatformDetection.IsNativeAot))]
public static void CallStackFrame_AggressiveInlining()
{
MethodInfo mi = typeof(System.Reflection.TestAssembly.ClassToInvoke).GetMethod(nameof(System.Reflection.TestAssembly.ClassToInvoke.CallMe_AggressiveInlining),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,5 +26,8 @@ private static Assembly CallMeActual()

[MethodImpl(MethodImplOptions.NoInlining)]
public static int CallMe_AvoidTailcall() => 42;

[MethodImpl(MethodImplOptions.NoInlining)]
public static Assembly InvokeDelegate(Func<Assembly> del) => del();
}
}

0 comments on commit 718501b

Please sign in to comment.