Skip to content

fp8_model_init doesn't work with DDP #1135

@MaciejBalaNV

Description

@MaciejBalaNV

When I'm trying to use fp8_model_init feature, it doesn't seem compatible with DDP. It throws an error:
RuntimeError: Modules with uninitialized parameters can't be used with "DistributedDataParallel". Run a dummy forward pass to correctly initialize the modules

Running a dummy forward pass doesn't help, using reset_parameters doesn't help either. Using a separate stream for DDP also does not fix this issue.

A simple reproducible case:

import os
import torch
import torch.nn as nn
import functools
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.distributed.fsdp.wrap import size_based_auto_wrap_policy
from torch.distributed.fsdp import ShardingStrategy
import transformer_engine as te

def setup(rank, world_size):
    os.environ["MASTER_ADDR"] = "localhost"
    os.environ["MASTER_PORT"] = "12364"

    dist.init_process_group("nccl", rank=rank, world_size=world_size)


def cleanup():
    dist.destroy_process_group()


class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.fc1 = te.pytorch.Linear(1024, 1024)
        self.fc2 = te.pytorch.Linear(1024, 10)

    def forward(self, x):
        return self.fc2(self.fc1(x))


def fsdp_main(rank, world_size):
    setup(rank, world_size)

    torch.cuda.set_device(rank)


    with te.pytorch.fp8.fp8_model_init(enabled=True):
        model = Net().to(rank)
    for i, m in enumerate(model.modules()):
        if hasattr(m, "reset_parameters"):
            print(f"resetting {i}")
            m.reset_parameters()
    input_data = torch.randn((16, 1024)).cuda()
    with torch.no_grad():
        model(input_data)
    torch.cuda.synchronize()
    model = DDP(model)
    torch.cuda.synchronize()

    dist.barrier()
    cleanup()


if __name__ == "__main__":
    WORLD_SIZE = 8
    mp.spawn(fsdp_main, args=(WORLD_SIZE,), nprocs=WORLD_SIZE, join=True)

@denera

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions