Skip to content

Commit

Permalink
updated builder
Browse files Browse the repository at this point in the history
  • Loading branch information
DenSumy committed Oct 21, 2024
1 parent 5145df3 commit 2c05bcc
Show file tree
Hide file tree
Showing 5 changed files with 367 additions and 78 deletions.
61 changes: 41 additions & 20 deletions rl_games/algos_torch/network_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import torch.nn as nn

from rl_games.algos_torch.d2rl import D2RLNet
from rl_games.common.layers.switch_ffn import MoEBlock
from rl_games.algos_torch.sac_helper import SquashedNormal
from rl_games.common.layers.recurrent import GRUWithDones, LSTMWithDones
from rl_games.common.layers.value import TwoHotEncodedValue, DefaultValue
Expand Down Expand Up @@ -67,6 +68,8 @@ def get_default_rnn_state(self):
return None

def get_aux_loss(self):
if self.moe_block:
return self.actor_mlp.get_aux_loss()
return None

def _calc_input_size(self, input_shape,cnn_layers=None):
Expand Down Expand Up @@ -128,6 +131,9 @@ def _build_mlp(self,
else:
return self._build_sequential_mlp(input_size, units, activation, dense_func, norm_func_name = None,)

def _build_moe_block(self, input_size, expert_units, model_units, num_experts):
return MoEBlock(input_size, expert_units, model_units, num_experts)

def _build_conv(self, ctype, **kwargs):
print('conv_name:', ctype)

Expand Down Expand Up @@ -231,9 +237,8 @@ def __init__(self, params, **kwargs):
cnn_output_size = self._calc_input_size(input_shape, self.actor_cnn)

mlp_input_size = cnn_output_size
if len(self.units) == 0:
out_size = cnn_output_size
else:
out_size = mlp_input_size
if len(self.units) > 0:
out_size = self.units[-1]

if self.has_rnn:
Expand Down Expand Up @@ -263,18 +268,23 @@ def __init__(self, params, **kwargs):
if self.rnn_ln:
self.layer_norm = torch.nn.LayerNorm(self.rnn_units)

mlp_args = {
'input_size' : mlp_input_size,
'units' : self.units,
'activation' : self.activation,
'norm_func_name' : self.normalization,
'dense_func' : torch.nn.Linear,
'd2rl' : self.is_d2rl,
'norm_only_first_layer' : self.norm_only_first_layer
}
self.actor_mlp = self._build_mlp(**mlp_args)
if self.separate:
self.critic_mlp = self._build_mlp(**mlp_args)

if self.moe_block:
self.actor_mlp = self._build_moe_block(mlp_input_size, self.expert_units, self.model_units, self.num_experts)
assert(not self.separate)
else:
mlp_args = {
'input_size' : mlp_input_size,
'units' : self.units,
'activation' : self.activation,
'norm_func_name' : self.normalization,
'dense_func' : torch.nn.Linear,
'd2rl' : self.is_d2rl,
'norm_only_first_layer' : self.norm_only_first_layer
}
self.actor_mlp = self._build_mlp(**mlp_args)
if self.separate:
self.critic_mlp = self._build_mlp(**mlp_args)

self.value = self._build_value_layer(out_size, self.value_size)
self.value_act = self.activations_factory.create(self.value_activation)
Expand Down Expand Up @@ -506,11 +516,22 @@ def get_default_rnn_state(self):

def load(self, params):
self.separate = params.get('separate', False)
self.units = params['mlp']['units']
self.activation = params['mlp']['activation']
self.initializer = params['mlp']['initializer']
self.is_d2rl = params['mlp'].get('d2rl', False)
self.norm_only_first_layer = params['mlp'].get('norm_only_first_layer', False)
self.moe_block = params.get('moe', False)

if self.moe_block:
assert(not params.get('mlp', False))
self.num_experts = self.moe_block['num_experts']
self.expert_units = self.moe_block['expert_units']
self.model_units = self.moe_block['model_units']
self.initializer = self.moe_block['initializer']
self.units = self.expert_units

else:
self.units = params['mlp']['units']
self.activation = params['mlp']['activation']
self.initializer = params['mlp']['initializer']
self.is_d2rl = params['mlp'].get('d2rl', False)
self.norm_only_first_layer = params['mlp'].get('norm_only_first_layer', False)
self.value_activation = params.get('value_activation', 'None')
self.normalization = params.get('normalization', None)
self.has_rnn = 'rnn' in params
Expand Down
210 changes: 210 additions & 0 deletions rl_games/common/layers/switch_ffn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,210 @@
import torch
import torch.nn as nn
import torch.nn.functional as F


class SwitchFeedForward(nn.Module):

def __init__(self,
model_dim: int,
hidden_dim: int,
out_dim: int,
is_scale_prob: bool,
num_experts: int,
activation: nn.Module = nn.ReLU

):
super().__init__()
self.hidden_dim = hidden_dim
self.model_dim = model_dim
self.out_dim = out_dim
self.is_scale_prob = is_scale_prob
self.num_experts = num_experts
self.experts = nn.ModuleList([
nn.Sequential(
nn.Linear(model_dim, out_dim),
activation(),
#nn.Linear(model_dim, hidden_dim),
#activation(),
#nn.Linear(hidden_dim, out_dim),
#activation(),
)
for _ in range(num_experts)
])
# Routing layer and softmax
self.switch = nn.Linear(model_dim, num_experts)
self.softmax = nn.Softmax(dim=-1)

def forward(self, x: torch.Tensor):
route_prob = self.softmax(self.switch(x))
route_prob_max, routes = torch.max(route_prob, dim=-1)
indexes_list = [torch.eq(routes, i).nonzero(as_tuple=True)[0] for i in range(self.num_experts)]

final_output = torch.zeros((x.size(0), self.out_dim), device=x.device)
counts = x.new_tensor([len(indexes_list[i]) for i in range(self.num_experts)])

# Get outputs of the expert FFNs
expert_output = [self.experts[i](x[indexes_list[i], :]) for i in range(self.num_experts)]
# Assign to final output
for i in range(self.num_experts):
final_output[indexes_list[i], :] = expert_output[i]

if self.is_scale_prob:
# Multiply by the expert outputs by the probabilities $y = p_i(x) E_i(x)$
final_output = final_output * route_prob_max.view(-1, 1)
else:
# not sure if this is correct
final_output = final_output * (route_prob_max / route_prob_max.detach()).view(-1, 1)


return final_output, counts, route_prob.sum(0), route_prob_max



class MoEFF(nn.Module):
def __init__(self,
model_dim: int,
hidden_dim: int,
out_dim: int,
num_experts: int,
activation: nn.Module = nn.ReLU,
**kwargs
):
super().__init__()

# Parameters from params
self.model_dim = model_dim
self.num_experts = num_experts
self.hidden_dim = hidden_dim
self.out_dim = out_dim
self.gating_hidden_size = kwargs.get('gating_hidden_size', 64)
self.use_sparse_gating = kwargs.get('use_sparse_gating', True)
self.use_entropy_loss = kwargs.get('use_entropy_loss', True)
self.use_diversity_loss = kwargs.get('use_diversity_loss', True)
self.top_k = kwargs.get('top_k', 2)
self.lambda_entropy = kwargs.get('lambda_entropy', 0.01)
self.lambda_diversity = kwargs.get('lambda_diversity', 0.00)


# Gating Network
self.gating_fc1 = nn.Linear(self.model_dim, self.gating_hidden_size)
self.gating_fc2 = nn.Linear(self.gating_hidden_size, num_experts)

# Expert Networks
self.expert_networks = nn.ModuleList([
nn.Sequential(
nn.Linear(self.model_dim, out_dim),
activation(),
) for _ in range(num_experts)
])


# Auxiliary loss map
self.aux_loss_map = {
}
if self.use_diversity_loss:
self.aux_loss_map['moe_diversity_loss'] = 0.0
if self.use_entropy_loss:
self.aux_loss_map['moe_entropy_loss'] = 0.0

def get_aux_loss(self):
return self.aux_loss_map

def forward(self, x):

# Gating Network Forward Pass
gating_x = F.relu(self.gating_fc1(x))
gating_logits = self.gating_fc2(gating_x) # Shape: [batch_size, num_experts]
orig_gating_weights = F.softmax(gating_logits, dim=1)
gating_weights = orig_gating_weights
# Apply Sparse Gating if enabled
if self.use_sparse_gating:
topk_values, topk_indices = torch.topk(gating_weights, self.top_k, dim=1)
sparse_mask = torch.zeros_like(gating_weights)
sparse_mask.scatter_(1, topk_indices, topk_values)
# probably better go with masked softmax
gating_weights = sparse_mask / sparse_mask.sum(dim=1, keepdim=True)

if self.use_entropy_loss:
# Compute Entropy Loss for Gating Weights
entropy = -torch.sum(gating_weights * torch.log(gating_weights + 1e-8), dim=1)
entropy_loss = torch.mean(entropy)
self.aux_loss_map['moe_entropy_loss'] = -self.lambda_entropy * entropy_loss

# Expert Networks Forward Pass
expert_outputs = []
for expert in self.expert_networks:
expert_outputs.append(expert(x)) # Each output shape: [batch_size, hidden_size]
expert_outputs = torch.stack(expert_outputs, dim=1) # Shape: [batch_size, num_experts, hidden_size]

# Compute Diversity Loss
if self.use_diversity_loss:
diversity_loss = 0.0
num_experts = len(self.expert_networks)
for i in range(num_experts):
for j in range(i + 1, num_experts):
similarity = F.cosine_similarity(expert_outputs[:, i, :], expert_outputs[:, j, :], dim=-1)
diversity_loss += torch.mean(similarity)
num_pairs = num_experts * (num_experts - 1) / 2
diversity_loss = diversity_loss / num_pairs
self.aux_loss_map['moe_diversity_loss'] = self.lambda_diversity * diversity_loss

# Aggregate Expert Outputs
gating_weights = gating_weights.unsqueeze(-1) # Shape: [batch_size, num_experts, 1]
aggregated_output = torch.sum(gating_weights * expert_outputs, dim=1) # Shape: [batch_size, hidden_size]
out = aggregated_output
return out


class MoEBlock(nn.Module):
def __init__(self,
input_size: int,
model_units: list[int],
expert_units: list[int],
num_experts: int,
):
super().__init__()
self.num_experts = num_experts
in_size = input_size
layers =[]
for u, m in zip(expert_units, model_units):
layers.append(MoEFF(in_size, m, u, num_experts))
in_size = u
self.layers = nn.ModuleList(layers)
self.load_balancing_loss = None

def get_aux_loss(self):
return {
"moe_load_balancing_loss": self.load_balancing_loss
}

def forward(self, x: torch.Tensor):
moe_diversity_loss, moe_entropy_loss = 0, 0
for layer in self.layers:
x = layer(x)
moe_diversity_loss = moe_diversity_loss + layer.get_aux_loss()['moe_diversity_loss']
moe_entropy_loss = moe_diversity_loss + layer.get_aux_loss()['moe_entropy_loss']

self.load_balancing_loss = moe_diversity_loss / len(self.layers) + moe_entropy_loss / len(self.layers)
return x

'''
def forward(self, x: torch.Tensor):
counts, route_prob_sums, route_prob_maxs = [], [], []
for layer in self.layers:
x, count, route_prob_sum, route_prob_max = layer(x)
counts.append(count)
route_prob_sums.append(route_prob_sum)
route_prob_maxs.append(route_prob_max)
counts = torch.stack(counts)
route_prob_sums = torch.stack(route_prob_sums)
route_prob_maxs = torch.stack(route_prob_maxs)
total = counts.sum(dim=-1, keepdims=True)
route_frac = counts / total
route_prob = route_prob_sums / total
self.load_balancing_loss = self.num_experts * (route_frac * route_prob).sum()
return x
'''
64 changes: 64 additions & 0 deletions rl_games/configs/bark/ppo_merging.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
params:
seed: 5
algo:
name: a2c_continuous

model:
name: continuous_a2c_logstd

network:
name: actor_critic
separate: False
space:
continuous:
mu_activation: None
sigma_activation: None
mu_init:
name: default
sigma_init:
name: const_initializer
val: 0
fixed_sigma: True
mlp:
units: [256, 128, 64]
activation: elu
initializer:
name: default

config:
name: Ant-v3_ray
env_name: openai_gym
score_to_win: 20000
normalize_input: True
normalize_value: True
value_bootstrap: True
reward_shaper:
scale_value: 0.1
normalize_advantage: True
gamma: 0.99
tau: 0.95

learning_rate: 3e-4
lr_schedule: adaptive
kl_threshold: 0.008
grad_norm: 1.0
entropy_coef: 0.0
truncate_grads: True
e_clip: 0.2
max_epochs: 2000
num_actors: 8 #64
horizon_length: 256 #64
minibatch_size: 2048
mini_epochs: 4
critic_coef: 2
clip_value: True
use_smooth_clamp: True
bound_loss_type: regularisation
bounds_loss_coef: 0.0

env_config:
name: "merging-v0"
seed: 5

player:
render: True
Loading

0 comments on commit 2c05bcc

Please sign in to comment.