diff --git a/sycl/CMakeLists.txt b/sycl/CMakeLists.txt index 389342f8f71fb..609d758689a1d 100644 --- a/sycl/CMakeLists.txt +++ b/sycl/CMakeLists.txt @@ -309,6 +309,7 @@ add_custom_target( sycl-toolchain DEPENDS sycl-runtime-libraries sycl-compiler sycl-ls + win_proxy_loader ${XPTIFW_LIBS} COMMENT "Building SYCL compiler toolchain..." ) @@ -341,6 +342,8 @@ add_subdirectory( plugins ) add_subdirectory(tools) +add_subdirectory(win_proxy_loader) + if(SYCL_INCLUDE_TESTS) if(NOT LLVM_INCLUDE_TESTS) message(FATAL_ERROR @@ -383,6 +386,7 @@ set( SYCL_TOOLCHAIN_DEPLOY_COMPONENTS sycl libsycldevice level-zero-sycl-dev + win_proxy_loader ${XPTIFW_LIBS} ${SYCL_TOOLCHAIN_DEPS} ) @@ -450,6 +454,7 @@ if("esimd_emulator" IN_LIST SYCL_ENABLE_PLUGINS) endif() endif() + # Use it as fake dependency in order to force another command(s) to execute. add_custom_command(OUTPUT __force_it COMMAND "${CMAKE_COMMAND}" -E echo diff --git a/sycl/include/sycl/detail/pi.hpp b/sycl/include/sycl/detail/pi.hpp index 5d1272a2724d9..bfa4c55164cb9 100644 --- a/sycl/include/sycl/detail/pi.hpp +++ b/sycl/include/sycl/detail/pi.hpp @@ -62,6 +62,8 @@ enum TraceLevel { bool trace(TraceLevel level); #ifdef __SYCL_RT_OS_WINDOWS +// these same constants are used by win_proxy_loader.dll +// if a plugin is added here, add it there as well. #ifdef _MSC_VER #define __SYCL_OPENCL_PLUGIN_NAME "pi_opencl.dll" #define __SYCL_LEVEL_ZERO_PLUGIN_NAME "pi_level_zero.dll" diff --git a/sycl/plugins/common_win_pi_trace/common_win_pi_trace.hpp b/sycl/plugins/common_win_pi_trace/common_win_pi_trace.hpp new file mode 100644 index 0000000000000..754560db80dbe --- /dev/null +++ b/sycl/plugins/common_win_pi_trace/common_win_pi_trace.hpp @@ -0,0 +1,34 @@ +// this .hpp is injected. Be sure to define __SYCL_PLUGIN_DLL_NAME before +// including +#ifdef _WIN32 +#include +BOOL WINAPI DllMain(HINSTANCE hinstDLL, // handle to DLL module + DWORD fdwReason, // reason for calling function + LPVOID lpReserved) { // reserved + + bool PrintPiTrace = false; + static const char *PiTrace = std::getenv("SYCL_PI_TRACE"); + static const int PiTraceValue = PiTrace ? std::stoi(PiTrace) : 0; + if (PiTraceValue == -1 || PiTraceValue == 2) { // Means print all PI traces + PrintPiTrace = true; + } + + // Perform actions based on the reason for calling. + switch (fdwReason) { + case DLL_PROCESS_DETACH: + if (PrintPiTrace) + std::cout << "---> DLL_PROCESS_DETACH " << __SYCL_PLUGIN_DLL_NAME << "\n" + << std::endl; + + break; + case DLL_PROCESS_ATTACH: + if (PrintPiTrace) + std::cout << "---> DLL_PROCESS_ATTACH " << __SYCL_PLUGIN_DLL_NAME << "\n" + << std::endl; + case DLL_THREAD_ATTACH: + case DLL_THREAD_DETACH: + break; + } + return TRUE; +} +#endif // WIN32 diff --git a/sycl/plugins/cuda/pi_cuda.cpp b/sycl/plugins/cuda/pi_cuda.cpp index b02462259ea9d..728e102cc7814 100644 --- a/sycl/plugins/cuda/pi_cuda.cpp +++ b/sycl/plugins/cuda/pi_cuda.cpp @@ -5862,6 +5862,12 @@ pi_result piPluginInit(pi_plugin *PluginInit) { return PI_SUCCESS; } +#ifdef _WIN32 +#define __SYCL_PLUGIN_DLL_NAME "pi_cuda.dll" +#include "../common_win_pi_trace/common_win_pi_trace.hpp" +#undef __SYCL_PLUGIN_DLL_NAME +#endif + } // extern "C" CUevent _pi_platform::evBase_{nullptr}; diff --git a/sycl/plugins/esimd_emulator/pi_esimd_emulator.cpp b/sycl/plugins/esimd_emulator/pi_esimd_emulator.cpp index dfe2630e72708..6672f3840dd1e 100644 --- a/sycl/plugins/esimd_emulator/pi_esimd_emulator.cpp +++ b/sycl/plugins/esimd_emulator/pi_esimd_emulator.cpp @@ -2102,4 +2102,10 @@ pi_result piPluginInit(pi_plugin *PluginInit) { return PI_SUCCESS; } +#ifdef _WIN32 +#define __SYCL_PLUGIN_DLL_NAME "pi_esimd_emulator.dll" +#include "../common_win_pi_trace/common_win_pi_trace.hpp" +#undef __SYCL_PLUGIN_DLL_NAME +#endif + } // extern C diff --git a/sycl/plugins/hip/pi_hip.cpp b/sycl/plugins/hip/pi_hip.cpp index 37284eadbb81a..45ad1c9fc9800 100644 --- a/sycl/plugins/hip/pi_hip.cpp +++ b/sycl/plugins/hip/pi_hip.cpp @@ -5510,6 +5510,12 @@ pi_result piPluginInit(pi_plugin *PluginInit) { return PI_SUCCESS; } +#ifdef _WIN32 +#define __SYCL_PLUGIN_DLL_NAME "pi_hip.dll" +#include "../common_win_pi_trace/common_win_pi_trace.hpp" +#undef __SYCL_PLUGIN_DLL_NAME +#endif + } // extern "C" hipEvent_t _pi_platform::evBase_{nullptr}; diff --git a/sycl/plugins/level_zero/pi_level_zero.cpp b/sycl/plugins/level_zero/pi_level_zero.cpp index b8af995d3b56f..c7b5c038560ae 100644 --- a/sycl/plugins/level_zero/pi_level_zero.cpp +++ b/sycl/plugins/level_zero/pi_level_zero.cpp @@ -9438,4 +9438,10 @@ pi_result piGetDeviceAndHostTimer(pi_device Device, uint64_t *DeviceTime, } return PI_SUCCESS; } + +#ifdef _WIN32 +#define __SYCL_PLUGIN_DLL_NAME "pi_level_zero.dll" +#include "../common_win_pi_trace/common_win_pi_trace.hpp" +#undef __SYCL_PLUGIN_DLL_NAME +#endif } // extern "C" diff --git a/sycl/plugins/opencl/pi_opencl.cpp b/sycl/plugins/opencl/pi_opencl.cpp index 8c30389285c83..2d08a1cc5d2ce 100644 --- a/sycl/plugins/opencl/pi_opencl.cpp +++ b/sycl/plugins/opencl/pi_opencl.cpp @@ -1941,4 +1941,10 @@ pi_result piPluginInit(pi_plugin *PluginInit) { return PI_SUCCESS; } +#ifdef _WIN32 +#define __SYCL_PLUGIN_DLL_NAME "pi_opencl.dll" +#include "../common_win_pi_trace/common_win_pi_trace.hpp" +#undef __SYCL_PLUGIN_DLL_NAME +#endif + } // end extern 'C' diff --git a/sycl/source/CMakeLists.txt b/sycl/source/CMakeLists.txt index c9a1ba4101296..d642aee133130 100644 --- a/sycl/source/CMakeLists.txt +++ b/sycl/source/CMakeLists.txt @@ -16,6 +16,7 @@ if (SYCL_ENABLE_XPTI_TRACING) include_directories(${LLVM_EXTERNAL_XPTI_SOURCE_DIR}/include) endif() + function(add_sycl_rt_library LIB_NAME LIB_OBJ_NAME) # Add an optional argument so we can get the library name to # link with for Windows Debug version @@ -53,6 +54,14 @@ function(add_sycl_rt_library LIB_NAME LIB_OBJ_NAME) target_link_libraries(${LIB_NAME} PRIVATE ${ARG_XPTI_LIB}) endif() + # win_proxy_loader + include_directories(${LLVM_EXTERNAL_SYCL_SOURCE_DIR}/win_proxy_loader) + if(WIN_DUPE) + target_link_libraries(${LIB_NAME} PUBLIC win_proxy_loaderd) + else() + target_link_libraries(${LIB_NAME} PUBLIC win_proxy_loader) + endif() + target_compile_definitions(${LIB_OBJ_NAME} PRIVATE __SYCL_INTERNAL_API ) if (WIN32) @@ -215,11 +224,13 @@ if (MSVC) string(REGEX REPLACE "/MT" "" ${flag_var} "${${flag_var}}") endforeach() + set(WIN_DUPE "1") if (SYCL_ENABLE_XPTI_TRACING) add_sycl_rt_library(sycl${SYCL_MAJOR_VERSION}d sycld_object XPTI_LIB xptid COMPILE_OPTIONS "/MDd" SOURCES ${SYCL_SOURCES}) else() add_sycl_rt_library(sycl${SYCL_MAJOR_VERSION}d sycld_object COMPILE_OPTIONS "/MDd" SOURCES ${SYCL_SOURCES}) endif() + unset(WIN_DUPE) add_library(sycld ALIAS sycl${SYCL_MAJOR_VERSION}d) set(SYCL_EXTRA_OPTS "/MD") diff --git a/sycl/source/detail/context_impl.cpp b/sycl/source/detail/context_impl.cpp index 55b9ae796a45d..071bb155bf642 100644 --- a/sycl/source/detail/context_impl.cpp +++ b/sycl/source/detail/context_impl.cpp @@ -136,7 +136,7 @@ context_impl::~context_impl() { } if (!MHostContext) { // TODO catch an exception and put it to list of asynchronous exceptions - getPlugin().call(MContext); + getPlugin().call_nocheck(MContext); } } diff --git a/sycl/source/detail/event_impl.cpp b/sycl/source/detail/event_impl.cpp index 3e2cd116cfaad..33a337a6734dc 100644 --- a/sycl/source/detail/event_impl.cpp +++ b/sycl/source/detail/event_impl.cpp @@ -69,6 +69,13 @@ void event_impl::waitInternal() { "waitInternal method cannot be used for a discarded event."); } else if (MState != HES_Complete) { // Wait for the host event +#ifdef _WIN32 + // during shutdown it's possible that outstanding threads on win may + // be terminated, in which case the NotifyHostTaskComplete will not be + // called and the cv.wait() below will hang forever. + if (Scheduler::getInstance().isShuttingDown) + return; +#endif std::unique_lock lock(MMutex); cv.wait(lock, [this] { return MState == HES_Complete; }); } diff --git a/sycl/source/detail/global_handler.cpp b/sycl/source/detail/global_handler.cpp index 91aac69cd30aa..ec0865d80902b 100644 --- a/sycl/source/detail/global_handler.cpp +++ b/sycl/source/detail/global_handler.cpp @@ -44,6 +44,10 @@ class ObjectUsageCounter { MCounter++; } ~ObjectUsageCounter() { +#if defined(XPTI_ENABLE_INSTRUMENTATION) && defined(_WIN32) + if (xptiTraceEnabled()) + return; // When xpti tracing, can't safely perform some shutdown ops. +#endif if (!MModifyCounter) return; @@ -165,20 +169,14 @@ void GlobalHandler::releaseDefaultContexts() { // finished. To avoid calls to nowhere, intentionally leak platform to device // cache. This will prevent destructors from being called, thus no PI cleanup // routines will be called in the end. + // Update: the win_proxy_loader addresses this for SYCL's own dependencies, + // but the GPU device dlls seem to manually load yet another DLL which may + // have been released when this function is called. So we still release() and + // leak until that is addressed. DefaultContext destructs fine on CPU device. MPlatformToDefaultContextCache.Inst.release(); #endif } -struct DefaultContextReleaseHandler { - ~DefaultContextReleaseHandler() { - GlobalHandler::instance().releaseDefaultContexts(); - } -}; - -void GlobalHandler::registerDefaultContextReleaseHandler() { - static DefaultContextReleaseHandler handler{}; -} - // Note: Split from shutdown so it is available to the unittests for ensuring // that the mock plugin is the lone plugin. void GlobalHandler::unloadPlugins() { @@ -202,9 +200,9 @@ void GlobalHandler::unloadPlugins() { void GlobalHandler::prepareSchedulerToRelease() { #ifndef _WIN32 drainThreadPool(); +#endif if (MScheduler.Inst) MScheduler.Inst->releaseResources(); -#endif } void GlobalHandler::drainThreadPool() { @@ -251,13 +249,30 @@ void shutdown() { extern "C" __SYCL_EXPORT BOOL WINAPI DllMain(HINSTANCE hinstDLL, DWORD fdwReason, LPVOID lpReserved) { + bool PrintPiTrace = false; + static const char *PiTrace = std::getenv("SYCL_PI_TRACE"); + static const int PiTraceValue = PiTrace ? std::stoi(PiTrace) : 0; + if (PiTraceValue == -1 || PiTraceValue == 2) { // Means print all PI traces + PrintPiTrace = true; + } + // Perform actions based on the reason for calling. switch (fdwReason) { case DLL_PROCESS_DETACH: - if (!lpReserved) - shutdown(); + if (PrintPiTrace) + std::cout << "---> DLL_PROCESS_DETACH syclx.dll\n" << std::endl; + +#ifdef XPTI_ENABLE_INSTRUMENTATION + if (xptiTraceEnabled()) + return TRUE; // When doing xpti tracing, we can't safely call shutdown. + // TODO: figure out what XPTI is doing that prevents release. +#endif + + shutdown(); break; case DLL_PROCESS_ATTACH: + if (PrintPiTrace) + std::cout << "---> DLL_PROCESS_ATTACH syclx.dll\n" << std::endl; case DLL_THREAD_ATTACH: case DLL_THREAD_DETACH: break; diff --git a/sycl/source/detail/os_util.cpp b/sycl/source/detail/os_util.cpp index 4d710d8bde8be..0185a990444a1 100644 --- a/sycl/source/detail/os_util.cpp +++ b/sycl/source/detail/os_util.cpp @@ -194,6 +194,8 @@ OSModuleHandle OSUtil::getOSModuleHandle(const void *VirtAddr) { } /// Returns an absolute path where the object was found. +// win_proxy_loader.dll uses this same logic. If it is changed +// significantly, it might be wise to change it there too. std::string OSUtil::getCurrentDSODir() { char Path[MAX_PATH]; Path[0] = '\0'; diff --git a/sycl/source/detail/platform_impl.cpp b/sycl/source/detail/platform_impl.cpp index 2e9db4ac71632..78a7fe837d26e 100644 --- a/sycl/source/detail/platform_impl.cpp +++ b/sycl/source/detail/platform_impl.cpp @@ -137,16 +137,6 @@ std::vector platform_impl::get_platforms() { } } - // Register default context release handler after plugins have been loaded and - // after the first calls to each plugin. This initializes a function-local - // variable that should be destroyed before any global variables in the - // plugins are destroyed. This is done after the first call to the backends to - // ensure any lazy-loaded dependencies are loaded prior to the handler - // variable's initialization. Note: The default context release handler is not - // guaranteed to be destroyed before function-local static variables as they - // may be initialized after. - GlobalHandler::registerDefaultContextReleaseHandler(); - return Platforms; } diff --git a/sycl/source/detail/scheduler/scheduler.cpp b/sycl/source/detail/scheduler/scheduler.cpp index ed1a3ff8a2a47..5c3c063a5e24a 100644 --- a/sycl/source/detail/scheduler/scheduler.cpp +++ b/sycl/source/detail/scheduler/scheduler.cpp @@ -392,21 +392,29 @@ Scheduler::Scheduler() { Scheduler::~Scheduler() { DefaultHostQueue.reset(); } void Scheduler::releaseResources() { + + BlockingT blockValue = BlockingT::BLOCKING; + // There might be some commands scheduled for post enqueue cleanup that // haven't been freed because of the graph mutex being locked at the time, // clean them up now. cleanupCommands({}); - cleanupAuxiliaryResources(BlockingT::BLOCKING); + cleanupAuxiliaryResources(blockValue); + // We need loop since sometimes we may need new objects to be added to // deferred mem objects storage during cleanup. Known example is: we cleanup // existing deferred mem objects under write lock, during this process we // cleanup commands related to this record, command may have last reference to // queue_impl, ~queue_impl is called and buffer for assert (which is created - // with size only so all confitions for deferred release are satisfied) is + // with size only so all conditions for deferred release are satisfied) is // added to deferred mem obj storage. So we may end up with leak. + // Windows: once we are shutting down, we can't rely on thread completion. + // So we clean up the deferred mem objects, but don't loop. +#ifndef _WIN32 while (!isDeferredMemObjectsEmpty()) - cleanupDeferredMemObjects(BlockingT::BLOCKING); +#endif + cleanupDeferredMemObjects(blockValue); } MemObjRecord *Scheduler::getMemObjRecord(const Requirement *const Req) { @@ -492,6 +500,7 @@ void Scheduler::cleanupDeferredMemObjects(BlockingT Blocking) { if (isDeferredMemObjectsEmpty()) return; if (Blocking == BlockingT::BLOCKING) { + this->isShuttingDown = true; std::vector> TempStorage; { std::lock_guard LockDef{MDeferredMemReleaseMutex}; @@ -499,6 +508,7 @@ void Scheduler::cleanupDeferredMemObjects(BlockingT Blocking) { } // if any objects in TempStorage exist - it is leaving scope and being // deleted + return; } std::vector> ObjsReadyToRelease; diff --git a/sycl/source/detail/scheduler/scheduler.hpp b/sycl/source/detail/scheduler/scheduler.hpp index 3ee7abf36a0e3..2fd8b3efff481 100644 --- a/sycl/source/detail/scheduler/scheduler.hpp +++ b/sycl/source/detail/scheduler/scheduler.hpp @@ -499,6 +499,7 @@ class Scheduler { // May lock graph with read and write modes during execution. void cleanupDeferredMemObjects(BlockingT Blocking); + bool isShuttingDown = false; // POD struct to convey some additional information from GraphBuilder::addCG // to the Scheduler to support kernel fusion. diff --git a/sycl/source/detail/windows_pi.cpp b/sycl/source/detail/windows_pi.cpp index 034a5d49033de..50c6c582b3a12 100644 --- a/sycl/source/detail/windows_pi.cpp +++ b/sycl/source/detail/windows_pi.cpp @@ -13,29 +13,17 @@ #include #include +#include "win_proxy_loader.hpp" + namespace sycl { __SYCL_INLINE_VER_NAMESPACE(_V1) { namespace detail { namespace pi { void *loadOsLibrary(const std::string &PluginPath) { - // Tells the system to not display the critical-error-handler message box. - // Instead, the system sends the error to the calling process. - // This is crucial for graceful handling of plugins that couldn't be - // loaded, e.g. due to missing native run-times. - // TODO: add reporting in case of an error. - // NOTE: we restore the old mode to not affect user app behavior. - // - UINT SavedMode = SetErrorMode(SEM_FAILCRITICALERRORS); - // Exclude current directory from DLL search path - if (!SetDllDirectoryA("")) { - assert(false && "Failed to update DLL search path"); - } - auto Result = (void *)LoadLibraryA(PluginPath.c_str()); - (void)SetErrorMode(SavedMode); - if (!SetDllDirectoryA(nullptr)) { - assert(false && "Failed to restore DLL search path"); - } + // We fetch the preloaded plugin from the win_proxy_loader. + // The proxy_loader handles any required error suppression. + auto Result = getPreloadedPlugin(PluginPath); return Result; } diff --git a/sycl/unittests/buffer/BufferDestructionCheck.cpp b/sycl/unittests/buffer/BufferDestructionCheck.cpp index 0abe2918aea36..549eb806b178c 100644 --- a/sycl/unittests/buffer/BufferDestructionCheck.cpp +++ b/sycl/unittests/buffer/BufferDestructionCheck.cpp @@ -91,6 +91,7 @@ TEST_F(BufferDestructionCheck, BufferWithSizeOnlyDefault) { ASSERT_EQ(MockSchedulerPtr->MDeferredMemObjRelease.size(), 1u); EXPECT_EQ(MockSchedulerPtr->MDeferredMemObjRelease[0].get(), RawBufferImplPtr); + EXPECT_CALL(*MockCmd, Release).Times(1); } @@ -126,6 +127,7 @@ TEST_F(BufferDestructionCheck, BufferWithSizeOnlyNonDefaultAllocator) { MockCmd = addCommandToBuffer(Buf, Q); EXPECT_CALL(*MockCmd, Release).Times(1); } + ASSERT_EQ(MockSchedulerPtr->MDeferredMemObjRelease.size(), 1u); EXPECT_EQ(MockSchedulerPtr->MDeferredMemObjRelease[0].get(), RawBufferImplPtr); @@ -328,4 +330,4 @@ TEST_F(BufferDestructionCheck, ReadyToReleaseLogic) { // call here, no need for extra limitation EXPECT_CALL(*ReadCmd, Release).Times(1); EXPECT_CALL(*WriteCmd, Release).Times(1); -} \ No newline at end of file +} diff --git a/sycl/unittests/scheduler/GraphCleanup.cpp b/sycl/unittests/scheduler/GraphCleanup.cpp index ab1af38ac4da6..31f93d403499d 100644 --- a/sycl/unittests/scheduler/GraphCleanup.cpp +++ b/sycl/unittests/scheduler/GraphCleanup.cpp @@ -305,7 +305,10 @@ TEST_F(SchedulerTest, HostTaskCleanup) { // submitted to the thread pool, shortly after the event is marked // as complete. detail::GlobalHandler::instance().drainThreadPool(); +#ifndef _WIN32 + // Windows drainThreadPool does nothing at this time. ASSERT_EQ(EventImpl->getCommand(), nullptr); +#endif } struct AttachSchedulerWrapper { diff --git a/sycl/win_proxy_loader/CMakeLists.txt b/sycl/win_proxy_loader/CMakeLists.txt new file mode 100644 index 0000000000000..b1d87f59cbfaf --- /dev/null +++ b/sycl/win_proxy_loader/CMakeLists.txt @@ -0,0 +1,62 @@ + +project(win_proxy_loader) +add_library(win_proxy_loader SHARED win_proxy_loader.cpp) +if (WIN32) + install(TARGETS win_proxy_loader + RUNTIME DESTINATION "bin" COMPONENT win_proxy_loader + ) +endif() + +if (MSVC) + # MSVC provides two incompatible build variants for its CRT: release and debug + # To avoid potential issues in user code we also need to provide two kinds + # of SYCL Runtime Library for release and debug configurations. + set(WINUNLOAD_CXX_FLAGS "") + if (CMAKE_BUILD_TYPE MATCHES "Debug") + set(WINUNLOAD_CXX_FLAGS "${CMAKE_CXX_FLAGS_DEBUG}") + string(REPLACE "/MDd" "" WINUNLOAD_CXX_FLAGS "${WINUNLOAD_CXX_FLAGS}") + string(REPLACE "/MTd" "" WINUNLOAD_CXX_FLAGS "${WINUNLOAD_CXX_FLAGS}") + else() + if (CMAKE_BUILD_TYPE MATCHES "Release") + set(WINUNLOAD_CXX_FLAGS "${CMAKE_CXX_FLAGS_RELEASE}") + elseif (CMAKE_BUILD_TYPE MATCHES "RelWithDebInfo") + set(WINUNLOAD_CXX_FLAGS "${CMAKE_CXX_FLAGS_MINSIZEREL}") + elseif (CMAKE_BUILD_TYPE MATCHES "MinSizeRel") + set(WINUNLOAD_CXX_FLAGS "${CMAKE_CXX_FLAGS_RELWITHDEBINFO}") + endif() + string(REPLACE "/MD" "" WINUNLOAD_CXX_FLAGS "${WINUNLOAD_CXX_FLAGS}") + string(REPLACE "/MT" "" WINUNLOAD_CXX_FLAGS "${WINUNLOAD_CXX_FLAGS}") + endif() + + # target_compile_options requires list of options, not a string + string(REPLACE " " ";" WINUNLOAD_CXX_FLAGS "${WINUNLOAD_CXX_FLAGS}") + + set(WINUNLOAD_CXX_FLAGS_RELEASE "${WINUNLOAD_CXX_FLAGS};/MD") + set(WINUNLOAD_CXX_FLAGS_DEBUG "${WINUNLOAD_CXX_FLAGS};/MDd") + + # CMake automatically applies these flags to all targets. To override this + # behavior, options lists are reset. + set(CMAKE_CXX_FLAGS_RELEASE "") + set(CMAKE_CXX_FLAGS_MINSIZEREL "") + set(CMAKE_CXX_FLAGS_RELWITHDEBINFO "") + set(CMAKE_CXX_FLAGS_DEBUG "") +endif() + + +# Handle the debug version for the Microsoft compiler as a special case by +# creating a debug version of the static library that uses the flags used by +# the SYCL runtime +if (MSVC) + add_library(win_proxy_loaderd SHARED win_proxy_loader.cpp) + target_compile_options(win_proxy_loaderd PRIVATE ${WINUNLOAD_CXX_FLAGS_DEBUG}) + target_compile_options(win_proxy_loader PRIVATE ${WINUNLOAD_CXX_FLAGS_RELEASE}) + target_link_libraries(win_proxy_loaderd PRIVATE shlwapi) + target_link_libraries(win_proxy_loader PRIVATE shlwapi) + if (WIN32) + install(TARGETS win_proxy_loaderd + RUNTIME DESTINATION "bin" COMPONENT win_proxy_loader + ) + endif() +endif() + + diff --git a/sycl/win_proxy_loader/win_proxy_loader.cpp b/sycl/win_proxy_loader/win_proxy_loader.cpp new file mode 100644 index 0000000000000..998352bf00ffd --- /dev/null +++ b/sycl/win_proxy_loader/win_proxy_loader.cpp @@ -0,0 +1,194 @@ +#include + +#ifdef _WIN32 + +#include +#include +#include +#include + +#endif + +#include +#include +#include + +#include "win_proxy_loader.hpp" + +#ifdef _WIN32 + +// ------------------------------------ + +static constexpr const char *DirSep = "\\"; +using OSModuleHandle = intptr_t; +/// Module handle for the executable module - it is assumed there is always +/// single one at most. +static constexpr OSModuleHandle ExeModuleHandle = -1; + +// cribbed from sycl/source/detail/os_util.cpp +std::string getDirName(const char *Path) { + std::string Tmp(Path); + // Remove trailing directory separators + Tmp.erase(Tmp.find_last_not_of("/\\") + 1, std::string::npos); + + size_t pos = Tmp.find_last_of("/\\"); + if (pos != std::string::npos) + return Tmp.substr(0, pos); + + // If no directory separator is present return initial path like dirname does + return Tmp; +} + +// cribbed from sycl/source/detail/os_util.cpp +OSModuleHandle getOSModuleHandle(const void *VirtAddr) { + HMODULE PhModule; + DWORD Flag = GET_MODULE_HANDLE_EX_FLAG_FROM_ADDRESS | + GET_MODULE_HANDLE_EX_FLAG_UNCHANGED_REFCOUNT; + auto LpModuleAddr = reinterpret_cast(VirtAddr); + if (!GetModuleHandleExA(Flag, LpModuleAddr, &PhModule)) { + // Expect the caller to check for zero and take + // necessary action + return 0; + } + if (PhModule == GetModuleHandleA(nullptr)) + return ExeModuleHandle; + return reinterpret_cast(PhModule); +} + +// cribbed from sycl/source/detail/os_util.cpp +/// Returns an absolute path where the object was found. +std::string getCurrentDSODir() { + char Path[MAX_PATH]; + Path[0] = '\0'; + Path[sizeof(Path) - 1] = '\0'; + auto Handle = getOSModuleHandle(reinterpret_cast(&getCurrentDSODir)); + DWORD Ret = GetModuleFileNameA( + reinterpret_cast(ExeModuleHandle == Handle ? 0 : Handle), + reinterpret_cast(&Path), sizeof(Path)); + assert(Ret < sizeof(Path) && "Path is longer than PATH_MAX?"); + assert(Ret > 0 && "GetModuleFileNameA failed"); + (void)Ret; + + BOOL RetCode = PathRemoveFileSpecA(reinterpret_cast(&Path)); + assert(RetCode && "PathRemoveFileSpecA failed"); + (void)RetCode; + + return Path; +} + +// these are cribbed from include/sycl/detail/pi.hpp +// a new plugin must be added to both places. +#ifdef _MSC_VER +#define __SYCL_OPENCL_PLUGIN_NAME "pi_opencl.dll" +#define __SYCL_LEVEL_ZERO_PLUGIN_NAME "pi_level_zero.dll" +#define __SYCL_CUDA_PLUGIN_NAME "pi_cuda.dll" +#define __SYCL_ESIMD_EMULATOR_PLUGIN_NAME "pi_esimd_emulator.dll" +#define __SYCL_HIP_PLUGIN_NAME "libpi_hip.dll" +#define __SYCL_UNIFIED_RUNTIME_PLUGIN_NAME "pi_unified_runtime.dll" +#else // llvm-mingw +#define __SYCL_OPENCL_PLUGIN_NAME "libpi_opencl.dll" +#define __SYCL_LEVEL_ZERO_PLUGIN_NAME "libpi_level_zero.dll" +#define __SYCL_CUDA_PLUGIN_NAME "libpi_cuda.dll" +#define __SYCL_ESIMD_EMULATOR_PLUGIN_NAME "libpi_esimd_emulator.dll" +#define __SYCL_HIP_PLUGIN_NAME "libpi_hip.dll" +#define __SYCL_UNIFIED_RUNTIME_PLUGIN_NAME "libpi_unified_runtime.dll" +#endif +// ------------------------------------ + +static std::map dllMap; + +/// load the five libraries and store them in a map. +void preloadLibraries() { + // Suppress system errors. + // Tells the system to not display the critical-error-handler message box. + // Instead, the system sends the error to the calling process. + // This is crucial for graceful handling of plugins that couldn't be + // loaded, e.g. due to missing native run-times. + // Sometimes affects L0 or the unified runtime. + // TODO: add reporting in case of an error. + // NOTE: we restore the old mode to not affect user app behavior. + // + UINT SavedMode = SetErrorMode(SEM_FAILCRITICALERRORS); + // Exclude current directory from DLL search path + if (!SetDllDirectoryA("")) { + assert(false && "Failed to update DLL search path"); + } + + // this path duplicates sycl/detail/pi.cpp:initializePlugins + const std::string LibSYCLDir = getCurrentDSODir() + DirSep; + + std::string ocl_path = LibSYCLDir + __SYCL_OPENCL_PLUGIN_NAME; + dllMap.emplace(ocl_path, LoadLibraryA(ocl_path.c_str())); + + std::string l0_path = LibSYCLDir + __SYCL_LEVEL_ZERO_PLUGIN_NAME; + dllMap.emplace(l0_path, LoadLibraryA(l0_path.c_str())); + + std::string cuda_path = LibSYCLDir + __SYCL_CUDA_PLUGIN_NAME; + dllMap.emplace(cuda_path, LoadLibraryA(cuda_path.c_str())); + + std::string esimd_path = LibSYCLDir + __SYCL_ESIMD_EMULATOR_PLUGIN_NAME; + dllMap.emplace(esimd_path, LoadLibraryA(esimd_path.c_str())); + + std::string hip_path = LibSYCLDir + __SYCL_HIP_PLUGIN_NAME; + dllMap.emplace(hip_path, LoadLibraryA(hip_path.c_str())); + + std::string ur_path = LibSYCLDir + __SYCL_UNIFIED_RUNTIME_PLUGIN_NAME; + dllMap.emplace(ur_path, LoadLibraryA(ur_path.c_str())); + + // Restore system error handling. + (void)SetErrorMode(SavedMode); + if (!SetDllDirectoryA(nullptr)) { + assert(false && "Failed to restore DLL search path"); + } +} + +/// windows_pi.cpp:loadOsLibrary() calls this to get the DLL we loaded earlier. +__declspec(dllexport) void *getPreloadedPlugin(const std::string &PluginPath) { + + auto match = dllMap.find( + PluginPath); // result might be nullptr, which is perfectly valid. + if (match == dllMap.end()) { + // unit testing? Just return nullptr. + if (PluginPath.find("unittests") != std::string::npos) + return nullptr; + + // Otherwise, asking for something we don't know about at all, is an issue. + std::cout << "unknown plugin: " << PluginPath << std::endl; + assert(false && "getPreloadedPlugin was given an unknown plugin path."); + return nullptr; + } + return match->second; +} + +BOOL WINAPI DllMain(HINSTANCE hinstDLL, // handle to DLL module + DWORD fdwReason, // reason for calling function + LPVOID lpReserved) // reserved +{ + bool PrintPiTrace = false; + static const char *PiTrace = std::getenv("SYCL_PI_TRACE"); + static const int PiTraceValue = PiTrace ? std::stoi(PiTrace) : 0; + if (PiTraceValue == -1 || PiTraceValue == 2) { // Means print all PI traces + PrintPiTrace = true; + } + + switch (fdwReason) { + case DLL_PROCESS_ATTACH: + if (PrintPiTrace) + std::cout << "---> DLL_PROCESS_ATTACH win_proxy_loader.dll\n" + << std::endl; + + preloadLibraries(); + break; + case DLL_PROCESS_DETACH: + if (PrintPiTrace) + std::cout << "---> DLL_PROCESS_DETACH win_proxy_loader.dll\n" + << std::endl; + + case DLL_THREAD_ATTACH: + case DLL_THREAD_DETACH: + break; + } + return TRUE; +} + +#endif // WIN32 diff --git a/sycl/win_proxy_loader/win_proxy_loader.hpp b/sycl/win_proxy_loader/win_proxy_loader.hpp new file mode 100644 index 0000000000000..226e40d8888cb --- /dev/null +++ b/sycl/win_proxy_loader/win_proxy_loader.hpp @@ -0,0 +1,7 @@ +#pragma once + +#ifdef _WIN32 +#include + +__declspec(dllexport) void *getPreloadedPlugin(const std::string &PluginPath); +#endif