diff --git a/python/triton/backends/driver.py b/python/triton/backends/driver.py index 44a4253efa64..0ec215ea3a89 100644 --- a/python/triton/backends/driver.py +++ b/python/triton/backends/driver.py @@ -146,6 +146,14 @@ def get_benchmarker(self) -> Benchmarker: """ raise NotImplementedError + def allocate_default_global_scratch(self, size: int, alignment: int, stream): + """ + Allocate global scratch when no explicit allocator was installed via set_allocator(). + Kernels on Blackwell (SM 12.0+) may require global scratch memory for cooperative + operations. This fallback prevents RuntimeError when no allocator is configured. + """ + raise NotImplementedError + def allocate_default_profile_scratch(self, size: int, alignment: int, stream): """ Allocate profile scratch when no explicit profile allocator override was installed. @@ -174,6 +182,18 @@ def __init__(self): def assemble_tensormap_to_arg(self, tensormaps_info, args): return args + def allocate_default_global_scratch(self, size: int, alignment: int, stream): + import torch + device = self.get_active_torch_device() + device_interface = self.get_device_interface() + if stream is None: + return torch.empty(size, dtype=torch.uint8, device=device).data_ptr() + launch_stream = device_interface.ExternalStream(stream, device=device) + with device_interface.stream(launch_stream): + scratch = torch.empty(size, dtype=torch.uint8, device=device) + scratch.record_stream(launch_stream) + return scratch.data_ptr() + def allocate_default_profile_scratch(self, size: int, alignment: int, stream): import torch device = self.get_active_torch_device() diff --git a/third_party/amd/backend/driver.py b/third_party/amd/backend/driver.py index bc7cee8972c9..ebefa34d3520 100644 --- a/third_party/amd/backend/driver.py +++ b/third_party/amd/backend/driver.py @@ -349,6 +349,8 @@ def allocate_scratch(size, align, allocator): grid_size = gridX * gridY * gridZ alloc_size = grid_size * size alloc_fn = allocator.get() + if isinstance(alloc_fn, _allocation.NullAllocator): + return active_driver.allocate_default_global_scratch(alloc_size, align, stream) return alloc_fn(alloc_size, align, stream) return None diff --git a/third_party/nvidia/backend/driver.py b/third_party/nvidia/backend/driver.py index 9ea4894660e9..0207779a3dcb 100644 --- a/third_party/nvidia/backend/driver.py +++ b/third_party/nvidia/backend/driver.py @@ -303,6 +303,8 @@ def allocate_scratch(size, align, allocator): grid_size = gridX * gridY * gridZ alloc_size = grid_size * self.num_ctas * size alloc_fn = allocator.get() + if isinstance(alloc_fn, _allocation.NullAllocator): + return active_driver.allocate_default_global_scratch(alloc_size, align, stream) return alloc_fn(alloc_size, align, stream) return None