-
Notifications
You must be signed in to change notification settings - Fork 547
Open
Description
When I'm trying to use fp8_model_init feature, it doesn't seem compatible with DDP. It throws an error:
RuntimeError: Modules with uninitialized parameters can't be used with "DistributedDataParallel". Run a dummy forward pass to correctly initialize the modules
Running a dummy forward pass doesn't help, using reset_parameters doesn't help either. Using a separate stream for DDP also does not fix this issue.
A simple reproducible case:
import os
import torch
import torch.nn as nn
import functools
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.distributed.fsdp.wrap import size_based_auto_wrap_policy
from torch.distributed.fsdp import ShardingStrategy
import transformer_engine as te
def setup(rank, world_size):
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "12364"
dist.init_process_group("nccl", rank=rank, world_size=world_size)
def cleanup():
dist.destroy_process_group()
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.fc1 = te.pytorch.Linear(1024, 1024)
self.fc2 = te.pytorch.Linear(1024, 10)
def forward(self, x):
return self.fc2(self.fc1(x))
def fsdp_main(rank, world_size):
setup(rank, world_size)
torch.cuda.set_device(rank)
with te.pytorch.fp8.fp8_model_init(enabled=True):
model = Net().to(rank)
for i, m in enumerate(model.modules()):
if hasattr(m, "reset_parameters"):
print(f"resetting {i}")
m.reset_parameters()
input_data = torch.randn((16, 1024)).cuda()
with torch.no_grad():
model(input_data)
torch.cuda.synchronize()
model = DDP(model)
torch.cuda.synchronize()
dist.barrier()
cleanup()
if __name__ == "__main__":
WORLD_SIZE = 8
mp.spawn(fsdp_main, args=(WORLD_SIZE,), nprocs=WORLD_SIZE, join=True)
Metadata
Metadata
Assignees
Labels
No labels