Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
58 commits
Select commit Hold shift + click to select a range
9f4ddbc
Adds initial PDL setup.
aendk Jan 21, 2026
73d28e4
Adds PDL barriers based on simple heuristic: place "sync" before firs…
aendk Jan 21, 2026
000f462
Further optimization pass of the first half of kernels
aendk Jan 21, 2026
b68aee7
Optimized PDL barriers for the second batch of kernels
aendk Jan 22, 2026
101583e
Further refinements after rebase.
aendk Feb 4, 2026
0e7aa04
Moves pdl logic to separate function, removes some whitespace
aendk Feb 5, 2026
d8eb8ab
Strips post-hoc PDL logic
aendk Feb 13, 2026
12ddf12
Adds stream capture PDL setup. Enrolls quantize_q8_1 to leverage pdl to
aendk Feb 13, 2026
adfd442
Enrolls mul_mat_vec_q, rms_norm_f32 and k_bin_bcast (partly) into PDL
aendk Feb 13, 2026
7f1342a
Enrolls mmvf, rope, set-rows and topk kernels for gpt-oss into PDL
aendk Feb 18, 2026
f3fe281
Merge branch 'master' into akieslinger/pdl-cuda
aendk Feb 18, 2026
c2d9d47
Introduce ggml_cuda_kernel_launch, to abstract away cudaLaunchKernelEx,
aendk Feb 18, 2026
d942a3a
Enrolls cpy_scalar_contiguous, k_get_rows_float and rms_norm_f32
aendk Feb 18, 2026
11150f0
Enrolls flash_attn_combine_results
aendk Feb 18, 2026
71f8f58
Fix: Drops needless and broken check of CUDA arch for PDL. PDL either
aendk Feb 19, 2026
8664310
Enrolls flash-attention kernels to pdl
aendk Feb 19, 2026
909ec1f
Fix: inlines ggml_cuda_kernel_launch, and uses perfect forwarding for
aendk Feb 20, 2026
3c584d0
Merge branch 'master' into akieslinger/pdl-cuda
aendk Feb 20, 2026
25bbc88
Perf: Enrolls k_bin_bcast variadic template invocation into PDL, via
aendk Feb 20, 2026
c5044bf
Enrolls all remaining kernels for qwen3-coder-next into PDL
aendk Feb 20, 2026
7e76151
Remove all PDL LC calls to create a baseline
aendk Mar 11, 2026
8746582
Merge branch 'master' into akieslinger/pdl-cuda
aendk Mar 11, 2026
dac466d
Merge branch 'master' into akieslinger/pdl-cuda
aendk Mar 24, 2026
23a24c5
Added LC according to internal guidance and tested kernel performance.
aendk Mar 25, 2026
ef28cda
Enrols missing qwen3-5 kernels passively into PDL.
aendk Apr 2, 2026
5e318bf
Kernel optimizations (LC signals) for qwen3.5
aendk Apr 10, 2026
f3b8665
Enrolls ssm-scan kernels into PDL
aendk Apr 10, 2026
0a7d8c3
Merge branch 'master' into akieslinger/pdl-cuda-lc-experiments
aendk Apr 16, 2026
75cd1b0
Merge branch 'master' into akieslinger/pdl-cuda-lc-experiments
aendk Apr 20, 2026
338477a
Merge branch 'master' into akieslinger-pdl-cuda-merge-test
aendk Apr 29, 2026
83e3c79
Adds GGML_CUDA_PDL command line option to toggle PDL.
aendk Apr 29, 2026
3b2d1d1
Fix: Ada and lower compilation by guarding PDL calls correctly
aendk May 8, 2026
2196115
Cleanup: Removes commented out GGML_CUDA_PDL_LC
aendk May 8, 2026
c471996
Cleanup: Removes experimental comments
aendk May 8, 2026
ad7bb69
Adds 90-virtual to build script so that Hopper GPUs can leverage PDL.
aendk May 8, 2026
ff4a9c7
Adds stricter checks to enable PDL, adds env-check to disable it, an…
aendk May 11, 2026
98ee686
Fix: Correct PDL en/disablement based on device-side arch check. Host
aendk May 12, 2026
54483ad
Fix: default-disable PDL. Enable by setting GGML_CUDA_ENABLE_PDL=1
aendk May 13, 2026
fee1c65
Merge branch 'master' into akieslinger/pdl-cuda-lc-experiments
aendk May 13, 2026
a083acc
Enable PDL by default for Hopper+ devices
aendk May 15, 2026
ac33653
Enrolls softcap_f32 and two flash_attn kernels into PDL.
aendk May 15, 2026
a459f2f
Improves flash attn PDL barrier placement
aendk May 15, 2026
4346c54
Merge branch 'master' into akieslinger/pdl-cuda-lc-experiments
aendk May 15, 2026
378e8e7
Fix: Perf regression on ada; excludes ada and below from PDL launches
aendk May 15, 2026
5683763
Improves some sync barrier placements
aendk May 15, 2026
0196e69
Merge branch 'master' into akieslinger/pdl-cuda-lc-experiments
aendk May 18, 2026
12b1d25
Drops superfluous constructor
aendk May 18, 2026
fc8099c
Adds #endif guard comments
aendk May 18, 2026
7bd9f64
Reverts experimental change to top-k-moe.cu, which moved expensive al…
aendk May 18, 2026
aac3b12
Exchanges GGML_CUDA_DISABLE_PDL with GGML_CUDA_PDL. IFF GGML_CUDA_PDL=0
aendk May 18, 2026
47a6072
Revert "Drops superfluous constructor". Adds const to remaining
aendk May 18, 2026
a48bc30
Cleanup: Removes and fixes some comments and whitespace
aendk May 18, 2026
da242a8
Clarifies comment of sync-barrier position
aendk May 18, 2026
0b104c4
Relocates and refactors PDL launch functions and accessories
aendk May 18, 2026
ee9c7b1
Adds error checking to the regular kernel launch path
aendk May 18, 2026
42c6310
Drops "auto" in favor of "ggml_cuda_kernel_params"
aendk May 19, 2026
72eaf40
Adds "const" to ggml_cuda_kernel_launch_params
aendk May 19, 2026
a82defd
[Whitespace] Adds final newline to common.cuh to make editorconfig CI…
aendk May 20, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion ggml/src/ggml-cuda/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ if (CUDAToolkit_FOUND)
# 80 == Ampere, asynchronous data loading, faster tensor core instructions
# 86 == RTX 3000, needs CUDA v11.1
# 89 == RTX 4000, needs CUDA v11.8
# 90 == Hopper H100/200, needs CUDA v11.8
# 120 == Blackwell, needs CUDA v12.8, FP4 tensor cores
#
# XX-virtual == compile CUDA code as PTX, do JIT compilation to binary code on first run
Expand All @@ -33,7 +34,7 @@ if (CUDAToolkit_FOUND)
list(APPEND CMAKE_CUDA_ARCHITECTURES 75-virtual 80-virtual 86-real)

