From d073f4dae8ff2cdc3e8200b9820bfdb37af8d34e Mon Sep 17 00:00:00 2001 From: ViktorM Date: Fri, 30 Aug 2024 22:36:45 -0700 Subject: [PATCH] Refactored vision backbone networks. Maniskill first pass. --- rl_games/algos_torch/model_builder.py | 6 - rl_games/algos_torch/network_builder.py | 525 +----------------- rl_games/common/a2c_common.py | 1 + rl_games/common/env_configurations.py | 27 +- rl_games/common/vecenv.py | 4 + rl_games/common/wrappers.py | 10 +- .../atari/ppo_pacman_envpool_resnet.yaml | 28 +- rl_games/configs/atari/ppo_pong_envpool.yaml | 5 +- rl_games/configs/maniskill/maniskill.yaml | 62 +++ .../maniskill_resnet.yaml} | 46 +- rl_games/envs/envpool.py | 9 +- rl_games/envs/maniskill.py | 216 +++++++ rl_games/networks/__init__.py | 6 +- rl_games/networks/vision_networks.py | 147 +++-- 14 files changed, 457 insertions(+), 635 deletions(-) create mode 100644 rl_games/configs/maniskill/maniskill.yaml rename rl_games/configs/{atari/ppo_pong_envpool_backbone.yaml => maniskill/maniskill_resnet.yaml} (64%) create mode 100644 rl_games/envs/maniskill.py diff --git a/rl_games/algos_torch/model_builder.py b/rl_games/algos_torch/model_builder.py index c22fca0e..63c5908b 100644 --- a/rl_games/algos_torch/model_builder.py +++ b/rl_games/algos_torch/model_builder.py @@ -2,7 +2,6 @@ import rl_games.algos_torch from rl_games.algos_torch import network_builder, models - NETWORK_REGISTRY = {} MODEL_REGISTRY = {} @@ -20,12 +19,7 @@ def __init__(self): self.network_factory.register_builder('actor_critic', lambda **kwargs: network_builder.A2CBuilder()) self.network_factory.register_builder('resnet_actor_critic', lambda **kwargs: network_builder.A2CResnetBuilder()) - self.network_factory.register_builder('vision_actor_critic', - lambda **kwargs: network_builder.A2CVisionBuilder()) - self.network_factory.register_builder('e2e_vision_actor_critic', - lambda **kwargs: network_builder.VisionBackboneBuilder()) - self.network_factory.register_builder('rnd_curiosity', lambda **kwargs: network_builder.RNDCuriosityBuilder()) self.network_factory.register_builder('soft_actor_critic', lambda **kwargs: network_builder.SACBuilder()) def load(self, params): diff --git a/rl_games/algos_torch/network_builder.py b/rl_games/algos_torch/network_builder.py index 896555ca..4b0acfa3 100644 --- a/rl_games/algos_torch/network_builder.py +++ b/rl_games/algos_torch/network_builder.py @@ -217,7 +217,7 @@ def __init__(self, params, **kwargs): cnn_args = { 'ctype' : self.cnn['type'], 'input_shape' : input_shape, - 'convs' :self.cnn['convs'], + 'convs' : self.cnn['convs'], 'activation' : self.cnn['activation'], 'norm_func_name' : self.normalization, } @@ -848,529 +848,6 @@ def build(self, name, **kwargs): return net -class A2CVisionBuilder(NetworkBuilder): - def __init__(self, **kwargs): - NetworkBuilder.__init__(self) - - def load(self, params): - self.params = params - - class Network(NetworkBuilder.BaseNetwork): - def __init__(self, params, **kwargs): - self.actions_num = actions_num = kwargs.pop('actions_num') - full_input_shape = kwargs.pop('input_shape') - proprio_size = 0 # Number of proprioceptive features - if type(full_input_shape) is dict: - input_shape = full_input_shape['camera'] - proprio_shape = full_input_shape['proprio'] - proprio_size = proprio_shape[0] - else: - input_shape = full_input_shape - - self.num_seqs = kwargs.pop('num_seqs', 1) - self.value_size = kwargs.pop('value_size', 1) - - NetworkBuilder.BaseNetwork.__init__(self) - self.load(params) - if self.permute_input: - input_shape = torch_ext.shape_whc_to_cwh(input_shape) - - self.cnn = self._build_impala(input_shape, self.conv_depths) - cnn_output_size = self._calc_input_size(input_shape, self.cnn) - - mlp_input_size = cnn_output_size + proprio_size - if len(self.units) == 0: - out_size = cnn_output_size - else: - out_size = self.units[-1] - - if self.has_rnn: - if not self.is_rnn_before_mlp: - rnn_in_size = out_size - out_size = self.rnn_units - else: - rnn_in_size = mlp_input_size - mlp_input_size = self.rnn_units - - self.rnn = self._build_rnn(self.rnn_name, rnn_in_size, self.rnn_units, self.rnn_layers) - self.layer_norm = torch.nn.LayerNorm(self.rnn_units) - - mlp_args = { - 'input_size' : mlp_input_size, - 'units' :self.units, - 'activation' : self.activation, - 'norm_func_name' : self.normalization, - 'dense_func' : torch.nn.Linear - } - - self.mlp = self._build_mlp(**mlp_args) - - self.value = self._build_value_layer(out_size, self.value_size) - self.value_act = self.activations_factory.create(self.value_activation) - self.flatten_act = self.activations_factory.create(self.activation) - - if self.is_discrete: - self.logits = torch.nn.Linear(out_size, actions_num) - if self.is_continuous: - self.mu = torch.nn.Linear(out_size, actions_num) - self.mu_act = self.activations_factory.create(self.space_config['mu_activation']) - mu_init = self.init_factory.create(**self.space_config['mu_init']) - self.sigma_act = self.activations_factory.create(self.space_config['sigma_activation']) - sigma_init = self.init_factory.create(**self.space_config['sigma_init']) - - if self.fixed_sigma: - self.sigma = nn.Parameter(torch.zeros(actions_num, requires_grad=True, dtype=torch.float32), requires_grad=True) - else: - self.sigma = torch.nn.Linear(out_size, actions_num) - - mlp_init = self.init_factory.create(**self.initializer) - - for m in self.modules(): - if isinstance(m, nn.Conv2d): - nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') - for m in self.mlp: - if isinstance(m, nn.Linear): - mlp_init(m.weight) - - if self.is_discrete: - mlp_init(self.logits.weight) - if self.is_continuous: - mu_init(self.mu.weight) - if self.fixed_sigma: - sigma_init(self.sigma) - else: - sigma_init(self.sigma.weight) - - mlp_init(self.value.weight) - - def forward(self, obs_dict): - # print(obs_dict.keys()) - # print(obs_dict['obs'].keys()) - # currently works only dictinary of camera and proprio observations - obs = obs_dict['obs']['camera'] - proprio = obs_dict['obs']['proprio'] - if self.permute_input: - obs = obs.permute((0, 3, 1, 2)) - - dones = obs_dict.get('dones', None) - bptt_len = obs_dict.get('bptt_len', 0) - states = obs_dict.get('rnn_states', None) - - out = obs - out = self.cnn(out) - out = out.flatten(1) - out = self.flatten_act(out) - - out = torch.cat([out, proprio], dim=1) - - if self.has_rnn: - # TODO: Double check, it's not lways present!!! - #seq_length = obs_dict['seq_length'] - seq_length = obs_dict.get('seq_length', 1) - - out_in = out - if not self.is_rnn_before_mlp: - out_in = out - out = self.mlp(out) - - batch_size = out.size()[0] - num_seqs = batch_size // seq_length - out = out.reshape(num_seqs, seq_length, -1) - - if len(states) == 1: - states = states[0] - - out = out.transpose(0, 1) - if dones is not None: - dones = dones.reshape(num_seqs, seq_length, -1) - dones = dones.transpose(0, 1) - out, states = self.rnn(out, states, dones, bptt_len) - out = out.transpose(0, 1) - out = out.contiguous().reshape(out.size()[0] * out.size()[1], -1) - - if self.rnn_ln: - out = self.layer_norm(out) - if self.is_rnn_before_mlp: - out = self.mlp(out) - if type(states) is not tuple: - states = (states,) - else: - out = self.mlp(out) - - value = self.value_act(self.value(out)) - - if self.is_discrete: - logits = self.logits(out) - return logits, value, states - - if self.is_continuous: - mu = self.mu_act(self.mu(out)) - if self.fixed_sigma: - sigma = self.sigma_act(self.sigma) - else: - sigma = self.sigma_act(self.sigma(out)) - return mu, mu*0 + sigma, value, states - - def load(self, params): - self.separate = False - self.units = params['mlp']['units'] - self.activation = params['mlp']['activation'] - self.initializer = params['mlp']['initializer'] - self.is_discrete = 'discrete' in params['space'] - self.is_continuous = 'continuous' in params['space'] - self.is_multi_discrete = 'multi_discrete'in params['space'] - self.value_activation = params.get('value_activation', 'None') - self.normalization = params.get('normalization', None) - - if self.is_continuous: - self.space_config = params['space']['continuous'] - self.fixed_sigma = self.space_config['fixed_sigma'] - elif self.is_discrete: - self.space_config = params['space']['discrete'] - elif self.is_multi_discrete: - self.space_config = params['space']['multi_discrete'] - - self.has_rnn = 'rnn' in params - if self.has_rnn: - self.rnn_units = params['rnn']['units'] - self.rnn_layers = params['rnn']['layers'] - self.rnn_name = params['rnn']['name'] - self.is_rnn_before_mlp = params['rnn'].get('before_mlp', False) - self.rnn_ln = params['rnn'].get('layer_norm', False) - - self.has_cnn = True - self.permute_input = params['cnn'].get('permute_input', True) - self.conv_depths = params['cnn']['conv_depths'] - self.require_rewards = params.get('require_rewards') - self.require_last_actions = params.get('require_last_actions') - - def _build_impala(self, input_shape, depths): - in_channels = input_shape[0] - layers = nn.ModuleList() - for d in depths: - layers.append(ImpalaSequential(in_channels, d)) - in_channels = d - return nn.Sequential(*layers) - - def is_separate_critic(self): - return False - - def is_rnn(self): - return self.has_rnn - - def get_default_rnn_state(self): - num_layers = self.rnn_layers - if self.rnn_name == 'lstm': - return (torch.zeros((num_layers, self.num_seqs, self.rnn_units)), - torch.zeros((num_layers, self.num_seqs, self.rnn_units))) - else: - return (torch.zeros((num_layers, self.num_seqs, self.rnn_units))) - - def build(self, name, **kwargs): - net = A2CVisionBuilder.Network(self.params, **kwargs) - return net - - -from torchvision import models, transforms - -def preprocess_image(image): - # Normalize the image using ImageNet's mean and standard deviation - normalize = transforms.Normalize( - mean=[0.485, 0.456, 0.406], # Mean of ImageNet dataset - std=[0.229, 0.224, 0.225] # Std of ImageNet dataset - ) - - # Apply the normalization - image = normalize(image) - - return image - - -class VisionBackboneBuilder(NetworkBuilder): - def __init__(self, **kwargs): - NetworkBuilder.__init__(self) - - def load(self, params): - self.params = params - - class Network(NetworkBuilder.BaseNetwork): - def __init__(self, params, **kwargs): - self.actions_num = kwargs.pop('actions_num') - full_input_shape = kwargs.pop('input_shape') - - print("Observations shape: ", full_input_shape) - - self.proprio_size = 0 # Number of proprioceptive features - if isinstance(full_input_shape, dict): - input_shape = full_input_shape['camera'] - proprio_shape = full_input_shape['proprio'] - self.proprio_size = proprio_shape[0] - else: - input_shape = full_input_shape - - self.num_seqs = kwargs.pop('num_seqs', 1) - self.value_size = kwargs.pop('value_size', 1) - - NetworkBuilder.BaseNetwork.__init__(self) - self.load(params) - if self.permute_input: - input_shape = torch_ext.shape_whc_to_cwh(input_shape) - - self.cnn, self.cnn_output_size = self._build_backbone(input_shape, params['backbone']) - - mlp_input_size = self.cnn_output_size + self.proprio_size - if len(self.units) == 0: - out_size = self.cnn_output_size - else: - out_size = self.units[-1] - - if self.has_rnn: - if not self.is_rnn_before_mlp: - rnn_in_size = out_size - out_size = self.rnn_units - else: - rnn_in_size = mlp_input_size - mlp_input_size = self.rnn_units - - self.rnn = self._build_rnn(self.rnn_name, rnn_in_size, self.rnn_units, self.rnn_layers) - self.layer_norm = torch.nn.LayerNorm(self.rnn_units) - - mlp_args = { - 'input_size': mlp_input_size, - 'units': self.units, - 'activation': self.activation, - 'norm_func_name': self.normalization, - 'dense_func': torch.nn.Linear - } - - self.mlp = self._build_mlp(**mlp_args) - - self.value = self._build_value_layer(out_size, self.value_size) - self.value_act = self.activations_factory.create(self.value_activation) - self.flatten_act = self.activations_factory.create(self.activation) - - if self.is_discrete: - self.logits = torch.nn.Linear(out_size, self.actions_num) - if self.is_continuous: - self.mu = torch.nn.Linear(out_size, self.actions_num) - self.mu_act = self.activations_factory.create(self.space_config['mu_activation']) - mu_init = self.init_factory.create(**self.space_config['mu_init']) - self.sigma_act = self.activations_factory.create(self.space_config['sigma_activation']) - sigma_init = self.init_factory.create(**self.space_config['sigma_init']) - - if self.fixed_sigma: - self.sigma = nn.Parameter(torch.zeros(self.actions_num, requires_grad=True, dtype=torch.float32), requires_grad=True) - else: - self.sigma = torch.nn.Linear(out_size, self.actions_num) - - mlp_init = self.init_factory.create(**self.initializer) - - for m in self.mlp: - if isinstance(m, nn.Linear): - mlp_init(m.weight) - - if self.is_discrete: - mlp_init(self.logits.weight) - if self.is_continuous: - mu_init(self.mu.weight) - if self.fixed_sigma: - sigma_init(self.sigma) - else: - sigma_init(self.sigma.weight) - - mlp_init(self.value.weight) - - def forward(self, obs_dict): - if self.proprio_size > 0: - obs = obs_dict['obs']['camera'] - proprio = obs_dict['obs']['proprio'] - else: - obs = obs_dict['obs'] - - if self.permute_input: - obs = obs.permute((0, 3, 1, 2)) - - if self.preprocess_image: - obs = preprocess_image(obs) - - dones = obs_dict.get('dones', None) - bptt_len = obs_dict.get('bptt_len', 0) - states = obs_dict.get('rnn_states', None) - - out = obs - out = self.cnn(out) - out = out.flatten(1) - out = self.flatten_act(out) - - if self.proprio_size > 0: - out = torch.cat([out, proprio], dim=1) - - if self.has_rnn: - seq_length = obs_dict.get('seq_length', 1) - - out_in = out - if not self.is_rnn_before_mlp: - out_in = out - out = self.mlp(out) - - batch_size = out.size()[0] - num_seqs = batch_size // seq_length - out = out.reshape(num_seqs, seq_length, -1) - - if len(states) == 1: - states = states[0] - - out = out.transpose(0, 1) - if dones is not None: - dones = dones.reshape(num_seqs, seq_length, -1) - dones = dones.transpose(0, 1) - out, states = self.rnn(out, states, dones, bptt_len) - out = out.transpose(0, 1) - out = out.contiguous().reshape(out.size()[0] * out.size()[1], -1) - - if self.rnn_ln: - out = self.layer_norm(out) - if self.is_rnn_before_mlp: - out = self.mlp(out) - if not isinstance(states, tuple): - states = (states,) - else: - out = self.mlp(out) - - value = self.value_act(self.value(out)) - - if self.is_discrete: - logits = self.logits(out) - return logits, value, states - - if self.is_continuous: - mu = self.mu_act(self.mu(out)) - if self.fixed_sigma: - sigma = self.sigma_act(self.sigma) - else: - sigma = self.sigma_act(self.sigma(out)) - return mu, mu * 0 + sigma, value, states - - def load(self, params): - self.separate = False - self.units = params['mlp']['units'] - self.activation = params['mlp']['activation'] - self.initializer = params['mlp']['initializer'] - self.is_discrete = 'discrete' in params['space'] - self.is_continuous = 'continuous' in params['space'] - self.is_multi_discrete = 'multi_discrete' in params['space'] - self.value_activation = params.get('value_activation', 'None') - self.normalization = params.get('normalization', None) - - if self.is_continuous: - self.space_config = params['space']['continuous'] - self.fixed_sigma = self.space_config['fixed_sigma'] - elif self.is_discrete: - self.space_config = params['space']['discrete'] - elif self.is_multi_discrete: - self.space_config = params['space']['multi_discrete'] - - self.has_rnn = 'rnn' in params - if self.has_rnn: - self.rnn_units = params['rnn']['units'] - self.rnn_layers = params['rnn']['layers'] - self.rnn_name = params['rnn']['name'] - self.is_rnn_before_mlp = params['rnn'].get('before_mlp', False) - self.rnn_ln = params['rnn'].get('layer_norm', False) - - self.has_cnn = True - self.permute_input = params['backbone'].get('permute_input', True) - self.require_rewards = params.get('require_rewards') - self.require_last_actions = params.get('require_last_actions') - - def _build_backbone(self, input_shape, backbone_params): - backbone_type = backbone_params['type'] - pretrained = backbone_params.get('pretrained', False) - self.preprocess_image = backbone_params.get('preprocess_image', False) - - if backbone_type == 'resnet18' or backbone_type == 'resnet34': - if backbone_type == 'resnet18': - backbone = models.resnet18(pretrained=pretrained, zero_init_residual=True) - else: - backbone = models.resnet34(pretrained=pretrained, zero_init_residual=True) - - # Modify the first convolution layer to match input shape if needed - # TODO: add low-res parameter - backbone.conv1 = nn.Conv2d(input_shape[0], 64, kernel_size=3, stride=1, padding=1, bias=False) - # backbone.maxpool = nn.Identity() - # if input_shape[0] != 3: - # backbone.conv1 = nn.Conv2d(input_shape[0], 64, kernel_size=7, stride=2, padding=3, bias=False) - # Remove the fully connected layer - backbone_output_size = backbone.fc.in_features - print('backbone_output_size: ', backbone_output_size) - backbone = nn.Sequential(*list(backbone.children())[:-1]) - elif backbone_type == 'convnext_tiny': - backbone = models.convnext_tiny(pretrained=pretrained) - backbone_output_size = backbone.classifier[2].in_features - backbone.classifier = nn.Identity() - - # Modify the first convolutional layer to work with smaller resolutions - backbone.features[0][0] = nn.Conv2d( - in_channels=input_shape[0], - out_channels=backbone.features[0][0].out_channels, - kernel_size=3, # Reduce kernel size to 3x3 - stride=1, # Reduce stride to 1 to preserve spatial resolution - padding=1, # Add padding to preserve dimensions after convolution - bias=True # False - ) - elif backbone_type == 'efficientnet_v2_s': - backbone = models.efficientnet_v2_s(pretrained=pretrained) - backbone.features[0][0] = nn.Conv2d(input_shape[0], 24, kernel_size=3, stride=1, padding=1, bias=False) - backbone_output_size = backbone.classifier[1].in_features - backbone.classifier = nn.Identity() - elif backbone_type == 'vit_b_16': - backbone = models.vision_transformer.vit_b_16(pretrained=pretrained) - - # Add a resize layer to ensure the input is correctly sized for ViT - resize_layer = nn.Upsample(size=(224, 224), mode='bilinear', align_corners=False) - - backbone_output_size = backbone.heads.head.in_features - backbone.heads.head = nn.Identity() - - # Combine the resize layer and the backbone into a sequential model - backbone = nn.Sequential(resize_layer, backbone) - # # Assuming your input image is a tensor or PIL image, resize it to 224x224 - # #obs = self.resize_transform(obs) - # backbone = models.vision_transformer.vit_b_16(pretrained=pretrained) - - # backbone_output_size = backbone.heads.head.in_features - # backbone.heads.head = nn.Identity() - else: - raise ValueError(f'Unknown backbone type: {backbone_type}') - - # Optionally freeze the follow-up layers, leaving the first convolutional layer unfrozen - if backbone_params.get('freeze', False): - print('Freezing backbone') - for name, param in backbone.named_parameters(): - if 'conv1' not in name: # Ensure the first conv layer is not frozen - param.requires_grad = False - - return backbone, backbone_output_size - - def is_separate_critic(self): - return False - - def is_rnn(self): - return self.has_rnn - - def get_default_rnn_state(self): - num_layers = self.rnn_layers - if self.rnn_name == 'lstm': - return (torch.zeros((num_layers, self.num_seqs, self.rnn_units)), - torch.zeros((num_layers, self.num_seqs, self.rnn_units))) - else: - return (torch.zeros((num_layers, self.num_seqs, self.rnn_units))) - - def build(self, name, **kwargs): - net = VisionBackboneBuilder.Network(self.params, **kwargs) - return net - - class DiagGaussianActor(NetworkBuilder.BaseNetwork): """torch.distributions implementation of an diagonal Gaussian policy.""" def __init__(self, output_dim, log_std_bounds, **mlp_args): diff --git a/rl_games/common/a2c_common.py b/rl_games/common/a2c_common.py index 5f484860..1870f316 100644 --- a/rl_games/common/a2c_common.py +++ b/rl_games/common/a2c_common.py @@ -12,6 +12,7 @@ from rl_games.common.diagnostics import DefaultDiagnostics, PpoDiagnostics from rl_games.algos_torch import model_builder from rl_games.interfaces.base_algorithm import BaseAlgorithm +from rl_games import networks import numpy as np import time import gym diff --git a/rl_games/common/env_configurations.py b/rl_games/common/env_configurations.py index 43c8ebe1..98480c0d 100644 --- a/rl_games/common/env_configurations.py +++ b/rl_games/common/env_configurations.py @@ -3,6 +3,7 @@ from rl_games.common import tr_helpers from rl_games.envs.brax import create_brax_env from rl_games.envs.envpool import create_envpool +from rl_games.envs.maniskill import create_maniskill from rl_games.envs.cule import create_cule import gym from gym.wrappers import FlattenObservation, FilterObservation @@ -10,7 +11,6 @@ import math - class HCRewardEnv(gym.RewardWrapper): def __init__(self, env): gym.RewardWrapper.__init__(self, env) @@ -34,8 +34,6 @@ def step(self, action): return observation, reward, done, info - - class DMControlObsWrapper(gym.ObservationWrapper): def __init__(self, env): gym.RewardWrapper.__init__(self, env) @@ -86,14 +84,16 @@ def create_slime_gym_env(**kwargs): env = gym.make(name, **kwargs) return env - def create_atari_gym_env(**kwargs): #frames = kwargs.pop('frames', 1) name = kwargs.pop('name') skip = kwargs.pop('skip',4) episode_life = kwargs.pop('episode_life',True) wrap_impala = kwargs.pop('wrap_impala', False) - env = wrappers.make_atari_deepmind(name, skip=skip,episode_life=episode_life, wrap_impala=wrap_impala, **kwargs) + gray_scale = kwargs.pop('gray_scale',True) + frame_stack = kwargs.pop('frame_stack', True) + env = wrappers.make_atari_deepmind(name, skip=skip,episode_life=episode_life, wrap_impala=wrap_impala, + gray_scale=gray_scale, frame_stack=frame_stack, **kwargs) return env def create_dm_control_env(**kwargs): @@ -154,14 +154,15 @@ def create_roboschool_env(name): def create_smac(name, **kwargs): from rl_games.envs.smac_env import SMACEnv, MultiDiscreteSmacWrapper + frames = kwargs.pop('frames', 1) transpose = kwargs.pop('transpose', False) flatten = kwargs.pop('flatten', True) has_cv = kwargs.get('central_value', False) as_single_agent = kwargs.pop('as_single_agent', False) + env = SMACEnv(name, **kwargs) - - + if frames > 1: if has_cv: env = wrappers.BatchedFrameStackWithStates(env, frames, transpose=False, flatten=flatten) @@ -170,6 +171,7 @@ def create_smac(name, **kwargs): if as_single_agent: env = MultiDiscreteSmacWrapper(env) + return env def create_smac_v2(name, **kwargs): @@ -179,7 +181,7 @@ def create_smac_v2(name, **kwargs): flatten = kwargs.pop('flatten', True) has_cv = kwargs.get('central_value', False) env = SMACEnvV2(name, **kwargs) - + if frames > 1: if has_cv: env = wrappers.BatchedFrameStackWithStates(env, frames, transpose=False, flatten=flatten) @@ -192,6 +194,7 @@ def create_smac_cnn(name, **kwargs): has_cv = kwargs.get('central_value', False) frames = kwargs.pop('frames', 4) transpose = kwargs.pop('transpose', False) + as_single_agent = kwargs.pop('as_single_agent', False) env = SMACEnv(name, **kwargs) if has_cv: @@ -200,6 +203,7 @@ def create_smac_cnn(name, **kwargs): env = wrappers.BatchedFrameStack(env, frames, transpose=transpose) if as_single_agent: env = MultiDiscreteSmacWrapper(env) + return env def create_test_env(name, **kwargs): @@ -211,7 +215,6 @@ def create_minigrid_env(name, **kwargs): import gym_minigrid import gym_minigrid.wrappers - state_bonus = kwargs.pop('state_bonus', False) action_bonus = kwargs.pop('action_bonus', False) rgb_fully_obs = kwargs.pop('rgb_fully_obs', False) @@ -219,7 +222,6 @@ def create_minigrid_env(name, **kwargs): view_size = kwargs.pop('view_size', 3) env = gym.make(name, **kwargs) - if state_bonus: env = gym_minigrid.wrappers.StateBonus(env) if action_bonus: @@ -423,6 +425,10 @@ def create_env(name, **kwargs): 'env_creator': lambda **kwargs: create_envpool(**kwargs), 'vecenv_type': 'ENVPOOL' }, + 'maniskill': { + 'env_creator': lambda **kwargs: create_maniskill(**kwargs), + 'vecenv_type': 'MANISKILL' + }, 'cule': { 'env_creator': lambda **kwargs: create_cule(**kwargs), 'vecenv_type': 'CULE' @@ -457,7 +463,6 @@ def get_obs_and_action_spaces_from_config(config): env.close() return result_shapes - def register(name, config): """Add a new key-value pair to the known environments (configurations dict). diff --git a/rl_games/common/vecenv.py b/rl_games/common/vecenv.py index 1fc37e9c..89bd8096 100644 --- a/rl_games/common/vecenv.py +++ b/rl_games/common/vecenv.py @@ -7,6 +7,7 @@ from time import sleep import torch + class RayWorker: """Wrapper around a third-party (gym for example) environment class that enables parallel training. @@ -282,5 +283,8 @@ def create_vec_env(config_name, num_actors, **kwargs): from rl_games.envs.envpool import Envpool register('ENVPOOL', lambda config_name, num_actors, **kwargs: Envpool(config_name, num_actors, **kwargs)) +from rl_games.envs.maniskill import Maniskill +register('MANISKILL', lambda config_name, num_actors, **kwargs: Maniskill(config_name, num_actors, **kwargs)) + from rl_games.envs.cule import CuleEnv register('CULE', lambda config_name, num_actors, **kwargs: CuleEnv(config_name, num_actors, **kwargs)) \ No newline at end of file diff --git a/rl_games/common/wrappers.py b/rl_games/common/wrappers.py index a62e0855..23696ed4 100644 --- a/rl_games/common/wrappers.py +++ b/rl_games/common/wrappers.py @@ -646,14 +646,15 @@ def make_atari(env_id, timelimit=True, noop_max=0, skip=4, sticky=False, directo #env = EpisodeStackedEnv(env) return env -def wrap_deepmind(env, episode_life=False, clip_rewards=True, frame_stack=True, scale =False, wrap_impala=False): +def wrap_deepmind(env, episode_life=False, clip_rewards=True, scale=False, + wrap_impala=False, frame_stack=True, gray_scale=True): """Configure environment for DeepMind-style Atari. """ if episode_life: env = EpisodicLifeEnv(env) if 'FIRE' in env.unwrapped.get_action_meanings(): env = FireResetEnv(env) - env = WarpFrame(env) + env = WarpFrame(env, grayscale=gray_scale) if scale: env = ScaledFloatFrame(env) if clip_rewards: @@ -680,7 +681,8 @@ def make_car_racing(env_id, skip=4): env = make_atari(env_id, noop_max=0, skip=skip) return wrap_carracing(env, clip_rewards=False) -def make_atari_deepmind(env_id, noop_max=30, skip=4, sticky=False, episode_life=True, wrap_impala=False, **kwargs): +def make_atari_deepmind(env_id, noop_max=30, skip=4, sticky=False, episode_life=True, + wrap_impala=False, frame_stack=True, gray_scale=True, **kwargs): env = make_atari(env_id, noop_max=noop_max, skip=skip, sticky=sticky, **kwargs) - return wrap_deepmind(env, episode_life=episode_life, clip_rewards=False, wrap_impala=wrap_impala) + return wrap_deepmind(env, episode_life=episode_life, clip_rewards=False, wrap_impala=wrap_impala, frame_stack=frame_stack, gray_scale=gray_scale) diff --git a/rl_games/configs/atari/ppo_pacman_envpool_resnet.yaml b/rl_games/configs/atari/ppo_pacman_envpool_resnet.yaml index c1cf0178..dc8754d2 100644 --- a/rl_games/configs/atari/ppo_pacman_envpool_resnet.yaml +++ b/rl_games/configs/atari/ppo_pacman_envpool_resnet.yaml @@ -24,38 +24,38 @@ params: norm_layer: None mlp: - units: [256] - activation: relu #elu + units: [512] + activation: relu regularizer: name: None initializer: name: default - rnn: name: lstm units: 512 layers: 1 before_mlp: True + concat_output: True + config: - name: Pacman_resnet18_LSTM_before_MLP_rew_shaper_100 + name: Pacman_resnet18_LSTM_before_MLP_concat_output_rew_shaper_100_norm env_name: envpool - reward_shaper: - min_val: -100 - max_val: 100 - mixed_precision: True - normalize_input: False + normalize_input: True normalize_value: True normalize_advantage: True + reward_shaper: + min_val: -100 + max_val: 100 gamma: 0.99 tau: 0.95 - learning_rate: 2e-4 - score_to_win: 100000 grad_norm: 1.0 entropy_coef: 0.01 truncate_grads: True - + learning_rate: 2e-4 + lr_schedule: linear + kl_threshold: 0.01 e_clip: 0.2 clip_value: True save_best_after: 25 @@ -65,11 +65,9 @@ params: minibatch_size: 2048 mini_epochs: 2 critic_coef: 1 - lr_schedule: linear - kl_threshold: 0.01 use_diagnostics: True seq_length: 8 - max_epochs: 10000 + max_epochs: 20000 #weight_decay: 0.001 env_config: diff --git a/rl_games/configs/atari/ppo_pong_envpool.yaml b/rl_games/configs/atari/ppo_pong_envpool.yaml index bd0844df..b3d7844e 100644 --- a/rl_games/configs/atari/ppo_pong_envpool.yaml +++ b/rl_games/configs/atari/ppo_pong_envpool.yaml @@ -47,12 +47,13 @@ params: name: Pong-v5_envpool env_name: envpool score_to_win: 20.0 - normalize_value: True + mixed_precision: True + normalize_value: True normalize_input: True + normalize_advantage: True reward_shaper: min_val: -1 max_val: 1 - normalize_advantage: True gamma: 0.99 tau: 0.95 learning_rate: 3e-4 diff --git a/rl_games/configs/maniskill/maniskill.yaml b/rl_games/configs/maniskill/maniskill.yaml new file mode 100644 index 00000000..59367677 --- /dev/null +++ b/rl_games/configs/maniskill/maniskill.yaml @@ -0,0 +1,62 @@ +params: + seed: 5 + algo: + name: a2c_continuous + + model: + name: continuous_a2c_logstd + + network: + name: actor_critic + separate: False + space: + continuous: + mu_activation: None + sigma_activation: None + mu_init: + name: default + sigma_init: + name: const_initializer + val: 0 + fixed_sigma: True + mlp: + units: [512, 256, 128] + activation: elu + initializer: + name: default + + config: + name: Maniskill + env_name: maniskill + normalize_input: True + normalize_value: True + value_bootstrap: True + reward_shaper: + scale_value: 1.0 + normalize_advantage: True + gamma: 0.99 + tau: 0.95 + + learning_rate: 3e-4 + lr_schedule: adaptive + kl_threshold: 0.008 + grad_norm: 1.0 + entropy_coef: 0.0 + truncate_grads: True + e_clip: 0.2 + clip_value: True + use_smooth_clamp: True + bound_loss_type: regularisation + bounds_loss_coef: 0.0005 + max_epochs: 2000 + num_actors: 4096 + horizon_length: 64 + minibatch_size: 16384 + mini_epochs: 5 + critic_coef: 2 + + env_config: + env_name: PickCube-v1 # todo: add list of all envs + + player: + render: True \ No newline at end of file diff --git a/rl_games/configs/atari/ppo_pong_envpool_backbone.yaml b/rl_games/configs/maniskill/maniskill_resnet.yaml similarity index 64% rename from rl_games/configs/atari/ppo_pong_envpool_backbone.yaml rename to rl_games/configs/maniskill/maniskill_resnet.yaml index f43b00de..f36f4e6c 100644 --- a/rl_games/configs/atari/ppo_pong_envpool_backbone.yaml +++ b/rl_games/configs/maniskill/maniskill_resnet.yaml @@ -10,75 +10,65 @@ params: separate: False value_shape: 1 space: - discrete: + continuous: backbone: - type: vit_b_16 #efficientnet_v2_s #convnext_tiny #vit_b_16 #resnet18 #resnet34 + type: resnet18 # can be efficientnet_v2_s #convnext_tiny #vit_b_16 #resnet18 #resnet34 pretrained: True permute_input: False - freeze: True + freeze: False preprocess_image: True - args: zero_init_residual: True norm_layer: None mlp: - units: [256] - activation: relu #elu + units: [512, 256] + activation: elu regularizer: name: None initializer: name: default - rnn: name: lstm units: 512 layers: 1 before_mlp: True - config: - name: Pong_vit_b_16_color - #name: Pong_resnet18_maxpool_LSTM_before_MLP_ELU - env_name: envpool - reward_shaper: - min_val: -1 - max_val: 1 + concat_output: True + config: + name: Maniskill_resnet18 + env_name: maniskill + score_to_win: 20.0 mixed_precision: True normalize_input: False normalize_value: True normalize_advantage: True gamma: 0.99 tau: 0.95 - learning_rate: 2e-4 - - score_to_win: 100000 grad_norm: 1.0 entropy_coef: 0.01 truncate_grads: True - e_clip: 0.2 clip_value: True save_best_after: 25 - save_frequency: 500 - num_actors: 32 - horizon_length: 64 - minibatch_size: 512 + save_frequency: 200 + num_actors: 64 + horizon_length: 128 + minibatch_size: 2048 mini_epochs: 2 critic_coef: 1 + learning_rate: 2e-4 lr_schedule: linear kl_threshold: 0.01 use_diagnostics: True seq_length: 8 - max_epochs: 10000 + max_epochs: 500 #weight_decay: 0.001 env_config: - env_name: Pong-v5 - has_lives: False - use_dict_obs_space: False #True - stack_num: 1 - gray_scale: False + env_name: PickCube-v1 + player: render: True games_num: 10 diff --git a/rl_games/envs/envpool.py b/rl_games/envs/envpool.py index e89d7403..9ddbab68 100644 --- a/rl_games/envs/envpool.py +++ b/rl_games/envs/envpool.py @@ -11,6 +11,7 @@ def flatten_dict(obs): res = np.column_stack(res) return res + class Envpool(IVecEnv): def __init__(self, config_name, num_actors, **kwargs): import envpool @@ -26,7 +27,7 @@ def __init__(self, config_name, num_actors, **kwargs): batch_size=self.batch_size, **kwargs ) - + if self.use_dict_obs_space: self.observation_space = gym.spaces.Dict({ 'observation' : self.env.observation_space, @@ -63,7 +64,7 @@ def _set_scores(self, infos, dones): self.scores *= 1 - dones def step(self, action): - next_obs, reward, is_done, info = self.env.step(action , self.ids) + next_obs, reward, is_done, info = self.env.step(action, self.ids) info['time_outs'] = info['TimeLimit.truncated'] self._set_scores(info, is_done) if self.flatten_obs: @@ -86,7 +87,7 @@ def reset(self): 'reward': np.zeros(obs.shape[0]), 'last_action': np.zeros(obs.shape[0]), } - + return obs def get_number_of_agents(self): @@ -99,7 +100,5 @@ def get_env_info(self): return info - - def create_envpool(**kwargs): return Envpool("", kwargs.pop('num_actors', 16), **kwargs) \ No newline at end of file diff --git a/rl_games/envs/maniskill.py b/rl_games/envs/maniskill.py new file mode 100644 index 00000000..c0e78640 --- /dev/null +++ b/rl_games/envs/maniskill.py @@ -0,0 +1,216 @@ +from rl_games.common.ivecenv import IVecEnv +#import gym +import numpy as np + +import torch +from typing import Dict, Literal + + +# def flatten_dict(obs): +# res = [] +# for k,v in obs.items(): +# res.append(v.reshape(v.shape[0], -1)) + +# res = np.column_stack(res) +# return res + + + + +# # create an environment with our configs and then reset to a clean state +# env = gym.make(env_id, +# num_envs=4, +# obs_mode=obs_mode, +# reward_mode=reward_mode, +# control_mode=control_mode, +# robot_uids=robot_uids, +# enable_shadow=True # this makes the default lighting cast shadows +# ) +# obs, _ = env.reset() +# print("Action Space:", env.action_space) + + +VecEnvObs = Dict[str, torch.Tensor | Dict[str, torch.Tensor]] + +def _process_obs(self, obs_dict: VecEnvObs) -> torch.Tensor | dict[str, torch.Tensor]: + + # process policy obs + obs = obs_dict["policy"] + + # TODO: add state processing for asymmetric case + # TODO: add clamping? + # currently supported only single-gpu case + + if not isinstance(obs, dict): + # clip the observations + obs = torch.clamp(obs, -self._clip_obs, self._clip_obs) + # move the buffer to rl-device + obs = obs.to(device=self._rl_device).clone() + + return obs + else: + # clip the observations + for key in obs.keys(): + obs[key] = torch.clamp(obs[key], -self._clip_obs, self._clip_obs) + # move the buffer to rl-device + obs[key] = obs[key].to(device=self._rl_device).clone() + # TODO: add state processing for asymmetric case + return obs + + +class Maniskill(IVecEnv): + def __init__(self, config_name, num_envs, **kwargs): + import gym.spaces + import gymnasium + import gymnasium as gym2 + import mani_skill.envs + + # Can be any env_id from the list of Rigid-Body envs: https://maniskill.readthedocs.io/en/latest/tasks/index.html + self.env_name = kwargs.pop('env_name', 'PickCube-v1') # can be one of ['PickCube-v1', 'PegInsertionSide-v1', 'StackCube-v1'] + + # an observation type and space, see https://maniskill.readthedocs.io/en/latest/user_guide/concepts/observation.html for details + self.obs_mode = kwargs.pop('obs_mode', 'state') # can be one of ['pointcloud', 'rgbd', 'state_dict', 'state'] + + # a controller type / action space, see https://maniskill.readthedocs.io/en/latest/user_guide/concepts/controllers.html for a full list + self.control_mode = "pd_joint_delta_pos" # can be one of ['pd_ee_delta_pose', 'pd_ee_delta_pos', 'pd_joint_delta_pos', 'arm_pd_joint_pos_vel'] + + self.reward_mode = "dense" # can be one of ['sparse', 'dense'] + self.robot_uids = "panda" # can be one of ['panda', 'fetch'] + + #self.batch_size = num_envs # ??? + + #self.use_dict_obs_space = kwargs.pop('use_dict_obs_space', True) + + # self.env = gym2.make( self.env_name, + # env_type=kwargs.pop('env_type', 'gym'), + # num_envs=num_envs, + # batch_size=self.batch_size, + # **kwargs + # ) + self.env = gym2.make(self.env_name, + num_envs=num_envs, + obs_mode=self.obs_mode, + reward_mode=self.reward_mode, + control_mode=self.control_mode, + robot_uids=self.robot_uids, + enable_shadow=True # this makes the default lighting cast shadows + ) + + # if self.use_dict_obs_space: + # self.observation_space = gym.spaces.Dict({ + # 'observation' : self.env.observation_space, + # 'reward' : gym.spaces.Box(low=0, high=1, shape=( ), dtype=np.float32), + # 'last_action': gym.spaces.Box(low=0, high=self.env.action_space.n, shape=(), dtype=int) + # }) + # else: + # self.observation_space = self.env.observation_space + + # if self.flatten_obs: + # self.orig_observation_space = self.observation_space + # self.observation_space = gym.spaces.flatten_space(self.observation_space) + + print("Observation Space:", self.env.observation_space) + policy_obs_space = self.env.unwrapped.single_observation_space + print("Observation Space Unwrapped:", policy_obs_space) + + self._clip_obs = np.inf + + # TODO: single function + if isinstance(policy_obs_space, gymnasium.spaces.Dict): + # check if we have a dictionary of observations + for key in policy_obs_space.keys(): + if not isinstance(policy_obs_space[key], gymnasium.spaces.Box): + raise NotImplementedError( + f"Dictinary of dictinary observations support was not testes: '{type(policy_obs_space[key])}'." + ) + self.observation_space = gym.spaces.Dict( + { + key: gym.spaces.Box(-self._clip_obs, self._clip_obs, policy_obs_space[key].shape) + for key in policy_obs_space.keys() + } + ) + else: + self.observation_space = gym.spaces.Box(-self._clip_obs, self._clip_obs, policy_obs_space.shape) + + # if isinstance(critic_obs_space, gymnasium.spaces.Dict): + # # check if we have a dictionary of observations + # for key in critic_obs_space.keys(): + # if not isinstance(critic_obs_space[key], gymnasium.spaces.Box): + # raise NotImplementedError( + # f"Dictinary of dictinary observations support has not been tested yet: '{type(policy_obs_space[key])}'." + # ) + # self.state_observation_space = gym.spaces.Dict( + # { + # key: gym.spaces.Box(-self._clip_obs, self._clip_obs, critic_obs_space[key].shape) + # for key in critic_obs_space.keys() + # } + # ) + # else: + # self.observation_space = gym.spaces.Box(-self._clip_obs, self._clip_obs, policy_obs_space.shape) + + self.action_space = self.env.unwrapped.single_action_space + + def step(self, action): + # # move actions to sim-device + # actions = actions.detach().clone().to(device=self._sim_device) + # # clip the actions + # actions = torch.clamp(actions, -self._clip_actions, self._clip_actions) + + obs_dict, rew, terminated, truncated, extras = self.env.step(action) + # move time out information to the extras dict + # this is only needed for infinite horizon tasks + # note: only useful when `value_bootstrap` is True in the agent configuration + extras["time_outs"] = truncated #truncated.to(device=self._rl_device) + # process observations and states + #obs_and_states = self._process_obs(obs_dict) + obs_and_states = obs_dict + # move buffers to rl-device + # note: we perform clone to prevent issues when rl-device and sim-device are the same. + #rew = rew.to(device=self._rl_device) + #dones = (terminated | truncated).to(device=self._rl_device) + dones = (terminated | truncated).any() # stop if any environment terminates/truncates + # extras = { + # k: v.to(device=self._rl_device, non_blocking=True) if hasattr(v, "to") else v for k, v in extras.items() + # } + + # remap extras from "log" to "episode" + if "log" in extras: + extras["episode"] = extras.pop("log") + + # done = (terminated | truncated).any() # stop if any environment terminates/truncates + # info['time_outs'] = truncated + + # if self.obs_mode == 'state_dict': + # next_obs = obs + + # if self.flatten_obs: + # next_obs = flatten_dict(next_obs) + + # if self.use_dict_obs_space: + # next_obs = { + # 'observation': next_obs, + # 'reward': np.clip(reward, -1, 1), + # 'last_action': action + # } + #return next_obs, reward, is_done, info + return obs_and_states, rew, dones, extras + + def reset(self): + obs = self.env.reset() + # if self.flatten_obs: + # obs = flatten_dict(obs) + + return obs + + def get_number_of_agents(self): + return 1 + + def get_env_info(self): + info = {} + info['action_space'] = self.action_space + info['observation_space'] = self.observation_space + return info + + +def create_maniskill(**kwargs): + return Maniskill("", num_envs=kwargs.pop('num_actors', 16), **kwargs) \ No newline at end of file diff --git a/rl_games/networks/__init__.py b/rl_games/networks/__init__.py index ca68c624..52b71133 100644 --- a/rl_games/networks/__init__.py +++ b/rl_games/networks/__init__.py @@ -1,7 +1,7 @@ from rl_games.networks.tcnn_mlp import TcnnNetBuilder -#from rl_games.networks.vision_networks import A2CVisionBuilder, A2CVisionBackboneBuilder +from rl_games.networks.vision_networks import VisionImpalaBuilder, VisionBackboneBuilder from rl_games.algos_torch import model_builder model_builder.register_network('tcnnnet', TcnnNetBuilder) -# model_builder.register_network('vision_actor_critic', A2CVisionBuilder) -# model_builder.register_network('e2e_vision_actor_critic', A2CVisionBackboneBuilder) \ No newline at end of file +model_builder.register_network('vision_actor_critic', VisionImpalaBuilder) +model_builder.register_network('e2e_vision_actor_critic', VisionBackboneBuilder) \ No newline at end of file diff --git a/rl_games/networks/vision_networks.py b/rl_games/networks/vision_networks.py index ba98645f..634745d6 100644 --- a/rl_games/networks/vision_networks.py +++ b/rl_games/networks/vision_networks.py @@ -4,7 +4,7 @@ import torch.nn.functional as F from rl_games.algos_torch import torch_ext from rl_games.algos_torch.network_builder import NetworkBuilder, ImpalaSequential - + class VisionImpalaBuilder(NetworkBuilder): def __init__(self, **kwargs): @@ -15,7 +15,7 @@ def load(self, params): class Network(NetworkBuilder.BaseNetwork): def __init__(self, params, **kwargs): - self.actions_num = kwargs.pop('actions_num') + self.actions_num = actions_num = kwargs.pop('actions_num') full_input_shape = kwargs.pop('input_shape') proprio_size = 0 # Number of proprioceptive features if type(full_input_shape) is dict: @@ -68,18 +68,18 @@ def __init__(self, params, **kwargs): self.flatten_act = self.activations_factory.create(self.activation) if self.is_discrete: - self.logits = torch.nn.Linear(out_size, self.actions_num) + self.logits = torch.nn.Linear(out_size, actions_num) if self.is_continuous: - self.mu = torch.nn.Linear(out_size, self.actions_num) + self.mu = torch.nn.Linear(out_size, actions_num) self.mu_act = self.activations_factory.create(self.space_config['mu_activation']) mu_init = self.init_factory.create(**self.space_config['mu_init']) self.sigma_act = self.activations_factory.create(self.space_config['sigma_activation']) sigma_init = self.init_factory.create(**self.space_config['sigma_init']) if self.fixed_sigma: - self.sigma = nn.Parameter(torch.zeros(self.actions_num, requires_grad=True, dtype=torch.float32), requires_grad=True) + self.sigma = nn.Parameter(torch.zeros(actions_num, requires_grad=True, dtype=torch.float32), requires_grad=True) else: - self.sigma = torch.nn.Linear(out_size, self.actions_num) + self.sigma = torch.nn.Linear(out_size, actions_num) mlp_init = self.init_factory.create(**self.initializer) @@ -102,9 +102,6 @@ def __init__(self, params, **kwargs): mlp_init(self.value.weight) def forward(self, obs_dict): - # print(obs_dict.keys()) - # print(obs_dict['obs'].keys()) - # currently works only dictinary of camera and proprio observations obs = obs_dict['obs']['camera'] proprio = obs_dict['obs']['proprio'] if self.permute_input: @@ -227,7 +224,20 @@ def build(self, name, **kwargs): return net -from timm import create_model # timm is required for ConvNeXt and ViT +from torchvision import models, transforms + +def preprocess_image(image): + # Normalize the image using ImageNet's mean and standard deviation + normalize = transforms.Normalize( + mean=[0.485, 0.456, 0.406], # Mean of ImageNet dataset + std=[0.229, 0.224, 0.225] # Std of ImageNet dataset + ) + + # Apply the normalization + image = normalize(image) + + return image + class VisionBackboneBuilder(NetworkBuilder): def __init__(self, **kwargs): @@ -241,6 +251,8 @@ def __init__(self, params, **kwargs): self.actions_num = kwargs.pop('actions_num') full_input_shape = kwargs.pop('input_shape') + print("Observations shape: ", full_input_shape) + self.proprio_size = 0 # Number of proprioceptive features if isinstance(full_input_shape, dict): input_shape = full_input_shape['camera'] @@ -257,12 +269,11 @@ def __init__(self, params, **kwargs): if self.permute_input: input_shape = torch_ext.shape_whc_to_cwh(input_shape) - self.cnn = self._build_backbone(input_shape, self.params['backbone']) - cnn_output_size = self.cnn_output_size + self.cnn, self.cnn_output_size = self._build_backbone(input_shape, params['backbone']) - mlp_input_size = cnn_output_size + self.proprio_size + mlp_input_size = self.cnn_output_size + self.proprio_size if len(self.units) == 0: - out_size = cnn_output_size + out_size = self.cnn_output_size else: out_size = self.units[-1] @@ -307,9 +318,6 @@ def __init__(self, params, **kwargs): mlp_init = self.init_factory.create(**self.initializer) - # for m in self.modules(): - # if isinstance(m, nn.Conv2d): - # nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') for m in self.mlp: if isinstance(m, nn.Linear): mlp_init(m.weight) @@ -327,13 +335,17 @@ def __init__(self, params, **kwargs): def forward(self, obs_dict): if self.proprio_size > 0: - obs = obs_dict['camera'] - proprio = obs_dict['proprio'] + obs = obs_dict['obs']['camera'] + proprio = obs_dict['obs']['proprio'] else: obs = obs_dict['obs'] + if self.permute_input: obs = obs.permute((0, 3, 1, 2)) + if self.preprocess_image: + obs = preprocess_image(obs) + dones = obs_dict.get('dones', None) bptt_len = obs_dict.get('bptt_len', 0) states = obs_dict.get('rnn_states', None) @@ -427,25 +439,86 @@ def load(self, params): def _build_backbone(self, input_shape, backbone_params): backbone_type = backbone_params['type'] pretrained = backbone_params.get('pretrained', False) + self.preprocess_image = backbone_params.get('preprocess_image', False) + + if backbone_type == 'resnet18' or backbone_type == 'resnet34': + if backbone_type == 'resnet18': + backbone = models.resnet18(pretrained=pretrained, zero_init_residual=True) + else: + backbone = models.resnet34(pretrained=pretrained, zero_init_residual=True) - if backbone_type == 'resnet18': - model = models.resnet18(pretrained=pretrained) # Modify the first convolution layer to match input shape if needed - if input_shape[0] != 3: - model.conv1 = nn.Conv2d(input_shape[0], 64, kernel_size=7, stride=2, padding=3, bias=False) + # TODO: add low-res parameter + backbone.conv1 = nn.Conv2d(input_shape[0], 64, kernel_size=3, stride=1, padding=1, bias=False) + # backbone.maxpool = nn.Identity() + # if input_shape[0] != 3: + # backbone.conv1 = nn.Conv2d(input_shape[0], 64, kernel_size=7, stride=2, padding=3, bias=False) # Remove the fully connected layer - self.cnn_output_size = model.fc.in_features - model = nn.Sequential(*list(model.children())[:-1]) + backbone_output_size = backbone.fc.in_features + print('backbone_output_size: ', backbone_output_size) + backbone = nn.Sequential(*list(backbone.children())[:-1]) elif backbone_type == 'convnext_tiny': - model = create_model('convnext_tiny', pretrained=pretrained) - # Remove the fully connected layer - self.cnn_output_size = model.head.fc.in_features - model = nn.Sequential(*list(model.children())[:-1]) - elif backbone_type == 'vit_tiny_patch16_224': - model = create_model('vit_tiny_patch16_224', pretrained=pretrained) - # ViT outputs a single token, so no need to remove layers - self.cnn_output - - def build(self, name, **kwargs): - net = VisionBackboneBuilder.Network(self.params, **kwargs) - return net \ No newline at end of file + backbone = models.convnext_tiny(pretrained=pretrained) + backbone_output_size = backbone.classifier[2].in_features + backbone.classifier = nn.Identity() + + # Modify the first convolutional layer to work with smaller resolutions + backbone.features[0][0] = nn.Conv2d( + in_channels=input_shape[0], + out_channels=backbone.features[0][0].out_channels, + kernel_size=3, # Reduce kernel size to 3x3 + stride=1, # Reduce stride to 1 to preserve spatial resolution + padding=1, # Add padding to preserve dimensions after convolution + bias=True # False + ) + elif backbone_type == 'efficientnet_v2_s': + backbone = models.efficientnet_v2_s(pretrained=pretrained) + backbone.features[0][0] = nn.Conv2d(input_shape[0], 24, kernel_size=3, stride=1, padding=1, bias=False) + backbone_output_size = backbone.classifier[1].in_features + backbone.classifier = nn.Identity() + elif backbone_type == 'vit_b_16': + backbone = models.vision_transformer.vit_b_16(pretrained=pretrained) + + # Add a resize layer to ensure the input is correctly sized for ViT + resize_layer = nn.Upsample(size=(224, 224), mode='bilinear', align_corners=False) + + backbone_output_size = backbone.heads.head.in_features + backbone.heads.head = nn.Identity() + + # Combine the resize layer and the backbone into a sequential model + backbone = nn.Sequential(resize_layer, backbone) + # # Assuming your input image is a tensor or PIL image, resize it to 224x224 + # #obs = self.resize_transform(obs) + # backbone = models.vision_transformer.vit_b_16(pretrained=pretrained) + + # backbone_output_size = backbone.heads.head.in_features + # backbone.heads.head = nn.Identity() + else: + raise ValueError(f'Unknown backbone type: {backbone_type}') + + # Optionally freeze the follow-up layers, leaving the first convolutional layer unfrozen + if backbone_params.get('freeze', False): + print('Freezing backbone') + for name, param in backbone.named_parameters(): + if 'conv1' not in name: # Ensure the first conv layer is not frozen + param.requires_grad = False + + return backbone, backbone_output_size + + def is_separate_critic(self): + return False + + def is_rnn(self): + return self.has_rnn + + def get_default_rnn_state(self): + num_layers = self.rnn_layers + if self.rnn_name == 'lstm': + return (torch.zeros((num_layers, self.num_seqs, self.rnn_units)), + torch.zeros((num_layers, self.num_seqs, self.rnn_units))) + else: + return (torch.zeros((num_layers, self.num_seqs, self.rnn_units))) + + def build(self, name, **kwargs): + net = VisionBackboneBuilder.Network(self.params, **kwargs) + return net \ No newline at end of file