Skip to content

fix: register default global_scratch allocator on Blackwell GPUs#825

Merged
zhiyuan1i merged 4 commits intofla-org:mainfrom
ssubbotin:fix/blackwell-global-scratch-allocator
Apr 13, 2026
Merged

fix: register default global_scratch allocator on Blackwell GPUs#825
zhiyuan1i merged 4 commits intofla-org:mainfrom
ssubbotin:fix/blackwell-global-scratch-allocator

Conversation

@ssubbotin
Copy link
Copy Markdown
Contributor

@ssubbotin ssubbotin commented Apr 12, 2026

Problem

On Blackwell GPUs (SM 10.0+), the Triton compiler emits global_scratch memory for autotuned kernels even when TMA is not used (FLA_USE_TMA=0, the default). Without an allocator registered, this causes NullAllocator crashes during kernel autotuning:

RuntimeError: Kernel requires a runtime memory allocation, but no allocator was set.

The exception is caught by the autotuner's _bench(), but this corrupts CUDA synchronization state, leading to process deadlocks on futex_wait_queue. This affects all MoE+Mamba models (Qwen3-Coder-Next, Qwen3.5) running on Blackwell via vLLM.

The standard workaround (--enforce-eager) disables CUDA graphs and torch.compile, costing 12x performance (13 tok/s vs 156 tok/s on RTX PRO 6000).

Root cause

The existing allocator registration in fla/utils.py only runs when IS_TMA_SUPPORTED is True, which requires FLA_USE_TMA=1. On Blackwell with the default FLA_USE_TMA=0, no allocator is registered, but the Triton compiler still needs global_scratch for inter-CTA workspace in autotuned kernel configurations.

Fix

Register the default torch.empty-based allocator on Blackwell (capability >= 10) regardless of FLA_USE_TMA. This is a 5-line change in fla/utils.py.

Testing

Tested on RTX PRO 6000 Blackwell (SM 12.0, CUDA 13.1, Triton 3.6.0) with vLLM serving Qwen3-Coder-Next 80B:

Configuration tok/s
Without fix (--enforce-eager) 13.3
With this fix 156.4

Reproduce script: https://gist.github.com/ssubbotin/2cfa8ac4f3904df66872a882e44eeb86

Related

Summary by CodeRabbit

  • Chores
    • Improved GPU memory allocation for more consistent behavior across device families.
    • Expanded allocator registration to cover newer NVIDIA GPU classes and added logging for global scratch allocations.
    • Centralized allocation logic and reduced conditional divergence for more predictable performance.
    • Improved reliability across varying TMA support configurations.

On Blackwell (SM 10.0+), the Triton compiler emits global_scratch
memory for autotuned kernels even when TMA is not used (FLA_USE_TMA=0).
Without an allocator registered, this causes NullAllocator crashes
during kernel autotuning, which corrupts CUDA synchronization state
and leads to process deadlocks.

The existing allocator registration only runs when IS_TMA_SUPPORTED
is True (requires FLA_USE_TMA=1). This change also registers the
allocator on Blackwell when TMA is disabled, since the compiler
still needs scratch space for other purposes on SM 10.0+.

Fixes deadlocks when running MoE+Mamba models (Qwen3-Coder-Next,
Qwen3.5) on Blackwell GPUs via vLLM.

See: triton-lang/triton#10002
@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai bot commented Apr 12, 2026

Walkthrough

Broadened Blackwell detection to include devices with compute capability major >= 10; added a module-level _default_alloc_fn(size, alignment, stream) that creates an int8 tensor on the current CUDA device; register _default_alloc_fn with triton.set_allocator for TMA-supported devices and for NVIDIA Blackwell-class GPUs.

Changes

Cohort / File(s) Summary
Allocator / GPU detection
fla/utils.py
Changed IS_NVIDIA_BLACKWELL from torch.cuda.get_device_capability()[0] == 10 to >= 10; added module-level _default_alloc_fn(size, alignment, stream) that allocates an int8 tensor on the current mapped CUDA device; replaced prior local alloc_fn with _default_alloc_fn for triton.set_allocator() when IS_TMA_SUPPORTED is true; added elif IS_NVIDIA_BLACKWELL branch to call triton.set_allocator(_default_alloc_fn) and emit a log message for Blackwell/global_scratch behavior.

