Skip to content

Commit

Permalink
Auxilarry losses visual training WIP. Maniskill env switched to using…
Browse files Browse the repository at this point in the history
… uint8 for RGBD envs.
  • Loading branch information
ViktorM committed Sep 10, 2024
1 parent b3467f2 commit 80ee4df
Show file tree
Hide file tree
Showing 5 changed files with 57 additions and 46 deletions.
29 changes: 1 addition & 28 deletions rl_games/algos_torch/a2c_continuous.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,34 +176,6 @@ def calc_gradients(self, input_dict):
rnn_masks
)

a_loss = self.actor_loss_func(old_action_log_probs_batch, action_log_probs, advantage, self.ppo, curr_e_clip)

if self.has_value_loss:
c_loss = common_losses.critic_loss(
self.model, value_preds_batch, values, curr_e_clip, return_batch,
self.clip_value
)
else:
c_loss = torch.zeros(1, device=self.ppo_device)
if self.bound_loss_type == 'regularisation':
b_loss = self.reg_loss(mu)
elif self.bound_loss_type == 'bound':
b_loss = self.bound_loss(mu)
else:
b_loss = torch.zeros(1, device=self.ppo_device)

losses, sum_mask = torch_ext.apply_masks(
[
a_loss.unsqueeze(1),
c_loss,
entropy.unsqueeze(1),
b_loss.unsqueeze(1)
],
rnn_masks
)
a_loss, c_loss, entropy, b_loss = losses[0], losses[1], losses[2], losses[3]

loss = a_loss + 0.5 * c_loss * self.critic_coef - entropy * self.entropy_coef + b_loss * self.bounds_loss_coef
aux_loss = self.model.get_aux_loss()
self.aux_loss_dict = {}
if aux_loss is not None:
Expand All @@ -213,6 +185,7 @@ def calc_gradients(self, input_dict):
self.aux_loss_dict[k] = v.detach()
else:
self.aux_loss_dict[k] = [v.detach()]

if self.multi_gpu:
self.optimizer.zero_grad()
else:
Expand Down
2 changes: 1 addition & 1 deletion rl_games/algos_torch/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,7 @@ def __init__(self, a2c_network, **kwargs):

def get_aux_loss(self):
return self.a2c_network.get_aux_loss()

def is_rnn(self):
return self.a2c_network.is_rnn()

Expand Down
4 changes: 2 additions & 2 deletions rl_games/configs/maniskill/maniskill_pickcube_vision.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ params:
concat_output: True

config:
name: PickCube_RGB_resnet18_LSTM_norm_embedding_64envs
name: PickCube_RGB_resnet18_LSTM_norm_embedding_64envs_auxloss
env_name: maniskill
reward_shaper:
scale_value: 1.0
Expand All @@ -67,7 +67,7 @@ params:
normalize_input: False
normalize_value: True
normalize_advantage: True
num_actors: 64 #128
num_actors: 64
reward_shaper:
scale_value: 1.0
gamma: 0.99
Expand Down
35 changes: 22 additions & 13 deletions rl_games/envs/maniskill.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,8 @@ def observation(self, observation: Dict):
# print("Observation:", observation.keys())
# for key, value in observation.items():
# print(key, value.keys())
tcp_pose = observation['extra']['tcp_pose']
aux_target = observation['extra']['aux_target']
del observation['extra']['aux_target']
# print("Input Obs:", observation.keys())
# print("Input Obs Agent:", observation['agent'].keys())
# print("Input Obs Extra:", observation['extra'].keys())
Expand All @@ -80,15 +81,16 @@ def observation(self, observation: Dict):
if self.include_depth:
images.append(cam_data["depth"])
images = torch.concat(images, axis=-1)

# flatten the rest of the data which should just be state data
observation = common.flatten_state_dict(observation, use_torch=True)

ret = dict()
if self.include_state:
ret["proprio"] = observation
if self.aux_loss:
ret['aux_target'] = tcp_pose
ret["camera"] = images.float() / 255.0
ret['aux_target'] = aux_target
ret["camera"] = images

return ret

Expand Down Expand Up @@ -141,6 +143,7 @@ def __init__(self, config_name, num_envs, **kwargs):
if self.obs_mode == 'rgbd':
self.env = RlgFlattenRGBDObservationWrapper(self.env, aux_loss=self.aux_loss)
policy_obs_space = self.env.unwrapped.single_observation_space
print("Observation Space Unwrapped After:", policy_obs_space)

modified_policy_obs_space = {}

Expand All @@ -150,18 +153,20 @@ def __init__(self, config_name, num_envs, **kwargs):
print("Value:", value)
if key == 'rgbd':
print("RGBD Shape:", value.shape)
print("RGBD Dtype:", value.dtype)
print(value)
self.env.unwrapped.single_observation_space[key].dtype = np.uint8
value.dtype = np.int8
modified_policy_obs_space['camera'] = value
elif key == 'state':
modified_policy_obs_space['proprio'] = value
else:
modified_policy_obs_space[key] = value

