diff --git a/flash_attn/cute/block_sparsity.py b/flash_attn/cute/block_sparsity.py index dcaa3656b52..1607a8b80b5 100644 --- a/flash_attn/cute/block_sparsity.py +++ b/flash_attn/cute/block_sparsity.py @@ -53,7 +53,7 @@ def _expand_sparsity_tensor( f"{tensor_name}{context_clause} with shape {tensor.shape} cannot be expanded to expected shape {expected_shape}." f"{hint_clause}" ) - return tensor.expand(*expected_shape).contiguous() + return tensor.expand(*expected_shape) def _check_and_expand_block( diff --git a/tests/cute/test_mask_mod.py b/tests/cute/test_mask_mod.py index 745fa01a588..96e051c5655 100644 --- a/tests/cute/test_mask_mod.py +++ b/tests/cute/test_mask_mod.py @@ -400,7 +400,7 @@ def mask_mod_flex(b, h, q_idx, kv_idx, bias=bias): bwd_rtol = 2 min_seqlen = min(seqlen_q, seqlen_k) - bwd_atol_floor = 1e-5 if min_seqlen >= 64 else 2e-5 + bwd_atol_floor = 1e-5 if min_seqlen >= 64 else 3e-5 dq_atol = max(bwd_atol_floor, 2 * (dq_ref_fp32 + 0.3 - 0.3 - dq_ref_fp32).abs().max().item()) dk_atol = max(bwd_atol_floor, 2 * (dk_ref_fp32 + 0.3 - 0.3 - dk_ref_fp32).abs().max().item()) dv_atol = max(bwd_atol_floor, 2 * (dv_ref_fp32 + 0.3 - 0.3 - dv_ref_fp32).abs().max().item())