Skip to content

Avoid PDL race conditions by disabling __restrict__ when PDL is used#24030

Merged
ORippler merged 3 commits into
ggml-org:masterfrom
aendk:akieslinger/pdl-bugfix-remove-restrict
Jun 3, 2026
Merged

Avoid PDL race conditions by disabling __restrict__ when PDL is used#24030
ORippler merged 3 commits into
ggml-org:masterfrom
aendk:akieslinger/pdl-bugfix-remove-restrict

Conversation

@aendk
Copy link
Copy Markdown
Contributor

@aendk aendk commented Jun 2, 2026

Overview

Follow up to #23825.
Together with CUDA engineers, we identified the suspected bug of #23825; PDL and __restrict__ cannot coexist, as __restrict__ can move data reads before the PDL barrier and cause race conditions in the GPU byte code.
This PR disables __restrict__ for device code which leverages PDL, and retains __restrict__ (and thus performance) for all other GPU architectures.

Performance

I tested performance on an RTX PRO 6000 Blackwell Max-Q and an RTX 6000 Ada Generation. On Ada, there is no impact on performance discernible from noise. This aligns with expectations as it does not use PDL, and because it benefits from __restrict__
On Blackwell, we see a small improvement on Qwen3.6 (~1%) because we re-enroll the flash-attention kernels disabled in #23825.

Detailed Bug Report

