Skip to content

Commit

Permalink
Depth training config. Dino2 vit support.
Browse files Browse the repository at this point in the history
  • Loading branch information
ViktorM committed Sep 19, 2024
1 parent fa3469a commit 092e559
Show file tree
Hide file tree
Showing 6 changed files with 270 additions and 26 deletions.
10 changes: 5 additions & 5 deletions rl_games/common/experience.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,14 +360,14 @@ def _create_tensor_from_space(self, space, base_shape):
return torch.zeros(base_shape + space.shape, dtype=dtype, device=self.device)
if type(space) is gym.spaces.Discrete:
dtype = numpy_to_torch_dtype_dict[space.dtype]
return torch.zeros(base_shape, dtype=dtype, device = self.device)
return torch.zeros(base_shape, dtype=dtype, device=self.device)
if type(space) is gym.spaces.Tuple:
'''
assuming that tuple is only Discrete tuple
'''
dtype = numpy_to_torch_dtype_dict[space.dtype]
tuple_len = len(space)
return torch.zeros(base_shape +(tuple_len,), dtype= dtype, device = self.device)
return torch.zeros(base_shape +(tuple_len,), dtype=dtype, device=self.device)
if type(space) is gym.spaces.Dict:
t_dict = {}
for k,v in space.spaces.items():
Expand All @@ -381,12 +381,12 @@ def update_data(self, name, index, val):
else:
self.tensor_dict[name][index,:] = val

def update_data_rnn(self, name, indices,play_mask, val):
def update_data_rnn(self, name, indices, play_mask, val):
if type(val) is dict:
for k,v in val:
self.tensor_dict[name][k][indices,play_mask] = v
self.tensor_dict[name][k][indices, play_mask] = v
else:
self.tensor_dict[name][indices,play_mask] = val
self.tensor_dict[name][indices, play_mask] = val

def get_transformed(self, transform_op):
res_dict = {}
Expand Down
104 changes: 104 additions & 0 deletions rl_games/configs/maniskill/maniskill_pickcube_depth.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
params:
seed: 42

# environment wrapper clipping
env:
# added to the wrapper
clip_observations: 5.0
# can make custom wrapper?
clip_actions: 1.0
algo:
name: a2c_continuous

model:
name: continuous_a2c_logstd

network:
name: e2e_vision_actor_critic
separate: False
value_shape: 1
space:
continuous:
mu_activation: None
sigma_activation: None

mu_init:
name: default
scale: 0.02
sigma_init:
name: const_initializer
val: 0
fixed_sigma: True

backbone:
type: resnet18 # can be efficientnet_v2_s #convnext_tiny #vit_b_16 #resnet18 #resnet34
pretrained: True
permute_input: True
freeze: False
preprocess_image: False
args:
zero_init_residual: True
norm_layer: None
mlp:
units: [256, 128]
activation: elu
regularizer:
name: None
initializer:
name: default
rnn:
name: lstm
layer_norm: True
units: 512
layers: 1
before_mlp: True
concat_output: True

config:
name: PickCube_RGB_resnet18_LSTM_256x128_MLP_depth
env_name: maniskill
reward_shaper:
scale_value: 1.0
device: cuda:0
device_name: cuda:0
multi_gpu: False
ppo: True
mixed_precision: True
normalize_input: False
normalize_value: True
normalize_advantage: True
num_actors: 128
reward_shaper:
scale_value: 1.0
gamma: 0.99
tau: 0.95
learning_rate: 2e-4
lr_schedule: linear
kl_threshold: 0.008
max_epochs: 10000
save_best_after: 25
save_frequency: 1000
grad_norm: 1.0
entropy_coef: 0.0
truncate_grads: True
e_clip: 0.2
horizon_length: 64 #128
minibatch_size: 1024
mini_epochs: 2
critic_coef: 1
clip_value: True
seq_length: 8
bounds_loss_coef: 0.0001
#weight_decay: 0.001

env_config:
env_name: PickCube-v1
obs_mode: depth #rgbd #rgb
control_mode: pd_ee_delta_pose
reward_mode: dense
aux_loss: False

player:
render: False
deterministic: True

4 changes: 2 additions & 2 deletions rl_games/configs/maniskill/maniskill_pickcube_vision.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ params:
concat_output: True

config:
name: PickCube_RGB_resnet18_LSTM_256x128_MLP
name: PickCube_RGB_resnet18_LSTM_256x128_MLP_frozen_RGB
env_name: maniskill
reward_shaper:
scale_value: 1.0
Expand All @@ -82,7 +82,7 @@ params:
entropy_coef: 0.0
truncate_grads: True
e_clip: 0.2
horizon_length: 64 #128
horizon_length: 64
minibatch_size: 1024
mini_epochs: 2
critic_coef: 1
Expand Down
104 changes: 104 additions & 0 deletions rl_games/configs/maniskill/maniskill_pickcube_vit.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
params:
seed: 42

