Commit e1f7f31
authored
remove scatter_add in MoE implementation (#1974)
PR for removing `scatter_add` in the MoE implementation. `scatter_add`
is somewhat problematic as it is non-deterministic due to the necessity
of [atomic
adds](https://discuss.pytorch.org/t/why-does-index-add-and-scatter-add-induce-non-deterministic-behavior-on-the-cuda-backend/45544/2)
for correctness.
Determinism, correctness, and performance tests using scripts under
`torchtitan/moe_bench_and_test`:
```
# Determinism: run same forward 100x and compute standard deviations
pytest -rsfP torchtitan/moe_bench_and_test/test_moe.py -k test_determinism
out_old_std=tensor(0.0297, device='cuda:0', dtype=torch.bfloat16)
out_std=tensor(0., device='cuda:0', dtype=torch.bfloat16)
out_old_std/out_moe_old.abs().mean()=tensor(0.0006, device='cuda:0', dtype=torch.bfloat16)
out_std/out_moe.abs().mean()=tensor(0., device='cuda:0', dtype=torch.bfloat16)
```
```
# Accuracy: compare MoE outputs to FFN outputs, with weights set such that outputs should be the same
# Relative error decreased by 3x
pytest -rsfP torchtitan/moe_bench_and_test/test_moe.py -k test_moe_ffn_equivalence
moe_old_rel_err=0.009754068047048696
moe_rel_err=0.002507858727736454
moe_old_rel_err/moe_rel_err=3.8894009216589858
```
```
# Timing: triton do_bench for DSv3 16B layer fwd + bwd. ~3% faster runtime
python torchtitan/moe_bench_and_test/moe_timing.py moe_old && python torchtitan/moe_bench_and_test/moe_timing.py moe
args=Namespace(cls='moe_old', perf_reps=1000, perf_warmups=100, seqlen=4096, bsz=4)
moe_time_ms=19.712812881469727
args=Namespace(cls='moe', perf_reps=1000, perf_warmups=100, seqlen=4096, bsz=4)
moe_time_ms=19.03301840562087
```
```
# Memory: for DSv3 16B layer fwd + bwd. ~15% reduction in active mem, ~18% in reserved mem.
python torchtitan/moe_bench_and_test/moe_memory.py moe_old && python torchtitan/moe_bench_and_test/moe_memory.py moe
args=Namespace(cls='moe_old', iters=1, seqlen=4096, bsz=4)
peak_stats.max_active_gib=5.926029682159424
peak_stats.max_reserved_gib=7.224609375
args=Namespace(cls='moe', iters=1, seqlen=4096, bsz=4)
peak_stats.max_active_gib=5.051033020019531
peak_stats.max_reserved_gib=5.91015625
```
Testing fwd + bwd correctness for `tp_degree=ep_degree=world_size=8` and
`etp=1`
```
# Similar relative errors
torchrun --nproc-per-node 8 torchtitan/moe_bench_and_test/test_tp.py
args=Namespace(seqlen=256, bsz=4, tol=0.01), world_size=8, tp=8, ep=8, etp=1
err_ratio_fsdp_ep_old=0.0028211805268959435
err_ratio_fsdp_ep=0.002805679534989922
err_ratio_ep_ep_old=0.0022941468020912068
kl_fsdp_ep_old=tensor(2.4915e-05, device='cuda:0', dtype=torch.bfloat16)
kl_fsdp_ep=tensor(2.0981e-05, device='cuda:0', dtype=torch.bfloat16)
kl_ep_ep_old=tensor(2.1458e-05, device='cuda:0', dtype=torch.bfloat16)
```
Everything under `torchtitan/moe_bench_and_test` is temporary testing
utilities and is to be deleted prior to merging.1 parent f8fa21e commit e1f7f31
2 files changed
+30
-28
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
264 | 264 | | |
265 | 265 | | |
266 | 266 | | |
267 | | - | |
268 | | - | |
269 | | - | |
270 | | - | |
271 | | - | |
272 | | - | |
| 267 | + | |
| 268 | + | |
| 269 | + | |
273 | 270 | | |
274 | 271 | | |
275 | 272 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
345 | 345 | | |
346 | 346 | | |
347 | 347 | | |
348 | | - | |
349 | 348 | | |
350 | 349 | | |
351 | 350 | | |
| |||
414 | 413 | | |
415 | 414 | | |
416 | 415 | | |
417 | | - | |
| 416 | + | |
418 | 417 | | |
419 | 418 | | |
420 | 419 | | |
| |||
430 | 429 | | |
431 | 430 | | |
432 | 431 | | |
433 | | - | |
| 432 | + | |
434 | 433 | | |
435 | 434 | | |
436 | 435 | | |
| |||
445 | 444 | | |
446 | 445 | | |
447 | 446 | | |
448 | | - | |
449 | | - | |
450 | | - | |
451 | | - | |
452 | | - | |
453 | | - | |
| 447 | + | |
454 | 448 | | |
455 | 449 | | |
456 | 450 | | |
| |||
464 | 458 | | |
465 | 459 | | |
466 | 460 | | |
467 | | - | |
468 | | - | |
469 | | - | |
470 | | - | |
| 461 | + | |
471 | 462 | | |
| 463 | + | |
| 464 | + | |
| 465 | + | |
| 466 | + | |
| 467 | + | |
| 468 | + | |
| 469 | + | |
| 470 | + | |
| 471 | + | |
| 472 | + | |
472 | 473 | | |
473 | | - | |
474 | | - | |
475 | | - | |
476 | | - | |
| 474 | + | |
| 475 | + | |
| 476 | + | |
| 477 | + | |
| 478 | + | |
| 479 | + | |
| 480 | + | |
| 481 | + | |
| 482 | + | |
| 483 | + | |
477 | 484 | | |
478 | | - | |
479 | | - | |
480 | | - | |
481 | | - | |
482 | | - | |
| 485 | + | |
| 486 | + | |
| 487 | + | |
483 | 488 | | |
484 | 489 | | |
485 | 490 | | |
| |||
0 commit comments