diff --git a/rl_games/algos_torch/network_builder.py b/rl_games/algos_torch/network_builder.py index 84f9a14a..97814c40 100644 --- a/rl_games/algos_torch/network_builder.py +++ b/rl_games/algos_torch/network_builder.py @@ -859,15 +859,20 @@ def load(self, params): class Network(NetworkBuilder.BaseNetwork): def __init__(self, params, **kwargs): self.actions_num = actions_num = kwargs.pop('actions_num') - input_shape = kwargs.pop('input_shape') - print('input_shape:', input_shape) - if type(input_shape) is dict: - input_shape = input_shape['camera'] - proprio_shape = input_shape['proprio'] + full_input_shape = kwargs.pop('input_shape') + proprio_size = 0 # Number of proprioceptive features + if type(full_input_shape) is dict: + input_shape = full_input_shape['camera'] + proprio_shape = full_input_shape['proprio'] + proprio_size = proprio_shape[0] + else: + input_shape = full_input_shape self.num_seqs = kwargs.pop('num_seqs', 1) self.value_size = kwargs.pop('value_size', 1) + print(params) + NetworkBuilder.BaseNetwork.__init__(self) self.load(params) if self.permute_input: @@ -875,7 +880,6 @@ def __init__(self, params, **kwargs): self.cnn = self._build_impala(input_shape, self.conv_depths) cnn_output_size = self._calc_input_size(input_shape, self.cnn) - proprio_size = proprio_shape[0] # Number of proprioceptive features mlp_input_size = cnn_output_size + proprio_size if len(self.units) == 0: @@ -943,10 +947,11 @@ def __init__(self, params, **kwargs): mlp_init(self.value.weight) def forward(self, obs_dict): - # for key in obs_dict: - # print(key) - obs = obs_dict['camera'] - proprio = obs_dict['proprio'] + # print(obs_dict.keys()) + # print(obs_dict['obs'].keys()) + # currently works only dictinary of camera and proprio observations + obs = obs_dict['obs']['camera'] + proprio = obs_dict['obs']['proprio'] if self.permute_input: obs = obs.permute((0, 3, 1, 2)) @@ -962,7 +967,9 @@ def forward(self, obs_dict): out = torch.cat([out, proprio], dim=1) if self.has_rnn: - seq_length = obs_dict['seq_length'] + # TODO: Double check, it's not lways present!!! + #seq_length = obs_dict['seq_length'] + seq_length = obs_dict.get('seq_length', 1) out_in = out if not self.is_rnn_before_mlp: @@ -1022,7 +1029,21 @@ def load(self, params): self.space_config = params['space']['continuous'] self.fixed_sigma = self.space_config['fixed_sigma'] elif self.is_discrete: - self.space_config = params['sA2CVisionBuildernv_depths'] + self.space_config = params['space']['discrete'] + elif self.is_multi_discrete: + self.space_config = params['space']['multi_discrete'] + + self.has_rnn = 'rnn' in params + if self.has_rnn: + self.rnn_units = params['rnn']['units'] + self.rnn_layers = params['rnn']['layers'] + self.rnn_name = params['rnn']['name'] + self.is_rnn_before_mlp = params['rnn'].get('before_mlp', False) + self.rnn_ln = params['rnn'].get('layer_norm', False) + + self.has_cnn = True + self.permute_input = params['cnn'].get('permute_input', True) + self.conv_depths = params['cnn']['conv_depths'] self.require_rewards = params.get('require_rewards') self.require_last_actions = params.get('require_last_actions') @@ -1052,6 +1073,7 @@ def build(self, name, **kwargs): net = A2CVisionBuilder.Network(self.params, **kwargs) return net + class DiagGaussianActor(NetworkBuilder.BaseNetwork): """torch.distributions implementation of an diagonal Gaussian policy.""" def __init__(self, output_dim, log_std_bounds, **mlp_args): diff --git a/rl_games/common/experience.py b/rl_games/common/experience.py index 7ff5ad46..b298e602 100644 --- a/rl_games/common/experience.py +++ b/rl_games/common/experience.py @@ -357,10 +357,10 @@ def _init_from_aux_dict(self, tensor_dict): def _create_tensor_from_space(self, space, base_shape): if type(space) is gym.spaces.Box: dtype = numpy_to_torch_dtype_dict[space.dtype] - return torch.zeros(base_shape + space.shape, dtype= dtype, device = self.device) + return torch.zeros(base_shape + space.shape, dtype=dtype, device=self.device) if type(space) is gym.spaces.Discrete: dtype = numpy_to_torch_dtype_dict[space.dtype] - return torch.zeros(base_shape, dtype= dtype, device = self.device) + return torch.zeros(base_shape, dtype=dtype, device = self.device) if type(space) is gym.spaces.Tuple: ''' assuming that tuple is only Discrete tuple