From d21fd8fb1df56b3c4b845befffd541098a04c3f6 Mon Sep 17 00:00:00 2001 From: ViktorM Date: Sun, 1 Sep 2024 15:28:38 -0700 Subject: [PATCH] Maniskill RGB and Depth only observations. Added resnet and impala cube pickup configs. --- rl_games/algos_torch/a2c_continuous.py | 5 +- .../maniskill/maniskill_pickcube_impala.yaml | 101 ++++++++++++ .../maniskill/maniskill_pickcube_state.yaml | 9 +- .../maniskill/maniskill_pickcube_vision.yaml | 101 ++++++++++++ .../maniskill/maniskill_pushcube_state.yaml | 2 +- .../configs/maniskill/maniskill_resnet.yaml | 77 --------- rl_games/envs/maniskill.py | 151 ++++++++++++------ 7 files changed, 318 insertions(+), 128 deletions(-) create mode 100644 rl_games/configs/maniskill/maniskill_pickcube_impala.yaml create mode 100644 rl_games/configs/maniskill/maniskill_pickcube_vision.yaml delete mode 100644 rl_games/configs/maniskill/maniskill_resnet.yaml diff --git a/rl_games/algos_torch/a2c_continuous.py b/rl_games/algos_torch/a2c_continuous.py index 822c1568..12b24a0e 100644 --- a/rl_games/algos_torch/a2c_continuous.py +++ b/rl_games/algos_torch/a2c_continuous.py @@ -95,7 +95,7 @@ def restore_central_value_function(self, fn): def get_masked_action_values(self, obs, action_masks): assert False - @torch.compile() #(mode='max-autotune') +# @torch.compile() #(mode='max-autotune') def calc_losses(self, actor_loss_func, old_action_log_probs_batch, action_log_probs, advantage, curr_e_clip, value_preds_batch, values, return_batch, mu, entropy, rnn_masks): a_loss = actor_loss_func(old_action_log_probs_batch, action_log_probs, advantage, self.ppo, curr_e_clip) @@ -109,7 +109,8 @@ def calc_losses(self, actor_loss_func, old_action_log_probs_batch, action_log_pr b_loss = self.bound_loss(mu) else: b_loss = torch.zeros(1, device=self.ppo_device) - losses, sum_mask = torch_ext.apply_masks([a_loss.unsqueeze(1), c_loss , entropy.unsqueeze(1), b_loss.unsqueeze(1)], rnn_masks) + + losses, sum_mask = torch_ext.apply_masks([a_loss.unsqueeze(1), c_loss, entropy.unsqueeze(1), b_loss.unsqueeze(1)], rnn_masks) a_loss, c_loss, entropy, b_loss = losses[0], losses[1], losses[2], losses[3] loss = a_loss + 0.5 * c_loss * self.critic_coef - entropy * self.entropy_coef + b_loss * self.bounds_loss_coef diff --git a/rl_games/configs/maniskill/maniskill_pickcube_impala.yaml b/rl_games/configs/maniskill/maniskill_pickcube_impala.yaml new file mode 100644 index 00000000..8f3f4786 --- /dev/null +++ b/rl_games/configs/maniskill/maniskill_pickcube_impala.yaml @@ -0,0 +1,101 @@ +params: + seed: 42 + + # environment wrapper clipping + env: + # added to the wrapper + clip_observations: 5.0 + # can make custom wrapper? + clip_actions: 1.0 + algo: + name: a2c_continuous + + model: + name: continuous_a2c_logstd + + network: + name: vision_actor_critic + require_rewards: False + require_last_actions: False + separate: False + value_shape: 1 + space: + continuous: + mu_activation: None + sigma_activation: None + + mu_init: + name: default + sigma_init: + name: const_initializer + val: 0 + fixed_sigma: True + cnn: + permute_input: True + conv_depths: [16, 32, 32] + activation: relu + initializer: + name: default + regularizer: + name: None + mlp: + units: [512, 256] + activation: elu + regularizer: + name: None + initializer: + name: default + # rnn: + # name: lstm + # units: 512 + # layers: 1 + # before_mlp: True + # concat_output: True + + config: + name: PickCube_RGB_impala + env_name: maniskill + reward_shaper: + scale_value: 1.0 + device: cuda:0 + device_name: cuda:0 + multi_gpu: False + ppo: True + mixed_precision: True + normalize_input: False + normalize_value: True + normalize_advantage: True + num_actors: 256 + reward_shaper: + scale_value: 1.0 + gamma: 0.99 + tau : 0.95 + learning_rate: 1e-4 + lr_schedule: adaptive + kl_threshold: 0.008 + max_epochs: 10000 + save_best_after: 25 + save_frequency: 500 + grad_norm: 1.0 + entropy_coef: 0.0 + truncate_grads: True + e_clip: 0.2 + horizon_length: 16 + minibatch_size: 2048 + mini_epochs: 2 + critic_coef: 1 + clip_value: True + seq_length: 8 + bounds_loss_coef: 0.0001 + #weight_decay: 0.001 + + env_config: + env_name: PickCube-v1 + obs_mode: rgbd + control_mode: pd_ee_delta_pose + reward_mode: dense + + player: + render: True + deterministic: True + diff --git a/rl_games/configs/maniskill/maniskill_pickcube_state.yaml b/rl_games/configs/maniskill/maniskill_pickcube_state.yaml index 0874e43e..9779fb41 100644 --- a/rl_games/configs/maniskill/maniskill_pickcube_state.yaml +++ b/rl_games/configs/maniskill/maniskill_pickcube_state.yaml @@ -28,11 +28,11 @@ params: config: name: PickCube_state env_name: maniskill + reward_shaper: + scale_value: 1.0 normalize_input: True normalize_value: True value_bootstrap: True - reward_shaper: - scale_value: 1.0 normalize_advantage: True gamma: 0.99 tau: 0.95 @@ -48,7 +48,7 @@ params: use_smooth_clamp: True bound_loss_type: regularisation bounds_loss_coef: 0.0005 - max_epochs: 4000 + max_epochs: 5000 save_best_after: 25 save_frequency: 100 num_actors: 4096 @@ -59,6 +59,9 @@ params: env_config: env_name: PickCube-v1 # todo: add list of all envs + observation_mode: state + control_mode: pd_ee_delta_pose + reward_mode: dense player: render: True \ No newline at end of file diff --git a/rl_games/configs/maniskill/maniskill_pickcube_vision.yaml b/rl_games/configs/maniskill/maniskill_pickcube_vision.yaml new file mode 100644 index 00000000..232d3b7c --- /dev/null +++ b/rl_games/configs/maniskill/maniskill_pickcube_vision.yaml @@ -0,0 +1,101 @@ +params: + seed: 42 + + # environment wrapper clipping + env: + # added to the wrapper + clip_observations: 5.0 + # can make custom wrapper? + clip_actions: 1.0 + algo: + name: a2c_continuous + + model: + name: continuous_a2c_logstd + + network: + name: e2e_vision_actor_critic + separate: False + value_shape: 1 + space: + continuous: + mu_activation: None + sigma_activation: None + + mu_init: + name: default + sigma_init: + name: const_initializer + val: 0 + fixed_sigma: True + + backbone: + type: resnet18 # can be efficientnet_v2_s #convnext_tiny #vit_b_16 #resnet18 #resnet34 + pretrained: True + permute_input: True + freeze: False + preprocess_image: True + args: + zero_init_residual: True + norm_layer: None + mlp: + units: [512, 256] + activation: elu + regularizer: + name: None + initializer: + name: default + # rnn: + # name: lstm + # units: 512 + # layers: 1 + # before_mlp: True + # concat_output: True + + config: + name: PickCube_RGB_resnet18 + env_name: maniskill + reward_shaper: + scale_value: 1.0 + device: cuda:0 + device_name: cuda:0 + multi_gpu: False + ppo: True + mixed_precision: True + normalize_input: False + normalize_value: True + normalize_advantage: True + num_actors: 128 + reward_shaper: + scale_value: 1.0 + gamma: 0.99 + tau : 0.95 + learning_rate: 1e-4 + lr_schedule: adaptive + kl_threshold: 0.008 + max_epochs: 20000 + save_best_after: 25 + save_frequency: 500 + grad_norm: 1.0 + entropy_coef: 0.0 + truncate_grads: True + e_clip: 0.2 + horizon_length: 16 + minibatch_size: 512 + mini_epochs: 2 + critic_coef: 1 + clip_value: True + seq_length: 8 + bounds_loss_coef: 0.0001 + #weight_decay: 0.001 + + env_config: + env_name: PickCube-v1 + obs_mode: rgbd + control_mode: pd_ee_delta_pose + reward_mode: dense + + player: + render: True + deterministic: True + diff --git a/rl_games/configs/maniskill/maniskill_pushcube_state.yaml b/rl_games/configs/maniskill/maniskill_pushcube_state.yaml index c95899bf..e550675f 100644 --- a/rl_games/configs/maniskill/maniskill_pushcube_state.yaml +++ b/rl_games/configs/maniskill/maniskill_pushcube_state.yaml @@ -26,7 +26,7 @@ params: name: default config: - name: PushCube_state + name: PushCube_EE_state env_name: maniskill normalize_input: True normalize_value: True diff --git a/rl_games/configs/maniskill/maniskill_resnet.yaml b/rl_games/configs/maniskill/maniskill_resnet.yaml deleted file mode 100644 index f36f4e6c..00000000 --- a/rl_games/configs/maniskill/maniskill_resnet.yaml +++ /dev/null @@ -1,77 +0,0 @@ -params: - algo: - name: a2c_discrete - - model: - name: discrete_a2c - - network: - name: e2e_vision_actor_critic - separate: False - value_shape: 1 - space: - continuous: - - backbone: - type: resnet18 # can be efficientnet_v2_s #convnext_tiny #vit_b_16 #resnet18 #resnet34 - pretrained: True - permute_input: False - freeze: False - preprocess_image: True - args: - zero_init_residual: True - norm_layer: None - - mlp: - units: [512, 256] - activation: elu - regularizer: - name: None - initializer: - name: default - rnn: - name: lstm - units: 512 - layers: 1 - before_mlp: True - concat_output: True - - config: - name: Maniskill_resnet18 - env_name: maniskill - score_to_win: 20.0 - mixed_precision: True - normalize_input: False - normalize_value: True - normalize_advantage: True - gamma: 0.99 - tau: 0.95 - grad_norm: 1.0 - entropy_coef: 0.01 - truncate_grads: True - e_clip: 0.2 - clip_value: True - save_best_after: 25 - save_frequency: 200 - num_actors: 64 - horizon_length: 128 - minibatch_size: 2048 - mini_epochs: 2 - critic_coef: 1 - learning_rate: 2e-4 - lr_schedule: linear - kl_threshold: 0.01 - use_diagnostics: True - seq_length: 8 - max_epochs: 500 - #weight_decay: 0.001 - - env_config: - env_name: PickCube-v1 - - player: - render: True - games_num: 10 - n_game_life: 1 - deterministic: True - diff --git a/rl_games/envs/maniskill.py b/rl_games/envs/maniskill.py index d3dff7f4..d31a4695 100644 --- a/rl_games/envs/maniskill.py +++ b/rl_games/envs/maniskill.py @@ -3,31 +3,14 @@ import numpy as np import torch -from typing import Dict, Literal - - -# def flatten_dict(obs): -# res = [] -# for k,v in obs.items(): -# res.append(v.reshape(v.shape[0], -1)) - -# res = np.column_stack(res) -# return res +from typing import Dict +import gymnasium as gym2 +import gymnasium.spaces.utils +from gymnasium.vector.utils import batch_space - - -# # create an environment with our configs and then reset to a clean state -# env = gym.make(env_id, -# num_envs=4, -# obs_mode=obs_mode, -# reward_mode=reward_mode, -# control_mode=control_mode, -# robot_uids=robot_uids, -# enable_shadow=True # this makes the default lighting cast shadows -# ) -# obs, _ = env.reset() -# print("Action Space:", env.action_space) +from mani_skill.envs.sapien_env import BaseEnv +from mani_skill.utils import common VecEnvObs = Dict[str, torch.Tensor | Dict[str, torch.Tensor]] @@ -56,14 +39,63 @@ def _process_obs(self, obs_dict: VecEnvObs) -> torch.Tensor | dict[str, torch.Te obs[key] = obs[key].to(device=self._rl_device).clone() # TODO: add state processing for asymmetric case return obs + +def maniskill_process_obs(obs_dict: VecEnvObs) -> dict[str, torch.Tensor]: + obs_dict['obs']['camera'] = obs_dict['obs'].pop('rgbd') + obs_dict['obs']['proprio'] = obs_dict['obs'].pop('state') + obs_dict['obs']['camera'] = obs_dict['obs']['camera'].float() / 255.0 + return obs_dict + + +class RlgFlattenRGBDObservationWrapper(gym2.ObservationWrapper): + """ + Flattens the rgbd mode observations into a dictionary with two keys, "camera" and "proprio" + + Args: + rgb (bool): Whether to include rgb images in the observation + depth (bool): Whether to include depth images in the observation + state (bool): Whether to include state data in the observation + + Note that the returned observations will have a "rgbd" or "rgb" or "depth" key depending on the rgb/depth bool flags. + """ + + def __init__(self, env, rgb=True, depth=False, state=True) -> None: + self.base_env: BaseEnv = env.unwrapped + super().__init__(env) + self.include_rgb = rgb + self.include_depth = depth + self.include_state = state + new_obs = self.observation(self.base_env._init_raw_obs) + self.base_env.update_obs_space(new_obs) + + def observation(self, observation: Dict): + sensor_data = observation.pop("sensor_data") + del observation["sensor_param"] + images = [] + for cam_data in sensor_data.values(): + if self.include_rgb: + images.append(cam_data["rgb"]) + if self.include_depth: + images.append(cam_data["depth"]) + images = torch.concat(images, axis=-1) + # flatten the rest of the data which should just be state data + observation = common.flatten_state_dict(observation, use_torch=True) + ret = dict() + if self.include_state: + ret["proprio"] = observation + ret["camera"] = images.float() / 255.0 + + return ret class Maniskill(IVecEnv): + def __init__(self, config_name, num_envs, **kwargs): import gym.spaces import gymnasium import gymnasium as gym2 import mani_skill.envs + from mani_skill.utils.wrappers import FlattenRGBDObservationWrapper # Can be any env_id from the list of Rigid-Body envs: https://maniskill.readthedocs.io/en/latest/tasks/index.html self.env_name = kwargs.pop('env_name', 'PickCube-v1') # can be one of ['PickCube-v1', 'PegInsertionSide-v1', 'StackCube-v1'] @@ -72,11 +104,19 @@ def __init__(self, config_name, num_envs, **kwargs): self.obs_mode = kwargs.pop('obs_mode', 'state') # can be one of ['pointcloud', 'rgbd', 'state_dict', 'state'] # a controller type / action space, see https://maniskill.readthedocs.io/en/latest/user_guide/concepts/controllers.html for a full list - self.control_mode = "pd_joint_delta_pos" # can be one of ['pd_ee_delta_pose', 'pd_ee_delta_pos', 'pd_joint_delta_pos', 'arm_pd_joint_pos_vel'] + # can be one of ['pd_ee_delta_pose', 'pd_ee_delta_pos', 'pd_joint_delta_pos', 'arm_pd_joint_pos_vel'] + self.control_mode = kwargs.pop('control_mode', 'pd_ee_delta_pose') #"pd_joint_delta_pos" - self.reward_mode = "dense" # can be one of ['sparse', 'dense'] + self.reward_mode = kwargs.pop('reward_mode', 'dense') # can be one of ['sparse', 'dense'] self.robot_uids = "panda" # can be one of ['panda', 'fetch'] + print("Creating Maniskill env with the following parameters:") + print("env_name:", self.env_name) + print("obs_mode:", self.obs_mode) + print("control_mode:", self.control_mode) + print("reward_mode:", self.reward_mode) + print("robot_uids:", self.robot_uids) + self.env = gym2.make(self.env_name, num_envs=num_envs, # render_mode="rgb_array", @@ -87,6 +127,31 @@ def __init__(self, config_name, num_envs, **kwargs): enable_shadow=True # this makes the default lighting cast shadows ) + print("Observation Space Before:", self.env.observation_space) + policy_obs_space = self.env.unwrapped.single_observation_space + print("Observation Space Unwrapped Before:", policy_obs_space) + + # TODO: add pointcloud and Depth support + if self.obs_mode == 'rgbd': + self.env = RlgFlattenRGBDObservationWrapper(self.env) + policy_obs_space = self.env.unwrapped.single_observation_space + + modified_policy_obs_space = {} + + # Copy existing keys and values, renaming as needed + for key, value in policy_obs_space.items(): + if key == 'rgbd': + print("RGBD Shape:", value.shape) + modified_policy_obs_space['camera'] = value + elif key == 'state': + modified_policy_obs_space['proprio'] = value + else: + modified_policy_obs_space[key] = value + + print("Observation Space Unwrapped:", modified_policy_obs_space) + + policy_obs_space = gymnasium.spaces.Dict(modified_policy_obs_space) + # from mani_skill.utils.wrappers import RecordEpisode # # to make it look a little more realistic, we will enable shadows which make the default lighting cast shadows # self.env = RecordEpisode( @@ -96,23 +161,6 @@ def __init__(self, config_name, num_envs, **kwargs): # # each 100 steps a new video is saved # max_steps_per_video=240 # ) - - # if self.use_dict_obs_space: - # self.observation_space = gym.spaces.Dict({ - # 'observation' : self.env.observation_space, - # 'reward' : gym.spaces.Box(low=0, high=1, shape=( ), dtype=np.float32), - # 'last_action': gym.spaces.Box(low=0, high=self.env.action_space.n, shape=(), dtype=int) - # }) - # else: - # self.observation_space = self.env.observation_space - - # if self.flatten_obs: - # self.orig_observation_space = self.observation_space - # self.observation_space = gym.spaces.flatten_space(self.observation_space) - - print("Observation Space:", self.env.observation_space) - policy_obs_space = self.env.unwrapped.single_observation_space - print("Observation Space Unwrapped:", policy_obs_space) self._clip_obs = 5.0 @@ -196,14 +244,27 @@ def step(self, actions): if "success" in extras: extras["successes"] = extras["success"].float().mean() - # if self.flatten_obs: - # next_obs = flatten_dict(next_obs) + # if self.obs_mode == 'rgbd': + # obs_and_states = maniskill_process_obs(obs_and_states) return obs_and_states, rew, dones, extras def reset(self): obs = self.env.reset() - return {'obs': obs[0]} + print("obs:", obs[0].keys()) + + obs_dict = {'obs': obs[0]} + + # if self.obs_mode == 'rgbd': + # obs_dict = maniskill_process_obs(obs_dict) + + # print("obs_dict:", obs_dict.keys()) + # print("obs_dict['obs']:", obs_dict['obs'].keys()) + # print("obs_dict['obs']['camera']:", obs_dict['obs']['camera'].shape) + # print("obs_dict['obs']['camera']:", obs_dict['obs']['camera'].dtype) + # print("obs_dict['obs']['camera']:", obs_dict['obs']['camera']) + + return obs_dict def render(self, mode='human'): self.env.render_human()