[Bugfix] Disable TMA on Blackwell GPUs to fix Triton autotuner OOM in fla/solve_trilfix: disable TMA on Blackwell (sm_12x) to prevent Triton autotuner OO…#36325
Conversation
|
👋 Hi! Thank you for contributing to the vLLM project. 💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels. Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run You ask your reviewers to trigger select CI tests on top of Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging. To run CI, PR reviewers can either: Add If you have any questions, please reach out to us on Slack at https://slack.vllm.ai. 🚀 |
There was a problem hiding this comment.
Code Review
This pull request addresses an out-of-memory error on Blackwell GPUs by disabling Tensor Memory Access (TMA) for this architecture. The fix is correct and targeted. My review includes a suggestion to refactor the implementation of the check to improve its readability and efficiency by avoiding a redundant function call.
| is_tma_supported = (is_nvidia and torch.cuda.get_device_capability(0)[0] >= 9) and ( | ||
| hasattr(triton.language, "_experimental_make_tensor_descriptor") | ||
| or hasattr(triton.language, "make_tensor_descriptor") | ||
| ) | ||
| ) and torch.cuda.get_device_capability(0)[0] < 12 # Disable on Blackwell (sm_12x): Triton autotuner OOM |
There was a problem hiding this comment.
While this fix is correct, the implementation can be improved for readability and efficiency. The expression for is_tma_supported now calls torch.cuda.get_device_capability(0)[0] twice and the formatting makes the line very long. It's better to combine the two compute capability checks into a single range check to avoid the redundant call and make the condition clearer.
| is_tma_supported = (is_nvidia and torch.cuda.get_device_capability(0)[0] >= 9) and ( | |
| hasattr(triton.language, "_experimental_make_tensor_descriptor") | |
| or hasattr(triton.language, "make_tensor_descriptor") | |
| ) | |
| ) and torch.cuda.get_device_capability(0)[0] < 12 # Disable on Blackwell (sm_12x): Triton autotuner OOM | |
| # Disable on Blackwell (sm_12x): Triton autotuner OOM | |
| is_tma_supported = (is_nvidia and 9 <= torch.cuda.get_device_capability(0)[0] < 12) and ( | |
| hasattr(triton.language, "_experimental_make_tensor_descriptor") | |
| or hasattr(triton.language, "make_tensor_descriptor") | |
| ) |
|
Hi @Rks2302, the pre-commit checks have failed. Please run: uv pip install pre-commit
pre-commit install
pre-commit run --all-filesThen, commit the changes and push to your branch. For future commits, Tip Is
|
1 similar comment
|
Hi @Rks2302, the pre-commit checks have failed. Please run: uv pip install pre-commit
pre-commit install
pre-commit run --all-filesThen, commit the changes and push to your branch. For future commits, Tip Is
|
…M in solve_tril Signed-off-by: Rks2302 <rahulksharma2302@gmail.com>
Signed-off-by: Rks2302 <rahulksharma2302@gmail.com>
Signed-off-by: Rks2302 <rahulksharma2302@gmail.com>
efd1eeb to
ae40230
Compare
|
I remember Hopper also had this OOM issue. We should find a better way to both avoid OOM and maintain performance. |
Summary
Fixes Triton autotuner OOM crash in
fla/ops/solve_tril.pywhen runningQwen3.5 models on Blackwell GPUs (RTX 5090, compute capability sm_12x).
Root Cause
is_tma_supportedevaluates toTrueon any GPU with compute capability >= 9,which includes Blackwell (sm_12x). During first inference, the Triton autotuner
benchmarks the
merge_fnkernel insolve_trilwith TMA enabled, causingoversized descriptor buffer allocations that OOM even when model weights fit
comfortably in VRAM.
Error
RuntimeError: Triton Error [CUDA]: out of memory
File "fla/ops/solve_tril.py", line 545, in solve_tril
merge_fn[NT, B * H](..., USE_TMA=is_tma_supported)
File "triton/runtime/autotuner.py"
timings = {config: self._bench(*args, config=config, **kwargs) for config in pruned_configs}
Fix
Add upper bound
< 12to restrict TMA only to Hopper (sm_90x). TMA workscorrectly on Hopper but causes Triton autotuner OOM on Blackwell (sm_12x).
Testing
After this fix, Qwen3.5 AWQ models run successfully on RTX 5090 without
--enforce-eager. Full inference pipeline verified working.Related Issues