Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Deprecate horovod in favor of torch.distributed #171

Merged
merged 4 commits into from
Jun 3, 2022
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,14 @@ poetry run python runner.py --train --file rl_games/configs/atari/ppo_breakout_t
```


## Multi GPU

We use `torchrun` to orchestrate any multi-gpu runs.

```bash
torchrun --standalone --nnodes=1 --nproc_per_node=2 runner.py --train --file rl_games/configs/ppo_cartpole.yaml
```

## Config Parameters

| Field | Example Value | Default | Description |
Expand Down
1 change: 0 additions & 1 deletion rl_games/algos_torch/a2c_continuous.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,6 @@ def __init__(self, base_name, params):
'writter' : self.writer,
'max_epochs' : self.max_epochs,
'multi_gpu' : self.multi_gpu,
'hvd': self.hvd if self.multi_gpu else None
}
self.central_value_net = central_value.CentralValueTrain(**cv_config).to(self.ppo_device)

Expand Down
1 change: 0 additions & 1 deletion rl_games/algos_torch/a2c_discrete.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,6 @@ def __init__(self, base_name, params):
'writter' : self.writer,
'max_epochs' : self.max_epochs,
'multi_gpu' : self.multi_gpu,
'hvd': self.hvd if self.multi_gpu else None
}
self.central_value_net = central_value.CentralValueTrain(**cv_config).to(self.ppo_device)

Expand Down
8 changes: 4 additions & 4 deletions rl_games/algos_torch/central_value.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import torch
from torch import nn
import torch.distributed as dist
import gym
import numpy as np
from rl_games.algos_torch import torch_ext
Expand All @@ -9,7 +10,7 @@
from rl_games.common import schedulers

class CentralValueTrain(nn.Module):
def __init__(self, state_shape, value_size, ppo_device, num_agents, horizon_length, num_actors, num_actions, seq_len, normalize_value,network, config, writter, max_epochs, multi_gpu, hvd):
def __init__(self, state_shape, value_size, ppo_device, num_agents, horizon_length, num_actors, num_actions, seq_len, normalize_value,network, config, writter, max_epochs, multi_gpu):
nn.Module.__init__(self)
self.ppo_device = ppo_device
self.num_agents, self.horizon_length, self.num_actors, self.seq_len = num_agents, horizon_length, num_actors, seq_len
Expand All @@ -19,7 +20,6 @@ def __init__(self, state_shape, value_size, ppo_device, num_agents, horizon_leng
self.value_size = value_size
self.max_epochs = max_epochs
self.multi_gpu = multi_gpu
self.hvd = hvd
self.truncate_grads = config.get('truncate_grads', False)
self.config = config
self.normalize_input = config['normalize_input']
Expand Down Expand Up @@ -77,8 +77,8 @@ def __init__(self, state_shape, value_size, ppo_device, num_agents, horizon_leng

def update_lr(self, lr):
if self.multi_gpu:
lr_tensor = torch.tensor([lr])
self.hvd.broadcast_value(lr_tensor, 'cv_learning_rate')
lr_tensor = torch.tensor([lr], device=self.device)
dist.broadcast(lr_tensor, 0)
lr = lr_tensor.item()

for param_group in self.optimizer.param_groups:
Expand Down
88 changes: 53 additions & 35 deletions rl_games/common/a2c_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from tensorboardX import SummaryWriter
import torch
from torch import nn
import torch.distributed as dist

from time import sleep

Expand Down Expand Up @@ -68,11 +69,15 @@ def __init__(self, base_name, params):
self.rank_size = 1
self.curr_frames = 0
if self.multi_gpu:
from rl_games.distributed.hvd_wrapper import HorovodWrapper
self.hvd = HorovodWrapper()
self.config = self.hvd.update_algo_config(config)
self.rank = self.hvd.rank
self.rank_size = self.hvd.rank_size
self.rank = int(os.getenv("LOCAL_RANK", "0"))
self.rank_size = int(os.getenv("WORLD_SIZE", "1"))
dist.init_process_group("nccl", rank=self.rank, world_size=self.rank_size)

self.device_name = 'cuda:' + str(self.rank)
config['device'] = self.device_name
if self.rank != 0:
config['print_stats'] = False
config['lr_schedule'] = None

self.use_diagnostics = config.get('use_diagnostics', False)

Expand Down Expand Up @@ -251,19 +256,27 @@ def __init__(self, base_name, params):

def trancate_gradients_and_step(self):
if self.multi_gpu:
self.optimizer.synchronize()
# batch allreduce ops: see https://github.com/entity-neural-network/incubator/pull/220
all_grads_list = []
for param in self.model.parameters():
if param.grad is not None:
all_grads_list.append(param.grad.view(-1))
all_grads = torch.cat(all_grads_list)
dist.all_reduce(all_grads, op=dist.ReduceOp.SUM)
offset = 0
for param in self.model.parameters():
if param.grad is not None:
param.grad.data.copy_(
all_grads[offset : offset + param.numel()].view_as(param.grad.data) / self.rank_size
)
offset += param.numel()

if self.truncate_grads:
self.scaler.unscale_(self.optimizer)
nn.utils.clip_grad_norm_(self.model.parameters(), self.grad_norm)

if self.multi_gpu:
with self.optimizer.skip_synchronize():
self.scaler.step(self.optimizer)
self.scaler.update()
else:
self.scaler.step(self.optimizer)
self.scaler.update()
self.scaler.step(self.optimizer)
self.scaler.update()

def load_networks(self, params):
builder = model_builder.ModelBuilder()
Expand Down Expand Up @@ -308,8 +321,8 @@ def set_train(self):

def update_lr(self, lr):
if self.multi_gpu:
lr_tensor = torch.tensor([lr])
self.hvd.broadcast_value(lr_tensor, 'learning_rate')
lr_tensor = torch.tensor([lr], device=self.device)
dist.broadcast(lr_tensor, 0)
lr = lr_tensor.item()

for param_group in self.optimizer.param_groups:
Expand Down Expand Up @@ -802,7 +815,8 @@ def train_epoch(self):

av_kls = torch_ext.mean_list(ep_kls)
if self.multi_gpu:
av_kls = self.hvd.average_value(av_kls, 'ep_kls')
dist.all_reduce(av_kls, op=dist.ReduceOp.SUM)
av_kls /= self.rank_size

self.last_lr, self.entropy_coef = self.scheduler.update(self.last_lr, self.entropy_coef, self.epoch_num, 0, av_kls.item())
self.update_lr(self.last_lr)
Expand Down Expand Up @@ -887,19 +901,20 @@ def train(self):
self.obs = self.env_reset()

if self.multi_gpu:
self.hvd.setup_algo(self)
torch.cuda.set_device(self.rank)
print("====================broadcasting parameters")
model_params = [self.model.state_dict()]
dist.broadcast_object_list(model_params, 0)
self.model.load_state_dict(model_params[0])

while True:
epoch_num = self.update_epoch()
step_time, play_time, update_time, sum_time, a_losses, c_losses, entropies, kls, last_lr, lr_mul = self.train_epoch()

if self.multi_gpu:
self.hvd.sync_stats(self)

# cleaning memory to optimize space
self.dataset.update_values_dict(None)
total_time += sum_time
curr_frames = self.curr_frames
curr_frames = self.curr_frames * self.rank_size if self.multi_gpu else self.curr_frames
self.frame += curr_frames
should_exit = False
if self.rank == 0:
Expand Down Expand Up @@ -958,10 +973,10 @@ def train(self):
print('MAX EPOCHS NUM!')
should_exit = True
update_time = 0
if self.multi_gpu:
should_exit_t = torch.tensor(should_exit).float()
self.hvd.broadcast_value(should_exit_t, 'should_exit')
should_exit = should_exit_t.bool().item()
# if self.multi_gpu:
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've found why I had this code: #95

# should_exit_t = torch.tensor(should_exit).float()
# dist.broadcast(should_exit_t, 0)
# should_exit = should_exit_t.bool().item()
if should_exit:
return self.last_mean_rewards, epoch_num

Expand Down Expand Up @@ -1040,9 +1055,10 @@ def train_epoch(self):
self.dataset.update_mu_sigma(cmu, csigma)

av_kls = torch_ext.mean_list(ep_kls)

if self.multi_gpu:
av_kls = self.hvd.average_value(av_kls, 'ep_kls')
dist.all_reduce(av_kls, op=dist.ReduceOp.SUM)
av_kls /= self.rank_size

self.last_lr, self.entropy_coef = self.scheduler.update(self.last_lr, self.entropy_coef, self.epoch_num, 0, av_kls.item())
self.update_lr(self.last_lr)

Expand Down Expand Up @@ -1128,16 +1144,18 @@ def train(self):
self.curr_frames = self.batch_size_envs

if self.multi_gpu:
self.hvd.setup_algo(self)
#
print("====================broadcasting parameters")
model_params = [self.model.state_dict()]
dist.broadcast_object_list(model_params, 0)
self.model.load_state_dict(model_params[0])

while True:
epoch_num = self.update_epoch()
step_time, play_time, update_time, sum_time, a_losses, c_losses, b_losses, entropies, kls, last_lr, lr_mul = self.train_epoch()
total_time += sum_time
frame = self.frame // self.num_agents

if self.multi_gpu:
self.hvd.sync_stats(self)
# cleaning memory to optimize space
self.dataset.update_values_dict(None)
should_exit = False
Expand All @@ -1147,7 +1165,7 @@ def train(self):
# do we need scaled_time?
scaled_time = self.num_agents * sum_time
scaled_play_time = self.num_agents * play_time
curr_frames = self.curr_frames
curr_frames = self.curr_frames * self.rank_size if self.multi_gpu else self.curr_frames
self.frame += curr_frames
if self.print_stats:
fps_step = curr_frames / step_time
Expand Down Expand Up @@ -1203,10 +1221,10 @@ def train(self):
should_exit = True

update_time = 0
if self.multi_gpu:
should_exit_t = torch.tensor(should_exit).float()
self.hvd.broadcast_value(should_exit_t, 'should_exit')
should_exit = should_exit_t.float().item()
# if self.multi_gpu:
# should_exit_t = torch.tensor(should_exit).float()
# dist.broadcast(should_exit_t, 0)
# should_exit = should_exit_t.float().item()
if should_exit:
return self.last_mean_rewards, epoch_num

75 changes: 0 additions & 75 deletions rl_games/distributed/hvd_wrapper.py

This file was deleted.

8 changes: 3 additions & 5 deletions rl_games/torch_runner.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import os
import time
import numpy as np
import random
Expand Down Expand Up @@ -60,10 +61,7 @@ def load_config(self, params):
self.seed = int(time.time())

if params["config"].get('multi_gpu', False):
import horovod.torch as hvd

hvd.init()
self.seed += hvd.rank()
self.seed += int(os.getenv("LOCAL_RANK", "0"))
print(f"self.seed = {self.seed}")

self.algo_params = params['algo']
Expand All @@ -82,7 +80,7 @@ def load_config(self, params):
params['config']['env_config']['seed'] = self.seed
else:
if params["config"].get('multi_gpu', False):
params['config']['env_config']['seed'] += hvd.rank()
params['config']['env_config']['seed'] += int(os.getenv("LOCAL_RANK", "0"))

config = params['config']
config['reward_shaper'] = tr_helpers.DefaultRewardsShaper(**config['reward_shaper'])
Expand Down
7 changes: 5 additions & 2 deletions runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,8 @@
except yaml.YAMLError as exc:
print(exc)

if args["track"]:
rank = int(os.getenv("LOCAL_RANK", "0"))
if args["track"] and rank == 0:
import wandb

wandb.init(
Expand All @@ -67,4 +68,6 @@
runner.run(args)

ray.shutdown()


if args["track"] and rank == 0:
wandb.finish()