From 64153b2480340dafdd9398e13208ac09dc9bcf31 Mon Sep 17 00:00:00 2001 From: Costa Huang Date: Wed, 25 May 2022 09:07:34 -0700 Subject: [PATCH] prototype torch.distributed --- rl_games/common/a2c_common.py | 95 +++++++++++++++++++---------------- 1 file changed, 53 insertions(+), 42 deletions(-) diff --git a/rl_games/common/a2c_common.py b/rl_games/common/a2c_common.py index a29af62a..2f183a2a 100644 --- a/rl_games/common/a2c_common.py +++ b/rl_games/common/a2c_common.py @@ -20,6 +20,7 @@ from tensorboardX import SummaryWriter import torch from torch import nn +import torch.distributed as dist from time import sleep @@ -68,11 +69,15 @@ def __init__(self, base_name, params): self.rank_size = 1 self.curr_frames = 0 if self.multi_gpu: - from rl_games.distributed.hvd_wrapper import HorovodWrapper - self.hvd = HorovodWrapper() - self.config = self.hvd.update_algo_config(config) - self.rank = self.hvd.rank - self.rank_size = self.hvd.rank_size + self.rank = int(os.getenv("LOCAL_RANK", "0")) + self.rank_size = int(os.getenv("WORLD_SIZE", "1")) + dist.init_process_group("nccl", rank=self.rank, world_size=self.rank_size) + + self.device_name = 'cuda:' + str(self.rank) + config['device'] = self.device_name + if self.rank != 0: + config['print_stats'] = False + config['lr_schedule'] = None self.use_diagnostics = config.get('use_diagnostics', False) @@ -256,19 +261,25 @@ def __init__(self, base_name, params): def trancate_gradients_and_step(self): if self.multi_gpu: - self.optimizer.synchronize() + # batch allreduce ops: see https://github.com/entity-neural-network/incubator/pull/220 + all_grads_list = [] + for param in self.model.parameters(): + if param.grad is not None: + all_grads_list.append(param.grad.view(-1)) + all_grads = torch.cat(all_grads_list) + dist.all_reduce(all_grads, op=dist.ReduceOp.SUM) + offset = 0 + for param in self.model.parameters(): + if param.grad is not None: + param.grad.data.copy_( + all_grads[offset : offset + param.numel()].view_as(param.grad.data) / self.rank_size + ) + offset += param.numel() if self.truncate_grads: - self.scaler.unscale_(self.optimizer) nn.utils.clip_grad_norm_(self.model.parameters(), self.grad_norm) - if self.multi_gpu: - with self.optimizer.skip_synchronize(): - self.scaler.step(self.optimizer) - self.scaler.update() - else: - self.scaler.step(self.optimizer) - self.scaler.update() + self.optimizer.step() def load_networks(self, params): builder = model_builder.ModelBuilder() @@ -315,10 +326,10 @@ def set_train(self): def update_lr(self, lr): - if self.multi_gpu: - lr_tensor = torch.tensor([lr]) - self.hvd.broadcast_value(lr_tensor, 'learning_rate') - lr = lr_tensor.item() + # if self.multi_gpu: + # lr_tensor = torch.tensor([lr]) + # self.hvd.broadcast_value(lr_tensor, 'learning_rate') + # lr = lr_tensor.item() for param_group in self.optimizer.param_groups: param_group['lr'] = lr @@ -809,8 +820,8 @@ def train_epoch(self): entropies.append(entropy) av_kls = torch_ext.mean_list(ep_kls) - if self.multi_gpu: - av_kls = self.hvd.average_value(av_kls, 'ep_kls') + # if self.multi_gpu: + # av_kls = self.hvd.average_value(av_kls, 'ep_kls') self.last_lr, self.entropy_coef = self.scheduler.update(self.last_lr, self.entropy_coef, self.epoch_num, 0, av_kls.item()) self.update_lr(self.last_lr) @@ -900,12 +911,12 @@ def train(self): while True: epoch_num = self.update_epoch() step_time, play_time, update_time, sum_time, a_losses, c_losses, entropies, kls, last_lr, lr_mul = self.train_epoch() - if self.multi_gpu: - self.hvd.sync_stats(self) + # if self.multi_gpu: + # self.hvd.sync_stats(self) # cleaning memory to optimize space self.dataset.update_values_dict(None) total_time += sum_time - curr_frames = self.curr_frames + curr_frames = self.curr_frames * self.rank_size if self.multi_gpu else self.curr_frames self.frame += curr_frames should_exit = False if self.rank == 0: @@ -964,10 +975,10 @@ def train(self): print('MAX EPOCHS NUM!') should_exit = True update_time = 0 - if self.multi_gpu: - should_exit_t = torch.tensor(should_exit).float() - self.hvd.broadcast_value(should_exit_t, 'should_exit') - should_exit = should_exit_t.bool().item() + # if self.multi_gpu: + # should_exit_t = torch.tensor(should_exit).float() + # self.hvd.broadcast_value(should_exit_t, 'should_exit') + # should_exit = should_exit_t.bool().item() if should_exit: return self.last_mean_rewards, epoch_num @@ -1046,16 +1057,16 @@ def train_epoch(self): self.dataset.update_mu_sigma(cmu, csigma) if self.schedule_type == 'legacy': - if self.multi_gpu: - kl = self.hvd.average_value(kl, 'ep_kls') + # if self.multi_gpu: + # kl = self.hvd.average_value(kl, 'ep_kls') self.last_lr, self.entropy_coef = self.scheduler.update(self.last_lr, self.entropy_coef, self.epoch_num, 0,kl.item()) self.update_lr(self.last_lr) av_kls = torch_ext.mean_list(ep_kls) if self.schedule_type == 'standard': - if self.multi_gpu: - av_kls = self.hvd.average_value(av_kls, 'ep_kls') + # if self.multi_gpu: + # av_kls = self.hvd.average_value(av_kls, 'ep_kls') self.last_lr, self.entropy_coef = self.scheduler.update(self.last_lr, self.entropy_coef, self.epoch_num, 0,av_kls.item()) self.update_lr(self.last_lr) kls.append(av_kls) @@ -1063,8 +1074,8 @@ def train_epoch(self): if self.normalize_input: self.model.running_mean_std.eval() # don't need to update statstics more than one miniepoch if self.schedule_type == 'standard_epoch': - if self.multi_gpu: - av_kls = self.hvd.average_value(torch_ext.mean_list(kls), 'ep_kls') + # if self.multi_gpu: + # av_kls = self.hvd.average_value(torch_ext.mean_list(kls), 'ep_kls') self.last_lr, self.entropy_coef = self.scheduler.update(self.last_lr, self.entropy_coef, self.epoch_num, 0,av_kls.item()) self.update_lr(self.last_lr) @@ -1144,16 +1155,16 @@ def train(self): self.obs = self.env_reset() self.curr_frames = self.batch_size_envs - if self.multi_gpu: - self.hvd.setup_algo(self) + # if self.multi_gpu: + # self.hvd.setup_algo(self) while True: epoch_num = self.update_epoch() step_time, play_time, update_time, sum_time, a_losses, c_losses, b_losses, entropies, kls, last_lr, lr_mul = self.train_epoch() total_time += sum_time frame = self.frame // self.num_agents - if self.multi_gpu: - self.hvd.sync_stats(self) + # if self.multi_gpu: + # self.hvd.sync_stats(self) # cleaning memory to optimize space self.dataset.update_values_dict(None) should_exit = False @@ -1162,7 +1173,7 @@ def train(self): # do we need scaled_time? scaled_time = self.num_agents * sum_time scaled_play_time = self.num_agents * play_time - curr_frames = self.curr_frames + curr_frames = self.curr_frames * self.rank_size if self.multi_gpu else self.curr_frames self.frame += curr_frames if self.print_stats: fps_step = curr_frames / step_time @@ -1218,10 +1229,10 @@ def train(self): should_exit = True update_time = 0 - if self.multi_gpu: - should_exit_t = torch.tensor(should_exit).float() - self.hvd.broadcast_value(should_exit_t, 'should_exit') - should_exit = should_exit_t.float().item() + # if self.multi_gpu: + # should_exit_t = torch.tensor(should_exit).float() + # self.hvd.broadcast_value(should_exit_t, 'should_exit') + # should_exit = should_exit_t.float().item() if should_exit: return self.last_mean_rewards, epoch_num