Skip to content

Commit

Permalink
Add concat_output for rnn (#260)
Browse files Browse the repository at this point in the history
  • Loading branch information
tylerlum authored Nov 20, 2023
1 parent 32f70ee commit a5d788a
Showing 1 changed file with 37 additions and 26 deletions.
63 changes: 37 additions & 26 deletions rl_games/algos_torch/network_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,23 +216,29 @@ def __init__(self, params, **kwargs):
if self.separate:
self.critic_cnn = self._build_conv( **cnn_args)

mlp_input_shape = self._calc_input_size(input_shape, self.actor_cnn)
cnn_output_size = self._calc_input_size(input_shape, self.actor_cnn)

in_mlp_shape = mlp_input_shape
mlp_input_size = cnn_output_size
if len(self.units) == 0:
out_size = mlp_input_shape
out_size = cnn_output_size
else:
out_size = self.units[-1]

if self.has_rnn:
if not self.is_rnn_before_mlp:
rnn_in_size = out_size
out_size = self.rnn_units
if self.rnn_concat_input:
rnn_in_size += in_mlp_shape
rnn_in_size += cnn_output_size

out_size = self.rnn_units
if self.rnn_concat_output:
out_size += cnn_output_size
else:
rnn_in_size = in_mlp_shape
in_mlp_shape = self.rnn_units
rnn_in_size = cnn_output_size

mlp_input_size = self.rnn_units
if self.rnn_concat_output:
mlp_input_size += cnn_output_size

if self.separate:
self.a_rnn = self._build_rnn(self.rnn_name, rnn_in_size, self.rnn_units, self.rnn_layers)
Expand All @@ -246,7 +252,7 @@ def __init__(self, params, **kwargs):
self.layer_norm = torch.nn.LayerNorm(self.rnn_units)

mlp_args = {
'input_size' : in_mlp_shape,
'input_size' : mlp_input_size,
'units' : self.units,
'activation' : self.activation,
'norm_func_name' : self.normalization,
Expand Down Expand Up @@ -325,15 +331,15 @@ def forward(self, obs_dict):
if self.has_rnn:
seq_length = obs_dict.get('seq_length', 1)

a_cnn_out = a_out
c_cnn_out = c_out
if not self.is_rnn_before_mlp:
a_out_in = a_out
c_out_in = c_out
a_out = self.actor_mlp(a_out_in)
c_out = self.critic_mlp(c_out_in)
a_out = self.actor_mlp(a_cnn_out)
c_out = self.critic_mlp(c_cnn_out)

if self.rnn_concat_input:
a_out = torch.cat([a_out, a_out_in], dim=1)
c_out = torch.cat([c_out, c_out_in], dim=1)
a_out = torch.cat([a_out, a_cnn_out], dim=1)
c_out = torch.cat([c_out, c_cnn_out], dim=1)

batch_size = a_out.size()[0]
num_seqs = batch_size // seq_length
Expand Down Expand Up @@ -369,6 +375,10 @@ def forward(self, obs_dict):
c_states = (c_states,)
states = a_states + c_states

if self.rnn_concat_output:
a_out = torch.cat([a_out, a_cnn_out], dim=1)
c_out = torch.cat([c_out, c_cnn_out], dim=1)

if self.is_rnn_before_mlp:
a_out = self.actor_mlp(a_out)
c_out = self.critic_mlp(c_out)
Expand Down Expand Up @@ -402,12 +412,11 @@ def forward(self, obs_dict):
if self.has_rnn:
seq_length = obs_dict.get('seq_length', 1)

out_in = out
cnn_out = out
if not self.is_rnn_before_mlp:
out_in = out
out = self.actor_mlp(out)
if self.rnn_concat_input:
out = torch.cat([out, out_in], dim=1)
out = torch.cat([out, cnn_out], dim=1)

batch_size = out.size()[0]
num_seqs = batch_size // seq_length
Expand All @@ -426,6 +435,8 @@ def forward(self, obs_dict):

if self.rnn_ln:
out = self.layer_norm(out)
if self.rnn_concat_output:
out = torch.cat([out, cnn_out], dim=1)
if self.is_rnn_before_mlp:
out = self.actor_mlp(out)
if type(states) is not tuple:
Expand Down Expand Up @@ -518,6 +529,7 @@ def load(self, params):
self.rnn_ln = params['rnn'].get('layer_norm', False)
self.is_rnn_before_mlp = params['rnn'].get('before_mlp', False)
self.rnn_concat_input = params['rnn'].get('concat_input', False)
self.rnn_concat_output = params['rnn'].get('concat_output', False)

if 'cnn' in params:
self.has_cnn = True
Expand Down Expand Up @@ -620,12 +632,11 @@ def __init__(self, params, **kwargs):
input_shape = torch_ext.shape_whc_to_cwh(input_shape)

self.cnn = self._build_impala(input_shape, self.conv_depths)
mlp_input_shape = self._calc_input_size(input_shape, self.cnn)

in_mlp_shape = mlp_input_shape
cnn_output_size = self._calc_input_size(input_shape, self.cnn)

mlp_input_size = cnn_output_size
if len(self.units) == 0:
out_size = mlp_input_shape
out_size = cnn_output_size
else:
out_size = self.units[-1]

Expand All @@ -634,17 +645,19 @@ def __init__(self, params, **kwargs):
rnn_in_size = out_size
out_size = self.rnn_units
else:
rnn_in_size = in_mlp_shape
in_mlp_shape = self.rnn_units
rnn_in_size = mlp_input_size
mlp_input_size = self.rnn_units

if self.require_rewards:
rnn_in_size += 1
if self.require_last_actions:
rnn_in_size += actions_num

self.rnn = self._build_rnn(self.rnn_name, rnn_in_size, self.rnn_units, self.rnn_layers)
#self.layer_norm = torch.nn.LayerNorm(self.rnn_units)

mlp_args = {
'input_size' : in_mlp_shape,
'input_size' : mlp_input_size,
'units' :self.units,
'activation' : self.activation,
'norm_func_name' : self.normalization,
Expand Down Expand Up @@ -899,8 +912,6 @@ def __init__(self, params, **kwargs):
NetworkBuilder.BaseNetwork.__init__(self)
self.load(params)

mlp_input_shape = input_shape

actor_mlp_args = {
'input_size' : obs_dim,
'units' : self.units,
Expand Down

0 comments on commit a5d788a

Please sign in to comment.