From ead882c38374afb0cd87fdde1293b4249d9179c3 Mon Sep 17 00:00:00 2001 From: ViktorM Date: Sat, 31 Aug 2024 02:09:35 -0700 Subject: [PATCH] Fixed Maniskill state-based training. --- rl_games/algos_torch/network_builder.py | 3 +-- rl_games/common/a2c_common.py | 4 ++-- rl_games/configs/maniskill/maniskill.yaml | 12 +++++++----- rl_games/envs/maniskill.py | 17 +++++++++++++---- 4 files changed, 23 insertions(+), 13 deletions(-) diff --git a/rl_games/algos_torch/network_builder.py b/rl_games/algos_torch/network_builder.py index 4b0acfa3..7d39c118 100644 --- a/rl_games/algos_torch/network_builder.py +++ b/rl_games/algos_torch/network_builder.py @@ -189,7 +189,6 @@ def _build_value_layer(self, input_size, output_size, value_type='legacy'): raise ValueError('value type is not "default", "legacy" or "two_hot_encoded"') - class A2CBuilder(NetworkBuilder): def __init__(self, **kwargs): NetworkBuilder.__init__(self) @@ -471,7 +470,7 @@ def forward(self, obs_dict): else: sigma = self.sigma_act(self.sigma(out)) return mu, mu*0 + sigma, value, states - + def is_separate_critic(self): return self.separate diff --git a/rl_games/common/a2c_common.py b/rl_games/common/a2c_common.py index 1870f316..d754e595 100644 --- a/rl_games/common/a2c_common.py +++ b/rl_games/common/a2c_common.py @@ -416,7 +416,6 @@ def get_action_values(self, obs): 'obs' : processed_obs, 'rnn_states' : self.rnn_states } - with torch.no_grad(): res_dict = self.model(input_dict) if self.has_central_value: @@ -757,6 +756,7 @@ def play_steps(self): res_dict = self.get_masked_action_values(self.obs, masks) else: res_dict = self.get_action_values(self.obs) + self.experience_buffer.update_data('obses', n, self.obs['obs']) self.experience_buffer.update_data('dones', n, self.dones) @@ -782,7 +782,7 @@ def play_steps(self): self.current_lengths += 1 all_done_indices = self.dones.nonzero(as_tuple=False) env_done_indices = all_done_indices[::self.num_agents] - + self.game_rewards.update(self.current_rewards[env_done_indices]) self.game_shaped_rewards.update(self.current_shaped_rewards[env_done_indices]) self.game_lengths.update(self.current_lengths[env_done_indices]) diff --git a/rl_games/configs/maniskill/maniskill.yaml b/rl_games/configs/maniskill/maniskill.yaml index 59367677..04ee3a4b 100644 --- a/rl_games/configs/maniskill/maniskill.yaml +++ b/rl_games/configs/maniskill/maniskill.yaml @@ -20,13 +20,13 @@ params: val: 0 fixed_sigma: True mlp: - units: [512, 256, 128] + units: [256, 128, 64] activation: elu initializer: name: default config: - name: Maniskill + name: PickCube_state env_name: maniskill normalize_input: True normalize_value: True @@ -48,11 +48,13 @@ params: use_smooth_clamp: True bound_loss_type: regularisation bounds_loss_coef: 0.0005 - max_epochs: 2000 + max_epochs: 1000 + save_best_after: 25 + save_frequency: 100 num_actors: 4096 - horizon_length: 64 + horizon_length: 16 minibatch_size: 16384 - mini_epochs: 5 + mini_epochs: 4 critic_coef: 2 env_config: diff --git a/rl_games/envs/maniskill.py b/rl_games/envs/maniskill.py index c0e78640..6b5fc30a 100644 --- a/rl_games/envs/maniskill.py +++ b/rl_games/envs/maniskill.py @@ -148,7 +148,11 @@ def __init__(self, config_name, num_envs, **kwargs): # else: # self.observation_space = gym.spaces.Box(-self._clip_obs, self._clip_obs, policy_obs_space.shape) - self.action_space = self.env.unwrapped.single_action_space + self._clip_actions = 1.0 + + action_space = self.env.unwrapped.single_action_space + print("Single action apace:", action_space) + self.action_space = gym.spaces.Box(-self._clip_actions, self._clip_actions, action_space.shape) def step(self, action): # # move actions to sim-device @@ -163,12 +167,14 @@ def step(self, action): extras["time_outs"] = truncated #truncated.to(device=self._rl_device) # process observations and states #obs_and_states = self._process_obs(obs_dict) - obs_and_states = obs_dict + + obs_and_states = {'obs': obs_dict} # move buffers to rl-device # note: we perform clone to prevent issues when rl-device and sim-device are the same. #rew = rew.to(device=self._rl_device) #dones = (terminated | truncated).to(device=self._rl_device) - dones = (terminated | truncated).any() # stop if any environment terminates/truncates + dones = (terminated | truncated) # stop if any environment terminates/truncates + # extras = { # k: v.to(device=self._rl_device, non_blocking=True) if hasattr(v, "to") else v for k, v in extras.items() # } @@ -197,10 +203,12 @@ def step(self, action): def reset(self): obs = self.env.reset() + # print(obs) + print("obs reset shape:", obs[0].shape) # if self.flatten_obs: # obs = flatten_dict(obs) - return obs + return {'obs': obs[0]} def get_number_of_agents(self): return 1 @@ -209,6 +217,7 @@ def get_env_info(self): info = {} info['action_space'] = self.action_space info['observation_space'] = self.observation_space + print("info:", info) return info