Skip to content

Add default global_scratch allocator fallback for Blackwell SM 12.0#10002

Closed
ssubbotin wants to merge 1 commit intotriton-lang:mainfrom
ssubbotin:fix/blackwell-global-scratch-allocator
Closed

Add default global_scratch allocator fallback for Blackwell SM 12.0#10002
ssubbotin wants to merge 1 commit intotriton-lang:mainfrom
ssubbotin:fix/blackwell-global-scratch-allocator

Conversation

@ssubbotin
Copy link
Copy Markdown

Summary

On Blackwell (SM 12.0+), Triton kernels may require global_scratch memory for cooperative operations. When no explicit allocator is configured via triton.set_allocator(), the NullAllocator raises RuntimeError, crashing any kernel that uses global_scratch.

This adds allocate_default_global_scratch() to GPUDriver — mirroring the existing allocate_default_profile_scratch() pattern — and uses it as a fallback in both NVIDIA and AMD backend launchers when NullAllocator is detected.

Problem

RuntimeError: Kernel requires a runtime memory allocation, but no allocator was set.
Use triton.set_allocator to specify an allocator.

This crashes on consumer Blackwell GPUs (RTX PRO 6000, RTX 5090, RTX 5080) when running Triton kernels that use global_scratch — e.g., FLA solve_tril in vLLM for MoE+Mamba models like Qwen3-Coder-Next.

On pre-Blackwell GPUs, these kernels don't use global_scratch, so the issue doesn't surface.

Fix

  • Add allocate_default_global_scratch() to GPUDriver (mirrors allocate_default_profile_scratch())
  • In allocate_scratch() in both NVIDIA and AMD launchers, fall back to the new method when NullAllocator is the current allocator
  • Uses torch.empty() for allocation, consistent with the existing profile scratch pattern

Testing

Verified on RTX PRO 6000 Blackwell (SM 12.0, CUDA 13.1):

  • Without fix: RuntimeError on any kernel using global_scratch
  • With fix: FLA chunk_gated_delta_rule compiles and runs correctly (27.4s first compile, correct output)

Related Issues

On Blackwell (SM 12.0+), Triton kernels may require global_scratch
memory for cooperative operations. When no explicit allocator is
configured via triton.set_allocator(), the NullAllocator raises
RuntimeError, crashing any kernel that uses global_scratch.

This adds allocate_default_global_scratch() to GPUDriver — mirroring
the existing allocate_default_profile_scratch() pattern — and uses it
as a fallback in both NVIDIA and AMD backend launchers when
NullAllocator is detected.

Fixes kernel crashes on RTX PRO 6000, RTX 5090, and other Blackwell
consumer GPUs when running Triton kernels that use global_scratch
(e.g., FLA solve_tril in vLLM for MoE+Mamba models).
@ssubbotin
Copy link
Copy Markdown
Author

Performance impact of this fix

Without this fix, the standard workaround for Blackwell is --enforce-eager (disabling CUDA graphs and torch.compile). We measured the cost of that workaround:

Qwen3-Coder-Next 80B (AWQ 4-bit) on RTX PRO 6000 Blackwell:

Mode tok/s Notes
--enforce-eager (workaround) 13.3 No CUDA graphs, no torch.compile
With this fix (CUDA graphs enabled) 156.4 torch.compile + CUDA graph capture works

The workaround costs 12x performance. This fix restores full CUDA graph + torch.compile performance on Blackwell by providing a default global_scratch allocator.

Tested on RTX PRO 6000 (SM 12.0, CUDA 13.1, Triton 3.6.0) running vLLM with FLA kernels for MoE+Mamba models.

@Jokeren
Copy link
Copy Markdown
Contributor

Jokeren commented Apr 11, 2026

Triton kernels may require global_scratch memory for cooperative operations.

Can you explain a bit more with an example?

@Jokeren
Copy link
Copy Markdown
Contributor

Jokeren commented Apr 11, 2026

More specifically, why cannot you use set_allocator here? Maybe I missed the context

@ssubbotin
Copy link
Copy Markdown
Author

@Jokeren Sure — here's the concrete example and a standalone reproduce script.

Example: FLA solve_tril kernel in vLLM

When running MoE+Mamba models (Qwen3-Coder-Next, Qwen3.5) on Blackwell GPUs via vLLM, the FLA (Flash Linear Attention) solve_tril kernel compiles to use global_scratch memory on SM 12.0. This doesn't happen on pre-Blackwell GPUs for the same kernel.

The kernel in question is the @triton.autotune-decorated merge_fn in fla/ops/solve_tril.py. When the autotuner benchmarks each configuration, it launches the kernel, which calls allocate_scratchNullAllocator.__call__()RuntimeError.

Since the autotuner catches exceptions during benchmarking, the RuntimeError doesn't surface to the user. Instead, the CUDA synchronization state is corrupted and the process deadlocks on futex_wait_queue.

The standard workaround is --enforce-eager (disabling CUDA graphs and torch.compile), but this costs 12x performance (13 tok/s vs 156 tok/s on RTX PRO 6000).

Reproduce script

Standalone — requires only torch + triton + vllm (for the FLA kernel). Run on any Blackwell GPU:

