Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 13 additions & 6 deletions fla/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -464,7 +464,7 @@ def map_triton_backend_to_torch_device() -> str:
IS_NVIDIA = (device_platform == 'cuda')
IS_INTEL_ALCHEMIST = (IS_INTEL and 'Intel(R) Arc(TM) A' in torch.xpu.get_device_name(0))
IS_NVIDIA_HOPPER = (IS_NVIDIA and ('NVIDIA H' in torch.cuda.get_device_name(0) or torch.cuda.get_device_capability()[0] >= 9))
IS_NVIDIA_BLACKWELL = (IS_NVIDIA and torch.cuda.get_device_capability()[0] == 10)
IS_NVIDIA_BLACKWELL = (IS_NVIDIA and torch.cuda.get_device_capability()[0] >= 10)
USE_CUDA_GRAPH = (IS_NVIDIA and os.environ.get('FLA_USE_CUDA_GRAPH', '0') == '1')

# Nvidia Ampere or newer, haven't check AMD and intel yet.
Expand All @@ -479,13 +479,20 @@ def map_triton_backend_to_torch_device() -> str:
# This is a workaround for old nvidia card.
os.environ['TRITON_F32_DEFAULT'] = 'ieee'

if IS_TMA_SUPPORTED:
logger.info('TMA is supported, using TMA by default.')

def alloc_fn(size: int, alignment: int, stream: int | None):
return torch.empty(size, device=torch.device(device_name, device_torch_lib.current_device()), dtype=torch.int8)
def _default_alloc_fn(size: int, alignment: int, stream: int | None):
return torch.empty(size, device=torch.device(device_name, device_torch_lib.current_device()), dtype=torch.int8)

Comment on lines +483 to +485

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

🌐 Web query:

triton set_allocator callback signature alignment stream requirements

💡 Result:

The triton.set_allocator function in the Triton Language (triton-lang/triton Python package) allows users to specify a custom memory allocator callback for GPU kernel memory allocations. This is particularly useful for custom memory management, such as using torch.empty for temporary buffers or integrating with custom pools. Callback signature: def alloc_fn(size: int, alignment: int, stream: Optional[int] = None) -> torch.Tensor - size: Number of bytes to allocate. - alignment: Required alignment in bytes for the allocation (e.g., 128 or 256 bytes common for shared memory or cache lines; Triton's internal shared memory allocator uses 128-byte scratchAlignment). - stream: Optional CUDA stream handle (int); if provided, allocation should be stream-ordered (e.g., using cudaMallocAsync if implementing in C++/CUDA). Allocations must respect stream ordering to avoid synchronization issues. Usage example: def alloc_fn(size: int, alignment: int, stream: Optional[int]): return torch.empty(size, device="cuda", dtype=torch.int8) # or pinned memory, custom pool slice, etc. triton.set_allocator(alloc_fn) Requirements: - Alignment: Must return a pointer aligned to the requested 'alignment' bytes. Triton enforces alignments like 128 bytes for scratch buffers (lib/Analysis/Allocation.cpp). Misalignment can cause warp memory errors (e.g., Issue #5136). - Stream: Optional but should use stream-ordered allocation (e.g., torch.empty implicitly uses current stream; for explicit control, use CUDA stream-ordered APIs like cudaMallocAsync from CUDA 11.2+ docs). Supports CUDA stream-ordered memory allocator for async alloc/free without global sync. - Return: A torch.Tensor (device="cuda") whose data_ptr is the allocated memory. Tensor should be contiguous and match size/alignment. - Thread-safe and efficient for kernel compilation/runtime. This feature is recent (not in all released versions; build from source or use latest PyTorch/Triton for availability, per Issues #6191, #155584). Used in advanced kernels like grouped-GEMM tutorials and third-party libs (e.g., gemlite). No official docs page found; inferred from source code patterns, tutorials, and issues.

Citations:


The alignment parameter must not be ignored in the allocator callback.

While ignoring stream is acceptable (since torch.empty implicitly uses the current stream), the alignment parameter is part of Triton's allocator contract and must be respected. Triton enforces alignment requirements (e.g., 128 bytes for scratch buffers), and misaligned allocations can cause warp memory errors. The current implementation using torch.empty(size, ...) does not guarantee the required alignment. This needs to be fixed by either using an alignment-aware allocation method or manually aligning the returned pointer.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@fla/utils.py` around lines 483 - 485, The allocator _default_alloc_fn
currently ignores the alignment parameter and returns torch.empty(...) which may
produce misaligned buffers; update _default_alloc_fn to honor alignment by
allocating with extra padding and returning an aligned buffer (or using an
alignment-aware PyTorch allocator if available): allocate at least size +
alignment-1 bytes on the same device (using device_name and
device_torch_lib.current_device()), compute an aligned start offset based on the
returned tensor.data_ptr(), and return a tensor view that begins at that aligned
address with the requested size and appropriate dtype (int8) so Triton’s
alignment contract is satisfied.


triton.set_allocator(alloc_fn)
if IS_TMA_SUPPORTED:
logger.info('TMA is supported, using TMA by default.')
triton.set_allocator(_default_alloc_fn)
elif IS_NVIDIA_BLACKWELL:
# Blackwell (SM 10.0+): Triton compiler may emit global_scratch for
# autotuned kernels even without TMA. Register a default allocator to
# prevent NullAllocator crashes. See triton-lang/triton#10002.
logger.info('Blackwell detected: registering default global_scratch allocator.')
triton.set_allocator(_default_alloc_fn)


def get_all_max_shared_mem():
Expand Down
Loading