# environment wrapper clipping
env:
# added to the wrapper
clip_observations: 5.0
# can make custom wrapper?
clip_actions: 1.0
algo:
name: a2c_continuous

model:
name: continuous_a2c_logstd

network:
name: e2e_vision_actor_critic
separate: False
value_shape: 1
space:
continuous:
mu_activation: None
sigma_activation: None

mu_init:
name: default
scale: 0.02
sigma_init:
name: const_initializer
val: 0
fixed_sigma: True

backbone:
type: dinov2_vits14_reg #vit_b_16 # can be efficientnet_v2_s #convnext_tiny #vit_b_16 #resnet18 #resnet34
pretrained: True
permute_input: True
freeze: True
preprocess_image: True
args:
zero_init_residual: True
norm_layer: None
mlp:
units: [256, 128]
activation: elu
regularizer:
name: None
initializer:
name: default
rnn:
name: lstm
layer_norm: True
units: 512
layers: 1
before_mlp: True
concat_output: True

config:
name: PickCube_RGB_DINO2_VITS14_REG_LSTM_256x128_MLP_frozen_RGB
env_name: maniskill
reward_shaper:
scale_value: 1.0
device: cuda:0
device_name: cuda:0
multi_gpu: False
ppo: True
mixed_precision: True
normalize_input: False
normalize_value: True
normalize_advantage: True
num_actors: 64 #128
reward_shaper:
scale_value: 1.0
gamma: 0.99
tau: 0.95
learning_rate: 2e-4
lr_schedule: linear
kl_threshold: 0.008
max_epochs: 10000
save_best_after: 25
save_frequency: 1000
grad_norm: 1.0
entropy_coef: 0.0
truncate_grads: True
e_clip: 0.2
horizon_length: 64
minibatch_size: 256 #512 #1024
mini_epochs: 2
critic_coef: 1
clip_value: True
seq_length: 8
bounds_loss_coef: 0.0001
#weight_decay: 0.001

env_config:
env_name: PickCube-v1
obs_mode: rgb
control_mode: pd_ee_delta_pose
reward_mode: dense
aux_loss: False

player:
render: False
deterministic: True

18 changes: 14 additions & 4 deletions rl_games/envs/maniskill.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,11 @@ def observation(self, observation: Dict):
ret["proprio"] = observation
if self.aux_loss:
ret['aux_target'] = aux_target
ret["camera"] = images

if not self.include_rgb and self.include_depth:
ret["camera"] = images.float() / 32768.0
else:
ret["camera"] = images

return ret

Expand Down Expand Up @@ -158,8 +162,10 @@ def __init__(self, config_name, num_envs, **kwargs):
print("Observation Space Unwrapped Before:", policy_obs_space)

# TODO: add pointcloud and Depth support
if self.obs_mode == 'rgb' or self.obs_mode == 'rgbd':
self.env = RlgFlattenRGBDObservationWrapper(self.env, aux_loss=self.aux_loss)
use_rgb = self.obs_mode == 'rgbd' or self.obs_mode == 'rgb'
use_depth = self.obs_mode == 'rgbd' or self.obs_mode == 'depth'
if self.obs_mode == 'rgb' or self.obs_mode == 'rgbd' or self.obs_mode == 'depth':
self.env = RlgFlattenRGBDObservationWrapper(self.env, aux_loss=self.aux_loss, rgb=use_rgb, depth=use_depth)
policy_obs_space = self.env.unwrapped.single_observation_space
print("Observation Space Unwrapped After:", policy_obs_space)

Expand Down Expand Up @@ -213,7 +219,11 @@ def __init__(self, config_name, num_envs, **kwargs):

val = policy_obs_space[key]
if val.dtype == np.float16 or val.dtype == np.float32:
self.observation_space[key] = gym.spaces.Box(-self._clip_obs, self._clip_obs, val.shape)
self.observation_space[key] = gym.spaces.Box(-self._clip_obs, self._clip_obs, val.shape, dtype=val.dtype)
elif val.dtype == np.int16:
# to fix!!!
#self.observation_space[key] = gym.spaces.Box(-32768, 32767, val.shape, dtype=np.int16)
self.observation_space[key] = gym.spaces.Box(-1,0, 1.0, val.shape, dtype=np.float32)
elif policy_obs_space[key].dtype == np.uint8:
self.observation_space[key] = gym.spaces.Box(0, 255, val.shape, dtype=np.uint8)
else:
Expand Down
56 changes: 41 additions & 15 deletions rl_games/networks/vision_networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -523,12 +523,11 @@ def load(self, params):