Gist: https://gist.github.com/ssubbotin/2cfa8ac4f3904df66872a882e44eeb86

# Shows the crash (Blackwell only):
python reproduce_blackwell_deadlock.py

# With the allocator fix applied:
python reproduce_blackwell_deadlock.py --fix

Output without fix (RTX PRO 6000, SM 12.0, Triton 3.6.0):

Test 1: FLA chunk_gated_delta_rule (prefill path)...
  CRASH: Kernel requires a runtime memory allocation, but no allocator was set.

Output with fix:

Test 1: FLA chunk_gated_delta_rule (prefill path)...
  PASS (27.2s, output shape torch.Size([1, 128, 8, 128]))

The 27.2s is one-time JIT compilation. After that, CUDA graphs capture the compiled kernel and decode runs at 156 tok/s.

Why global_scratch on Blackwell?

The Triton compiler emits global_scratch for kernels that need inter-CTA communication or large workspace on SM 12.0. The same kernel compiled for SM 8.9 (Ada) or SM 9.0 (Hopper) doesn't use global_scratch, which is why this issue only surfaces on Blackwell.

The NullAllocator default was fine when no kernels needed scratch, but Blackwell changes that assumption.

@ssubbotin
Copy link
Copy Markdown
Author

@Jokeren Re: why not use set_allocator — two reasons:

1. Library code can't call set_allocator without conflicting with user code.

The crash happens inside vLLM's FLA kernel library, which is third-party code consumed by vLLM. If FLA calls triton.set_allocator() at import time, it would overwrite any allocator the end user (or vLLM, or another framework) already configured. There's no "add a fallback" API — set_allocator is global and last-writer-wins.

Every library that uses Triton kernels with global_scratch would need to call set_allocator, and they'd conflict with each other. The allocator should have a sensible default, not require every caller to configure it.

2. set_allocator doesn't propagate across process boundaries.

vLLM spawns the EngineCore as a separate process via multiprocessing.spawn. The allocator is stored in a ContextVar, which is per-process. Even if the parent process calls set_allocator, the child process (where the kernels actually run) still has NullAllocator.

We tried patching via usercustomize.py and PYTHONPATH — the ContextVar gets reset during process initialization. The only reliable workaround we found was monkey-patching allocate_scratch in driver.py directly.

The fix in this PR mirrors the existing allocate_default_profile_scratch() pattern — Triton already has a torch-based fallback for profile scratch when no allocator is set. This PR does the same for global scratch.

@masahi
Copy link
Copy Markdown
Collaborator

masahi commented Apr 11, 2026

does the kernel in question use make_tensor_descriptor? This is the only case that I am aware of that needs scratch space allocation from user. This is TMA with descriptor created in kernel. The right solution is to create descriptors on host via TensorDescriptor constructor. I dont think this is sm120-specific issue.

Copy link
Copy Markdown
Collaborator

@ThomasRaoux ThomasRaoux left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It should be user's responsibility to attach an allocator we don't want the compiler to do allocation under the hood

@ThomasRaoux
Copy link
Copy Markdown
Collaborator

does the kernel in question use make_tensor_descriptor? This is the only case that I am aware of that needs scratch space allocation from user. This is TMA with descriptor created in kernel. The right solution is to create descriptors on host via TensorDescriptor constructor. I dont think this is sm120-specific issue.

+1

@ssubbotin
Copy link
Copy Markdown
Author

@masahi @ThomasRaoux Thank you for the review. I want to clarify an important detail:

The crash happens WITHOUT make_tensor_descriptor. TMA is disabled by default in FLA (FLA_USE_TMA=0), and the kernel uses tl.make_block_ptr, not tl.make_tensor_descriptor. The Triton compiler still emits global_scratch on SM 12.0 (Blackwell) for the autotuned solve_tril kernel.

To verify:

# FLA_USE_TMA is NOT set (defaults to '0')
# The kernel path is the non-TMA branch: tl.make_block_ptr
python reproduce_blackwell_deadlock.py
# → CRASH: Kernel requires a runtime memory allocation, but no allocator was set.

The USE_TMA code path (with make_tensor_descriptor) is behind an opt-in env var and is NOT active during the crash.

So the suggestion to "use host-side TensorDescriptor" doesn't apply here — the kernel isn't using TMA at all. The global_scratch allocation is generated by the Triton compiler for SM 12.0 for reasons unrelated to TMA descriptors (likely inter-CTA workspace for the autotuned configurations).

This is why we believe a default allocator fallback is the right fix — user code shouldn't need to know that the compiler decided to use scratch space internally.

Happy to investigate further which compiler pass introduces the global_scratch on SM 12.0 if that would help.

ssubbotin added a commit to ssubbotin/flash-linear-attention that referenced this pull request Apr 12, 2026
On Blackwell (SM 10.0+), the Triton compiler emits global_scratch
memory for autotuned kernels even when TMA is not used (FLA_USE_TMA=0).
Without an allocator registered, this causes NullAllocator crashes
during kernel autotuning, which corrupts CUDA synchronization state
and leads to process deadlocks.