# if self.aux_loss:
# modified_policy_obs_space['aux_target'] = gymnasium.spaces.Dict(low=-np.Inf, high=np.Inf, shape=(3, ), dtype=np.float32)

print("Observation Space Unwrapped:", modified_policy_obs_space)
print("Observation Space Unwrapped Done:", modified_policy_obs_space)

policy_obs_space = gymnasium.spaces.Dict(modified_policy_obs_space)
print("Observation Space After:", policy_obs_space)

# from mani_skill.utils.wrappers import RecordEpisode
# # to make it look a little more realistic, we will enable shadows which make the default lighting cast shadows
Expand All @@ -175,6 +180,8 @@ def __init__(self, config_name, num_envs, **kwargs):

self._clip_obs = 5.0

self.observation_space = gym.spaces.Dict()

# TODO: single function
if isinstance(policy_obs_space, gymnasium.spaces.Dict):
# check if we have a dictionary of observations
Expand All @@ -183,15 +190,17 @@ def __init__(self, config_name, num_envs, **kwargs):
raise NotImplementedError(
f"Dictinary of dictinary observations support was not testes: '{type(policy_obs_space[key])}'."
)
self.observation_space = gym.spaces.Dict(
{
key: gym.spaces.Box(-self._clip_obs, self._clip_obs, policy_obs_space[key].shape)
for key in policy_obs_space.keys()
}
)

val = policy_obs_space[key]
if val.dtype == np.float16 or val.dtype == np.float32:
self.observation_space[key] = gym.spaces.Box(-self._clip_obs, self._clip_obs, val.shape)
elif policy_obs_space[key].dtype == np.uint8:
self.observation_space[key] = gym.spaces.Box(0, 255, val.shape, dtype=np.uint8)
else:
self.observation_space = gym.spaces.Box(-self._clip_obs, self._clip_obs, policy_obs_space.shape)

print("Observation Space:", self.observation_space)

self._clip_actions = 1.0

action_space = self.env.unwrapped.single_action_space
Expand Down
33 changes: 31 additions & 2 deletions rl_games/networks/vision_networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def __init__(self, params, **kwargs):
self.aux_loss_linear = nn.Linear(out_size, self.target_shape)

self.aux_loss_map = {
'aux_dist_loss' : None
'aux_dist_loss': None
}

self.value = self._build_value_layer(out_size, self.value_size)
Expand Down Expand Up @@ -129,7 +129,7 @@ def get_aux_loss(self):
def forward(self, obs_dict):
obs = obs_dict['obs']['camera']
proprio = obs_dict['obs']['proprio']
target_obs = obs['tcp_pose'] # obs[self.target_key]
target_obs = obs[self.target_key]
if self.permute_input:
obs = obs.permute((0, 3, 1, 2))

Expand Down Expand Up @@ -281,13 +281,24 @@ def __init__(self, params, **kwargs):
self.actions_num = kwargs.pop('actions_num')
full_input_shape = kwargs.pop('input_shape')

print('full_input_shape: ', full_input_shape)

self.target_key = 'aux_target'
self.target_shape = full_input_shape[self.target_key]
print("Target shape: ", self.target_shape)

print("Observations shape: ", full_input_shape)

self.proprio_size = 0 # Number of proprioceptive features
if isinstance(full_input_shape, dict):
input_shape = full_input_shape['camera']
proprio_shape = full_input_shape['proprio']
self.proprio_size = proprio_shape[0]

# # TODO: This is a hack to get the target shape
# for k, v in full_input_shape.items():
# if self.target_key == k:
# self.target_shape = v[0]
else:
input_shape = full_input_shape

Expand Down Expand Up @@ -330,6 +341,12 @@ def __init__(self, params, **kwargs):

self.mlp = self._build_mlp(**mlp_args)

self.aux_loss_linear = nn.Linear(out_size, self.target_shape[0])

self.aux_loss_map = {
'aux_dist_loss': None
}

self.value = self._build_value_layer(out_size, self.value_size)
self.value_act = self.activations_factory.create(self.value_activation)
self.flatten_act = self.activations_factory.create(self.activation)
Expand Down Expand Up @@ -365,13 +382,22 @@ def __init__(self, params, **kwargs):

mlp_init(self.value.weight)

def get_aux_loss(self):
return self.aux_loss_map

def forward(self, obs_dict):
if self.proprio_size > 0:
obs = obs_dict['obs']['camera']
proprio = obs_dict['obs']['proprio']
else:
obs = obs_dict['obs']

target_obs = obs_dict['obs'][self.target_key]

# print('obs.min(): ', obs.min())
# print('obs.max(): ', obs.max())
# print('obs.shape: ', obs.shape)

if self.permute_input:
obs = obs.permute((0, 3, 1, 2))

Expand Down Expand Up @@ -426,6 +452,9 @@ def forward(self, obs_dict):

value = self.value_act(self.value(out))

y = self.aux_loss_linear(out)
self.aux_loss_map['aux_dist_loss'] = torch.nn.functional.mse_loss(y, target_obs)

if self.is_discrete:
logits = self.logits(out)
return logits, value, states
Expand Down

0 comments on commit 80ee4df

Please sign in to comment.