self.has_cnn = True
self.permute_input = params['backbone'].get('permute_input', True)
self.require_rewards = params.get('require_rewards')
self.require_last_actions = params.get('require_last_actions')

def _build_backbone(self, input_shape, backbone_params):
backbone_type = backbone_params['type']
pretrained = backbone_params.get('pretrained', False)
modify_first_conv = backbone_params.get('modify_first_conv', False)
self.preprocess_image = backbone_params.get('preprocess_image', False)

if backbone_type == 'resnet18' or backbone_type == 'resnet34':
Expand All @@ -538,11 +537,13 @@ def _build_backbone(self, input_shape, backbone_params):
backbone = models.resnet34(pretrained=pretrained, zero_init_residual=True)

# Modify the first convolution layer to match input shape if needed
# TODO: add low-res parameter
#backbone.conv1 = nn.Conv2d(input_shape[0], 64, kernel_size=3, stride=1, padding=1, bias=False)
if input_shape[0] != 3:
backbone.conv1 = nn.Conv2d(input_shape[0], 64, kernel_size=7, stride=2, padding=3, bias=False)

# In case of lower resolution images, modify the first convolutional layer to work with smaller resolutions
if modify_first_conv:
backbone.conv1 = nn.Conv2d(input_shape[0], 64, kernel_size=3, stride=1, padding=1, bias=False)
# backbone.maxpool = nn.Identity()
# if input_shape[0] != 3:
# backbone.conv1 = nn.Conv2d(input_shape[0], 64, kernel_size=7, stride=2, padding=3, bias=False)
# Remove the fully connected layer
backbone_output_size = backbone.fc.in_features
print('backbone_output_size: ', backbone_output_size)
Expand Down Expand Up @@ -577,21 +578,46 @@ def _build_backbone(self, input_shape, backbone_params):

# Combine the resize layer and the backbone into a sequential model
backbone = nn.Sequential(resize_layer, backbone)
# # Assuming your input image is a tensor or PIL image, resize it to 224x224
# #obs = self.resize_transform(obs)
# backbone = models.vision_transformer.vit_b_16(pretrained=pretrained)
elif backbone_type == 'dinov2_vits14_reg':
# **Newly Added DinoV2 ViT-S14 Regression Backbone Handling**
try:
# Load the dinov2_vits14_reg model using torch.hub
backbone = torch.hub.load('facebookresearch/dinov2', 'dinov2_vits14_reg', pretrained=pretrained)
except Exception as e:
raise ValueError(f"Failed to load dinov2_vits14_reg: {e}")

# The head is already Identity, no need to modify
# Add a resize layer to ensure the input is correctly sized for DinoV2
resize_size = 196
resize_layer = nn.Upsample(size=(resize_size, resize_size), mode='bilinear', align_corners=False)

# backbone_output_size = backbone.heads.head.in_features
# backbone.heads.head = nn.Identity()
# Combine the resize layer and the backbone into a sequential model
backbone = nn.Sequential(resize_layer, backbone)

# Set backbone_output_size to 384 since the model outputs 384-dimensional features
backbone_output_size = 384
else:
raise ValueError(f'Unknown backbone type: {backbone_type}')

# Optionally freeze the follow-up layers, leaving the first convolutional layer unfrozen
# # Optionally freeze the follow-up layers, leaving the first convolutional layer unfrozen
# if backbone_params.get('freeze', False):
# print('Freezing backbone')
# for name, param in backbone.named_parameters():
# if 'conv1' not in name: # Ensure the first conv layer is not frozen
# param.requires_grad = False

# Optionally freeze layers except specified layers (for ResNet, ConvNeXt, EfficientNet, ViT)
if backbone_params.get('freeze', False):
print('Freezing backbone')
print('Freezing backbone layers except specified layers')
for name, param in backbone.named_parameters():
if 'conv1' not in name: # Ensure the first conv layer is not frozen
param.requires_grad = False
if 'vit' in backbone_type.lower():
# For all ViT backbones, avoid freezing 'patch_embed.proj' layer
if not ('patch_embed.proj' in name):
param.requires_grad = False
else:
# General case for other backbones (e.g., ResNet, ConvNeXt)
if 'conv1' not in name and 'features.0.0' not in name:
param.requires_grad = False

return backbone, backbone_output_size

Expand Down

0 comments on commit 092e559

Please sign in to comment.