diff --git a/rl_games/algos_torch/model_builder.py b/rl_games/algos_torch/model_builder.py index 0ebb38db..e33bd86f 100644 --- a/rl_games/algos_torch/model_builder.py +++ b/rl_games/algos_torch/model_builder.py @@ -1,7 +1,7 @@ from rl_games.common import object_factory import rl_games.algos_torch -from rl_games.algos_torch import models, network_builder -from rl_games.networks import vision_networks +from rl_games.algos_torch import network_builder, models + NETWORK_REGISTRY = {} MODEL_REGISTRY = {} @@ -21,9 +21,9 @@ def __init__(self): self.network_factory.register_builder('resnet_actor_critic', lambda **kwargs: network_builder.A2CResnetBuilder()) self.network_factory.register_builder('vision_actor_critic', - lambda **kwargs: vision_networks.A2CVisionBuilder()) - self.network_factory.register_builder('e2e_vision_actor_critic', - lambda **kwargs: vision_networks.A2CVisionBackboneBuilder()) + lambda **kwargs: network_builder.A2CVisionBuilder()) + # self.network_factory.register_builder('e2e_vision_actor_critic', + # lambda **kwargs: vision_networks.A2CVisionBackboneBuilder()) self.network_factory.register_builder('rnd_curiosity', lambda **kwargs: network_builder.RNDCuriosityBuilder()) self.network_factory.register_builder('soft_actor_critic', lambda **kwargs: network_builder.SACBuilder()) diff --git a/rl_games/algos_torch/models.py b/rl_games/algos_torch/models.py index 8781091c..fb518c9c 100644 --- a/rl_games/algos_torch/models.py +++ b/rl_games/algos_torch/models.py @@ -10,6 +10,7 @@ from rl_games.algos_torch.running_mean_std import RunningMeanStd, RunningMeanStdObs from rl_games.algos_torch.moving_mean_std import GeneralizedMovingStats + class BaseModel(): def __init__(self, model_class): self.model_class = model_class @@ -25,12 +26,14 @@ def get_value_layer(self): def build(self, config): obs_shape = config['input_shape'] + print(f"obs_shape: {obs_shape}") normalize_value = config.get('normalize_value', False) normalize_input = config.get('normalize_input', False) value_size = config.get('value_size', 1) return self.Network(self.network_builder.build(self.model_class, **config), obs_shape=obs_shape, normalize_value=normalize_value, normalize_input=normalize_input, value_size=value_size) + class BaseModelNetwork(nn.Module): def __init__(self, obs_shape, normalize_value, normalize_input, value_size): nn.Module.__init__(self) @@ -366,5 +369,3 @@ def forward(self, input_dict): dist = SquashedNormal(mu, sigma) return dist - - diff --git a/rl_games/algos_torch/network_builder.py b/rl_games/algos_torch/network_builder.py index d89fc75f..84f9a14a 100644 --- a/rl_games/algos_torch/network_builder.py +++ b/rl_games/algos_torch/network_builder.py @@ -488,19 +488,19 @@ def get_default_rnn_state(self): rnn_units = self.rnn_units if self.rnn_name == 'lstm': if self.separate: - return (torch.zeros((num_layers, self.num_seqs, rnn_units)), + return (torch.zeros((num_layers, self.num_seqs, rnn_units)), + torch.zeros((num_layers, self.num_seqs, rnn_units)), torch.zeros((num_layers, self.num_seqs, rnn_units)), - torch.zeros((num_layers, self.num_seqs, rnn_units)), torch.zeros((num_layers, self.num_seqs, rnn_units))) else: - return (torch.zeros((num_layers, self.num_seqs, rnn_units)), + return (torch.zeros((num_layers, self.num_seqs, rnn_units)), torch.zeros((num_layers, self.num_seqs, rnn_units))) else: if self.separate: - return (torch.zeros((num_layers, self.num_seqs, rnn_units)), + return (torch.zeros((num_layers, self.num_seqs, rnn_units)), torch.zeros((num_layers, self.num_seqs, rnn_units))) else: - return (torch.zeros((num_layers, self.num_seqs, rnn_units)),) + return (torch.zeros((num_layers, self.num_seqs, rnn_units)),) def load(self, params): self.separate = params.get('separate', False) @@ -849,6 +849,209 @@ def build(self, name, **kwargs): return net +class A2CVisionBuilder(NetworkBuilder): + def __init__(self, **kwargs): + NetworkBuilder.__init__(self) + + def load(self, params): + self.params = params + + class Network(NetworkBuilder.BaseNetwork): + def __init__(self, params, **kwargs): + self.actions_num = actions_num = kwargs.pop('actions_num') + input_shape = kwargs.pop('input_shape') + print('input_shape:', input_shape) + if type(input_shape) is dict: + input_shape = input_shape['camera'] + proprio_shape = input_shape['proprio'] + + self.num_seqs = kwargs.pop('num_seqs', 1) + self.value_size = kwargs.pop('value_size', 1) + + NetworkBuilder.BaseNetwork.__init__(self) + self.load(params) + if self.permute_input: + input_shape = torch_ext.shape_whc_to_cwh(input_shape) + + self.cnn = self._build_impala(input_shape, self.conv_depths) + cnn_output_size = self._calc_input_size(input_shape, self.cnn) + proprio_size = proprio_shape[0] # Number of proprioceptive features + + mlp_input_size = cnn_output_size + proprio_size + if len(self.units) == 0: + out_size = cnn_output_size + else: + out_size = self.units[-1] + + if self.has_rnn: + if not self.is_rnn_before_mlp: + rnn_in_size = out_size + out_size = self.rnn_units + else: + rnn_in_size = mlp_input_size + mlp_input_size = self.rnn_units + + self.rnn = self._build_rnn(self.rnn_name, rnn_in_size, self.rnn_units, self.rnn_layers) + self.layer_norm = torch.nn.LayerNorm(self.rnn_units) + + mlp_args = { + 'input_size' : mlp_input_size, + 'units' :self.units, + 'activation' : self.activation, + 'norm_func_name' : self.normalization, + 'dense_func' : torch.nn.Linear + } + + self.mlp = self._build_mlp(**mlp_args) + + self.value = self._build_value_layer(out_size, self.value_size) + self.value_act = self.activations_factory.create(self.value_activation) + self.flatten_act = self.activations_factory.create(self.activation) + + if self.is_discrete: + self.logits = torch.nn.Linear(out_size, actions_num) + if self.is_continuous: + self.mu = torch.nn.Linear(out_size, actions_num) + self.mu_act = self.activations_factory.create(self.space_config['mu_activation']) + mu_init = self.init_factory.create(**self.space_config['mu_init']) + self.sigma_act = self.activations_factory.create(self.space_config['sigma_activation']) + sigma_init = self.init_factory.create(**self.space_config['sigma_init']) + + if self.fixed_sigma: + self.sigma = nn.Parameter(torch.zeros(actions_num, requires_grad=True, dtype=torch.float32), requires_grad=True) + else: + self.sigma = torch.nn.Linear(out_size, actions_num) + + mlp_init = self.init_factory.create(**self.initializer) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + for m in self.mlp: + if isinstance(m, nn.Linear): + mlp_init(m.weight) + + if self.is_discrete: + mlp_init(self.logits.weight) + if self.is_continuous: + mu_init(self.mu.weight) + if self.fixed_sigma: + sigma_init(self.sigma) + else: + sigma_init(self.sigma.weight) + + mlp_init(self.value.weight) + + def forward(self, obs_dict): + # for key in obs_dict: + # print(key) + obs = obs_dict['camera'] + proprio = obs_dict['proprio'] + if self.permute_input: + obs = obs.permute((0, 3, 1, 2)) + + dones = obs_dict.get('dones', None) + bptt_len = obs_dict.get('bptt_len', 0) + states = obs_dict.get('rnn_states', None) + + out = obs + out = self.cnn(out) + out = out.flatten(1) + out = self.flatten_act(out) + + out = torch.cat([out, proprio], dim=1) + + if self.has_rnn: + seq_length = obs_dict['seq_length'] + + out_in = out + if not self.is_rnn_before_mlp: + out_in = out + out = self.mlp(out) + + batch_size = out.size()[0] + num_seqs = batch_size // seq_length + out = out.reshape(num_seqs, seq_length, -1) + + if len(states) == 1: + states = states[0] + + out = out.transpose(0, 1) + if dones is not None: + dones = dones.reshape(num_seqs, seq_length, -1) + dones = dones.transpose(0, 1) + out, states = self.rnn(out, states, dones, bptt_len) + out = out.transpose(0, 1) + out = out.contiguous().reshape(out.size()[0] * out.size()[1], -1) + + if self.rnn_ln: + out = self.layer_norm(out) + if self.is_rnn_before_mlp: + out = self.mlp(out) + if type(states) is not tuple: + states = (states,) + else: + out = self.mlp(out) + + value = self.value_act(self.value(out)) + + if self.is_discrete: + logits = self.logits(out) + return logits, value, states + + if self.is_continuous: + mu = self.mu_act(self.mu(out)) + if self.fixed_sigma: + sigma = self.sigma_act(self.sigma) + else: + sigma = self.sigma_act(self.sigma(out)) + return mu, mu*0 + sigma, value, states + + def load(self, params): + self.separate = False + self.units = params['mlp']['units'] + self.activation = params['mlp']['activation'] + self.initializer = params['mlp']['initializer'] + self.is_discrete = 'discrete' in params['space'] + self.is_continuous = 'continuous' in params['space'] + self.is_multi_discrete = 'multi_discrete'in params['space'] + self.value_activation = params.get('value_activation', 'None') + self.normalization = params.get('normalization', None) + + if self.is_continuous: + self.space_config = params['space']['continuous'] + self.fixed_sigma = self.space_config['fixed_sigma'] + elif self.is_discrete: + self.space_config = params['sA2CVisionBuildernv_depths'] + self.require_rewards = params.get('require_rewards') + self.require_last_actions = params.get('require_last_actions') + + def _build_impala(self, input_shape, depths): + in_channels = input_shape[0] + layers = nn.ModuleList() + for d in depths: + layers.append(ImpalaSequential(in_channels, d)) + in_channels = d + return nn.Sequential(*layers) + + def is_separate_critic(self): + return False + + def is_rnn(self): + return self.has_rnn + + def get_default_rnn_state(self): + num_layers = self.rnn_layers + if self.rnn_name == 'lstm': + return (torch.zeros((num_layers, self.num_seqs, self.rnn_units)), + torch.zeros((num_layers, self.num_seqs, self.rnn_units))) + else: + return (torch.zeros((num_layers, self.num_seqs, self.rnn_units))) + + def build(self, name, **kwargs): + net = A2CVisionBuilder.Network(self.params, **kwargs) + return net + class DiagGaussianActor(NetworkBuilder.BaseNetwork): """torch.distributions implementation of an diagonal Gaussian policy.""" def __init__(self, output_dim, log_std_bounds, **mlp_args): diff --git a/rl_games/common/a2c_common.py b/rl_games/common/a2c_common.py index b3a90937..5f484860 100644 --- a/rl_games/common/a2c_common.py +++ b/rl_games/common/a2c_common.py @@ -10,7 +10,7 @@ from rl_games.common.experience import ExperienceBuffer from rl_games.common.interval_summary_writer import IntervalSummaryWriter from rl_games.common.diagnostics import DefaultDiagnostics, PpoDiagnostics -from rl_games.algos_torch import model_builder +from rl_games.algos_torch import model_builder from rl_games.interfaces.base_algorithm import BaseAlgorithm import numpy as np import time @@ -356,7 +356,7 @@ def trancate_gradients_and_step(self): def load_networks(self, params): builder = model_builder.ModelBuilder() self.config['network'] = builder.load(params) - has_central_value_net = self.config.get('central_value_config') is not None + has_central_value_net = self.config.get('central_value_config') is not None if has_central_value_net: print('Adding Central Value Network') if 'model' not in params['config']['central_value_config']: diff --git a/rl_games/networks/__init__.py b/rl_games/networks/__init__.py index 1c99d866..ca68c624 100644 --- a/rl_games/networks/__init__.py +++ b/rl_games/networks/__init__.py @@ -1,4 +1,7 @@ from rl_games.networks.tcnn_mlp import TcnnNetBuilder +#from rl_games.networks.vision_networks import A2CVisionBuilder, A2CVisionBackboneBuilder from rl_games.algos_torch import model_builder -model_builder.register_network('tcnnnet', TcnnNetBuilder) \ No newline at end of file +model_builder.register_network('tcnnnet', TcnnNetBuilder) +# model_builder.register_network('vision_actor_critic', A2CVisionBuilder) +# model_builder.register_network('e2e_vision_actor_critic', A2CVisionBackboneBuilder) \ No newline at end of file diff --git a/rl_games/networks/vision_networks.py b/rl_games/networks/vision_networks.py index 0647c238..dd0d39f7 100644 --- a/rl_games/networks/vision_networks.py +++ b/rl_games/networks/vision_networks.py @@ -2,7 +2,7 @@ from torch import nn from torchvision import models import torch.nn.functional as F -import torch_ext +from rl_games.algos_torch import torch_ext from rl_games.algos_torch.network_builder import NetworkBuilder, ImpalaSequential