Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
156 changes: 86 additions & 70 deletions csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.shared.cs
Original file line number Diff line number Diff line change
Expand Up @@ -892,99 +892,115 @@ internal class NativeLib
/// On Windows, it explicitly loads the library with a lowercase .dll extension to handle
/// case-sensitive filesystems.
/// </summary>
#if NET5_0_OR_GREATER
[System.Diagnostics.CodeAnalysis.UnconditionalSuppressMessage("SingleFile", "IL3000:Avoid accessing Assembly file path when publishing as a single file", Justification = "We also check AppContext.BaseDirectory as a fallback")]
#endif
private static IntPtr DllImportResolver(string libraryName, Assembly assembly, DllImportSearchPath? searchPath)
{
if (libraryName == NativeLib.DllName || libraryName == OrtExtensionsNativeMethods.ExtensionsDllName)
try
{
string mappedName = null;
if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows))
{
// Explicitly load with .dll extension to avoid issues where the OS might try .DLL
mappedName = libraryName + ".dll";
}
else if (RuntimeInformation.IsOSPlatform(OSPlatform.Linux))
if (libraryName == NativeLib.DllName || libraryName == OrtExtensionsNativeMethods.ExtensionsDllName)
{
// Explicitly load with .so extension and lib prefix
mappedName = "lib" + libraryName + ".so";
}
else if (RuntimeInformation.IsOSPlatform(OSPlatform.OSX))
{
// Explicitly load with .dylib extension and lib prefix
mappedName = "lib" + libraryName + ".dylib";
}

if (mappedName != null)
{
// 1. Try default loading (name only)
if (NativeLibrary.TryLoad(mappedName, assembly, searchPath, out IntPtr handle))
string mappedName = null;
if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows))
{
return handle;
// Explicitly load with .dll extension to avoid issues where the OS might try .DLL
mappedName = libraryName + ".dll";
}

// 2. Try relative to assembly location (look into runtimes subfolders)
string assemblyLocation = null;
try { assemblyLocation = assembly.Location; } catch { }
if (!string.IsNullOrEmpty(assemblyLocation))
else if (RuntimeInformation.IsOSPlatform(OSPlatform.Linux))
{
// Explicitly load with .so extension and lib prefix
mappedName = "lib" + libraryName + ".so";
}
else if (RuntimeInformation.IsOSPlatform(OSPlatform.OSX))
{
string assemblyDir = System.IO.Path.GetDirectoryName(assemblyLocation);
string rid = RuntimeInformation.RuntimeIdentifier;
// Explicitly load with .dylib extension and lib prefix
mappedName = "lib" + libraryName + ".dylib";
}

// Probe the specific RID first, then common fallbacks for the current OS
string[] ridsToTry;
if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows))
{
ridsToTry = new[] { rid, "win-x64", "win-arm64" };
}
else if (RuntimeInformation.IsOSPlatform(OSPlatform.Linux))
{
ridsToTry = new[] { rid, "linux-x64", "linux-arm64" };
}
else if (RuntimeInformation.IsOSPlatform(OSPlatform.OSX))
{
// We no longer provide osx-x64 in official package since 1.24.
// However, we keep it in the list for build-from-source users.
ridsToTry = new[] { rid, "osx-arm64", "osx-x64" };
}
else
if (mappedName != null)
{
// 1. Try default loading (name only)
if (NativeLibrary.TryLoad(mappedName, assembly, searchPath, out IntPtr handle))
{
ridsToTry = new[] { rid };
return handle;
}

foreach (var tryRid in ridsToTry)
// 2. Try relative to assembly location (look into runtimes subfolders)
string assemblyLocation = null;
try { assemblyLocation = assembly.Location; } catch { }
if (!string.IsNullOrEmpty(assemblyLocation))
{
string probePath = System.IO.Path.Combine(assemblyDir, "runtimes", tryRid, "native", mappedName);
if (System.IO.File.Exists(probePath) && NativeLibrary.TryLoad(probePath, assembly, searchPath, out handle))
string assemblyDir = System.IO.Path.GetDirectoryName(assemblyLocation);
string rid = RuntimeInformation.RuntimeIdentifier;

// Probe the specific RID first, then common fallbacks for the current OS
string[] ridsToTry;
if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows))
{
LogLibLoad($"[DllImportResolver] Loaded {mappedName} from: {probePath}");
return handle;
ridsToTry = new[] { rid, "win-x64", "win-arm64" };
}
else if (RuntimeInformation.IsOSPlatform(OSPlatform.Linux))
{
ridsToTry = new[] { rid, "linux-x64", "linux-arm64" };
}
else if (RuntimeInformation.IsOSPlatform(OSPlatform.OSX))
{
// We no longer provide osx-x64 in official package since 1.24.
// However, we keep it in the list for build-from-source users.
ridsToTry = new[] { rid, "osx-arm64", "osx-x64" };
}
else
{
ridsToTry = new[] { rid };
}
}
}

