From 092e5595b4269af0afd12fcd510681bd01822cdb Mon Sep 17 00:00:00 2001 From: ViktorM Date: Thu, 19 Sep 2024 13:39:26 -0700 Subject: [PATCH] Depth training config. Dino2 vit support. --- rl_games/common/experience.py | 10 +- .../maniskill/maniskill_pickcube_depth.yaml | 104 ++++++++++++++++++ .../maniskill/maniskill_pickcube_vision.yaml | 4 +- .../maniskill/maniskill_pickcube_vit.yaml | 104 ++++++++++++++++++ rl_games/envs/maniskill.py | 18 ++- rl_games/networks/vision_networks.py | 56 +++++++--- 6 files changed, 270 insertions(+), 26 deletions(-) create mode 100644 rl_games/configs/maniskill/maniskill_pickcube_depth.yaml create mode 100644 rl_games/configs/maniskill/maniskill_pickcube_vit.yaml diff --git a/rl_games/common/experience.py b/rl_games/common/experience.py index b298e602..3f79200f 100644 --- a/rl_games/common/experience.py +++ b/rl_games/common/experience.py @@ -360,14 +360,14 @@ def _create_tensor_from_space(self, space, base_shape): return torch.zeros(base_shape + space.shape, dtype=dtype, device=self.device) if type(space) is gym.spaces.Discrete: dtype = numpy_to_torch_dtype_dict[space.dtype] - return torch.zeros(base_shape, dtype=dtype, device = self.device) + return torch.zeros(base_shape, dtype=dtype, device=self.device) if type(space) is gym.spaces.Tuple: ''' assuming that tuple is only Discrete tuple ''' dtype = numpy_to_torch_dtype_dict[space.dtype] tuple_len = len(space) - return torch.zeros(base_shape +(tuple_len,), dtype= dtype, device = self.device) + return torch.zeros(base_shape +(tuple_len,), dtype=dtype, device=self.device) if type(space) is gym.spaces.Dict: t_dict = {} for k,v in space.spaces.items(): @@ -381,12 +381,12 @@ def update_data(self, name, index, val): else: self.tensor_dict[name][index,:] = val - def update_data_rnn(self, name, indices,play_mask, val): + def update_data_rnn(self, name, indices, play_mask, val): if type(val) is dict: for k,v in val: - self.tensor_dict[name][k][indices,play_mask] = v + self.tensor_dict[name][k][indices, play_mask] = v else: - self.tensor_dict[name][indices,play_mask] = val + self.tensor_dict[name][indices, play_mask] = val def get_transformed(self, transform_op): res_dict = {} diff --git a/rl_games/configs/maniskill/maniskill_pickcube_depth.yaml b/rl_games/configs/maniskill/maniskill_pickcube_depth.yaml new file mode 100644 index 00000000..58a1c4b4 --- /dev/null +++ b/rl_games/configs/maniskill/maniskill_pickcube_depth.yaml @@ -0,0 +1,104 @@ +params: + seed: 42 + + # environment wrapper clipping + env: + # added to the wrapper + clip_observations: 5.0 + # can make custom wrapper? + clip_actions: 1.0 + algo: + name: a2c_continuous + + model: + name: continuous_a2c_logstd + + network: + name: e2e_vision_actor_critic + separate: False + value_shape: 1 + space: + continuous: + mu_activation: None + sigma_activation: None + + mu_init: + name: default + scale: 0.02 + sigma_init: + name: const_initializer + val: 0 + fixed_sigma: True + + backbone: + type: resnet18 # can be efficientnet_v2_s #convnext_tiny #vit_b_16 #resnet18 #resnet34 + pretrained: True + permute_input: True + freeze: False + preprocess_image: False + args: + zero_init_residual: True + norm_layer: None + mlp: + units: [256, 128] + activation: elu + regularizer: + name: None + initializer: + name: default + rnn: + name: lstm + layer_norm: True + units: 512 + layers: 1 + before_mlp: True + concat_output: True + + config: + name: PickCube_RGB_resnet18_LSTM_256x128_MLP_depth + env_name: maniskill + reward_shaper: + scale_value: 1.0 + device: cuda:0 + device_name: cuda:0 + multi_gpu: False + ppo: True + mixed_precision: True + normalize_input: False + normalize_value: True + normalize_advantage: True + num_actors: 128 + reward_shaper: + scale_value: 1.0 + gamma: 0.99 + tau: 0.95 + learning_rate: 2e-4 + lr_schedule: linear + kl_threshold: 0.008 + max_epochs: 10000 + save_best_after: 25 + save_frequency: 1000 + grad_norm: 1.0 + entropy_coef: 0.0 + truncate_grads: True + e_clip: 0.2 + horizon_length: 64 #128 + minibatch_size: 1024 + mini_epochs: 2 + critic_coef: 1 + clip_value: True + seq_length: 8 + bounds_loss_coef: 0.0001 + #weight_decay: 0.001 + + env_config: + env_name: PickCube-v1 + obs_mode: depth #rgbd #rgb + control_mode: pd_ee_delta_pose + reward_mode: dense + aux_loss: False + + player: + render: False + deterministic: True + diff --git a/rl_games/configs/maniskill/maniskill_pickcube_vision.yaml b/rl_games/configs/maniskill/maniskill_pickcube_vision.yaml index d6d2c2b6..ff712598 100644 --- a/rl_games/configs/maniskill/maniskill_pickcube_vision.yaml +++ b/rl_games/configs/maniskill/maniskill_pickcube_vision.yaml @@ -55,7 +55,7 @@ params: concat_output: True config: - name: PickCube_RGB_resnet18_LSTM_256x128_MLP + name: PickCube_RGB_resnet18_LSTM_256x128_MLP_frozen_RGB env_name: maniskill reward_shaper: scale_value: 1.0 @@ -82,7 +82,7 @@ params: entropy_coef: 0.0 truncate_grads: True e_clip: 0.2 - horizon_length: 64 #128 + horizon_length: 64 minibatch_size: 1024 mini_epochs: 2 critic_coef: 1 diff --git a/rl_games/configs/maniskill/maniskill_pickcube_vit.yaml b/rl_games/configs/maniskill/maniskill_pickcube_vit.yaml new file mode 100644 index 00000000..e5d490a3 --- /dev/null +++ b/rl_games/configs/maniskill/maniskill_pickcube_vit.yaml @@ -0,0 +1,104 @@ +params: + seed: 42 + + # environment wrapper clipping + env: + # added to the wrapper + clip_observations: 5.0 + # can make custom wrapper? + clip_actions: 1.0 + algo: + name: a2c_continuous + + model: + name: continuous_a2c_logstd + + network: + name: e2e_vision_actor_critic + separate: False + value_shape: 1 + space: + continuous: + mu_activation: None + sigma_activation: None + + mu_init: + name: default + scale: 0.02 + sigma_init: + name: const_initializer + val: 0 + fixed_sigma: True + + backbone: + type: dinov2_vits14_reg #vit_b_16 # can be efficientnet_v2_s #convnext_tiny #vit_b_16 #resnet18 #resnet34 + pretrained: True + permute_input: True + freeze: True + preprocess_image: True + args: + zero_init_residual: True + norm_layer: None + mlp: + units: [256, 128] + activation: elu + regularizer: + name: None + initializer: + name: default + rnn: + name: lstm + layer_norm: True + units: 512 + layers: 1 + before_mlp: True + concat_output: True + + config: + name: PickCube_RGB_DINO2_VITS14_REG_LSTM_256x128_MLP_frozen_RGB + env_name: maniskill + reward_shaper: + scale_value: 1.0 + device: cuda:0 + device_name: cuda:0 + multi_gpu: False + ppo: True + mixed_precision: True + normalize_input: False + normalize_value: True + normalize_advantage: True + num_actors: 64 #128 + reward_shaper: + scale_value: 1.0 + gamma: 0.99 + tau: 0.95 + learning_rate: 2e-4 + lr_schedule: linear + kl_threshold: 0.008 + max_epochs: 10000 + save_best_after: 25 + save_frequency: 1000 + grad_norm: 1.0 + entropy_coef: 0.0 + truncate_grads: True + e_clip: 0.2 + horizon_length: 64 + minibatch_size: 256 #512 #1024 + mini_epochs: 2 + critic_coef: 1 + clip_value: True + seq_length: 8 + bounds_loss_coef: 0.0001 + #weight_decay: 0.001 + + env_config: + env_name: PickCube-v1 + obs_mode: rgb + control_mode: pd_ee_delta_pose + reward_mode: dense + aux_loss: False + + player: + render: False + deterministic: True + diff --git a/rl_games/envs/maniskill.py b/rl_games/envs/maniskill.py index 81423e86..568b2de3 100644 --- a/rl_games/envs/maniskill.py +++ b/rl_games/envs/maniskill.py @@ -109,7 +109,11 @@ def observation(self, observation: Dict): ret["proprio"] = observation if self.aux_loss: ret['aux_target'] = aux_target - ret["camera"] = images + + if not self.include_rgb and self.include_depth: + ret["camera"] = images.float() / 32768.0 + else: + ret["camera"] = images return ret @@ -158,8 +162,10 @@ def __init__(self, config_name, num_envs, **kwargs): print("Observation Space Unwrapped Before:", policy_obs_space) # TODO: add pointcloud and Depth support - if self.obs_mode == 'rgb' or self.obs_mode == 'rgbd': - self.env = RlgFlattenRGBDObservationWrapper(self.env, aux_loss=self.aux_loss) + use_rgb = self.obs_mode == 'rgbd' or self.obs_mode == 'rgb' + use_depth = self.obs_mode == 'rgbd' or self.obs_mode == 'depth' + if self.obs_mode == 'rgb' or self.obs_mode == 'rgbd' or self.obs_mode == 'depth': + self.env = RlgFlattenRGBDObservationWrapper(self.env, aux_loss=self.aux_loss, rgb=use_rgb, depth=use_depth) policy_obs_space = self.env.unwrapped.single_observation_space print("Observation Space Unwrapped After:", policy_obs_space) @@ -213,7 +219,11 @@ def __init__(self, config_name, num_envs, **kwargs): val = policy_obs_space[key] if val.dtype == np.float16 or val.dtype == np.float32: - self.observation_space[key] = gym.spaces.Box(-self._clip_obs, self._clip_obs, val.shape) + self.observation_space[key] = gym.spaces.Box(-self._clip_obs, self._clip_obs, val.shape, dtype=val.dtype) + elif val.dtype == np.int16: + # to fix!!! + #self.observation_space[key] = gym.spaces.Box(-32768, 32767, val.shape, dtype=np.int16) + self.observation_space[key] = gym.spaces.Box(-1,0, 1.0, val.shape, dtype=np.float32) elif policy_obs_space[key].dtype == np.uint8: self.observation_space[key] = gym.spaces.Box(0, 255, val.shape, dtype=np.uint8) else: diff --git a/rl_games/networks/vision_networks.py b/rl_games/networks/vision_networks.py index 64b7ff80..484b7764 100644 --- a/rl_games/networks/vision_networks.py +++ b/rl_games/networks/vision_networks.py @@ -523,12 +523,11 @@ def load(self, params): self.has_cnn = True self.permute_input = params['backbone'].get('permute_input', True) - self.require_rewards = params.get('require_rewards') - self.require_last_actions = params.get('require_last_actions') def _build_backbone(self, input_shape, backbone_params): backbone_type = backbone_params['type'] pretrained = backbone_params.get('pretrained', False) + modify_first_conv = backbone_params.get('modify_first_conv', False) self.preprocess_image = backbone_params.get('preprocess_image', False) if backbone_type == 'resnet18' or backbone_type == 'resnet34': @@ -538,11 +537,13 @@ def _build_backbone(self, input_shape, backbone_params): backbone = models.resnet34(pretrained=pretrained, zero_init_residual=True) # Modify the first convolution layer to match input shape if needed - # TODO: add low-res parameter - #backbone.conv1 = nn.Conv2d(input_shape[0], 64, kernel_size=3, stride=1, padding=1, bias=False) + if input_shape[0] != 3: + backbone.conv1 = nn.Conv2d(input_shape[0], 64, kernel_size=7, stride=2, padding=3, bias=False) + + # In case of lower resolution images, modify the first convolutional layer to work with smaller resolutions + if modify_first_conv: + backbone.conv1 = nn.Conv2d(input_shape[0], 64, kernel_size=3, stride=1, padding=1, bias=False) # backbone.maxpool = nn.Identity() - # if input_shape[0] != 3: - # backbone.conv1 = nn.Conv2d(input_shape[0], 64, kernel_size=7, stride=2, padding=3, bias=False) # Remove the fully connected layer backbone_output_size = backbone.fc.in_features print('backbone_output_size: ', backbone_output_size) @@ -577,21 +578,46 @@ def _build_backbone(self, input_shape, backbone_params): # Combine the resize layer and the backbone into a sequential model backbone = nn.Sequential(resize_layer, backbone) - # # Assuming your input image is a tensor or PIL image, resize it to 224x224 - # #obs = self.resize_transform(obs) - # backbone = models.vision_transformer.vit_b_16(pretrained=pretrained) + elif backbone_type == 'dinov2_vits14_reg': + # **Newly Added DinoV2 ViT-S14 Regression Backbone Handling** + try: + # Load the dinov2_vits14_reg model using torch.hub + backbone = torch.hub.load('facebookresearch/dinov2', 'dinov2_vits14_reg', pretrained=pretrained) + except Exception as e: + raise ValueError(f"Failed to load dinov2_vits14_reg: {e}") + + # The head is already Identity, no need to modify + # Add a resize layer to ensure the input is correctly sized for DinoV2 + resize_size = 196 + resize_layer = nn.Upsample(size=(resize_size, resize_size), mode='bilinear', align_corners=False) - # backbone_output_size = backbone.heads.head.in_features - # backbone.heads.head = nn.Identity() + # Combine the resize layer and the backbone into a sequential model + backbone = nn.Sequential(resize_layer, backbone) + + # Set backbone_output_size to 384 since the model outputs 384-dimensional features + backbone_output_size = 384 else: raise ValueError(f'Unknown backbone type: {backbone_type}') - # Optionally freeze the follow-up layers, leaving the first convolutional layer unfrozen + # # Optionally freeze the follow-up layers, leaving the first convolutional layer unfrozen + # if backbone_params.get('freeze', False): + # print('Freezing backbone') + # for name, param in backbone.named_parameters(): + # if 'conv1' not in name: # Ensure the first conv layer is not frozen + # param.requires_grad = False + + # Optionally freeze layers except specified layers (for ResNet, ConvNeXt, EfficientNet, ViT) if backbone_params.get('freeze', False): - print('Freezing backbone') + print('Freezing backbone layers except specified layers') for name, param in backbone.named_parameters(): - if 'conv1' not in name: # Ensure the first conv layer is not frozen - param.requires_grad = False + if 'vit' in backbone_type.lower(): + # For all ViT backbones, avoid freezing 'patch_embed.proj' layer + if not ('patch_embed.proj' in name): + param.requires_grad = False + else: + # General case for other backbones (e.g., ResNet, ConvNeXt) + if 'conv1' not in name and 'features.0.0' not in name: + param.requires_grad = False return backbone, backbone_output_size