Skip to content

Commit

Permalink
Fixed Maniskill resets. Push and pick configs.
Browse files Browse the repository at this point in the history
  • Loading branch information
ViktorM committed Sep 1, 2024
1 parent ead882c commit 8be57aa
Show file tree
Hide file tree
Showing 6 changed files with 112 additions and 48 deletions.
8 changes: 5 additions & 3 deletions rl_games/common/player.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,9 @@ def __init__(self, params):
'central_value_config') is not None
self.device_name = self.config.get('device_name', 'cuda')
self.render_env = self.player_config.get('render', False)
self.games_num = self.player_config.get('games_num', 2000)
# A large constant for the deafault number of games (evaluations)
self.inf_games = int(1e9)
self.games_num = self.player_config.get('games_num', self.inf_games)

if 'deterministic' in self.player_config:
self.is_deterministic = self.player_config['deterministic']
Expand Down Expand Up @@ -271,14 +273,14 @@ def init_rnn(self):
)[2]), dtype=torch.float32).to(self.device) for s in rnn_states]

def run(self):
n_games = self.games_num
render = self.render_env
n_games = self.games_num
n_game_life = self.n_game_life
n_games = n_games * n_game_life
is_deterministic = self.is_deterministic
sum_rewards = 0
sum_steps = 0
sum_game_res = 0
n_games = n_games * n_game_life
games_played = 0
has_masks = False
has_masks_func = getattr(self.env, "has_action_mask", None) is not None
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,13 +48,13 @@ params:
use_smooth_clamp: True
bound_loss_type: regularisation
bounds_loss_coef: 0.0005
max_epochs: 1000
max_epochs: 4000
save_best_after: 25
save_frequency: 100
num_actors: 4096
horizon_length: 16
minibatch_size: 16384
mini_epochs: 4
mini_epochs: 6
critic_coef: 2

env_config:
Expand Down
65 changes: 65 additions & 0 deletions rl_games/configs/maniskill/maniskill_pushcube_state.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
params:
seed: 5
algo:
name: a2c_continuous

model:
name: continuous_a2c_logstd

network:
name: actor_critic
separate: False
space:
continuous:
mu_activation: None
sigma_activation: None
mu_init:
name: default
sigma_init:
name: const_initializer
val: 0
fixed_sigma: True
mlp:
units: [256, 128, 64]
activation: elu
initializer:
name: default

config:
name: PushCube_state
env_name: maniskill
normalize_input: True
normalize_value: True
value_bootstrap: True
reward_shaper:
scale_value: 1.0
normalize_advantage: True
gamma: 0.99
tau: 0.95

learning_rate: 3e-4
lr_schedule: adaptive
kl_threshold: 0.008
grad_norm: 1.0
entropy_coef: 0.0
truncate_grads: True
e_clip: 0.2
clip_value: True
use_smooth_clamp: True
bound_loss_type: regularisation
bounds_loss_coef: 0.0005
max_epochs: 500
save_best_after: 25
save_frequency: 50
num_actors: 4096
horizon_length: 16
minibatch_size: 16384
mini_epochs: 6
critic_coef: 2

env_config:
env_name: PushCube-v1 # todo: add list of all envs

player:
render: True
render_sleep: 0.0
76 changes: 37 additions & 39 deletions rl_games/envs/maniskill.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,25 +77,26 @@ def __init__(self, config_name, num_envs, **kwargs):
self.reward_mode = "dense" # can be one of ['sparse', 'dense']
self.robot_uids = "panda" # can be one of ['panda', 'fetch']

#self.batch_size = num_envs # ???

#self.use_dict_obs_space = kwargs.pop('use_dict_obs_space', True)

# self.env = gym2.make( self.env_name,
# env_type=kwargs.pop('env_type', 'gym'),
# num_envs=num_envs,
# batch_size=self.batch_size,
# **kwargs
# )
self.env = gym2.make(self.env_name,
num_envs=num_envs,
# render_mode="rgb_array",
obs_mode=self.obs_mode,
reward_mode=self.reward_mode,
control_mode=self.control_mode,
robot_uids=self.robot_uids,
enable_shadow=True # this makes the default lighting cast shadows
)

# from mani_skill.utils.wrappers import RecordEpisode
# # to make it look a little more realistic, we will enable shadows which make the default lighting cast shadows
# self.env = RecordEpisode(
# self.env,
# "./videos", # the directory to save replay videos and trajectories to
# # on GPU sim we record intervals, not by single episodes as there are multiple envs
# # each 100 steps a new video is saved
# max_steps_per_video=240
# )