Using __restrict__ can lead to data loads which leverage the read-only texture cache path. This is a non-coherent cache.
An example load into this cache is via the ld.global.nc SASS byte code instruction. It is explicitly excluded from the CUDA/PTX memory consistency model in which PDL works (https://docs.nvidia.com/cuda/parallel-thread-execution/#scope-and-applicability).
This means that the compiler is free to move any __restrict__ input load compiled into ld.global.nc ahead of the PDL synchronization barrier (cudaGridDependencySynchronize() is compiled to ACQBULK). Race conditions are possible.
Until now, this was not common knowledge, we'll improve tooling and guides.

The bug addressed in #23825 was caused by a race condition like this. A __restrict__ ptr was hoisted over (=moved in front of) the SASS instruction of the ggml_pdl_cuda_sync() barrier. The data read was thus executed before the PDL synchronization barrier in GPU byte code, even though in C++, the synchronization was placed correctly before the data read.

Requirements

  • I have read and agree with the contributing guidelines
  • AI usage disclosure: YES, but every line committed was double-checked manually.

@ORippler for viz

PDL. Adds preprocessor directives based on arch in kernel body to add
__restrict__ to retain performance on older architectures.
@aendk aendk requested review from a team and IMbackK as code owners June 2, 2026 15:42
@github-actions github-actions Bot added Nvidia GPU Issues specific to Nvidia GPUs ggml changes relating to the ggml tensor library for machine learning labels Jun 2, 2026
Copy link
Copy Markdown
Collaborator

@ORippler ORippler left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for your work. Adding some more context:

  1. For the interested, a minimal repro with workaround is here: https://godbolt.org/z/zzrrarWss
    • In broken, SASS from sm_100 has LDG hoisted before ACQBULK, which effectively introduces the data-race.
    • fixed conditionally applies restrict to b based on CUDA_ARCH and does not show this pattern on post-Hopper GPUS while maintaining restrict perf gains on older HW.
  2. While we could simply remove __restrict__ without perf penalty on newer hardware that have iso-VRAM throughput for texture and compute memory-paths, we saw some perf regression on Ampere-based A40 and hence opted for device-side restriction. Fun-fact: CCCL does not use restrict anywhere and still manages to achieve SOL perf across gens.
  3. The conditional restrict cannot happen in the function signature of the kernel, as this will fail for templated __global__ kernels
  4. We will file an update to CUDA docs to explicitly state that PDL and restrict are mutually exclusive for the time being.

From my side this is generally good to go, but I'd love for us to conditionally #define the restrict once, which should make the kernels significantly easier to read (see the godbolt I shared). Given PDL is currently potentially bugged, we can do this in a follow-up PR.

Comment thread ggml/src/ggml-cuda/fattn-common.cuh Outdated
Copy link
Copy Markdown
Contributor

@JohannesGaessler JohannesGaessler left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The current solution is quite cumbersome. As suggested by Oliver, please define a macro like GGML_CUDA_RESTRICT.

More generally, beyond just documentation my opinion is that the combination of __restrict__ and PDL in a single kernel should result in a compiler warning since evidently even for NVIDIA engineers this is a non-trivial issue to track down.

@ORippler
Copy link
Copy Markdown
Collaborator

ORippler commented Jun 2, 2026

More generally, beyond just documentation my opinion is that the combination of restrict and PDL in a single kernel should result in a compiler warning since evidently even for NVIDIA engineers this is a non-trivial issue to track down.

Is that a Feature Request for NVCC/CUDA compiler chain? Or something you would like to see implemented in ggml's CUDA backend. For the former, adding a warning when the hoisting happens is something we are discussing as one of the ways to improve developer UX.

@JohannesGaessler
Copy link
Copy Markdown
Contributor

Is that a Feature Request for NVCC/CUDA compiler chain? Or something you would like to see implemented in ggml's CUDA backend.

It's not so much a feature request (since I personally am now aware of the problem) but rather a feature suggestion to avoid time spent on debugging in other projects. For ggml I would consider a macro to be 100% sufficient.

@jeffbolznv
Copy link
Copy Markdown
Contributor

We will file an update to CUDA docs to explicitly state that PDL and restrict are mutually exclusive for the time being.

Is this "working as designed" or is it just a compiler bug that needs to be fixed?

@aendk
Copy link
Copy Markdown
Contributor Author

aendk commented Jun 3, 2026

@jeffbolznv everything works as designed, we need to add a big disclaimer that PDL cannot be mixed with __restrict__.
PDL works in the confines of the memory model, whilst __restrict__ can lead to SASS that is explicitly excluded from guarantees of this model.

@aendk
Copy link
Copy Markdown
Contributor Author

aendk commented Jun 3, 2026

@ORippler @JohannesGaessler ready for review.
I double checked performance (no change) and validity (via nsight compute). The bytecode shows the intended behavior on Ada (no PDL but __restrict__) and Blackwell (PDL, no __restrict__ ).

Comment on lines +681 to +682
float * dst_ptr,
const float2 * dst_fixup_ptr,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
float * dst_ptr,
const float2 * dst_fixup_ptr,
float * GGML_CUDA_RESTRICT dst_ptr,
const float2 * GGML_CUDA_RESTRICT dst_fixup_ptr,

I would say to just use the macro in the declaration. Also, if __restrict__ has a negligible impact on Blackwell anyways, please replace all instances of __restrict__ in the CUDA backend with GGML_CUDA_RESTRICT, even if those kernels currently do not use PDL.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That sadly does not work in templated function headers which are __global__.
The reason is that __CUDA_ARCH__ is undefined on host compilation passes, but defined on device compilation passes, and this leads to ABI mismatches.

We could do that on some of the function headers (e.g. quantize_q8_1), but it is discouraged by our engineering and possibly/likely unstable.

Copy link
Copy Markdown
Collaborator

@ORippler ORippler Jun 3, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could do that on some of the function headers (e.g. quantize_q8_1), but it is discouraged by our engineering and possibly/likely unstable.

This is not supported, quoting from the cuda docs:
https://docs.nvidia.com/cuda/cuda-programming-guide/05-appendices/cpp-language-extensions.html#cuda-arch

image

Copy link
Copy Markdown
Contributor Author

@aendk aendk Jun 3, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, if __restrict__ has a negligible impact on Blackwell anyways, please replace all instances of __restrict__ in the CUDA backend with GGML_CUDA_RESTRICT, even if those kernels currently do not use PDL.

Would you still want me to do this by adding code outside of the function header?
I assume that because its not a seamless replacement, it does not make sense at this point.

Copy link
Copy Markdown
Collaborator

@ORippler ORippler left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. Please include Hopper in the fix
  2. We could think about naming this to GGML_CUDA_RESTRICT_PREHOPPER, but that may be too verbose (I don't have any stance on this)

I'd also recommend a comment linking this PR to where we #define GGML_CUDA_RESTRICT so people can understand why this is necessary until the CUDA documentation update is in place

Comment thread ggml/src/ggml-cuda/common.cuh Outdated
Co-authored-by: Oliver Simons <osimons@nvidia.com>
Copy link
Copy Markdown
Collaborator

@ORippler ORippler left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. I'd personally recommend to fix PDL first and do another round of mechanical __restrict__ changes + associated perf analysis in a follow-up PR

Copy link
Copy Markdown
Contributor

@JohannesGaessler JohannesGaessler left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't really like this syntax but it seems it's the least bad option we have. Thank you for debugging this!

Comment thread ggml/src/ggml-cuda/ssm-scan.cu
@ORippler ORippler merged commit 9e58d4d into ggml-org:master Jun 3, 2026
21 of 22 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ggml changes relating to the ggml tensor library for machine learning Nvidia GPU Issues specific to Nvidia GPUs

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants