Skip to content

Commit

Permalink
Enable loss-parallel in example (#19882)
Browse files Browse the repository at this point in the history
  • Loading branch information
awaelchli authored May 20, 2024
1 parent 82e6e61 commit d76feef
Show file tree
Hide file tree
Showing 4 changed files with 20 additions and 3 deletions.
8 changes: 7 additions & 1 deletion examples/fabric/tensor_parallel/parallelism.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,13 @@ def parallelize(model: Transformer, device_mesh: DeviceMesh) -> Transformer:
# Parallelize the first embedding and the last linear out projection
plan = {
"tok_embeddings": RowwiseParallel(input_layouts=Replicate()),
"output": ColwiseParallel(input_layouts=Shard(1), output_layouts=Replicate()),
"output": ColwiseParallel(
input_layouts=Shard(1),
# Optional: Shard the output along the class dimension to compute the loss in parallel.
# See `loss_parallel` in `train.py`
output_layouts=Shard(-1),
use_local_output=False,
),
"norm": SequenceParallel(),
"layers.0": PrepareModuleInput(
input_layouts=(Replicate(), None),
Expand Down
2 changes: 1 addition & 1 deletion examples/fabric/tensor_parallel/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,8 @@ def train():

with loss_parallel():
loss = F.cross_entropy(output.reshape(-1, output.size(-1)), labels.reshape(-1))
fabric.backward(loss)

fabric.backward(loss)
optimizer.step()
optimizer.zero_grad()
fabric.print(f"Iteration {i} complete")
Expand Down
8 changes: 7 additions & 1 deletion examples/pytorch/tensor_parallel/parallelism.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,13 @@ def parallelize(model: Transformer, device_mesh: DeviceMesh) -> Transformer:
# Parallelize the first embedding and the last linear out projection
plan = {
"tok_embeddings": RowwiseParallel(input_layouts=Replicate()),
"output": ColwiseParallel(input_layouts=Shard(1), output_layouts=Replicate()),
"output": ColwiseParallel(
input_layouts=Shard(1),
# Optional: Shard the output along the class dimension to compute the loss in parallel.
# See `loss_parallel` in `train.py`
output_layouts=Shard(-1),
use_local_output=False,
),
"norm": SequenceParallel(),
"layers.0": PrepareModuleInput(
input_layouts=(Replicate(), None),
Expand Down
5 changes: 5 additions & 0 deletions examples/pytorch/tensor_parallel/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,14 @@ def training_step(self, batch):
inputs = batch[:, :-1]
labels = batch[:, 1:]
output = self.model(inputs)
# Optional: Parallelize loss computation across class dimension (see parallelism.py)
with loss_parallel():
return F.cross_entropy(output.reshape(-1, output.size(-1)), labels.reshape(-1))

def backward(self, *args, **kwargs):
with loss_parallel():
super().backward(*args, **kwargs)

def configure_optimizers(self):
return torch.optim.AdamW(self.model.parameters(), lr=3e-3, foreach=True)

Expand Down

0 comments on commit d76feef

Please sign in to comment.