diff --git a/examples/dsa_sparse_finetune/sparse_mla_bwd.py b/examples/dsa_sparse_finetune/sparse_mla_bwd.py index 53e5f8bfe..06eaa8eb3 100644 --- a/examples/dsa_sparse_finetune/sparse_mla_bwd.py +++ b/examples/dsa_sparse_finetune/sparse_mla_bwd.py @@ -226,17 +226,17 @@ def sparse_mla_bwd_kernel( if bi_i < BS // split_store: acc_dkv_tail_shared[bi_i, d_i] = acc_dkv_tail[bi_i + s * (BS // split_store), d_i] - for bi_i, d_i in T.Parallel(BS // split_store, D // 4): - T.atomic_addx4( - dKV[bos + Indices[bos + s_i, bz, i_i * BS + bi_i + s * (BS // split_store)], bz, d_i * 4], - acc_dkv_shared[bi_i, d_i * 4], + for bi_i, d_i in T.Parallel(BS // split_store, D): + T.atomic_add( + dKV[bos + Indices[bos + s_i, bz, i_i * BS + bi_i + s * (BS // split_store)], bz, d_i], + acc_dkv_shared[bi_i, d_i], ) # Atomically update dKV, dKV_tail tensors - for bi_i, d_i in T.Parallel(BS // split_store, D_tail // 4): - T.atomic_addx4( - dKV[bos + Indices[bos + s_i, bz, i_i * BS + bi_i + s * (BS // split_store)], bz, D + d_i * 4], - acc_dkv_tail_shared[bi_i, d_i * 4], + for bi_i, d_i in T.Parallel(BS // split_store, D_tail): + T.atomic_add( + dKV[bos + Indices[bos + s_i, bz, i_i * BS + bi_i + s * (BS // split_store)], bz, D + d_i], + acc_dkv_tail_shared[bi_i, d_i], ) # Store the accumulated dQ