Skip to content

Commit

Permalink
Fixed Maniskill state-based training.
Browse files Browse the repository at this point in the history
  • Loading branch information
ViktorM committed Aug 31, 2024
1 parent d073f4d commit ead882c
Show file tree
Hide file tree
Showing 4 changed files with 23 additions and 13 deletions.
3 changes: 1 addition & 2 deletions rl_games/algos_torch/network_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,6 @@ def _build_value_layer(self, input_size, output_size, value_type='legacy'):
raise ValueError('value type is not "default", "legacy" or "two_hot_encoded"')



class A2CBuilder(NetworkBuilder):
def __init__(self, **kwargs):
NetworkBuilder.__init__(self)
Expand Down Expand Up @@ -471,7 +470,7 @@ def forward(self, obs_dict):
else:
sigma = self.sigma_act(self.sigma(out))
return mu, mu*0 + sigma, value, states

def is_separate_critic(self):
return self.separate

Expand Down
4 changes: 2 additions & 2 deletions rl_games/common/a2c_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -416,7 +416,6 @@ def get_action_values(self, obs):
'obs' : processed_obs,
'rnn_states' : self.rnn_states
}

with torch.no_grad():
res_dict = self.model(input_dict)
if self.has_central_value:
Expand Down Expand Up @@ -757,6 +756,7 @@ def play_steps(self):
res_dict = self.get_masked_action_values(self.obs, masks)
else:
res_dict = self.get_action_values(self.obs)

self.experience_buffer.update_data('obses', n, self.obs['obs'])
self.experience_buffer.update_data('dones', n, self.dones)

Expand All @@ -782,7 +782,7 @@ def play_steps(self):
self.current_lengths += 1
all_done_indices = self.dones.nonzero(as_tuple=False)
env_done_indices = all_done_indices[::self.num_agents]

self.game_rewards.update(self.current_rewards[env_done_indices])
self.game_shaped_rewards.update(self.current_shaped_rewards[env_done_indices])
self.game_lengths.update(self.current_lengths[env_done_indices])
Expand Down
12 changes: 7 additions & 5 deletions rl_games/configs/maniskill/maniskill.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,13 @@ params:
val: 0
fixed_sigma: True
mlp:
units: [512, 256, 128]
units: [256, 128, 64]
activation: elu
initializer:
name: default

config:
name: Maniskill
name: PickCube_state
env_name: maniskill
normalize_input: True
normalize_value: True
Expand All @@ -48,11 +48,13 @@ params:
use_smooth_clamp: True
bound_loss_type: regularisation
bounds_loss_coef: 0.0005
max_epochs: 2000
max_epochs: 1000
save_best_after: 25
save_frequency: 100
num_actors: 4096
horizon_length: 64
horizon_length: 16
minibatch_size: 16384
mini_epochs: 5
mini_epochs: 4
critic_coef: 2

env_config:
Expand Down
17 changes: 13 additions & 4 deletions rl_games/envs/maniskill.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,11 @@ def __init__(self, config_name, num_envs, **kwargs):
# else:
# self.observation_space = gym.spaces.Box(-self._clip_obs, self._clip_obs, policy_obs_space.shape)

self.action_space = self.env.unwrapped.single_action_space
self._clip_actions = 1.0

action_space = self.env.unwrapped.single_action_space
print("Single action apace:", action_space)
self.action_space = gym.spaces.Box(-self._clip_actions, self._clip_actions, action_space.shape)

def step(self, action):
# # move actions to sim-device
Expand All @@ -163,12 +167,14 @@ def step(self, action):
extras["time_outs"] = truncated #truncated.to(device=self._rl_device)
# process observations and states
#obs_and_states = self._process_obs(obs_dict)
obs_and_states = obs_dict

obs_and_states = {'obs': obs_dict}
# move buffers to rl-device
# note: we perform clone to prevent issues when rl-device and sim-device are the same.
#rew = rew.to(device=self._rl_device)
#dones = (terminated | truncated).to(device=self._rl_device)
dones = (terminated | truncated).any() # stop if any environment terminates/truncates
dones = (terminated | truncated) # stop if any environment terminates/truncates

# extras = {
# k: v.to(device=self._rl_device, non_blocking=True) if hasattr(v, "to") else v for k, v in extras.items()
# }
Expand Down Expand Up @@ -197,10 +203,12 @@ def step(self, action):

def reset(self):
obs = self.env.reset()
# print(obs)
print("obs reset shape:", obs[0].shape)
# if self.flatten_obs:
# obs = flatten_dict(obs)

return obs
return {'obs': obs[0]}

def get_number_of_agents(self):
return 1
Expand All @@ -209,6 +217,7 @@ def get_env_info(self):
info = {}
info['action_space'] = self.action_space
info['observation_space'] = self.observation_space
print("info:", info)
return info


Expand Down

0 comments on commit ead882c

Please sign in to comment.