From 0dc5b439774b670ad41e4724cb75853d4e916f23 Mon Sep 17 00:00:00 2001 From: ViktorM Date: Sat, 17 Aug 2024 18:32:36 -0700 Subject: [PATCH] WIP: network with different vision backbones. --- rl_games/algos_torch/model_builder.py | 4 +- rl_games/algos_torch/network_builder.py | 233 +++++++++++++++++++++++- rl_games/networks/vision_networks.py | 31 ++-- 3 files changed, 253 insertions(+), 15 deletions(-) diff --git a/rl_games/algos_torch/model_builder.py b/rl_games/algos_torch/model_builder.py index e33bd86f..c22fca0e 100644 --- a/rl_games/algos_torch/model_builder.py +++ b/rl_games/algos_torch/model_builder.py @@ -22,8 +22,8 @@ def __init__(self): lambda **kwargs: network_builder.A2CResnetBuilder()) self.network_factory.register_builder('vision_actor_critic', lambda **kwargs: network_builder.A2CVisionBuilder()) - # self.network_factory.register_builder('e2e_vision_actor_critic', - # lambda **kwargs: vision_networks.A2CVisionBackboneBuilder()) + self.network_factory.register_builder('e2e_vision_actor_critic', + lambda **kwargs: network_builder.VisionBackboneBuilder()) self.network_factory.register_builder('rnd_curiosity', lambda **kwargs: network_builder.RNDCuriosityBuilder()) self.network_factory.register_builder('soft_actor_critic', lambda **kwargs: network_builder.SACBuilder()) diff --git a/rl_games/algos_torch/network_builder.py b/rl_games/algos_torch/network_builder.py index 97814c40..29968f40 100644 --- a/rl_games/algos_torch/network_builder.py +++ b/rl_games/algos_torch/network_builder.py @@ -871,8 +871,6 @@ def __init__(self, params, **kwargs): self.num_seqs = kwargs.pop('num_seqs', 1) self.value_size = kwargs.pop('value_size', 1) - print(params) - NetworkBuilder.BaseNetwork.__init__(self) self.load(params) if self.permute_input: @@ -1074,6 +1072,237 @@ def build(self, name, **kwargs): return net +from torchvision import models +from timm import create_model # timm is required for ConvNeXt and ViT + +class VisionBackboneBuilder(NetworkBuilder): + def __init__(self, **kwargs): + NetworkBuilder.__init__(self) + + def load(self, params): + self.params = params + + class Network(NetworkBuilder.BaseNetwork): + def __init__(self, params, **kwargs): + self.actions_num = kwargs.pop('actions_num') + full_input_shape = kwargs.pop('input_shape') + + self.proprio_size = 0 # Number of proprioceptive features + if isinstance(full_input_shape, dict): + input_shape = full_input_shape['camera'] + proprio_shape = full_input_shape['proprio'] + self.proprio_size = proprio_shape[0] + else: + input_shape = full_input_shape + + self.num_seqs = kwargs.pop('num_seqs', 1) + self.value_size = kwargs.pop('value_size', 1) + + NetworkBuilder.BaseNetwork.__init__(self) + self.load(params) + if self.permute_input: + input_shape = torch_ext.shape_whc_to_cwh(input_shape) + + self.cnn = self._build_backbone(input_shape, self.params['backbone']) + cnn_output_size = self.cnn_output_size + + mlp_input_size = cnn_output_size + self.proprio_size + if len(self.units) == 0: + out_size = cnn_output_size + else: + out_size = self.units[-1] + + if self.has_rnn: + if not self.is_rnn_before_mlp: + rnn_in_size = out_size + out_size = self.rnn_units + else: + rnn_in_size = mlp_input_size + mlp_input_size = self.rnn_units + + self.rnn = self._build_rnn(self.rnn_name, rnn_in_size, self.rnn_units, self.rnn_layers) + self.layer_norm = torch.nn.LayerNorm(self.rnn_units) + + mlp_args = { + 'input_size': mlp_input_size, + 'units': self.units, + 'activation': self.activation, + 'norm_func_name': self.normalization, + 'dense_func': torch.nn.Linear + } + + self.mlp = self._build_mlp(**mlp_args) + + self.value = self._build_value_layer(out_size, self.value_size) + self.value_act = self.activations_factory.create(self.value_activation) + self.flatten_act = self.activations_factory.create(self.activation) + + if self.is_discrete: + self.logits = torch.nn.Linear(out_size, self.actions_num) + if self.is_continuous: + self.mu = torch.nn.Linear(out_size, self.actions_num) + self.mu_act = self.activations_factory.create(self.space_config['mu_activation']) + mu_init = self.init_factory.create(**self.space_config['mu_init']) + self.sigma_act = self.activations_factory.create(self.space_config['sigma_activation']) + sigma_init = self.init_factory.create(**self.space_config['sigma_init']) + + if self.fixed_sigma: + self.sigma = nn.Parameter(torch.zeros(self.actions_num, requires_grad=True, dtype=torch.float32), requires_grad=True) + else: + self.sigma = torch.nn.Linear(out_size, self.actions_num) + + mlp_init = self.init_factory.create(**self.initializer) + + # for m in self.modules(): + # if isinstance(m, nn.Conv2d): + # nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + for m in self.mlp: + if isinstance(m, nn.Linear): + mlp_init(m.weight) + + if self.is_discrete: + mlp_init(self.logits.weight) + if self.is_continuous: + mu_init(self.mu.weight) + if self.fixed_sigma: + sigma_init(self.sigma) + else: + sigma_init(self.sigma.weight) + + mlp_init(self.value.weight) + + def forward(self, obs_dict): + if self.proprio_size > 0: + obs = obs_dict['camera'] + proprio = obs_dict['proprio'] + else: + obs = obs_dict['obs'] + if self.permute_input: + obs = obs.permute((0, 3, 1, 2)) + + dones = obs_dict.get('dones', None) + bptt_len = obs_dict.get('bptt_len', 0) + states = obs_dict.get('rnn_states', None) + + out = obs + out = self.cnn(out) + out = out.flatten(1) + out = self.flatten_act(out) + + if self.proprio_size > 0: + out = torch.cat([out, proprio], dim=1) + + if self.has_rnn: + seq_length = obs_dict.get('seq_length', 1) + + out_in = out + if not self.is_rnn_before_mlp: + out_in = out + out = self.mlp(out) + + batch_size = out.size()[0] + num_seqs = batch_size // seq_length + out = out.reshape(num_seqs, seq_length, -1) + + if len(states) == 1: + states = states[0] + + out = out.transpose(0, 1) + if dones is not None: + dones = dones.reshape(num_seqs, seq_length, -1) + dones = dones.transpose(0, 1) + out, states = self.rnn(out, states, dones, bptt_len) + out = out.transpose(0, 1) + out = out.contiguous().reshape(out.size()[0] * out.size()[1], -1) + + if self.rnn_ln: + out = self.layer_norm(out) + if self.is_rnn_before_mlp: + out = self.mlp(out) + if not isinstance(states, tuple): + states = (states,) + else: + out = self.mlp(out) + + value = self.value_act(self.value(out)) + + if self.is_discrete: + logits = self.logits(out) + return logits, value, states + + if self.is_continuous: + mu = self.mu_act(self.mu(out)) + if self.fixed_sigma: + sigma = self.sigma_act(self.sigma) + else: + sigma = self.sigma_act(self.sigma(out)) + return mu, mu * 0 + sigma, value, states + + def load(self, params): + self.separate = False + self.units = params['mlp']['units'] + self.activation = params['mlp']['activation'] + self.initializer = params['mlp']['initializer'] + self.is_discrete = 'discrete' in params['space'] + self.is_continuous = 'continuous' in params['space'] + self.is_multi_discrete = 'multi_discrete' in params['space'] + self.value_activation = params.get('value_activation', 'None') + self.normalization = params.get('normalization', None) + + if self.is_continuous: + self.space_config = params['space']['continuous'] + self.fixed_sigma = self.space_config['fixed_sigma'] + elif self.is_discrete: + self.space_config = params['space']['discrete'] + elif self.is_multi_discrete: + self.space_config = params['space']['multi_discrete'] + + self.has_rnn = 'rnn' in params + if self.has_rnn: + self.rnn_units = params['rnn']['units'] + self.rnn_layers = params['rnn']['layers'] + self.rnn_name = params['rnn']['name'] + self.is_rnn_before_mlp = params['rnn'].get('before_mlp', False) + self.rnn_ln = params['rnn'].get('layer_norm', False) + + self.has_cnn = True + self.permute_input = params['backbone'].get('permute_input', True) + self.require_rewards = params.get('require_rewards') + self.require_last_actions = params.get('require_last_actions') + + def _build_backbone(self, input_shape, backbone_params): + print("AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA") + print(backbone_params) + backbone_type = backbone_params['type'] + pretrained = backbone_params.get('pretrained', False) + + if backbone_type == 'resnet18': + model = models.resnet18(pretrained=pretrained) + # Modify the first convolution layer to match input shape if needed + if input_shape[0] != 3: + model.conv1 = nn.Conv2d(input_shape[0], 64, kernel_size=7, stride=2, padding=3, bias=False) + # Remove the fully connected layer + self.cnn_output_size = model.fc.in_features + model = nn.Sequential(*list(model.children())[:-1]) + elif backbone_type == 'convnext_tiny': + model = create_model('convnext_tiny', pretrained=pretrained) + # Remove the fully connected layer + self.cnn_output_size = model.head.fc.in_features + model = nn.Sequential(*list(model.children())[:-1]) + # elif backbone_type == 'vit_tiny_patch16_224': + # model = create_model('vit_tiny_patch16_224', pretrained=pretrained) + # # ViT outputs a single token, so no need to remove layers + # self.cnn_output + else: + raise ValueError(f'Unknown backbone type: {backbone_type}') + + return model + + def build(self, name, **kwargs): + net = VisionBackboneBuilder.Network(self.params, **kwargs) + return net + + class DiagGaussianActor(NetworkBuilder.BaseNetwork): """torch.distributions implementation of an diagonal Gaussian policy.""" def __init__(self, output_dim, log_std_bounds, **mlp_args): diff --git a/rl_games/networks/vision_networks.py b/rl_games/networks/vision_networks.py index a3431630..fc4a7772 100644 --- a/rl_games/networks/vision_networks.py +++ b/rl_games/networks/vision_networks.py @@ -242,11 +242,12 @@ class Network(NetworkBuilder.BaseNetwork): def __init__(self, params, **kwargs): self.actions_num = kwargs.pop('actions_num') full_input_shape = kwargs.pop('input_shape') - proprio_size = 0 # Number of proprioceptive features + + self.proprio_size = 0 # Number of proprioceptive features if isinstance(full_input_shape, dict): input_shape = full_input_shape['camera'] proprio_shape = full_input_shape['proprio'] - proprio_size = proprio_shape[0] + self.proprio_size = proprio_shape[0] else: input_shape = full_input_shape @@ -261,7 +262,7 @@ def __init__(self, params, **kwargs): self.cnn = self._build_backbone(input_shape, self.params['backbone']) cnn_output_size = self.cnn_output_size - mlp_input_size = cnn_output_size + proprio_size + mlp_input_size = cnn_output_size + self.proprio_size if len(self.units) == 0: out_size = cnn_output_size else: @@ -308,9 +309,9 @@ def __init__(self, params, **kwargs): mlp_init = self.init_factory.create(**self.initializer) - for m in self.modules(): - if isinstance(m, nn.Conv2d): - nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + # for m in self.modules(): + # if isinstance(m, nn.Conv2d): + # nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') for m in self.mlp: if isinstance(m, nn.Linear): mlp_init(m.weight) @@ -327,8 +328,11 @@ def __init__(self, params, **kwargs): mlp_init(self.value.weight) def forward(self, obs_dict): - obs = obs_dict['camera'] - proprio = obs_dict['proprio'] + if self.proprio_size > 0: + obs = obs_dict['camera'] + proprio = obs_dict['proprio'] + else: + obs = obs_dict['obs'] if self.permute_input: obs = obs.permute((0, 3, 1, 2)) @@ -341,7 +345,8 @@ def forward(self, obs_dict): out = out.flatten(1) out = self.flatten_act(out) - out = torch.cat([out, proprio], dim=1) + if self.proprio_size > 0: + out = torch.cat([out, proprio], dim=1) if self.has_rnn: seq_length = obs_dict.get('seq_length', 1) @@ -417,7 +422,7 @@ def load(self, params): self.rnn_ln = params['rnn'].get('layer_norm', False) self.has_cnn = True - self.permute_input = params['cnn'].get('permute_input', True) + self.permute_input = params['backbone'].get('permute_input', True) self.require_rewards = params.get('require_rewards') self.require_last_actions = params.get('require_last_actions') @@ -441,4 +446,8 @@ def _build_backbone(self, input_shape, backbone_params): elif backbone_type == 'vit_tiny_patch16_224': model = create_model('vit_tiny_patch16_224', pretrained=pretrained) # ViT outputs a single token, so no need to remove layers - self.cnn_output \ No newline at end of file + self.cnn_output + + def build(self, name, **kwargs): + net = VisionBackboneBuilder.Network(self.params, **kwargs) + return net \ No newline at end of file