Skip to content

Commit

Permalink
Merge branch 'master' into VM/torch_upgrade
Browse files Browse the repository at this point in the history
  • Loading branch information
ViktorM committed Jul 12, 2024
2 parents 46c8c48 + 2606eff commit ae043b9
Show file tree
Hide file tree
Showing 8 changed files with 138 additions and 29 deletions.
6 changes: 3 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -67,10 +67,10 @@ Explore RL Games quick and easily in colab notebooks:

## Installation

For maximum training performance a preliminary installation of Pytorch 1.9+ with CUDA 11.1+ is highly recommended:
For maximum training performance a preliminary installation of Pytorch 2.2 or newer with CUDA 12.1 or newer is highly recommended:

```conda install pytorch torchvision cudatoolkit=11.3 -c pytorch -c nvidia``` or:
```pip install torch==1.11.0+cu113 torchvision==0.12.0+cu113 -f https://download.pytorch.org/whl/cu113/torch_stable.html```
```conda install pytorch torchvision pytorch-cuda=12.1 -c pytorch -c nvidia``` or:
```pip install pip3 install torch torchvision```

Then:

Expand Down
4 changes: 4 additions & 0 deletions rl_games/algos_torch/a2c_continuous.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,10 @@ def restore(self, fn, set_epoch=True):
checkpoint = torch_ext.load_checkpoint(fn)
self.set_full_state_weights(checkpoint, set_epoch=set_epoch)

def restore_central_value_function(self, fn):
checkpoint = torch_ext.load_checkpoint(fn)
self.set_central_value_function_weights(checkpoint)

def get_masked_action_values(self, obs, action_masks):
assert False

Expand Down
3 changes: 3 additions & 0 deletions rl_games/algos_torch/central_value.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
import copy
import torch
from torch import nn
import torch.distributed as dist
Expand Down Expand Up @@ -226,6 +227,8 @@ def train_net(self):
self.train()
loss = 0
for _ in range(self.mini_epoch):
if self.config.get('freeze_critic', False):
break
for idx in range(len(self.dataset)):
loss += self.train_critic(self.dataset[idx])
if self.normalize_input:
Expand Down
14 changes: 12 additions & 2 deletions rl_games/algos_torch/network_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from rl_games.algos_torch.sac_helper import SquashedNormal
from rl_games.common.layers.recurrent import GRUWithDones, LSTMWithDones
from rl_games.common.layers.value import TwoHotEncodedValue, DefaultValue
from rl_games.algos_torch.spatial_softmax import SpatialSoftArgmax


def _create_initializer(func, **kwargs):
Expand Down Expand Up @@ -130,12 +131,17 @@ def _build_conv(self, ctype, **kwargs):

if ctype == 'conv2d':
return self._build_cnn2d(**kwargs)
if ctype == 'conv2d_spatial_softargmax':
return self._build_cnn2d(add_spatial_softmax=True, **kwargs)
if ctype == 'conv2d_flatten':
return self._build_cnn2d(add_flatten=True, **kwargs)
if ctype == 'coord_conv2d':
return self._build_cnn2d(conv_func=torch_ext.CoordConv2d, **kwargs)
if ctype == 'conv1d':
return self._build_cnn1d(**kwargs)

def _build_cnn2d(self, input_shape, convs, activation, conv_func=torch.nn.Conv2d, norm_func_name=None):
def _build_cnn2d(self, input_shape, convs, activation, conv_func=torch.nn.Conv2d, norm_func_name=None,
add_spatial_softmax=False, add_flatten=False):
in_channels = input_shape[0]
layers = []
for conv in convs:
Expand All @@ -150,7 +156,11 @@ def _build_cnn2d(self, input_shape, convs, activation, conv_func=torch.nn.Conv2d
if norm_func_name == 'layer_norm':
layers.append(torch_ext.LayerNorm2d(in_channels))
elif norm_func_name == 'batch_norm':
layers.append(torch.nn.BatchNorm2d(in_channels))
layers.append(torch.nn.BatchNorm2d(in_channels))
if add_spatial_softmax:
layers.append(SpatialSoftArgmax(normalize=True))
if add_flatten:
layers.append(torch.nn.Flatten())
return nn.Sequential(*layers)

def _build_cnn1d(self, input_shape, convs, activation, norm_func_name=None):
Expand Down
15 changes: 7 additions & 8 deletions rl_games/algos_torch/sac_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -444,7 +444,7 @@ def clear_stats(self):
self.algo_observer.after_clear_stats()

def play_steps(self, random_exploration = False):
total_time_start = time.time()
total_time_start = time.perf_counter()
total_update_time = 0
total_time = 0
step_time = 0.0
Expand All @@ -469,11 +469,10 @@ def play_steps(self, random_exploration = False):
with torch.no_grad():
action = self.act(obs.float(), self.env_info["action_space"].shape, sample=True)

step_start = time.time()

step_start = time.perf_counter()
with torch.no_grad():
next_obs, rewards, dones, infos = self.env_step(action)
step_end = time.time()
step_end = time.perf_counter()

self.current_rewards += rewards
self.current_lengths += 1
Expand Down Expand Up @@ -503,17 +502,17 @@ def play_steps(self, random_exploration = False):
self.obs = next_obs.clone()

rewards = self.rewards_shaper(rewards)

self.replay_buffer.add(obs, action, torch.unsqueeze(rewards, 1), next_obs_processed, torch.unsqueeze(dones, 1))

if isinstance(obs, dict):
obs = self.obs['obs']

if not random_exploration:
self.set_train()
update_time_start = time.time()

update_time_start = time.perf_counter()
actor_loss_info, critic1_loss, critic2_loss = self.update(self.epoch_num)
update_time_end = time.time()
update_time_end = time.perf_counter()
update_time = update_time_end - update_time_start

self.extract_actor_stats(actor_losses, entropies, alphas, alpha_losses, actor_loss_info)
Expand All @@ -524,7 +523,7 @@ def play_steps(self, random_exploration = False):

total_update_time += update_time

total_time_end = time.time()
total_time_end = time.perf_counter()
total_time = total_time_end - total_time_start
play_time = total_time - total_update_time

Expand Down
83 changes: 83 additions & 0 deletions rl_games/algos_torch/spatial_softmax.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
import torch
import torch.nn as nn
import torch.nn.functional as F


# Adopted from https://gist.github.com/kevinzakka/dd9fa5177cda13593524f4d71eb38ad5
class SpatialSoftArgmax(nn.Module):
"""Spatial softmax as defined in [1].
Concretely, the spatial softmax of each feature
map is used to compute a weighted mean of the pixel
locations, effectively performing a soft arg-max
over the feature dimension.
References:
[1]: End-to-End Training of Deep Visuomotor Policies,
https://arxiv.org/abs/1504.00702
"""

def __init__(self, normalize=False):
"""Constructor.
Args:
normalize (bool): Whether to use normalized
image coordinates, i.e. coordinates in
the range `[-1, 1]`.
"""
super().__init__()

self.normalize = normalize

def _coord_grid(self, h, w, device):
if self.normalize:
return torch.stack(
torch.meshgrid(
torch.linspace(-1, 1, w, device=device),
torch.linspace(-1, 1, h, device=device),
)
)
return torch.stack(
torch.meshgrid(
torch.arange(0, w, device=device),
torch.arange(0, h, device=device),
)
)

def forward(self, x):
assert x.ndim == 4, "Expecting a tensor of shape (B, C, H, W)."

# compute a spatial softmax over the input:
# given an input of shape (B, C, H, W),
# reshape it to (B*C, H*W) then apply
# the softmax operator over the last dimension
b, c, h, w = x.shape
softmax = F.softmax(x.reshape(-1, h * w), dim=-1)

# create a meshgrid of pixel coordinates
# both in the x and y axes
xc, yc = self._coord_grid(h, w, x.device)

# element-wise multiply the x and y coordinates
# with the softmax, then sum over the h*w dimension
# this effectively computes the weighted mean of x
# and y locations
x_mean = (softmax * xc.flatten()).sum(dim=1, keepdims=True)
y_mean = (softmax * yc.flatten()).sum(dim=1, keepdims=True)

# concatenate and reshape the result
# to (B, C*2) where for every feature
# we have the expected x and y pixel
# locations
return torch.cat([x_mean, y_mean], dim=1).view(-1, c * 2)


if __name__ == "__main__":
b, c, h, w = 32, 64, 12, 12
x = torch.zeros(b, c, h, w)
true_max = torch.randint(0, 10, size=(b, c, 2))
for i in range(b):
for j in range(c):
x[i, j, true_max[i, j, 0], true_max[i, j, 1]] = 1000
soft_max = SpatialSoftArgmax()(x).reshape(b, c, 2)
assert torch.allclose(true_max.float(), soft_max)
36 changes: 21 additions & 15 deletions rl_games/common/a2c_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -646,6 +646,9 @@ def set_full_state_weights(self, weights, set_epoch=True):
env_state = weights.get('env_state', None)
self.vec_env.set_env_state(env_state)

def set_central_value_function_weights(self, weights):
self.central_value_net.load_state_dict(weights['assymetric_vf_nets'])

def get_weights(self):
state = self.get_stats_weights()
state['model'] = self.model.state_dict()
Expand Down Expand Up @@ -760,9 +763,9 @@ def play_steps(self):
if self.has_central_value:
self.experience_buffer.update_data('states', n, self.obs['states'])

step_time_start = time.time()
step_time_start = time.perf_counter()
self.obs, rewards, self.dones, infos = self.env_step(res_dict['actions'])
step_time_end = time.time()
step_time_end = time.perf_counter()

step_time += (step_time_end - step_time_start)

Expand Down Expand Up @@ -833,9 +836,9 @@ def play_steps_rnn(self):
if self.has_central_value:
self.experience_buffer.update_data('states', n, self.obs['states'])

step_time_start = time.time()
step_time_start = time.perf_counter()
self.obs, rewards, self.dones, infos = self.env_step(res_dict['actions'])
step_time_end = time.time()
step_time_end = time.perf_counter()

step_time += (step_time_end - step_time_start)

Expand Down Expand Up @@ -923,7 +926,7 @@ def train_epoch(self):
super().train_epoch()

self.set_eval()
play_time_start = time.time()
play_time_start = time.perf_counter()

with torch.no_grad():
if self.is_rnn:
Expand All @@ -933,8 +936,8 @@ def train_epoch(self):

self.set_train()

play_time_end = time.time()
update_time_start = time.time()
play_time_end = time.perf_counter()
update_time_start = time.perf_counter()
rnn_masks = batch_dict.get('rnn_masks', None)

self.curr_frames = batch_dict.pop('played_frames')
Expand Down Expand Up @@ -969,7 +972,7 @@ def train_epoch(self):
if self.normalize_input:
self.model.running_mean_std.eval() # don't need to update statstics more than one miniepoch

update_time_end = time.time()
update_time_end = time.perf_counter()
play_time = play_time_end - play_time_start
update_time = update_time_end - update_time_start
total_time = update_time_end - play_time_start
Expand Down Expand Up @@ -1037,7 +1040,7 @@ def prepare_dataset(self, batch_dict):
def train(self):
self.init_tensors()
self.mean_rewards = self.last_mean_rewards = -100500
start_time = time.time()
start_time = time.perf_counter()
total_time = 0
rep_count = 0
# self.frame = 0 # loading from checkpoint
Expand Down Expand Up @@ -1186,15 +1189,15 @@ def train_epoch(self):
super().train_epoch()

self.set_eval()
play_time_start = time.time()
play_time_start = time.perf_counter()
with torch.no_grad():
if self.is_rnn:
batch_dict = self.play_steps_rnn()
else:
batch_dict = self.play_steps()

play_time_end = time.time()
update_time_start = time.time()
play_time_end = time.perf_counter()
update_time_start = time.perf_counter()
rnn_masks = batch_dict.get('rnn_masks', None)

self.set_train()
Expand Down Expand Up @@ -1243,7 +1246,7 @@ def train_epoch(self):
if self.normalize_input:
self.model.running_mean_std.eval() # don't need to update statstics more than one miniepoch

update_time_end = time.time()
update_time_end = time.perf_counter()
play_time = play_time_end - play_time_start
update_time = update_time_end - update_time_start
total_time = update_time_end - play_time_start
Expand All @@ -1265,7 +1268,10 @@ def prepare_dataset(self, batch_dict):
advantages = returns - values

if self.normalize_value:
self.value_mean_std.train()
if self.config.get('freeze_critic', False):
self.value_mean_std.eval()
else:
self.value_mean_std.train()
values = self.value_mean_std(values)
returns = self.value_mean_std(returns)
self.value_mean_std.eval()
Expand Down Expand Up @@ -1313,7 +1319,7 @@ def prepare_dataset(self, batch_dict):
def train(self):
self.init_tensors()
self.last_mean_rewards = -100500
start_time = time.time()
start_time = time.perf_counter()
total_time = 0
rep_count = 0
self.obs = self.env_reset()
Expand Down
6 changes: 5 additions & 1 deletion rl_games/torch_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,10 @@

def _restore(agent, args):
if 'checkpoint' in args and args['checkpoint'] is not None and args['checkpoint'] !='':
if args['train'] and args.get('load_critic_only', False):
assert agent.has_central_value, 'This should only work for asymmetric actor critic'
agent.restore_central_value_function(args['checkpoint'])
return
agent.restore(args['checkpoint'])

def _override_sigma(agent, args):
Expand Down Expand Up @@ -63,7 +67,7 @@ def __init__(self, algo_observer=None):

self.algo_observer = algo_observer if algo_observer else DefaultAlgoObserver()
torch.backends.cudnn.benchmark = True
### it didnot help for lots for openai gym envs anyway :(
### it did not help for lots for openai gym envs anyway :(
#torch.backends.cudnn.deterministic = True
#torch.use_deterministic_algorithms(True)

Expand Down

0 comments on commit ae043b9

Please sign in to comment.