From 3ff63eef23a058038d7c9a8da0d46d017330c0b8 Mon Sep 17 00:00:00 2001 From: Yuzhou Nie Date: Thu, 18 Dec 2025 02:34:17 -0800 Subject: [PATCH 1/7] Temporary workaround to disable TritonGPUHoistTMEMAlloc in b_dk += tl.dot(tl.trans(b_dA), b_kb) --- fla/ops/gated_delta_rule/wy_fast.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/fla/ops/gated_delta_rule/wy_fast.py b/fla/ops/gated_delta_rule/wy_fast.py index c5119dcf7c..838f5e4ea3 100644 --- a/fla/ops/gated_delta_rule/wy_fast.py +++ b/fla/ops/gated_delta_rule/wy_fast.py @@ -198,7 +198,14 @@ def prepare_wy_repr_bwd_kernel( b_A += tl.dot(b_kb, tl.trans(b_k)) b_dkb = tl.dot(b_dA, b_k) b_db += tl.sum(b_dkb * b_k, 1) - b_dk += tl.dot(tl.trans(b_dA), b_kb) + b_dk += tl.inline_asm_elementwise( + asm="mov.f32 $0, $1;", + constraints="=r,r", + args=[tl.dot(tl.trans(b_dA), b_kb)], + dtype=tl.float32, + is_pure=True, + pack=1, + ) b_dk += b_dkb * b_b[:, None] tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1)) tl.store(p_db, b_db.to(p_db.dtype.element_ty), boundary_check=(0,)) From 7ac536e29034e48fe8c89ab60464b42095f5430f Mon Sep 17 00:00:00 2001 From: Yuzhou Nie Date: Thu, 18 Dec 2025 02:49:14 -0800 Subject: [PATCH 2/7] Fix Triton compiler bug workaround in wy_fast.py --- fla/ops/gated_delta_rule/wy_fast.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/fla/ops/gated_delta_rule/wy_fast.py b/fla/ops/gated_delta_rule/wy_fast.py index 838f5e4ea3..f6f3a913a6 100644 --- a/fla/ops/gated_delta_rule/wy_fast.py +++ b/fla/ops/gated_delta_rule/wy_fast.py @@ -198,6 +198,10 @@ def prepare_wy_repr_bwd_kernel( b_A += tl.dot(b_kb, tl.trans(b_k)) b_dkb = tl.dot(b_dA, b_k) b_db += tl.sum(b_dkb * b_k, 1) + # Temporary workaround for a Triton compiler bug on B200 GPUs. (https://github.com/fla-org/flash-linear-attention/issues/638) + # The `TritonGPUHoistTMEMAlloc` pass incorrectly fuses the add and dot operations, + # leading to a dominance error. The inline assembly prevents this fusion. + # TODO: Remove this workaround when the compiler bug is fixed. b_dk += tl.inline_asm_elementwise( asm="mov.f32 $0, $1;", constraints="=r,r", From 435f0694530a4eaade2c04656be205fa81029441 Mon Sep 17 00:00:00 2001 From: Yuzhou Nie Date: Thu, 18 Dec 2025 14:34:25 -0800 Subject: [PATCH 3/7] Use Blackwell-specific safe dot workaround in wy_fast.py --- fla/ops/gated_delta_rule/wy_fast.py | 38 +++++++++++++++++++---------- fla/utils.py | 1 + 2 files changed, 26 insertions(+), 13 deletions(-) diff --git a/fla/ops/gated_delta_rule/wy_fast.py b/fla/ops/gated_delta_rule/wy_fast.py index f6f3a913a6..d2e68193fe 100644 --- a/fla/ops/gated_delta_rule/wy_fast.py +++ b/fla/ops/gated_delta_rule/wy_fast.py @@ -7,7 +7,30 @@ from fla.ops.utils import prepare_chunk_indices from fla.ops.utils.op import exp -from fla.utils import autotune_cache_kwargs, check_shared_mem +from fla.utils import autotune_cache_kwargs, check_shared_mem, IS_NVIDIA_BLACKWELL + +if IS_NVIDIA_BLACKWELL: + """ + Compute tl.dot with SM100 workaround. + + On SM100 (Blackwell) GPUs, wraps the result in inline assembly to prevent + the TritonGPUHoistTMEMAlloc pass from incorrectly fusing add and dot operations. + See: https://github.com/fla-org/flash-linear-attention/issues/638 + """ + @triton.jit + def safe_dot(a, b): + return tl.inline_asm_elementwise( + asm="mov.f32 $0, $1;", + constraints="=r,r", + args=[tl.dot(a, b)], + dtype=tl.float32, + is_pure=True, + pack=1, + ) +else: + @triton.jit + def safe_dot(a, b): + return tl.dot(a, b) @triton.heuristics({ @@ -198,18 +221,7 @@ def prepare_wy_repr_bwd_kernel( b_A += tl.dot(b_kb, tl.trans(b_k)) b_dkb = tl.dot(b_dA, b_k) b_db += tl.sum(b_dkb * b_k, 1) - # Temporary workaround for a Triton compiler bug on B200 GPUs. (https://github.com/fla-org/flash-linear-attention/issues/638) - # The `TritonGPUHoistTMEMAlloc` pass incorrectly fuses the add and dot operations, - # leading to a dominance error. The inline assembly prevents this fusion. - # TODO: Remove this workaround when the compiler bug is fixed. - b_dk += tl.inline_asm_elementwise( - asm="mov.f32 $0, $1;", - constraints="=r,r", - args=[tl.dot(tl.trans(b_dA), b_kb)], - dtype=tl.float32, - is_pure=True, - pack=1, - ) + b_dk += safe_dot(tl.trans(b_dA), b_kb) b_dk += b_dkb * b_b[:, None] tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1)) tl.store(p_db, b_db.to(p_db.dtype.element_ty), boundary_check=(0,)) diff --git a/fla/utils.py b/fla/utils.py index af9658aa00..04e6187b9b 100644 --- a/fla/utils.py +++ b/fla/utils.py @@ -395,6 +395,7 @@ def map_triton_backend_to_torch_device() -> str: IS_NVIDIA = (device_platform == 'cuda') IS_INTEL_ALCHEMIST = (IS_INTEL and 'Intel(R) Arc(TM) A' in torch.xpu.get_device_name(0)) IS_NVIDIA_HOPPER = (IS_NVIDIA and ('NVIDIA H' in torch.cuda.get_device_name(0) or torch.cuda.get_device_capability()[0] >= 9)) +IS_NVIDIA_BLACKWELL = (IS_NVIDIA and torch.cuda.get_device_capability()[0] == 10) USE_CUDA_GRAPH = (IS_NVIDIA and os.environ.get('FLA_USE_CUDA_GRAPH', '0') == '1') # Nvidia Ampere or newer, haven't check AMD and intel yet. From dea8e1c8420fcf34bb56dbab6d31fc04c81c0c06 Mon Sep 17 00:00:00 2001 From: Yuzhou Nie Date: Thu, 18 Dec 2025 14:39:34 -0800 Subject: [PATCH 4/7] fix link error --- fla/ops/gated_delta_rule/wy_fast.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fla/ops/gated_delta_rule/wy_fast.py b/fla/ops/gated_delta_rule/wy_fast.py index d2e68193fe..8bcbccb7bc 100644 --- a/fla/ops/gated_delta_rule/wy_fast.py +++ b/fla/ops/gated_delta_rule/wy_fast.py @@ -7,7 +7,7 @@ from fla.ops.utils import prepare_chunk_indices from fla.ops.utils.op import exp -from fla.utils import autotune_cache_kwargs, check_shared_mem, IS_NVIDIA_BLACKWELL +from fla.utils import IS_NVIDIA_BLACKWELL, autotune_cache_kwargs, check_shared_mem if IS_NVIDIA_BLACKWELL: """ From 79036e42db761784ddf9d3ed35e88db7048afead Mon Sep 17 00:00:00 2001 From: Yuzhou Nie Date: Thu, 18 Dec 2025 14:43:08 -0800 Subject: [PATCH 5/7] Add NVIDIA Blackwell detection flag --- fla/utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/fla/utils.py b/fla/utils.py index 04e6187b9b..f472f11bc3 100644 --- a/fla/utils.py +++ b/fla/utils.py @@ -478,6 +478,7 @@ def _register_aliases(): 'IS_NVIDIA', 'IS_INTEL_ALCHEMIST', 'IS_NVIDIA_HOPPER', + 'IS_NVIDIA_BLACKWELL', 'USE_CUDA_GRAPH', 'IS_TF32_SUPPORTED', 'IS_GATHER_SUPPORTED', From 34ff8cee11882e46d25b6f8a1905a3ff1e6dce88 Mon Sep 17 00:00:00 2001 From: Yuzhou Nie Date: Thu, 18 Dec 2025 14:49:25 -0800 Subject: [PATCH 6/7] add doc --- fla/ops/gated_delta_rule/wy_fast.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/fla/ops/gated_delta_rule/wy_fast.py b/fla/ops/gated_delta_rule/wy_fast.py index 8bcbccb7bc..6b885fc4d9 100644 --- a/fla/ops/gated_delta_rule/wy_fast.py +++ b/fla/ops/gated_delta_rule/wy_fast.py @@ -16,6 +16,9 @@ On SM100 (Blackwell) GPUs, wraps the result in inline assembly to prevent the TritonGPUHoistTMEMAlloc pass from incorrectly fusing add and dot operations. See: https://github.com/fla-org/flash-linear-attention/issues/638 + + TODO: Remove this workaround once the Triton compiler bug is fixed. + Track upstream issue at: https://github.com/triton-lang/triton/issues/8695 """ @triton.jit def safe_dot(a, b): From 7ca0a4ab698761a9f24d5a9ee2589f3b8708a853 Mon Sep 17 00:00:00 2001 From: Yuzhou Nie Date: Thu, 18 Dec 2025 14:54:09 -0800 Subject: [PATCH 7/7] fix lint --- fla/ops/gated_delta_rule/wy_fast.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fla/ops/gated_delta_rule/wy_fast.py b/fla/ops/gated_delta_rule/wy_fast.py index 6b885fc4d9..f727880a49 100644 --- a/fla/ops/gated_delta_rule/wy_fast.py +++ b/fla/ops/gated_delta_rule/wy_fast.py @@ -16,7 +16,7 @@ On SM100 (Blackwell) GPUs, wraps the result in inline assembly to prevent the TritonGPUHoistTMEMAlloc pass from incorrectly fusing add and dot operations. See: https://github.com/fla-org/flash-linear-attention/issues/638 - + TODO: Remove this workaround once the Triton compiler bug is fixed. Track upstream issue at: https://github.com/triton-lang/triton/issues/8695 """