diff --git a/README.md b/README.md index 2069d00b..c3529e54 100644 --- a/README.md +++ b/README.md @@ -67,10 +67,10 @@ Explore RL Games quick and easily in colab notebooks: ## Installation -For maximum training performance a preliminary installation of Pytorch 1.9+ with CUDA 11.1+ is highly recommended: +For maximum training performance a preliminary installation of Pytorch 2.2 or newer with CUDA 12.1 or newer is highly recommended: -```conda install pytorch torchvision cudatoolkit=11.3 -c pytorch -c nvidia``` or: -```pip install torch==1.11.0+cu113 torchvision==0.12.0+cu113 -f https://download.pytorch.org/whl/cu113/torch_stable.html``` +```conda install pytorch torchvision pytorch-cuda=12.1 -c pytorch -c nvidia``` or: +```pip install pip3 install torch torchvision``` Then: diff --git a/rl_games/algos_torch/a2c_continuous.py b/rl_games/algos_torch/a2c_continuous.py index 4a3aac3b..e0812d70 100644 --- a/rl_games/algos_torch/a2c_continuous.py +++ b/rl_games/algos_torch/a2c_continuous.py @@ -88,6 +88,10 @@ def restore(self, fn, set_epoch=True): checkpoint = torch_ext.load_checkpoint(fn) self.set_full_state_weights(checkpoint, set_epoch=set_epoch) + def restore_central_value_function(self, fn): + checkpoint = torch_ext.load_checkpoint(fn) + self.set_central_value_function_weights(checkpoint) + def get_masked_action_values(self, obs, action_masks): assert False diff --git a/rl_games/algos_torch/central_value.py b/rl_games/algos_torch/central_value.py index f3ada733..8bc13729 100644 --- a/rl_games/algos_torch/central_value.py +++ b/rl_games/algos_torch/central_value.py @@ -1,4 +1,5 @@ import os +import copy import torch from torch import nn import torch.distributed as dist @@ -226,6 +227,8 @@ def train_net(self): self.train() loss = 0 for _ in range(self.mini_epoch): + if self.config.get('freeze_critic', False): + break for idx in range(len(self.dataset)): loss += self.train_critic(self.dataset[idx]) if self.normalize_input: diff --git a/rl_games/algos_torch/network_builder.py b/rl_games/algos_torch/network_builder.py index ab047920..e5d625c0 100644 --- a/rl_games/algos_torch/network_builder.py +++ b/rl_games/algos_torch/network_builder.py @@ -8,6 +8,7 @@ from rl_games.algos_torch.sac_helper import SquashedNormal from rl_games.common.layers.recurrent import GRUWithDones, LSTMWithDones from rl_games.common.layers.value import TwoHotEncodedValue, DefaultValue +from rl_games.algos_torch.spatial_softmax import SpatialSoftArgmax def _create_initializer(func, **kwargs): @@ -130,12 +131,17 @@ def _build_conv(self, ctype, **kwargs): if ctype == 'conv2d': return self._build_cnn2d(**kwargs) + if ctype == 'conv2d_spatial_softargmax': + return self._build_cnn2d(add_spatial_softmax=True, **kwargs) + if ctype == 'conv2d_flatten': + return self._build_cnn2d(add_flatten=True, **kwargs) if ctype == 'coord_conv2d': return self._build_cnn2d(conv_func=torch_ext.CoordConv2d, **kwargs) if ctype == 'conv1d': return self._build_cnn1d(**kwargs) - def _build_cnn2d(self, input_shape, convs, activation, conv_func=torch.nn.Conv2d, norm_func_name=None): + def _build_cnn2d(self, input_shape, convs, activation, conv_func=torch.nn.Conv2d, norm_func_name=None, + add_spatial_softmax=False, add_flatten=False): in_channels = input_shape[0] layers = [] for conv in convs: @@ -150,7 +156,11 @@ def _build_cnn2d(self, input_shape, convs, activation, conv_func=torch.nn.Conv2d if norm_func_name == 'layer_norm': layers.append(torch_ext.LayerNorm2d(in_channels)) elif norm_func_name == 'batch_norm': - layers.append(torch.nn.BatchNorm2d(in_channels)) + layers.append(torch.nn.BatchNorm2d(in_channels)) + if add_spatial_softmax: + layers.append(SpatialSoftArgmax(normalize=True)) + if add_flatten: + layers.append(torch.nn.Flatten()) return nn.Sequential(*layers) def _build_cnn1d(self, input_shape, convs, activation, norm_func_name=None): diff --git a/rl_games/algos_torch/sac_agent.py b/rl_games/algos_torch/sac_agent.py index 909ae3e5..46c7c8f7 100644 --- a/rl_games/algos_torch/sac_agent.py +++ b/rl_games/algos_torch/sac_agent.py @@ -444,7 +444,7 @@ def clear_stats(self): self.algo_observer.after_clear_stats() def play_steps(self, random_exploration = False): - total_time_start = time.time() + total_time_start = time.perf_counter() total_update_time = 0 total_time = 0 step_time = 0.0 @@ -469,11 +469,10 @@ def play_steps(self, random_exploration = False): with torch.no_grad(): action = self.act(obs.float(), self.env_info["action_space"].shape, sample=True) - step_start = time.time() - + step_start = time.perf_counter() with torch.no_grad(): next_obs, rewards, dones, infos = self.env_step(action) - step_end = time.time() + step_end = time.perf_counter() self.current_rewards += rewards self.current_lengths += 1 @@ -503,7 +502,6 @@ def play_steps(self, random_exploration = False): self.obs = next_obs.clone() rewards = self.rewards_shaper(rewards) - self.replay_buffer.add(obs, action, torch.unsqueeze(rewards, 1), next_obs_processed, torch.unsqueeze(dones, 1)) if isinstance(obs, dict): @@ -511,9 +509,10 @@ def play_steps(self, random_exploration = False): if not random_exploration: self.set_train() - update_time_start = time.time() + + update_time_start = time.perf_counter() actor_loss_info, critic1_loss, critic2_loss = self.update(self.epoch_num) - update_time_end = time.time() + update_time_end = time.perf_counter() update_time = update_time_end - update_time_start self.extract_actor_stats(actor_losses, entropies, alphas, alpha_losses, actor_loss_info) @@ -524,7 +523,7 @@ def play_steps(self, random_exploration = False): total_update_time += update_time - total_time_end = time.time() + total_time_end = time.perf_counter() total_time = total_time_end - total_time_start play_time = total_time - total_update_time diff --git a/rl_games/algos_torch/spatial_softmax.py b/rl_games/algos_torch/spatial_softmax.py new file mode 100644 index 00000000..862efed9 --- /dev/null +++ b/rl_games/algos_torch/spatial_softmax.py @@ -0,0 +1,83 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + + +# Adopted from https://gist.github.com/kevinzakka/dd9fa5177cda13593524f4d71eb38ad5 +class SpatialSoftArgmax(nn.Module): + """Spatial softmax as defined in [1]. + + Concretely, the spatial softmax of each feature + map is used to compute a weighted mean of the pixel + locations, effectively performing a soft arg-max + over the feature dimension. + + References: + [1]: End-to-End Training of Deep Visuomotor Policies, + https://arxiv.org/abs/1504.00702 + """ + + def __init__(self, normalize=False): + """Constructor. + + Args: + normalize (bool): Whether to use normalized + image coordinates, i.e. coordinates in + the range `[-1, 1]`. + """ + super().__init__() + + self.normalize = normalize + + def _coord_grid(self, h, w, device): + if self.normalize: + return torch.stack( + torch.meshgrid( + torch.linspace(-1, 1, w, device=device), + torch.linspace(-1, 1, h, device=device), + ) + ) + return torch.stack( + torch.meshgrid( + torch.arange(0, w, device=device), + torch.arange(0, h, device=device), + ) + ) + + def forward(self, x): + assert x.ndim == 4, "Expecting a tensor of shape (B, C, H, W)." + + # compute a spatial softmax over the input: + # given an input of shape (B, C, H, W), + # reshape it to (B*C, H*W) then apply + # the softmax operator over the last dimension + b, c, h, w = x.shape + softmax = F.softmax(x.reshape(-1, h * w), dim=-1) + + # create a meshgrid of pixel coordinates + # both in the x and y axes + xc, yc = self._coord_grid(h, w, x.device) + + # element-wise multiply the x and y coordinates + # with the softmax, then sum over the h*w dimension + # this effectively computes the weighted mean of x + # and y locations + x_mean = (softmax * xc.flatten()).sum(dim=1, keepdims=True) + y_mean = (softmax * yc.flatten()).sum(dim=1, keepdims=True) + + # concatenate and reshape the result + # to (B, C*2) where for every feature + # we have the expected x and y pixel + # locations + return torch.cat([x_mean, y_mean], dim=1).view(-1, c * 2) + + +if __name__ == "__main__": + b, c, h, w = 32, 64, 12, 12 + x = torch.zeros(b, c, h, w) + true_max = torch.randint(0, 10, size=(b, c, 2)) + for i in range(b): + for j in range(c): + x[i, j, true_max[i, j, 0], true_max[i, j, 1]] = 1000 + soft_max = SpatialSoftArgmax()(x).reshape(b, c, 2) + assert torch.allclose(true_max.float(), soft_max) \ No newline at end of file diff --git a/rl_games/common/a2c_common.py b/rl_games/common/a2c_common.py index bda3d38f..e0c0cb28 100644 --- a/rl_games/common/a2c_common.py +++ b/rl_games/common/a2c_common.py @@ -646,6 +646,9 @@ def set_full_state_weights(self, weights, set_epoch=True): env_state = weights.get('env_state', None) self.vec_env.set_env_state(env_state) + def set_central_value_function_weights(self, weights): + self.central_value_net.load_state_dict(weights['assymetric_vf_nets']) + def get_weights(self): state = self.get_stats_weights() state['model'] = self.model.state_dict() @@ -760,9 +763,9 @@ def play_steps(self): if self.has_central_value: self.experience_buffer.update_data('states', n, self.obs['states']) - step_time_start = time.time() + step_time_start = time.perf_counter() self.obs, rewards, self.dones, infos = self.env_step(res_dict['actions']) - step_time_end = time.time() + step_time_end = time.perf_counter() step_time += (step_time_end - step_time_start) @@ -833,9 +836,9 @@ def play_steps_rnn(self): if self.has_central_value: self.experience_buffer.update_data('states', n, self.obs['states']) - step_time_start = time.time() + step_time_start = time.perf_counter() self.obs, rewards, self.dones, infos = self.env_step(res_dict['actions']) - step_time_end = time.time() + step_time_end = time.perf_counter() step_time += (step_time_end - step_time_start) @@ -923,7 +926,7 @@ def train_epoch(self): super().train_epoch() self.set_eval() - play_time_start = time.time() + play_time_start = time.perf_counter() with torch.no_grad(): if self.is_rnn: @@ -933,8 +936,8 @@ def train_epoch(self): self.set_train() - play_time_end = time.time() - update_time_start = time.time() + play_time_end = time.perf_counter() + update_time_start = time.perf_counter() rnn_masks = batch_dict.get('rnn_masks', None) self.curr_frames = batch_dict.pop('played_frames') @@ -969,7 +972,7 @@ def train_epoch(self): if self.normalize_input: self.model.running_mean_std.eval() # don't need to update statstics more than one miniepoch - update_time_end = time.time() + update_time_end = time.perf_counter() play_time = play_time_end - play_time_start update_time = update_time_end - update_time_start total_time = update_time_end - play_time_start @@ -1037,7 +1040,7 @@ def prepare_dataset(self, batch_dict): def train(self): self.init_tensors() self.mean_rewards = self.last_mean_rewards = -100500 - start_time = time.time() + start_time = time.perf_counter() total_time = 0 rep_count = 0 # self.frame = 0 # loading from checkpoint @@ -1186,15 +1189,15 @@ def train_epoch(self): super().train_epoch() self.set_eval() - play_time_start = time.time() + play_time_start = time.perf_counter() with torch.no_grad(): if self.is_rnn: batch_dict = self.play_steps_rnn() else: batch_dict = self.play_steps() - play_time_end = time.time() - update_time_start = time.time() + play_time_end = time.perf_counter() + update_time_start = time.perf_counter() rnn_masks = batch_dict.get('rnn_masks', None) self.set_train() @@ -1243,7 +1246,7 @@ def train_epoch(self): if self.normalize_input: self.model.running_mean_std.eval() # don't need to update statstics more than one miniepoch - update_time_end = time.time() + update_time_end = time.perf_counter() play_time = play_time_end - play_time_start update_time = update_time_end - update_time_start total_time = update_time_end - play_time_start @@ -1265,7 +1268,10 @@ def prepare_dataset(self, batch_dict): advantages = returns - values if self.normalize_value: - self.value_mean_std.train() + if self.config.get('freeze_critic', False): + self.value_mean_std.eval() + else: + self.value_mean_std.train() values = self.value_mean_std(values) returns = self.value_mean_std(returns) self.value_mean_std.eval() @@ -1313,7 +1319,7 @@ def prepare_dataset(self, batch_dict): def train(self): self.init_tensors() self.last_mean_rewards = -100500 - start_time = time.time() + start_time = time.perf_counter() total_time = 0 rep_count = 0 self.obs = self.env_reset() diff --git a/rl_games/torch_runner.py b/rl_games/torch_runner.py index 4377d29b..0f7a9ac8 100644 --- a/rl_games/torch_runner.py +++ b/rl_games/torch_runner.py @@ -17,6 +17,10 @@ def _restore(agent, args): if 'checkpoint' in args and args['checkpoint'] is not None and args['checkpoint'] !='': + if args['train'] and args.get('load_critic_only', False): + assert agent.has_central_value, 'This should only work for asymmetric actor critic' + agent.restore_central_value_function(args['checkpoint']) + return agent.restore(args['checkpoint']) def _override_sigma(agent, args): @@ -63,7 +67,7 @@ def __init__(self, algo_observer=None): self.algo_observer = algo_observer if algo_observer else DefaultAlgoObserver() torch.backends.cudnn.benchmark = True - ### it didnot help for lots for openai gym envs anyway :( + ### it did not help for lots for openai gym envs anyway :( #torch.backends.cudnn.deterministic = True #torch.use_deterministic_algorithms(True)