-
-
Notifications
You must be signed in to change notification settings - Fork 17.6k
[MoE Refactor] Integrate Naive Prepare Finalize into MK #32567
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
ca114bd
d51a1a6
4a37fbb
6ab0934
5de3d38
bc30ec4
1666529
8f9e6cd
6121050
2d16269
8e60d88
c3ee917
be84e3b
71aa335
21f3c10
6ac652a
fece963
8d0cc52
4fc2917
60a15b0
0a84042
c5f4734
5010925
1d09143
62efd36
92fd9b0
2b406e3
01f004e
a3462dd
70b7909
9f9557e
859fd35
15c0112
530b463
ddc5b2a
03fe6ec
10b4922
8e6783f
ceb5ec0
e726362
dbd27e6
a29abcc
dfcadb8
f2531a7
be22a24
c466250
ba6903f
1e24906
45bc73e
66c9388
14be251
e76eb02
a5b25d9
19f8470
ad2a758
b456504
d9438b5
c4d86e0
1d95eab
dd71aee
f5d44bc
038af58
226ad44
24ce77b
b66c07c
732abd5
a5cf835
4a5fce9
7cbeae8
af0c268
b5474b7
9105ead
0da54f1
ef152ad
b366ad4
135cc53
a9faf58
609298d
53cfc42
08c442a
2ca7950
f184c48
881f436
9e38aca
98220b0
7bd2a44
b6d5107
678c34a
f9b5b92
e2b9969
3e3f17c
4f058f4
41c1264
0ad06db
bf4d327
578c183
bf0acb1
fe8fb7f
8f46f93
52124ea
59bab2d
ab22193
eeca003
9217c64
ed6d5fd
7e57390
0e79444
ec4edbd
acd7150
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you add a test/registration for
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yes |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -59,7 +59,7 @@ def naive_multicast( | |
|
|
||
| return buffer | ||
|
|
||
| def dispatch( | ||
| def dispatch_router_logits( | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Are we going to deprecate dispatch_router_logits eventually or keep it around for different cases?
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We will still need it for monolithic kernels (trtllm) |
||
| self, | ||
| hidden_states: torch.Tensor, | ||
| router_logits: torch.Tensor, | ||
|
|
@@ -84,6 +84,34 @@ def dispatch( | |
|
|
||
| return hidden_states, router_logits | ||
|
|
||
| def dispatch( | ||
| self, | ||
| hidden_states: torch.Tensor, | ||
| topk_weights: torch.Tensor, | ||
| topk_ids: torch.Tensor, | ||
| is_sequence_parallel: bool = False, | ||
| extra_tensors: list[torch.Tensor] | None = None, | ||
| ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: | ||
| if extra_tensors is not None: | ||
| raise NotImplementedError( | ||
| "extra_tensors is not supported for NaiveAll2AllManager" | ||
| ) | ||
| sp_size = self.tp_group.world_size if is_sequence_parallel else 1 | ||
| dp_metadata = get_forward_context().dp_metadata | ||
| assert dp_metadata is not None | ||
| cu_tokens_across_sp_cpu = dp_metadata.cu_tokens_across_sp(sp_size) | ||
|
|
||
| hidden_states = self.naive_multicast( | ||
| hidden_states, cu_tokens_across_sp_cpu, is_sequence_parallel | ||
| ) | ||
| topk_weights = self.naive_multicast( | ||
| topk_weights, cu_tokens_across_sp_cpu, is_sequence_parallel | ||
| ) | ||
| topk_ids = self.naive_multicast( | ||
| topk_ids, cu_tokens_across_sp_cpu, is_sequence_parallel | ||
| ) | ||
| return hidden_states, topk_weights, topk_ids | ||
|
|
||
| def combine( | ||
| self, hidden_states: torch.Tensor, is_sequence_parallel: bool = False | ||
| ) -> torch.Tensor: | ||
|
|
@@ -114,7 +142,7 @@ class AgRsAll2AllManager(All2AllManagerBase): | |
| def __init__(self, cpu_group): | ||
| super().__init__(cpu_group) | ||
|
|
||
| def dispatch( | ||
| def dispatch_router_logits( | ||
| self, | ||
| hidden_states: torch.Tensor, | ||
| router_logits: torch.Tensor, | ||
|
|
@@ -148,6 +176,46 @@ def dispatch( | |
| return (gathered_tensors[0], gathered_tensors[1], gathered_tensors[2:]) | ||
| return gathered_tensors[0], gathered_tensors[1] | ||
|
|
||
| def dispatch( | ||
| self, | ||
| hidden_states: torch.Tensor, | ||
| topk_weights: torch.Tensor, | ||
| topk_ids: torch.Tensor, | ||
| is_sequence_parallel: bool = False, | ||
| extra_tensors: list[torch.Tensor] | None = None, | ||
| ) -> ( | ||
| tuple[torch.Tensor, torch.Tensor, torch.Tensor] | ||
| | tuple[torch.Tensor, torch.Tensor, torch.Tensor, list[torch.Tensor]] | ||
| ): | ||
| """ | ||
| Gather hidden_states and router_logits from all dp ranks. | ||
| """ | ||
| dp_metadata = get_forward_context().dp_metadata | ||
| assert dp_metadata is not None | ||
| sizes = dp_metadata.get_chunk_sizes_across_dp_rank() | ||
| assert sizes is not None | ||
| dist_group = get_ep_group() if is_sequence_parallel else get_dp_group() | ||
| assert sizes[dist_group.rank_in_group] == hidden_states.shape[0] | ||
|
|
||
| tensors_to_gather = [hidden_states, topk_weights, topk_ids] | ||
| if extra_tensors is not None: | ||
| tensors_to_gather.extend(extra_tensors) | ||
|
|
||
| gathered_tensors = dist_group.all_gatherv( | ||
| tensors_to_gather, | ||
| dim=0, | ||
| sizes=sizes, | ||
| ) | ||
|
|
||
| hidden_states = gathered_tensors[0] | ||
| topk_weights = gathered_tensors[1] | ||
| topk_ids = gathered_tensors[2] | ||
|
|
||
| if extra_tensors is None: | ||
| return hidden_states, topk_weights, topk_ids | ||
|
|
||
| return hidden_states, topk_weights, topk_ids, gathered_tensors[3:] | ||
|
|
||
| def combine( | ||
| self, hidden_states: torch.Tensor, is_sequence_parallel: bool = False | ||
| ) -> torch.Tensor: | ||
|
|
@@ -216,7 +284,7 @@ def get_handle(self, kwargs): | |
| pplx.AllToAll.internode if self.internode else pplx.AllToAll.intranode, | ||
| ) | ||
|
|
||
| def dispatch( | ||
| def dispatch_router_logits( | ||
| self, | ||
| hidden_states: torch.Tensor, | ||
| router_logits: torch.Tensor, | ||
|
|
@@ -225,6 +293,19 @@ def dispatch( | |
| ) -> tuple[torch.Tensor, torch.Tensor]: | ||
| raise NotImplementedError | ||
|
|
||
| def dispatch( | ||
| self, | ||
| hidden_states: torch.Tensor, | ||
| topk_weights: torch.Tensor, | ||
| topk_ids: torch.Tensor, | ||
| is_sequence_parallel: bool = False, | ||
| extra_tensors: list[torch.Tensor] | None = None, | ||
| ) -> ( | ||
| tuple[torch.Tensor, torch.Tensor, torch.Tensor] | ||
| | tuple[torch.Tensor, torch.Tensor, torch.Tensor, list[torch.Tensor]] | ||
| ): | ||
| raise NotImplementedError | ||
|
|
||
| def combine( | ||
| self, hidden_states: torch.Tensor, is_sequence_parallel: bool = False | ||
| ) -> torch.Tensor: | ||
|
|
@@ -264,7 +345,7 @@ def __init__(self, cpu_group): | |
| def get_handle(self, kwargs): | ||
| raise NotImplementedError | ||
|
|
||
| def dispatch( | ||
| def dispatch_router_logits( | ||
| self, | ||
| hidden_states: torch.Tensor, | ||
| router_logits: torch.Tensor, | ||
|
|
@@ -273,6 +354,19 @@ def dispatch( | |
| ) -> tuple[torch.Tensor, torch.Tensor]: | ||
| raise NotImplementedError | ||
|
|
||
| def dispatch( | ||
| self, | ||
| hidden_states: torch.Tensor, | ||
| topk_weights: torch.Tensor, | ||
| topk_ids: torch.Tensor, | ||
| is_sequence_parallel: bool = False, | ||
| extra_tensors: list[torch.Tensor] | None = None, | ||
| ) -> ( | ||
| tuple[torch.Tensor, torch.Tensor, torch.Tensor] | ||
| | tuple[torch.Tensor, torch.Tensor, torch.Tensor, list[torch.Tensor]] | ||
| ): | ||
| raise NotImplementedError | ||
|
|
||
| def combine( | ||
| self, hidden_states: torch.Tensor, is_sequence_parallel: bool = False | ||
| ) -> torch.Tensor: | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should there be an entry for
MoEPrepareAndFinalizeNaiveEP?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes