Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Mixture of Experts MLP #309

Closed
wants to merge 7 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions rl_games/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from rl_games.networks import *
1 change: 0 additions & 1 deletion rl_games/algos_torch/network_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
from rl_games.common.layers.value import TwoHotEncodedValue, DefaultValue
from rl_games.algos_torch.spatial_softmax import SpatialSoftArgmax


def _create_initializer(func, **kwargs):
return lambda v : func(v, **kwargs)

Expand Down
1 change: 0 additions & 1 deletion rl_games/common/env_configurations.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
import math



class HCRewardEnv(gym.RewardWrapper):
def __init__(self, env):
gym.RewardWrapper.__init__(self, env)
Expand Down
15 changes: 15 additions & 0 deletions rl_games/common/player.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from rl_games.common import env_configurations
from rl_games.algos_torch import model_builder

import pandas as pd

class BasePlayer(object):

Expand Down Expand Up @@ -271,6 +272,9 @@ def init_rnn(self):
)[2]), dtype=torch.float32).to(self.device) for s in rnn_states]

def run(self):
# create pandas dataframe with fields: game_index, observation, action, reward and done
df = pd.DataFrame(columns=['game_index', 'observation', 'action', 'reward', 'done'])

n_games = self.games_num
render = self.render_env
n_game_life = self.n_game_life
Expand Down Expand Up @@ -313,6 +317,8 @@ def run(self):

print_game_res = False

game_indices = torch.arange(0, batch_size).to(self.device)
cur_games = batch_size
for n in range(self.max_steps):
if self.evaluation and n % self.update_checkpoint_freq == 0:
self.maybe_load_new_checkpoint()
Expand All @@ -324,7 +330,11 @@ def run(self):
else:
action = self.get_action(obses, is_deterministic)

prev_obses = obses
obses, r, done, info = self.env_step(self.env, action)

for i in range(batch_size):
df.loc[len(df)] = [game_indices[i].cpu().numpy().item(), prev_obses[i].cpu().numpy(), action[i].cpu().numpy(), r[i].cpu().numpy().item(), done[i].cpu().numpy().item()]
cr += r
steps += 1

Expand All @@ -337,6 +347,9 @@ def run(self):
done_count = len(done_indices)
games_played += done_count

for bid in done_indices:
game_indices[bid] = cur_games
cur_games += 1
if done_count > 0:
if self.is_rnn:
for s in self.states:
Expand Down Expand Up @@ -379,6 +392,8 @@ def run(self):
else:
print('av reward:', sum_rewards / games_played * n_game_life,
'av steps:', sum_steps / games_played * n_game_life)

df.to_parquet('game_data.parquet')

def get_batch_size(self, obses, batch_size):
obs_shape = self.obs_shape
Expand Down
5 changes: 4 additions & 1 deletion rl_games/configs/mujoco/ant_envpool.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -62,4 +62,7 @@ params:
#flat_observation: True

player:
render: False
render: False
num_actors: 64
games_num: 1000
use_vecenv: True
71 changes: 71 additions & 0 deletions rl_games/configs/mujoco/ant_envpool_moe.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
params:
seed: 5
algo:
name: a2c_continuous

model:
name: continuous_a2c_logstd

network:
name: moe
space:
continuous:
mu_activation: None
sigma_activation: None
mu_init:
name: default
sigma_init:
name: const_initializer
val: 0
fixed_sigma: True
num_experts: 4
hidden_size: 256
gating_hidden_size: 128
use_sparse_gating: True
use_entropy_loss: True
use_diversity_loss: False
top_k: 2
lambda_entropy: -0.01
lambda_diversity: 0.01

config:
name: Ant-v4_envpool_moe
env_name: envpool
score_to_win: 20000
normalize_input: True
normalize_value: True
value_bootstrap: True
normalize_advantage: True
reward_shaper:
scale_value: 1

gamma: 0.99
tau: 0.95
learning_rate: 3e-4
lr_schedule: adaptive
kl_threshold: 0.008
grad_norm: 1.0
entropy_coef: 0.0
truncate_grads: True
e_clip: 0.2
clip_value: True
use_smooth_clamp: True
bound_loss_type: regularisation
bounds_loss_coef: 0.0
max_epochs: 2000
num_actors: 64
horizon_length: 64
minibatch_size: 2048
mini_epochs: 4
critic_coef: 2

env_config:
env_name: Ant-v4
seed: 5
#flat_observation: True

player:
render: False
num_actors: 64
games_num: 1000
use_vecenv: True
71 changes: 71 additions & 0 deletions rl_games/configs/mujoco/humanoid_envpool_moe.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
params:
seed: 5
algo:
name: a2c_continuous

model:
name: continuous_a2c_logstd

network:
name: moe
space:
continuous:
mu_activation: None
sigma_activation: None
mu_init:
name: default
sigma_init:
name: const_initializer
val: 0
fixed_sigma: True
num_experts: 4
hidden_size: 512
gating_hidden_size: 128
use_sparse_gating: True
use_entropy_loss: True
use_diversity_loss: True
top_k: 2
lambda_entropy: -0.01
lambda_diversity: 0.01

config:
name: Humanoid_envpool_moe
env_name: envpool
score_to_win: 20000
normalize_input: True
normalize_value: True
value_bootstrap: True
normalize_advantage: True
reward_shaper:
scale_value: 0.1

gamma: 0.99
tau: 0.95
learning_rate: 3e-4
lr_schedule: adaptive
kl_threshold: 0.008
grad_norm: 1.0
entropy_coef: 0.0
truncate_grads: True
e_clip: 0.2
clip_value: True
use_smooth_clamp: True
bound_loss_type: regularisation
bounds_loss_coef: 0.0
max_epochs: 2000
num_actors: 64
horizon_length: 128
minibatch_size: 2048
mini_epochs: 5
critic_coef: 4

env_config:
env_name: Humanoid-v4
seed: 5
#flat_observation: True

player:
render: False
num_actors: 64
games_num: 1000
use_vecenv: True
5 changes: 4 additions & 1 deletion rl_games/networks/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
from rl_games.networks.tcnn_mlp import TcnnNetBuilder
from rl_games.networks.moe import MoENetBuilder

from rl_games.algos_torch import model_builder

model_builder.register_network('tcnnnet', TcnnNetBuilder)
model_builder.register_network('tcnnnet', TcnnNetBuilder)
model_builder.register_network('moe', MoENetBuilder)
Loading