Skip to content

Commit

Permalink
Deprecate horovod in favor of torch.distributed (#171)
Browse files Browse the repository at this point in the history
* Prototype `torch.distributed`

* Fix

* add docs

* Address comments

Co-authored-by: Costa Huang <[email protected]>
  • Loading branch information
vwxyzjn and Costa Huang authored Jun 3, 2022
1 parent 06a3319 commit c9575bd
Show file tree
Hide file tree
Showing 8 changed files with 71 additions and 121 deletions.
8 changes: 8 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,14 @@ poetry run python runner.py --train --file rl_games/configs/atari/ppo_breakout_t
```


## Multi GPU

We use `torchrun` to orchestrate any multi-gpu runs.

```bash
torchrun --standalone --nnodes=1 --nproc_per_node=2 runner.py --train --file rl_games/configs/ppo_cartpole.yaml
```

## Config Parameters

| Field | Example Value | Default | Description |
Expand Down
1 change: 0 additions & 1 deletion rl_games/algos_torch/a2c_continuous.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,6 @@ def __init__(self, base_name, params):
'writter' : self.writer,
'max_epochs' : self.max_epochs,
'multi_gpu' : self.multi_gpu,
'hvd': self.hvd if self.multi_gpu else None
}
self.central_value_net = central_value.CentralValueTrain(**cv_config).to(self.ppo_device)

Expand Down
1 change: 0 additions & 1 deletion rl_games/algos_torch/a2c_discrete.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,6 @@ def __init__(self, base_name, params):
'writter' : self.writer,
'max_epochs' : self.max_epochs,
'multi_gpu' : self.multi_gpu,
'hvd': self.hvd if self.multi_gpu else None
}
self.central_value_net = central_value.CentralValueTrain(**cv_config).to(self.ppo_device)

Expand Down
8 changes: 4 additions & 4 deletions rl_games/algos_torch/central_value.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import torch
from torch import nn
import torch.distributed as dist
import gym
import numpy as np
from rl_games.algos_torch import torch_ext
Expand All @@ -9,7 +10,7 @@
from rl_games.common import schedulers

class CentralValueTrain(nn.Module):
def __init__(self, state_shape, value_size, ppo_device, num_agents, horizon_length, num_actors, num_actions, seq_len, normalize_value,network, config, writter, max_epochs, multi_gpu, hvd):
def __init__(self, state_shape, value_size, ppo_device, num_agents, horizon_length, num_actors, num_actions, seq_len, normalize_value,network, config, writter, max_epochs, multi_gpu):
nn.Module.__init__(self)
self.ppo_device = ppo_device
self.num_agents, self.horizon_length, self.num_actors, self.seq_len = num_agents, horizon_length, num_actors, seq_len
Expand All @@ -19,7 +20,6 @@ def __init__(self, state_shape, value_size, ppo_device, num_agents, horizon_leng
self.value_size = value_size
self.max_epochs = max_epochs
self.multi_gpu = multi_gpu
self.hvd = hvd
self.truncate_grads = config.get('truncate_grads', False)
self.config = config
self.normalize_input = config['normalize_input']
Expand Down Expand Up @@ -77,8 +77,8 @@ def __init__(self, state_shape, value_size, ppo_device, num_agents, horizon_leng

def update_lr(self, lr):
if self.multi_gpu:
lr_tensor = torch.tensor([lr])
self.hvd.broadcast_value(lr_tensor, 'cv_learning_rate')
lr_tensor = torch.tensor([lr], device=self.device)
dist.broadcast(lr_tensor, 0)
lr = lr_tensor.item()

for param_group in self.optimizer.param_groups:
Expand Down
84 changes: 51 additions & 33 deletions rl_games/common/a2c_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from tensorboardX import SummaryWriter
import torch
from torch import nn
import torch.distributed as dist

from time import sleep

Expand Down Expand Up @@ -68,11 +69,15 @@ def __init__(self, base_name, params):
self.rank_size = 1
self.curr_frames = 0
if self.multi_gpu:
from rl_games.distributed.hvd_wrapper import HorovodWrapper
self.hvd = HorovodWrapper()
self.config = self.hvd.update_algo_config(config)
self.rank = self.hvd.rank
self.rank_size = self.hvd.rank_size
self.rank = int(os.getenv("LOCAL_RANK", "0"))
self.rank_size = int(os.getenv("WORLD_SIZE", "1"))
dist.init_process_group("nccl", rank=self.rank, world_size=self.rank_size)

self.device_name = 'cuda:' + str(self.rank)
config['device'] = self.device_name
if self.rank != 0:
config['print_stats'] = False
config['lr_schedule'] = None

self.use_diagnostics = config.get('use_diagnostics', False)

Expand Down Expand Up @@ -251,19 +256,27 @@ def __init__(self, base_name, params):

def trancate_gradients_and_step(self):
if self.multi_gpu:
self.optimizer.synchronize()
# batch allreduce ops: see https://github.com/entity-neural-network/incubator/pull/220
all_grads_list = []
for param in self.model.parameters():
if param.grad is not None:
all_grads_list.append(param.grad.view(-1))
all_grads = torch.cat(all_grads_list)
dist.all_reduce(all_grads, op=dist.ReduceOp.SUM)
offset = 0
for param in self.model.parameters():
if param.grad is not None:
param.grad.data.copy_(
all_grads[offset : offset + param.numel()].view_as(param.grad.data) / self.rank_size
)
offset += param.numel()

if self.truncate_grads:
self.scaler.unscale_(self.optimizer)
nn.utils.clip_grad_norm_(self.model.parameters(), self.grad_norm)

if self.multi_gpu:
with self.optimizer.skip_synchronize():
self.scaler.step(self.optimizer)
self.scaler.update()
else:
self.scaler.step(self.optimizer)
self.scaler.update()
self.scaler.step(self.optimizer)
self.scaler.update()

def load_networks(self, params):
builder = model_builder.ModelBuilder()
Expand Down Expand Up @@ -308,8 +321,8 @@ def set_train(self):

def update_lr(self, lr):
if self.multi_gpu:
lr_tensor = torch.tensor([lr])
self.hvd.broadcast_value(lr_tensor, 'learning_rate')
lr_tensor = torch.tensor([lr], device=self.device)
dist.broadcast(lr_tensor, 0)
lr = lr_tensor.item()

for param_group in self.optimizer.param_groups:
Expand Down Expand Up @@ -802,7 +815,8 @@ def train_epoch(self):

av_kls = torch_ext.mean_list(ep_kls)
if self.multi_gpu:
av_kls = self.hvd.average_value(av_kls, 'ep_kls')
dist.all_reduce(av_kls, op=dist.ReduceOp.SUM)
av_kls /= self.rank_size

self.last_lr, self.entropy_coef = self.scheduler.update(self.last_lr, self.entropy_coef, self.epoch_num, 0, av_kls.item())
self.update_lr(self.last_lr)
Expand Down Expand Up @@ -887,19 +901,20 @@ def train(self):
self.obs = self.env_reset()

if self.multi_gpu:
self.hvd.setup_algo(self)
torch.cuda.set_device(self.rank)
print("====================broadcasting parameters")
model_params = [self.model.state_dict()]
dist.broadcast_object_list(model_params, 0)
self.model.load_state_dict(model_params[0])

while True:
epoch_num = self.update_epoch()
step_time, play_time, update_time, sum_time, a_losses, c_losses, entropies, kls, last_lr, lr_mul = self.train_epoch()

if self.multi_gpu:
self.hvd.sync_stats(self)

# cleaning memory to optimize space
self.dataset.update_values_dict(None)
total_time += sum_time
curr_frames = self.curr_frames
curr_frames = self.curr_frames * self.rank_size if self.multi_gpu else self.curr_frames
self.frame += curr_frames
should_exit = False
if self.rank == 0:
Expand Down Expand Up @@ -959,9 +974,9 @@ def train(self):
should_exit = True
update_time = 0
if self.multi_gpu:
should_exit_t = torch.tensor(should_exit).float()
self.hvd.broadcast_value(should_exit_t, 'should_exit')
should_exit = should_exit_t.bool().item()
should_exit_t = torch.tensor(should_exit).float()
dist.broadcast(should_exit_t, 0)
should_exit = should_exit_t.bool().item()
if should_exit:
return self.last_mean_rewards, epoch_num

Expand Down Expand Up @@ -1040,9 +1055,10 @@ def train_epoch(self):
self.dataset.update_mu_sigma(cmu, csigma)

av_kls = torch_ext.mean_list(ep_kls)

if self.multi_gpu:
av_kls = self.hvd.average_value(av_kls, 'ep_kls')
dist.all_reduce(av_kls, op=dist.ReduceOp.SUM)
av_kls /= self.rank_size

self.last_lr, self.entropy_coef = self.scheduler.update(self.last_lr, self.entropy_coef, self.epoch_num, 0, av_kls.item())
self.update_lr(self.last_lr)

Expand Down Expand Up @@ -1128,16 +1144,18 @@ def train(self):
self.curr_frames = self.batch_size_envs

if self.multi_gpu:
self.hvd.setup_algo(self)
#
print("====================broadcasting parameters")
model_params = [self.model.state_dict()]
dist.broadcast_object_list(model_params, 0)
self.model.load_state_dict(model_params[0])

while True:
epoch_num = self.update_epoch()
step_time, play_time, update_time, sum_time, a_losses, c_losses, b_losses, entropies, kls, last_lr, lr_mul = self.train_epoch()
total_time += sum_time
frame = self.frame // self.num_agents

if self.multi_gpu:
self.hvd.sync_stats(self)
# cleaning memory to optimize space
self.dataset.update_values_dict(None)
should_exit = False
Expand All @@ -1147,7 +1165,7 @@ def train(self):
# do we need scaled_time?
scaled_time = self.num_agents * sum_time
scaled_play_time = self.num_agents * play_time
curr_frames = self.curr_frames
curr_frames = self.curr_frames * self.rank_size if self.multi_gpu else self.curr_frames
self.frame += curr_frames
if self.print_stats:
fps_step = curr_frames / step_time
Expand Down Expand Up @@ -1204,9 +1222,9 @@ def train(self):

update_time = 0
if self.multi_gpu:
should_exit_t = torch.tensor(should_exit).float()
self.hvd.broadcast_value(should_exit_t, 'should_exit')
should_exit = should_exit_t.float().item()
should_exit_t = torch.tensor(should_exit).float()
dist.broadcast(should_exit_t, 0)
should_exit = should_exit_t.float().item()
if should_exit:
return self.last_mean_rewards, epoch_num

75 changes: 0 additions & 75 deletions rl_games/distributed/hvd_wrapper.py

This file was deleted.

8 changes: 3 additions & 5 deletions rl_games/torch_runner.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import os
import time
import numpy as np
import random
Expand Down Expand Up @@ -60,10 +61,7 @@ def load_config(self, params):
self.seed = int(time.time())

if params["config"].get('multi_gpu', False):
import horovod.torch as hvd

hvd.init()
self.seed += hvd.rank()
self.seed += int(os.getenv("LOCAL_RANK", "0"))
print(f"self.seed = {self.seed}")

self.algo_params = params['algo']
Expand All @@ -82,7 +80,7 @@ def load_config(self, params):
params['config']['env_config']['seed'] = self.seed
else:
if params["config"].get('multi_gpu', False):
params['config']['env_config']['seed'] += hvd.rank()
params['config']['env_config']['seed'] += int(os.getenv("LOCAL_RANK", "0"))

config = params['config']
config['reward_shaper'] = tr_helpers.DefaultRewardsShaper(**config['reward_shaper'])
Expand Down
7 changes: 5 additions & 2 deletions runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,8 @@
except yaml.YAMLError as exc:
print(exc)

if args["track"]:
rank = int(os.getenv("LOCAL_RANK", "0"))
if args["track"] and rank == 0:
import wandb

wandb.init(
Expand All @@ -67,4 +68,6 @@
runner.run(args)

ray.shutdown()


if args["track"] and rank == 0:
wandb.finish()

0 comments on commit c9575bd

Please sign in to comment.