Skip to content

Commit

Permalink
Added support for more visual backbones.
Browse files Browse the repository at this point in the history
  • Loading branch information
ViktorM committed Aug 17, 2024
1 parent 04d653a commit 637d2bc
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 47 deletions.
16 changes: 3 additions & 13 deletions rl_games/algos_torch/a2c_continuous.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def __init__(self, base_name, params):
'normalize_value' : self.normalize_value,
'normalize_input': self.normalize_input,
}

self.model = self.network.build(build_config)
self.model.to(self.ppo_device)
self.states = None
Expand Down Expand Up @@ -79,7 +79,7 @@ def __init__(self, base_name, params):
def update_epoch(self):
self.epoch_num += 1
return self.epoch_num

def save(self, fn):
state = self.get_full_state_weights()
torch_ext.save_checkpoint(fn, state)
Expand Down Expand Up @@ -144,14 +144,6 @@ def calc_gradients(self, input_dict):
'obs' : obs_batch,
}

# print("TEST")
# print("----------------")
# for key in input_dict:
# print(key)

# if "proprio" in input_dict:
# batch_dict['proprio'] = input_dict['proprio']

rnn_masks = None
if self.is_rnn:
rnn_masks = input_dict['rnn_masks']
Expand All @@ -172,7 +164,7 @@ def calc_gradients(self, input_dict):
loss, a_loss, c_loss, entropy, b_loss, sum_mask = self.calc_losses(self.actor_loss_func,
old_action_log_probs_batch, action_log_probs, advantage, curr_e_clip,
value_preds_batch, values, return_batch, mu, entropy, rnn_masks)

if self.multi_gpu:
self.optimizer.zero_grad()
else:
Expand Down Expand Up @@ -222,5 +214,3 @@ def bound_loss(self, mu):
else:
b_loss = 0
return b_loss


1 change: 1 addition & 0 deletions rl_games/envs/test/test_asymmetric_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import numpy as np
from rl_games.common.wrappers import MaskVelocityWrapper


class TestAsymmetricCritic(gym.Env):
def __init__(self, wrapped_env_name, **kwargs):
gym.Env.__init__(self)
Expand Down
63 changes: 29 additions & 34 deletions rl_games/networks/vision_networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from rl_games.algos_torch.network_builder import NetworkBuilder, ImpalaSequential


class A2CVisionBuilder(NetworkBuilder):
class VisionImpalaBuilder(NetworkBuilder):
def __init__(self, **kwargs):
NetworkBuilder.__init__(self)

Expand Down Expand Up @@ -225,11 +225,13 @@ def get_default_rnn_state(self):
return (torch.zeros((num_layers, self.num_seqs, self.rnn_units)))

def build(self, name, **kwargs):
net = A2CVisionBuilder.Network(self.params, **kwargs)
net = VisionImpalaBuilder.Network(self.params, **kwargs)
return net


class A2CVisionBackboneBuilder(NetworkBuilder):
from timm import create_model # timm is required for ConvNeXt and ViT

class VisionBackboneBuilder(NetworkBuilder):
def __init__(self, **kwargs):
NetworkBuilder.__init__(self)

Expand All @@ -241,7 +243,7 @@ def __init__(self, params, **kwargs):
self.actions_num = kwargs.pop('actions_num')
full_input_shape = kwargs.pop('input_shape')
proprio_size = 0 # Number of proprioceptive features
if type(full_input_shape) is dict:
if isinstance(full_input_shape, dict):
input_shape = full_input_shape['camera']
proprio_shape = full_input_shape['proprio']
proprio_size = proprio_shape[0]
Expand All @@ -256,9 +258,8 @@ def __init__(self, params, **kwargs):
if self.permute_input:
input_shape = torch_ext.shape_whc_to_cwh(input_shape)

self.cnn = self._build_resnet(input_shape, self.params['cnn']['pretrained'])
cnn_output_size = self.cnn.fc.in_features # Output size after ResNet
proprio_size = proprio_shape[0] # Number of proprioceptive features
self.cnn = self._build_backbone(input_shape, self.params['backbone'])
cnn_output_size = self.cnn_output_size

mlp_input_size = cnn_output_size + proprio_size
if len(self.units) == 0:
Expand Down Expand Up @@ -326,7 +327,6 @@ def __init__(self, params, **kwargs):
mlp_init(self.value.weight)

def forward(self, obs_dict):
# TODO: Add resnet preprocessing
obs = obs_dict['camera']
proprio = obs_dict['proprio']
if self.permute_input:
Expand Down Expand Up @@ -421,29 +421,24 @@ def load(self, params):
self.require_rewards = params.get('require_rewards')
self.require_last_actions = params.get('require_last_actions')

def _build_resnet(self, input_shape, pretrained):
resnet = models.resnet18(pretrained=pretrained)
# Modify the first convolution layer to match input shape if needed
if input_shape[0] != 3:
resnet.conv1 = nn.Conv2d(input_shape[0], 64, kernel_size=7, stride=2, padding=3, bias=False)
# Remove the fully connected layer
resnet = nn.Sequential(*list(resnet.children())[:-1])
return resnet

def is_separate_critic(self):
return False

def is_rnn(self):
return self.has_rnn

def get_default_rnn_state(self):
num_layers = self.rnn_layers
if self.rnn_name == 'lstm':
return (torch.zeros((num_layers, self.num_seqs, self.rnn_units)),
torch.zeros((num_layers, self.num_seqs, self.rnn_units)))
else:
return (torch.zeros((num_layers, self.num_seqs, self.rnn_units)))

def build(self, name, **kwargs):
net = A2CVisionBackboneBuilder.Network(self.params, **kwargs)
return net
def _build_backbone(self, input_shape, backbone_params):
backbone_type = backbone_params['type']
pretrained = backbone_params.get('pretrained', False)

if backbone_type == 'resnet18':
model = models.resnet18(pretrained=pretrained)
# Modify the first convolution layer to match input shape if needed
if input_shape[0] != 3:
model.conv1 = nn.Conv2d(input_shape[0], 64, kernel_size=7, stride=2, padding=3, bias=False)
# Remove the fully connected layer
self.cnn_output_size = model.fc.in_features
model = nn.Sequential(*list(model.children())[:-1])
elif backbone_type == 'convnext_tiny':
model = create_model('convnext_tiny', pretrained=pretrained)
# Remove the fully connected layer
self.cnn_output_size = model.head.fc.in_features
model = nn.Sequential(*list(model.children())[:-1])
elif backbone_type == 'vit_tiny_patch16_224':
model = create_model('vit_tiny_patch16_224', pretrained=pretrained)
# ViT outputs a single token, so no need to remove layers
self.cnn_output

0 comments on commit 637d2bc

Please sign in to comment.