Skip to content

Commit

Permalink
WIP: e2e. Mor vit architectures. Better aux loss support from yaml. t…
Browse files Browse the repository at this point in the history
…ODO: unify impala and vision backbones.
  • Loading branch information
ViktorM committed Sep 30, 2024
1 parent ed14b38 commit e2c6bd0
Show file tree
Hide file tree
Showing 2 changed files with 145 additions and 63 deletions.
1 change: 1 addition & 0 deletions rl_games/envs/maniskill.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import gymnasium as gym2
import gymnasium.spaces.utils
from gymnasium.vector.utils import batch_space

from mani_skill.utils import common


Expand Down
207 changes: 144 additions & 63 deletions rl_games/networks/vision_networks.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import torch
from torch import nn
from torchvision import models
from rl_games.algos_torch.running_mean_std import RunningMeanStd, RunningMeanStdObs
import torch.nn.functional as F
from rl_games.algos_torch import torch_ext
from rl_games.algos_torch.network_builder import NetworkBuilder, ImpalaSequential
Expand All @@ -18,10 +17,22 @@ class Network(NetworkBuilder.BaseNetwork):
def __init__(self, params, **kwargs):
self.actions_num = actions_num = kwargs.pop('actions_num')
full_input_shape = kwargs.pop('input_shape')
self.use_aux_loss = kwargs.pop('use_aux_loss', True)
print(kwargs)
print("params: ", params)
self.aux_loss_weight = 100.0 # kwargs.pop('aux_loss_weight', 1.0)

aux_loss_params = params.get('aux_loss', None)
if aux_loss_params is not None:
print("Auxilliary losses params: ", aux_loss_params)
self.use_aux_loss = aux_loss_params.get('use_aux_loss', False)
self.aux_loss_weight = aux_loss_params.get('weight', 1.0)
self.aux_concat = aux_loss_params.get('concat', False)
# test policy with auxilliary targets as input
self.test_baseline = aux_loss_params.get('test_baseline', False)
else:
self.use_aux_loss = False
self.aux_loss_weight = 1.0

print("Full input shape: ", full_input_shape)
if self.use_aux_loss:
self.target_key = 'aux_target'
if 'aux_target' in full_input_shape:
Expand Down Expand Up @@ -49,16 +60,17 @@ def __init__(self, params, **kwargs):
cnn_output_size = self._calc_input_size(input_shape, self.cnn)

mlp_input_size = cnn_output_size + self.proprio_size
if self.use_aux_loss:
if self.use_aux_loss and self.aux_concat:
mlp_input_size += self.target_shape[0]

if len(self.units) == 0:
out_size = cnn_output_size
else:
out_size = self.units[-1]

if self.test_baseline:
mlp_input_size = self.proprio_size + self.target_shape[0]
self.layer_norm_emb = torch.nn.LayerNorm(mlp_input_size)
#self.layer_norm_emb = torch.nn.RMSNorm(mlp_input_size)

if self.has_rnn:
if not self.is_rnn_before_mlp:
Expand Down Expand Up @@ -118,6 +130,10 @@ def __init__(self, params, **kwargs):
if isinstance(m, nn.Linear):
mlp_init(m.weight)

if self.use_aux_loss:
#mlp_init(self.aux_loss_linear.weight)
torch.nn.init.xavier_normal_(self.aux_loss_linear.weight, gain=0.02)

if self.is_discrete:
mlp_init(self.logits.weight)
if self.is_continuous:
Expand All @@ -139,9 +155,6 @@ def forward(self, obs_dict):
else:
obs = obs_dict['obs']

if self.use_aux_loss:
target_obs = obs_dict['obs'][self.target_key]

if self.permute_input:
obs = obs.permute((0, 3, 1, 2))

Expand All @@ -154,18 +167,28 @@ def forward(self, obs_dict):
out = out.flatten(1)
cnn_out = self.flatten_act(out)

