diff --git a/flashinfer/cute_dsl/__init__.py b/flashinfer/cute_dsl/__init__.py index 9e6762531a..940031453d 100644 --- a/flashinfer/cute_dsl/__init__.py +++ b/flashinfer/cute_dsl/__init__.py @@ -53,10 +53,6 @@ add_rmsnorm_fp4quant, AddRMSNormFP4QuantKernel, ) - from .gated_delta_rule import ( - gated_delta_rule, - GatedDeltaRuleKernel, - ) __all__ = [ # Utils (always available) @@ -83,7 +79,4 @@ # Add + RMSNorm + FP4 Quantization "add_rmsnorm_fp4quant", "AddRMSNormFP4QuantKernel", - # Gated Delta Rule - "gated_delta_rule", - "GatedDeltaRuleKernel", ] diff --git a/flashinfer/cute_dsl/add_rmsnorm_fp4quant.py b/flashinfer/cute_dsl/add_rmsnorm_fp4quant.py index 0bd9acea06..9a6bfe55af 100644 --- a/flashinfer/cute_dsl/add_rmsnorm_fp4quant.py +++ b/flashinfer/cute_dsl/add_rmsnorm_fp4quant.py @@ -536,6 +536,14 @@ def kernel( if row_in_bounds: cute.copy(copy_atom_store, tRrO, tRgO, pred=tXpX) + # In cluster mode, Phase 3 reads from mR across the FULL hidden dimension, + # including slices written by other CTAs. We must ensure all CTAs' global + # memory writes are visible before any CTA proceeds to Phase 3. + if cutlass.const_expr(cluster_n > 1): + cute.arch.fence_acq_rel_cluster() + cute.arch.cluster_arrive_relaxed() + cute.arch.cluster_wait() + actual_row_idx = bidx * rows_per_block + row_in_block # Phase 3: RMSNorm + Quantize