From 3cfcfa33f6247c6106dd1bf19c4ae688ce2c1a4e Mon Sep 17 00:00:00 2001 From: yhna940 Date: Mon, 17 Apr 2023 22:53:47 +0900 Subject: [PATCH] Fix minor bug for single output --- .../parallel/data_parallel/data_parallel.py | 7 +- .../parallel/data_parallel/data_parallel.py | 104 ++++++++++++++++++ 2 files changed, 109 insertions(+), 2 deletions(-) create mode 100644 tests_deprecated/torch/nn/parallel/data_parallel/data_parallel.py diff --git a/oslo/torch/nn/parallel/data_parallel/data_parallel.py b/oslo/torch/nn/parallel/data_parallel/data_parallel.py index 67d6f5ae..fee247cd 100644 --- a/oslo/torch/nn/parallel/data_parallel/data_parallel.py +++ b/oslo/torch/nn/parallel/data_parallel/data_parallel.py @@ -114,9 +114,12 @@ def forward(self, *args, **kwargs): } ) - if isinstance(inputs, torch.Tensor): + single_output = isinstance(inputs, torch.Tensor) + if single_output: inputs = (inputs,) - return _DistributedBackwardFunction.apply(self, *inputs) + + outputs = _DistributedBackwardFunction.apply(self, *inputs) + return outputs[0] if single_output else outputs def _pre_backward(self): pass diff --git a/tests_deprecated/torch/nn/parallel/data_parallel/data_parallel.py b/tests_deprecated/torch/nn/parallel/data_parallel/data_parallel.py new file mode 100644 index 00000000..9115f347 --- /dev/null +++ b/tests_deprecated/torch/nn/parallel/data_parallel/data_parallel.py @@ -0,0 +1,104 @@ +import os +import torch +import torch.nn as nn +import torch.multiprocessing as mp + +from oslo.torch.utils import get_free_port +from oslo.torch.distributed.parallel_context import ParallelContext +from oslo.torch.nn.parallel.data_parallel.data_parallel import _DistributedDataParallel +import copy + +import pytest + +skip_if_dist_unavailable = pytest.mark.skipif( + torch.cuda.device_count() < 2, reason="dist required" +) + + +class MlpModel(nn.Module): + def __init__(self): + super(MlpModel, self).__init__() + self.linear1 = nn.Linear(128, 256) + self.linear2 = nn.Linear(256, 512) + + def forward(self, x): + x = self.linear1(x) + x = self.linear2(x) + return x + + +class DictOutputModel(nn.Module): + def __init__(self): + super(DictOutputModel, self).__init__() + self.linear1 = nn.Linear(128, 256) + self.linear2 = nn.Linear(128, 512) + + def forward(self, x): + x1 = self.linear1(x) + x2 = self.linear2(x) + return {"output1": x1, "output2": x2} + + +class MultiOutputModel(nn.Module): + def __init__(self): + super(MultiOutputModel, self).__init__() + self.linear1 = nn.Linear(128, 256) + self.linear2 = nn.Linear(128, 512) + + def forward(self, x): + x1 = self.linear1(x) + x2 = self.linear2(x) + return x1, x2 + + +def run_dist(rank, world_size, model_class): + os.environ["RANK"] = str(rank) + os.environ["LOCAL_RANK"] = str(rank) + parallel_context = ParallelContext.from_torch(data_parallel_size=world_size) + + device = torch.device(f"cuda:{rank}" if torch.cuda.is_available() else "cpu") + + model = model_class() + ddp_model = _DistributedDataParallel( + copy.deepcopy(model).to(device), parallel_context + ) + ddp_model.parallelize() + model = model.to(device) + + input_data = torch.randn(32, 128).to(device) + + output_normal = model(input_data) + + output_ddp = ddp_model(input_data) + + if isinstance(output_normal, dict): + for key in output_normal: + assert torch.allclose( + output_normal[key], output_ddp[key], rtol=1e-03, atol=1e-03 + ), f"Outputs do not match for key {key}!" + elif isinstance(output_normal, tuple): + for i, (out_n, out_f) in enumerate(zip(output_normal, output_ddp)): + assert torch.allclose( + out_n, out_f, rtol=1e-03, atol=1e-03 + ), f"Outputs do not match for index {i}!" + else: + assert torch.allclose( + output_normal, output_ddp, rtol=1e-03, atol=1e-03 + ), "Outputs do not match!" + + +@skip_if_dist_unavailable +def test_ddp(): + world_size = 2 + os.environ["WORLD_SIZE"] = str(world_size) + os.environ["LOCAL_WORLD_SIZE"] = str(world_size) + os.environ["MASTER_ADDR"] = "localhost" + os.environ["MASTER_PORT"] = str(get_free_port()) + + models_to_test = [MlpModel, DictOutputModel, MultiOutputModel] + + for model_class in models_to_test: + print(f"Testing {model_class.__name__}") + mp.spawn(run_dist, args=(world_size, model_class), nprocs=world_size, join=True) + + print("All tests passed!")