if self.proprio_size > 0:
# TODO: Add support for without proprioception obs case
if self.proprio_size > 0 and not self.test_baseline:
out = torch.cat([cnn_out, proprio], dim=1)

if self.use_aux_loss:
y = self.aux_loss_linear(cnn_out)
out = torch.cat([out, y], dim=1)
target_obs = obs_dict['obs'][self.target_key]

if self.aux_concat:
out = torch.cat([out, y.detach()], dim=1)
elif self.test_baseline:
out = torch.cat([proprio, target_obs], dim=1)
# print("aux predicted shape: ", y.shape)
# print("target shape: ", target_obs.shape)
# print("aux predicted: ", y[0])
# print("aux target: ", target_obs[0])
self.aux_loss_map['aux_dist_loss'] = self.aux_loss_weight * torch.nn.functional.mse_loss(y, target_obs)
# print("aux predicted shape: ", y.shape)
# print("aux predicted: ", y)
# print("aux target: ", target_obs)
# print("delta: ", y - target_obs)
# print("aux loss: ", self.aux_loss_map['aux_dist_loss'])
#print("aux loss: ", self.aux_loss_map['aux_dist_loss'])

out = self.layer_norm_emb(out)

Expand Down Expand Up @@ -302,7 +325,7 @@ def __init__(self, params, **kwargs):
self.actions_num = kwargs.pop('actions_num')
full_input_shape = kwargs.pop('input_shape')

self.use_aux_loss = kwargs.pop('use_aux_loss', False)
self.use_aux_loss = kwargs.pop('use_aux_loss', True)
self.aux_loss_weight = 100.0
if self.use_aux_loss:
self.target_key = 'aux_target'
Expand Down Expand Up @@ -420,9 +443,6 @@ def forward(self, obs_dict):
else:
obs = obs_dict['obs']

if self.use_aux_loss:
target_obs = obs_dict['obs'][self.target_key]

if self.permute_input:
obs = obs.permute((0, 3, 1, 2))

Expand All @@ -442,6 +462,7 @@ def forward(self, obs_dict):
out = torch.cat([vis_out, proprio], dim=1)

if self.use_aux_loss:
target_obs = obs_dict['obs'][self.target_key]
y = self.aux_loss_linear(vis_out)
out = torch.cat([out, y], dim=1)
self.aux_loss_map['aux_dist_loss'] = self.aux_loss_weight * torch.nn.functional.mse_loss(y, target_obs)
Expand Down Expand Up @@ -530,18 +551,36 @@ def _build_backbone(self, input_shape, backbone_params):
modify_first_layer = backbone_params.get('modify_first_layer', False)
self.preprocess_image = backbone_params.get('preprocess_image', False)

# Define a mapping from backbone type to required resize size
# Mapping from backbone_type to required resize size
resize_size_map = {
'vit_b_16': 224, # Must be divisible by 16
'dinov2_vits14_reg': 196, # Must be divisible by 14
'vit_mae': 224, # Must be divisible by 16
'vit_tiny': 224, # ViT-Tiny adjusted for 16x16 patch size and smaller input
'deit_tiny': 224, # DeiT-Tiny adjusted similarly
'deit_tiny_distilled': 224, # DeiT-Tiny distilled version
'mobilevit_s': 256, # MobileViT-S
'efficientformer_l2': 224, # EfficientFormer-L2
'swinv2': 224, # Swin Transformer
# Add other ViT variants as needed
'vit_b_16': 224,
'vit_mae': 224,
'vit_tiny': 224,
'deit_tiny': 224,
'deit_tiny_distilled': 224,
'mobilevit_s': 256,
'efficientformerv2_l': 224,
'efficientformerv2_s0': 224,
'efficientformerv2_s1': 224,
'efficientformerv2_s2': 224,
'swinv2': 256, # SwinV2 typically uses 256x256 input size
# Add other models as needed
}

