diff --git a/tests_deprecated/torch/nn/parallel/data_parallel/zero/test_hybrid.py b/tests_deprecated/torch/nn/parallel/data_parallel/zero/test_hybrid.py index 89e1ffe8..920b61a2 100644 --- a/tests_deprecated/torch/nn/parallel/data_parallel/zero/test_hybrid.py +++ b/tests_deprecated/torch/nn/parallel/data_parallel/zero/test_hybrid.py @@ -7,29 +7,18 @@ import torch import torch.multiprocessing as mp import torch.nn as nn +import oslo from oslo.torch.distributed.parallel_context import ParallelContext from oslo.torch.utils import get_free_port, set_seed from oslo.torch.nn.parallel.data_parallel.zero import ZeroRedundancyOptimizer -from torch.testing import assert_close from oslo.torch.nn.parallel import TensorParallel +from transformers import AutoModelForSequenceClassification, AutoTokenizer skip_if_dist_unavailable = pytest.mark.skipif( torch.cuda.device_count() < 2, reason="dist required" ) -class MlpModel(nn.Module): - def __init__(self): - super(MlpModel, self).__init__() - self.linear1 = nn.Linear(128, 256) - self.linear2 = nn.Linear(256, 512) - - def forward(self, x): - x = self.linear1(x) - x = self.linear2(x) - return x - - def assert_shard_close( tensor: torch.Tensor, shard: torch.Tensor, @@ -40,14 +29,14 @@ def assert_shard_close( ): assert tensor.ndim == shard.ndim if tensor.shape == shard.shape: - return assert_close(tensor, shard, rtol=rtol, atol=atol) + return torch.allclose(tensor, shard, rtol=rtol, atol=atol) else: dims_not_eq = torch.nonzero( torch.tensor(tensor.shape) != torch.tensor(shard.shape) ) if dims_not_eq.numel() == 1: dim = dims_not_eq.item() - return assert_close( + return torch.allclose( tensor.chunk(world_size, dim)[rank], shard, rtol=rtol, atol=atol ) else: @@ -58,38 +47,43 @@ def run(parallel_context: ParallelContext): local_rank = torch.distributed.get_rank() # create model - model = MlpModel().cuda() + model = AutoModelForSequenceClassification.from_pretrained("bert-base-uncased") hybrid_model = TensorParallel( copy.deepcopy(model), parallel_context=parallel_context ) - zero_model = model + oslo.ready(hybrid_model, parallel_context) + zero_model = model.cuda() # create optimizer hybrid_optimizer = ZeroRedundancyOptimizer( - torch.optim.Adam(hybrid_model.parameters(), lr=1), + torch.optim.Adam(hybrid_model.parameters(), lr=1e-2), parallel_context=parallel_context, - overlap_communication=True, - partition_grad=True, ) zero_optimizer = ZeroRedundancyOptimizer( - torch.optim.Adam(zero_model.parameters(), lr=1), + torch.optim.Adam(zero_model.parameters(), lr=1e-2), parallel_context=parallel_context, - overlap_communication=True, ) + # create tokenizer + tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased") + # create data set_seed(2021 + local_rank) - input_data = torch.randn(32, 128).cuda() + input_text = ["This is a sample text."] * 32 + inputs = tokenizer( + input_text, return_tensors="pt", padding=True, truncation=True + ).to("cuda") + labels = torch.randint(0, model.config.num_labels, (32,)).long().cuda() # zero-dp forward - hybrid_output = hybrid_model(input_data) - zero_output = zero_model(input_data) + hybrid_output = hybrid_model(**inputs, labels=labels).loss + zero_output = zero_model(**inputs, labels=labels).loss assert torch.allclose(hybrid_output, zero_output) # zero-dp backward - hybrid_output.sum().float().backward() - zero_output.sum().float().backward() + hybrid_output.backward() + zero_output.backward() # step hybrid_optimizer.step() @@ -97,7 +91,9 @@ def run(parallel_context: ParallelContext): # check updated param for hp, zp in zip(hybrid_model.parameters(), zero_model.parameters()): - assert torch.allclose(hp.data, zp.data) + assert assert_shard_close( + zp.data, hp.data, local_rank, torch.distributed.get_world_size() + ) def run_dist(rank, world_size):