Skip to content

Commit e1f7f31

Browse files
authored
remove scatter_add in MoE implementation (#1974)
PR for removing `scatter_add` in the MoE implementation. `scatter_add` is somewhat problematic as it is non-deterministic due to the necessity of [atomic adds](https://discuss.pytorch.org/t/why-does-index-add-and-scatter-add-induce-non-deterministic-behavior-on-the-cuda-backend/45544/2) for correctness. Determinism, correctness, and performance tests using scripts under `torchtitan/moe_bench_and_test`: ``` # Determinism: run same forward 100x and compute standard deviations pytest -rsfP torchtitan/moe_bench_and_test/test_moe.py -k test_determinism out_old_std=tensor(0.0297, device='cuda:0', dtype=torch.bfloat16) out_std=tensor(0., device='cuda:0', dtype=torch.bfloat16) out_old_std/out_moe_old.abs().mean()=tensor(0.0006, device='cuda:0', dtype=torch.bfloat16) out_std/out_moe.abs().mean()=tensor(0., device='cuda:0', dtype=torch.bfloat16) ``` ``` # Accuracy: compare MoE outputs to FFN outputs, with weights set such that outputs should be the same # Relative error decreased by 3x pytest -rsfP torchtitan/moe_bench_and_test/test_moe.py -k test_moe_ffn_equivalence moe_old_rel_err=0.009754068047048696 moe_rel_err=0.002507858727736454 moe_old_rel_err/moe_rel_err=3.8894009216589858 ``` ``` # Timing: triton do_bench for DSv3 16B layer fwd + bwd. ~3% faster runtime python torchtitan/moe_bench_and_test/moe_timing.py moe_old && python torchtitan/moe_bench_and_test/moe_timing.py moe args=Namespace(cls='moe_old', perf_reps=1000, perf_warmups=100, seqlen=4096, bsz=4) moe_time_ms=19.712812881469727 args=Namespace(cls='moe', perf_reps=1000, perf_warmups=100, seqlen=4096, bsz=4) moe_time_ms=19.03301840562087 ``` ``` # Memory: for DSv3 16B layer fwd + bwd. ~15% reduction in active mem, ~18% in reserved mem. python torchtitan/moe_bench_and_test/moe_memory.py moe_old && python torchtitan/moe_bench_and_test/moe_memory.py moe args=Namespace(cls='moe_old', iters=1, seqlen=4096, bsz=4) peak_stats.max_active_gib=5.926029682159424 peak_stats.max_reserved_gib=7.224609375 args=Namespace(cls='moe', iters=1, seqlen=4096, bsz=4) peak_stats.max_active_gib=5.051033020019531 peak_stats.max_reserved_gib=5.91015625 ``` Testing fwd + bwd correctness for `tp_degree=ep_degree=world_size=8` and `etp=1` ``` # Similar relative errors torchrun --nproc-per-node 8 torchtitan/moe_bench_and_test/test_tp.py args=Namespace(seqlen=256, bsz=4, tol=0.01), world_size=8, tp=8, ep=8, etp=1 err_ratio_fsdp_ep_old=0.0028211805268959435 err_ratio_fsdp_ep=0.002805679534989922 err_ratio_ep_ep_old=0.0022941468020912068 kl_fsdp_ep_old=tensor(2.4915e-05, device='cuda:0', dtype=torch.bfloat16) kl_fsdp_ep=tensor(2.0981e-05, device='cuda:0', dtype=torch.bfloat16) kl_ep_ep_old=tensor(2.1458e-05, device='cuda:0', dtype=torch.bfloat16) ``` Everything under `torchtitan/moe_bench_and_test` is temporary testing utilities and is to be deleted prior to merging.
1 parent f8fa21e commit e1f7f31

File tree

2 files changed

+30
-28
lines changed

2 files changed

+30
-28
lines changed

torchtitan/distributed/expert_parallel.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -264,12 +264,9 @@ def _prepare_output_fn(self, mod, outputs, device_mesh):
264264
# NOTE: As we shard routed tokens along bs*slen dim across the TP ranks,
265265
# the MoE gather and scatter still require global token indices.
266266
local_rank = device_mesh.get_local_rank()
267-
# fact: top_scores.shape[0] // mod.top_k = batch_size * seq_len // ep_degree
268-
if not hasattr(mod, "top_k"):
269-
raise ValueError(
270-
"TokenReorderer class in MoE should always have top_k attribute."
271-
)
272-
token_indices_experts_sorted += top_scores.shape[0] // mod.top_k * local_rank
267+
token_indices_experts_sorted = (
268+
token_indices_experts_sorted + top_scores.shape[0] * local_rank
269+
)
273270

274271
return top_scores, token_indices_experts_sorted, num_tokens_per_expert
275272

torchtitan/models/moe/moe.py

Lines changed: 27 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -345,7 +345,6 @@ def forward(
345345
)
346346

347347
top_scores_experts_sorted = top_scores.view(-1)[token_indices_experts_sorted]
348-
token_indices_experts_sorted = token_indices_experts_sorted // self.top_k
349348

350349
return (
351350
top_scores_experts_sorted,
@@ -414,7 +413,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
414413
bs, slen, dim = x.shape
415414
x = x.view(-1, dim)
416415

417-
# top_scores and selected_experts_indices shape (bs*slen*top_k,)
416+
# top_scores and selected_experts_indices shape (bs*slen, top_k)
418417
# num_tokens_per_expert shape (num_experts,)
419418
(
420419
top_scores,
@@ -430,7 +429,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
430429
with torch.no_grad():
431430
self.tokens_per_expert.add_(num_tokens_per_expert)
432431

433-
# top_scores and token_indices_experts_sorted shape (bs*slen*top_k,)
432+
# top_scores_experts_sorted and token_indices_experts_sorted shape (bs*slen*top_k,)
434433
# num_tokens_per_expert shape (num_experts,)
435434
# NOTE: the reason we need to compute num_tokens_per_expert again is:
436435
# 1st computation in router is to update self.tokens_per_expert
@@ -445,12 +444,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
445444
) = self.reorderer(top_scores, selected_experts_indices)
446445

447446
# shape (bs*slen*top_k, dim)
448-
token_indices_experts_sorted = token_indices_experts_sorted.reshape(
449-
-1, 1
450-
).expand(-1, dim)
451-
452-
# shape (bs*slen*top_k, dim)
453-
routed_input = torch.gather(x, dim=0, index=token_indices_experts_sorted)
447+
routed_input = x[token_indices_experts_sorted // self.router.top_k]
454448

455449
if self.score_before_experts:
456450
routed_input = (
@@ -464,22 +458,33 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
464458
# shared expert
465459
# Note: we execute the shared expert before scoring the output of the routed expert
466460
# to "implicitly" overlap the shared expert compute with token combine communication
467-
if self.shared_experts is not None:
468-
out = self.shared_experts(x)
469-
else:
470-
out = torch.zeros_like(x)
461+
out = self.shared_experts(x) if self.shared_experts is not None else None
471462

463+
# Unsort routed outputs
464+
routed_output_unsorted = torch.zeros(
465+
(bs * slen * self.router.top_k, dim),
466+
dtype=routed_output.dtype,
467+
device=routed_output.device,
468+
)
469+
routed_output_unsorted[token_indices_experts_sorted] = routed_output
470+
routed_output_unsorted = routed_output_unsorted.reshape(
471+
-1, self.router.top_k, dim
472+
)
472473
if not self.score_before_experts:
473-
routed_output = (
474-
routed_output.to(torch.float32)
475-
* top_scores_experts_sorted.reshape(-1, 1)
476-
).to(x.dtype)
474+
out_experts = (
475+
torch.bmm(
476+
top_scores.reshape(-1, 1, self.router.top_k),
477+
routed_output_unsorted.float(),
478+
)
479+
.to(x.dtype)
480+
.squeeze(1)
481+
)
482+
else:
483+
out_experts = routed_output_unsorted.sum(dim=1)
477484

478-
out = out.scatter_add(
479-
dim=0, index=token_indices_experts_sorted, src=routed_output
480-
)
481-
out = out.reshape(bs, slen, dim)
482-
return out
485+
if out is None:
486+
return out_experts.reshape(bs, slen, dim)
487+
return (out + out_experts).reshape(bs, slen, dim)
483488

484489
def init_weights(
485490
self,

0 commit comments

Comments
 (0)