diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.shared.cs b/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.shared.cs index abe73b77f4071..a6b267c6802cf 100644 --- a/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.shared.cs +++ b/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.shared.cs @@ -892,99 +892,115 @@ internal class NativeLib /// On Windows, it explicitly loads the library with a lowercase .dll extension to handle /// case-sensitive filesystems. /// +#if NET5_0_OR_GREATER + [System.Diagnostics.CodeAnalysis.UnconditionalSuppressMessage("SingleFile", "IL3000:Avoid accessing Assembly file path when publishing as a single file", Justification = "We also check AppContext.BaseDirectory as a fallback")] +#endif private static IntPtr DllImportResolver(string libraryName, Assembly assembly, DllImportSearchPath? searchPath) { - if (libraryName == NativeLib.DllName || libraryName == OrtExtensionsNativeMethods.ExtensionsDllName) + try { - string mappedName = null; - if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows)) - { - // Explicitly load with .dll extension to avoid issues where the OS might try .DLL - mappedName = libraryName + ".dll"; - } - else if (RuntimeInformation.IsOSPlatform(OSPlatform.Linux)) + if (libraryName == NativeLib.DllName || libraryName == OrtExtensionsNativeMethods.ExtensionsDllName) { - // Explicitly load with .so extension and lib prefix - mappedName = "lib" + libraryName + ".so"; - } - else if (RuntimeInformation.IsOSPlatform(OSPlatform.OSX)) - { - // Explicitly load with .dylib extension and lib prefix - mappedName = "lib" + libraryName + ".dylib"; - } - - if (mappedName != null) - { - // 1. Try default loading (name only) - if (NativeLibrary.TryLoad(mappedName, assembly, searchPath, out IntPtr handle)) + string mappedName = null; + if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows)) { - return handle; + // Explicitly load with .dll extension to avoid issues where the OS might try .DLL + mappedName = libraryName + ".dll"; } - - // 2. Try relative to assembly location (look into runtimes subfolders) - string assemblyLocation = null; - try { assemblyLocation = assembly.Location; } catch { } - if (!string.IsNullOrEmpty(assemblyLocation)) + else if (RuntimeInformation.IsOSPlatform(OSPlatform.Linux)) + { + // Explicitly load with .so extension and lib prefix + mappedName = "lib" + libraryName + ".so"; + } + else if (RuntimeInformation.IsOSPlatform(OSPlatform.OSX)) { - string assemblyDir = System.IO.Path.GetDirectoryName(assemblyLocation); - string rid = RuntimeInformation.RuntimeIdentifier; + // Explicitly load with .dylib extension and lib prefix + mappedName = "lib" + libraryName + ".dylib"; + } - // Probe the specific RID first, then common fallbacks for the current OS - string[] ridsToTry; - if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows)) - { - ridsToTry = new[] { rid, "win-x64", "win-arm64" }; - } - else if (RuntimeInformation.IsOSPlatform(OSPlatform.Linux)) - { - ridsToTry = new[] { rid, "linux-x64", "linux-arm64" }; - } - else if (RuntimeInformation.IsOSPlatform(OSPlatform.OSX)) - { - // We no longer provide osx-x64 in official package since 1.24. - // However, we keep it in the list for build-from-source users. - ridsToTry = new[] { rid, "osx-arm64", "osx-x64" }; - } - else + if (mappedName != null) + { + // 1. Try default loading (name only) + if (NativeLibrary.TryLoad(mappedName, assembly, searchPath, out IntPtr handle)) { - ridsToTry = new[] { rid }; + return handle; } - foreach (var tryRid in ridsToTry) + // 2. Try relative to assembly location (look into runtimes subfolders) + string assemblyLocation = null; + try { assemblyLocation = assembly.Location; } catch { } + if (!string.IsNullOrEmpty(assemblyLocation)) { - string probePath = System.IO.Path.Combine(assemblyDir, "runtimes", tryRid, "native", mappedName); - if (System.IO.File.Exists(probePath) && NativeLibrary.TryLoad(probePath, assembly, searchPath, out handle)) + string assemblyDir = System.IO.Path.GetDirectoryName(assemblyLocation); + string rid = RuntimeInformation.RuntimeIdentifier; + + // Probe the specific RID first, then common fallbacks for the current OS + string[] ridsToTry; + if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows)) { - LogLibLoad($"[DllImportResolver] Loaded {mappedName} from: {probePath}"); - return handle; + ridsToTry = new[] { rid, "win-x64", "win-arm64" }; + } + else if (RuntimeInformation.IsOSPlatform(OSPlatform.Linux)) + { + ridsToTry = new[] { rid, "linux-x64", "linux-arm64" }; + } + else if (RuntimeInformation.IsOSPlatform(OSPlatform.OSX)) + { + // We no longer provide osx-x64 in official package since 1.24. + // However, we keep it in the list for build-from-source users. + ridsToTry = new[] { rid, "osx-arm64", "osx-x64" }; + } + else + { + ridsToTry = new[] { rid }; } - } - } - // 3. Try AppContext.BaseDirectory as a fallback - string baseDir = AppContext.BaseDirectory; - if (!string.IsNullOrEmpty(baseDir)) - { - string probePath = System.IO.Path.Combine(baseDir, mappedName); - if (NativeLibrary.TryLoad(probePath, assembly, searchPath, out handle)) - { - LogLibLoad($"[DllImportResolver] Loaded {mappedName} from: {probePath}"); - return handle; + foreach (var tryRid in ridsToTry) + { + string probePath = System.IO.Path.Combine(assemblyDir, "runtimes", tryRid, "native", mappedName); + if (System.IO.File.Exists(probePath) && NativeLibrary.TryLoad(probePath, assembly, searchPath, out handle)) + { + LogLibLoad($"[DllImportResolver] Loaded {mappedName} from: {probePath}"); + return handle; + } + } } - string rid = RuntimeInformation.RuntimeIdentifier; - probePath = System.IO.Path.Combine(baseDir, "runtimes", rid, "native", mappedName); - if (NativeLibrary.TryLoad(probePath, assembly, searchPath, out handle)) + // 3. Try AppContext.BaseDirectory as a fallback + try { - LogLibLoad($"[DllImportResolver] Loaded {mappedName} from: {probePath}"); - return handle; + string baseDir = AppContext.BaseDirectory; + if (!string.IsNullOrEmpty(baseDir)) + { + string probePath = System.IO.Path.Combine(baseDir, mappedName); + if (NativeLibrary.TryLoad(probePath, assembly, searchPath, out handle)) + { + LogLibLoad($"[DllImportResolver] Loaded {mappedName} from: {probePath}"); + return handle; + } + + string rid = RuntimeInformation.RuntimeIdentifier; + probePath = System.IO.Path.Combine(baseDir, "runtimes", rid, "native", mappedName); + if (NativeLibrary.TryLoad(probePath, assembly, searchPath, out handle)) + { + LogLibLoad($"[DllImportResolver] Loaded {mappedName} from: {probePath}"); + return handle; + } + } } - } + catch { } // Ignore AppDomainUnloadedException or similar from AppContext.BaseDirectory - LogLibLoad($"[DllImportResolver] Failed loading {mappedName} (RID: {RuntimeInformation.RuntimeIdentifier}, Assembly: {assemblyLocation})"); + LogLibLoad($"[DllImportResolver] Failed loading {mappedName} (RID: {RuntimeInformation.RuntimeIdentifier}, Assembly: {assemblyLocation})"); + } } } + catch (Exception ex) + { + // Unhandled exceptions inside DllImportResolver can result in TypeInitializationException. + // Log and swallow the error, returning IntPtr.Zero to fall back to default CLR logic. + try { System.Diagnostics.Trace.WriteLine($"[DllImportResolver] Exception during resolution: {ex}"); } catch { } + } // Fall back to default resolution return IntPtr.Zero; diff --git a/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/OrtEnvTests.cs b/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/OrtEnvTests.cs index 94f8e927c1331..aa1b683acd668 100644 --- a/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/OrtEnvTests.cs +++ b/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/OrtEnvTests.cs @@ -489,4 +489,47 @@ void TestCopyTensors() } } } + + [Collection("Ort Inference Tests")] + public class OrtEnvDllImportResolverTest + { + [Fact(DisplayName = "TestDllImportResolverDoesNotThrow")] + public void TestDllImportResolverDoesNotThrow() + { + // The DllImportResolver is a private static method in NativeMethods. + var nativeMethodsType = typeof(OrtEnv).Assembly.GetType("Microsoft.ML.OnnxRuntime.NativeMethods"); + Assert.NotNull(nativeMethodsType); + + // It might not be defined on all platforms (defined when !NETSTANDARD2_0 && !__ANDROID__ && !__IOS__). + var resolverMethod = nativeMethodsType.GetMethod("DllImportResolver", System.Reflection.BindingFlags.NonPublic | System.Reflection.BindingFlags.Static); + + if (resolverMethod != null) + { + try + { + // Invoke with null assembly to force it into edge cases where assembly.Location would throw NullReferenceException. + // It should catch the exception and return IntPtr.Zero gracefully rather than throwing. + var result = resolverMethod.Invoke(null, new object[] { "onnxruntime", null, null }); + + // If it reaches here without throwing TargetInvocationException, the try-catch in DllImportResolver works. + Assert.True(result is IntPtr); + } + catch (System.Reflection.TargetInvocationException ex) + { + // If NativeMethods..cctor() threw because the native library is missing, + // we will get a TypeInitializationException wrapping a DllNotFoundException (or DllImportException). + // This is acceptable locally. What we want to avoid is NullReferenceException from DllImportResolver. + if (ex.InnerException is TypeInitializationException typeInitEx) + { + Assert.IsNotType(typeInitEx.InnerException); + } + else + { + Assert.IsNotType(ex.InnerException); + throw; + } + } + } + } + } }