Skip to content

Commit

Permalink
Added aux loss to the backbone network. Added debug save images. Trai…
Browse files Browse the repository at this point in the history
…ning configs improvements.
  • Loading branch information
ViktorM committed Sep 16, 2024
1 parent ee19a1c commit fa3469a
Show file tree
Hide file tree
Showing 6 changed files with 97 additions and 58 deletions.
11 changes: 6 additions & 5 deletions rl_games/configs/maniskill/maniskill_pickcube_impala_lstm.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ params:
concat_output: True

config:
name: PickCube_RGB_impala_LSTM_norm
name: PickCube_RGB_AUX_LOSS_add_impala_LSTM_norm
env_name: maniskill
reward_shaper:
scale_value: 1.0
Expand All @@ -70,13 +70,13 @@ params:
reward_shaper:
scale_value: 1.0
gamma: 0.99
tau : 0.95
tau: 0.95
learning_rate: 2e-4
lr_schedule: linear
kl_threshold: 0.008
max_epochs: 20000
save_best_after: 25
save_frequency: 500
save_frequency: 1000
grad_norm: 1.0
entropy_coef: 0.0
truncate_grads: True
Expand All @@ -92,11 +92,12 @@ params:

env_config:
env_name: PickCube-v1
obs_mode: rgbd
obs_mode: rgb
control_mode: pd_ee_delta_pose
reward_mode: dense
aux_loss: True

player:
render: True
render: False
deterministic: True

Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,6 @@ params:

network:
name: vision_actor_critic
require_rewards: False
require_last_actions: False
separate: False
value_shape: 1
space:
Expand All @@ -40,7 +38,7 @@ params:
regularizer:
name: None
mlp:
units: [256]
units: [256, 128]
activation: elu
regularizer:
name: None
Expand All @@ -55,7 +53,7 @@ params:
concat_output: True

config:
name: PickCube_RGB_Impala_Small_LSTM_RMSnorm_embedding_128envs
name: PickCube_RGB_AUX_LOSS_Impala_Small_LSTM_256x128_MLP_100weight
env_name: maniskill
reward_shaper:
scale_value: 1.0
Expand All @@ -67,22 +65,22 @@ params:
normalize_input: False
normalize_value: True
normalize_advantage: True
num_actors: 128 #256
num_actors: 256
reward_shaper:
scale_value: 1.0
gamma: 0.99
tau : 0.95
tau: 0.95
learning_rate: 2e-4
lr_schedule: None
kl_threshold: 0.008
max_epochs: 10000
save_best_after: 25
save_frequency: 500
save_frequency: 1000
grad_norm: 1.0
entropy_coef: 0.0
truncate_grads: True
e_clip: 0.2
horizon_length: 32
horizon_length: 64
minibatch_size: 2048
mini_epochs: 2
critic_coef: 1
Expand All @@ -93,9 +91,10 @@ params:

env_config:
env_name: PickCube-v1
obs_mode: rgbd
obs_mode: rgb
control_mode: pd_ee_delta_pose
reward_mode: dense
aux_loss: True

player:
render: False
Expand Down
19 changes: 10 additions & 9 deletions rl_games/configs/maniskill/maniskill_pickcube_vision.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ params:
zero_init_residual: True
norm_layer: None
mlp:
units: [256]
units: [256, 128]
activation: elu
regularizer:
name: None
Expand All @@ -55,7 +55,7 @@ params:
concat_output: True

config:
name: PickCube_RGB_resnet18_LSTM_norm_embedding_128envs_2e-4_linear_lr_first_layer_retrain
name: PickCube_RGB_resnet18_LSTM_256x128_MLP
env_name: maniskill
reward_shaper:
scale_value: 1.0
Expand All @@ -67,23 +67,23 @@ params:
normalize_input: False
normalize_value: True
normalize_advantage: True
num_actors: 64
num_actors: 128
reward_shaper:
scale_value: 1.0
gamma: 0.99
tau : 0.95
tau: 0.95
learning_rate: 2e-4
lr_schedule: linear
kl_threshold: 0.008
max_epochs: 20000
max_epochs: 10000
save_best_after: 25
save_frequency: 500
save_frequency: 1000
grad_norm: 1.0
entropy_coef: 0.0
truncate_grads: True
e_clip: 0.2
horizon_length: 128
minibatch_size: 512
horizon_length: 64 #128
minibatch_size: 1024
mini_epochs: 2
critic_coef: 1
clip_value: True
Expand All @@ -93,9 +93,10 @@ params:

env_config:
env_name: PickCube-v1
obs_mode: rgbd
obs_mode: rgb
control_mode: pd_ee_delta_pose
reward_mode: dense
aux_loss: False

player:
render: False
Expand Down
33 changes: 27 additions & 6 deletions rl_games/envs/maniskill.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import gymnasium as gym2
import gymnasium.spaces.utils
from gymnasium.vector.utils import batch_space
from mani_skill.utils import common


VecEnvObs = Dict[str, torch.Tensor | Dict[str, torch.Tensor]]
Expand Down Expand Up @@ -36,6 +37,19 @@ def _process_obs(self, obs_dict: VecEnvObs) -> torch.Tensor | dict[str, torch.Te
# TODO: add state processing for asymmetric case
return obs

def save_images_to_file(images: torch.Tensor, file_path: str):
"""Save images to file.
Args:
images: A tensor of shape (N, H, W, C) containing the images.
file_path: The path to save the images to.
"""
from torchvision.utils import make_grid, save_image

save_image(
make_grid(torch.swapaxes(images.unsqueeze(1), 1, -1).squeeze(-1), nrow=round(images.shape[0] ** 0.5)), file_path
)


class RlgFlattenRGBDObservationWrapper(gym2.ObservationWrapper):
"""
Expand All @@ -51,10 +65,11 @@ class RlgFlattenRGBDObservationWrapper(gym2.ObservationWrapper):

def __init__(self, env, rgb=True, depth=False, state=True, aux_loss=False) -> None:
from mani_skill.envs.sapien_env import BaseEnv
from mani_skill.utils import common

self.base_env: BaseEnv = env.unwrapped
self.aux_loss = aux_loss
self.write_image_to_file = False

super().__init__(env)
self.include_rgb = rgb
self.include_depth = depth
Expand Down Expand Up @@ -83,6 +98,9 @@ def observation(self, observation: Dict):
images.append(cam_data["depth"])
images = torch.concat(images, axis=-1)

if self.write_image_to_file:
save_images_to_file(images.float() / 255.0, f"pickup_cube_{'rgb'}.png")

# flatten the rest of the data which should just be state data
observation = common.flatten_state_dict(observation, use_torch=True)

Expand All @@ -103,7 +121,6 @@ def __init__(self, config_name, num_envs, **kwargs):
import gymnasium
import gymnasium as gym2
import mani_skill.envs
from mani_skill.utils.wrappers import FlattenRGBDObservationWrapper

# Can be any env_id from the list of Rigid-Body envs: https://maniskill.readthedocs.io/en/latest/tasks/index.html
self.env_name = kwargs.pop('env_name', 'PickCube-v1') # can be one of ['PickCube-v1', 'PegInsertionSide-v1', 'StackCube-v1']
Expand Down Expand Up @@ -141,7 +158,7 @@ def __init__(self, config_name, num_envs, **kwargs):
print("Observation Space Unwrapped Before:", policy_obs_space)

# TODO: add pointcloud and Depth support
if self.obs_mode == 'rgbd':
if self.obs_mode == 'rgb' or self.obs_mode == 'rgbd':
self.env = RlgFlattenRGBDObservationWrapper(self.env, aux_loss=self.aux_loss)
policy_obs_space = self.env.unwrapped.single_observation_space
print("Observation Space Unwrapped After:", policy_obs_space)
Expand All @@ -152,7 +169,7 @@ def __init__(self, config_name, num_envs, **kwargs):
for key, value in policy_obs_space.items():
print("Key:", key)
print("Value:", value)
if key == 'rgbd':
if key == 'rgb' or key == 'rgbd':
print("RGBD Shape:", value.shape)
print("RGBD Dtype:", value.dtype)
print(value)
Expand Down Expand Up @@ -188,6 +205,8 @@ def __init__(self, config_name, num_envs, **kwargs):
# check if we have a dictionary of observations
for key in policy_obs_space.keys():
if not isinstance(policy_obs_space[key], gymnasium.spaces.Box):
print("Key:", key)
print("Value:", policy_obs_space[key])
raise NotImplementedError(
f"Dictinary of dictinary observations support was not testes: '{type(policy_obs_space[key])}'."
)
Expand Down Expand Up @@ -228,7 +247,7 @@ def step(self, actions):
# dones = (terminated | truncated)
dones = torch.logical_or(terminated, truncated)
if dones.any():
env_idx = torch.arange(0, self.env.num_envs, device=self.env.device)[dones] # device=self.device
env_idx = torch.arange(0, self.env.unwrapped.num_envs, device=self.env.unwrapped.device)[dones] # device=self.device
reset_obs, _ = self.env.reset(options=dict(env_idx=env_idx))
obs_and_states['obs'] = reset_obs

Expand Down Expand Up @@ -278,4 +297,6 @@ def get_env_info(self):


def create_maniskill(**kwargs):
return Maniskill("", num_envs=kwargs.pop('num_actors', 16), **kwargs)
print("Creating Maniskill env with the following parameters:")
print(kwargs)
return Maniskill("", num_envs=kwargs.pop('num_actors', 4), **kwargs)
Loading

0 comments on commit fa3469a

Please sign in to comment.