Skip to content

Commit

Permalink
Layer norm for vision model. Better maniskill training configs.
Browse files Browse the repository at this point in the history
  • Loading branch information
ViktorM committed Sep 8, 2024
1 parent a78caac commit 8b274d1
Show file tree
Hide file tree
Showing 6 changed files with 27 additions and 46 deletions.
1 change: 0 additions & 1 deletion rl_games/configs/maniskill/maniskill_ant.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@ params:
normalize_advantage: True
gamma: 0.99
tau: 0.95

learning_rate: 3e-4
lr_schedule: adaptive
kl_threshold: 0.008
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ params:
concat_output: True

config:
name: PickCube_RGB_Impala_Small_LSTM_norm_embedding
name: PickCube_RGB_Impala_Small_LSTM_RMSnorm_embedding_128envs
env_name: maniskill
reward_shaper:
scale_value: 1.0
Expand All @@ -67,7 +67,7 @@ params:
normalize_input: False
normalize_value: True
normalize_advantage: True
num_actors: 256
num_actors: 128 #256
reward_shaper:
scale_value: 1.0
gamma: 0.99
Expand Down
7 changes: 3 additions & 4 deletions rl_games/configs/maniskill/maniskill_pickcube_state.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@ params:
normalize_advantage: True
gamma: 0.99
tau: 0.95

learning_rate: 3e-4
lr_schedule: adaptive
kl_threshold: 0.008
Expand All @@ -48,17 +47,17 @@ params:
use_smooth_clamp: True
bound_loss_type: regularisation
bounds_loss_coef: 0.0005
max_epochs: 5000
max_epochs: 2000
save_best_after: 25
save_frequency: 100
save_frequency: 500
num_actors: 4096
horizon_length: 16
minibatch_size: 16384
mini_epochs: 6
critic_coef: 2

env_config:
env_name: PickCube-v1 # todo: add list of all envs
env_name: PickCube-v1 # todo: add list of all envs
observation_mode: state
control_mode: pd_ee_delta_pose
reward_mode: dense
Expand Down
27 changes: 14 additions & 13 deletions rl_games/configs/maniskill/maniskill_pickcube_vision.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -40,21 +40,22 @@ params:
zero_init_residual: True
norm_layer: None
mlp:
units: [512]
units: [256]
activation: elu
regularizer:
name: None
initializer:
name: default
# rnn:
# name: lstm
# units: 512
# layers: 1
# before_mlp: True
# concat_output: True
rnn:
name: lstm
layer_norm: True
units: 512
layers: 1
before_mlp: True
concat_output: True

config:
name: PickCube_RGB_resnet18
name: PickCube_RGB_resnet18_LSTM_norm_embedding_64envs
env_name: maniskill
reward_shaper:
scale_value: 1.0
Expand All @@ -66,22 +67,22 @@ params:
normalize_input: False
normalize_value: True
normalize_advantage: True
num_actors: 128
num_actors: 64 #128
reward_shaper:
scale_value: 1.0
gamma: 0.99
tau : 0.95
learning_rate: 1e-4
lr_schedule: None
lr_schedule: linear
kl_threshold: 0.008
max_epochs: 50000
max_epochs: 20000
save_best_after: 25
save_frequency: 500
grad_norm: 1.0
entropy_coef: 0.0
truncate_grads: True
e_clip: 0.2
horizon_length: 64
horizon_length: 128
minibatch_size: 512
mini_epochs: 2
critic_coef: 1
Expand All @@ -97,6 +98,6 @@ params:
reward_mode: dense

player:
render: True
render: False
deterministic: True

25 changes: 2 additions & 23 deletions rl_games/envs/maniskill.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ def __init__(self, config_name, num_envs, **kwargs):
robot_uids=self.robot_uids,
enable_shadow=True # this makes the default lighting cast shadows
)