Sequence Diagram(s)

sequenceDiagram
    participant Py as PythonModule
    participant Triton as TritonRuntime
    participant CUDA as NVIDIA_GPU
    Py->>Triton: triton.set_allocator(_default_alloc_fn)
    Note over Triton,CUDA: Triton stores allocator callback
    Triton->>CUDA: request global_scratch allocation via callback
    CUDA-->>Triton: returns device memory tensor (int8)
    Triton-->>Py: allocator registered (log)
Loading

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~20 minutes

Possibly related PRs

Poem

🐰 I hop through bytes and CUDA nights,
A tiny allocator lights up flights,
Blackwell calls, tensors align,
Triton hears my little sign,
Hooray — memory snug and bright! ✨

🚥 Pre-merge checks | ✅ 2 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 0.00% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (2 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title clearly and concisely describes the main change: registering a default allocator on Blackwell GPUs, which directly addresses the runtime issue detailed in the PR objectives.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request refactors the Triton allocator setup in fla/utils.py by extracting a default allocation function and extending its registration to NVIDIA Blackwell (SM 10.0+) devices to prevent NullAllocator crashes. A review comment identifies that hardcoding the GPU index in the capability check may cause issues in multi-GPU systems and suggests using the current device instead.

Comment thread fla/utils.py Outdated

triton.set_allocator(alloc_fn)
triton.set_allocator(_default_alloc_fn)
elif IS_NVIDIA and torch.cuda.get_device_capability(0)[0] >= 10:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Using torch.cuda.get_device_capability(0) hardcodes the check to the first GPU. In heterogeneous multi-GPU systems where the current device (e.g., device 1) is Blackwell but device 0 is an older architecture, this check will incorrectly return False at import time if device 0 is the default, potentially leading to the NullAllocator crash you are addressing.

It is better to use torch.cuda.get_device_capability() without arguments to check the current device, which is also consistent with how IS_NVIDIA_BLACKWELL is defined at line 467. Alternatively, you could use the IS_NVIDIA_BLACKWELL constant directly if you update its definition to use >= 10.

Suggested change
elif IS_NVIDIA and torch.cuda.get_device_capability(0)[0] >= 10:
elif IS_NVIDIA and torch.cuda.get_device_capability()[0] >= 10:

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🧹 Nitpick comments (1)
fla/utils.py (1)

488-488: Avoid Blackwell capability check drift.

Consider reusing IS_NVIDIA_BLACKWELL here and updating that constant to match >= 10; right now the constant (== 10) and this branch (>= 10) encode different definitions.

♻️ Proposed refactor
-IS_NVIDIA_BLACKWELL = (IS_NVIDIA and torch.cuda.get_device_capability()[0] == 10)
+IS_NVIDIA_BLACKWELL = (IS_NVIDIA and torch.cuda.get_device_capability()[0] >= 10)

-elif IS_NVIDIA and torch.cuda.get_device_capability(0)[0] >= 10:
+elif IS_NVIDIA_BLACKWELL:
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@fla/utils.py` at line 488, The branch uses a different Blackwell capability
test (torch.cuda.get_device_capability(0)[0] >= 10) than the constant
IS_NVIDIA_BLACKWELL (which is == 10), so replace the inline capability check
with the shared constant and change the IS_NVIDIA_BLACKWELL definition to use >=
10; specifically, update the IS_NVIDIA_BLACKWELL assignment to compute
torch.cuda.get_device_capability(0)[0] >= 10 (preserving existing IS_NVIDIA
gating) and change the branch `elif IS_NVIDIA and
torch.cuda.get_device_capability(0)[0] >= 10:` to `elif IS_NVIDIA_BLACKWELL:` so
both places use the same semantic test.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Nitpick comments:
In `@fla/utils.py`:
- Line 488: The branch uses a different Blackwell capability test
(torch.cuda.get_device_capability(0)[0] >= 10) than the constant
IS_NVIDIA_BLACKWELL (which is == 10), so replace the inline capability check
with the shared constant and change the IS_NVIDIA_BLACKWELL definition to use >=
10; specifically, update the IS_NVIDIA_BLACKWELL assignment to compute
torch.cuda.get_device_capability(0)[0] >= 10 (preserving existing IS_NVIDIA
gating) and change the branch `elif IS_NVIDIA and
torch.cuda.get_device_capability(0)[0] >= 10:` to `elif IS_NVIDIA_BLACKWELL:` so
both places use the same semantic test.

ℹ️ Review info
⚙️ Run configuration

Configuration used: Repository UI

Review profile: CHILL

Plan: Pro

Run ID: 77dea07e-739e-4527-a5fa-4fc1e9d745b1

📥 Commits

Reviewing files that changed from the base of the PR and between 967f8c0 and a668170.

📒 Files selected for processing (1)
  • fla/utils.py

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🧹 Nitpick comments (1)
fla/utils.py (1)

490-495: Unify Blackwell detection to avoid semantic drift.

Line 490 uses >= 10, but the exported IS_NVIDIA_BLACKWELL flag is currently == 10 and is consumed in other files (fla/ops/gated_delta_rule/wy_fast.py:14-30). Please align this to one source of truth to prevent divergent behavior across modules.

Suggested diff
-IS_NVIDIA_BLACKWELL = (IS_NVIDIA and torch.cuda.get_device_capability()[0] == 10)
+IS_NVIDIA_BLACKWELL = (IS_NVIDIA and torch.cuda.get_device_capability()[0] >= 10)
...
-elif IS_NVIDIA and torch.cuda.get_device_capability()[0] >= 10:
+elif IS_NVIDIA_BLACKWELL:
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@fla/utils.py` around lines 490 - 495, The Blackwell detection is
inconsistent: this file uses IS_NVIDIA and torch.cuda.get_device_capability()[0]
>= 10 while other modules use the exported flag IS_NVIDIA_BLACKWELL; unify to a
single source of truth by replacing the local capability check with the exported
IS_NVIDIA_BLACKWELL flag (or update IS_NVIDIA_BLACKWELL to reflect the same >=10
logic) so all modules (e.g., fla/ops/gated_delta_rule/wy_fast.py) behave
identically; specifically, change the condition that registers the Triton
allocator to use IS_NVIDIA_BLACKWELL and keep the
triton.set_allocator(_default_alloc_fn) call unchanged.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Nitpick comments:
In `@fla/utils.py`:
- Around line 490-495: The Blackwell detection is inconsistent: this file uses
IS_NVIDIA and torch.cuda.get_device_capability()[0] >= 10 while other modules
use the exported flag IS_NVIDIA_BLACKWELL; unify to a single source of truth by
replacing the local capability check with the exported IS_NVIDIA_BLACKWELL flag
(or update IS_NVIDIA_BLACKWELL to reflect the same >=10 logic) so all modules
(e.g., fla/ops/gated_delta_rule/wy_fast.py) behave identically; specifically,
change the condition that registers the Triton allocator to use
IS_NVIDIA_BLACKWELL and keep the triton.set_allocator(_default_alloc_fn) call
unchanged.

ℹ️ Review info
⚙️ Run configuration

Configuration used: Repository UI

Review profile: CHILL

Plan: Pro

Run ID: 48c6c3c8-e807-4eb9-8de7-adae01480d9f

📥 Commits

Reviewing files that changed from the base of the PR and between a668170 and 84f040a.

📒 Files selected for processing (1)
  • fla/utils.py

Comment thread fla/utils.py Outdated
…eedback)