# Mapping from backbone_type to actual model name in TIMM
model_name_map = {
'vit_tiny': 'vit_tiny_patch16_224',
'deit_tiny': 'deit_tiny_patch16_224',
'deit_tiny_distilled': 'deit_tiny_distilled_patch16_224',
# 'mobilevit_s': 'mobilevit_s',
'efficientformerv2_l': 'efficientformerv2_l.snap_dist_in1k',
'efficientformerv2_s0': 'efficientformerv2_s0.snap_dist_in1k',
'efficientformerv2_s1': 'efficientformerv2_s1.snap_dist_in1k',
'efficientformerv2_s2': 'efficientformerv2_s2.snap_dist_in1k',
# 'vit_b_16': 'vit_base_patch16_224',
'vit_mae': 'vit_base_patch16_224_mae',
# 'swinv2': 'swinv2_tiny_window8_256',
# Add other mappings as needed
}

backbone = None
Expand Down Expand Up @@ -592,32 +631,11 @@ def _build_backbone(self, input_shape, backbone_params):
)

elif backbone_type.lower() in resize_size_map:
import timm

# Unified ViT handling for various ViT models
resize_size = resize_size_map.get(backbone_type.lower(), 224) # Default to 224 if not specified

if backbone_type.lower() == 'vit_tiny':
backbone = timm.create_model('vit_tiny_patch16_224', pretrained=pretrained)
backbone_output_size = backbone.head.in_features
elif backbone_type.lower() == 'deit_tiny':
backbone = timm.create_model('deit_tiny_patch16_224', pretrained=pretrained)
backbone_output_size = backbone.head.in_features
elif backbone_type.lower() == 'deit_tiny_distilled':
backbone = timm.create_model('deit_tiny_distilled_patch16_224', pretrained=pretrained,)
backbone_output_size = backbone.head.in_features
elif backbone_type.lower() == 'mobilevit_s':
print("Not working")
backbone = timm.create_model('mobilevit_s', pretrained=pretrained)
print(backbone)
backbone_output_size = backbone.head.fc.in_features
print(backbone_output_size)
elif backbone_type.lower() == 'efficientformer_l2':
backbone = timm.create_model('efficientformerv2_s1.snap_dist_in1k', pretrained=pretrained)
# print(backbone)
# print(backbone.pos_embed.shape)
backbone_output_size = backbone.head.in_features
elif backbone_type.lower() == 'vit_b_16':
if backbone_type.lower() == 'vit_b_16':
backbone = models.vision_transformer.vit_b_16(pretrained=pretrained)
backbone_output_size = backbone.heads.head.in_features
elif backbone_type.lower() == 'dinov2_vits14_reg':
Expand All @@ -626,12 +644,51 @@ def _build_backbone(self, input_shape, backbone_params):
except Exception as e:
raise ValueError(f"Failed to load dinov2_vits14_reg: {e}")
backbone_output_size = 384 # As per Dinov2 ViT-S14 regression output
elif backbone_type.lower() == 'vit_mae':
elif backbone_type == 'mobilevit_s':
import timm
mobilevit_models = [model_name for model_name in timm.list_models() if 'mobilevit' in model_name.lower()]
print("Available MobileViT Models in timm:")
print(mobilevit_models)
model_name = 'mobilevitv2_100' #'mobilevit_s'
try:
backbone = timm.create_model(
model_name,
pretrained=pretrained,
num_classes=0, # Removes the classification head
# global_pool='avg' # Adds global average pooling
)
backbone_output_size = backbone.num_features
except Exception as e:
raise ValueError(f"Failed to create model '{model_name}': {e}")

# # Manually add global average pooling and flatten layers
# backbone = nn.Sequential(
# backbone,
# nn.AdaptiveAvgPool2d((1, 1)), # Global average pooling
# nn.Flatten(1) # Flatten from the first dimension onward
# )

