-
Notifications
You must be signed in to change notification settings - Fork 19.2k
Programmatic Dependent Launch (PDL) for more performance on newer NVIDIA GPUs (Hopper+) #22522
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
9f4ddbc
73d28e4
000f462
b68aee7
101583e
0e7aa04
d8eb8ab
12ddf12
adfd442
7f1342a
f3fe281
c2d9d47
d942a3a
11150f0
71f8f58
8664310
909ec1f
3c584d0
25bbc88
c5044bf
7e76151
8746582
dac466d
23a24c5
ef28cda
5e318bf
f3b8665
0a7d8c3
75cd1b0
338477a
83e3c79
3b2d1d1
2196115
c471996
ad7bb69
ff4a9c7
98ee686
54483ad
fee1c65
a083acc
ac33653
a459f2f
4346c54
378e8e7
5683763
0196e69
12b1d25
fc8099c
7bd9f64
aac3b12
47a6072
a48bc30
da242a8
0b104c4
ee9c7b1
42c6310
72eaf40
a82defd
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -5,6 +5,7 @@ | |
| #include "ggml-cuda.h" | ||
|
|
||
| #include <cstdint> | ||
| #include <cstdlib> | ||
| #include <memory> | ||
|
|
||
| #if defined(GGML_USE_HIP) | ||
|
|
@@ -27,6 +28,7 @@ | |
| #include <cstdio> | ||
| #include <string> | ||
| #include <unordered_map> | ||
| #include <utility> | ||
| #include <vector> | ||
|
|
||
| #if defined(GGML_USE_HIP) | ||
|
|
@@ -50,6 +52,7 @@ | |
| #define GGML_CUDA_CC_TURING 750 | ||
| #define GGML_CUDA_CC_AMPERE 800 | ||
| #define GGML_CUDA_CC_ADA_LOVELACE 890 | ||
| #define GGML_CUDA_CC_HOPPER 900 | ||
|
aendk marked this conversation as resolved.
|
||
| // While BW spans CC 1000, 1100 & 1200, we are integrating Tensor Core instructions available to 1200 family, see | ||
| // https://docs.nvidia.com/cutlass/media/docs/cpp/blackwell_functionality.html#blackwell-sm120-gemms | ||
| #define GGML_CUDA_CC_BLACKWELL 1200 | ||
|
|
@@ -107,6 +110,24 @@ | |
| # define GGML_CUDA_USE_CUB | ||
| #endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) && CUDART_VERSION >= 11070 | ||
|
|
||
| // PDL host-side support (cudaLaunchKernelEx) requires CUDART >= 11.8 and excludes HIP/MUSA. | ||
| // __CUDA_ARCH__ is undefined in host passes; GPU arch check happens in device-side code. | ||
| #if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) && CUDART_VERSION >= 11080 | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Btw actually I think its not just 11.8, I am compiling on 12.1 (cudatoolkit from conda) and I still get
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Just out of curiosity, what is your
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It seems that // cudaLaunchKernelEx requires C++11, but unfortunately <cuda_runtime.h> checks this using the __cplusplus macro,
// which is reported wrongly for MSVC. CTK 12.3 fixed this by additionally detecting _MSV_VER. As a workaround, we
// provide our own copy of cudaLaunchKernelEx when it is not available from the CTK.
#if _CCCL_COMPILER(MSVC) && _CCCL_CUDACC_BELOW(12, 3)
// Copied from <cuda_runtime.h>
template <typename... ExpTypes, typename... ActTypes>
static cudaError_t _CCCL_HOST
cudaLaunchKernelEx_MSVC_workaround(const cudaLaunchConfig_t* config, void (*kernel)(ExpTypes...), ActTypes&&... args)
{
return [&](ExpTypes... coercedArgs) {
void* pArgs[] = {&coercedArgs...};
return ::cudaLaunchKernelExC(config, (const void*) kernel, pArgs);
}(std::forward<ActTypes>(args)...);
}
#endif So presumably this is the problem.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think the easiest fix would be to just bump the minimum required CUDA version on our end to 12.3. PDL only works on Hopper and Blackwell, and Blackwell needs a higher CUDA version anyways. And I think there basically are no llama.cpp users with both H100s and CUDA versions older than 12.3.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Or rather, we can conditionally require CUDA 12.3 for PDL + MSVC since no one is going to use Windows for H100s.
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Did we fix this? I am getting the same problem for the CUDA 11.8 builds of
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
@LostRuins #23742 I filed a PR that restricts PDL on MSVC tool-chain, would you mind verifying this fixes your build issues? 😇
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Just saw that we are waiting for confirmation - hope I didn't merge the PR to soon?
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Seems to work now, the build completed successfully.
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks for testing ❤️ |
||
| # define GGML_CUDA_USE_PDL | ||
| #endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) && CUDART_VERSION >= 11080 | ||
|
|
||
| static __device__ __forceinline__ void ggml_cuda_pdl_sync() { | ||
| #if defined(GGML_CUDA_USE_PDL) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= GGML_CUDA_CC_HOPPER | ||
| cudaGridDependencySynchronize(); | ||
| #endif // defined(GGML_CUDA_USE_PDL) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= GGML_CUDA_CC_HOPPER | ||
| } | ||
|
|
||
| static __device__ __forceinline__ void ggml_cuda_pdl_lc() { | ||
| #if defined(GGML_CUDA_USE_PDL) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= GGML_CUDA_CC_HOPPER | ||
| cudaTriggerProgrammaticLaunchCompletion(); | ||
| #endif // defined(GGML_CUDA_USE_PDL) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= GGML_CUDA_CC_HOPPER | ||
| } | ||
|
|
||
| #ifdef __CUDA_ARCH_LIST__ | ||
| constexpr bool ggml_cuda_has_arch_impl(int) { | ||
| return false; | ||
|
|
@@ -165,6 +186,7 @@ void ggml_cuda_error(const char * stmt, const char * func, const char * file, in | |
|
|
||
| #define CUDA_CHECK(err) CUDA_CHECK_GEN(err, cudaSuccess, cudaGetErrorString) | ||
|
|
||
|
|
||
| #if CUDART_VERSION >= 12000 || defined(GGML_USE_MUSA) | ||
| static const char * cublas_get_error_str(const cublasStatus_t err) { | ||
| return cublasGetStatusString(err); | ||
|
|
@@ -1487,3 +1509,67 @@ struct ggml_cuda_mm_fusion_args_device { | |
| const void * gate_bias = nullptr; | ||
| ggml_glu_op glu_op; | ||
| }; | ||
|
|
||
| struct ggml_cuda_kernel_launch_params { | ||
| dim3 block_nums; | ||
| dim3 block_dims; | ||
| size_t shmem; | ||
| cudaStream_t stream; | ||
|
|
||
| // size_t shmem | ||
| ggml_cuda_kernel_launch_params(const dim3& block_nums_, const dim3& block_dims_, const size_t shmem_, const cudaStream_t stream_) | ||
| : block_nums(block_nums_), block_dims(block_dims_), shmem(shmem_), stream(stream_) {} | ||
|
|
||
| // Some call sites pass ints instead of the required size_t. This 2nd constructor casts int->size_t to avoid these -Wnarrowing warnings. | ||
| ggml_cuda_kernel_launch_params(const dim3& block_nums_, const dim3& block_dims_, const int shmem_, const cudaStream_t stream_) | ||
| : block_nums(block_nums_), block_dims(block_dims_), shmem((size_t)shmem_), stream(stream_) {} | ||
| }; | ||
|
|
||
| #if defined(GGML_CUDA_USE_PDL) | ||
| struct ggml_cuda_pdl_config { | ||
| cudaLaunchAttribute attr; | ||
| cudaLaunchConfig_t cfg; | ||
|
|
||
| ggml_cuda_pdl_config(const ggml_cuda_kernel_launch_params & params) { | ||
| attr.id = cudaLaunchAttributeProgrammaticStreamSerialization; | ||
| attr.val.programmaticStreamSerializationAllowed = 1; | ||
|
|
||
| cfg = {}; | ||
| cfg.gridDim = params.block_nums; | ||
| cfg.blockDim = params.block_dims; | ||
| cfg.dynamicSmemBytes = params.shmem; | ||
| cfg.stream = params.stream; | ||
| cfg.attrs = &attr; | ||
| cfg.numAttrs = 1; | ||
| } | ||
|
|
||
| // Delete due to &attr | ||
| ggml_cuda_pdl_config(const ggml_cuda_pdl_config&) = delete; | ||
| ggml_cuda_pdl_config& operator=(const ggml_cuda_pdl_config&) = delete; | ||
| ggml_cuda_pdl_config& operator=(ggml_cuda_pdl_config&&) = delete; | ||
|
|
||
| }; | ||
| #endif //defined(GGML_CUDA_USE_PDL) | ||
|
|
||
|
|
||
| template<typename Kernel, typename... Args> | ||
| static __inline__ void ggml_cuda_kernel_launch(Kernel kernel, const ggml_cuda_kernel_launch_params & launch_params, Args&&... args) { | ||
| #if defined(GGML_CUDA_USE_PDL) | ||
|
|
||
| static const bool env_pdl_enabled = []() { | ||
| const char * env = getenv("GGML_CUDA_PDL"); | ||
| return env == nullptr || std::atoi(env) != 0; | ||
| }(); | ||
|
|
||
| if (env_pdl_enabled && ggml_cuda_info().devices[ggml_cuda_get_device()].cc >= GGML_CUDA_CC_HOPPER) { | ||
| auto pdl_cfg = ggml_cuda_pdl_config(launch_params); | ||
|
|
||
| CUDA_CHECK(cudaLaunchKernelEx(&pdl_cfg.cfg, kernel, std::forward<Args>(args)... )); | ||
| return; | ||
| } | ||
| #endif //defined(GGML_CUDA_USE_PDL) | ||
|
|
||
| kernel<<<launch_params.block_nums, launch_params.block_dims, launch_params.shmem, launch_params.stream>>>(std::forward<Args>(args)... ); | ||
| CUDA_CHECK(cudaGetLastError()); | ||
| } | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I personally don't see a benefit of adding
90-virtualto the defaults, given we can expect poor performance here due to the cuda backend being unoptimized for Data-Center GPUsThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we have a survey on what % of users use hopper devices to estimate the trade off?
I added it as it was requested here #22522 (comment)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd expect it to be low as we use neither
wgmmafor hopper nottcgen05for BW Tensor Core acceleration. But I'll not push back on this further, was just hoping to keep binary bloat on our Windows-releases small (iirc llama.cpp only builds/ships binaries for Windows so far).There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Some people have been using it, but report suboptimal pre-fill perf as expected #18005