diff --git a/examples/deepseek_v32/sparse_mla_fwd.py b/examples/deepseek_v32/sparse_mla_fwd.py index 2c8bf7fc7..ddde11f5b 100644 --- a/examples/deepseek_v32/sparse_mla_fwd.py +++ b/examples/deepseek_v32/sparse_mla_fwd.py @@ -161,9 +161,7 @@ def main( for h_i in T.Parallel(H_per_block): sumexp[h_i] = T.log2(sumexp[h_i]) + m_i[h_i] * sm_scale - T.copy(acc_o, O_shared) T.copy(acc_o, Output[b_i, s_i, H0:H1, :]) - T.copy(sumexp, Lse_shared) T.copy(sumexp, Lse[b_i, s_i, H0:H1]) return main