diff --git a/megatron/core/parallel_state.py b/megatron/core/parallel_state.py index 041d2e0ebf..12caaf242f 100644 --- a/megatron/core/parallel_state.py +++ b/megatron/core/parallel_state.py @@ -115,6 +115,10 @@ def initialize_model_parallel( # Get world size and rank. Ensure some consistencies. assert torch.distributed.is_initialized() world_size: int = torch.distributed.get_world_size() + + # Get world size and rank. Ensure some consistencies. + assert torch.distributed.is_initialized() + world_size: int = torch.distributed.get_world_size() if world_size % (tensor_model_parallel_size * pipeline_model_parallel_size) != 0: raise RuntimeError( @@ -157,7 +161,10 @@ def initialize_model_parallel( ranks = range(start_rank + j, end_rank, tensor_model_parallel_size) all_data_parallel_group_ranks.append(list(ranks)) group = torch.distributed.new_group(ranks) - group_gloo = torch.distributed.new_group(ranks, backend="gloo") + if world_size <= 8: + group_gloo = torch.distributed.new_group(ranks, backend="gloo") + else: + group_gloo = torch.distributed.new_group(ranks, backend="nccl") if rank in ranks: _DATA_PARALLEL_GROUP = group _DATA_PARALLEL_GROUP_GLOO = group_gloo diff --git a/megatron/optimizer/distrib_optimizer.py b/megatron/optimizer/distrib_optimizer.py index 38e244afc4..870570316e 100644 --- a/megatron/optimizer/distrib_optimizer.py +++ b/megatron/optimizer/distrib_optimizer.py @@ -567,7 +567,10 @@ def save_parameter_state(self, filename): buffers. - Save world buffers to disk (i.e., distrib_opt.pt). """ - + # Order to judge cpu or cuda + assert torch.distributed.is_initialized() + world_size: int = torch.distributed.get_world_size() + # Data parallelism variables. data_parallel_world_size = mpu.get_data_parallel_world_size() data_parallel_rank = mpu.get_data_parallel_rank() @@ -587,10 +590,16 @@ def save_parameter_state(self, filename): model = self.models[model_idx] gbuf_world_numel = model._grad_buffers[dtype].numel_padded gbuf_local_numel = int(gbuf_world_numel / data_parallel_world_size) - local_shards = {key: torch.empty((gbuf_local_numel,), - dtype=torch.float32, - device="cpu") - for key in ("param", "exp_avg", "exp_avg_sq")} + if world_size <=8: + local_shards = {key: torch.empty((gbuf_local_numel,), + dtype=torch.float32, + device="cpu") + for key in ("param", "exp_avg", "exp_avg_sq")} + else: + local_shards = {key: torch.empty((gbuf_local_numel,), + dtype=torch.float32, + device="cuda") + for key in ("param", "exp_avg", "exp_avg_sq")} # Build contiguous DP rank shards (for param + optim states). for model_param, param_range_map in \ @@ -612,8 +621,12 @@ def save_parameter_state(self, filename): gbuf_local_start = param_range_map["gbuf_local"].start gbuf_local_end = param_range_map["gbuf_local"].end for key in local_shards: - local_shards[key][gbuf_local_start:gbuf_local_end] \ - .data.copy_(tensors[key].detach().cpu()) + if world_size <=8: + local_shards[key][gbuf_local_start:gbuf_local_end] \ + .data.copy_(tensors[key].detach().cpu()) + else: + local_shards[key][gbuf_local_start:gbuf_local_end] \ + .data.copy_(tensors[key].detach().cuda()) # Gather contiguous shards on DP rank 0. world_tensors = {} @@ -621,10 +634,16 @@ def save_parameter_state(self, filename): # Gather tensor list. if data_parallel_rank == 0: - recv_tensors = [torch.empty((gbuf_local_numel,), - dtype=torch.float32, - device="cpu") - for _ in range(data_parallel_world_size)] + if world_size <=8: + recv_tensors = [torch.empty((gbuf_local_numel,), + dtype=torch.float32, + device="cpu") + for _ in range(data_parallel_world_size)] + else: + recv_tensors = [torch.empty((gbuf_local_numel,), + dtype=torch.float32, + device="cuda") + for _ in range(data_parallel_world_size)] else: recv_tensors = None diff --git a/megatron/optimizer/overlapped_dist_optimizer.py b/megatron/optimizer/overlapped_dist_optimizer.py index eaab434248..8ff86985dc 100644 --- a/megatron/optimizer/overlapped_dist_optimizer.py +++ b/megatron/optimizer/overlapped_dist_optimizer.py @@ -953,7 +953,10 @@ def save_parameter_state(self, filename): buffers. - Save world buffers to disk (i.e., distrib_opt.pt). """ - + # Order to judge cpu or cuda + assert torch.distributed.is_initialized() + world_size: int = torch.distributed.get_world_size() + # Data parallelism variables. data_parallel_world_size = mpu.get_data_parallel_world_size() data_parallel_rank = mpu.get_data_parallel_rank() @@ -970,30 +973,49 @@ def save_parameter_state(self, filename): optim_state = self.optimizer.state[inner_param] # copy to CPU buffer num_elements = self._param_buffer[i].get_partitioned_size() - local_shards = { - key: torch.empty( - num_elements, - dtype=torch.float32, - device="cpu", - ) - for key in ("param", "exp_avg", "exp_avg_sq") - } + if world_size <=8: + local_shards = { + key: torch.empty( + num_elements, + dtype=torch.float32, + device="cpu", + ) + for key in ("param", "exp_avg", "exp_avg_sq") + } + else: + local_shards = { + key: torch.empty( + num_elements, + dtype=torch.float32, + device="cuda", + ) + for key in ("param", "exp_avg", "exp_avg_sq") + } tensors = { "param": self._param_buffer[i].get_fp32_partitioned_param(), **optim_state, } for key in local_shards: - local_shards[key].data.copy_(tensors[key].detach().cpu()) + if world_size <=8: + local_shards[key].data.copy_(tensors[key].detach().cpu()) + else: + local_shards[key].data.copy_(tensors[key].detach().cuda()) # Gather contiguous shards on DP rank 0. world_tensors = {} for key, send_tensor in local_shards.items(): # Gather tensor list. if data_parallel_rank == 0: - recv_tensors = [torch.empty((num_elements,), - dtype=torch.float32, - device="cpu") - for _ in range(data_parallel_world_size)] + if world_size <=8: + recv_tensors = [torch.empty((num_elements,), + dtype=torch.float32, + device="cpu") + for _ in range(data_parallel_world_size)] + else: + recv_tensors = [torch.empty((num_elements,), + dtype=torch.float32, + device="cuda") + for _ in range(data_parallel_world_size)] else: recv_tensors = None # Gather. @@ -1150,4 +1172,4 @@ def load_parameter_state_in_parallel(self, filename): tensors[key].data.copy_(local_shards[key]) self.copy_updated_parameters() - self.gather_parameters() \ No newline at end of file + self.gather_parameters()