The existing allocator registration only runs when IS_TMA_SUPPORTED
is True (requires FLA_USE_TMA=1). This change also registers the
allocator on Blackwell when TMA is disabled, since the compiler
still needs scratch space for other purposes on SM 10.0+.

Fixes deadlocks when running MoE+Mamba models (Qwen3-Coder-Next,
Qwen3.5) on Blackwell GPUs via vLLM.

See: triton-lang/triton#10002
@ssubbotin
Copy link
Copy Markdown
Author

For reference, we submitted a workaround to FLA: fla-org/flash-linear-attention#825

That PR registers a default allocator on Blackwell regardless of FLA_USE_TMA. This addresses the immediate user-facing issue.

However, we still believe Triton should provide a default allocator (or at least a better error path) when the compiler decides to use global_scratch — user code has no way to know this will happen, and the current behavior (silent deadlock via corrupted CUDA state in autotuner) is hard to debug.

Happy to keep this PR open or close it depending on your preference.

@Jokeren
Copy link
Copy Markdown
Contributor

Jokeren commented Apr 12, 2026

However, we still believe Triton should provide a default allocator (or at least a better error path) when the compiler decides to use global_scratch — user code has no way to know this will happen, and the current behavior (silent deadlock via corrupted CUDA state in autotuner) is hard to debug.

We shouldn't provide a default allocator. I still suspect there's something wrong. Are you able to provide a single script reproducer? Thanks

@Jokeren
Copy link
Copy Markdown
Contributor

Jokeren commented Apr 12, 2026

However, we still believe Triton should provide a default allocator (or at least a better error path) when the compiler decides to use global_scratch — user code has no way to know this will happen, and the current behavior (silent deadlock via corrupted CUDA state in autotuner) is hard to debug.

We shouldn't provide a default allocator. I still suspect there's something wrong. Are you able to provide a single script reproducer? Thanks

Looks like reproduce_blackwell_deadlock.py has a lot of dependencies

@masahi
Copy link
Copy Markdown
Collaborator

masahi commented Apr 12, 2026

Running your script with MLIR_ENABLE_DUMP=1 TRITON_ALWAYS_COMPILE=1 should tell why and how "the compiler decides to use global_scratch". I still think this is coming from make_tensor_descriptor.

@masahi
Copy link
Copy Markdown
Collaborator

masahi commented Apr 12, 2026

So your reproducer uses the FLA kernel vendored in vllm. The decision to use TMA in vllm seems to have changed last week: vllm-project/vllm#38981. So if you are using vllm prior to that commit, this explains what's happening.

@ssubbotin
Copy link
Copy Markdown
Author

@masahi @Jokeren Thank you for digging into this — you were right.

Our vLLM Docker image was built before vllm-project/vllm#38981 (merged April 4), which aligned vLLM's vendored FLA copy with upstream's FLA_USE_TMA gating. Our image had the old code where is_tma_supported = True unconditionally on SM≥9, so the TMA code path (make_tensor_descriptor) was active on Blackwell — exactly what you suspected.

Updating to a vLLM build that includes #38981 resolves the issue since the non-TMA path doesn't need global_scratch.

I apologize for the confusion — I should have verified more carefully which code path was active before asserting it wasn't TMA-related.

Happy to close this PR. We also submitted fla-org/flash-linear-attention#825 which registers a default allocator on Blackwell as defense-in-depth, but the real fix was already in vLLM #38981.

Thank you for your patience and the pointer to the vLLM change.

@ssubbotin
Copy link
Copy Markdown
Author

Closing — the root cause was our vLLM image using a pre-#38981 vendored FLA copy that unconditionally enabled TMA on Blackwell. Updating vLLM resolves the issue. Thank you for the review.

@ssubbotin ssubbotin closed this Apr 12, 2026
zhiyuan1i pushed a commit to fla-org/flash-linear-attention that referenced this pull request Apr 13, 2026
* fix: register default global_scratch allocator on Blackwell GPUs

On Blackwell (SM 10.0+), the Triton compiler emits global_scratch
memory for autotuned kernels even when TMA is not used (FLA_USE_TMA=0).
Without an allocator registered, this causes NullAllocator crashes
during kernel autotuning, which corrupts CUDA synchronization state
and leads to process deadlocks.

The existing allocator registration only runs when IS_TMA_SUPPORTED
is True (requires FLA_USE_TMA=1). This change also registers the
allocator on Blackwell when TMA is disabled, since the compiler
still needs scratch space for other purposes on SM 10.0+.

Fixes deadlocks when running MoE+Mamba models (Qwen3-Coder-Next,
Qwen3.5) on Blackwell GPUs via vLLM.

See: triton-lang/triton#10002

* style: fix autopep8 blank lines

* fix: use current device for capability check (review feedback)

* refactor: use IS_NVIDIA_BLACKWELL constant, update to >= 10 (review feedback)

- Use shared IS_NVIDIA_BLACKWELL constant instead of inline capability check
- Change IS_NVIDIA_BLACKWELL from == 10 to >= 10 for forward compatibility
  with future NVIDIA architectures beyond Blackwell
- Addresses CodeRabbit and Gemini review feedback
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants