Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Added MoE layer #310

Open
wants to merge 11 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions rl_games/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from rl_games.networks import *
2 changes: 1 addition & 1 deletion rl_games/algos_torch/a2c_continuous.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def __init__(self, base_name, params):
self.init_rnn_from_model(self.model)
self.last_lr = float(self.last_lr)
self.bound_loss_type = self.config.get('bound_loss_type', 'bound') # 'regularisation' or 'bound'
self.optimizer = optim.Adam(self.model.parameters(), float(self.last_lr), eps=1e-08, weight_decay=self.weight_decay)
self.optimizer = optim.Adam(self.model.parameters(), float(self.last_lr), eps=1e-08, weight_decay=self.weight_decay, fused=True)

if self.has_central_value:
cv_config = {
Expand Down
62 changes: 41 additions & 21 deletions rl_games/algos_torch/network_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,12 @@
import torch.nn as nn

from rl_games.algos_torch.d2rl import D2RLNet
from rl_games.common.layers.switch_ffn import MoEBlock
from rl_games.algos_torch.sac_helper import SquashedNormal
from rl_games.common.layers.recurrent import GRUWithDones, LSTMWithDones
from rl_games.common.layers.value import TwoHotEncodedValue, DefaultValue
from rl_games.algos_torch.spatial_softmax import SpatialSoftArgmax


def _create_initializer(func, **kwargs):
return lambda v : func(v, **kwargs)

Expand Down Expand Up @@ -68,6 +68,8 @@ def get_default_rnn_state(self):
return None

def get_aux_loss(self):
if self.moe_block:
return self.actor_mlp.get_aux_loss()
return None

def _calc_input_size(self, input_shape,cnn_layers=None):
Expand Down Expand Up @@ -129,6 +131,9 @@ def _build_mlp(self,
else:
return self._build_sequential_mlp(input_size, units, activation, dense_func, norm_func_name = None,)

def _build_moe_block(self, input_size, expert_units, model_units, num_experts):
return MoEBlock(input_size, expert_units, model_units, num_experts)

def _build_conv(self, ctype, **kwargs):
print('conv_name:', ctype)

Expand Down Expand Up @@ -232,9 +237,8 @@ def __init__(self, params, **kwargs):
cnn_output_size = self._calc_input_size(input_shape, self.actor_cnn)

mlp_input_size = cnn_output_size
if len(self.units) == 0:
out_size = cnn_output_size
else:
out_size = mlp_input_size
if len(self.units) > 0:
out_size = self.units[-1]

if self.has_rnn:
Expand Down Expand Up @@ -264,18 +268,23 @@ def __init__(self, params, **kwargs):
if self.rnn_ln:
self.layer_norm = torch.nn.LayerNorm(self.rnn_units)

mlp_args = {
'input_size' : mlp_input_size,
'units' : self.units,
'activation' : self.activation,
'norm_func_name' : self.normalization,
'dense_func' : torch.nn.Linear,
'd2rl' : self.is_d2rl,
'norm_only_first_layer' : self.norm_only_first_layer
}
self.actor_mlp = self._build_mlp(**mlp_args)
if self.separate:
self.critic_mlp = self._build_mlp(**mlp_args)

if self.moe_block:
self.actor_mlp = self._build_moe_block(mlp_input_size, self.expert_units, self.model_units, self.num_experts)
assert(not self.separate)
else:
mlp_args = {
'input_size' : mlp_input_size,
'units' : self.units,
'activation' : self.activation,
'norm_func_name' : self.normalization,
'dense_func' : torch.nn.Linear,
'd2rl' : self.is_d2rl,
'norm_only_first_layer' : self.norm_only_first_layer
}
self.actor_mlp = self._build_mlp(**mlp_args)
if self.separate:
self.critic_mlp = self._build_mlp(**mlp_args)

self.value = self._build_value_layer(out_size, self.value_size)
self.value_act = self.activations_factory.create(self.value_activation)
Expand Down Expand Up @@ -507,11 +516,22 @@ def get_default_rnn_state(self):

def load(self, params):
self.separate = params.get('separate', False)
self.units = params['mlp']['units']
self.activation = params['mlp']['activation']
self.initializer = params['mlp']['initializer']
self.is_d2rl = params['mlp'].get('d2rl', False)
self.norm_only_first_layer = params['mlp'].get('norm_only_first_layer', False)
self.moe_block = params.get('moe', False)

if self.moe_block:
assert(not params.get('mlp', False))
self.num_experts = self.moe_block['num_experts']
self.expert_units = self.moe_block['expert_units']
self.model_units = self.moe_block['model_units']
self.initializer = self.moe_block['initializer']
self.units = self.expert_units

else:
self.units = params['mlp']['units']
self.activation = params['mlp']['activation']
self.initializer = params['mlp']['initializer']
self.is_d2rl = params['mlp'].get('d2rl', False)
self.norm_only_first_layer = params['mlp'].get('norm_only_first_layer', False)
self.value_activation = params.get('value_activation', 'None')
self.normalization = params.get('normalization', None)
self.has_rnn = 'rnn' in params
Expand Down
1 change: 0 additions & 1 deletion rl_games/common/env_configurations.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
import math



class HCRewardEnv(gym.RewardWrapper):
def __init__(self, env):
gym.RewardWrapper.__init__(self, env)
Expand Down
210 changes: 210 additions & 0 deletions rl_games/common/layers/switch_ffn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,210 @@
import torch
import torch.nn as nn
import torch.nn.functional as F


class SwitchFeedForward(nn.Module):

def __init__(self,
model_dim: int,
hidden_dim: int,
out_dim: int,
is_scale_prob: bool,
num_experts: int,
activation: nn.Module = nn.ReLU

):
super().__init__()
self.hidden_dim = hidden_dim
self.model_dim = model_dim
self.out_dim = out_dim
self.is_scale_prob = is_scale_prob
self.num_experts = num_experts
self.experts = nn.ModuleList([
nn.Sequential(
nn.Linear(model_dim, out_dim),
activation(),
#nn.Linear(model_dim, hidden_dim),
#activation(),
#nn.Linear(hidden_dim, out_dim),
#activation(),
)
for _ in range(num_experts)
])
# Routing layer and softmax
self.switch = nn.Linear(model_dim, num_experts)
self.softmax = nn.Softmax(dim=-1)

def forward(self, x: torch.Tensor):
route_prob = self.softmax(self.switch(x))
route_prob_max, routes = torch.max(route_prob, dim=-1)
indexes_list = [torch.eq(routes, i).nonzero(as_tuple=True)[0] for i in range(self.num_experts)]

final_output = torch.zeros((x.size(0), self.out_dim), device=x.device)
counts = x.new_tensor([len(indexes_list[i]) for i in range(self.num_experts)])

# Get outputs of the expert FFNs
expert_output = [self.experts[i](x[indexes_list[i], :]) for i in range(self.num_experts)]
# Assign to final output
for i in range(self.num_experts):
final_output[indexes_list[i], :] = expert_output[i]

if self.is_scale_prob:
# Multiply by the expert outputs by the probabilities $y = p_i(x) E_i(x)$
final_output = final_output * route_prob_max.view(-1, 1)
else:
# not sure if this is correct
final_output = final_output * (route_prob_max / route_prob_max.detach()).view(-1, 1)


return final_output, counts, route_prob.sum(0), route_prob_max



class MoEFF(nn.Module):
def __init__(self,
model_dim: int,
hidden_dim: int,
out_dim: int,
num_experts: int,
activation: nn.Module = nn.ReLU,
**kwargs
):
super().__init__()

# Parameters from params
self.model_dim = model_dim
self.num_experts = num_experts
self.hidden_dim = hidden_dim
self.out_dim = out_dim
self.gating_hidden_size = kwargs.get('gating_hidden_size', 64)
self.use_sparse_gating = kwargs.get('use_sparse_gating', True)
self.use_entropy_loss = kwargs.get('use_entropy_loss', True)
self.use_diversity_loss = kwargs.get('use_diversity_loss', True)
self.top_k = kwargs.get('top_k', 2)
self.lambda_entropy = kwargs.get('lambda_entropy', 0.01)
self.lambda_diversity = kwargs.get('lambda_diversity', 0.00)


# Gating Network
self.gating_fc1 = nn.Linear(self.model_dim, self.gating_hidden_size)
self.gating_fc2 = nn.Linear(self.gating_hidden_size, num_experts)

# Expert Networks
self.expert_networks = nn.ModuleList([
nn.Sequential(
nn.Linear(self.model_dim, out_dim),
activation(),
) for _ in range(num_experts)
])


# Auxiliary loss map
self.aux_loss_map = {
}
if self.use_diversity_loss:
self.aux_loss_map['moe_diversity_loss'] = 0.0
if self.use_entropy_loss:
self.aux_loss_map['moe_entropy_loss'] = 0.0

def get_aux_loss(self):
return self.aux_loss_map

def forward(self, x):

# Gating Network Forward Pass
gating_x = F.relu(self.gating_fc1(x))
gating_logits = self.gating_fc2(gating_x) # Shape: [batch_size, num_experts]
orig_gating_weights = F.softmax(gating_logits, dim=1)
gating_weights = orig_gating_weights
# Apply Sparse Gating if enabled
if self.use_sparse_gating:
topk_values, topk_indices = torch.topk(gating_weights, self.top_k, dim=1)
sparse_mask = torch.zeros_like(gating_weights)
sparse_mask.scatter_(1, topk_indices, topk_values)
# probably better go with masked softmax
gating_weights = sparse_mask / sparse_mask.sum(dim=1, keepdim=True)

if self.use_entropy_loss:
# Compute Entropy Loss for Gating Weights
entropy = -torch.sum(gating_weights * torch.log(gating_weights + 1e-8), dim=1)
entropy_loss = torch.mean(entropy)
self.aux_loss_map['moe_entropy_loss'] = -self.lambda_entropy * entropy_loss

# Expert Networks Forward Pass
expert_outputs = []
for expert in self.expert_networks:
expert_outputs.append(expert(x)) # Each output shape: [batch_size, hidden_size]
expert_outputs = torch.stack(expert_outputs, dim=1) # Shape: [batch_size, num_experts, hidden_size]

# Compute Diversity Loss
if self.use_diversity_loss:
diversity_loss = 0.0
num_experts = len(self.expert_networks)
for i in range(num_experts):
for j in range(i + 1, num_experts):
similarity = F.cosine_similarity(expert_outputs[:, i, :], expert_outputs[:, j, :], dim=-1)
diversity_loss += torch.mean(similarity)
num_pairs = num_experts * (num_experts - 1) / 2
diversity_loss = diversity_loss / num_pairs
self.aux_loss_map['moe_diversity_loss'] = self.lambda_diversity * diversity_loss

# Aggregate Expert Outputs
gating_weights = gating_weights.unsqueeze(-1) # Shape: [batch_size, num_experts, 1]
aggregated_output = torch.sum(gating_weights * expert_outputs, dim=1) # Shape: [batch_size, hidden_size]
out = aggregated_output
return out


class MoEBlock(nn.Module):
def __init__(self,
input_size: int,
model_units: list[int],
expert_units: list[int],
num_experts: int,
):
super().__init__()
self.num_experts = num_experts
in_size = input_size
layers =[]
for u, m in zip(expert_units, model_units):
layers.append(MoEFF(in_size, m, u, num_experts))
in_size = u
self.layers = nn.ModuleList(layers)
self.load_balancing_loss = None

def get_aux_loss(self):
return {
"moe_load_balancing_loss": self.load_balancing_loss
}

def forward(self, x: torch.Tensor):
moe_diversity_loss, moe_entropy_loss = 0, 0
for layer in self.layers:
x = layer(x)
moe_diversity_loss = moe_diversity_loss + layer.get_aux_loss()['moe_diversity_loss']
moe_entropy_loss = moe_diversity_loss + layer.get_aux_loss()['moe_entropy_loss']

self.load_balancing_loss = moe_diversity_loss / len(self.layers) + moe_entropy_loss / len(self.layers)
return x

'''
def forward(self, x: torch.Tensor):
counts, route_prob_sums, route_prob_maxs = [], [], []
for layer in self.layers:
x, count, route_prob_sum, route_prob_max = layer(x)
counts.append(count)
route_prob_sums.append(route_prob_sum)
route_prob_maxs.append(route_prob_max)

counts = torch.stack(counts)
route_prob_sums = torch.stack(route_prob_sums)
route_prob_maxs = torch.stack(route_prob_maxs)

total = counts.sum(dim=-1, keepdims=True)
route_frac = counts / total
route_prob = route_prob_sums / total

self.load_balancing_loss = self.num_experts * (route_frac * route_prob).sum()
return x
'''
Loading