diff --git a/rl_games/configs/maniskill/maniskill_pickcube_impala_lstm.yaml b/rl_games/configs/maniskill/maniskill_pickcube_impala_lstm.yaml index a0a71480..ede946e7 100644 --- a/rl_games/configs/maniskill/maniskill_pickcube_impala_lstm.yaml +++ b/rl_games/configs/maniskill/maniskill_pickcube_impala_lstm.yaml @@ -54,7 +54,7 @@ params: concat_output: True config: - name: PickCube_RGB_impala_LSTM_norm + name: PickCube_RGB_AUX_LOSS_add_impala_LSTM_norm env_name: maniskill reward_shaper: scale_value: 1.0 @@ -70,13 +70,13 @@ params: reward_shaper: scale_value: 1.0 gamma: 0.99 - tau : 0.95 + tau: 0.95 learning_rate: 2e-4 lr_schedule: linear kl_threshold: 0.008 max_epochs: 20000 save_best_after: 25 - save_frequency: 500 + save_frequency: 1000 grad_norm: 1.0 entropy_coef: 0.0 truncate_grads: True @@ -92,11 +92,12 @@ params: env_config: env_name: PickCube-v1 - obs_mode: rgbd + obs_mode: rgb control_mode: pd_ee_delta_pose reward_mode: dense + aux_loss: True player: - render: True + render: False deterministic: True diff --git a/rl_games/configs/maniskill/maniskill_pickcube_impala_small_lstm.yaml b/rl_games/configs/maniskill/maniskill_pickcube_impala_small_lstm.yaml index 01ed1a7f..763b7eb1 100644 --- a/rl_games/configs/maniskill/maniskill_pickcube_impala_small_lstm.yaml +++ b/rl_games/configs/maniskill/maniskill_pickcube_impala_small_lstm.yaml @@ -15,8 +15,6 @@ params: network: name: vision_actor_critic - require_rewards: False - require_last_actions: False separate: False value_shape: 1 space: @@ -40,7 +38,7 @@ params: regularizer: name: None mlp: - units: [256] + units: [256, 128] activation: elu regularizer: name: None @@ -55,7 +53,7 @@ params: concat_output: True config: - name: PickCube_RGB_Impala_Small_LSTM_RMSnorm_embedding_128envs + name: PickCube_RGB_AUX_LOSS_Impala_Small_LSTM_256x128_MLP_100weight env_name: maniskill reward_shaper: scale_value: 1.0 @@ -67,22 +65,22 @@ params: normalize_input: False normalize_value: True normalize_advantage: True - num_actors: 128 #256 + num_actors: 256 reward_shaper: scale_value: 1.0 gamma: 0.99 - tau : 0.95 + tau: 0.95 learning_rate: 2e-4 lr_schedule: None kl_threshold: 0.008 max_epochs: 10000 save_best_after: 25 - save_frequency: 500 + save_frequency: 1000 grad_norm: 1.0 entropy_coef: 0.0 truncate_grads: True e_clip: 0.2 - horizon_length: 32 + horizon_length: 64 minibatch_size: 2048 mini_epochs: 2 critic_coef: 1 @@ -93,9 +91,10 @@ params: env_config: env_name: PickCube-v1 - obs_mode: rgbd + obs_mode: rgb control_mode: pd_ee_delta_pose reward_mode: dense + aux_loss: True player: render: False diff --git a/rl_games/configs/maniskill/maniskill_pickcube_vision.yaml b/rl_games/configs/maniskill/maniskill_pickcube_vision.yaml index 6b687393..d6d2c2b6 100644 --- a/rl_games/configs/maniskill/maniskill_pickcube_vision.yaml +++ b/rl_games/configs/maniskill/maniskill_pickcube_vision.yaml @@ -40,7 +40,7 @@ params: zero_init_residual: True norm_layer: None mlp: - units: [256] + units: [256, 128] activation: elu regularizer: name: None @@ -55,7 +55,7 @@ params: concat_output: True config: - name: PickCube_RGB_resnet18_LSTM_norm_embedding_128envs_2e-4_linear_lr_first_layer_retrain + name: PickCube_RGB_resnet18_LSTM_256x128_MLP env_name: maniskill reward_shaper: scale_value: 1.0 @@ -67,23 +67,23 @@ params: normalize_input: False normalize_value: True normalize_advantage: True - num_actors: 64 + num_actors: 128 reward_shaper: scale_value: 1.0 gamma: 0.99 - tau : 0.95 + tau: 0.95 learning_rate: 2e-4 lr_schedule: linear kl_threshold: 0.008 - max_epochs: 20000 + max_epochs: 10000 save_best_after: 25 - save_frequency: 500 + save_frequency: 1000 grad_norm: 1.0 entropy_coef: 0.0 truncate_grads: True e_clip: 0.2 - horizon_length: 128 - minibatch_size: 512 + horizon_length: 64 #128 + minibatch_size: 1024 mini_epochs: 2 critic_coef: 1 clip_value: True @@ -93,9 +93,10 @@ params: env_config: env_name: PickCube-v1 - obs_mode: rgbd + obs_mode: rgb control_mode: pd_ee_delta_pose reward_mode: dense + aux_loss: False player: render: False diff --git a/rl_games/envs/maniskill.py b/rl_games/envs/maniskill.py index fcabf937..81423e86 100644 --- a/rl_games/envs/maniskill.py +++ b/rl_games/envs/maniskill.py @@ -7,6 +7,7 @@ import gymnasium as gym2 import gymnasium.spaces.utils from gymnasium.vector.utils import batch_space +from mani_skill.utils import common VecEnvObs = Dict[str, torch.Tensor | Dict[str, torch.Tensor]] @@ -36,6 +37,19 @@ def _process_obs(self, obs_dict: VecEnvObs) -> torch.Tensor | dict[str, torch.Te # TODO: add state processing for asymmetric case return obs +def save_images_to_file(images: torch.Tensor, file_path: str): + """Save images to file. + + Args: + images: A tensor of shape (N, H, W, C) containing the images. + file_path: The path to save the images to. + """ + from torchvision.utils import make_grid, save_image + + save_image( + make_grid(torch.swapaxes(images.unsqueeze(1), 1, -1).squeeze(-1), nrow=round(images.shape[0] ** 0.5)), file_path + ) + class RlgFlattenRGBDObservationWrapper(gym2.ObservationWrapper): """ @@ -51,10 +65,11 @@ class RlgFlattenRGBDObservationWrapper(gym2.ObservationWrapper): def __init__(self, env, rgb=True, depth=False, state=True, aux_loss=False) -> None: from mani_skill.envs.sapien_env import BaseEnv - from mani_skill.utils import common self.base_env: BaseEnv = env.unwrapped self.aux_loss = aux_loss + self.write_image_to_file = False + super().__init__(env) self.include_rgb = rgb self.include_depth = depth @@ -83,6 +98,9 @@ def observation(self, observation: Dict): images.append(cam_data["depth"]) images = torch.concat(images, axis=-1) + if self.write_image_to_file: + save_images_to_file(images.float() / 255.0, f"pickup_cube_{'rgb'}.png") + # flatten the rest of the data which should just be state data observation = common.flatten_state_dict(observation, use_torch=True) @@ -103,7 +121,6 @@ def __init__(self, config_name, num_envs, **kwargs): import gymnasium import gymnasium as gym2 import mani_skill.envs - from mani_skill.utils.wrappers import FlattenRGBDObservationWrapper # Can be any env_id from the list of Rigid-Body envs: https://maniskill.readthedocs.io/en/latest/tasks/index.html self.env_name = kwargs.pop('env_name', 'PickCube-v1') # can be one of ['PickCube-v1', 'PegInsertionSide-v1', 'StackCube-v1'] @@ -141,7 +158,7 @@ def __init__(self, config_name, num_envs, **kwargs): print("Observation Space Unwrapped Before:", policy_obs_space) # TODO: add pointcloud and Depth support - if self.obs_mode == 'rgbd': + if self.obs_mode == 'rgb' or self.obs_mode == 'rgbd': self.env = RlgFlattenRGBDObservationWrapper(self.env, aux_loss=self.aux_loss) policy_obs_space = self.env.unwrapped.single_observation_space print("Observation Space Unwrapped After:", policy_obs_space) @@ -152,7 +169,7 @@ def __init__(self, config_name, num_envs, **kwargs): for key, value in policy_obs_space.items(): print("Key:", key) print("Value:", value) - if key == 'rgbd': + if key == 'rgb' or key == 'rgbd': print("RGBD Shape:", value.shape) print("RGBD Dtype:", value.dtype) print(value) @@ -188,6 +205,8 @@ def __init__(self, config_name, num_envs, **kwargs): # check if we have a dictionary of observations for key in policy_obs_space.keys(): if not isinstance(policy_obs_space[key], gymnasium.spaces.Box): + print("Key:", key) + print("Value:", policy_obs_space[key]) raise NotImplementedError( f"Dictinary of dictinary observations support was not testes: '{type(policy_obs_space[key])}'." ) @@ -228,7 +247,7 @@ def step(self, actions): # dones = (terminated | truncated) dones = torch.logical_or(terminated, truncated) if dones.any(): - env_idx = torch.arange(0, self.env.num_envs, device=self.env.device)[dones] # device=self.device + env_idx = torch.arange(0, self.env.unwrapped.num_envs, device=self.env.unwrapped.device)[dones] # device=self.device reset_obs, _ = self.env.reset(options=dict(env_idx=env_idx)) obs_and_states['obs'] = reset_obs @@ -278,4 +297,6 @@ def get_env_info(self): def create_maniskill(**kwargs): - return Maniskill("", num_envs=kwargs.pop('num_actors', 16), **kwargs) \ No newline at end of file + print("Creating Maniskill env with the following parameters:") + print(kwargs) + return Maniskill("", num_envs=kwargs.pop('num_actors', 4), **kwargs) \ No newline at end of file diff --git a/rl_games/networks/vision_networks.py b/rl_games/networks/vision_networks.py index 60c6580f..64b7ff80 100644 --- a/rl_games/networks/vision_networks.py +++ b/rl_games/networks/vision_networks.py @@ -18,7 +18,15 @@ class Network(NetworkBuilder.BaseNetwork): def __init__(self, params, **kwargs): self.actions_num = actions_num = kwargs.pop('actions_num') full_input_shape = kwargs.pop('input_shape') - self.use_aux_loss = kwargs.pop('use_aux_loss', False) + self.use_aux_loss = kwargs.pop('use_aux_loss', True) + print(kwargs) + print("params: ", params) + self.aux_loss_weight = 100.0 # kwargs.pop('aux_loss_weight', 1.0) + if self.use_aux_loss: + self.target_key = 'aux_target' + if 'aux_target' in full_input_shape: + self.target_shape = full_input_shape[self.target_key] + print("Target shape: ", self.target_shape) self.proprio_size = 0 # Number of proprioceptive features if type(full_input_shape) is dict: @@ -41,12 +49,14 @@ def __init__(self, params, **kwargs): cnn_output_size = self._calc_input_size(input_shape, self.cnn) mlp_input_size = cnn_output_size + self.proprio_size + if self.use_aux_loss: + mlp_input_size += self.target_shape[0] + if len(self.units) == 0: out_size = cnn_output_size else: out_size = self.units[-1] - self.running_mean_std = torch.jit.script(RunningMeanStd((mlp_input_size,))) self.layer_norm_emb = torch.nn.LayerNorm(mlp_input_size) #self.layer_norm_emb = torch.nn.RMSNorm(mlp_input_size) @@ -73,7 +83,10 @@ def __init__(self, params, **kwargs): self.aux_loss_map = None if self.use_aux_loss: - self.aux_loss_linear = nn.Linear(out_size, self.target_shape) + print("Building aux loss") + print("cnn_output_size: ", cnn_output_size) + print("target_shape: ", self.target_shape) + self.aux_loss_linear = nn.Linear(cnn_output_size, self.target_shape[0]) self.aux_loss_map = { 'aux_dist_loss': None } @@ -116,14 +129,6 @@ def __init__(self, params, **kwargs): mlp_init(self.value.weight) - def norm_emb(self, embedding): - #with torch.no_grad(): - return self.running_mean_std(embedding) if self.normalize_emb else embedding - # if len(self.units) == 0: - # out_size = cnn_output_size - # else: - # out_size = self.units[-1] - def get_aux_loss(self): return self.aux_loss_map @@ -147,10 +152,21 @@ def forward(self, obs_dict): out = obs out = self.cnn(out) out = out.flatten(1) - out = self.flatten_act(out) + cnn_out = self.flatten_act(out) if self.proprio_size > 0: - out = torch.cat([out, proprio], dim=1) + out = torch.cat([cnn_out, proprio], dim=1) + + if self.use_aux_loss: + y = self.aux_loss_linear(cnn_out) + out = torch.cat([out, y], dim=1) + self.aux_loss_map['aux_dist_loss'] = self.aux_loss_weight * torch.nn.functional.mse_loss(y, target_obs) + # print("aux predicted shape: ", y.shape) + # print("aux predicted: ", y) + # print("aux target: ", target_obs) + # print("delta: ", y - target_obs) + # print("aux loss: ", self.aux_loss_map['aux_dist_loss']) + out = self.layer_norm_emb(out) if self.has_rnn: @@ -187,10 +203,6 @@ def forward(self, obs_dict): value = self.value_act(self.value(out)) - if self.use_aux_loss: - y = self.aux_loss_linear(out) - self.aux_loss_map['aux_dist_loss'] = torch.nn.functional.mse_loss(y, target_obs) - if self.is_discrete: logits = self.logits(out) return logits, value, states @@ -291,6 +303,7 @@ def __init__(self, params, **kwargs): full_input_shape = kwargs.pop('input_shape') self.use_aux_loss = kwargs.pop('use_aux_loss', False) + self.aux_loss_weight = 100.0 if self.use_aux_loss: self.target_key = 'aux_target' if 'aux_target' in full_input_shape: @@ -324,12 +337,15 @@ def __init__(self, params, **kwargs): self.cnn, self.cnn_output_size = self._build_backbone(input_shape, params['backbone']) mlp_input_size = self.cnn_output_size + self.proprio_size + if self.use_aux_loss: + mlp_input_size += self.target_shape[0] + if len(self.units) == 0: out_size = self.cnn_output_size else: out_size = self.units[-1] - self.layer_norm_emb = torch.nn.LayerNorm((mlp_input_size,)) + self.layer_norm_emb = torch.nn.LayerNorm(mlp_input_size) if self.has_rnn: if not self.is_rnn_before_mlp: @@ -354,7 +370,7 @@ def __init__(self, params, **kwargs): self.aux_loss_map = None if self.use_aux_loss: - self.aux_loss_linear = nn.Linear(out_size, self.target_shape) + self.aux_loss_linear = nn.Linear(self.cnn_output_size, self.target_shape[0]) self.aux_loss_map = { 'aux_dist_loss': None } @@ -420,10 +436,15 @@ def forward(self, obs_dict): out = obs out = self.cnn(out) out = out.flatten(1) - out = self.flatten_act(out) + vis_out = self.flatten_act(out) if self.proprio_size > 0: - out = torch.cat([out, proprio], dim=1) + out = torch.cat([vis_out, proprio], dim=1) + + if self.use_aux_loss: + y = self.aux_loss_linear(vis_out) + out = torch.cat([out, y], dim=1) + self.aux_loss_map['aux_dist_loss'] = self.aux_loss_weight * torch.nn.functional.mse_loss(y, target_obs) out = self.layer_norm_emb(out) @@ -461,10 +482,6 @@ def forward(self, obs_dict): value = self.value_act(self.value(out)) - if self.use_aux_loss: - y = self.aux_loss_linear(out) - self.aux_loss_map['aux_dist_loss'] = torch.nn.functional.mse_loss(y, target_obs) - if self.is_discrete: logits = self.logits(out) return logits, value, states @@ -522,7 +539,7 @@ def _build_backbone(self, input_shape, backbone_params): # Modify the first convolution layer to match input shape if needed # TODO: add low-res parameter - backbone.conv1 = nn.Conv2d(input_shape[0], 64, kernel_size=3, stride=1, padding=1, bias=False) + #backbone.conv1 = nn.Conv2d(input_shape[0], 64, kernel_size=3, stride=1, padding=1, bias=False) # backbone.maxpool = nn.Identity() # if input_shape[0] != 3: # backbone.conv1 = nn.Conv2d(input_shape[0], 64, kernel_size=7, stride=2, padding=3, bias=False) diff --git a/runner.py b/runner.py index 38802593..6d043ed5 100644 --- a/runner.py +++ b/runner.py @@ -70,12 +70,12 @@ runner.run(args) + if args["track"] and global_rank == 0: + wandb.finish() + try: import ray except ImportError: pass else: ray.shutdown() - - if args["track"] and global_rank == 0: - wandb.finish()