Skip to content

Commit

Permalink
WIP: adding proprio observations to resnet network.
Browse files Browse the repository at this point in the history
  • Loading branch information
ViktorM committed Aug 16, 2024
1 parent ae043b9 commit 1e58c1e
Show file tree
Hide file tree
Showing 7 changed files with 284 additions and 21 deletions.
10 changes: 9 additions & 1 deletion rl_games/algos_torch/a2c_continuous.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def __init__(self, base_name, params):
'actions_num' : self.actions_num,
'input_shape' : obs_shape,
'num_seqs' : self.num_actors * self.num_agents,
'value_size': self.env_info.get('value_size',1),
'value_size': self.env_info.get('value_size', 1),
'normalize_value' : self.normalize_value,
'normalize_input': self.normalize_input,
}
Expand Down Expand Up @@ -144,6 +144,14 @@ def calc_gradients(self, input_dict):
'obs' : obs_batch,
}

# print("TEST")
# print("----------------")
# for key in input_dict:
# print(key)

# if "proprio" in input_dict:
# batch_dict['proprio'] = input_dict['proprio']

rnn_masks = None
if self.is_rnn:
rnn_masks = input_dict['rnn_masks']
Expand Down
5 changes: 4 additions & 1 deletion rl_games/algos_torch/model_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,10 @@ def __init__(self):
self.network_factory.set_builders(NETWORK_REGISTRY)
self.network_factory.register_builder('actor_critic', lambda **kwargs: network_builder.A2CBuilder())
self.network_factory.register_builder('resnet_actor_critic',
lambda **kwargs: network_builder.A2CResnetBuilder())
lambda **kwargs: network_builder.A2CResnetBuilder())
self.network_factory.register_builder('vision_actor_critic',
lambda **kwargs: network_builder.A2CVisionBuilder())

self.network_factory.register_builder('rnd_curiosity', lambda **kwargs: network_builder.RNDCuriosityBuilder())
self.network_factory.register_builder('soft_actor_critic', lambda **kwargs: network_builder.SACBuilder())

Expand Down
1 change: 0 additions & 1 deletion rl_games/algos_torch/moving_mean_std.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,6 @@ def _get_stats(self):
else:
raise NotImplementedError(self.impl)


def _update_stats(self, x):
m = self.decay
if self.impl == 'off':
Expand Down
265 changes: 249 additions & 16 deletions rl_games/algos_torch/network_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,8 +263,8 @@ def __init__(self, params, **kwargs):

