diff --git a/hopper/flash_fwd_launch_template.h b/hopper/flash_fwd_launch_template.h index 331c0088dc3..43567a52ecf 100644 --- a/hopper/flash_fwd_launch_template.h +++ b/hopper/flash_fwd_launch_template.h @@ -140,7 +140,7 @@ void run_flash_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { static_cast(params.softmax_lse_ptr), {_1{}, seqlen_q, !is_varlen_q ? params.h * seqlen_q : 0, 0}, // stride_LSE static_cast(params.softmax_lseaccum_ptr), - {_1{}, seqlen_q, !is_varlen_q ? params.h * seqlen_q : 0, params.h * seqlen_q * batch_q}, // stride_LSE_partial + {_1{}, params.lseaccum_head_stride, !is_varlen_q ? params.lseaccum_batch_stride : 0, params.lseaccum_split_stride}, // stride_LSE_partial params.h_k, params.cu_seqlens_q, params.seqused_q };