diff --git a/rl_games/algos_torch/model_builder.py b/rl_games/algos_torch/model_builder.py index 58378063..0ebb38db 100644 --- a/rl_games/algos_torch/model_builder.py +++ b/rl_games/algos_torch/model_builder.py @@ -1,7 +1,7 @@ from rl_games.common import object_factory import rl_games.algos_torch -from rl_games.algos_torch import network_builder -from rl_games.algos_torch import models +from rl_games.algos_torch import models, network_builder +from rl_games.networks import vision_networks NETWORK_REGISTRY = {} MODEL_REGISTRY = {} @@ -21,7 +21,9 @@ def __init__(self): self.network_factory.register_builder('resnet_actor_critic', lambda **kwargs: network_builder.A2CResnetBuilder()) self.network_factory.register_builder('vision_actor_critic', - lambda **kwargs: network_builder.A2CVisionBuilder()) + lambda **kwargs: vision_networks.A2CVisionBuilder()) + self.network_factory.register_builder('e2e_vision_actor_critic', + lambda **kwargs: vision_networks.A2CVisionBackboneBuilder()) self.network_factory.register_builder('rnd_curiosity', lambda **kwargs: network_builder.RNDCuriosityBuilder()) self.network_factory.register_builder('soft_actor_critic', lambda **kwargs: network_builder.SACBuilder()) diff --git a/rl_games/algos_torch/network_builder.py b/rl_games/algos_torch/network_builder.py index 699ec7db..d89fc75f 100644 --- a/rl_games/algos_torch/network_builder.py +++ b/rl_games/algos_torch/network_builder.py @@ -849,239 +849,6 @@ def build(self, name, **kwargs): return net -class A2CVisionBuilder(NetworkBuilder): - def __init__(self, **kwargs): - NetworkBuilder.__init__(self) - - def load(self, params): - self.params = params - - class Network(NetworkBuilder.BaseNetwork): - def __init__(self, params, **kwargs): - self.actions_num = actions_num = kwargs.pop('actions_num') - input_shape = kwargs.pop('input_shape') - print('input_shape:', input_shape) - if type(input_shape) is dict: - input_shape = input_shape['observation'] - self.num_seqs = num_seqs = kwargs.pop('num_seqs', 1) - self.value_size = kwargs.pop('value_size', 1) - - # TODO: add proprioception from config - # no normilization for proprioception for now - proprio_shape = kwargs.pop('proprio_shape', None) - self.proprio_size = 68 - - NetworkBuilder.BaseNetwork.__init__(self) - self.load(params) - if self.permute_input: - input_shape = torch_ext.shape_whc_to_cwh(input_shape) - - self.cnn = self._build_impala(input_shape, self.conv_depths) - cnn_output_size = self._calc_input_size(input_shape, self.cnn) - - mlp_input_size = cnn_output_size + self.proprio_size - if len(self.units) == 0: - out_size = cnn_output_size - else: - out_size = self.units[-1] - - if self.has_rnn: - if not self.is_rnn_before_mlp: - rnn_in_size = out_size - out_size = self.rnn_units - else: - rnn_in_size = mlp_input_size - mlp_input_size = self.rnn_units - - if self.require_rewards: - rnn_in_size += 1 - if self.require_last_actions: - rnn_in_size += actions_num - - self.rnn = self._build_rnn(self.rnn_name, rnn_in_size, self.rnn_units, self.rnn_layers) - self.layer_norm = torch.nn.LayerNorm(self.rnn_units) - - mlp_args = { - 'input_size' : mlp_input_size, - 'units' :self.units, - 'activation' : self.activation, - 'norm_func_name' : self.normalization, - 'dense_func' : torch.nn.Linear - } - - self.mlp = self._build_mlp(**mlp_args) - - self.value = self._build_value_layer(out_size, self.value_size) - self.value_act = self.activations_factory.create(self.value_activation) - self.flatten_act = self.activations_factory.create(self.activation) - - if self.is_discrete: - self.logits = torch.nn.Linear(out_size, actions_num) - if self.is_continuous: - self.mu = torch.nn.Linear(out_size, actions_num) - self.mu_act = self.activations_factory.create(self.space_config['mu_activation']) - mu_init = self.init_factory.create(**self.space_config['mu_init']) - self.sigma_act = self.activations_factory.create(self.space_config['sigma_activation']) - sigma_init = self.init_factory.create(**self.space_config['sigma_init']) - - if self.fixed_sigma: - self.sigma = nn.Parameter(torch.zeros(actions_num, requires_grad=True, dtype=torch.float32), requires_grad=True) - else: - self.sigma = torch.nn.Linear(out_size, actions_num) - - mlp_init = self.init_factory.create(**self.initializer) - - for m in self.modules(): - if isinstance(m, nn.Conv2d): - nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') - #nn.init.xavier_uniform_(m.weight, gain=nn.init.calculate_gain('relu')) - for m in self.mlp: - if isinstance(m, nn.Linear): - mlp_init(m.weight) - - if self.is_discrete: - mlp_init(self.logits.weight) - if self.is_continuous: - mu_init(self.mu.weight) - if self.fixed_sigma: - sigma_init(self.sigma) - else: - sigma_init(self.sigma.weight) - - mlp_init(self.value.weight) - - def forward(self, obs_dict): - # for key in obs_dict: - # print(key) - obs = obs_dict['obs']['camera'] - proprio = obs_dict['obs']['proprio'] - if self.permute_input: - obs = obs.permute((0, 3, 1, 2)) - - dones = obs_dict.get('dones', None) - bptt_len = obs_dict.get('bptt_len', 0) - states = obs_dict.get('rnn_states', None) - - out = obs - out = self.cnn(out) - out = out.flatten(1) - out = self.flatten_act(out) - - out = torch.cat([out, proprio], dim=1) - - if self.has_rnn: - #seq_length = obs_dict['seq_length'] - seq_length = obs_dict.get('seq_length', 1) - - out_in = out - if not self.is_rnn_before_mlp: - out_in = out - out = self.mlp(out) - - obs_list = [out] - if self.require_rewards: - obs_list.append(reward.unsqueeze(1)) - if self.require_last_actions: - obs_list.append(last_action) - out = torch.cat(obs_list, dim=1) - batch_size = out.size()[0] - num_seqs = batch_size // seq_length - out = out.reshape(num_seqs, seq_length, -1) - - if len(states) == 1: - states = states[0] - - out = out.transpose(0, 1) - if dones is not None: - dones = dones.reshape(num_seqs, seq_length, -1) - dones = dones.transpose(0, 1) - out, states = self.rnn(out, states, dones, bptt_len) - out = out.transpose(0, 1) - out = out.contiguous().reshape(out.size()[0] * out.size()[1], -1) - - if self.rnn_ln: - out = self.layer_norm(out) - if self.is_rnn_before_mlp: - out = self.mlp(out) - if type(states) is not tuple: - states = (states,) - else: - out = self.mlp(out) - - value = self.value_act(self.value(out)) - - if self.is_discrete: - logits = self.logits(out) - return logits, value, states - - if self.is_continuous: - mu = self.mu_act(self.mu(out)) - if self.fixed_sigma: - sigma = self.sigma_act(self.sigma) - else: - sigma = self.sigma_act(self.sigma(out)) - return mu, mu*0 + sigma, value, states - - def load(self, params): - self.separate = False - self.units = params['mlp']['units'] - self.activation = params['mlp']['activation'] - self.initializer = params['mlp']['initializer'] - self.is_discrete = 'discrete' in params['space'] - self.is_continuous = 'continuous' in params['space'] - self.is_multi_discrete = 'multi_discrete'in params['space'] - self.value_activation = params.get('value_activation', 'None') - self.normalization = params.get('normalization', None) - - if self.is_continuous: - self.space_config = params['space']['continuous'] - self.fixed_sigma = self.space_config['fixed_sigma'] - elif self.is_discrete: - self.space_config = params['space']['discrete'] - elif self.is_multi_discrete: - self.space_config = params['space']['multi_discrete'] - - self.has_rnn = 'rnn' in params - if self.has_rnn: - self.rnn_units = params['rnn']['units'] - self.rnn_layers = params['rnn']['layers'] - self.rnn_name = params['rnn']['name'] - self.is_rnn_before_mlp = params['rnn'].get('before_mlp', False) - self.rnn_ln = params['rnn'].get('layer_norm', False) - - self.has_cnn = True - self.permute_input = params['cnn'].get('permute_input', True) - self.conv_depths = params['cnn']['conv_depths'] - self.require_rewards = params.get('require_rewards') - self.require_last_actions = params.get('require_last_actions') - - def _build_impala(self, input_shape, depths): - in_channels = input_shape[0] - layers = nn.ModuleList() - for d in depths: - layers.append(ImpalaSequential(in_channels, d)) - in_channels = d - return nn.Sequential(*layers) - - def is_separate_critic(self): - return False - - def is_rnn(self): - return self.has_rnn - - def get_default_rnn_state(self): - num_layers = self.rnn_layers - if self.rnn_name == 'lstm': - return (torch.zeros((num_layers, self.num_seqs, self.rnn_units)), - torch.zeros((num_layers, self.num_seqs, self.rnn_units))) - else: - return (torch.zeros((num_layers, self.num_seqs, self.rnn_units))) - - def build(self, name, **kwargs): - net = A2CVisionBuilder.Network(self.params, **kwargs) - return net - - class DiagGaussianActor(NetworkBuilder.BaseNetwork): """torch.distributions implementation of an diagonal Gaussian policy.""" def __init__(self, output_dim, log_std_bounds, **mlp_args): diff --git a/rl_games/common/experience.py b/rl_games/common/experience.py index 7bef426e..20ca19aa 100644 --- a/rl_games/common/experience.py +++ b/rl_games/common/experience.py @@ -379,10 +379,6 @@ def _create_tensor_from_space(self, space, base_shape): return t_dict def update_data(self, name, index, val): - print('name:', name) - print(self.tensor_dict.keys()) - print(self.tensor_dict[name].shape) - print(self.tensor_dict["obses"].shape) if type(val) is dict: for k,v in val.items(): self.tensor_dict[name][k][index,:] = v diff --git a/rl_games/common/player.py b/rl_games/common/player.py index 98be6501..c9cae62d 100644 --- a/rl_games/common/player.py +++ b/rl_games/common/player.py @@ -179,7 +179,7 @@ def _preproc_obs(self, obs_batch): def env_step(self, env, actions): if not self.is_tensor_obses: - actions = actions.cpu().numpy() + actions = actions.cpu().detach().numpy() obs, rewards, dones, infos = env.step(actions) if hasattr(obs, 'dtype') and obs.dtype == np.float64: obs = np.float32(obs) diff --git a/rl_games/envs/envpool.py b/rl_games/envs/envpool.py index f80ab7a9..e89d7403 100644 --- a/rl_games/envs/envpool.py +++ b/rl_games/envs/envpool.py @@ -28,7 +28,7 @@ def __init__(self, config_name, num_actors, **kwargs): ) if self.use_dict_obs_space: - self.observation_space= gym.spaces.Dict({ + self.observation_space = gym.spaces.Dict({ 'observation' : self.env.observation_space, 'reward' : gym.spaces.Box(low=0, high=1, shape=( ), dtype=np.float32), 'last_action': gym.spaces.Box(low=0, high=self.env.action_space.n, shape=(), dtype=int) diff --git a/rl_games/networks/vision_networks.py b/rl_games/networks/vision_networks.py new file mode 100644 index 00000000..fc221eb7 --- /dev/null +++ b/rl_games/networks/vision_networks.py @@ -0,0 +1,428 @@ +import torch +from torch import nn +import torch.nn.functional as F +import torch_ext +from rl_games.algos_torch.network_builder import NetworkBuilder + + +class A2CVisionBuilder(NetworkBuilder): + def __init__(self, **kwargs): + NetworkBuilder.__init__(self) + + def load(self, params): + self.params = params + + class Network(NetworkBuilder.BaseNetwork): + def __init__(self, params, **kwargs): + self.actions_num = actions_num = kwargs.pop('actions_num') + input_shape = kwargs.pop('input_shape') + print('input_shape:', input_shape) + if type(input_shape) is dict: + input_shape = input_shape['camera'] + proprio_shape = input_shape['proprio'] + self.num_seqs = num_seqs = kwargs.pop('num_seqs', 1) + self.value_size = kwargs.pop('value_size', 1) + + NetworkBuilder.BaseNetwork.__init__(self) + self.load(params) + if self.permute_input: + input_shape = torch_ext.shape_whc_to_cwh(input_shape) + + self.cnn = self._build_impala(input_shape, self.conv_depths) + cnn_output_size = self._calc_input_size(input_shape, self.cnn) + proprio_size = proprio_shape[0] # Number of proprioceptive features + + mlp_input_size = cnn_output_size + proprio_size + if len(self.units) == 0: + out_size = cnn_output_size + else: + out_size = self.units[-1] + + if self.has_rnn: + if not self.is_rnn_before_mlp: + rnn_in_size = out_size + out_size = self.rnn_units + else: + rnn_in_size = mlp_input_size + mlp_input_size = self.rnn_units + + self.rnn = self._build_rnn(self.rnn_name, rnn_in_size, self.rnn_units, self.rnn_layers) + self.layer_norm = torch.nn.LayerNorm(self.rnn_units) + + mlp_args = { + 'input_size' : mlp_input_size, + 'units' :self.units, + 'activation' : self.activation, + 'norm_func_name' : self.normalization, + 'dense_func' : torch.nn.Linear + } + + self.mlp = self._build_mlp(**mlp_args) + + self.value = self._build_value_layer(out_size, self.value_size) + self.value_act = self.activations_factory.create(self.value_activation) + self.flatten_act = self.activations_factory.create(self.activation) + + if self.is_discrete: + self.logits = torch.nn.Linear(out_size, actions_num) + if self.is_continuous: + self.mu = torch.nn.Linear(out_size, actions_num) + self.mu_act = self.activations_factory.create(self.space_config['mu_activation']) + mu_init = self.init_factory.create(**self.space_config['mu_init']) + self.sigma_act = self.activations_factory.create(self.space_config['sigma_activation']) + sigma_init = self.init_factory.create(**self.space_config['sigma_init']) + + if self.fixed_sigma: + self.sigma = nn.Parameter(torch.zeros(actions_num, requires_grad=True, dtype=torch.float32), requires_grad=True) + else: + self.sigma = torch.nn.Linear(out_size, actions_num) + + mlp_init = self.init_factory.create(**self.initializer) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + for m in self.mlp: + if isinstance(m, nn.Linear): + mlp_init(m.weight) + + if self.is_discrete: + mlp_init(self.logits.weight) + if self.is_continuous: + mu_init(self.mu.weight) + if self.fixed_sigma: + sigma_init(self.sigma) + else: + sigma_init(self.sigma.weight) + + mlp_init(self.value.weight) + + def forward(self, obs_dict): + # for key in obs_dict: + # print(key) + obs = obs_dict['camera'] + proprio = obs_dict['proprio'] + if self.permute_input: + obs = obs.permute((0, 3, 1, 2)) + + dones = obs_dict.get('dones', None) + bptt_len = obs_dict.get('bptt_len', 0) + states = obs_dict.get('rnn_states', None) + + out = obs + out = self.cnn(out) + out = out.flatten(1) + out = self.flatten_act(out) + + out = torch.cat([out, proprio], dim=1) + + if self.has_rnn: + seq_length = obs_dict['seq_length'] + + out_in = out + if not self.is_rnn_before_mlp: + out_in = out + out = self.mlp(out) + + batch_size = out.size()[0] + num_seqs = batch_size // seq_length + out = out.reshape(num_seqs, seq_length, -1) + + if len(states) == 1: + states = states[0] + + out = out.transpose(0, 1) + if dones is not None: + dones = dones.reshape(num_seqs, seq_length, -1) + dones = dones.transpose(0, 1) + out, states = self.rnn(out, states, dones, bptt_len) + out = out.transpose(0, 1) + out = out.contiguous().reshape(out.size()[0] * out.size()[1], -1) + + if self.rnn_ln: + out = self.layer_norm(out) + if self.is_rnn_before_mlp: + out = self.mlp(out) + if type(states) is not tuple: + states = (states,) + else: + out = self.mlp(out) + + value = self.value_act(self.value(out)) + + if self.is_discrete: + logits = self.logits(out) + return logits, value, states + + if self.is_continuous: + mu = self.mu_act(self.mu(out)) + if self.fixed_sigma: + sigma = self.sigma_act(self.sigma) + else: + sigma = self.sigma_act(self.sigma(out)) + return mu, mu*0 + sigma, value, states + + def load(self, params): + self.separate = False + self.units = params['mlp']['units'] + self.activation = params['mlp']['activation'] + self.initializer = params['mlp']['initializer'] + self.is_discrete = 'discrete' in params['space'] + self.is_continuous = 'continuous' in params['space'] + self.is_multi_discrete = 'multi_discrete'in params['space'] + self.value_activation = params.get('value_activation', 'None') + self.normalization = params.get('normalization', None) + + if self.is_continuous: + self.space_config = params['space']['continuous'] + self.fixed_sigma = self.space_config['fixed_sigma'] + elif self.is_discrete: + self.space_config = params['sA2CVisionBuildernv_depths'] + self.require_rewards = params.get('require_rewards') + self.require_last_actions = params.get('require_last_actions') + + def _build_impala(self, input_shape, depths): + in_channels = input_shape[0] + layers = nn.ModuleList() + for d in depths: + layers.append(ImpalaSequential(in_channels, d)) + in_channels = d + return nn.Sequential(*layers) + + def is_separate_critic(self): + return False + + def is_rnn(self): + return self.has_rnn + + def get_default_rnn_state(self): + num_layers = self.rnn_layers + if self.rnn_name == 'lstm': + return (torch.zeros((num_layers, self.num_seqs, self.rnn_units)), + torch.zeros((num_layers, self.num_seqs, self.rnn_units))) + else: + return (torch.zeros((num_layers, self.num_seqs, self.rnn_units))) + + def build(self, name, **kwargs): + net = A2CVisionBuilder.Network(self.params, **kwargs) + return net + + +import torch +import torch.nn as nn +from torchvision import models + +class A2CVisionBackboneBuilder(NetworkBuilder): + def __init__(self, **kwargs): + NetworkBuilder.__init__(self) + + def load(self, params): + self.params = params + + class Network(NetworkBuilder.BaseNetwork): + def __init__(self, params, **kwargs): + self.actions_num = kwargs.pop('actions_num') + input_shape = kwargs.pop('input_shape') + print('input_shape:', input_shape) + if isinstance(input_shape, dict): + input_shape = input_shape['camera'] + proprio_shape = input_shape['proprio'] + self.num_seqs = num_seqs = kwargs.pop('num_seqs', 1) + self.value_size = kwargs.pop('value_size', 1) + + NetworkBuilder.BaseNetwork.__init__(self) + self.load(params) + if self.permute_input: + input_shape = torch_ext.shape_whc_to_cwh(input_shape) + + self.cnn = self._build_resnet(input_shape, self.params['cnn']['pretrained']) + cnn_output_size = self.cnn.fc.in_features # Output size after ResNet + proprio_size = proprio_shape[0] # Number of proprioceptive features + + mlp_input_size = cnn_output_size + proprio_size + if len(self.units) == 0: + out_size = cnn_output_size + else: + out_size = self.units[-1] + + if self.has_rnn: + if not self.is_rnn_before_mlp: + rnn_in_size = out_size + out_size = self.rnn_units + else: + rnn_in_size = mlp_input_size + mlp_input_size = self.rnn_units + + self.rnn = self._build_rnn(self.rnn_name, rnn_in_size, self.rnn_units, self.rnn_layers) + self.layer_norm = torch.nn.LayerNorm(self.rnn_units) + + mlp_args = { + 'input_size': mlp_input_size, + 'units': self.units, + 'activation': self.activation, + 'norm_func_name': self.normalization, + 'dense_func': torch.nn.Linear + } + + self.mlp = self._build_mlp(**mlp_args) + + self.value = self._build_value_layer(out_size, self.value_size) + self.value_act = self.activations_factory.create(self.value_activation) + self.flatten_act = self.activations_factory.create(self.activation) + + if self.is_discrete: + self.logits = torch.nn.Linear(out_size, self.actions_num) + if self.is_continuous: + self.mu = torch.nn.Linear(out_size, self.actions_num) + self.mu_act = self.activations_factory.create(self.space_config['mu_activation']) + mu_init = self.init_factory.create(**self.space_config['mu_init']) + self.sigma_act = self.activations_factory.create(self.space_config['sigma_activation']) + sigma_init = self.init_factory.create(**self.space_config['sigma_init']) + + if self.fixed_sigma: + self.sigma = nn.Parameter(torch.zeros(self.actions_num, requires_grad=True, dtype=torch.float32), requires_grad=True) + else: + self.sigma = torch.nn.Linear(out_size, self.actions_num) + + mlp_init = self.init_factory.create(**self.initializer) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + for m in self.mlp: + if isinstance(m, nn.Linear): + mlp_init(m.weight) + + if self.is_discrete: + mlp_init(self.logits.weight) + if self.is_continuous: + mu_init(self.mu.weight) + if self.fixed_sigma: + sigma_init(self.sigma) + else: + sigma_init(self.sigma.weight) + + mlp_init(self.value.weight) + + def forward(self, obs_dict): + # TODO: Add resnet preprocessing + obs = obs_dict['camera'] + proprio = obs_dict['proprio'] + if self.permute_input: + obs = obs.permute((0, 3, 1, 2)) + + dones = obs_dict.get('dones', None) + bptt_len = obs_dict.get('bptt_len', 0) + states = obs_dict.get('rnn_states', None) + + out = obs + out = self.cnn(out) + out = out.flatten(1) + out = self.flatten_act(out) + + out = torch.cat([out, proprio], dim=1) + + if self.has_rnn: + seq_length = obs_dict.get('seq_length', 1) + + out_in = out + if not self.is_rnn_before_mlp: + out_in = out + out = self.mlp(out) + + batch_size = out.size()[0] + num_seqs = batch_size // seq_length + out = out.reshape(num_seqs, seq_length, -1) + + if len(states) == 1: + states = states[0] + + out = out.transpose(0, 1) + if dones is not None: + dones = dones.reshape(num_seqs, seq_length, -1) + dones = dones.transpose(0, 1) + out, states = self.rnn(out, states, dones, bptt_len) + out = out.transpose(0, 1) + out = out.contiguous().reshape(out.size()[0] * out.size()[1], -1) + + if self.rnn_ln: + out = self.layer_norm(out) + if self.is_rnn_before_mlp: + out = self.mlp(out) + if not isinstance(states, tuple): + states = (states,) + else: + out = self.mlp(out) + + value = self.value_act(self.value(out)) + + if self.is_discrete: + logits = self.logits(out) + return logits, value, states + + if self.is_continuous: + mu = self.mu_act(self.mu(out)) + if self.fixed_sigma: + sigma = self.sigma_act(self.sigma) + else: + sigma = self.sigma_act(self.sigma(out)) + return mu, mu * 0 + sigma, value, states + + def load(self, params): + self.separate = False + self.units = params['mlp']['units'] + self.activation = params['mlp']['activation'] + self.initializer = params['mlp']['initializer'] + self.is_discrete = 'discrete' in params['space'] + self.is_continuous = 'continuous' in params['space'] + self.is_multi_discrete = 'multi_discrete' in params['space'] + self.value_activation = params.get('value_activation', 'None') + self.normalization = params.get('normalization', None) + + if self.is_continuous: + self.space_config = params['space']['continuous'] + self.fixed_sigma = self.space_config['fixed_sigma'] + elif self.is_discrete: + self.space_config = params['space']['discrete'] + elif self.is_multi_discrete: + self.space_config = params['space']['multi_discrete'] + + self.has_rnn = 'rnn' in params + if self.has_rnn: + self.rnn_units = params['rnn']['units'] + self.rnn_layers = params['rnn']['layers'] + self.rnn_name = params['rnn']['name'] + self.is_rnn_before_mlp = params['rnn'].get('before_mlp', False) + self.rnn_ln = params['rnn'].get('layer_norm', False) + + self.has_cnn = True + self.permute_input = params['cnn'].get('permute_input', True) + self.require_rewards = params.get('require_rewards') + self.require_last_actions = params.get('require_last_actions') + + def _build_resnet(self, input_shape, pretrained): + resnet = models.resnet18(pretrained=pretrained) + # Modify the first convolution layer to match input shape if needed + if input_shape[0] != 3: + resnet.conv1 = nn.Conv2d(input_shape[0], 64, kernel_size=7, stride=2, padding=3, bias=False) + # Remove the fully connected layer + resnet = nn.Sequential(*list(resnet.children())[:-1]) + return resnet + + def is_separate_critic(self): + return False + + def is_rnn(self): + return self.has_rnn + + def get_default_rnn_state(self): + num_layers = self.rnn_layers + if self.rnn_name == 'lstm': + return (torch.zeros((num_layers, self.num_seqs, self.rnn_units)), + torch.zeros((num_layers, self.num_seqs, self.rnn_units))) + else: + return (torch.zeros((num_layers, self.num_seqs, self.rnn_units))) + + def build(self, name, **kwargs): + net = A2CVisionBackboneBuilder.Network(self.params, **kwargs) + return net