Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions python/triton/backends/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,14 @@ def get_benchmarker(self) -> Benchmarker:
"""
raise NotImplementedError

def allocate_default_global_scratch(self, size: int, alignment: int, stream):
"""
Allocate global scratch when no explicit allocator was installed via set_allocator().
Kernels on Blackwell (SM 12.0+) may require global scratch memory for cooperative
operations. This fallback prevents RuntimeError when no allocator is configured.
"""
raise NotImplementedError

def allocate_default_profile_scratch(self, size: int, alignment: int, stream):
"""
Allocate profile scratch when no explicit profile allocator override was installed.
Expand Down Expand Up @@ -174,6 +182,18 @@ def __init__(self):
def assemble_tensormap_to_arg(self, tensormaps_info, args):
return args

def allocate_default_global_scratch(self, size: int, alignment: int, stream):
import torch
device = self.get_active_torch_device()
device_interface = self.get_device_interface()
if stream is None:
return torch.empty(size, dtype=torch.uint8, device=device).data_ptr()
launch_stream = device_interface.ExternalStream(stream, device=device)
with device_interface.stream(launch_stream):
scratch = torch.empty(size, dtype=torch.uint8, device=device)
scratch.record_stream(launch_stream)
return scratch.data_ptr()

def allocate_default_profile_scratch(self, size: int, alignment: int, stream):
import torch
device = self.get_active_torch_device()
Expand Down
2 changes: 2 additions & 0 deletions third_party/amd/backend/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,6 +349,8 @@ def allocate_scratch(size, align, allocator):
grid_size = gridX * gridY * gridZ
alloc_size = grid_size * size
alloc_fn = allocator.get()
if isinstance(alloc_fn, _allocation.NullAllocator):
return active_driver.allocate_default_global_scratch(alloc_size, align, stream)
return alloc_fn(alloc_size, align, stream)
return None

Expand Down
2 changes: 2 additions & 0 deletions third_party/nvidia/backend/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,6 +303,8 @@ def allocate_scratch(size, align, allocator):
grid_size = gridX * gridY * gridZ
alloc_size = grid_size * self.num_ctas * size
alloc_fn = allocator.get()
if isinstance(alloc_fn, _allocation.NullAllocator):
return active_driver.allocate_default_global_scratch(alloc_size, align, stream)
return alloc_fn(alloc_size, align, stream)
return None

Expand Down