# if self.use_dict_obs_space:
# self.observation_space = gym.spaces.Dict({
# 'observation' : self.env.observation_space,
Expand All @@ -113,7 +114,7 @@ def __init__(self, config_name, num_envs, **kwargs):
policy_obs_space = self.env.unwrapped.single_observation_space
print("Observation Space Unwrapped:", policy_obs_space)

self._clip_obs = np.inf
self._clip_obs = 5.0

# TODO: single function
if isinstance(policy_obs_space, gymnasium.spaces.Dict):
Expand Down Expand Up @@ -154,27 +155,35 @@ def __init__(self, config_name, num_envs, **kwargs):
print("Single action apace:", action_space)
self.action_space = gym.spaces.Box(-self._clip_actions, self._clip_actions, action_space.shape)

def step(self, action):
# # move actions to sim-device
def step(self, actions):
# TODO: use env device
# TODO: add reward/observation clamoping
# TODO: move buffers to rl-device
# TODO: move actions to sim-device
# actions = actions.detach().clone().to(device=self._sim_device)
# # clip the actions
# actions = torch.clamp(actions, -self._clip_actions, self._clip_actions)
actions = torch.clamp(actions, -self._clip_actions, self._clip_actions)

obs_dict, rew, terminated, truncated, extras = self.env.step(action)
obs_dict, rew, terminated, truncated, extras = self.env.step(actions)
#self.env.render_human()
# move time out information to the extras dict
# this is only needed for infinite horizon tasks
# note: only useful when `value_bootstrap` is True in the agent configuration
extras["time_outs"] = truncated #truncated.to(device=self._rl_device)

extras["time_outs"] = truncated

# process observations and states
#obs_and_states = self._process_obs(obs_dict)

obs_and_states = {'obs': obs_dict}
# move buffers to rl-device
# note: we perform clone to prevent issues when rl-device and sim-device are the same.
#rew = rew.to(device=self._rl_device)
#dones = (terminated | truncated).to(device=self._rl_device)
dones = (terminated | truncated) # stop if any environment terminates/truncates

# dones = (terminated | truncated)
dones = torch.logical_or(terminated, truncated)
if dones.any():
env_idx = torch.arange(0, self.env.num_envs, device=self.env.device)[dones] # device=self.device
reset_obs, _ = self.env.reset(options=dict(env_idx=env_idx))
obs_and_states['obs'] = reset_obs

#print('extras keys:', extras.keys())
# extras = {
# k: v.to(device=self._rl_device, non_blocking=True) if hasattr(v, "to") else v for k, v in extras.items()
# }
Expand All @@ -183,32 +192,21 @@ def step(self, action):
if "log" in extras:
extras["episode"] = extras.pop("log")

# done = (terminated | truncated).any() # stop if any environment terminates/truncates
# info['time_outs'] = truncated

# if self.obs_mode == 'state_dict':
# next_obs = obs
# TODO: revisit success calculation
if "success" in extras:
extras["successes"] = extras["success"].float().mean()

# if self.flatten_obs:
# next_obs = flatten_dict(next_obs)

# if self.use_dict_obs_space:
# next_obs = {
# 'observation': next_obs,
# 'reward': np.clip(reward, -1, 1),
# 'last_action': action
# }
#return next_obs, reward, is_done, info
return obs_and_states, rew, dones, extras

def reset(self):
obs = self.env.reset()
# print(obs)
print("obs reset shape:", obs[0].shape)
# if self.flatten_obs:
# obs = flatten_dict(obs)

return {'obs': obs[0]}

def render(self, mode='human'):
self.env.render_human()

def get_number_of_agents(self):
return 1
Expand Down
3 changes: 0 additions & 3 deletions rl_games/torch_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,9 +67,6 @@ def __init__(self, algo_observer=None):

self.algo_observer = algo_observer if algo_observer else DefaultAlgoObserver()
torch.backends.cudnn.benchmark = True
### it did not help for lots for openai gym envs anyway :(
#torch.backends.cudnn.deterministic = True
#torch.use_deterministic_algorithms(True)

def reset(self):
pass
Expand Down
4 changes: 3 additions & 1 deletion runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,9 @@
else:
ray.init(object_store_memory=1024*1024*1000)

runner = Runner()
from rl_games.common.algo_observer import IsaacAlgoObserver
runner = Runner(IsaacAlgoObserver())

try:
runner.load(config)
except yaml.YAMLError as exc:
Expand Down

0 comments on commit 8be57aa

Please sign in to comment.