Skip to content

Commit

Permalink
fix(pu): fix moe in feedforward layer of transformer and polish configs
Browse files Browse the repository at this point in the history
  • Loading branch information
dyyoungg committed Jul 19, 2024
1 parent b460d2f commit 5117459
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 18 deletions.
19 changes: 12 additions & 7 deletions lzero/model/unizero_world_models/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,17 +122,22 @@ def __init__(self, config: TransformerConfig) -> None:
self.ln2 = nn.LayerNorm(config.embed_dim)
self.attn = SelfAttention(config)
if config.moe_in_transformer:
self.mlp = nn.Sequential(
nn.Linear(config.embed_dim, 4 * config.embed_dim),
nn.GELU(approximate='tanh'),
nn.Linear(4 * config.embed_dim, config.embed_dim),
nn.Dropout(config.resid_pdrop),
)
# 创建多个独立的 MLP 实例
self.experts = nn.ModuleList([
nn.Sequential(
nn.Linear(config.embed_dim, 4 * config.embed_dim),
nn.GELU(approximate='tanh'),
nn.Linear(4 * config.embed_dim, config.embed_dim),
nn.Dropout(config.resid_pdrop),
) for _ in range(config.num_experts_of_moe_in_transformer)
])

self.feed_forward = MoeLayer(
experts=[self.mlp for _ in range(config.num_experts_of_moe_in_transformer)],
experts=self.experts,
gate=nn.Linear(config.embed_dim, config.num_experts_of_moe_in_transformer, bias=False),
num_experts_per_tok=1,
)

print("="*20)
print(f'use moe in feed_forward of transformer, num of expert: {config.num_experts_of_moe_in_transformer}')
print("="*20)
Expand Down
25 changes: 14 additions & 11 deletions zoo/atari/config/atari_unizero_multitask_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,8 @@ def create_config(env_id, action_space_size, collector_env_num, evaluator_env_nu
# device='cpu', # 'cuda',
device='cuda', # 'cuda',
action_space_size=action_space_size,
num_layers=4, # NOTE
# num_layers=4, # NOTE
num_layers=2, # NOTE
num_heads=8,
embed_dim=768,
obs_type='image',
Expand All @@ -59,21 +60,22 @@ def create_config(env_id, action_space_size, collector_env_num, evaluator_env_nu
task_num=len(env_id_list),
# num_experts_in_softmoe_head=4, # NOTE
num_experts_in_softmoe_head=-1, # NOTE
moe_in_transformer=True,
# moe_in_transformer=False, # NOTE
# num_experts_of_moe_in_transformer=4,
num_experts_of_moe_in_transformer=2,
# moe_in_transformer=True,
moe_in_transformer=False, # NOTE
num_experts_of_moe_in_transformer=4,
# num_experts_of_moe_in_transformer=2,
),
),
use_priority=False,
# use_priority=False,
# print_task_priority_logs=False,
use_priority=True, # TODO
print_task_priority_logs=False,
# use_priority=True, # TODO
# print_task_priority_logs=True,
cuda=True,
model_path=None,
num_unroll_steps=num_unroll_steps,
# update_per_collect=None,
update_per_collect=1000,
# update_per_collect=1000,
update_per_collect=500,
replay_ratio=0.25,
batch_size=batch_size,
optim_type='AdamW',
Expand All @@ -97,7 +99,8 @@ def generate_configs(env_id_list, action_space_size, collector_env_num, n_episod
# exp_name_prefix = f'data_unizero_mt_0716/{len(env_id_list)}games_4-head_1-encoder-{norm_type}_CAGrad_lsd768-nlayer4-nh8_max-bs1500_seed{seed}/'
# exp_name_prefix = f'data_unizero_mt_0716/{len(env_id_list)}games_4-head_1-encoder-{norm_type}_MoCo_lsd768-nlayer4-nh8_max-bs1500_seed{seed}/'
# exp_name_prefix = f'data_unizero_mt_0716/{len(env_id_list)}games_pong-boxing-envnum2_4-head_1-encoder-{norm_type}_trans-ffw-moe1-same_lsd768-nlayer4-nh8_max-bs1500_upc1000_seed{seed}/'
exp_name_prefix = f'data_unizero_mt_0716/{len(env_id_list)}games_pong-boxing-envnum2_4-head_1-encoder-{norm_type}_trans-ffw-moe2_lsd768-nlayer4-nh8_max-bs1500_upc1000_seed{seed}/'
# exp_name_prefix = f'data_unizero_mt_0719/{len(env_id_list)}games_pong-boxing-envnum2_4-head_1-encoder-{norm_type}_trans-ffw-moe4_lsd768-nlayer2-nh8_max-bs1500_upc1000_seed{seed}/'
exp_name_prefix = f'data_unizero_mt_0719/{len(env_id_list)}games_pong-boxing-envnum2_4-head_1-encoder-{norm_type}_lsd768-nlayer2-nh8_max-bs1500_upc500_seed{seed}/'

# exp_name_prefix = f'data_unizero_mt_0716/{len(env_id_list)}games_1-head_1-encoder-{norm_type}_trans-ffw-moe4_lsd768-nlayer4-nh8_max-bs1500_seed{seed}/'

Expand Down Expand Up @@ -173,7 +176,7 @@ def create_env_manager():
num_unroll_steps = 10
infer_context_length = 4
norm_type = 'LN'
# norm_type = 'BN' # bad performance now
# # norm_type = 'BN' # bad performance now


# ======== TODO: only for debug ========
Expand Down

0 comments on commit 5117459

Please sign in to comment.