diff --git a/rl_games/algos_torch/a2c_continuous.py b/rl_games/algos_torch/a2c_continuous.py index 2d46eaea..36477469 100644 --- a/rl_games/algos_torch/a2c_continuous.py +++ b/rl_games/algos_torch/a2c_continuous.py @@ -176,34 +176,6 @@ def calc_gradients(self, input_dict): rnn_masks ) - a_loss = self.actor_loss_func(old_action_log_probs_batch, action_log_probs, advantage, self.ppo, curr_e_clip) - - if self.has_value_loss: - c_loss = common_losses.critic_loss( - self.model, value_preds_batch, values, curr_e_clip, return_batch, - self.clip_value - ) - else: - c_loss = torch.zeros(1, device=self.ppo_device) - if self.bound_loss_type == 'regularisation': - b_loss = self.reg_loss(mu) - elif self.bound_loss_type == 'bound': - b_loss = self.bound_loss(mu) - else: - b_loss = torch.zeros(1, device=self.ppo_device) - - losses, sum_mask = torch_ext.apply_masks( - [ - a_loss.unsqueeze(1), - c_loss, - entropy.unsqueeze(1), - b_loss.unsqueeze(1) - ], - rnn_masks - ) - a_loss, c_loss, entropy, b_loss = losses[0], losses[1], losses[2], losses[3] - - loss = a_loss + 0.5 * c_loss * self.critic_coef - entropy * self.entropy_coef + b_loss * self.bounds_loss_coef aux_loss = self.model.get_aux_loss() self.aux_loss_dict = {} if aux_loss is not None: @@ -213,6 +185,7 @@ def calc_gradients(self, input_dict): self.aux_loss_dict[k] = v.detach() else: self.aux_loss_dict[k] = [v.detach()] + if self.multi_gpu: self.optimizer.zero_grad() else: diff --git a/rl_games/algos_torch/models.py b/rl_games/algos_torch/models.py index 05bc4b4f..8d1d73ca 100644 --- a/rl_games/algos_torch/models.py +++ b/rl_games/algos_torch/models.py @@ -268,7 +268,7 @@ def __init__(self, a2c_network, **kwargs): def get_aux_loss(self): return self.a2c_network.get_aux_loss() - + def is_rnn(self): return self.a2c_network.is_rnn() diff --git a/rl_games/configs/maniskill/maniskill_pickcube_vision.yaml b/rl_games/configs/maniskill/maniskill_pickcube_vision.yaml index 0d8449e3..18f3039c 100644 --- a/rl_games/configs/maniskill/maniskill_pickcube_vision.yaml +++ b/rl_games/configs/maniskill/maniskill_pickcube_vision.yaml @@ -55,7 +55,7 @@ params: concat_output: True config: - name: PickCube_RGB_resnet18_LSTM_norm_embedding_64envs + name: PickCube_RGB_resnet18_LSTM_norm_embedding_64envs_auxloss env_name: maniskill reward_shaper: scale_value: 1.0 @@ -67,7 +67,7 @@ params: normalize_input: False normalize_value: True normalize_advantage: True - num_actors: 64 #128 + num_actors: 64 reward_shaper: scale_value: 1.0 gamma: 0.99 diff --git a/rl_games/envs/maniskill.py b/rl_games/envs/maniskill.py index bbd0cda5..40e86b5c 100644 --- a/rl_games/envs/maniskill.py +++ b/rl_games/envs/maniskill.py @@ -66,7 +66,8 @@ def observation(self, observation: Dict): # print("Observation:", observation.keys()) # for key, value in observation.items(): # print(key, value.keys()) - tcp_pose = observation['extra']['tcp_pose'] + aux_target = observation['extra']['aux_target'] + del observation['extra']['aux_target'] # print("Input Obs:", observation.keys()) # print("Input Obs Agent:", observation['agent'].keys()) # print("Input Obs Extra:", observation['extra'].keys()) @@ -80,6 +81,7 @@ def observation(self, observation: Dict): if self.include_depth: images.append(cam_data["depth"]) images = torch.concat(images, axis=-1) + # flatten the rest of the data which should just be state data observation = common.flatten_state_dict(observation, use_torch=True) @@ -87,8 +89,8 @@ def observation(self, observation: Dict): if self.include_state: ret["proprio"] = observation if self.aux_loss: - ret['aux_target'] = tcp_pose - ret["camera"] = images.float() / 255.0 + ret['aux_target'] = aux_target + ret["camera"] = images return ret @@ -141,6 +143,7 @@ def __init__(self, config_name, num_envs, **kwargs): if self.obs_mode == 'rgbd': self.env = RlgFlattenRGBDObservationWrapper(self.env, aux_loss=self.aux_loss) policy_obs_space = self.env.unwrapped.single_observation_space + print("Observation Space Unwrapped After:", policy_obs_space) modified_policy_obs_space = {} @@ -150,18 +153,20 @@ def __init__(self, config_name, num_envs, **kwargs): print("Value:", value) if key == 'rgbd': print("RGBD Shape:", value.shape) + print("RGBD Dtype:", value.dtype) + print(value) + self.env.unwrapped.single_observation_space[key].dtype = np.uint8 + value.dtype = np.int8 modified_policy_obs_space['camera'] = value elif key == 'state': modified_policy_obs_space['proprio'] = value else: modified_policy_obs_space[key] = value - - # if self.aux_loss: - # modified_policy_obs_space['aux_target'] = gymnasium.spaces.Dict(low=-np.Inf, high=np.Inf, shape=(3, ), dtype=np.float32) - print("Observation Space Unwrapped:", modified_policy_obs_space) + print("Observation Space Unwrapped Done:", modified_policy_obs_space) policy_obs_space = gymnasium.spaces.Dict(modified_policy_obs_space) + print("Observation Space After:", policy_obs_space) # from mani_skill.utils.wrappers import RecordEpisode # # to make it look a little more realistic, we will enable shadows which make the default lighting cast shadows @@ -175,6 +180,8 @@ def __init__(self, config_name, num_envs, **kwargs): self._clip_obs = 5.0 + self.observation_space = gym.spaces.Dict() + # TODO: single function if isinstance(policy_obs_space, gymnasium.spaces.Dict): # check if we have a dictionary of observations @@ -183,15 +190,17 @@ def __init__(self, config_name, num_envs, **kwargs): raise NotImplementedError( f"Dictinary of dictinary observations support was not testes: '{type(policy_obs_space[key])}'." ) - self.observation_space = gym.spaces.Dict( - { - key: gym.spaces.Box(-self._clip_obs, self._clip_obs, policy_obs_space[key].shape) - for key in policy_obs_space.keys() - } - ) + + val = policy_obs_space[key] + if val.dtype == np.float16 or val.dtype == np.float32: + self.observation_space[key] = gym.spaces.Box(-self._clip_obs, self._clip_obs, val.shape) + elif policy_obs_space[key].dtype == np.uint8: + self.observation_space[key] = gym.spaces.Box(0, 255, val.shape, dtype=np.uint8) else: self.observation_space = gym.spaces.Box(-self._clip_obs, self._clip_obs, policy_obs_space.shape) + print("Observation Space:", self.observation_space) + self._clip_actions = 1.0 action_space = self.env.unwrapped.single_action_space diff --git a/rl_games/networks/vision_networks.py b/rl_games/networks/vision_networks.py index d1f8411e..ad503076 100644 --- a/rl_games/networks/vision_networks.py +++ b/rl_games/networks/vision_networks.py @@ -74,7 +74,7 @@ def __init__(self, params, **kwargs): self.aux_loss_linear = nn.Linear(out_size, self.target_shape) self.aux_loss_map = { - 'aux_dist_loss' : None + 'aux_dist_loss': None } self.value = self._build_value_layer(out_size, self.value_size) @@ -129,7 +129,7 @@ def get_aux_loss(self): def forward(self, obs_dict): obs = obs_dict['obs']['camera'] proprio = obs_dict['obs']['proprio'] - target_obs = obs['tcp_pose'] # obs[self.target_key] + target_obs = obs[self.target_key] if self.permute_input: obs = obs.permute((0, 3, 1, 2)) @@ -281,6 +281,12 @@ def __init__(self, params, **kwargs): self.actions_num = kwargs.pop('actions_num') full_input_shape = kwargs.pop('input_shape') + print('full_input_shape: ', full_input_shape) + + self.target_key = 'aux_target' + self.target_shape = full_input_shape[self.target_key] + print("Target shape: ", self.target_shape) + print("Observations shape: ", full_input_shape) self.proprio_size = 0 # Number of proprioceptive features @@ -288,6 +294,11 @@ def __init__(self, params, **kwargs): input_shape = full_input_shape['camera'] proprio_shape = full_input_shape['proprio'] self.proprio_size = proprio_shape[0] + + # # TODO: This is a hack to get the target shape + # for k, v in full_input_shape.items(): + # if self.target_key == k: + # self.target_shape = v[0] else: input_shape = full_input_shape @@ -330,6 +341,12 @@ def __init__(self, params, **kwargs): self.mlp = self._build_mlp(**mlp_args) + self.aux_loss_linear = nn.Linear(out_size, self.target_shape[0]) + + self.aux_loss_map = { + 'aux_dist_loss': None + } + self.value = self._build_value_layer(out_size, self.value_size) self.value_act = self.activations_factory.create(self.value_activation) self.flatten_act = self.activations_factory.create(self.activation) @@ -365,6 +382,9 @@ def __init__(self, params, **kwargs): mlp_init(self.value.weight) + def get_aux_loss(self): + return self.aux_loss_map + def forward(self, obs_dict): if self.proprio_size > 0: obs = obs_dict['obs']['camera'] @@ -372,6 +392,12 @@ def forward(self, obs_dict): else: obs = obs_dict['obs'] + target_obs = obs_dict['obs'][self.target_key] + + # print('obs.min(): ', obs.min()) + # print('obs.max(): ', obs.max()) + # print('obs.shape: ', obs.shape) + if self.permute_input: obs = obs.permute((0, 3, 1, 2)) @@ -426,6 +452,9 @@ def forward(self, obs_dict): value = self.value_act(self.value(out)) + y = self.aux_loss_linear(out) + self.aux_loss_map['aux_dist_loss'] = torch.nn.functional.mse_loss(y, target_obs) + if self.is_discrete: logits = self.logits(out) return logits, value, states