Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 15 additions & 5 deletions torchtitan/experiments/llama4/infra/parallelize.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,7 @@
from torchtitan.config_manager import JobConfig, TORCH_DTYPE_MAP
from torchtitan.distributed import ParallelDims

from torchtitan.models.llama3.infra.parallelize import (
apply_ac,
apply_compile,
apply_ddp,
)
from torchtitan.models.llama3.infra.parallelize import apply_ac, apply_ddp
from torchtitan.tools.logging import logger

from .expert_parallel import (
Expand All @@ -36,6 +32,20 @@
)


def apply_compile(model: nn.Module):
"""
Apply torch.compile to each TransformerBlock, which makes compilation efficient due to
repeated structure. Alternatively one can compile the whole model (after applying DP).
"""
torch._dynamo.config.fail_on_recompile_limit_hit = True
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is this for?
Other than this, it seems we can just apply the same function llama 3 uses.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is to loud error if we recompile more than 8 times (default). currently, we would just silently fallback to eager if it happens.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we do the same to Llama 3? If so we can still reuse this function

for layer_id, transformer_block in model.layers.named_children():
# NOTE: we allow graph breaks on FSDP hooks for MoE experts, `see set_fullgraph(False)` in moe.py
transformer_block = torch.compile(transformer_block, fullgraph=True)
model.layers.register_module(layer_id, transformer_block)

logger.info("Compiling each TransformerBlock with torch.compile")


