Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 58 additions & 0 deletions src/dll_load_error.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#if defined(_WIN32)

#include <Windows.h>
#include <dbghelp.h>
#include <memory>
#include <string>
#pragma comment(lib, "dbghelp.lib")

struct HMODULE_Deleter {
typedef HMODULE pointer;
void operator()(HMODULE h) { FreeLibrary(h); }
};

using ModulePtr = std::unique_ptr<HMODULE, HMODULE_Deleter>;

std::string DetermineLoadLibraryError(const char* filename) {
std::string error("Error loading");

ModulePtr hModule; // Here so that filename is valid until the next iteration
while (filename) {
error += std::string(" \"") + filename + "\"";

// We use DONT_RESOLVE_DLL_REFERENCES instead of LOAD_LIBRARY_AS_DATAFILE because the latter will not process the import table
// and will result in the IMAGE_IMPORT_DESCRIPTOR table names being uninitialized.
hModule = ModulePtr{LoadLibraryEx(filename, NULL, DONT_RESOLVE_DLL_REFERENCES)};
if (!hModule) {
error += " which is missing. (Error " + std::to_string(GetLastError()) + ')';
return error;
}

// Get the address of the Import Directory
ULONG size{};
IMAGE_IMPORT_DESCRIPTOR* import_desc = reinterpret_cast<IMAGE_IMPORT_DESCRIPTOR*>(ImageDirectoryEntryToData(hModule.get(), FALSE, IMAGE_DIRECTORY_ENTRY_IMPORT, &size));
if (!import_desc) {
error += " No import directory found."; // This is unexpected, and I'm not sure how it could happen but we handle it just in case.
return error;
}

// Iterate through the import descriptors to see which dependent DLL can't load
filename = nullptr;
for (; import_desc->Characteristics; import_desc++) {
char* dll_name = (char*)((BYTE*)(hModule.get()) + import_desc->Name);
// Try to load the dependent DLL, and if it fails, we loop again with this as the DLL and we'll be one step closer to the missing file.
ModulePtr hDepModule{LoadLibrary(dll_name)};
if (!hDepModule) {
filename = dll_name;
error += " which depends on";
break;
}
}
}
error += " But no dependency issue could be determined.";
return error;
}

#endif
3 changes: 3 additions & 0 deletions src/dll_load_error.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
std::string DetermineLoadLibraryError(const char* filename);
4 changes: 3 additions & 1 deletion src/generators.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ std::string CurrentModulePath() {

return out_path;
}

#include "dll_load_error.h"
#endif

void ThrowErrorIfSessionTerminated(bool is_session_terminated) {
Expand Down Expand Up @@ -132,7 +134,7 @@ struct LibraryHandle {
auto path = CurrentModulePath() + filename;
handle_ = LoadLibrary(path.c_str());
if (!handle_)
throw std::runtime_error(std::string("Failed to load library: ") + path + " Error: " + std::to_string(GetLastError()));
throw std::runtime_error(std::string("Failed to load library: ") + DetermineLoadLibraryError(filename));
};

~LibraryHandle() { FreeLibrary(handle_); }
Expand Down
8 changes: 4 additions & 4 deletions src/models/model.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ struct State {

void SetTerminate();
void UnsetTerminate();
mutable bool session_terminated_{};
bool session_terminated_{};
OrtValue* GetInput(const char* name);

virtual void RewindTo(size_t index) { (void)index; };
Expand Down Expand Up @@ -135,9 +135,9 @@ struct Model : std::enable_shared_from_this<Model>, LeakChecked<Model>, External
std::unique_ptr<Config> config_;
std::unique_ptr<OrtSessionOptions> session_options_;

mutable DeviceInterface* p_device_{}; // The device we're running on (matches device_type_) used for things that work the same on all devices
mutable DeviceInterface* p_device_inputs_{}; // For some model inputs, the device might be the CPU device (all but KV cache currently for WebGPU and DML)
mutable DeviceInterface* p_device_kvcache_{}; // The kvcache is always allocated in device memory (TODO: Remove in favor of just p_device_?)
DeviceInterface* p_device_{}; // The device we're running on (matches device_type_) used for things that work the same on all devices
DeviceInterface* p_device_inputs_{}; // For some model inputs, the device might be the CPU device (all but KV cache currently for WebGPU and DML)
DeviceInterface* p_device_kvcache_{}; // The kvcache is always allocated in device memory (TODO: Remove in favor of just p_device_?)

Ort::Allocator& allocator_cpu_{GetDeviceInterface(DeviceType::CPU)->GetAllocator()};

Expand Down
Loading