// 3. Try AppContext.BaseDirectory as a fallback
string baseDir = AppContext.BaseDirectory;
if (!string.IsNullOrEmpty(baseDir))
{
string probePath = System.IO.Path.Combine(baseDir, mappedName);
if (NativeLibrary.TryLoad(probePath, assembly, searchPath, out handle))
{
LogLibLoad($"[DllImportResolver] Loaded {mappedName} from: {probePath}");
return handle;
foreach (var tryRid in ridsToTry)
{
string probePath = System.IO.Path.Combine(assemblyDir, "runtimes", tryRid, "native", mappedName);
if (System.IO.File.Exists(probePath) && NativeLibrary.TryLoad(probePath, assembly, searchPath, out handle))
{
LogLibLoad($"[DllImportResolver] Loaded {mappedName} from: {probePath}");
return handle;
}
}
}

string rid = RuntimeInformation.RuntimeIdentifier;
probePath = System.IO.Path.Combine(baseDir, "runtimes", rid, "native", mappedName);
if (NativeLibrary.TryLoad(probePath, assembly, searchPath, out handle))
// 3. Try AppContext.BaseDirectory as a fallback
try
{
LogLibLoad($"[DllImportResolver] Loaded {mappedName} from: {probePath}");
return handle;
string baseDir = AppContext.BaseDirectory;
if (!string.IsNullOrEmpty(baseDir))
{
string probePath = System.IO.Path.Combine(baseDir, mappedName);
if (NativeLibrary.TryLoad(probePath, assembly, searchPath, out handle))
{
LogLibLoad($"[DllImportResolver] Loaded {mappedName} from: {probePath}");
return handle;
}

string rid = RuntimeInformation.RuntimeIdentifier;
probePath = System.IO.Path.Combine(baseDir, "runtimes", rid, "native", mappedName);
if (NativeLibrary.TryLoad(probePath, assembly, searchPath, out handle))
{
LogLibLoad($"[DllImportResolver] Loaded {mappedName} from: {probePath}");
return handle;
}
}
}
}
catch { } // Ignore AppDomainUnloadedException or similar from AppContext.BaseDirectory

LogLibLoad($"[DllImportResolver] Failed loading {mappedName} (RID: {RuntimeInformation.RuntimeIdentifier}, Assembly: {assemblyLocation})");
LogLibLoad($"[DllImportResolver] Failed loading {mappedName} (RID: {RuntimeInformation.RuntimeIdentifier}, Assembly: {assemblyLocation})");

}
}
}
catch (Exception ex)
{
// Unhandled exceptions inside DllImportResolver can result in TypeInitializationException.
// Log and swallow the error, returning IntPtr.Zero to fall back to default CLR logic.
try { System.Diagnostics.Trace.WriteLine($"[DllImportResolver] Exception during resolution: {ex}"); } catch { }
}

// Fall back to default resolution
return IntPtr.Zero;
Expand Down
43 changes: 43 additions & 0 deletions csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/OrtEnvTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -489,4 +489,47 @@ void TestCopyTensors()
}
}
}

[Collection("Ort Inference Tests")]
public class OrtEnvDllImportResolverTest
{
[Fact(DisplayName = "TestDllImportResolverDoesNotThrow")]
public void TestDllImportResolverDoesNotThrow()
{
// The DllImportResolver is a private static method in NativeMethods.
var nativeMethodsType = typeof(OrtEnv).Assembly.GetType("Microsoft.ML.OnnxRuntime.NativeMethods");
Assert.NotNull(nativeMethodsType);

// It might not be defined on all platforms (defined when !NETSTANDARD2_0 && !__ANDROID__ && !__IOS__).
var resolverMethod = nativeMethodsType.GetMethod("DllImportResolver", System.Reflection.BindingFlags.NonPublic | System.Reflection.BindingFlags.Static);

if (resolverMethod != null)
{
try
{
// Invoke with null assembly to force it into edge cases where assembly.Location would throw NullReferenceException.
// It should catch the exception and return IntPtr.Zero gracefully rather than throwing.
var result = resolverMethod.Invoke(null, new object[] { "onnxruntime", null, null });

// If it reaches here without throwing TargetInvocationException, the try-catch in DllImportResolver works.
Assert.True(result is IntPtr);
}
catch (System.Reflection.TargetInvocationException ex)
{
// If NativeMethods..cctor() threw because the native library is missing,
// we will get a TypeInitializationException wrapping a DllNotFoundException (or DllImportException).
// This is acceptable locally. What we want to avoid is NullReferenceException from DllImportResolver.
if (ex.InnerException is TypeInitializationException typeInitEx)
{
Assert.IsNotType<NullReferenceException>(typeInitEx.InnerException);
}
else
{
Assert.IsNotType<NullReferenceException>(ex.InnerException);
throw;
}
}
}
}
}
}
Loading