diff --git a/rl_games/algos_torch/network_builder.py b/rl_games/algos_torch/network_builder.py index ce5651c5..ab047920 100644 --- a/rl_games/algos_torch/network_builder.py +++ b/rl_games/algos_torch/network_builder.py @@ -216,23 +216,29 @@ def __init__(self, params, **kwargs): if self.separate: self.critic_cnn = self._build_conv( **cnn_args) - mlp_input_shape = self._calc_input_size(input_shape, self.actor_cnn) + cnn_output_size = self._calc_input_size(input_shape, self.actor_cnn) - in_mlp_shape = mlp_input_shape + mlp_input_size = cnn_output_size if len(self.units) == 0: - out_size = mlp_input_shape + out_size = cnn_output_size else: out_size = self.units[-1] if self.has_rnn: if not self.is_rnn_before_mlp: rnn_in_size = out_size - out_size = self.rnn_units if self.rnn_concat_input: - rnn_in_size += in_mlp_shape + rnn_in_size += cnn_output_size + + out_size = self.rnn_units + if self.rnn_concat_output: + out_size += cnn_output_size else: - rnn_in_size = in_mlp_shape - in_mlp_shape = self.rnn_units + rnn_in_size = cnn_output_size + + mlp_input_size = self.rnn_units + if self.rnn_concat_output: + mlp_input_size += cnn_output_size if self.separate: self.a_rnn = self._build_rnn(self.rnn_name, rnn_in_size, self.rnn_units, self.rnn_layers) @@ -246,7 +252,7 @@ def __init__(self, params, **kwargs): self.layer_norm = torch.nn.LayerNorm(self.rnn_units) mlp_args = { - 'input_size' : in_mlp_shape, + 'input_size' : mlp_input_size, 'units' : self.units, 'activation' : self.activation, 'norm_func_name' : self.normalization, @@ -325,15 +331,15 @@ def forward(self, obs_dict): if self.has_rnn: seq_length = obs_dict.get('seq_length', 1) + a_cnn_out = a_out + c_cnn_out = c_out if not self.is_rnn_before_mlp: - a_out_in = a_out - c_out_in = c_out - a_out = self.actor_mlp(a_out_in) - c_out = self.critic_mlp(c_out_in) + a_out = self.actor_mlp(a_cnn_out) + c_out = self.critic_mlp(c_cnn_out) if self.rnn_concat_input: - a_out = torch.cat([a_out, a_out_in], dim=1) - c_out = torch.cat([c_out, c_out_in], dim=1) + a_out = torch.cat([a_out, a_cnn_out], dim=1) + c_out = torch.cat([c_out, c_cnn_out], dim=1) batch_size = a_out.size()[0] num_seqs = batch_size // seq_length @@ -369,6 +375,10 @@ def forward(self, obs_dict): c_states = (c_states,) states = a_states + c_states + if self.rnn_concat_output: + a_out = torch.cat([a_out, a_cnn_out], dim=1) + c_out = torch.cat([c_out, c_cnn_out], dim=1) + if self.is_rnn_before_mlp: a_out = self.actor_mlp(a_out) c_out = self.critic_mlp(c_out) @@ -402,12 +412,11 @@ def forward(self, obs_dict): if self.has_rnn: seq_length = obs_dict.get('seq_length', 1) - out_in = out + cnn_out = out if not self.is_rnn_before_mlp: - out_in = out out = self.actor_mlp(out) if self.rnn_concat_input: - out = torch.cat([out, out_in], dim=1) + out = torch.cat([out, cnn_out], dim=1) batch_size = out.size()[0] num_seqs = batch_size // seq_length @@ -426,6 +435,8 @@ def forward(self, obs_dict): if self.rnn_ln: out = self.layer_norm(out) + if self.rnn_concat_output: + out = torch.cat([out, cnn_out], dim=1) if self.is_rnn_before_mlp: out = self.actor_mlp(out) if type(states) is not tuple: @@ -518,6 +529,7 @@ def load(self, params): self.rnn_ln = params['rnn'].get('layer_norm', False) self.is_rnn_before_mlp = params['rnn'].get('before_mlp', False) self.rnn_concat_input = params['rnn'].get('concat_input', False) + self.rnn_concat_output = params['rnn'].get('concat_output', False) if 'cnn' in params: self.has_cnn = True @@ -620,12 +632,11 @@ def __init__(self, params, **kwargs): input_shape = torch_ext.shape_whc_to_cwh(input_shape) self.cnn = self._build_impala(input_shape, self.conv_depths) - mlp_input_shape = self._calc_input_size(input_shape, self.cnn) - - in_mlp_shape = mlp_input_shape + cnn_output_size = self._calc_input_size(input_shape, self.cnn) + mlp_input_size = cnn_output_size if len(self.units) == 0: - out_size = mlp_input_shape + out_size = cnn_output_size else: out_size = self.units[-1] @@ -634,17 +645,19 @@ def __init__(self, params, **kwargs): rnn_in_size = out_size out_size = self.rnn_units else: - rnn_in_size = in_mlp_shape - in_mlp_shape = self.rnn_units + rnn_in_size = mlp_input_size + mlp_input_size = self.rnn_units + if self.require_rewards: rnn_in_size += 1 if self.require_last_actions: rnn_in_size += actions_num + self.rnn = self._build_rnn(self.rnn_name, rnn_in_size, self.rnn_units, self.rnn_layers) #self.layer_norm = torch.nn.LayerNorm(self.rnn_units) mlp_args = { - 'input_size' : in_mlp_shape, + 'input_size' : mlp_input_size, 'units' :self.units, 'activation' : self.activation, 'norm_func_name' : self.normalization, @@ -899,8 +912,6 @@ def __init__(self, params, **kwargs): NetworkBuilder.BaseNetwork.__init__(self) self.load(params) - mlp_input_shape = input_shape - actor_mlp_args = { 'input_size' : obs_dim, 'units' : self.units,