if (CUDAToolkit_VERSION VERSION_GREATER_EQUAL "11.8")
list(APPEND CMAKE_CUDA_ARCHITECTURES 89-real)
list(APPEND CMAKE_CUDA_ARCHITECTURES 89-real 90-virtual)
Comment on lines -36 to +37
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I personally don't see a benefit of adding 90-virtual to the defaults, given we can expect poor performance here due to the cuda backend being unoptimized for Data-Center GPUs

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we have a survey on what % of users use hopper devices to estimate the trade off?
I added it as it was requested here #22522 (comment)

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we have a survey on what % of users use hopper devices to estimate the trade off?
I added it as it was requested here #22522 (comment)

I'd expect it to be low as we use neither wgmma for hopper not tcgen05 for BW Tensor Core acceleration. But I'll not push back on this further, was just hoping to keep binary bloat on our Windows-releases small (iirc llama.cpp only builds/ships binaries for Windows so far).

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some people have been using it, but report suboptimal pre-fill perf as expected #18005

endif()

if (CUDAToolkit_VERSION VERSION_GREATER_EQUAL "12.8")
Expand Down
32 changes: 14 additions & 18 deletions ggml/src/ggml-cuda/binbcast.cu
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@
#include <cstdint>
#include <utility>

template<typename T, size_t>
using type_for_index = T;

