diff --git a/flash_attn/cute/flash_bwd_sm100.py b/flash_attn/cute/flash_bwd_sm100.py index 0b0488963ba..de6bceca843 100644 --- a/flash_attn/cute/flash_bwd_sm100.py +++ b/flash_attn/cute/flash_bwd_sm100.py @@ -412,8 +412,12 @@ def __call__( assert self.dv_dtype.width == 32, "Must accumulate dV in float precision for GQA" # Assume all strides are divisible by 128 bits except the last stride + # Skip assume for Python ints (e.g., stride=0 from GQA expand) new_stride = lambda t: ( - *(cute.assume(s, divby=128 // t.element_type.width) for s in t.stride[:-1]), + *( + s if isinstance(s, int) else cute.assume(s, divby=128 // t.element_type.width) + for s in t.stride[:-1] + ), t.stride[-1], ) ( diff --git a/flash_attn/cute/flash_fwd_sm100.py b/flash_attn/cute/flash_fwd_sm100.py index dd81c1d6db5..26c02f853f2 100644 --- a/flash_attn/cute/flash_fwd_sm100.py +++ b/flash_attn/cute/flash_fwd_sm100.py @@ -290,8 +290,9 @@ def __call__( self.v_dtype = mV.element_type self.o_dtype = mO.element_type # Assume all strides are divisible by 128 bits except the last stride + # Skip assume for Python ints (e.g., stride=0 from GQA expand) new_stride = lambda t: ( - *(cute.assume(s, divby=128 // t.element_type.width) for s in t.stride[:-1]), + *(s if isinstance(s, int) else cute.assume(s, divby=128 // t.element_type.width) for s in t.stride[:-1]), t.stride[-1], ) mQ, mK, mV, mO = [