diff --git a/fla/ops/gated_delta_rule/wy_fast.py b/fla/ops/gated_delta_rule/wy_fast.py index c5119dcf7c..f727880a49 100644 --- a/fla/ops/gated_delta_rule/wy_fast.py +++ b/fla/ops/gated_delta_rule/wy_fast.py @@ -7,7 +7,33 @@ from fla.ops.utils import prepare_chunk_indices from fla.ops.utils.op import exp -from fla.utils import autotune_cache_kwargs, check_shared_mem +from fla.utils import IS_NVIDIA_BLACKWELL, autotune_cache_kwargs, check_shared_mem + +if IS_NVIDIA_BLACKWELL: + """ + Compute tl.dot with SM100 workaround. + + On SM100 (Blackwell) GPUs, wraps the result in inline assembly to prevent + the TritonGPUHoistTMEMAlloc pass from incorrectly fusing add and dot operations. + See: https://github.com/fla-org/flash-linear-attention/issues/638 + + TODO: Remove this workaround once the Triton compiler bug is fixed. + Track upstream issue at: https://github.com/triton-lang/triton/issues/8695 + """ + @triton.jit + def safe_dot(a, b): + return tl.inline_asm_elementwise( + asm="mov.f32 $0, $1;", + constraints="=r,r", + args=[tl.dot(a, b)], + dtype=tl.float32, + is_pure=True, + pack=1, + ) +else: + @triton.jit + def safe_dot(a, b): + return tl.dot(a, b) @triton.heuristics({ @@ -198,7 +224,7 @@ def prepare_wy_repr_bwd_kernel( b_A += tl.dot(b_kb, tl.trans(b_k)) b_dkb = tl.dot(b_dA, b_k) b_db += tl.sum(b_dkb * b_k, 1) - b_dk += tl.dot(tl.trans(b_dA), b_kb) + b_dk += safe_dot(tl.trans(b_dA), b_kb) b_dk += b_dkb * b_b[:, None] tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1)) tl.store(p_db, b_db.to(p_db.dtype.element_ty), boundary_check=(0,)) diff --git a/fla/utils.py b/fla/utils.py index af9658aa00..f472f11bc3 100644 --- a/fla/utils.py +++ b/fla/utils.py @@ -395,6 +395,7 @@ def map_triton_backend_to_torch_device() -> str: IS_NVIDIA = (device_platform == 'cuda') IS_INTEL_ALCHEMIST = (IS_INTEL and 'Intel(R) Arc(TM) A' in torch.xpu.get_device_name(0)) IS_NVIDIA_HOPPER = (IS_NVIDIA and ('NVIDIA H' in torch.cuda.get_device_name(0) or torch.cuda.get_device_capability()[0] >= 9)) +IS_NVIDIA_BLACKWELL = (IS_NVIDIA and torch.cuda.get_device_capability()[0] == 10) USE_CUDA_GRAPH = (IS_NVIDIA and os.environ.get('FLA_USE_CUDA_GRAPH', '0') == '1') # Nvidia Ampere or newer, haven't check AMD and intel yet. @@ -477,6 +478,7 @@ def _register_aliases(): 'IS_NVIDIA', 'IS_INTEL_ALCHEMIST', 'IS_NVIDIA_HOPPER', + 'IS_NVIDIA_BLACKWELL', 'USE_CUDA_GRAPH', 'IS_TF32_SUPPORTED', 'IS_GATHER_SUPPORTED',