static __device__ __forceinline__ float op_repeat(const float a, const float b) {
return b;
GGML_UNUSED(a);
Expand Down Expand Up @@ -52,6 +55,7 @@ static __global__ void k_bin_bcast(const src0_t * src0,
const int s12,
const int s13,
src1_ptrs... src1s) {
ggml_cuda_pdl_lc();
const uint32_t i0s = blockDim.x * blockIdx.x + threadIdx.x;
const uint32_t i1 = (blockDim.y * blockIdx.y + threadIdx.y);
const uint32_t i2 = fastdiv((blockDim.z * blockIdx.z + threadIdx.z), ne3);
Expand All @@ -72,6 +76,7 @@ static __global__ void k_bin_bcast(const src0_t * src0,
const src0_t * src0_row = src0 ? (src0 + i_src0) : nullptr;
dst_t * dst_row = dst + i_dst;

ggml_cuda_pdl_sync();
for (int i0 = i0s; i0 < ne0; i0 += blockDim.x * gridDim.x) {
const uint32_t i10 = fastmodulo(i0, ne10);

Expand Down Expand Up @@ -141,6 +146,7 @@ static __global__ void k_bin_bcast_unravel(const src0_t * src0,

const int i10 = fastmodulo(i0, ne10);

ggml_cuda_pdl_sync();
float result = src0_row ? (float) src0_row[i0*s00] : 0.0f;
if constexpr (sizeof...(src1_ptrs) > 0) {
result = (..., (result = bin_op(result, (float)src1s[i_src1 + i10*s10])));
Expand Down Expand Up @@ -282,35 +288,24 @@ static void launch_bin_bcast_pack(const ggml_tensor * src0, const ggml_tensor *
const uint3 ne1_fastdiv = init_fastdiv_values((uint32_t) ne1);
const uint3 ne2_fastdiv = init_fastdiv_values((uint32_t) ne2);

if constexpr (sizeof...(I) > 0) {
k_bin_bcast_unravel<bin_op, src0_t, src1_t, dst_t><<<block_num, block_size, 0, stream>>>(
{
const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params((dim3)block_num, block_size, 0, stream);
ggml_cuda_kernel_launch(k_bin_bcast_unravel<bin_op, src0_t, src1_t, dst_t, type_for_index<const src1_t *, I>...>, launch_params,
src0_dd, src1_dd, dst_dd, ne0_fastdiv, ne1_fastdiv, ne2_fastdiv, ne3, prod_012, prod_01, ne10, ne11,
ne12, ne13,
/*s0,*/ s1, s2, s3,
s00, s01, s02, s03,
s10, s11, s12, s13, (const src1_t *) dst->src[I + 1]->data...);
} else {
k_bin_bcast_unravel<bin_op, src0_t, src1_t, dst_t>
<<<block_num, block_size, 0, stream>>>(src0_dd, src1_dd, dst_dd, ne0_fastdiv, ne1_fastdiv,
ne2_fastdiv, ne3, prod_012, prod_01, ne10, ne11, ne12, ne13,
/*s0,*/ s1, s2, s3,
s00, s01, s02, s03,
s10, s11, s12, s13);
}
} else {
const uint3 ne3_fastdiv = init_fastdiv_values((uint32_t) ne3);
if constexpr (sizeof...(I) > 0) {
k_bin_bcast<bin_op, src0_t, src1_t, dst_t><<<block_nums, block_dims, 0, stream>>>(
{
const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params(block_nums, block_dims, 0, stream);
ggml_cuda_kernel_launch(k_bin_bcast<bin_op, src0_t, src1_t, dst_t, type_for_index<const src1_t *, I>...>, launch_params,
src0_dd, src1_dd, dst_dd, ne0, ne1, ne2, ne3_fastdiv, ne10, ne11, ne12, ne13,
/*s0,*/ s1, s2, s3,
s00 ,s01, s02, s03,
s10, s11, s12, s13, (const src1_t *) dst->src[I + 1]->data...);
} else {
k_bin_bcast<bin_op, src0_t, src1_t, dst_t><<<block_nums, block_dims, 0, stream>>>(
src0_dd, src1_dd, dst_dd, ne0, ne1, ne2, ne3_fastdiv, ne10, ne11, ne12, ne13,
/*s0,*/ s1, s2, s3,
s00, s01, s02, s03,
s10, s11, s12, s13);
s10, s11, s12, s13, (const src1_t *) dst->src[I + 1]->data...);
}
}
}
Expand All @@ -333,6 +328,7 @@ static __global__ void k_repeat_back(
}

T sum = 0;
ggml_cuda_pdl_sync();
for (int64_t i3 = tid3; i3 < ne03; i3 += ne3) {
for (int64_t i2 = tid2; i2 < ne02; i2 += ne2) {
for (int64_t i1 = tid1; i1 < ne01; i1 += ne1) {
Expand Down
86 changes: 86 additions & 0 deletions ggml/src/ggml-cuda/common.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#include "ggml-cuda.h"

#include <cstdint>
#include <cstdlib>
#include <memory>

#if defined(GGML_USE_HIP)
Expand All @@ -27,6 +28,7 @@
#include <cstdio>
#include <string>
#include <unordered_map>
#include <utility>
#include <vector>

#if defined(GGML_USE_HIP)
Expand All @@ -50,6 +52,7 @@
#define GGML_CUDA_CC_TURING 750
#define GGML_CUDA_CC_AMPERE 800
#define GGML_CUDA_CC_ADA_LOVELACE 890
#define GGML_CUDA_CC_HOPPER 900
Comment thread
aendk marked this conversation as resolved.
// While BW spans CC 1000, 1100 & 1200, we are integrating Tensor Core instructions available to 1200 family, see
// https://docs.nvidia.com/cutlass/media/docs/cpp/blackwell_functionality.html#blackwell-sm120-gemms
#define GGML_CUDA_CC_BLACKWELL 1200
Expand Down Expand Up @@ -107,6 +110,24 @@
# define GGML_CUDA_USE_CUB
#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) && CUDART_VERSION >= 11070

// PDL host-side support (cudaLaunchKernelEx) requires CUDART >= 11.8 and excludes HIP/MUSA.
// __CUDA_ARCH__ is undefined in host passes; GPU arch check happens in device-side code.
#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) && CUDART_VERSION >= 11080
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Btw actually I think its not just 11.8, I am compiling on 12.1 (cudatoolkit from conda) and I still get error : identifier "cudaLaunchKernelEx" is undefined

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cudaLaunchKernelEx is supported in 12.1 and is mentioned in the programming guide of this version (cf. https://docs.nvidia.com/cuda/archive/12.1.0/cuda-c-programming-guide/index.html).
Your error is likely specific to your setup, my suggestion to fix this is by (1) building llama.cpp from scratch again as this is easy & fast, and if this does not work, (2) to update your CUDA toolkit and conda.

Just out of curiosity, what is your nvidia-smi output?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems that cudaLaunchKernelEx is a templated version of cudaLaunchKernelExC as described here. Apparently C++11 or newer is needed but for llama.cpp/ggml we are already requiring that. The function seems to be defined in cuda_runtime.h. While grepping the CUDA toolkit I found this in targets/x86_64-linux/include/cccl/thrust/system/cuda/detail/core/triple_chevron_launch.h:

  // cudaLaunchKernelEx requires C++11, but unfortunately <cuda_runtime.h> checks this using the __cplusplus macro,
  // which is reported wrongly for MSVC. CTK 12.3 fixed this by additionally detecting _MSV_VER. As a workaround, we
  // provide our own copy of cudaLaunchKernelEx when it is not available from the CTK.
#if _CCCL_COMPILER(MSVC) && _CCCL_CUDACC_BELOW(12, 3)
  // Copied from <cuda_runtime.h>
  template <typename... ExpTypes, typename... ActTypes>
  static cudaError_t _CCCL_HOST     
  cudaLaunchKernelEx_MSVC_workaround(const cudaLaunchConfig_t* config, void (*kernel)(ExpTypes...), ActTypes&&... args)
  {                                 
    return [&](ExpTypes... coercedArgs) {
      void* pArgs[] = {&coercedArgs...};
      return ::cudaLaunchKernelExC(config, (const void*) kernel, pArgs);
    }(std::forward<ActTypes>(args)...);
  }                                 
#endif                              

So presumably this is the problem.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the easiest fix would be to just bump the minimum required CUDA version on our end to 12.3. PDL only works on Hopper and Blackwell, and Blackwell needs a higher CUDA version anyways. And I think there basically are no llama.cpp users with both H100s and CUDA versions older than 12.3.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Or rather, we can conditionally require CUDA 12.3 for PDL + MSVC since no one is going to use Windows for H100s.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did we fix this? I am getting the same problem for the CUDA 11.8 builds of whisper.cpp: https://github.com/ggml-org/whisper.cpp/actions/runs/26453742390/job/77916503088#step:11:921

Copy link
Copy Markdown
Collaborator

@ORippler ORippler May 26, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did we fix this? I am getting the same problem for the CUDA 11.8 builds of whisper.cpp: https://github.com/ggml-org/whisper.cpp/actions/runs/26453742390/job/77916503088#step:11:921

@LostRuins #23742 I filed a PR that restricts PDL on MSVC tool-chain, would you mind verifying this fixes your build issues? 😇

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just saw that we are waiting for confirmation - hope I didn't merge the PR to soon?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@LostRuins #23742 I filed a PR that restricts PDL on MSVC tool-chain, would you mind verifying this fixes your build issues? 😇

Seems to work now, the build completed successfully.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for testing ❤️

# define GGML_CUDA_USE_PDL
#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) && CUDART_VERSION >= 11080

static __device__ __forceinline__ void ggml_cuda_pdl_sync() {
#if defined(GGML_CUDA_USE_PDL) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= GGML_CUDA_CC_HOPPER
cudaGridDependencySynchronize();
#endif // defined(GGML_CUDA_USE_PDL) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= GGML_CUDA_CC_HOPPER
}

static __device__ __forceinline__ void ggml_cuda_pdl_lc() {
#if defined(GGML_CUDA_USE_PDL) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= GGML_CUDA_CC_HOPPER
cudaTriggerProgrammaticLaunchCompletion();
#endif // defined(GGML_CUDA_USE_PDL) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= GGML_CUDA_CC_HOPPER
}

#ifdef __CUDA_ARCH_LIST__
constexpr bool ggml_cuda_has_arch_impl(int) {
return false;
Expand Down Expand Up @@ -165,6 +186,7 @@ void ggml_cuda_error(const char * stmt, const char * func, const char * file, in

#define CUDA_CHECK(err) CUDA_CHECK_GEN(err, cudaSuccess, cudaGetErrorString)


#if CUDART_VERSION >= 12000 || defined(GGML_USE_MUSA)
static const char * cublas_get_error_str(const cublasStatus_t err) {
return cublasGetStatusString(err);
Expand Down Expand Up @@ -1487,3 +1509,67 @@ struct ggml_cuda_mm_fusion_args_device {
const void * gate_bias = nullptr;
ggml_glu_op glu_op;
};

struct ggml_cuda_kernel_launch_params {
dim3 block_nums;
dim3 block_dims;
size_t shmem;
cudaStream_t stream;

// size_t shmem
ggml_cuda_kernel_launch_params(const dim3& block_nums_, const dim3& block_dims_, const size_t shmem_, const cudaStream_t stream_)
: block_nums(block_nums_), block_dims(block_dims_), shmem(shmem_), stream(stream_) {}

// Some call sites pass ints instead of the required size_t. This 2nd constructor casts int->size_t to avoid these -Wnarrowing warnings.
ggml_cuda_kernel_launch_params(const dim3& block_nums_, const dim3& block_dims_, const int shmem_, const cudaStream_t stream_)
: block_nums(block_nums_), block_dims(block_dims_), shmem((size_t)shmem_), stream(stream_) {}
};

#if defined(GGML_CUDA_USE_PDL)
struct ggml_cuda_pdl_config {
cudaLaunchAttribute attr;
cudaLaunchConfig_t cfg;

ggml_cuda_pdl_config(const ggml_cuda_kernel_launch_params & params) {
attr.id = cudaLaunchAttributeProgrammaticStreamSerialization;
attr.val.programmaticStreamSerializationAllowed = 1;

cfg = {};
cfg.gridDim = params.block_nums;
cfg.blockDim = params.block_dims;
cfg.dynamicSmemBytes = params.shmem;
cfg.stream = params.stream;
cfg.attrs = &attr;
cfg.numAttrs = 1;
}

// Delete due to &attr
ggml_cuda_pdl_config(const ggml_cuda_pdl_config&) = delete;
ggml_cuda_pdl_config& operator=(const ggml_cuda_pdl_config&) = delete;
ggml_cuda_pdl_config& operator=(ggml_cuda_pdl_config&&) = delete;

};
#endif //defined(GGML_CUDA_USE_PDL)


template<typename Kernel, typename... Args>
static __inline__ void ggml_cuda_kernel_launch(Kernel kernel, const ggml_cuda_kernel_launch_params & launch_params, Args&&... args) {
#if defined(GGML_CUDA_USE_PDL)

static const bool env_pdl_enabled = []() {
const char * env = getenv("GGML_CUDA_PDL");
return env == nullptr || std::atoi(env) != 0;
}();

if (env_pdl_enabled && ggml_cuda_info().devices[ggml_cuda_get_device()].cc >= GGML_CUDA_CC_HOPPER) {
auto pdl_cfg = ggml_cuda_pdl_config(launch_params);

CUDA_CHECK(cudaLaunchKernelEx(&pdl_cfg.cfg, kernel, std::forward<Args>(args)... ));
return;
}
#endif //defined(GGML_CUDA_USE_PDL)

kernel<<<launch_params.block_nums, launch_params.block_dims, launch_params.shmem, launch_params.stream>>>(std::forward<Args>(args)... );
CUDA_CHECK(cudaGetLastError());
}

5 changes: 3 additions & 2 deletions ggml/src/ggml-cuda/concat.cu
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ static __global__ void __launch_bounds__(CUDA_CONCAT_BLOCK_SIZE) concat_f32_cont

const int64_t n = ne0 * ne1 * ne2;

ggml_cuda_pdl_sync();
for (int64_t i = (int64_t) blockIdx.x * blockDim.x + threadIdx.x; i < n; i += (int64_t) blockDim.x * gridDim.x) {
if constexpr (dim == 0) {
const int64_t row = i / ne0;
Expand Down Expand Up @@ -64,8 +65,8 @@ static void concat_f32_cuda(const float * x,
const int num_blocks = (n + CUDA_CONCAT_BLOCK_SIZE - 1) / CUDA_CONCAT_BLOCK_SIZE;

if (dim == 0) {
concat_f32_cont<0>
<<<num_blocks, CUDA_CONCAT_BLOCK_SIZE, 0, stream>>>(x, y, dst, ne00, ne01, ne02, ne0, ne1, ne2);
const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params(num_blocks, CUDA_CONCAT_BLOCK_SIZE, 0, stream);
ggml_cuda_kernel_launch(concat_f32_cont<0>, launch_params,x, y, dst, ne00, ne01, ne02, ne0, ne1, ne2);
return;
}
if (dim == 1) {
Expand Down
20 changes: 14 additions & 6 deletions ggml/src/ggml-cuda/cpy.cu
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ static __global__ void cpy_scalar(const char * cx, char * cdst, const int64_t ne
const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t nb00, const int64_t nb01, const int64_t nb02,
const int64_t nb03, const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t nb10, const int64_t nb11,
const int64_t nb12, const int64_t nb13) {
ggml_cuda_pdl_lc();
const int64_t i = (int64_t)blockDim.x*blockIdx.x + threadIdx.x;

if (i >= ne) {
Expand All @@ -36,6 +37,7 @@ static __global__ void cpy_scalar(const char * cx, char * cdst, const int64_t ne
const int64_t i10 = i - i13*ne10*ne11*ne12 - i12*ne10*ne11 - i11*ne10;
const int64_t dst_offset = i10*nb10 + i11*nb11 + i12*nb12 + i13 * nb13;

ggml_cuda_pdl_sync();
cpy_1(cx + x_offset, cdst + dst_offset);
}

Expand All @@ -59,6 +61,7 @@ static __global__ void cpy_scalar_transpose(const char * cx, char * cdst, const
__shared__ float tile[2][CUDA_CPY_TILE_DIM_2D][CUDA_CPY_TILE_DIM_2D+1];
int cur_tile_buf = 0;

ggml_cuda_pdl_sync();
#pragma unroll
for (int i = 0; i < CUDA_CPY_BLOCK_NM; ++i) {

Expand Down Expand Up @@ -142,6 +145,7 @@ static __global__ void cpy_f32_q(const char * cx, char * cdst, const int64_t ne,
const int64_t i10 = i - i13*ne10*ne11*ne12 - i12*ne10*ne11 - i11*ne10;
const int64_t dst_offset = (i10/qk)*nb10 + i11*nb11 + i12*nb12 + i13*nb13;

ggml_cuda_pdl_sync();
cpy_blck(cx + x_offset, cdst + dst_offset);
}

Expand All @@ -168,6 +172,7 @@ static __global__ void cpy_q_f32(const char * cx, char * cdst, const int64_t ne,
const int64_t i10 = i - i13*ne10*ne11*ne12 - i12*ne10*ne11 - i11*ne10;
const int64_t dst_offset = i10*nb10 + i11*nb11 + i12*nb12 + i13*nb13;

ggml_cuda_pdl_sync();
cpy_blck(cx + x_offset, cdst + dst_offset);
}

Expand All @@ -182,6 +187,7 @@ static __global__ void cpy_scalar_contiguous(const char * cx, char * cdst, const
const src_t * x = (const src_t *) cx;
dst_t * dst = (dst_t *) cdst;

ggml_cuda_pdl_sync();
dst[i] = ggml_cuda_cast<dst_t>(x[i]);
}

Expand All @@ -192,8 +198,8 @@ cudaStream_t stream) {

const int64_t num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE;
GGML_ASSERT(num_blocks < UINT_MAX);
cpy_scalar_contiguous<src_t, dst_t><<<num_blocks, CUDA_CPY_BLOCK_SIZE, 0, stream>>>
(cx, cdst, ne);
const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params((dim3)num_blocks, CUDA_CPY_BLOCK_SIZE, 0, stream);
ggml_cuda_kernel_launch(cpy_scalar_contiguous<src_t, dst_t>, launch_params, cx, cdst, ne);
}

template<typename src_t, typename dst_t, bool transposed = false>
Expand Down Expand Up @@ -223,13 +229,15 @@ static void ggml_cpy_scalar_cuda(
GGML_ASSERT(grid_z < USHRT_MAX);
dim3 dimGrid(grid_x, grid_y, grid_z);
dim3 dimBlock(CUDA_CPY_TILE_DIM_2D, CUDA_CPY_BLOCK_ROWS, 1);
cpy_scalar_transpose<dst_t><<<dimGrid, dimBlock, 0, stream>>>
(cx, cdst, ne, ne00n, ne01n, ne02n, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params(dimGrid, dimBlock, 0, stream);
ggml_cuda_kernel_launch(cpy_scalar_transpose<dst_t>, launch_params,
cx, cdst, ne, ne00n, ne01n, ne02n, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
} else {
const int64_t num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE;
GGML_ASSERT(num_blocks < UINT_MAX);
cpy_scalar<cpy_1_scalar<src_t, dst_t>><<<num_blocks, CUDA_CPY_BLOCK_SIZE, 0, stream>>>
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params((dim3)num_blocks, CUDA_CPY_BLOCK_SIZE, 0, stream);
ggml_cuda_kernel_launch(cpy_scalar<cpy_1_scalar<src_t, dst_t>>, launch_params,
cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
}
}

Expand Down
Loading
Loading