Skip to content

Commit

Permalink
seems to be working :)
Browse files Browse the repository at this point in the history
  • Loading branch information
DenSumy committed Oct 19, 2024
1 parent 53c475f commit 54813ee
Show file tree
Hide file tree
Showing 5 changed files with 136 additions and 32 deletions.
1 change: 1 addition & 0 deletions rl_games/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from rl_games.networks import *
1 change: 0 additions & 1 deletion rl_games/algos_torch/network_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
from rl_games.common.layers.recurrent import GRUWithDones, LSTMWithDones
from rl_games.common.layers.value import TwoHotEncodedValue, DefaultValue


def _create_initializer(func, **kwargs):
return lambda v : func(v, **kwargs)

Expand Down
71 changes: 71 additions & 0 deletions rl_games/configs/mujoco/ant_envpool_moe.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
params:
seed: 5
algo:
name: a2c_continuous

model:
name: continuous_a2c_logstd

network:
name: moe
space:
continuous:
mu_activation: None
sigma_activation: None
mu_init:
name: default
sigma_init:
name: const_initializer
val: 0
fixed_sigma: True
num_experts: 4
hidden_size: 256
gating_hidden_size: 128
use_sparse_gating: true
use_entropy_loss: true
use_diversity_loss: true
top_k: 2
lambda_entropy: 0.01
lambda_diversity: 0.01

config:
name: Ant-v4_envpool_moe
env_name: envpool
score_to_win: 20000
normalize_input: True
normalize_value: True
value_bootstrap: True
normalize_advantage: True
reward_shaper:
scale_value: 1

gamma: 0.99
tau: 0.95
learning_rate: 3e-4
lr_schedule: adaptive
kl_threshold: 0.008
grad_norm: 1.0
entropy_coef: 0.0
truncate_grads: True
e_clip: 0.2
clip_value: True
use_smooth_clamp: True
bound_loss_type: regularisation
bounds_loss_coef: 0.0
max_epochs: 2000
num_actors: 64
horizon_length: 64
minibatch_size: 2048
mini_epochs: 4
critic_coef: 2

env_config:
env_name: Ant-v4
seed: 5
#flat_observation: True

player:
render: False
num_actors: 64
games_num: 1000
use_vecenv: True
5 changes: 4 additions & 1 deletion rl_games/networks/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
from rl_games.networks.tcnn_mlp import TcnnNetBuilder
from rl_games.networks.moe import MoENetBuilder

from rl_games.algos_torch import model_builder

model_builder.register_network('tcnnnet', TcnnNetBuilder)
model_builder.register_network('tcnnnet', TcnnNetBuilder)
model_builder.register_network('moe', MoENetBuilder)
90 changes: 60 additions & 30 deletions rl_games/networks/moe.py
Original file line number Diff line number Diff line change
@@ -1,32 +1,51 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from rl_games.common import networks
from rl_games.common import layers
from rl_games.algos_torch.network_builder import NetworkBuilder

class MoENet(networks.NetworkBuilder.BaseNetwork):
class MoENet(NetworkBuilder.BaseNetwork):
def __init__(self, params, **kwargs):
nn.Module.__init__(self)
NetworkBuilder.BaseNetwork.__init__(self)
actions_num = kwargs.pop('actions_num')
input_shape = kwargs.pop('input_shape')
num_inputs = 0

self.has_space = 'space' in params
self.central_value = params.get('central_value', False)
if self.has_space:
self.is_multi_discrete = 'multi_discrete'in params['space']
self.is_discrete = 'discrete' in params['space']
self.is_continuous = 'continuous'in params['space']
if self.is_continuous:
self.space_config = params['space']['continuous']
self.fixed_sigma = self.space_config['fixed_sigma']
elif self.is_discrete:
self.space_config = params['space']['discrete']
elif self.is_multi_discrete:
self.space_config = params['space']['multi_discrete']
else:
self.is_discrete = False
self.is_continuous = False
self.is_multi_discrete = False

self.value_size = kwargs.pop('value_size', 1)

# Parameters from params
num_experts = params.get('num_experts', 3)
hidden_size = params.get('hidden_size', 128)
gating_hidden_size = params.get('gating_hidden_size', 64)
self.use_sparse_gating = params.get('use_sparse_gating', False)
self.use_entropy_loss = params.get('use_entropy_loss', True)
self.use_diversity_loss = params.get('use_diversity_loss', True)
self.top_k = params.get('top_k', 1)
self.lambda_entropy = params.get('lambda_entropy', 0.01)
self.lambda_diversity = params.get('lambda_diversity', 0.01)

# Input processing
assert isinstance(input_shape, dict), "Input shape must be a dict"
for k, v in input_shape.items():
num_inputs += v[0]
#assert isinstance(input_shape, dict), "Input shape must be a dict"
#for k, v in input_shape.items():
# num_inputs += v[0]
num_inputs = input_shape[0]