# Add resize layer if needed
# resize_size = resize_size_map.get(backbone_type, 256)
resize_size = 160
resize_layer = nn.Upsample(size=(resize_size, resize_size), mode='bilinear', align_corners=False)
backbone = nn.Sequential(resize_layer, backbone)
elif backbone_type.lower() in model_name_map:
import timm
model_name = model_name_map.get(backbone_type)
if model_name is None:
raise ValueError(f'Unknown backbone type: {backbone_type}')
try:
backbone = timm.create_model('vit_base_patch16_224.mae', pretrained=pretrained)
backbone = timm.create_model(model_name, pretrained=pretrained)
if hasattr(backbone, 'embed_dim'):
backbone_output_size = backbone.embed_dim
elif hasattr(backbone, 'num_features'):
backbone_output_size = backbone.num_features
else:
raise ValueError(f"Unable to determine backbone output size for model '{model_name}'")
except Exception as e:
raise ValueError(f"Failed to load vit_mae: {e}")
backbone_output_size = 768 # Typically 768 for ViT-Base
raise ValueError(f"Failed to create model '{model_name}': {e}")

elif backbone_type.lower() == 'swinv2':
print("Not working")
try:
Expand All @@ -642,20 +699,44 @@ def _build_backbone(self, input_shape, backbone_params):
else:
raise ValueError(f'Unknown ViT model type: {backbone_type}')

# Remove the classification/regression head if present
if hasattr(backbone, 'heads') and hasattr(backbone.heads, 'head'):
backbone.heads.head = nn.Identity()
elif hasattr(backbone, 'head'):
# # Remove the classification/regression head if present
# if hasattr(backbone, 'heads') and hasattr(backbone.heads, 'head'):
# backbone.heads.head = nn.Identity()
# elif hasattr(backbone, 'head'):
# backbone.head = nn.Identity()
# elif hasattr(backbone, 'decoder'): # For MAE models
# backbone.decoder = nn.Identity()
# else:
# print(f"Unable to locate the classification/regression head in {backbone_type} model.")

# Get backbone output size
# if hasattr(backbone, 'embed_dim'):
# backbone_output_size = backbone.embed_dim
# elif hasattr(backbone, 'num_features'):
# backbone_output_size = backbone.num_features
# elif hasattr(backbone, 'head') and hasattr(backbone.head, 'in_features'):
# backbone_output_size = backbone.head.in_features
# else:
# raise ValueError(f"Unable to determine backbone output size for model '{model_name}'")

# # Add a resize layer to ensure the input is correctly sized for the specific ViT model
# resize_layer = nn.Upsample(size=(resize_size, resize_size), mode='bilinear', align_corners=False)

# # Combine the resize layer and the backbone into a sequential model
# backbone = nn.Sequential(resize_layer, backbone)

# Remove classification head
if hasattr(backbone, 'head'):
backbone.head = nn.Identity()
elif hasattr(backbone, 'decoder'): # For MAE models
backbone.decoder = nn.Identity()
elif hasattr(backbone, 'fc'):
backbone.fc = nn.Identity()
else:
print(f"Unable to locate the classification/regression head in {backbone_type} model.")
# TODO: Double-check with resnet
print(f"Unable to locate the classification head in {backbone_type} model.")

# Add a resize layer to ensure the input is correctly sized for the specific ViT model
# Add resize layer if needed
resize_size = resize_size_map.get(backbone_type, 224)
resize_layer = nn.Upsample(size=(resize_size, resize_size), mode='bilinear', align_corners=False)

# Combine the resize layer and the backbone into a sequential model
backbone = nn.Sequential(resize_layer, backbone)

else:
Expand All @@ -671,7 +752,7 @@ def _build_backbone(self, input_shape, backbone_params):
param.requires_grad = False
else:
# General case for other backbones (e.g., ResNet, ConvNeXt)
if 'conv1' not in name and 'features.0.0' not in name:
if 'conv1' not in name and 'features.0.0' not in name and 'patch_embed.proj' not in name and 'cls_token' not in name:
param.requires_grad = False

return backbone, backbone_output_size
Expand Down

0 comments on commit e2c6bd0

Please sign in to comment.