Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 43 additions & 28 deletions csrc/driver_api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -61,34 +61,49 @@
//
// Driver APIs are loaded using cudaGetDriverEntryPoint as recommended by
// https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#using-the-runtime-api
#define DEFINE_DRIVER_API_WRAPPER(funcName, version) \
namespace { \
template <typename ReturnType, typename... Args> \
struct funcName##Loader { \
static ReturnType lazilyLoadAndInvoke(Args... args) { \
static decltype(::funcName)* f; \
static std::once_flag once; \
std::call_once(once, [&]() { \
NVFUSER_CUDA_RT_SAFE_CALL(cudaGetDriverEntryPointByVersion( \
#funcName, \
reinterpret_cast<void**>(&f), \
version, \
cudaEnableDefault)); \
}); \
return f(args...); \
} \
/* This ctor is just a CTAD helper, it is only used in a */ \
/* non-evaluated environment*/ \
funcName##Loader(ReturnType(Args...)){}; \
}; \
\
/* Use CTAD rule to deduct return and argument types */ \
template <typename ReturnType, typename... Args> \
funcName##Loader(ReturnType(Args...)) \
->funcName##Loader<ReturnType, Args...>; \
} /* namespace */ \
\
decltype(::funcName)* funcName = \
namespace {
void getDriverEntryPoint(
const char* symbol,
unsigned int version,
void** entry_point) {
#if (CUDA_VERSION >= 12050)
NVFUSER_CUDA_RT_SAFE_CALL(cudaGetDriverEntryPointByVersion(
symbol, entry_point, version, cudaEnableDefault));
#else
(void)version;
NVFUSER_CUDA_RT_SAFE_CALL(
cudaGetDriverEntryPoint(symbol, entry_point, cudaEnableDefault));
#endif
}
} // namespace

#define DEFINE_DRIVER_API_WRAPPER(funcName, version) \
namespace { \
template <typename ReturnType, typename... Args> \
struct funcName##Loader { \
static ReturnType lazilyLoadAndInvoke(Args... args) { \
static decltype(::funcName)* entry_point; \
static std::once_flag once; \
std::call_once( \
once, \
getDriverEntryPoint, \
#funcName, \
version, \
reinterpret_cast<void**>(&entry_point)); \
return entry_point(args...); \
} \
/* This ctor is just a CTAD helper, it is only used in a */ \
/* non-evaluated environment*/ \
funcName##Loader(ReturnType(Args...)){}; \
}; \
\
/* Use CTAD rule to deduct return and argument types */ \
template <typename ReturnType, typename... Args> \
funcName##Loader(ReturnType(Args...)) \
->funcName##Loader<ReturnType, Args...>; \
} /* namespace */ \
\
decltype(::funcName)* funcName = \
decltype(funcName##Loader(::funcName))::lazilyLoadAndInvoke

namespace nvfuser {
Expand Down