fix: register default global_scratch allocator on Blackwell GPUs#825
Conversation
On Blackwell (SM 10.0+), the Triton compiler emits global_scratch memory for autotuned kernels even when TMA is not used (FLA_USE_TMA=0). Without an allocator registered, this causes NullAllocator crashes during kernel autotuning, which corrupts CUDA synchronization state and leads to process deadlocks. The existing allocator registration only runs when IS_TMA_SUPPORTED is True (requires FLA_USE_TMA=1). This change also registers the allocator on Blackwell when TMA is disabled, since the compiler still needs scratch space for other purposes on SM 10.0+. Fixes deadlocks when running MoE+Mamba models (Qwen3-Coder-Next, Qwen3.5) on Blackwell GPUs via vLLM. See: triton-lang/triton#10002
WalkthroughBroadened Blackwell detection to include devices with compute capability major >= 10; added a module-level Changes
Sequence Diagram(s)sequenceDiagram
participant Py as PythonModule
participant Triton as TritonRuntime
participant CUDA as NVIDIA_GPU
Py->>Triton: triton.set_allocator(_default_alloc_fn)
Note over Triton,CUDA: Triton stores allocator callback
Triton->>CUDA: request global_scratch allocation via callback
CUDA-->>Triton: returns device memory tensor (int8)
Triton-->>Py: allocator registered (log)
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~20 minutes Possibly related PRs
Poem
🚥 Pre-merge checks | ✅ 2 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Code Review
This pull request refactors the Triton allocator setup in fla/utils.py by extracting a default allocation function and extending its registration to NVIDIA Blackwell (SM 10.0+) devices to prevent NullAllocator crashes. A review comment identifies that hardcoding the GPU index in the capability check may cause issues in multi-GPU systems and suggests using the current device instead.
|
|
||
| triton.set_allocator(alloc_fn) | ||
| triton.set_allocator(_default_alloc_fn) | ||
| elif IS_NVIDIA and torch.cuda.get_device_capability(0)[0] >= 10: |
There was a problem hiding this comment.
Using torch.cuda.get_device_capability(0) hardcodes the check to the first GPU. In heterogeneous multi-GPU systems where the current device (e.g., device 1) is Blackwell but device 0 is an older architecture, this check will incorrectly return False at import time if device 0 is the default, potentially leading to the NullAllocator crash you are addressing.
It is better to use torch.cuda.get_device_capability() without arguments to check the current device, which is also consistent with how IS_NVIDIA_BLACKWELL is defined at line 467. Alternatively, you could use the IS_NVIDIA_BLACKWELL constant directly if you update its definition to use >= 10.
| elif IS_NVIDIA and torch.cuda.get_device_capability(0)[0] >= 10: | |
| elif IS_NVIDIA and torch.cuda.get_device_capability()[0] >= 10: |
There was a problem hiding this comment.
🧹 Nitpick comments (1)
fla/utils.py (1)
488-488: Avoid Blackwell capability check drift.Consider reusing
IS_NVIDIA_BLACKWELLhere and updating that constant to match>= 10; right now the constant (== 10) and this branch (>= 10) encode different definitions.♻️ Proposed refactor
-IS_NVIDIA_BLACKWELL = (IS_NVIDIA and torch.cuda.get_device_capability()[0] == 10) +IS_NVIDIA_BLACKWELL = (IS_NVIDIA and torch.cuda.get_device_capability()[0] >= 10) -elif IS_NVIDIA and torch.cuda.get_device_capability(0)[0] >= 10: +elif IS_NVIDIA_BLACKWELL:🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@fla/utils.py` at line 488, The branch uses a different Blackwell capability test (torch.cuda.get_device_capability(0)[0] >= 10) than the constant IS_NVIDIA_BLACKWELL (which is == 10), so replace the inline capability check with the shared constant and change the IS_NVIDIA_BLACKWELL definition to use >= 10; specifically, update the IS_NVIDIA_BLACKWELL assignment to compute torch.cuda.get_device_capability(0)[0] >= 10 (preserving existing IS_NVIDIA gating) and change the branch `elif IS_NVIDIA and torch.cuda.get_device_capability(0)[0] >= 10:` to `elif IS_NVIDIA_BLACKWELL:` so both places use the same semantic test.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Nitpick comments:
In `@fla/utils.py`:
- Line 488: The branch uses a different Blackwell capability test
(torch.cuda.get_device_capability(0)[0] >= 10) than the constant
IS_NVIDIA_BLACKWELL (which is == 10), so replace the inline capability check
with the shared constant and change the IS_NVIDIA_BLACKWELL definition to use >=
10; specifically, update the IS_NVIDIA_BLACKWELL assignment to compute
torch.cuda.get_device_capability(0)[0] >= 10 (preserving existing IS_NVIDIA
gating) and change the branch `elif IS_NVIDIA and
torch.cuda.get_device_capability(0)[0] >= 10:` to `elif IS_NVIDIA_BLACKWELL:` so
both places use the same semantic test.
There was a problem hiding this comment.
🧹 Nitpick comments (1)
fla/utils.py (1)
490-495: Unify Blackwell detection to avoid semantic drift.Line 490 uses
>= 10, but the exportedIS_NVIDIA_BLACKWELLflag is currently== 10and is consumed in other files (fla/ops/gated_delta_rule/wy_fast.py:14-30). Please align this to one source of truth to prevent divergent behavior across modules.Suggested diff
-IS_NVIDIA_BLACKWELL = (IS_NVIDIA and torch.cuda.get_device_capability()[0] == 10) +IS_NVIDIA_BLACKWELL = (IS_NVIDIA and torch.cuda.get_device_capability()[0] >= 10) ... -elif IS_NVIDIA and torch.cuda.get_device_capability()[0] >= 10: +elif IS_NVIDIA_BLACKWELL:🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@fla/utils.py` around lines 490 - 495, The Blackwell detection is inconsistent: this file uses IS_NVIDIA and torch.cuda.get_device_capability()[0] >= 10 while other modules use the exported flag IS_NVIDIA_BLACKWELL; unify to a single source of truth by replacing the local capability check with the exported IS_NVIDIA_BLACKWELL flag (or update IS_NVIDIA_BLACKWELL to reflect the same >=10 logic) so all modules (e.g., fla/ops/gated_delta_rule/wy_fast.py) behave identically; specifically, change the condition that registers the Triton allocator to use IS_NVIDIA_BLACKWELL and keep the triton.set_allocator(_default_alloc_fn) call unchanged.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Nitpick comments:
In `@fla/utils.py`:
- Around line 490-495: The Blackwell detection is inconsistent: this file uses
IS_NVIDIA and torch.cuda.get_device_capability()[0] >= 10 while other modules
use the exported flag IS_NVIDIA_BLACKWELL; unify to a single source of truth by
replacing the local capability check with the exported IS_NVIDIA_BLACKWELL flag
(or update IS_NVIDIA_BLACKWELL to reflect the same >=10 logic) so all modules
(e.g., fla/ops/gated_delta_rule/wy_fast.py) behave identically; specifically,
change the condition that registers the Triton allocator to use
IS_NVIDIA_BLACKWELL and keep the triton.set_allocator(_default_alloc_fn) call
unchanged.
…eedback) - Use shared IS_NVIDIA_BLACKWELL constant instead of inline capability check - Change IS_NVIDIA_BLACKWELL from == 10 to >= 10 for forward compatibility with future NVIDIA architectures beyond Blackwell - Addresses CodeRabbit and Gemini review feedback
There was a problem hiding this comment.
Actionable comments posted: 1
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@fla/utils.py`:
- Around line 483-485: The allocator _default_alloc_fn currently ignores the
alignment parameter and returns torch.empty(...) which may produce misaligned
buffers; update _default_alloc_fn to honor alignment by allocating with extra
padding and returning an aligned buffer (or using an alignment-aware PyTorch
allocator if available): allocate at least size + alignment-1 bytes on the same
device (using device_name and device_torch_lib.current_device()), compute an
aligned start offset based on the returned tensor.data_ptr(), and return a
tensor view that begins at that aligned address with the requested size and
appropriate dtype (int8) so Triton’s alignment contract is satisfied.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
| def _default_alloc_fn(size: int, alignment: int, stream: int | None): | ||
| return torch.empty(size, device=torch.device(device_name, device_torch_lib.current_device()), dtype=torch.int8) | ||
|
|
There was a problem hiding this comment.
🧩 Analysis chain
🌐 Web query:
triton set_allocator callback signature alignment stream requirements
💡 Result:
The triton.set_allocator function in the Triton Language (triton-lang/triton Python package) allows users to specify a custom memory allocator callback for GPU kernel memory allocations. This is particularly useful for custom memory management, such as using torch.empty for temporary buffers or integrating with custom pools. Callback signature: def alloc_fn(size: int, alignment: int, stream: Optional[int] = None) -> torch.Tensor - size: Number of bytes to allocate. - alignment: Required alignment in bytes for the allocation (e.g., 128 or 256 bytes common for shared memory or cache lines; Triton's internal shared memory allocator uses 128-byte scratchAlignment). - stream: Optional CUDA stream handle (int); if provided, allocation should be stream-ordered (e.g., using cudaMallocAsync if implementing in C++/CUDA). Allocations must respect stream ordering to avoid synchronization issues. Usage example: def alloc_fn(size: int, alignment: int, stream: Optional[int]): return torch.empty(size, device="cuda", dtype=torch.int8) # or pinned memory, custom pool slice, etc. triton.set_allocator(alloc_fn) Requirements: - Alignment: Must return a pointer aligned to the requested 'alignment' bytes. Triton enforces alignments like 128 bytes for scratch buffers (lib/Analysis/Allocation.cpp). Misalignment can cause warp memory errors (e.g., Issue #5136). - Stream: Optional but should use stream-ordered allocation (e.g., torch.empty implicitly uses current stream; for explicit control, use CUDA stream-ordered APIs like cudaMallocAsync from CUDA 11.2+ docs). Supports CUDA stream-ordered memory allocator for async alloc/free without global sync. - Return: A torch.Tensor (device="cuda") whose data_ptr is the allocated memory. Tensor should be contiguous and match size/alignment. - Thread-safe and efficient for kernel compilation/runtime. This feature is recent (not in all released versions; build from source or use latest PyTorch/Triton for availability, per Issues #6191, #155584). Used in advanced kernels like grouped-GEMM tutorials and third-party libs (e.g., gemlite). No official docs page found; inferred from source code patterns, tutorials, and issues.
Citations:
- 1: https://github.com/mobiusml/gemlite/blob/master/gemlite/triton_kernels/utils.py
- 2: Bug module 'triton' has no attribute 'set_allocator' triton-lang/triton#6191
- 3: [Upstream Triton] Handle user-specified triton.set_allocator function pytorch/pytorch#155584
- 4: https://github.com/openai/triton/blob/main/lib/Analysis/Allocation.cpp
- 5: [Alloc] Fixed alignment for shared memory allocation triton-lang/triton#3854
- 6: Warp memory alignment error when manually launching compiled PTX triton-lang/triton#5136
- 7: https://docs.nvidia.com/cuda/cuda-programming-guide/04-special-topics/stream-ordered-memory-allocation.html
The alignment parameter must not be ignored in the allocator callback.
While ignoring stream is acceptable (since torch.empty implicitly uses the current stream), the alignment parameter is part of Triton's allocator contract and must be respected. Triton enforces alignment requirements (e.g., 128 bytes for scratch buffers), and misaligned allocations can cause warp memory errors. The current implementation using torch.empty(size, ...) does not guarantee the required alignment. This needs to be fixed by either using an alignment-aware allocation method or manually aligning the returned pointer.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@fla/utils.py` around lines 483 - 485, The allocator _default_alloc_fn
currently ignores the alignment parameter and returns torch.empty(...) which may
produce misaligned buffers; update _default_alloc_fn to honor alignment by
allocating with extra padding and returning an aligned buffer (or using an
alignment-aware PyTorch allocator if available): allocate at least size +
alignment-1 bytes on the same device (using device_name and
device_torch_lib.current_device()), compute an aligned start offset based on the
returned tensor.data_ptr(), and return a tensor view that begins at that aligned
address with the requested size and appropriate dtype (int8) so Triton’s
alignment contract is satisfied.
Problem
On Blackwell GPUs (SM 10.0+), the Triton compiler emits
global_scratchmemory for autotuned kernels even when TMA is not used (FLA_USE_TMA=0, the default). Without an allocator registered, this causesNullAllocatorcrashes during kernel autotuning:The exception is caught by the autotuner's
_bench(), but this corrupts CUDA synchronization state, leading to process deadlocks onfutex_wait_queue. This affects all MoE+Mamba models (Qwen3-Coder-Next, Qwen3.5) running on Blackwell via vLLM.The standard workaround (
--enforce-eager) disables CUDA graphs andtorch.compile, costing 12x performance (13 tok/s vs 156 tok/s on RTX PRO 6000).Root cause
The existing allocator registration in
fla/utils.pyonly runs whenIS_TMA_SUPPORTEDis True, which requiresFLA_USE_TMA=1. On Blackwell with the defaultFLA_USE_TMA=0, no allocator is registered, but the Triton compiler still needsglobal_scratchfor inter-CTA workspace in autotuned kernel configurations.Fix
Register the default
torch.empty-based allocator on Blackwell (capability >= 10) regardless ofFLA_USE_TMA. This is a 5-line change infla/utils.py.Testing
Tested on RTX PRO 6000 Blackwell (SM 12.0, CUDA 13.1, Triton 3.6.0) with vLLM serving Qwen3-Coder-Next 80B:
--enforce-eager)Reproduce script: https://gist.github.com/ssubbotin/2cfa8ac4f3904df66872a882e44eeb86
Related
Summary by CodeRabbit