Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 28 additions & 2 deletions fla/ops/gated_delta_rule/wy_fast.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,33 @@

from fla.ops.utils import prepare_chunk_indices
from fla.ops.utils.op import exp
from fla.utils import autotune_cache_kwargs, check_shared_mem
from fla.utils import IS_NVIDIA_BLACKWELL, autotune_cache_kwargs, check_shared_mem

if IS_NVIDIA_BLACKWELL:
"""
Compute tl.dot with SM100 workaround.

On SM100 (Blackwell) GPUs, wraps the result in inline assembly to prevent
the TritonGPUHoistTMEMAlloc pass from incorrectly fusing add and dot operations.
See: https://github.com/fla-org/flash-linear-attention/issues/638

TODO: Remove this workaround once the Triton compiler bug is fixed.
Track upstream issue at: https://github.com/triton-lang/triton/issues/8695
"""
@triton.jit
def safe_dot(a, b):
return tl.inline_asm_elementwise(
asm="mov.f32 $0, $1;",
constraints="=r,r",
args=[tl.dot(a, b)],
dtype=tl.float32,
is_pure=True,
pack=1,
)
else:
@triton.jit
def safe_dot(a, b):
return tl.dot(a, b)


@triton.heuristics({
Expand Down Expand Up @@ -198,7 +224,7 @@ def prepare_wy_repr_bwd_kernel(
b_A += tl.dot(b_kb, tl.trans(b_k))
b_dkb = tl.dot(b_dA, b_k)
b_db += tl.sum(b_dkb * b_k, 1)
b_dk += tl.dot(tl.trans(b_dA), b_kb)
b_dk += safe_dot(tl.trans(b_dA), b_kb)
b_dk += b_dkb * b_b[:, None]
tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1))
tl.store(p_db, b_db.to(p_db.dtype.element_ty), boundary_check=(0,))
Expand Down
2 changes: 2 additions & 0 deletions fla/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -395,6 +395,7 @@ def map_triton_backend_to_torch_device() -> str:
IS_NVIDIA = (device_platform == 'cuda')
IS_INTEL_ALCHEMIST = (IS_INTEL and 'Intel(R) Arc(TM) A' in torch.xpu.get_device_name(0))
IS_NVIDIA_HOPPER = (IS_NVIDIA and ('NVIDIA H' in torch.cuda.get_device_name(0) or torch.cuda.get_device_capability()[0] >= 9))
IS_NVIDIA_BLACKWELL = (IS_NVIDIA and torch.cuda.get_device_capability()[0] == 10)
USE_CUDA_GRAPH = (IS_NVIDIA and os.environ.get('FLA_USE_CUDA_GRAPH', '0') == '1')

# Nvidia Ampere or newer, haven't check AMD and intel yet.
Expand Down Expand Up @@ -477,6 +478,7 @@ def _register_aliases():
'IS_NVIDIA',
'IS_INTEL_ALCHEMIST',
'IS_NVIDIA_HOPPER',
'IS_NVIDIA_BLACKWELL',
'USE_CUDA_GRAPH',
'IS_TF32_SUPPORTED',
'IS_GATHER_SUPPORTED',
Expand Down
Loading