From bd75d25dff28af24a57fc8ad30ef22478a55bdd7 Mon Sep 17 00:00:00 2001 From: Ryan Hill Date: Fri, 21 Mar 2025 16:09:55 -0700 Subject: [PATCH 1/2] Show better DLL errors on library load failure --- src/dll_load_error.cpp | 58 ++++++++++++++++++++++++++++++++++++++++++ src/dll_load_error.h | 1 + src/generators.cpp | 4 ++- src/models/model.h | 8 +++--- 4 files changed, 66 insertions(+), 5 deletions(-) create mode 100644 src/dll_load_error.cpp create mode 100644 src/dll_load_error.h diff --git a/src/dll_load_error.cpp b/src/dll_load_error.cpp new file mode 100644 index 0000000000..5950834f6c --- /dev/null +++ b/src/dll_load_error.cpp @@ -0,0 +1,58 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +#if defined(_WIN32) + +#include +#include +#include +#include +#pragma comment(lib, "dbghelp.lib") + +struct HMODULE_Deleter { + typedef HMODULE pointer; + void operator()(HMODULE h) { FreeLibrary(h); } +}; + +using ModulePtr = std::unique_ptr; + +std::string DetermineLoadLibraryError(const char* filename) { + std::string error("Error loading"); + + ModulePtr hModule; // Here so that filename is valid until the next iteration + while (filename) { + error += std::string(" \"") + filename + "\""; + + // We use DONT_RESOLVE_DLL_REFERENCES instead of LOAD_LIBRARY_AS_DATAFILE because the latter will not process the import table + // and will result in the IMAGE_IMPORT_DESCRIPTOR table names being uninitialized. + hModule = ModulePtr{LoadLibraryEx(filename, NULL, DONT_RESOLVE_DLL_REFERENCES)}; + if (!hModule) { + error += " which is missing. (Error " + std::to_string(GetLastError()) + ')'; + return error; + } + + // Get the address of the Import Directory + ULONG size{}; + IMAGE_IMPORT_DESCRIPTOR* import_desc = reinterpret_cast(ImageDirectoryEntryToData(hModule.get(), FALSE, IMAGE_DIRECTORY_ENTRY_IMPORT, &size)); + if (!import_desc) { + error += " No import directory found."; // This is unexpected, and I'm not sure how it could happen but we handle it just in case. + return error; + } + + // Iterate through the import descriptors to see which dependent DLL can't load + filename = nullptr; + for (; import_desc->Characteristics; import_desc++) { + char* dll_name = (char*)((BYTE*)(hModule.get()) + import_desc->Name); + // Try to load the dependent DLL, and if it fails, we loop again with this as the DLL and we'll be one step closer to the missing file. + ModulePtr hDepModule{LoadLibrary(dll_name)}; + if (!hDepModule) { + filename = dll_name; + error += " which depends on"; + break; + } + } + } + error += " But no dependency issue could be determined."; + return error; +} + +#endif diff --git a/src/dll_load_error.h b/src/dll_load_error.h new file mode 100644 index 0000000000..85eb2cb204 --- /dev/null +++ b/src/dll_load_error.h @@ -0,0 +1 @@ +std::string DetermineLoadLibraryError(const char* filename); \ No newline at end of file diff --git a/src/generators.cpp b/src/generators.cpp index 992664a237..c792f8da36 100644 --- a/src/generators.cpp +++ b/src/generators.cpp @@ -30,6 +30,8 @@ std::string CurrentModulePath() { return out_path; } + +#include "dll_load_error.h" #endif void ThrowErrorIfSessionTerminated(bool is_session_terminated) { @@ -132,7 +134,7 @@ struct LibraryHandle { auto path = CurrentModulePath() + filename; handle_ = LoadLibrary(path.c_str()); if (!handle_) - throw std::runtime_error(std::string("Failed to load library: ") + path + " Error: " + std::to_string(GetLastError())); + throw std::runtime_error(std::string("Failed to load library: ") + path + " Error: " + std::to_string(GetLastError()) + DetermineLoadLibraryError(filename)); }; ~LibraryHandle() { FreeLibrary(handle_); } diff --git a/src/models/model.h b/src/models/model.h index 0306dc13e2..3b8f8fe6f2 100644 --- a/src/models/model.h +++ b/src/models/model.h @@ -27,7 +27,7 @@ struct State { void SetTerminate(); void UnsetTerminate(); - mutable bool session_terminated_{}; + bool session_terminated_{}; OrtValue* GetInput(const char* name); virtual void RewindTo(size_t index) { (void)index; }; @@ -138,9 +138,9 @@ struct Model : std::enable_shared_from_this, LeakChecked, External std::unique_ptr config_; std::unique_ptr session_options_; - mutable DeviceInterface* p_device_{}; // The device we're running on (matches device_type_) used for things that work the same on all devices - mutable DeviceInterface* p_device_inputs_{}; // For some model inputs, the device might be the CPU device (all but KV cache currently for WebGPU and DML) - mutable DeviceInterface* p_device_kvcache_{}; // The kvcache is always allocated in device memory (TODO: Remove in favor of just p_device_?) + DeviceInterface* p_device_{}; // The device we're running on (matches device_type_) used for things that work the same on all devices + DeviceInterface* p_device_inputs_{}; // For some model inputs, the device might be the CPU device (all but KV cache currently for WebGPU and DML) + DeviceInterface* p_device_kvcache_{}; // The kvcache is always allocated in device memory (TODO: Remove in favor of just p_device_?) Ort::Allocator& allocator_cpu_{GetDeviceInterface(DeviceType::CPU)->GetAllocator()}; From e2a1adb256fd2ab2369d1f10e90828f09115aa51 Mon Sep 17 00:00:00 2001 From: RyanUnderhill <38674843+RyanUnderhill@users.noreply.github.com> Date: Tue, 25 Mar 2025 02:02:10 -0700 Subject: [PATCH 2/2] Detailed DLL dependency errors --- src/dll_load_error.cpp | 2 +- src/dll_load_error.h | 2 ++ src/generators.cpp | 2 +- 3 files changed, 4 insertions(+), 2 deletions(-) diff --git a/src/dll_load_error.cpp b/src/dll_load_error.cpp index 5950834f6c..31c956403f 100644 --- a/src/dll_load_error.cpp +++ b/src/dll_load_error.cpp @@ -4,7 +4,7 @@ #include #include -#include +#include #include #pragma comment(lib, "dbghelp.lib") diff --git a/src/dll_load_error.h b/src/dll_load_error.h index 85eb2cb204..3881e9fd82 100644 --- a/src/dll_load_error.h +++ b/src/dll_load_error.h @@ -1 +1,3 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. std::string DetermineLoadLibraryError(const char* filename); \ No newline at end of file diff --git a/src/generators.cpp b/src/generators.cpp index c792f8da36..4c3bd6832e 100644 --- a/src/generators.cpp +++ b/src/generators.cpp @@ -134,7 +134,7 @@ struct LibraryHandle { auto path = CurrentModulePath() + filename; handle_ = LoadLibrary(path.c_str()); if (!handle_) - throw std::runtime_error(std::string("Failed to load library: ") + path + " Error: " + std::to_string(GetLastError()) + DetermineLoadLibraryError(filename)); + throw std::runtime_error(std::string("Failed to load library: ") + DetermineLoadLibraryError(filename)); }; ~LibraryHandle() { FreeLibrary(handle_); }