diff --git a/rl_games/algos_torch/a2c_continuous.py b/rl_games/algos_torch/a2c_continuous.py index 78795a2b..a64de95f 100644 --- a/rl_games/algos_torch/a2c_continuous.py +++ b/rl_games/algos_torch/a2c_continuous.py @@ -83,6 +83,10 @@ def restore(self, fn, set_epoch=True): checkpoint = torch_ext.load_checkpoint(fn) self.set_full_state_weights(checkpoint, set_epoch=set_epoch) + def restore_central_value_function(self, fn): + checkpoint = torch_ext.load_checkpoint(fn) + self.set_central_value_function_weights(checkpoint) + def get_masked_action_values(self, obs, action_masks): assert False diff --git a/rl_games/algos_torch/central_value.py b/rl_games/algos_torch/central_value.py index d75c687c..c06d9a18 100644 --- a/rl_games/algos_torch/central_value.py +++ b/rl_games/algos_torch/central_value.py @@ -1,4 +1,5 @@ import os +import copy import torch from torch import nn import torch.distributed as dist @@ -219,6 +220,8 @@ def train_net(self): self.train() loss = 0 for _ in range(self.mini_epoch): + if self.config.get('freeze_critic', False): + break for idx in range(len(self.dataset)): loss += self.train_critic(self.dataset[idx]) if self.normalize_input: diff --git a/rl_games/algos_torch/network_builder.py b/rl_games/algos_torch/network_builder.py index ab047920..e5d625c0 100644 --- a/rl_games/algos_torch/network_builder.py +++ b/rl_games/algos_torch/network_builder.py @@ -8,6 +8,7 @@ from rl_games.algos_torch.sac_helper import SquashedNormal from rl_games.common.layers.recurrent import GRUWithDones, LSTMWithDones from rl_games.common.layers.value import TwoHotEncodedValue, DefaultValue +from rl_games.algos_torch.spatial_softmax import SpatialSoftArgmax def _create_initializer(func, **kwargs): @@ -130,12 +131,17 @@ def _build_conv(self, ctype, **kwargs): if ctype == 'conv2d': return self._build_cnn2d(**kwargs) + if ctype == 'conv2d_spatial_softargmax': + return self._build_cnn2d(add_spatial_softmax=True, **kwargs) + if ctype == 'conv2d_flatten': + return self._build_cnn2d(add_flatten=True, **kwargs) if ctype == 'coord_conv2d': return self._build_cnn2d(conv_func=torch_ext.CoordConv2d, **kwargs) if ctype == 'conv1d': return self._build_cnn1d(**kwargs) - def _build_cnn2d(self, input_shape, convs, activation, conv_func=torch.nn.Conv2d, norm_func_name=None): + def _build_cnn2d(self, input_shape, convs, activation, conv_func=torch.nn.Conv2d, norm_func_name=None, + add_spatial_softmax=False, add_flatten=False): in_channels = input_shape[0] layers = [] for conv in convs: @@ -150,7 +156,11 @@ def _build_cnn2d(self, input_shape, convs, activation, conv_func=torch.nn.Conv2d if norm_func_name == 'layer_norm': layers.append(torch_ext.LayerNorm2d(in_channels)) elif norm_func_name == 'batch_norm': - layers.append(torch.nn.BatchNorm2d(in_channels)) + layers.append(torch.nn.BatchNorm2d(in_channels)) + if add_spatial_softmax: + layers.append(SpatialSoftArgmax(normalize=True)) + if add_flatten: + layers.append(torch.nn.Flatten()) return nn.Sequential(*layers) def _build_cnn1d(self, input_shape, convs, activation, norm_func_name=None): diff --git a/rl_games/algos_torch/spatial_softmax.py b/rl_games/algos_torch/spatial_softmax.py new file mode 100644 index 00000000..862efed9 --- /dev/null +++ b/rl_games/algos_torch/spatial_softmax.py @@ -0,0 +1,83 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + + +# Adopted from https://gist.github.com/kevinzakka/dd9fa5177cda13593524f4d71eb38ad5 +class SpatialSoftArgmax(nn.Module): + """Spatial softmax as defined in [1]. + + Concretely, the spatial softmax of each feature + map is used to compute a weighted mean of the pixel + locations, effectively performing a soft arg-max + over the feature dimension. + + References: + [1]: End-to-End Training of Deep Visuomotor Policies, + https://arxiv.org/abs/1504.00702 + """ + + def __init__(self, normalize=False): + """Constructor. + + Args: + normalize (bool): Whether to use normalized + image coordinates, i.e. coordinates in + the range `[-1, 1]`. + """ + super().__init__() + + self.normalize = normalize + + def _coord_grid(self, h, w, device): + if self.normalize: + return torch.stack( + torch.meshgrid( + torch.linspace(-1, 1, w, device=device), + torch.linspace(-1, 1, h, device=device), + ) + ) + return torch.stack( + torch.meshgrid( + torch.arange(0, w, device=device), + torch.arange(0, h, device=device), + ) + ) + + def forward(self, x): + assert x.ndim == 4, "Expecting a tensor of shape (B, C, H, W)." + + # compute a spatial softmax over the input: + # given an input of shape (B, C, H, W), + # reshape it to (B*C, H*W) then apply + # the softmax operator over the last dimension + b, c, h, w = x.shape + softmax = F.softmax(x.reshape(-1, h * w), dim=-1) + + # create a meshgrid of pixel coordinates + # both in the x and y axes + xc, yc = self._coord_grid(h, w, x.device) + + # element-wise multiply the x and y coordinates + # with the softmax, then sum over the h*w dimension + # this effectively computes the weighted mean of x + # and y locations + x_mean = (softmax * xc.flatten()).sum(dim=1, keepdims=True) + y_mean = (softmax * yc.flatten()).sum(dim=1, keepdims=True) + + # concatenate and reshape the result + # to (B, C*2) where for every feature + # we have the expected x and y pixel + # locations + return torch.cat([x_mean, y_mean], dim=1).view(-1, c * 2) + + +if __name__ == "__main__": + b, c, h, w = 32, 64, 12, 12 + x = torch.zeros(b, c, h, w) + true_max = torch.randint(0, 10, size=(b, c, 2)) + for i in range(b): + for j in range(c): + x[i, j, true_max[i, j, 0], true_max[i, j, 1]] = 1000 + soft_max = SpatialSoftArgmax()(x).reshape(b, c, 2) + assert torch.allclose(true_max.float(), soft_max) \ No newline at end of file diff --git a/rl_games/common/a2c_common.py b/rl_games/common/a2c_common.py index 19b95985..224bca6b 100644 --- a/rl_games/common/a2c_common.py +++ b/rl_games/common/a2c_common.py @@ -643,6 +643,9 @@ def set_full_state_weights(self, weights, set_epoch=True): env_state = weights.get('env_state', None) self.vec_env.set_env_state(env_state) + def set_central_value_function_weights(self, weights): + self.central_value_net.load_state_dict(weights['assymetric_vf_nets']) + def get_weights(self): state = self.get_stats_weights() state['model'] = self.model.state_dict() @@ -1262,7 +1265,10 @@ def prepare_dataset(self, batch_dict): advantages = returns - values if self.normalize_value: - self.value_mean_std.train() + if self.config.get('freeze_critic', False): + self.value_mean_std.eval() + else: + self.value_mean_std.train() values = self.value_mean_std(values) returns = self.value_mean_std(returns) self.value_mean_std.eval() diff --git a/rl_games/torch_runner.py b/rl_games/torch_runner.py index 86be48ac..0f7a9ac8 100644 --- a/rl_games/torch_runner.py +++ b/rl_games/torch_runner.py @@ -17,6 +17,10 @@ def _restore(agent, args): if 'checkpoint' in args and args['checkpoint'] is not None and args['checkpoint'] !='': + if args['train'] and args.get('load_critic_only', False): + assert agent.has_central_value, 'This should only work for asymmetric actor critic' + agent.restore_central_value_function(args['checkpoint']) + return agent.restore(args['checkpoint']) def _override_sigma(agent, args):