Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Conflict with torch.distributed? #145

Open
Wuzimeng opened this issue Jan 13, 2024 · 1 comment
Open

Conflict with torch.distributed? #145

Wuzimeng opened this issue Jan 13, 2024 · 1 comment

Comments

@Wuzimeng
Copy link

Hello, I encounted an error when calling flop_count_table() in my distributed training code.
The error message is as below. But I checked the input of function allgather() and didn't find anything unusual.

File "/xxx/anaconda3/envs/torch13/lib/python3.9/site-packages/torch/distributed/distributed_c10d.py", line 2275, in all_gather
work = default_pg.allgather([tensor_list], [tensor])
RuntimeError: unsupported input list type: Tensor[]
ERROR:torch.distributed.elastic.multiprocessing.api:failed (exitcode: 1) local_rank: 0 (pid: 628698) of binary: /xxx/anaconda3/envs/torch13/bin/python

Here's a brief code which can regenerate my error by calling python -m torch.distributed.run --nproc_per_node=1 --master_port 10603 try.py

import torch
import torch.nn as nn

from fvcore.nn import FlopCountAnalysis, flop_count_str, flop_count_table

class SimpleModel(nn.Module):
    def __init__(self):
        super(SimpleModel, self).__init__()
        self.fc = nn.Linear(10, 1)

    def forward(self, x):
        y = self.fc(x)
        concat_all_gather(y)
        return y.sum()

@torch.no_grad()
def concat_all_gather(tensor):
    tensors_gather = [torch.ones_like(tensor)
        for _ in range(torch.distributed.get_world_size())]
    torch.distributed.all_gather(tensors_gather, tensor, async_op=False)
    output = torch.cat(tensors_gather, dim=0)
    return output    

torch.distributed.init_process_group(backend='nccl')
local_rank = torch.distributed.get_rank()
world_size = torch.distributed.get_world_size()

model = SimpleModel().cuda()
model = nn.parallel.DistributedDataParallel(model, device_ids=[local_rank], output_device=local_rank)

flop = FlopCountAnalysis(model.module, torch.randn(100, 10).cuda())
print(flop_count_table(flop, max_depth=7, show_param_shapes=True))

torch.distributed.destroy_process_group()

Additionally, my environment is:
Python 3.9.18, cuda-11.7, fvcore==0.1.5.post20221221, torch 1.13

Another confusing thing is, in the python3.8.18 & cuda-11.4 & torch 1.10 environment, the above doesn't result in an error.

@philipwan
Copy link

I'm having the same problem. Have you solved it?
In my condition, it seems like there is a conflict between jit.trace module and dist.all_gather

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants