diff --git a/src/dll_load_error.cpp b/src/dll_load_error.cpp new file mode 100644 index 0000000000..31c956403f --- /dev/null +++ b/src/dll_load_error.cpp @@ -0,0 +1,58 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +#if defined(_WIN32) + +#include +#include +#include +#include +#pragma comment(lib, "dbghelp.lib") + +struct HMODULE_Deleter { + typedef HMODULE pointer; + void operator()(HMODULE h) { FreeLibrary(h); } +}; + +using ModulePtr = std::unique_ptr; + +std::string DetermineLoadLibraryError(const char* filename) { + std::string error("Error loading"); + + ModulePtr hModule; // Here so that filename is valid until the next iteration + while (filename) { + error += std::string(" \"") + filename + "\""; + + // We use DONT_RESOLVE_DLL_REFERENCES instead of LOAD_LIBRARY_AS_DATAFILE because the latter will not process the import table + // and will result in the IMAGE_IMPORT_DESCRIPTOR table names being uninitialized. + hModule = ModulePtr{LoadLibraryEx(filename, NULL, DONT_RESOLVE_DLL_REFERENCES)}; + if (!hModule) { + error += " which is missing. (Error " + std::to_string(GetLastError()) + ')'; + return error; + } + + // Get the address of the Import Directory + ULONG size{}; + IMAGE_IMPORT_DESCRIPTOR* import_desc = reinterpret_cast(ImageDirectoryEntryToData(hModule.get(), FALSE, IMAGE_DIRECTORY_ENTRY_IMPORT, &size)); + if (!import_desc) { + error += " No import directory found."; // This is unexpected, and I'm not sure how it could happen but we handle it just in case. + return error; + } + + // Iterate through the import descriptors to see which dependent DLL can't load + filename = nullptr; + for (; import_desc->Characteristics; import_desc++) { + char* dll_name = (char*)((BYTE*)(hModule.get()) + import_desc->Name); + // Try to load the dependent DLL, and if it fails, we loop again with this as the DLL and we'll be one step closer to the missing file. + ModulePtr hDepModule{LoadLibrary(dll_name)}; + if (!hDepModule) { + filename = dll_name; + error += " which depends on"; + break; + } + } + } + error += " But no dependency issue could be determined."; + return error; +} + +#endif diff --git a/src/dll_load_error.h b/src/dll_load_error.h new file mode 100644 index 0000000000..3881e9fd82 --- /dev/null +++ b/src/dll_load_error.h @@ -0,0 +1,3 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +std::string DetermineLoadLibraryError(const char* filename); \ No newline at end of file diff --git a/src/generators.cpp b/src/generators.cpp index b7e8fd588d..d3e271098e 100644 --- a/src/generators.cpp +++ b/src/generators.cpp @@ -30,6 +30,8 @@ std::string CurrentModulePath() { return out_path; } + +#include "dll_load_error.h" #endif void ThrowErrorIfSessionTerminated(bool is_session_terminated) { @@ -132,7 +134,7 @@ struct LibraryHandle { auto path = CurrentModulePath() + filename; handle_ = LoadLibrary(path.c_str()); if (!handle_) - throw std::runtime_error(std::string("Failed to load library: ") + path + " Error: " + std::to_string(GetLastError())); + throw std::runtime_error(std::string("Failed to load library: ") + DetermineLoadLibraryError(filename)); }; ~LibraryHandle() { FreeLibrary(handle_); } diff --git a/src/models/model.h b/src/models/model.h index 9e10366d86..85a9635daa 100644 --- a/src/models/model.h +++ b/src/models/model.h @@ -26,7 +26,7 @@ struct State { void SetTerminate(); void UnsetTerminate(); - mutable bool session_terminated_{}; + bool session_terminated_{}; OrtValue* GetInput(const char* name); virtual void RewindTo(size_t index) { (void)index; }; @@ -135,9 +135,9 @@ struct Model : std::enable_shared_from_this, LeakChecked, External std::unique_ptr config_; std::unique_ptr session_options_; - mutable DeviceInterface* p_device_{}; // The device we're running on (matches device_type_) used for things that work the same on all devices - mutable DeviceInterface* p_device_inputs_{}; // For some model inputs, the device might be the CPU device (all but KV cache currently for WebGPU and DML) - mutable DeviceInterface* p_device_kvcache_{}; // The kvcache is always allocated in device memory (TODO: Remove in favor of just p_device_?) + DeviceInterface* p_device_{}; // The device we're running on (matches device_type_) used for things that work the same on all devices + DeviceInterface* p_device_inputs_{}; // For some model inputs, the device might be the CPU device (all but KV cache currently for WebGPU and DML) + DeviceInterface* p_device_kvcache_{}; // The kvcache is always allocated in device memory (TODO: Remove in favor of just p_device_?) Ort::Allocator& allocator_cpu_{GetDeviceInterface(DeviceType::CPU)->GetAllocator()};