Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Prototype torch.distributed #165

Closed
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 53 additions & 42 deletions rl_games/common/a2c_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from tensorboardX import SummaryWriter
import torch
from torch import nn
import torch.distributed as dist

from time import sleep

Expand Down Expand Up @@ -68,11 +69,15 @@ def __init__(self, base_name, params):
self.rank_size = 1
self.curr_frames = 0
if self.multi_gpu:
from rl_games.distributed.hvd_wrapper import HorovodWrapper
self.hvd = HorovodWrapper()
self.config = self.hvd.update_algo_config(config)
self.rank = self.hvd.rank
self.rank_size = self.hvd.rank_size
self.rank = int(os.getenv("LOCAL_RANK", "0"))
self.rank_size = int(os.getenv("WORLD_SIZE", "1"))
dist.init_process_group("nccl", rank=self.rank, world_size=self.rank_size)

self.device_name = 'cuda:' + str(self.rank)
config['device'] = self.device_name
if self.rank != 0:
config['print_stats'] = False
config['lr_schedule'] = None

self.use_diagnostics = config.get('use_diagnostics', False)

Expand Down Expand Up @@ -256,19 +261,25 @@ def __init__(self, base_name, params):

def trancate_gradients_and_step(self):
if self.multi_gpu:
self.optimizer.synchronize()
# batch allreduce ops: see https://github.com/entity-neural-network/incubator/pull/220
all_grads_list = []
for param in self.model.parameters():
if param.grad is not None:
all_grads_list.append(param.grad.view(-1))
all_grads = torch.cat(all_grads_list)
dist.all_reduce(all_grads, op=dist.ReduceOp.SUM)
offset = 0
for param in self.model.parameters():
if param.grad is not None:
param.grad.data.copy_(
all_grads[offset : offset + param.numel()].view_as(param.grad.data) / self.rank_size
)
offset += param.numel()

if self.truncate_grads:
self.scaler.unscale_(self.optimizer)
nn.utils.clip_grad_norm_(self.model.parameters(), self.grad_norm)

if self.multi_gpu:
with self.optimizer.skip_synchronize():
self.scaler.step(self.optimizer)
self.scaler.update()
else:
self.scaler.step(self.optimizer)
self.scaler.update()
self.optimizer.step()

def load_networks(self, params):
builder = model_builder.ModelBuilder()
Expand Down Expand Up @@ -315,10 +326,10 @@ def set_train(self):


def update_lr(self, lr):
if self.multi_gpu:
lr_tensor = torch.tensor([lr])
self.hvd.broadcast_value(lr_tensor, 'learning_rate')
lr = lr_tensor.item()
# if self.multi_gpu:
# lr_tensor = torch.tensor([lr])
# self.hvd.broadcast_value(lr_tensor, 'learning_rate')
# lr = lr_tensor.item()

for param_group in self.optimizer.param_groups:
param_group['lr'] = lr
Expand Down Expand Up @@ -809,8 +820,8 @@ def train_epoch(self):
entropies.append(entropy)

av_kls = torch_ext.mean_list(ep_kls)
if self.multi_gpu:
av_kls = self.hvd.average_value(av_kls, 'ep_kls')
# if self.multi_gpu:
# av_kls = self.hvd.average_value(av_kls, 'ep_kls')

self.last_lr, self.entropy_coef = self.scheduler.update(self.last_lr, self.entropy_coef, self.epoch_num, 0, av_kls.item())
self.update_lr(self.last_lr)
Expand Down Expand Up @@ -900,12 +911,12 @@ def train(self):
while True:
epoch_num = self.update_epoch()
step_time, play_time, update_time, sum_time, a_losses, c_losses, entropies, kls, last_lr, lr_mul = self.train_epoch()
if self.multi_gpu:
self.hvd.sync_stats(self)
# if self.multi_gpu:
# self.hvd.sync_stats(self)
# cleaning memory to optimize space
self.dataset.update_values_dict(None)
total_time += sum_time
curr_frames = self.curr_frames
curr_frames = self.curr_frames * self.rank_size if self.multi_gpu else self.curr_frames
self.frame += curr_frames
should_exit = False
if self.rank == 0:
Expand Down Expand Up @@ -964,10 +975,10 @@ def train(self):
print('MAX EPOCHS NUM!')
should_exit = True
update_time = 0
if self.multi_gpu:
should_exit_t = torch.tensor(should_exit).float()
self.hvd.broadcast_value(should_exit_t, 'should_exit')
should_exit = should_exit_t.bool().item()
# if self.multi_gpu:
# should_exit_t = torch.tensor(should_exit).float()
# self.hvd.broadcast_value(should_exit_t, 'should_exit')
# should_exit = should_exit_t.bool().item()
if should_exit:
return self.last_mean_rewards, epoch_num

