Skip to content

Commit

Permalink
Update for tacsl release: CNN tower processing, critic weights loadin…
Browse files Browse the repository at this point in the history
…g and freezing. (#298)

* fix missing import copy

* adding ability to post-process the output of a conv tower with the spatial soft argmax or flatten layer

* enable loading the weights of the critic network from a PPO checkpoint, without the actor weights

* add flag to freeze critic while training actor
  • Loading branch information
iakinola23 authored Jul 12, 2024
1 parent 7a2b25f commit 2606eff
Show file tree
Hide file tree
Showing 6 changed files with 113 additions and 3 deletions.
4 changes: 4 additions & 0 deletions rl_games/algos_torch/a2c_continuous.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,10 @@ def restore(self, fn, set_epoch=True):
checkpoint = torch_ext.load_checkpoint(fn)
self.set_full_state_weights(checkpoint, set_epoch=set_epoch)

def restore_central_value_function(self, fn):
checkpoint = torch_ext.load_checkpoint(fn)
self.set_central_value_function_weights(checkpoint)

def get_masked_action_values(self, obs, action_masks):
assert False

Expand Down
3 changes: 3 additions & 0 deletions rl_games/algos_torch/central_value.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
import copy
import torch
from torch import nn
import torch.distributed as dist
Expand Down Expand Up @@ -219,6 +220,8 @@ def train_net(self):
self.train()
loss = 0
for _ in range(self.mini_epoch):
if self.config.get('freeze_critic', False):
break
for idx in range(len(self.dataset)):
loss += self.train_critic(self.dataset[idx])
if self.normalize_input:
Expand Down
14 changes: 12 additions & 2 deletions rl_games/algos_torch/network_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from rl_games.algos_torch.sac_helper import SquashedNormal
from rl_games.common.layers.recurrent import GRUWithDones, LSTMWithDones
from rl_games.common.layers.value import TwoHotEncodedValue, DefaultValue
from rl_games.algos_torch.spatial_softmax import SpatialSoftArgmax


def _create_initializer(func, **kwargs):
Expand Down Expand Up @@ -130,12 +131,17 @@ def _build_conv(self, ctype, **kwargs):

if ctype == 'conv2d':
return self._build_cnn2d(**kwargs)
if ctype == 'conv2d_spatial_softargmax':
return self._build_cnn2d(add_spatial_softmax=True, **kwargs)
if ctype == 'conv2d_flatten':
return self._build_cnn2d(add_flatten=True, **kwargs)
if ctype == 'coord_conv2d':
return self._build_cnn2d(conv_func=torch_ext.CoordConv2d, **kwargs)
if ctype == 'conv1d':
return self._build_cnn1d(**kwargs)

def _build_cnn2d(self, input_shape, convs, activation, conv_func=torch.nn.Conv2d, norm_func_name=None):
def _build_cnn2d(self, input_shape, convs, activation, conv_func=torch.nn.Conv2d, norm_func_name=None,
add_spatial_softmax=False, add_flatten=False):
in_channels = input_shape[0]
layers = []
for conv in convs:
Expand All @@ -150,7 +156,11 @@ def _build_cnn2d(self, input_shape, convs, activation, conv_func=torch.nn.Conv2d
if norm_func_name == 'layer_norm':
layers.append(torch_ext.LayerNorm2d(in_channels))
elif norm_func_name == 'batch_norm':
layers.append(torch.nn.BatchNorm2d(in_channels))
layers.append(torch.nn.BatchNorm2d(in_channels))
if add_spatial_softmax:
layers.append(SpatialSoftArgmax(normalize=True))
if add_flatten:
layers.append(torch.nn.Flatten())
return nn.Sequential(*layers)

def _build_cnn1d(self, input_shape, convs, activation, norm_func_name=None):
Expand Down
83 changes: 83 additions & 0 deletions rl_games/algos_torch/spatial_softmax.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
import torch
import torch.nn as nn
import torch.nn.functional as F


# Adopted from https://gist.github.com/kevinzakka/dd9fa5177cda13593524f4d71eb38ad5
class SpatialSoftArgmax(nn.Module):
"""Spatial softmax as defined in [1].
Concretely, the spatial softmax of each feature
map is used to compute a weighted mean of the pixel
locations, effectively performing a soft arg-max
over the feature dimension.
References:
[1]: End-to-End Training of Deep Visuomotor Policies,
https://arxiv.org/abs/1504.00702
"""

def __init__(self, normalize=False):
"""Constructor.
Args:
normalize (bool): Whether to use normalized
image coordinates, i.e. coordinates in
the range `[-1, 1]`.
"""
super().__init__()

self.normalize = normalize

def _coord_grid(self, h, w, device):
if self.normalize:
return torch.stack(
torch.meshgrid(
torch.linspace(-1, 1, w, device=device),
torch.linspace(-1, 1, h, device=device),
)
)
return torch.stack(
torch.meshgrid(
torch.arange(0, w, device=device),
torch.arange(0, h, device=device),
)
)

def forward(self, x):
assert x.ndim == 4, "Expecting a tensor of shape (B, C, H, W)."

# compute a spatial softmax over the input:
# given an input of shape (B, C, H, W),
# reshape it to (B*C, H*W) then apply
# the softmax operator over the last dimension
b, c, h, w = x.shape
softmax = F.softmax(x.reshape(-1, h * w), dim=-1)

# create a meshgrid of pixel coordinates
# both in the x and y axes
xc, yc = self._coord_grid(h, w, x.device)

# element-wise multiply the x and y coordinates
# with the softmax, then sum over the h*w dimension
# this effectively computes the weighted mean of x
# and y locations
x_mean = (softmax * xc.flatten()).sum(dim=1, keepdims=True)
y_mean = (softmax * yc.flatten()).sum(dim=1, keepdims=True)

# concatenate and reshape the result
# to (B, C*2) where for every feature
# we have the expected x and y pixel
# locations
return torch.cat([x_mean, y_mean], dim=1).view(-1, c * 2)


if __name__ == "__main__":
b, c, h, w = 32, 64, 12, 12
x = torch.zeros(b, c, h, w)
true_max = torch.randint(0, 10, size=(b, c, 2))
for i in range(b):
for j in range(c):
x[i, j, true_max[i, j, 0], true_max[i, j, 1]] = 1000
soft_max = SpatialSoftArgmax()(x).reshape(b, c, 2)
assert torch.allclose(true_max.float(), soft_max)
8 changes: 7 additions & 1 deletion rl_games/common/a2c_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -643,6 +643,9 @@ def set_full_state_weights(self, weights, set_epoch=True):
env_state = weights.get('env_state', None)
self.vec_env.set_env_state(env_state)

def set_central_value_function_weights(self, weights):
self.central_value_net.load_state_dict(weights['assymetric_vf_nets'])

def get_weights(self):
state = self.get_stats_weights()
state['model'] = self.model.state_dict()
Expand Down Expand Up @@ -1262,7 +1265,10 @@ def prepare_dataset(self, batch_dict):
advantages = returns - values

if self.normalize_value:
self.value_mean_std.train()
if self.config.get('freeze_critic', False):
self.value_mean_std.eval()
else:
self.value_mean_std.train()
values = self.value_mean_std(values)
returns = self.value_mean_std(returns)
self.value_mean_std.eval()
Expand Down
4 changes: 4 additions & 0 deletions rl_games/torch_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,10 @@

def _restore(agent, args):
if 'checkpoint' in args and args['checkpoint'] is not None and args['checkpoint'] !='':
if args['train'] and args.get('load_critic_only', False):
assert agent.has_central_value, 'This should only work for asymmetric actor critic'
agent.restore_central_value_function(args['checkpoint'])
return
agent.restore(args['checkpoint'])

def _override_sigma(agent, args):
Expand Down

0 comments on commit 2606eff

Please sign in to comment.