- Use shared IS_NVIDIA_BLACKWELL constant instead of inline capability check
- Change IS_NVIDIA_BLACKWELL from == 10 to >= 10 for forward compatibility
  with future NVIDIA architectures beyond Blackwell
- Addresses CodeRabbit and Gemini review feedback
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@fla/utils.py`:
- Around line 483-485: The allocator _default_alloc_fn currently ignores the
alignment parameter and returns torch.empty(...) which may produce misaligned
buffers; update _default_alloc_fn to honor alignment by allocating with extra
padding and returning an aligned buffer (or using an alignment-aware PyTorch
allocator if available): allocate at least size + alignment-1 bytes on the same
device (using device_name and device_torch_lib.current_device()), compute an
aligned start offset based on the returned tensor.data_ptr(), and return a
tensor view that begins at that aligned address with the requested size and
appropriate dtype (int8) so Triton’s alignment contract is satisfied.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Repository UI

Review profile: CHILL

Plan: Pro

Run ID: 67dfd420-43b0-44f0-8b88-8bf7355fdc4f

📥 Commits

Reviewing files that changed from the base of the PR and between 84f040a and c426461.

📒 Files selected for processing (1)
  • fla/utils.py

Comment thread fla/utils.py
Comment on lines +483 to +485
def _default_alloc_fn(size: int, alignment: int, stream: int | None):
return torch.empty(size, device=torch.device(device_name, device_torch_lib.current_device()), dtype=torch.int8)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

🌐 Web query:

triton set_allocator callback signature alignment stream requirements

💡 Result:

The triton.set_allocator function in the Triton Language (triton-lang/triton Python package) allows users to specify a custom memory allocator callback for GPU kernel memory allocations. This is particularly useful for custom memory management, such as using torch.empty for temporary buffers or integrating with custom pools. Callback signature: def alloc_fn(size: int, alignment: int, stream: Optional[int] = None) -> torch.Tensor - size: Number of bytes to allocate. - alignment: Required alignment in bytes for the allocation (e.g., 128 or 256 bytes common for shared memory or cache lines; Triton's internal shared memory allocator uses 128-byte scratchAlignment). - stream: Optional CUDA stream handle (int); if provided, allocation should be stream-ordered (e.g., using cudaMallocAsync if implementing in C++/CUDA). Allocations must respect stream ordering to avoid synchronization issues. Usage example: def alloc_fn(size: int, alignment: int, stream: Optional[int]): return torch.empty(size, device="cuda", dtype=torch.int8) # or pinned memory, custom pool slice, etc. triton.set_allocator(alloc_fn) Requirements: - Alignment: Must return a pointer aligned to the requested 'alignment' bytes. Triton enforces alignments like 128 bytes for scratch buffers (lib/Analysis/Allocation.cpp). Misalignment can cause warp memory errors (e.g., Issue #5136). - Stream: Optional but should use stream-ordered allocation (e.g., torch.empty implicitly uses current stream; for explicit control, use CUDA stream-ordered APIs like cudaMallocAsync from CUDA 11.2+ docs). Supports CUDA stream-ordered memory allocator for async alloc/free without global sync. - Return: A torch.Tensor (device="cuda") whose data_ptr is the allocated memory. Tensor should be contiguous and match size/alignment. - Thread-safe and efficient for kernel compilation/runtime. This feature is recent (not in all released versions; build from source or use latest PyTorch/Triton for availability, per Issues #6191, #155584). Used in advanced kernels like grouped-GEMM tutorials and third-party libs (e.g., gemlite). No official docs page found; inferred from source code patterns, tutorials, and issues.

Citations:


The alignment parameter must not be ignored in the allocator callback.

While ignoring stream is acceptable (since torch.empty implicitly uses the current stream), the alignment parameter is part of Triton's allocator contract and must be respected. Triton enforces alignment requirements (e.g., 128 bytes for scratch buffers), and misaligned allocations can cause warp memory errors. The current implementation using torch.empty(size, ...) does not guarantee the required alignment. This needs to be fixed by either using an alignment-aware allocation method or manually aligning the returned pointer.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@fla/utils.py` around lines 483 - 485, The allocator _default_alloc_fn
currently ignores the alignment parameter and returns torch.empty(...) which may
produce misaligned buffers; update _default_alloc_fn to honor alignment by
allocating with extra padding and returning an aligned buffer (or using an
alignment-aware PyTorch allocator if available): allocate at least size +
alignment-1 bytes on the same device (using device_name and
device_torch_lib.current_device()), compute an aligned start offset based on the
returned tensor.data_ptr(), and return a tensor view that begins at that aligned
address with the requested size and appropriate dtype (int8) so Triton’s
alignment contract is satisfied.

@zhiyuan1i zhiyuan1i merged commit 8b05e2f into fla-org:main Apr 13, 2026
4 of 5 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants