diff --git a/flash_attn/cute/block_sparse_utils.py b/flash_attn/cute/block_sparse_utils.py index 898a05aa728..67847a0bd6c 100644 --- a/flash_attn/cute/block_sparse_utils.py +++ b/flash_attn/cute/block_sparse_utils.py @@ -705,7 +705,7 @@ def handle_block_sparse_empty_tile_correction_sm100( scale_row_idx = tidx + stage * m_block_size sScale[scale_row_idx] = row_sum_value if const_expr(mLSE is not None or learnable_sink is not None): - sScale[scale_row_idx + m_block_size * 2] = row_max_value + sScale[scale_row_idx + q_stage * m_block_size] = row_max_value acc_flag = row_sum_value == Float32(0.0) or row_sum_value != row_sum_value stats[stage] = (row_sum_value, row_max_value, acc_flag) diff --git a/flash_attn/cute/flash_fwd_sm100.py b/flash_attn/cute/flash_fwd_sm100.py index ccf8edbc43d..c66ca7553a3 100644 --- a/flash_attn/cute/flash_fwd_sm100.py +++ b/flash_attn/cute/flash_fwd_sm100.py @@ -1784,7 +1784,7 @@ def softmax_loop( sScale[tidx + stage * self.m_block_size] = softmax.row_sum[0] if const_expr(mLSE is not None or learnable_sink is not None): sScale[ - tidx + stage * self.m_block_size + self.m_block_size * 2 + tidx + stage * self.m_block_size + self.q_stage * self.m_block_size ] = softmax.row_max[0] # if tidx == 0: # cute.printf("softmax row sum stage %d: %f, row_max = %f\n", stage, softmax.row_sum[0], softmax.row_max[0]) @@ -1853,7 +1853,7 @@ def softmax_loop( sScale[tidx + stage * self.m_block_size] = softmax.row_sum[0] if const_expr(mLSE is not None or learnable_sink is not None): sScale[ - tidx + stage * self.m_block_size + self.m_block_size * 2 + tidx + stage * self.m_block_size + self.q_stage * self.m_block_size ] = softmax.row_max[0] cute.arch.mbarrier_arrive(mbar_ptr + self.mbar_softmax_corr_full_offset + stage) @@ -2159,7 +2159,7 @@ def correction_loop( # scale = tSrScale_t2r[0] row_sum = sScale[tidx + stage * self.m_block_size] if const_expr(mLSE is not None or learnable_sink is not None): - row_max = sScale[tidx + stage * self.m_block_size + self.m_block_size * 2] + row_max = sScale[tidx + stage * self.m_block_size + self.q_stage * self.m_block_size] else: row_max = None cute.arch.mbarrier_arrive(mbar_ptr + self.mbar_softmax_corr_empty_offset + stage)