Expand Down Expand Up @@ -1046,25 +1057,25 @@ def train_epoch(self):
self.dataset.update_mu_sigma(cmu, csigma)

if self.schedule_type == 'legacy':
if self.multi_gpu:
kl = self.hvd.average_value(kl, 'ep_kls')
# if self.multi_gpu:
# kl = self.hvd.average_value(kl, 'ep_kls')
self.last_lr, self.entropy_coef = self.scheduler.update(self.last_lr, self.entropy_coef, self.epoch_num, 0,kl.item())
self.update_lr(self.last_lr)

av_kls = torch_ext.mean_list(ep_kls)

if self.schedule_type == 'standard':
if self.multi_gpu:
av_kls = self.hvd.average_value(av_kls, 'ep_kls')
# if self.multi_gpu:
# av_kls = self.hvd.average_value(av_kls, 'ep_kls')
self.last_lr, self.entropy_coef = self.scheduler.update(self.last_lr, self.entropy_coef, self.epoch_num, 0,av_kls.item())
self.update_lr(self.last_lr)
kls.append(av_kls)
self.diagnostics.mini_epoch(self, mini_ep)
if self.normalize_input:
self.model.running_mean_std.eval() # don't need to update statstics more than one miniepoch
if self.schedule_type == 'standard_epoch':
if self.multi_gpu:
av_kls = self.hvd.average_value(torch_ext.mean_list(kls), 'ep_kls')
# if self.multi_gpu:
# av_kls = self.hvd.average_value(torch_ext.mean_list(kls), 'ep_kls')
self.last_lr, self.entropy_coef = self.scheduler.update(self.last_lr, self.entropy_coef, self.epoch_num, 0,av_kls.item())
self.update_lr(self.last_lr)

Expand Down Expand Up @@ -1144,16 +1155,16 @@ def train(self):
self.obs = self.env_reset()
self.curr_frames = self.batch_size_envs

if self.multi_gpu:
self.hvd.setup_algo(self)
# if self.multi_gpu:
# self.hvd.setup_algo(self)

while True:
epoch_num = self.update_epoch()
step_time, play_time, update_time, sum_time, a_losses, c_losses, b_losses, entropies, kls, last_lr, lr_mul = self.train_epoch()
total_time += sum_time
frame = self.frame // self.num_agents
if self.multi_gpu:
self.hvd.sync_stats(self)
# if self.multi_gpu:
# self.hvd.sync_stats(self)
# cleaning memory to optimize space
self.dataset.update_values_dict(None)
should_exit = False
Expand All @@ -1162,7 +1173,7 @@ def train(self):
# do we need scaled_time?
scaled_time = self.num_agents * sum_time
scaled_play_time = self.num_agents * play_time
curr_frames = self.curr_frames
curr_frames = self.curr_frames * self.rank_size if self.multi_gpu else self.curr_frames
self.frame += curr_frames
if self.print_stats:
fps_step = curr_frames / step_time
Expand Down Expand Up @@ -1218,10 +1229,10 @@ def train(self):
should_exit = True

update_time = 0
if self.multi_gpu:
should_exit_t = torch.tensor(should_exit).float()
self.hvd.broadcast_value(should_exit_t, 'should_exit')
should_exit = should_exit_t.float().item()
# if self.multi_gpu:
# should_exit_t = torch.tensor(should_exit).float()
# self.hvd.broadcast_value(should_exit_t, 'should_exit')
# should_exit = should_exit_t.float().item()
if should_exit:
return self.last_mean_rewards, epoch_num