diff --git a/rl_games/algos_torch/network_builder.py b/rl_games/algos_torch/network_builder.py index 3dc48c58..5ba5492d 100644 --- a/rl_games/algos_torch/network_builder.py +++ b/rl_games/algos_torch/network_builder.py @@ -5,6 +5,7 @@ import torch.nn as nn from rl_games.algos_torch.d2rl import D2RLNet +from rl_games.common.layers.switch_ffn import MoEBlock from rl_games.algos_torch.sac_helper import SquashedNormal from rl_games.common.layers.recurrent import GRUWithDones, LSTMWithDones from rl_games.common.layers.value import TwoHotEncodedValue, DefaultValue @@ -67,6 +68,8 @@ def get_default_rnn_state(self): return None def get_aux_loss(self): + if self.moe_block: + return self.actor_mlp.get_aux_loss() return None def _calc_input_size(self, input_shape,cnn_layers=None): @@ -128,6 +131,9 @@ def _build_mlp(self, else: return self._build_sequential_mlp(input_size, units, activation, dense_func, norm_func_name = None,) + def _build_moe_block(self, input_size, expert_units, model_units, num_experts): + return MoEBlock(input_size, expert_units, model_units, num_experts) + def _build_conv(self, ctype, **kwargs): print('conv_name:', ctype) @@ -231,9 +237,8 @@ def __init__(self, params, **kwargs): cnn_output_size = self._calc_input_size(input_shape, self.actor_cnn) mlp_input_size = cnn_output_size - if len(self.units) == 0: - out_size = cnn_output_size - else: + out_size = mlp_input_size + if len(self.units) > 0: out_size = self.units[-1] if self.has_rnn: @@ -263,18 +268,23 @@ def __init__(self, params, **kwargs): if self.rnn_ln: self.layer_norm = torch.nn.LayerNorm(self.rnn_units) - mlp_args = { - 'input_size' : mlp_input_size, - 'units' : self.units, - 'activation' : self.activation, - 'norm_func_name' : self.normalization, - 'dense_func' : torch.nn.Linear, - 'd2rl' : self.is_d2rl, - 'norm_only_first_layer' : self.norm_only_first_layer - } - self.actor_mlp = self._build_mlp(**mlp_args) - if self.separate: - self.critic_mlp = self._build_mlp(**mlp_args) + + if self.moe_block: + self.actor_mlp = self._build_moe_block(mlp_input_size, self.expert_units, self.model_units, self.num_experts) + assert(not self.separate) + else: + mlp_args = { + 'input_size' : mlp_input_size, + 'units' : self.units, + 'activation' : self.activation, + 'norm_func_name' : self.normalization, + 'dense_func' : torch.nn.Linear, + 'd2rl' : self.is_d2rl, + 'norm_only_first_layer' : self.norm_only_first_layer + } + self.actor_mlp = self._build_mlp(**mlp_args) + if self.separate: + self.critic_mlp = self._build_mlp(**mlp_args) self.value = self._build_value_layer(out_size, self.value_size) self.value_act = self.activations_factory.create(self.value_activation) @@ -506,11 +516,22 @@ def get_default_rnn_state(self): def load(self, params): self.separate = params.get('separate', False) - self.units = params['mlp']['units'] - self.activation = params['mlp']['activation'] - self.initializer = params['mlp']['initializer'] - self.is_d2rl = params['mlp'].get('d2rl', False) - self.norm_only_first_layer = params['mlp'].get('norm_only_first_layer', False) + self.moe_block = params.get('moe', False) + + if self.moe_block: + assert(not params.get('mlp', False)) + self.num_experts = self.moe_block['num_experts'] + self.expert_units = self.moe_block['expert_units'] + self.model_units = self.moe_block['model_units'] + self.initializer = self.moe_block['initializer'] + self.units = self.expert_units + + else: + self.units = params['mlp']['units'] + self.activation = params['mlp']['activation'] + self.initializer = params['mlp']['initializer'] + self.is_d2rl = params['mlp'].get('d2rl', False) + self.norm_only_first_layer = params['mlp'].get('norm_only_first_layer', False) self.value_activation = params.get('value_activation', 'None') self.normalization = params.get('normalization', None) self.has_rnn = 'rnn' in params diff --git a/rl_games/common/layers/switch_ffn.py b/rl_games/common/layers/switch_ffn.py new file mode 100644 index 00000000..6101fa80 --- /dev/null +++ b/rl_games/common/layers/switch_ffn.py @@ -0,0 +1,210 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class SwitchFeedForward(nn.Module): + + def __init__(self, + model_dim: int, + hidden_dim: int, + out_dim: int, + is_scale_prob: bool, + num_experts: int, + activation: nn.Module = nn.ReLU + + ): + super().__init__() + self.hidden_dim = hidden_dim + self.model_dim = model_dim + self.out_dim = out_dim + self.is_scale_prob = is_scale_prob + self.num_experts = num_experts + self.experts = nn.ModuleList([ + nn.Sequential( + nn.Linear(model_dim, out_dim), + activation(), + #nn.Linear(model_dim, hidden_dim), + #activation(), + #nn.Linear(hidden_dim, out_dim), + #activation(), + ) + for _ in range(num_experts) + ]) + # Routing layer and softmax + self.switch = nn.Linear(model_dim, num_experts) + self.softmax = nn.Softmax(dim=-1) + + def forward(self, x: torch.Tensor): + route_prob = self.softmax(self.switch(x)) + route_prob_max, routes = torch.max(route_prob, dim=-1) + indexes_list = [torch.eq(routes, i).nonzero(as_tuple=True)[0] for i in range(self.num_experts)] + + final_output = torch.zeros((x.size(0), self.out_dim), device=x.device) + counts = x.new_tensor([len(indexes_list[i]) for i in range(self.num_experts)]) + + # Get outputs of the expert FFNs + expert_output = [self.experts[i](x[indexes_list[i], :]) for i in range(self.num_experts)] + # Assign to final output + for i in range(self.num_experts): + final_output[indexes_list[i], :] = expert_output[i] + + if self.is_scale_prob: + # Multiply by the expert outputs by the probabilities $y = p_i(x) E_i(x)$ + final_output = final_output * route_prob_max.view(-1, 1) + else: + # not sure if this is correct + final_output = final_output * (route_prob_max / route_prob_max.detach()).view(-1, 1) + + + return final_output, counts, route_prob.sum(0), route_prob_max + + + +class MoEFF(nn.Module): + def __init__(self, + model_dim: int, + hidden_dim: int, + out_dim: int, + num_experts: int, + activation: nn.Module = nn.ReLU, + **kwargs + ): + super().__init__() + + # Parameters from params + self.model_dim = model_dim + self.num_experts = num_experts + self.hidden_dim = hidden_dim + self.out_dim = out_dim + self.gating_hidden_size = kwargs.get('gating_hidden_size', 64) + self.use_sparse_gating = kwargs.get('use_sparse_gating', True) + self.use_entropy_loss = kwargs.get('use_entropy_loss', True) + self.use_diversity_loss = kwargs.get('use_diversity_loss', True) + self.top_k = kwargs.get('top_k', 2) + self.lambda_entropy = kwargs.get('lambda_entropy', 0.01) + self.lambda_diversity = kwargs.get('lambda_diversity', 0.00) + + + # Gating Network + self.gating_fc1 = nn.Linear(self.model_dim, self.gating_hidden_size) + self.gating_fc2 = nn.Linear(self.gating_hidden_size, num_experts) + + # Expert Networks + self.expert_networks = nn.ModuleList([ + nn.Sequential( + nn.Linear(self.model_dim, out_dim), + activation(), + ) for _ in range(num_experts) + ]) + + + # Auxiliary loss map + self.aux_loss_map = { + } + if self.use_diversity_loss: + self.aux_loss_map['moe_diversity_loss'] = 0.0 + if self.use_entropy_loss: + self.aux_loss_map['moe_entropy_loss'] = 0.0 + + def get_aux_loss(self): + return self.aux_loss_map + + def forward(self, x): + + # Gating Network Forward Pass + gating_x = F.relu(self.gating_fc1(x)) + gating_logits = self.gating_fc2(gating_x) # Shape: [batch_size, num_experts] + orig_gating_weights = F.softmax(gating_logits, dim=1) + gating_weights = orig_gating_weights + # Apply Sparse Gating if enabled + if self.use_sparse_gating: + topk_values, topk_indices = torch.topk(gating_weights, self.top_k, dim=1) + sparse_mask = torch.zeros_like(gating_weights) + sparse_mask.scatter_(1, topk_indices, topk_values) + # probably better go with masked softmax + gating_weights = sparse_mask / sparse_mask.sum(dim=1, keepdim=True) + + if self.use_entropy_loss: + # Compute Entropy Loss for Gating Weights + entropy = -torch.sum(gating_weights * torch.log(gating_weights + 1e-8), dim=1) + entropy_loss = torch.mean(entropy) + self.aux_loss_map['moe_entropy_loss'] = -self.lambda_entropy * entropy_loss + + # Expert Networks Forward Pass + expert_outputs = [] + for expert in self.expert_networks: + expert_outputs.append(expert(x)) # Each output shape: [batch_size, hidden_size] + expert_outputs = torch.stack(expert_outputs, dim=1) # Shape: [batch_size, num_experts, hidden_size] + + # Compute Diversity Loss + if self.use_diversity_loss: + diversity_loss = 0.0 + num_experts = len(self.expert_networks) + for i in range(num_experts): + for j in range(i + 1, num_experts): + similarity = F.cosine_similarity(expert_outputs[:, i, :], expert_outputs[:, j, :], dim=-1) + diversity_loss += torch.mean(similarity) + num_pairs = num_experts * (num_experts - 1) / 2 + diversity_loss = diversity_loss / num_pairs + self.aux_loss_map['moe_diversity_loss'] = self.lambda_diversity * diversity_loss + + # Aggregate Expert Outputs + gating_weights = gating_weights.unsqueeze(-1) # Shape: [batch_size, num_experts, 1] + aggregated_output = torch.sum(gating_weights * expert_outputs, dim=1) # Shape: [batch_size, hidden_size] + out = aggregated_output + return out + + +class MoEBlock(nn.Module): + def __init__(self, + input_size: int, + model_units: list[int], + expert_units: list[int], + num_experts: int, + ): + super().__init__() + self.num_experts = num_experts + in_size = input_size + layers =[] + for u, m in zip(expert_units, model_units): + layers.append(MoEFF(in_size, m, u, num_experts)) + in_size = u + self.layers = nn.ModuleList(layers) + self.load_balancing_loss = None + + def get_aux_loss(self): + return { + "moe_load_balancing_loss": self.load_balancing_loss + } + + def forward(self, x: torch.Tensor): + moe_diversity_loss, moe_entropy_loss = 0, 0 + for layer in self.layers: + x = layer(x) + moe_diversity_loss = moe_diversity_loss + layer.get_aux_loss()['moe_diversity_loss'] + moe_entropy_loss = moe_diversity_loss + layer.get_aux_loss()['moe_entropy_loss'] + + self.load_balancing_loss = moe_diversity_loss / len(self.layers) + moe_entropy_loss / len(self.layers) + return x + +''' + def forward(self, x: torch.Tensor): + counts, route_prob_sums, route_prob_maxs = [], [], [] + for layer in self.layers: + x, count, route_prob_sum, route_prob_max = layer(x) + counts.append(count) + route_prob_sums.append(route_prob_sum) + route_prob_maxs.append(route_prob_max) + + counts = torch.stack(counts) + route_prob_sums = torch.stack(route_prob_sums) + route_prob_maxs = torch.stack(route_prob_maxs) + + total = counts.sum(dim=-1, keepdims=True) + route_frac = counts / total + route_prob = route_prob_sums / total + + self.load_balancing_loss = self.num_experts * (route_frac * route_prob).sum() + return x +''' \ No newline at end of file diff --git a/rl_games/configs/bark/ppo_merging.yaml b/rl_games/configs/bark/ppo_merging.yaml new file mode 100644 index 00000000..253c4273 --- /dev/null +++ b/rl_games/configs/bark/ppo_merging.yaml @@ -0,0 +1,64 @@ +params: + seed: 5 + algo: + name: a2c_continuous + + model: + name: continuous_a2c_logstd + + network: + name: actor_critic + separate: False + space: + continuous: + mu_activation: None + sigma_activation: None + mu_init: + name: default + sigma_init: + name: const_initializer + val: 0 + fixed_sigma: True + mlp: + units: [256, 128, 64] + activation: elu + initializer: + name: default + + config: + name: Ant-v3_ray + env_name: openai_gym + score_to_win: 20000 + normalize_input: True + normalize_value: True + value_bootstrap: True + reward_shaper: + scale_value: 0.1 + normalize_advantage: True + gamma: 0.99 + tau: 0.95 + + learning_rate: 3e-4 + lr_schedule: adaptive + kl_threshold: 0.008 + grad_norm: 1.0 + entropy_coef: 0.0 + truncate_grads: True + e_clip: 0.2 + max_epochs: 2000 + num_actors: 8 #64 + horizon_length: 256 #64 + minibatch_size: 2048 + mini_epochs: 4 + critic_coef: 2 + clip_value: True + use_smooth_clamp: True + bound_loss_type: regularisation + bounds_loss_coef: 0.0 + + env_config: + name: "merging-v0" + seed: 5 + + player: + render: True \ No newline at end of file diff --git a/rl_games/configs/mujoco/ant_envpool_moe.yaml b/rl_games/configs/mujoco/ant_envpool_moe.yaml index 1c87fa6b..cdfbc1f2 100644 --- a/rl_games/configs/mujoco/ant_envpool_moe.yaml +++ b/rl_games/configs/mujoco/ant_envpool_moe.yaml @@ -7,7 +7,8 @@ params: name: continuous_a2c_logstd network: - name: moe + name: actor_critic + separate: False space: continuous: mu_activation: None @@ -18,16 +19,14 @@ params: name: const_initializer val: 0 fixed_sigma: True - num_experts: 4 - hidden_size: 256 - gating_hidden_size: 128 - use_sparse_gating: True - use_entropy_loss: True - use_diversity_loss: False - top_k: 2 - lambda_entropy: -0.01 - lambda_diversity: 0.01 - + + moe: + num_experts: 4 + expert_units: [256, 128, 64] + model_units: [256, 128, 64] + #expert_activation: elu + initializer: + name: default config: name: Ant-v4_envpool_moe env_name: envpool diff --git a/rl_games/configs/mujoco/humanoid_envpool_moe.yaml b/rl_games/configs/mujoco/humanoid_envpool_moe.yaml index 98eaf22d..1de67a92 100644 --- a/rl_games/configs/mujoco/humanoid_envpool_moe.yaml +++ b/rl_games/configs/mujoco/humanoid_envpool_moe.yaml @@ -7,7 +7,8 @@ params: name: continuous_a2c_logstd network: - name: moe + name: actor_critic + separate: False space: continuous: mu_activation: None @@ -18,54 +19,48 @@ params: name: const_initializer val: 0 fixed_sigma: True - num_experts: 4 - hidden_size: 512 - gating_hidden_size: 128 - use_sparse_gating: True - use_entropy_loss: True - use_diversity_loss: True - top_k: 2 - lambda_entropy: -0.01 - lambda_diversity: 0.01 + moe: + num_experts: 4 + expert_units: [512, 256, 128] + model_units: [512, 256, 128] + #expert_activation: elu + is_scale_prob: False + initializer: + name: default config: - name: Humanoid_envpool_moe - env_name: envpool - score_to_win: 20000 - normalize_input: True - normalize_value: True - value_bootstrap: True - normalize_advantage: True - reward_shaper: - scale_value: 0.1 + name: Humanoid-v4_envpool + env_name: envpool + score_to_win: 20000 + normalize_input: True + normalize_value: True + value_bootstrap: True + reward_shaper: + scale_value: 0.1 + normalize_advantage: True + gamma: 0.99 + tau: 0.95 - gamma: 0.99 - tau: 0.95 - learning_rate: 3e-4 - lr_schedule: adaptive - kl_threshold: 0.008 - grad_norm: 1.0 - entropy_coef: 0.0 - truncate_grads: True - e_clip: 0.2 - clip_value: True - use_smooth_clamp: True - bound_loss_type: regularisation - bounds_loss_coef: 0.0 - max_epochs: 2000 - num_actors: 64 - horizon_length: 128 - minibatch_size: 2048 - mini_epochs: 5 - critic_coef: 4 + learning_rate: 3e-4 + lr_schedule: adaptive + kl_threshold: 0.008 + grad_norm: 1.0 + entropy_coef: 0.0 + truncate_grads: True + e_clip: 0.2 + clip_value: True + use_smooth_clamp: True + bound_loss_type: regularisation + bounds_loss_coef: 0.0005 + max_epochs: 2000 + num_actors: 64 + horizon_length: 128 + minibatch_size: 2048 + mini_epochs: 5 + critic_coef: 4 - env_config: - env_name: Humanoid-v4 - seed: 5 - #flat_observation: True + env_config: + env_name: Humanoid-v4 - player: - render: False - num_actors: 64 - games_num: 1000 - use_vecenv: True \ No newline at end of file + player: + render: True \ No newline at end of file