From 26f74ec311772a0e5930820f0e0c167a150cde08 Mon Sep 17 00:00:00 2001 From: Sergey Subbotin Date: Sat, 11 Apr 2026 10:01:03 +0200 Subject: [PATCH] Add default global_scratch allocator fallback for Blackwell SM 12.0 MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit On Blackwell (SM 12.0+), Triton kernels may require global_scratch memory for cooperative operations. When no explicit allocator is configured via triton.set_allocator(), the NullAllocator raises RuntimeError, crashing any kernel that uses global_scratch. This adds allocate_default_global_scratch() to GPUDriver — mirroring the existing allocate_default_profile_scratch() pattern — and uses it as a fallback in both NVIDIA and AMD backend launchers when NullAllocator is detected. Fixes kernel crashes on RTX PRO 6000, RTX 5090, and other Blackwell consumer GPUs when running Triton kernels that use global_scratch (e.g., FLA solve_tril in vLLM for MoE+Mamba models). --- python/triton/backends/driver.py | 20 ++++++++++++++++++++ third_party/amd/backend/driver.py | 2 ++ third_party/nvidia/backend/driver.py | 2 ++ 3 files changed, 24 insertions(+) diff --git a/python/triton/backends/driver.py b/python/triton/backends/driver.py index 44a4253efa64..0ec215ea3a89 100644 --- a/python/triton/backends/driver.py +++ b/python/triton/backends/driver.py @@ -146,6 +146,14 @@ def get_benchmarker(self) -> Benchmarker: """ raise NotImplementedError + def allocate_default_global_scratch(self, size: int, alignment: int, stream): + """ + Allocate global scratch when no explicit allocator was installed via set_allocator(). + Kernels on Blackwell (SM 12.0+) may require global scratch memory for cooperative + operations. This fallback prevents RuntimeError when no allocator is configured. + """ + raise NotImplementedError + def allocate_default_profile_scratch(self, size: int, alignment: int, stream): """ Allocate profile scratch when no explicit profile allocator override was installed. @@ -174,6 +182,18 @@ def __init__(self): def assemble_tensormap_to_arg(self, tensormaps_info, args): return args + def allocate_default_global_scratch(self, size: int, alignment: int, stream): + import torch + device = self.get_active_torch_device() + device_interface = self.get_device_interface() + if stream is None: + return torch.empty(size, dtype=torch.uint8, device=device).data_ptr() + launch_stream = device_interface.ExternalStream(stream, device=device) + with device_interface.stream(launch_stream): + scratch = torch.empty(size, dtype=torch.uint8, device=device) + scratch.record_stream(launch_stream) + return scratch.data_ptr() + def allocate_default_profile_scratch(self, size: int, alignment: int, stream): import torch device = self.get_active_torch_device() diff --git a/third_party/amd/backend/driver.py b/third_party/amd/backend/driver.py index bc7cee8972c9..ebefa34d3520 100644 --- a/third_party/amd/backend/driver.py +++ b/third_party/amd/backend/driver.py @@ -349,6 +349,8 @@ def allocate_scratch(size, align, allocator): grid_size = gridX * gridY * gridZ alloc_size = grid_size * size alloc_fn = allocator.get() + if isinstance(alloc_fn, _allocation.NullAllocator): + return active_driver.allocate_default_global_scratch(alloc_size, align, stream) return alloc_fn(alloc_size, align, stream) return None diff --git a/third_party/nvidia/backend/driver.py b/third_party/nvidia/backend/driver.py index 9ea4894660e9..0207779a3dcb 100644 --- a/third_party/nvidia/backend/driver.py +++ b/third_party/nvidia/backend/driver.py @@ -303,6 +303,8 @@ def allocate_scratch(size, align, allocator): grid_size = gridX * gridY * gridZ alloc_size = grid_size * self.num_ctas * size alloc_fn = allocator.get() + if isinstance(alloc_fn, _allocation.NullAllocator): + return active_driver.allocate_default_global_scratch(alloc_size, align, stream) return alloc_fn(alloc_size, align, stream) return None