# Gating Network
self.gating_fc1 = nn.Linear(num_inputs, gating_hidden_size)
Expand All @@ -44,15 +63,27 @@ def __init__(self, params, **kwargs):
) for _ in range(num_experts)
])

# Output layers
self.mean_linear = nn.Linear(hidden_size, actions_num)
if self.is_discrete:
self.logits = torch.nn.Linear(hidden_size, actions_num)
if self.is_multi_discrete:
self.logits = torch.nn.ModuleList([torch.nn.Linear(hidden_size, num) for num in actions_num])
if self.is_continuous:
self.mu = torch.nn.Linear(hidden_size, actions_num)
self.sigma = torch.nn.Parameter(torch.zeros(actions_num, requires_grad=True, dtype=torch.float32),
requires_grad=True)
self.mu_act = self.activations_factory.create(self.space_config['mu_activation'])
#mu_init = self.init_factory.create(**self.space_config['mu_init'])
self.sigma_act = self.activations_factory.create(self.space_config['sigma_activation'])
#sigma_init = self.init_factory.create(**self.space_config['sigma_init'])
self.value = nn.Linear(hidden_size, self.value_size)

# Auxiliary loss map
self.aux_loss_map = {
'entropy_loss': None,
'diversity_loss': None,
}
if self.use_diversity_loss:
self.aux_loss_map['diversity_loss'] = 0.0
if self.use_entropy_loss:
self.aux_loss_map['entropy_loss'] = 0.0

def is_rnn(self):
return False
Expand All @@ -61,11 +92,7 @@ def get_aux_loss(self):
return self.aux_loss_map

def forward(self, obs_dict):
# Combine observations
obs = []
for k in obs_dict:
obs.append(obs_dict[k])
obs = torch.cat(obs, dim=-1)
obs = obs_dict['obs']

# Gating Network Forward Pass
gating_x = F.relu(self.gating_fc1(obs))
Expand All @@ -79,10 +106,12 @@ def forward(self, obs_dict):
sparse_mask.scatter_(1, topk_indices, topk_values)
gating_weights = sparse_mask / sparse_mask.sum(dim=1, keepdim=True) # Re-normalize


if self.use_entropy_loss:
# Compute Entropy Loss for Gating Weights
entropy = -torch.sum(gating_weights * torch.log(gating_weights + 1e-8), dim=1)
entropy_loss = torch.mean(entropy)
self.aux_loss_map['entropy_loss'] = self.lambda_entropy * entropy_loss
entropy = -torch.sum(gating_weights * torch.log(gating_weights + 1e-8), dim=1)
entropy_loss = torch.mean(entropy)
self.aux_loss_map['entropy_loss'] = self.lambda_entropy * entropy_loss

# Expert Networks Forward Pass
expert_outputs = []
Expand All @@ -91,22 +120,23 @@ def forward(self, obs_dict):
expert_outputs = torch.stack(expert_outputs, dim=1) # Shape: [batch_size, num_experts, hidden_size]

# Compute Diversity Loss
diversity_loss = 0.0
num_experts = len(self.expert_networks)
for i in range(num_experts):
for j in range(i + 1, num_experts):
similarity = F.cosine_similarity(expert_outputs[:, i, :], expert_outputs[:, j, :], dim=-1)
diversity_loss += torch.mean(similarity)
num_pairs = num_experts * (num_experts - 1) / 2
diversity_loss = diversity_loss / num_pairs
self.aux_loss_map['diversity_loss'] = self.lambda_diversity * diversity_loss
if self.use_diversity_loss:
diversity_loss = 0.0
num_experts = len(self.expert_networks)
for i in range(num_experts):
for j in range(i + 1, num_experts):
similarity = F.cosine_similarity(expert_outputs[:, i, :], expert_outputs[:, j, :], dim=-1)
diversity_loss += torch.mean(similarity)
num_pairs = num_experts * (num_experts - 1) / 2
diversity_loss = diversity_loss / num_pairs
self.aux_loss_map['diversity_loss'] = self.lambda_diversity * diversity_loss

# Aggregate Expert Outputs
gating_weights = gating_weights.unsqueeze(-1) # Shape: [batch_size, num_experts, 1]
aggregated_output = torch.sum(gating_weights * expert_outputs, dim=1) # Shape: [batch_size, hidden_size]

out = aggregated_output
value = self.value_act(self.value(out))
value = self.value(out)
states = None
if self.central_value:
return value, states
Expand All @@ -123,7 +153,7 @@ def forward(self, obs_dict):
sigma = self.sigma_act(self.sigma)
else:
sigma = self.sigma_act(self.sigma(out))
return mu, mu*0 + sigma, value, states
return mu, mu*0 + sigma, value, states


from rl_games.algos_torch.network_builder import NetworkBuilder
Expand Down

0 comments on commit 54813ee

Please sign in to comment.