mlp_args = {
'input_size' : mlp_input_size,
'units' : self.units,
'activation' : self.activation,
'units' : self.units,
'activation' : self.activation,
'norm_func_name' : self.normalization,
'dense_func' : torch.nn.Linear,
'd2rl' : self.is_d2rl,
Expand Down Expand Up @@ -417,7 +417,7 @@ def forward(self, obs_dict):
else:
out = obs
out = self.actor_cnn(out)
out = out.flatten(1)
out = out.flatten(1)

if self.has_rnn:
seq_length = obs_dict.get('seq_length', 1)
Expand Down Expand Up @@ -664,12 +664,12 @@ def __init__(self, params, **kwargs):
rnn_in_size += actions_num

self.rnn = self._build_rnn(self.rnn_name, rnn_in_size, self.rnn_units, self.rnn_layers)
#self.layer_norm = torch.nn.LayerNorm(self.rnn_units)
self.layer_norm = torch.nn.LayerNorm(self.rnn_units)

mlp_args = {
'input_size' : mlp_input_size,
'units' :self.units,
'activation' : self.activation,
'units' :self.units,
'activation' : self.activation,
'norm_func_name' : self.normalization,
'dense_func' : torch.nn.Linear
}
Expand Down Expand Up @@ -701,7 +701,7 @@ def __init__(self, params, **kwargs):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
#nn.init.xavier_uniform_(m.weight, gain=nn.init.calculate_gain('relu'))
for m in self.mlp:
if isinstance(m, nn.Linear):
if isinstance(m, nn.Linear):
mlp_init(m.weight)

if self.is_discrete:
Expand Down Expand Up @@ -733,9 +733,242 @@ def forward(self, obs_dict):

out = obs
out = self.cnn(out)
out = out.flatten(1)
out = out.flatten(1)
out = self.flatten_act(out)

if self.has_rnn:
#seq_length = obs_dict['seq_length']
seq_length = obs_dict.get('seq_length', 1)

out_in = out
if not self.is_rnn_before_mlp:
out_in = out
out = self.mlp(out)

obs_list = [out]
if self.require_rewards:
obs_list.append(reward.unsqueeze(1))
if self.require_last_actions:
obs_list.append(last_action)
out = torch.cat(obs_list, dim=1)
batch_size = out.size()[0]
num_seqs = batch_size // seq_length
out = out.reshape(num_seqs, seq_length, -1)

if len(states) == 1:
states = states[0]

out = out.transpose(0, 1)
if dones is not None:
dones = dones.reshape(num_seqs, seq_length, -1)
dones = dones.transpose(0, 1)
out, states = self.rnn(out, states, dones, bptt_len)
out = out.transpose(0, 1)
out = out.contiguous().reshape(out.size()[0] * out.size()[1], -1)

if self.rnn_ln:
out = self.layer_norm(out)
if self.is_rnn_before_mlp:
out = self.mlp(out)
if type(states) is not tuple:
states = (states,)
else:
out = self.mlp(out)

value = self.value_act(self.value(out))

if self.is_discrete:
logits = self.logits(out)
return logits, value, states

if self.is_continuous:
mu = self.mu_act(self.mu(out))
if self.fixed_sigma:
sigma = self.sigma_act(self.sigma)
else:
sigma = self.sigma_act(self.sigma(out))
return mu, mu*0 + sigma, value, states

def load(self, params):
self.separate = False
self.units = params['mlp']['units']
self.activation = params['mlp']['activation']
self.initializer = params['mlp']['initializer']
self.is_discrete = 'discrete' in params['space']
self.is_continuous = 'continuous' in params['space']
self.is_multi_discrete = 'multi_discrete'in params['space']
self.value_activation = params.get('value_activation', 'None')
self.normalization = params.get('normalization', None)

if self.is_continuous:
self.space_config = params['space']['continuous']
self.fixed_sigma = self.space_config['fixed_sigma']
elif self.is_discrete:
self.space_config = params['space']['discrete']
elif self.is_multi_discrete:
self.space_config = params['space']['multi_discrete']

self.has_rnn = 'rnn' in params
if self.has_rnn:
self.rnn_units = params['rnn']['units']
self.rnn_layers = params['rnn']['layers']
self.rnn_name = params['rnn']['name']
self.is_rnn_before_mlp = params['rnn'].get('before_mlp', False)
self.rnn_ln = params['rnn'].get('layer_norm', False)

self.has_cnn = True
self.permute_input = params['cnn'].get('permute_input', True)
self.conv_depths = params['cnn']['conv_depths']
self.require_rewards = params.get('require_rewards')
self.require_last_actions = params.get('require_last_actions')

def _build_impala(self, input_shape, depths):
in_channels = input_shape[0]
layers = nn.ModuleList()
for d in depths:
layers.append(ImpalaSequential(in_channels, d))
in_channels = d
return nn.Sequential(*layers)

def is_separate_critic(self):
return False

def is_rnn(self):
return self.has_rnn

def get_default_rnn_state(self):
num_layers = self.rnn_layers
if self.rnn_name == 'lstm':
return (torch.zeros((num_layers, self.num_seqs, self.rnn_units)),
torch.zeros((num_layers, self.num_seqs, self.rnn_units)))
else:
return (torch.zeros((num_layers, self.num_seqs, self.rnn_units)))

def build(self, name, **kwargs):
net = A2CResnetBuilder.Network(self.params, **kwargs)
return net


class A2CVisionBuilder(NetworkBuilder):
def __init__(self, **kwargs):
NetworkBuilder.__init__(self)

def load(self, params):
self.params = params

class Network(NetworkBuilder.BaseNetwork):
def __init__(self, params, **kwargs):
self.actions_num = actions_num = kwargs.pop('actions_num')
input_shape = kwargs.pop('input_shape')
print('input_shape:', input_shape)
if type(input_shape) is dict:
input_shape = input_shape['observation']
self.num_seqs = num_seqs = kwargs.pop('num_seqs', 1)
self.value_size = kwargs.pop('value_size', 1)

# TODO: add proprioception from config
# no normilization for proprioception for now
proprio_shape = kwargs.pop('proprio_shape', None)
self.proprio_size = 68

NetworkBuilder.BaseNetwork.__init__(self)
self.load(params)
if self.permute_input:
input_shape = torch_ext.shape_whc_to_cwh(input_shape)

self.cnn = self._build_impala(input_shape, self.conv_depths)
cnn_output_size = self._calc_input_size(input_shape, self.cnn)

mlp_input_size = cnn_output_size + self.proprio_size
if len(self.units) == 0:
out_size = cnn_output_size
else:
out_size = self.units[-1]

if self.has_rnn:
if not self.is_rnn_before_mlp:
rnn_in_size = out_size
out_size = self.rnn_units
else:
rnn_in_size = mlp_input_size
mlp_input_size = self.rnn_units

if self.require_rewards:
rnn_in_size += 1
if self.require_last_actions:
rnn_in_size += actions_num

self.rnn = self._build_rnn(self.rnn_name, rnn_in_size, self.rnn_units, self.rnn_layers)
self.layer_norm = torch.nn.LayerNorm(self.rnn_units)

mlp_args = {
'input_size' : mlp_input_size,
'units' :self.units,
'activation' : self.activation,
'norm_func_name' : self.normalization,
'dense_func' : torch.nn.Linear
}

self.mlp = self._build_mlp(**mlp_args)

self.value = self._build_value_layer(out_size, self.value_size)
self.value_act = self.activations_factory.create(self.value_activation)
self.flatten_act = self.activations_factory.create(self.activation)

if self.is_discrete:
self.logits = torch.nn.Linear(out_size, actions_num)
if self.is_continuous:
self.mu = torch.nn.Linear(out_size, actions_num)
self.mu_act = self.activations_factory.create(self.space_config['mu_activation'])
mu_init = self.init_factory.create(**self.space_config['mu_init'])
self.sigma_act = self.activations_factory.create(self.space_config['sigma_activation'])
sigma_init = self.init_factory.create(**self.space_config['sigma_init'])

if self.fixed_sigma:
self.sigma = nn.Parameter(torch.zeros(actions_num, requires_grad=True, dtype=torch.float32), requires_grad=True)
else:
self.sigma = torch.nn.Linear(out_size, actions_num)

mlp_init = self.init_factory.create(**self.initializer)

for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
#nn.init.xavier_uniform_(m.weight, gain=nn.init.calculate_gain('relu'))
for m in self.mlp:
if isinstance(m, nn.Linear):
mlp_init(m.weight)

if self.is_discrete:
mlp_init(self.logits.weight)
if self.is_continuous:
mu_init(self.mu.weight)
if self.fixed_sigma:
sigma_init(self.sigma)
else:
sigma_init(self.sigma.weight)

mlp_init(self.value.weight)

def forward(self, obs_dict):
# for key in obs_dict:
# print(key)
obs = obs_dict['obs']['camera']
proprio = obs_dict['obs']['proprio']
if self.permute_input:
obs = obs.permute((0, 3, 1, 2))

dones = obs_dict.get('dones', None)
bptt_len = obs_dict.get('bptt_len', 0)
states = obs_dict.get('rnn_states', None)

out = obs
out = self.cnn(out)
out = out.flatten(1)
out = self.flatten_act(out)

out = torch.cat([out, proprio], dim=1)

if self.has_rnn:
#seq_length = obs_dict['seq_length']
seq_length = obs_dict.get('seq_length', 1)
Expand Down Expand Up @@ -824,7 +1057,7 @@ def load(self, params):

def _build_impala(self, input_shape, depths):
in_channels = input_shape[0]
layers = nn.ModuleList()
layers = nn.ModuleList()
for d in depths:
layers.append(ImpalaSequential(in_channels, d))
in_channels = d
Expand All @@ -845,7 +1078,7 @@ def get_default_rnn_state(self):
return (torch.zeros((num_layers, self.num_seqs, self.rnn_units)))

def build(self, name, **kwargs):
net = A2CResnetBuilder.Network(self.params, **kwargs)
net = A2CVisionBuilder.Network(self.params, **kwargs)
return net


Expand Down Expand Up @@ -923,19 +1156,19 @@ def __init__(self, params, **kwargs):
self.load(params)

actor_mlp_args = {
'input_size' : obs_dim,
'units' : self.units,
'activation' : self.activation,
'input_size' : obs_dim,
'units' : self.units,
'activation' : self.activation,
'norm_func_name' : self.normalization,
'dense_func' : torch.nn.Linear,
'd2rl' : self.is_d2rl,
'norm_only_first_layer' : self.norm_only_first_layer
}

critic_mlp_args = {
'input_size' : obs_dim + action_dim,
'units' : self.units,
'activation' : self.activation,
'input_size' : obs_dim + action_dim,
'units' : self.units,
'activation' : self.activation,
'norm_func_name' : self.normalization,
'dense_func' : torch.nn.Linear,
'd2rl' : self.is_d2rl,
Expand Down
2 changes: 1 addition & 1 deletion rl_games/algos_torch/torch_ext.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,7 @@ def get_coord(x):
x.size(0), 1, 1, 1).type_as(x)
CoordConv2d.pool[key] = coord
return CoordConv2d.pool[key]

def forward(self, x):
return torch.nn.functional.conv2d(torch.cat([x, self.get_coord(x).type_as(x)], 1), self.weight, self.bias, self.stride,
self.padding, self.dilation, self.groups)
Expand Down Expand Up @@ -245,7 +246,6 @@ def forward(self, x):
return self.gamma.expand_as(x) * (x - mean) / (std + self.eps) + self.beta.expand_as(x)



class DiscreteActionsEncoder(nn.Module):
def __init__(self, actions_max, mlp_out, emb_size, num_agents, use_embedding):
super().__init__()
Expand Down
Loading

0 comments on commit 1e58c1e

Please sign in to comment.