Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

solve the RuntimeError: Tensors must be CUDA and dense #33

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion megatron/core/parallel_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,10 @@ def initialize_model_parallel(
# Get world size and rank. Ensure some consistencies.
assert torch.distributed.is_initialized()
world_size: int = torch.distributed.get_world_size()

# Get world size and rank. Ensure some consistencies.
assert torch.distributed.is_initialized()
world_size: int = torch.distributed.get_world_size()

if world_size % (tensor_model_parallel_size * pipeline_model_parallel_size) != 0:
raise RuntimeError(
Expand Down Expand Up @@ -157,7 +161,10 @@ def initialize_model_parallel(
ranks = range(start_rank + j, end_rank, tensor_model_parallel_size)
all_data_parallel_group_ranks.append(list(ranks))
group = torch.distributed.new_group(ranks)
group_gloo = torch.distributed.new_group(ranks, backend="gloo")
if world_size <= 8:
group_gloo = torch.distributed.new_group(ranks, backend="gloo")
else:
group_gloo = torch.distributed.new_group(ranks, backend="nccl")
if rank in ranks:
_DATA_PARALLEL_GROUP = group
_DATA_PARALLEL_GROUP_GLOO = group_gloo
Expand Down
41 changes: 30 additions & 11 deletions megatron/optimizer/distrib_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -567,7 +567,10 @@ def save_parameter_state(self, filename):
buffers.
- Save world buffers to disk (i.e., distrib_opt.pt).
"""

# Order to judge cpu or cuda
assert torch.distributed.is_initialized()
world_size: int = torch.distributed.get_world_size()

# Data parallelism variables.
data_parallel_world_size = mpu.get_data_parallel_world_size()
data_parallel_rank = mpu.get_data_parallel_rank()
Expand All @@ -587,10 +590,16 @@ def save_parameter_state(self, filename):
model = self.models[model_idx]
gbuf_world_numel = model._grad_buffers[dtype].numel_padded
gbuf_local_numel = int(gbuf_world_numel / data_parallel_world_size)
local_shards = {key: torch.empty((gbuf_local_numel,),
dtype=torch.float32,
device="cpu")
for key in ("param", "exp_avg", "exp_avg_sq")}
if world_size <=8:
local_shards = {key: torch.empty((gbuf_local_numel,),
dtype=torch.float32,
device="cpu")
for key in ("param", "exp_avg", "exp_avg_sq")}
else:
local_shards = {key: torch.empty((gbuf_local_numel,),
dtype=torch.float32,
device="cuda")
for key in ("param", "exp_avg", "exp_avg_sq")}

# Build contiguous DP rank shards (for param + optim states).
for model_param, param_range_map in \
Expand All @@ -612,19 +621,29 @@ def save_parameter_state(self, filename):
gbuf_local_start = param_range_map["gbuf_local"].start
gbuf_local_end = param_range_map["gbuf_local"].end
for key in local_shards:
local_shards[key][gbuf_local_start:gbuf_local_end] \
.data.copy_(tensors[key].detach().cpu())
if world_size <=8:
local_shards[key][gbuf_local_start:gbuf_local_end] \
.data.copy_(tensors[key].detach().cpu())
else:
local_shards[key][gbuf_local_start:gbuf_local_end] \
.data.copy_(tensors[key].detach().cuda())

# Gather contiguous shards on DP rank 0.
world_tensors = {}
for key, send_tensor in local_shards.items():

# Gather tensor list.
if data_parallel_rank == 0:
recv_tensors = [torch.empty((gbuf_local_numel,),
dtype=torch.float32,
device="cpu")
for _ in range(data_parallel_world_size)]
if world_size <=8:
recv_tensors = [torch.empty((gbuf_local_numel,),
dtype=torch.float32,
device="cpu")
for _ in range(data_parallel_world_size)]
else:
recv_tensors = [torch.empty((gbuf_local_numel,),
dtype=torch.float32,
device="cuda")
for _ in range(data_parallel_world_size)]
else:
recv_tensors = None

Expand Down
52 changes: 37 additions & 15 deletions megatron/optimizer/overlapped_dist_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -953,7 +953,10 @@ def save_parameter_state(self, filename):
buffers.
- Save world buffers to disk (i.e., distrib_opt.pt).
"""

# Order to judge cpu or cuda
assert torch.distributed.is_initialized()
world_size: int = torch.distributed.get_world_size()

# Data parallelism variables.
data_parallel_world_size = mpu.get_data_parallel_world_size()
data_parallel_rank = mpu.get_data_parallel_rank()
Expand All @@ -970,30 +973,49 @@ def save_parameter_state(self, filename):
optim_state = self.optimizer.state[inner_param]
# copy to CPU buffer
num_elements = self._param_buffer[i].get_partitioned_size()
local_shards = {
key: torch.empty(
num_elements,
dtype=torch.float32,
device="cpu",
)
for key in ("param", "exp_avg", "exp_avg_sq")
}
if world_size <=8:
local_shards = {
key: torch.empty(
num_elements,
dtype=torch.float32,
device="cpu",
)
for key in ("param", "exp_avg", "exp_avg_sq")
}
else:
local_shards = {
key: torch.empty(
num_elements,
dtype=torch.float32,
device="cuda",
)
for key in ("param", "exp_avg", "exp_avg_sq")
}
tensors = {
"param": self._param_buffer[i].get_fp32_partitioned_param(),
**optim_state,
}
for key in local_shards:
local_shards[key].data.copy_(tensors[key].detach().cpu())
if world_size <=8:
local_shards[key].data.copy_(tensors[key].detach().cpu())
else:
local_shards[key].data.copy_(tensors[key].detach().cuda())

# Gather contiguous shards on DP rank 0.
world_tensors = {}
for key, send_tensor in local_shards.items():
# Gather tensor list.
if data_parallel_rank == 0:
recv_tensors = [torch.empty((num_elements,),
dtype=torch.float32,
device="cpu")
for _ in range(data_parallel_world_size)]
if world_size <=8:
recv_tensors = [torch.empty((num_elements,),
dtype=torch.float32,
device="cpu")
for _ in range(data_parallel_world_size)]
else:
recv_tensors = [torch.empty((num_elements,),
dtype=torch.float32,
device="cuda")
for _ in range(data_parallel_world_size)]
else:
recv_tensors = None
# Gather.
Expand Down Expand Up @@ -1150,4 +1172,4 @@ def load_parameter_state_in_parallel(self, filename):
tensors[key].data.copy_(local_shards[key])

self.copy_updated_parameters()
self.gather_parameters()
self.gather_parameters()