diff --git a/rl_games/algos_torch/a2c_continuous.py b/rl_games/algos_torch/a2c_continuous.py index 89acc1ee..1861f216 100644 --- a/rl_games/algos_torch/a2c_continuous.py +++ b/rl_games/algos_torch/a2c_continuous.py @@ -161,7 +161,7 @@ def calc_gradients(self, input_dict): if self.zero_rnn_on_done: batch_dict['dones'] = input_dict['dones'] - with torch.cuda.amp.autocast(enabled=self.mixed_precision): + with torch.amp.autocast("cuda", enabled=self.mixed_precision): res_dict = self.model(batch_dict) action_log_probs = res_dict['prev_neglogp'] values = res_dict['values'] diff --git a/rl_games/algos_torch/a2c_discrete.py b/rl_games/algos_torch/a2c_discrete.py index 495dbdc2..b72d0506 100644 --- a/rl_games/algos_torch/a2c_discrete.py +++ b/rl_games/algos_torch/a2c_discrete.py @@ -159,7 +159,7 @@ def calc_gradients(self, input_dict): if self.zero_rnn_on_done: batch_dict['dones'] = input_dict['dones'] - with torch.cuda.amp.autocast(enabled=self.mixed_precision): + with torch.amp.autocast("cuda", enabled=self.mixed_precision): res_dict = self.model(batch_dict) action_log_probs = res_dict['prev_neglogp'] values = res_dict['values'] diff --git a/rl_games/networks/vision_networks.py b/rl_games/networks/vision_networks.py index dd0d39f7..d2f25ca6 100644 --- a/rl_games/networks/vision_networks.py +++ b/rl_games/networks/vision_networks.py @@ -15,12 +15,15 @@ def load(self, params): class Network(NetworkBuilder.BaseNetwork): def __init__(self, params, **kwargs): - self.actions_num = actions_num = kwargs.pop('actions_num') - input_shape = kwargs.pop('input_shape') - print('input_shape:', input_shape) - if type(input_shape) is dict: - input_shape = input_shape['camera'] - proprio_shape = input_shape['proprio'] + self.actions_num = kwargs.pop('actions_num') + full_input_shape = kwargs.pop('input_shape') + proprio_size = 0 # Number of proprioceptive features + if type(full_input_shape) is dict: + input_shape = full_input_shape['camera'] + proprio_shape = full_input_shape['proprio'] + proprio_size = proprio_shape[0] + else: + input_shape = full_input_shape self.num_seqs = kwargs.pop('num_seqs', 1) self.value_size = kwargs.pop('value_size', 1) @@ -32,7 +35,6 @@ def __init__(self, params, **kwargs): self.cnn = self._build_impala(input_shape, self.conv_depths) cnn_output_size = self._calc_input_size(input_shape, self.cnn) - proprio_size = proprio_shape[0] # Number of proprioceptive features mlp_input_size = cnn_output_size + proprio_size if len(self.units) == 0: @@ -66,18 +68,18 @@ def __init__(self, params, **kwargs): self.flatten_act = self.activations_factory.create(self.activation) if self.is_discrete: - self.logits = torch.nn.Linear(out_size, actions_num) + self.logits = torch.nn.Linear(out_size, self.actions_num) if self.is_continuous: - self.mu = torch.nn.Linear(out_size, actions_num) + self.mu = torch.nn.Linear(out_size, self.actions_num) self.mu_act = self.activations_factory.create(self.space_config['mu_activation']) mu_init = self.init_factory.create(**self.space_config['mu_init']) self.sigma_act = self.activations_factory.create(self.space_config['sigma_activation']) sigma_init = self.init_factory.create(**self.space_config['sigma_init']) if self.fixed_sigma: - self.sigma = nn.Parameter(torch.zeros(actions_num, requires_grad=True, dtype=torch.float32), requires_grad=True) + self.sigma = nn.Parameter(torch.zeros(self.actions_num, requires_grad=True, dtype=torch.float32), requires_grad=True) else: - self.sigma = torch.nn.Linear(out_size, actions_num) + self.sigma = torch.nn.Linear(out_size, self.actions_num) mlp_init = self.init_factory.create(**self.initializer) @@ -97,13 +99,14 @@ def __init__(self, params, **kwargs): else: sigma_init(self.sigma.weight) - mlp_init(self.value.weight) + mlp_init(self.value.weight) def forward(self, obs_dict): - # for key in obs_dict: - # print(key) - obs = obs_dict['camera'] - proprio = obs_dict['proprio'] + # print(obs_dict.keys()) + # print(obs_dict['obs'].keys()) + # currently works only dictinary of camera and proprio observations + obs = obs_dict['obs']['camera'] + proprio = obs_dict['obs']['proprio'] if self.permute_input: obs = obs.permute((0, 3, 1, 2)) @@ -119,7 +122,9 @@ def forward(self, obs_dict): out = torch.cat([out, proprio], dim=1) if self.has_rnn: - seq_length = obs_dict['seq_length'] + # TODO: Double check, it's not lways present!!! + #seq_length = obs_dict['seq_length'] + seq_length = obs_dict.get('seq_length', 1) out_in = out if not self.is_rnn_before_mlp: @@ -179,7 +184,21 @@ def load(self, params): self.space_config = params['space']['continuous'] self.fixed_sigma = self.space_config['fixed_sigma'] elif self.is_discrete: - self.space_config = params['sA2CVisionBuildernv_depths'] + self.space_config = params['space']['discrete'] + elif self.is_multi_discrete: + self.space_config = params['space']['multi_discrete'] + + self.has_rnn = 'rnn' in params + if self.has_rnn: + self.rnn_units = params['rnn']['units'] + self.rnn_layers = params['rnn']['layers'] + self.rnn_name = params['rnn']['name'] + self.is_rnn_before_mlp = params['rnn'].get('before_mlp', False) + self.rnn_ln = params['rnn'].get('layer_norm', False) + + self.has_cnn = True + self.permute_input = params['cnn'].get('permute_input', True) + self.conv_depths = params['cnn']['conv_depths'] self.require_rewards = params.get('require_rewards') self.require_last_actions = params.get('require_last_actions') @@ -200,10 +219,10 @@ def is_rnn(self): def get_default_rnn_state(self): num_layers = self.rnn_layers if self.rnn_name == 'lstm': - return (torch.zeros((num_layers, self.num_seqs, self.rnn_units)), + return (torch.zeros((num_layers, self.num_seqs, self.rnn_units)), torch.zeros((num_layers, self.num_seqs, self.rnn_units))) else: - return (torch.zeros((num_layers, self.num_seqs, self.rnn_units))) + return (torch.zeros((num_layers, self.num_seqs, self.rnn_units))) def build(self, name, **kwargs): net = A2CVisionBuilder.Network(self.params, **kwargs) @@ -220,11 +239,14 @@ def load(self, params): class Network(NetworkBuilder.BaseNetwork): def __init__(self, params, **kwargs): self.actions_num = kwargs.pop('actions_num') - input_shape = kwargs.pop('input_shape') - print('input_shape:', input_shape) - if isinstance(input_shape, dict): - input_shape = input_shape['camera'] - proprio_shape = input_shape['proprio'] + full_input_shape = kwargs.pop('input_shape') + proprio_size = 0 # Number of proprioceptive features + if type(full_input_shape) is dict: + input_shape = full_input_shape['camera'] + proprio_shape = full_input_shape['proprio'] + proprio_size = proprio_shape[0] + else: + input_shape = full_input_shape self.num_seqs = kwargs.pop('num_seqs', 1) self.value_size = kwargs.pop('value_size', 1)