diff --git a/rl_games/envs/maniskill.py b/rl_games/envs/maniskill.py index 568b2de3..ab6b89aa 100644 --- a/rl_games/envs/maniskill.py +++ b/rl_games/envs/maniskill.py @@ -7,6 +7,7 @@ import gymnasium as gym2 import gymnasium.spaces.utils from gymnasium.vector.utils import batch_space + from mani_skill.utils import common diff --git a/rl_games/networks/vision_networks.py b/rl_games/networks/vision_networks.py index ff491d8e..db039ae1 100644 --- a/rl_games/networks/vision_networks.py +++ b/rl_games/networks/vision_networks.py @@ -1,7 +1,6 @@ import torch from torch import nn from torchvision import models -from rl_games.algos_torch.running_mean_std import RunningMeanStd, RunningMeanStdObs import torch.nn.functional as F from rl_games.algos_torch import torch_ext from rl_games.algos_torch.network_builder import NetworkBuilder, ImpalaSequential @@ -18,10 +17,22 @@ class Network(NetworkBuilder.BaseNetwork): def __init__(self, params, **kwargs): self.actions_num = actions_num = kwargs.pop('actions_num') full_input_shape = kwargs.pop('input_shape') - self.use_aux_loss = kwargs.pop('use_aux_loss', True) print(kwargs) print("params: ", params) - self.aux_loss_weight = 100.0 # kwargs.pop('aux_loss_weight', 1.0) + + aux_loss_params = params.get('aux_loss', None) + if aux_loss_params is not None: + print("Auxilliary losses params: ", aux_loss_params) + self.use_aux_loss = aux_loss_params.get('use_aux_loss', False) + self.aux_loss_weight = aux_loss_params.get('weight', 1.0) + self.aux_concat = aux_loss_params.get('concat', False) + # test policy with auxilliary targets as input + self.test_baseline = aux_loss_params.get('test_baseline', False) + else: + self.use_aux_loss = False + self.aux_loss_weight = 1.0 + + print("Full input shape: ", full_input_shape) if self.use_aux_loss: self.target_key = 'aux_target' if 'aux_target' in full_input_shape: @@ -49,7 +60,7 @@ def __init__(self, params, **kwargs): cnn_output_size = self._calc_input_size(input_shape, self.cnn) mlp_input_size = cnn_output_size + self.proprio_size - if self.use_aux_loss: + if self.use_aux_loss and self.aux_concat: mlp_input_size += self.target_shape[0] if len(self.units) == 0: @@ -57,8 +68,9 @@ def __init__(self, params, **kwargs): else: out_size = self.units[-1] + if self.test_baseline: + mlp_input_size = self.proprio_size + self.target_shape[0] self.layer_norm_emb = torch.nn.LayerNorm(mlp_input_size) - #self.layer_norm_emb = torch.nn.RMSNorm(mlp_input_size) if self.has_rnn: if not self.is_rnn_before_mlp: @@ -118,6 +130,10 @@ def __init__(self, params, **kwargs): if isinstance(m, nn.Linear): mlp_init(m.weight) + if self.use_aux_loss: + #mlp_init(self.aux_loss_linear.weight) + torch.nn.init.xavier_normal_(self.aux_loss_linear.weight, gain=0.02) + if self.is_discrete: mlp_init(self.logits.weight) if self.is_continuous: @@ -139,9 +155,6 @@ def forward(self, obs_dict): else: obs = obs_dict['obs'] - if self.use_aux_loss: - target_obs = obs_dict['obs'][self.target_key] - if self.permute_input: obs = obs.permute((0, 3, 1, 2)) @@ -154,18 +167,28 @@ def forward(self, obs_dict): out = out.flatten(1) cnn_out = self.flatten_act(out) - if self.proprio_size > 0: + # TODO: Add support for without proprioception obs case + if self.proprio_size > 0 and not self.test_baseline: out = torch.cat([cnn_out, proprio], dim=1) if self.use_aux_loss: y = self.aux_loss_linear(cnn_out) - out = torch.cat([out, y], dim=1) + target_obs = obs_dict['obs'][self.target_key] + + if self.aux_concat: + out = torch.cat([out, y.detach()], dim=1) + elif self.test_baseline: + out = torch.cat([proprio, target_obs], dim=1) + # print("aux predicted shape: ", y.shape) + # print("target shape: ", target_obs.shape) + # print("aux predicted: ", y[0]) + # print("aux target: ", target_obs[0]) self.aux_loss_map['aux_dist_loss'] = self.aux_loss_weight * torch.nn.functional.mse_loss(y, target_obs) # print("aux predicted shape: ", y.shape) # print("aux predicted: ", y) # print("aux target: ", target_obs) # print("delta: ", y - target_obs) - # print("aux loss: ", self.aux_loss_map['aux_dist_loss']) + #print("aux loss: ", self.aux_loss_map['aux_dist_loss']) out = self.layer_norm_emb(out) @@ -302,7 +325,7 @@ def __init__(self, params, **kwargs): self.actions_num = kwargs.pop('actions_num') full_input_shape = kwargs.pop('input_shape') - self.use_aux_loss = kwargs.pop('use_aux_loss', False) + self.use_aux_loss = kwargs.pop('use_aux_loss', True) self.aux_loss_weight = 100.0 if self.use_aux_loss: self.target_key = 'aux_target' @@ -420,9 +443,6 @@ def forward(self, obs_dict): else: obs = obs_dict['obs'] - if self.use_aux_loss: - target_obs = obs_dict['obs'][self.target_key] - if self.permute_input: obs = obs.permute((0, 3, 1, 2)) @@ -442,6 +462,7 @@ def forward(self, obs_dict): out = torch.cat([vis_out, proprio], dim=1) if self.use_aux_loss: + target_obs = obs_dict['obs'][self.target_key] y = self.aux_loss_linear(vis_out) out = torch.cat([out, y], dim=1) self.aux_loss_map['aux_dist_loss'] = self.aux_loss_weight * torch.nn.functional.mse_loss(y, target_obs) @@ -530,18 +551,36 @@ def _build_backbone(self, input_shape, backbone_params): modify_first_layer = backbone_params.get('modify_first_layer', False) self.preprocess_image = backbone_params.get('preprocess_image', False) - # Define a mapping from backbone type to required resize size + # Mapping from backbone_type to required resize size resize_size_map = { - 'vit_b_16': 224, # Must be divisible by 16 - 'dinov2_vits14_reg': 196, # Must be divisible by 14 - 'vit_mae': 224, # Must be divisible by 16 - 'vit_tiny': 224, # ViT-Tiny adjusted for 16x16 patch size and smaller input - 'deit_tiny': 224, # DeiT-Tiny adjusted similarly - 'deit_tiny_distilled': 224, # DeiT-Tiny distilled version - 'mobilevit_s': 256, # MobileViT-S - 'efficientformer_l2': 224, # EfficientFormer-L2 - 'swinv2': 224, # Swin Transformer - # Add other ViT variants as needed + 'vit_b_16': 224, + 'vit_mae': 224, + 'vit_tiny': 224, + 'deit_tiny': 224, + 'deit_tiny_distilled': 224, + 'mobilevit_s': 256, + 'efficientformerv2_l': 224, + 'efficientformerv2_s0': 224, + 'efficientformerv2_s1': 224, + 'efficientformerv2_s2': 224, + 'swinv2': 256, # SwinV2 typically uses 256x256 input size + # Add other models as needed + } + + # Mapping from backbone_type to actual model name in TIMM + model_name_map = { + 'vit_tiny': 'vit_tiny_patch16_224', + 'deit_tiny': 'deit_tiny_patch16_224', + 'deit_tiny_distilled': 'deit_tiny_distilled_patch16_224', + # 'mobilevit_s': 'mobilevit_s', + 'efficientformerv2_l': 'efficientformerv2_l.snap_dist_in1k', + 'efficientformerv2_s0': 'efficientformerv2_s0.snap_dist_in1k', + 'efficientformerv2_s1': 'efficientformerv2_s1.snap_dist_in1k', + 'efficientformerv2_s2': 'efficientformerv2_s2.snap_dist_in1k', + # 'vit_b_16': 'vit_base_patch16_224', + 'vit_mae': 'vit_base_patch16_224_mae', + # 'swinv2': 'swinv2_tiny_window8_256', + # Add other mappings as needed } backbone = None @@ -592,32 +631,11 @@ def _build_backbone(self, input_shape, backbone_params): ) elif backbone_type.lower() in resize_size_map: - import timm # Unified ViT handling for various ViT models resize_size = resize_size_map.get(backbone_type.lower(), 224) # Default to 224 if not specified - if backbone_type.lower() == 'vit_tiny': - backbone = timm.create_model('vit_tiny_patch16_224', pretrained=pretrained) - backbone_output_size = backbone.head.in_features - elif backbone_type.lower() == 'deit_tiny': - backbone = timm.create_model('deit_tiny_patch16_224', pretrained=pretrained) - backbone_output_size = backbone.head.in_features - elif backbone_type.lower() == 'deit_tiny_distilled': - backbone = timm.create_model('deit_tiny_distilled_patch16_224', pretrained=pretrained,) - backbone_output_size = backbone.head.in_features - elif backbone_type.lower() == 'mobilevit_s': - print("Not working") - backbone = timm.create_model('mobilevit_s', pretrained=pretrained) - print(backbone) - backbone_output_size = backbone.head.fc.in_features - print(backbone_output_size) - elif backbone_type.lower() == 'efficientformer_l2': - backbone = timm.create_model('efficientformerv2_s1.snap_dist_in1k', pretrained=pretrained) - # print(backbone) - # print(backbone.pos_embed.shape) - backbone_output_size = backbone.head.in_features - elif backbone_type.lower() == 'vit_b_16': + if backbone_type.lower() == 'vit_b_16': backbone = models.vision_transformer.vit_b_16(pretrained=pretrained) backbone_output_size = backbone.heads.head.in_features elif backbone_type.lower() == 'dinov2_vits14_reg': @@ -626,12 +644,51 @@ def _build_backbone(self, input_shape, backbone_params): except Exception as e: raise ValueError(f"Failed to load dinov2_vits14_reg: {e}") backbone_output_size = 384 # As per Dinov2 ViT-S14 regression output - elif backbone_type.lower() == 'vit_mae': + elif backbone_type == 'mobilevit_s': + import timm + mobilevit_models = [model_name for model_name in timm.list_models() if 'mobilevit' in model_name.lower()] + print("Available MobileViT Models in timm:") + print(mobilevit_models) + model_name = 'mobilevitv2_100' #'mobilevit_s' + try: + backbone = timm.create_model( + model_name, + pretrained=pretrained, + num_classes=0, # Removes the classification head + # global_pool='avg' # Adds global average pooling + ) + backbone_output_size = backbone.num_features + except Exception as e: + raise ValueError(f"Failed to create model '{model_name}': {e}") + + # # Manually add global average pooling and flatten layers + # backbone = nn.Sequential( + # backbone, + # nn.AdaptiveAvgPool2d((1, 1)), # Global average pooling + # nn.Flatten(1) # Flatten from the first dimension onward + # ) + + # Add resize layer if needed + # resize_size = resize_size_map.get(backbone_type, 256) + resize_size = 160 + resize_layer = nn.Upsample(size=(resize_size, resize_size), mode='bilinear', align_corners=False) + backbone = nn.Sequential(resize_layer, backbone) + elif backbone_type.lower() in model_name_map: + import timm + model_name = model_name_map.get(backbone_type) + if model_name is None: + raise ValueError(f'Unknown backbone type: {backbone_type}') try: - backbone = timm.create_model('vit_base_patch16_224.mae', pretrained=pretrained) + backbone = timm.create_model(model_name, pretrained=pretrained) + if hasattr(backbone, 'embed_dim'): + backbone_output_size = backbone.embed_dim + elif hasattr(backbone, 'num_features'): + backbone_output_size = backbone.num_features + else: + raise ValueError(f"Unable to determine backbone output size for model '{model_name}'") except Exception as e: - raise ValueError(f"Failed to load vit_mae: {e}") - backbone_output_size = 768 # Typically 768 for ViT-Base + raise ValueError(f"Failed to create model '{model_name}': {e}") + elif backbone_type.lower() == 'swinv2': print("Not working") try: @@ -642,20 +699,44 @@ def _build_backbone(self, input_shape, backbone_params): else: raise ValueError(f'Unknown ViT model type: {backbone_type}') - # Remove the classification/regression head if present - if hasattr(backbone, 'heads') and hasattr(backbone.heads, 'head'): - backbone.heads.head = nn.Identity() - elif hasattr(backbone, 'head'): + # # Remove the classification/regression head if present + # if hasattr(backbone, 'heads') and hasattr(backbone.heads, 'head'): + # backbone.heads.head = nn.Identity() + # elif hasattr(backbone, 'head'): + # backbone.head = nn.Identity() + # elif hasattr(backbone, 'decoder'): # For MAE models + # backbone.decoder = nn.Identity() + # else: + # print(f"Unable to locate the classification/regression head in {backbone_type} model.") + + # Get backbone output size + # if hasattr(backbone, 'embed_dim'): + # backbone_output_size = backbone.embed_dim + # elif hasattr(backbone, 'num_features'): + # backbone_output_size = backbone.num_features + # elif hasattr(backbone, 'head') and hasattr(backbone.head, 'in_features'): + # backbone_output_size = backbone.head.in_features + # else: + # raise ValueError(f"Unable to determine backbone output size for model '{model_name}'") + + # # Add a resize layer to ensure the input is correctly sized for the specific ViT model + # resize_layer = nn.Upsample(size=(resize_size, resize_size), mode='bilinear', align_corners=False) + + # # Combine the resize layer and the backbone into a sequential model + # backbone = nn.Sequential(resize_layer, backbone) + + # Remove classification head + if hasattr(backbone, 'head'): backbone.head = nn.Identity() - elif hasattr(backbone, 'decoder'): # For MAE models - backbone.decoder = nn.Identity() + elif hasattr(backbone, 'fc'): + backbone.fc = nn.Identity() else: - print(f"Unable to locate the classification/regression head in {backbone_type} model.") + # TODO: Double-check with resnet + print(f"Unable to locate the classification head in {backbone_type} model.") - # Add a resize layer to ensure the input is correctly sized for the specific ViT model + # Add resize layer if needed + resize_size = resize_size_map.get(backbone_type, 224) resize_layer = nn.Upsample(size=(resize_size, resize_size), mode='bilinear', align_corners=False) - - # Combine the resize layer and the backbone into a sequential model backbone = nn.Sequential(resize_layer, backbone) else: @@ -671,7 +752,7 @@ def _build_backbone(self, input_shape, backbone_params): param.requires_grad = False else: # General case for other backbones (e.g., ResNet, ConvNeXt) - if 'conv1' not in name and 'features.0.0' not in name: + if 'conv1' not in name and 'features.0.0' not in name and 'patch_embed.proj' not in name and 'cls_token' not in name: param.requires_grad = False return backbone, backbone_output_size