def parallelize_llama(
model: nn.Module,
parallel_dims: ParallelDims,
Expand Down
146 changes: 74 additions & 72 deletions torchtitan/experiments/llama4/model/moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,75 @@
from .args import TransformerModelArgs


# TODO: keeping this for-loop implementation for comparison
# and readability, may remove later
@expert_parallel
def _run_experts_for_loop(
w1: torch.Tensor,
w2: torch.Tensor,
w3: torch.Tensor,
x: torch.Tensor,
num_tokens_per_expert: torch.Tensor | None = None,
) -> torch.Tensor:
if num_tokens_per_expert is not None:
# NOTE: this would incur a synchronization between device and host
num_tokens_per_expert = num_tokens_per_expert.tolist()

# side-effect code due to the usage of generate_permute_indices
num_padding = x.shape[0] - sum(num_tokens_per_expert)

# a tuple of tensors indexed by experts
# each with shape (tokens_per_expert(varying), dim)
x = torch.split(
x[: sum(num_tokens_per_expert)],
split_size_or_sections=num_tokens_per_expert,
dim=0,
)
out_experts_splits = []
for expert_idx, x_expert in enumerate(x):
h = F.silu(torch.matmul(x_expert, w1[expert_idx]))
h = h * torch.matmul(x_expert, w3[expert_idx])
h = torch.matmul(h, w2[expert_idx])
# h shape (tokens_per_expert(varying), dim)
out_experts_splits.append(h)
out = torch.cat(out_experts_splits, dim=0)

# side-effect code due to the usage of generate_permute_indices
out = torch.vstack((out, out.new_zeros((num_padding, out.shape[-1]))))
else:
# x shape (num_experts, tokens_per_expert, dim)
h = F.silu(torch.bmm(x, w1))
h = h * torch.bmm(x, w3)
# out shape (num_experts, tokens_per_expert, dim)
out = torch.bmm(h, w2)

return out


@expert_parallel
def _run_experts_grouped_mm(
w1: torch.Tensor,
w2: torch.Tensor,
w3: torch.Tensor,
x: torch.Tensor,
num_tokens_per_expert: torch.Tensor | None = None,
) -> torch.Tensor:
if num_tokens_per_expert is not None:
offsets = torch.cumsum(num_tokens_per_expert, dim=0, dtype=torch.int32)
# grouped mm between a 2D tensor and a 3D tensor
assert x.dim() == 2
else:
offsets = None
# fall back to regular bmm between 3D tensors
assert x.dim() == 3

h = F.silu(torch._grouped_mm(x.bfloat16(), w1.bfloat16(), offs=offsets))
h = h * torch._grouped_mm(x.bfloat16(), w3.bfloat16(), offs=offsets)
out = torch._grouped_mm(h, w2.bfloat16(), offs=offsets).type_as(x)

return out


class GroupedExperts(nn.Module):
def __init__(
self,
Expand All @@ -28,89 +97,21 @@ def __init__(
self.w3 = nn.Parameter(torch.empty(num_experts, dim, hidden_dim))
self.use_grouped_mm = use_grouped_mm

@torch._dynamo.set_fullgraph(True)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what is this annotation for?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Compiling the block with fullgraph=False could allow graph breaks to creep in silently with dynamo changes, and we wouldn't know about them until we manually inspect the graph or suspect QPS to have regressed.

This API to more granularly control the fullgraph argument of torch.compile, you can flip it on and off within a compiled region. In this case, we allow graph breaks between GroupedExperts.call and GroupedExperts.forward, i.e. allow graph break on the forward hooks from FSDP

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In addition to FSDP comms, EP a2a also happens before & after GroupedExperts.forward. Does it mean it's still not fine-grained enough to capture graphs in EP?

def forward(
self,
x: torch.Tensor,
num_tokens_per_expert: torch.Tensor | None = None,
) -> torch.Tensor:
if self.use_grouped_mm:
return GroupedExperts._run_experts_grouped_mm(
return _run_experts_grouped_mm(
self.w1, self.w2, self.w3, x, num_tokens_per_expert
)
else:
return GroupedExperts._run_experts_for_loop(
return _run_experts_for_loop(
self.w1, self.w2, self.w3, x, num_tokens_per_expert
)

# TODO: keeping this for-loop implementation for comparison
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

staticmethod on user-defined classes can not be generically supported, I moved those out.

Could you explain more? Does it mean if we move them out, then torch.compile can trace them in the same graph as the caller module is in?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes

# and readability, may remove later
@expert_parallel
@staticmethod
def _run_experts_for_loop(
w1: torch.Tensor,
w2: torch.Tensor,
w3: torch.Tensor,
x: torch.Tensor,
num_tokens_per_expert: torch.Tensor | None = None,
) -> torch.Tensor:
if num_tokens_per_expert is not None:
# NOTE: this would incur a synchronization between device and host
num_tokens_per_expert = num_tokens_per_expert.tolist()

# side-effect code due to the usage of generate_permute_indices
num_padding = x.shape[0] - sum(num_tokens_per_expert)

# a tuple of tensors indexed by experts
# each with shape (tokens_per_expert(varying), dim)
x = torch.split(
x[: sum(num_tokens_per_expert)],
split_size_or_sections=num_tokens_per_expert,
dim=0,
)
out_experts_splits = []
for expert_idx, x_expert in enumerate(x):
h = F.silu(torch.matmul(x_expert, w1[expert_idx]))
h = h * torch.matmul(x_expert, w3[expert_idx])
h = torch.matmul(h, w2[expert_idx])
# h shape (tokens_per_expert(varying), dim)
out_experts_splits.append(h)
out = torch.cat(out_experts_splits, dim=0)

# side-effect code due to the usage of generate_permute_indices
out = torch.vstack((out, out.new_zeros((num_padding, out.shape[-1]))))
else:
# x shape (num_experts, tokens_per_expert, dim)
h = F.silu(torch.bmm(x, w1))
h = h * torch.bmm(x, w3)
# out shape (num_experts, tokens_per_expert, dim)
out = torch.bmm(h, w2)

return out

@expert_parallel
@staticmethod
def _run_experts_grouped_mm(
w1: torch.Tensor,
w2: torch.Tensor,
w3: torch.Tensor,
x: torch.Tensor,
num_tokens_per_expert: torch.Tensor | None = None,
) -> torch.Tensor:
if num_tokens_per_expert is not None:
offsets = torch.cumsum(num_tokens_per_expert, dim=0, dtype=torch.int32)
# grouped mm between a 2D tensor and a 3D tensor
assert x.dim() == 2
else:
offsets = None
# fall back to regular bmm between 3D tensors
assert x.dim() == 3

h = F.silu(torch._grouped_mm(x.bfloat16(), w1.bfloat16(), offs=offsets))
h = h * torch._grouped_mm(x.bfloat16(), w3.bfloat16(), offs=offsets)
out = torch._grouped_mm(h, w2.bfloat16(), offs=offsets).type_as(x)

return out

def init_weights(self, init_std: float):
nn.init.trunc_normal_(self.w1, mean=0.0, std=0.02)
nn.init.trunc_normal_(self.w2, mean=0.0, std=init_std)
Expand Down Expand Up @@ -297,7 +298,8 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
)

# shape (bs*slen*top_k, dim)
routed_output = self.experts(routed_input, num_tokens_per_expert)
with torch._dynamo.set_fullgraph(False):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IIUC, this annotation is for the FSDP caused graph break, correct?
Can we possibly incur this in the apply_compile function. Technically this change is model-intrusively, despite being small.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This API can't decorate GroupedExperts.call right now. If it's a problem, we can just compile MoE with fullgraph=False

routed_output = self.experts(routed_input, num_tokens_per_expert)

# shared expert
if self.shared_expert is not None:
Expand Down