diff --git a/rl_games/__init__.py b/rl_games/__init__.py index e69de29b..7c443754 100644 --- a/rl_games/__init__.py +++ b/rl_games/__init__.py @@ -0,0 +1 @@ +from rl_games.networks import * \ No newline at end of file diff --git a/rl_games/algos_torch/network_builder.py b/rl_games/algos_torch/network_builder.py index ab047920..54781bd3 100644 --- a/rl_games/algos_torch/network_builder.py +++ b/rl_games/algos_torch/network_builder.py @@ -9,7 +9,6 @@ from rl_games.common.layers.recurrent import GRUWithDones, LSTMWithDones from rl_games.common.layers.value import TwoHotEncodedValue, DefaultValue - def _create_initializer(func, **kwargs): return lambda v : func(v, **kwargs) diff --git a/rl_games/configs/mujoco/ant_envpool_moe.yaml b/rl_games/configs/mujoco/ant_envpool_moe.yaml new file mode 100644 index 00000000..814850cc --- /dev/null +++ b/rl_games/configs/mujoco/ant_envpool_moe.yaml @@ -0,0 +1,71 @@ +params: + seed: 5 + algo: + name: a2c_continuous + + model: + name: continuous_a2c_logstd + + network: + name: moe + space: + continuous: + mu_activation: None + sigma_activation: None + mu_init: + name: default + sigma_init: + name: const_initializer + val: 0 + fixed_sigma: True + num_experts: 4 + hidden_size: 256 + gating_hidden_size: 128 + use_sparse_gating: true + use_entropy_loss: true + use_diversity_loss: true + top_k: 2 + lambda_entropy: 0.01 + lambda_diversity: 0.01 + + config: + name: Ant-v4_envpool_moe + env_name: envpool + score_to_win: 20000 + normalize_input: True + normalize_value: True + value_bootstrap: True + normalize_advantage: True + reward_shaper: + scale_value: 1 + + gamma: 0.99 + tau: 0.95 + learning_rate: 3e-4 + lr_schedule: adaptive + kl_threshold: 0.008 + grad_norm: 1.0 + entropy_coef: 0.0 + truncate_grads: True + e_clip: 0.2 + clip_value: True + use_smooth_clamp: True + bound_loss_type: regularisation + bounds_loss_coef: 0.0 + max_epochs: 2000 + num_actors: 64 + horizon_length: 64 + minibatch_size: 2048 + mini_epochs: 4 + critic_coef: 2 + + env_config: + env_name: Ant-v4 + seed: 5 + #flat_observation: True + + player: + render: False + num_actors: 64 + games_num: 1000 + use_vecenv: True \ No newline at end of file diff --git a/rl_games/networks/__init__.py b/rl_games/networks/__init__.py index 1c99d866..1bfc0264 100644 --- a/rl_games/networks/__init__.py +++ b/rl_games/networks/__init__.py @@ -1,4 +1,7 @@ from rl_games.networks.tcnn_mlp import TcnnNetBuilder +from rl_games.networks.moe import MoENetBuilder + from rl_games.algos_torch import model_builder -model_builder.register_network('tcnnnet', TcnnNetBuilder) \ No newline at end of file +model_builder.register_network('tcnnnet', TcnnNetBuilder) +model_builder.register_network('moe', MoENetBuilder) \ No newline at end of file diff --git a/rl_games/networks/moe.py b/rl_games/networks/moe.py index 6c43efb8..081699e8 100644 --- a/rl_games/networks/moe.py +++ b/rl_games/networks/moe.py @@ -1,17 +1,33 @@ import torch import torch.nn as nn import torch.nn.functional as F -from rl_games.common import networks -from rl_games.common import layers +from rl_games.algos_torch.network_builder import NetworkBuilder -class MoENet(networks.NetworkBuilder.BaseNetwork): +class MoENet(NetworkBuilder.BaseNetwork): def __init__(self, params, **kwargs): - nn.Module.__init__(self) + NetworkBuilder.BaseNetwork.__init__(self) actions_num = kwargs.pop('actions_num') input_shape = kwargs.pop('input_shape') num_inputs = 0 + self.has_space = 'space' in params self.central_value = params.get('central_value', False) + if self.has_space: + self.is_multi_discrete = 'multi_discrete'in params['space'] + self.is_discrete = 'discrete' in params['space'] + self.is_continuous = 'continuous'in params['space'] + if self.is_continuous: + self.space_config = params['space']['continuous'] + self.fixed_sigma = self.space_config['fixed_sigma'] + elif self.is_discrete: + self.space_config = params['space']['discrete'] + elif self.is_multi_discrete: + self.space_config = params['space']['multi_discrete'] + else: + self.is_discrete = False + self.is_continuous = False + self.is_multi_discrete = False + self.value_size = kwargs.pop('value_size', 1) # Parameters from params @@ -19,14 +35,17 @@ def __init__(self, params, **kwargs): hidden_size = params.get('hidden_size', 128) gating_hidden_size = params.get('gating_hidden_size', 64) self.use_sparse_gating = params.get('use_sparse_gating', False) + self.use_entropy_loss = params.get('use_entropy_loss', True) + self.use_diversity_loss = params.get('use_diversity_loss', True) self.top_k = params.get('top_k', 1) self.lambda_entropy = params.get('lambda_entropy', 0.01) self.lambda_diversity = params.get('lambda_diversity', 0.01) # Input processing - assert isinstance(input_shape, dict), "Input shape must be a dict" - for k, v in input_shape.items(): - num_inputs += v[0] + #assert isinstance(input_shape, dict), "Input shape must be a dict" + #for k, v in input_shape.items(): + # num_inputs += v[0] + num_inputs = input_shape[0] # Gating Network self.gating_fc1 = nn.Linear(num_inputs, gating_hidden_size) @@ -44,15 +63,27 @@ def __init__(self, params, **kwargs): ) for _ in range(num_experts) ]) - # Output layers - self.mean_linear = nn.Linear(hidden_size, actions_num) + if self.is_discrete: + self.logits = torch.nn.Linear(hidden_size, actions_num) + if self.is_multi_discrete: + self.logits = torch.nn.ModuleList([torch.nn.Linear(hidden_size, num) for num in actions_num]) + if self.is_continuous: + self.mu = torch.nn.Linear(hidden_size, actions_num) + self.sigma = torch.nn.Parameter(torch.zeros(actions_num, requires_grad=True, dtype=torch.float32), + requires_grad=True) + self.mu_act = self.activations_factory.create(self.space_config['mu_activation']) + #mu_init = self.init_factory.create(**self.space_config['mu_init']) + self.sigma_act = self.activations_factory.create(self.space_config['sigma_activation']) + #sigma_init = self.init_factory.create(**self.space_config['sigma_init']) self.value = nn.Linear(hidden_size, self.value_size) # Auxiliary loss map self.aux_loss_map = { - 'entropy_loss': None, - 'diversity_loss': None, } + if self.use_diversity_loss: + self.aux_loss_map['diversity_loss'] = 0.0 + if self.use_entropy_loss: + self.aux_loss_map['entropy_loss'] = 0.0 def is_rnn(self): return False @@ -61,11 +92,7 @@ def get_aux_loss(self): return self.aux_loss_map def forward(self, obs_dict): - # Combine observations - obs = [] - for k in obs_dict: - obs.append(obs_dict[k]) - obs = torch.cat(obs, dim=-1) + obs = obs_dict['obs'] # Gating Network Forward Pass gating_x = F.relu(self.gating_fc1(obs)) @@ -79,10 +106,12 @@ def forward(self, obs_dict): sparse_mask.scatter_(1, topk_indices, topk_values) gating_weights = sparse_mask / sparse_mask.sum(dim=1, keepdim=True) # Re-normalize + + if self.use_entropy_loss: # Compute Entropy Loss for Gating Weights - entropy = -torch.sum(gating_weights * torch.log(gating_weights + 1e-8), dim=1) - entropy_loss = torch.mean(entropy) - self.aux_loss_map['entropy_loss'] = self.lambda_entropy * entropy_loss + entropy = -torch.sum(gating_weights * torch.log(gating_weights + 1e-8), dim=1) + entropy_loss = torch.mean(entropy) + self.aux_loss_map['entropy_loss'] = self.lambda_entropy * entropy_loss # Expert Networks Forward Pass expert_outputs = [] @@ -91,22 +120,23 @@ def forward(self, obs_dict): expert_outputs = torch.stack(expert_outputs, dim=1) # Shape: [batch_size, num_experts, hidden_size] # Compute Diversity Loss - diversity_loss = 0.0 - num_experts = len(self.expert_networks) - for i in range(num_experts): - for j in range(i + 1, num_experts): - similarity = F.cosine_similarity(expert_outputs[:, i, :], expert_outputs[:, j, :], dim=-1) - diversity_loss += torch.mean(similarity) - num_pairs = num_experts * (num_experts - 1) / 2 - diversity_loss = diversity_loss / num_pairs - self.aux_loss_map['diversity_loss'] = self.lambda_diversity * diversity_loss + if self.use_diversity_loss: + diversity_loss = 0.0 + num_experts = len(self.expert_networks) + for i in range(num_experts): + for j in range(i + 1, num_experts): + similarity = F.cosine_similarity(expert_outputs[:, i, :], expert_outputs[:, j, :], dim=-1) + diversity_loss += torch.mean(similarity) + num_pairs = num_experts * (num_experts - 1) / 2 + diversity_loss = diversity_loss / num_pairs + self.aux_loss_map['diversity_loss'] = self.lambda_diversity * diversity_loss # Aggregate Expert Outputs gating_weights = gating_weights.unsqueeze(-1) # Shape: [batch_size, num_experts, 1] aggregated_output = torch.sum(gating_weights * expert_outputs, dim=1) # Shape: [batch_size, hidden_size] out = aggregated_output - value = self.value_act(self.value(out)) + value = self.value(out) states = None if self.central_value: return value, states @@ -123,7 +153,7 @@ def forward(self, obs_dict): sigma = self.sigma_act(self.sigma) else: sigma = self.sigma_act(self.sigma(out)) - return mu, mu*0 + sigma, value, states + return mu, mu*0 + sigma, value, states from rl_games.algos_torch.network_builder import NetworkBuilder