-
Notifications
You must be signed in to change notification settings - Fork 649
Initial compile support for llama4 #1365
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -13,6 +13,75 @@ | |
| from .args import TransformerModelArgs | ||
|
|
||
|
|
||
| # TODO: keeping this for-loop implementation for comparison | ||
| # and readability, may remove later | ||
| @expert_parallel | ||
| def _run_experts_for_loop( | ||
| w1: torch.Tensor, | ||
| w2: torch.Tensor, | ||
| w3: torch.Tensor, | ||
| x: torch.Tensor, | ||
| num_tokens_per_expert: torch.Tensor | None = None, | ||
| ) -> torch.Tensor: | ||
| if num_tokens_per_expert is not None: | ||
| # NOTE: this would incur a synchronization between device and host | ||
| num_tokens_per_expert = num_tokens_per_expert.tolist() | ||
|
|
||
| # side-effect code due to the usage of generate_permute_indices | ||
| num_padding = x.shape[0] - sum(num_tokens_per_expert) | ||
|
|
||
| # a tuple of tensors indexed by experts | ||
| # each with shape (tokens_per_expert(varying), dim) | ||
| x = torch.split( | ||
| x[: sum(num_tokens_per_expert)], | ||
| split_size_or_sections=num_tokens_per_expert, | ||
| dim=0, | ||
| ) | ||
| out_experts_splits = [] | ||
| for expert_idx, x_expert in enumerate(x): | ||
| h = F.silu(torch.matmul(x_expert, w1[expert_idx])) | ||
| h = h * torch.matmul(x_expert, w3[expert_idx]) | ||
| h = torch.matmul(h, w2[expert_idx]) | ||
| # h shape (tokens_per_expert(varying), dim) | ||
| out_experts_splits.append(h) | ||
| out = torch.cat(out_experts_splits, dim=0) | ||
|
|
||
| # side-effect code due to the usage of generate_permute_indices | ||
| out = torch.vstack((out, out.new_zeros((num_padding, out.shape[-1])))) | ||
| else: | ||
| # x shape (num_experts, tokens_per_expert, dim) | ||
| h = F.silu(torch.bmm(x, w1)) | ||
| h = h * torch.bmm(x, w3) | ||
| # out shape (num_experts, tokens_per_expert, dim) | ||
| out = torch.bmm(h, w2) | ||
|
|
||
| return out | ||
|
|
||
|
|
||
| @expert_parallel | ||
| def _run_experts_grouped_mm( | ||
| w1: torch.Tensor, | ||
| w2: torch.Tensor, | ||
| w3: torch.Tensor, | ||
| x: torch.Tensor, | ||
| num_tokens_per_expert: torch.Tensor | None = None, | ||
| ) -> torch.Tensor: | ||
| if num_tokens_per_expert is not None: | ||
| offsets = torch.cumsum(num_tokens_per_expert, dim=0, dtype=torch.int32) | ||
| # grouped mm between a 2D tensor and a 3D tensor | ||
| assert x.dim() == 2 | ||
| else: | ||
| offsets = None | ||
| # fall back to regular bmm between 3D tensors | ||
| assert x.dim() == 3 | ||
|
|
||
| h = F.silu(torch._grouped_mm(x.bfloat16(), w1.bfloat16(), offs=offsets)) | ||
| h = h * torch._grouped_mm(x.bfloat16(), w3.bfloat16(), offs=offsets) | ||
| out = torch._grouped_mm(h, w2.bfloat16(), offs=offsets).type_as(x) | ||
|
|
||
| return out | ||
|
|
||
|
|
||
| class GroupedExperts(nn.Module): | ||
| def __init__( | ||
| self, | ||
|
|
@@ -28,89 +97,21 @@ def __init__( | |
| self.w3 = nn.Parameter(torch.empty(num_experts, dim, hidden_dim)) | ||
| self.use_grouped_mm = use_grouped_mm | ||
|
|
||
| @torch._dynamo.set_fullgraph(True) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. what is this annotation for?
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Compiling the block with fullgraph=False could allow graph breaks to creep in silently with dynamo changes, and we wouldn't know about them until we manually inspect the graph or suspect QPS to have regressed. This API to more granularly control the fullgraph argument of torch.compile, you can flip it on and off within a compiled region. In this case, we allow graph breaks between GroupedExperts.call and GroupedExperts.forward, i.e. allow graph break on the forward hooks from FSDP
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In addition to FSDP comms, EP a2a also happens before & after |
||
| def forward( | ||
| self, | ||
| x: torch.Tensor, | ||
| num_tokens_per_expert: torch.Tensor | None = None, | ||
| ) -> torch.Tensor: | ||
| if self.use_grouped_mm: | ||
| return GroupedExperts._run_experts_grouped_mm( | ||
| return _run_experts_grouped_mm( | ||
| self.w1, self.w2, self.w3, x, num_tokens_per_expert | ||
| ) | ||
| else: | ||
| return GroupedExperts._run_experts_for_loop( | ||
| return _run_experts_for_loop( | ||
| self.w1, self.w2, self.w3, x, num_tokens_per_expert | ||
| ) | ||
|
|
||
| # TODO: keeping this for-loop implementation for comparison | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Could you explain more? Does it mean if we move them out, then torch.compile can trace them in the same graph as the caller module is in?
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yes |
||
| # and readability, may remove later | ||
| @expert_parallel | ||
| @staticmethod | ||
| def _run_experts_for_loop( | ||
| w1: torch.Tensor, | ||
| w2: torch.Tensor, | ||
| w3: torch.Tensor, | ||
| x: torch.Tensor, | ||
| num_tokens_per_expert: torch.Tensor | None = None, | ||
| ) -> torch.Tensor: | ||
| if num_tokens_per_expert is not None: | ||
| # NOTE: this would incur a synchronization between device and host | ||
| num_tokens_per_expert = num_tokens_per_expert.tolist() | ||
|
|
||
| # side-effect code due to the usage of generate_permute_indices | ||
| num_padding = x.shape[0] - sum(num_tokens_per_expert) | ||
|
|
||
| # a tuple of tensors indexed by experts | ||
| # each with shape (tokens_per_expert(varying), dim) | ||
| x = torch.split( | ||
| x[: sum(num_tokens_per_expert)], | ||
| split_size_or_sections=num_tokens_per_expert, | ||
| dim=0, | ||
| ) | ||
| out_experts_splits = [] | ||
| for expert_idx, x_expert in enumerate(x): | ||
| h = F.silu(torch.matmul(x_expert, w1[expert_idx])) | ||
| h = h * torch.matmul(x_expert, w3[expert_idx]) | ||
| h = torch.matmul(h, w2[expert_idx]) | ||
| # h shape (tokens_per_expert(varying), dim) | ||
| out_experts_splits.append(h) | ||
| out = torch.cat(out_experts_splits, dim=0) | ||
|
|
||
| # side-effect code due to the usage of generate_permute_indices | ||
| out = torch.vstack((out, out.new_zeros((num_padding, out.shape[-1])))) | ||
| else: | ||
| # x shape (num_experts, tokens_per_expert, dim) | ||
| h = F.silu(torch.bmm(x, w1)) | ||
| h = h * torch.bmm(x, w3) | ||
| # out shape (num_experts, tokens_per_expert, dim) | ||
| out = torch.bmm(h, w2) | ||
|
|
||
| return out | ||
|
|
||
| @expert_parallel | ||
| @staticmethod | ||
| def _run_experts_grouped_mm( | ||
| w1: torch.Tensor, | ||
| w2: torch.Tensor, | ||
| w3: torch.Tensor, | ||
| x: torch.Tensor, | ||
| num_tokens_per_expert: torch.Tensor | None = None, | ||
| ) -> torch.Tensor: | ||
| if num_tokens_per_expert is not None: | ||
| offsets = torch.cumsum(num_tokens_per_expert, dim=0, dtype=torch.int32) | ||
| # grouped mm between a 2D tensor and a 3D tensor | ||
| assert x.dim() == 2 | ||
| else: | ||
| offsets = None | ||
| # fall back to regular bmm between 3D tensors | ||
| assert x.dim() == 3 | ||
|
|
||
| h = F.silu(torch._grouped_mm(x.bfloat16(), w1.bfloat16(), offs=offsets)) | ||
| h = h * torch._grouped_mm(x.bfloat16(), w3.bfloat16(), offs=offsets) | ||
| out = torch._grouped_mm(h, w2.bfloat16(), offs=offsets).type_as(x) | ||
|
|
||
| return out | ||
|
|
||
| def init_weights(self, init_std: float): | ||
| nn.init.trunc_normal_(self.w1, mean=0.0, std=0.02) | ||
| nn.init.trunc_normal_(self.w2, mean=0.0, std=init_std) | ||
|
|
@@ -297,7 +298,8 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| ) | ||
|
|
||
| # shape (bs*slen*top_k, dim) | ||
| routed_output = self.experts(routed_input, num_tokens_per_expert) | ||
| with torch._dynamo.set_fullgraph(False): | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. IIUC, this annotation is for the FSDP caused graph break, correct?
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This API can't decorate GroupedExperts.call right now. If it's a problem, we can just compile MoE with fullgraph=False |
||
| routed_output = self.experts(routed_input, num_tokens_per_expert) | ||
|
|
||
| # shared expert | ||
| if self.shared_expert is not None: | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What is this for?
Other than this, it seems we can just apply the same function llama 3 uses.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is to loud error if we recompile more than 8 times (default). currently, we would just silently fallback to eager if it happens.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should we do the same to Llama 3? If so we can still reuse this function