print("Observation Space Before:", self.env.observation_space)
policy_obs_space = self.env.unwrapped.single_observation_space
print("Observation Space Unwrapped Before:", policy_obs_space)
Expand Down Expand Up @@ -181,22 +181,6 @@ def __init__(self, config_name, num_envs, **kwargs):
else:
self.observation_space = gym.spaces.Box(-self._clip_obs, self._clip_obs, policy_obs_space.shape)

# if isinstance(critic_obs_space, gymnasium.spaces.Dict):
# # check if we have a dictionary of observations
# for key in critic_obs_space.keys():
# if not isinstance(critic_obs_space[key], gymnasium.spaces.Box):
# raise NotImplementedError(
# f"Dictinary of dictinary observations support has not been tested yet: '{type(policy_obs_space[key])}'."
# )
# self.state_observation_space = gym.spaces.Dict(
# {
# key: gym.spaces.Box(-self._clip_obs, self._clip_obs, critic_obs_space[key].shape)
# for key in critic_obs_space.keys()
# }
# )
# else:
# self.observation_space = gym.spaces.Box(-self._clip_obs, self._clip_obs, policy_obs_space.shape)

self._clip_actions = 1.0

action_space = self.env.unwrapped.single_action_space
Expand All @@ -213,10 +197,8 @@ def step(self, actions):
actions = torch.clamp(actions, -self._clip_actions, self._clip_actions)

obs_dict, rew, terminated, truncated, extras = self.env.step(actions)
#self.env.render_human()
# move time out information to the extras dict
# this is only needed for infinite horizon tasks
# note: only useful when `value_bootstrap` is True in the agent configuration
# note: only used when `value_bootstrap` is True in the agent configuration

extras["time_outs"] = truncated

Expand All @@ -240,7 +222,6 @@ def step(self, actions):
if "log" in extras:
extras["episode"] = extras.pop("log")

# TODO: revisit success calculation
if "success" in extras:
extras["successes"] = extras["success"].float().mean()

Expand All @@ -251,8 +232,6 @@ def step(self, actions):

def reset(self):
obs = self.env.reset()
print("obs:", obs[0].keys())

obs_dict = {'obs': obs[0]}

# if self.obs_mode == 'rgbd':
Expand Down
9 changes: 6 additions & 3 deletions rl_games/networks/vision_networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,13 +48,14 @@ def __init__(self, params, **kwargs):

self.running_mean_std = torch.jit.script(RunningMeanStd((mlp_input_size,)))
self.layer_norm_emb = torch.nn.LayerNorm(mlp_input_size)
#self.layer_norm_emb = torch.nn.RMSNorm(mlp_input_size)

if self.has_rnn:
if not self.is_rnn_before_mlp:
rnn_in_size = out_size
out_size = self.rnn_units
else:
rnn_in_size = mlp_input_size
rnn_in_size = mlp_input_size
mlp_input_size = self.rnn_units

self.rnn = self._build_rnn(self.rnn_name, rnn_in_size, self.rnn_units, self.rnn_layers)
Expand Down Expand Up @@ -132,8 +133,6 @@ def forward(self, obs_dict):
out = self.flatten_act(out)

out = torch.cat([out, proprio], dim=1)
#print('out shape: ', out.shape)
#out = self.norm_emb(out)
out = self.layer_norm_emb(out)

if self.has_rnn:
Expand Down Expand Up @@ -295,6 +294,8 @@ def __init__(self, params, **kwargs):
else:
out_size = self.units[-1]

self.layer_norm_emb = torch.nn.LayerNorm((mlp_input_size,))

if self.has_rnn:
if not self.is_rnn_before_mlp:
rnn_in_size = out_size
Expand Down Expand Up @@ -376,6 +377,8 @@ def forward(self, obs_dict):
if self.proprio_size > 0:
out = torch.cat([out, proprio], dim=1)

out = self.layer_norm_emb(out)

if self.has_rnn:
seq_length = obs_dict.get('seq_length', 1)

Expand Down

0 comments on commit 8b274d1

Please sign in to comment.