Skip to content

Commit

Permalink
fix(nyz): polish dmc2gym sac entry
Browse files Browse the repository at this point in the history
  • Loading branch information
PaParaZz1 committed Mar 8, 2023
1 parent b81ce53 commit f798002
Show file tree
Hide file tree
Showing 5 changed files with 14 additions and 87 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def get_attribute(self, name: str) -> Any:


def call_gae_estimator(batch_size: int = 32, trajectory_end_idx_size: int = 5, buffer: Optional[Buffer] = None):
cfg = EasyDict({'policy': {'collect': {'discount_factor': 0.9, 'gae_lambda': 0.95}}})
cfg = EasyDict({'policy': {'collect': {'discount_factor': 0.9, 'gae_lambda': 0.95}, 'cuda': False}})

ctx = OnlineRLContext()
assert trajectory_end_idx_size <= batch_size
Expand Down
48 changes: 3 additions & 45 deletions dizoo/dmc2gym/config/dmc2gym_sac_pixel_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
env_id='dmc2gym-v0',
domain_name="cartpole",
task_name="swingup",
frame_skip=2,
frame_skip=4,
warp_frame=True,
scale=True,
clip_rewards=False,
Expand All @@ -17,46 +17,26 @@
collector_env_num=8,
evaluator_env_num=8,
n_evaluator_episode=8,
# collector_env_num=1,
# evaluator_env_num=1,
# n_evaluator_episode=1,
stop_value=1e6,
manager=dict(shared_memory=False, ),
),
policy=dict(
model_type='pixel',
cuda=True,
random_collect_size=10000,
# random_collect_size=10,
model=dict(
obs_shape=(3, 84, 84),
action_shape=1,
twin_critic=True,
encoder_hidden_size_list=[32, 32, 50],
encoder_hidden_size_list=[32, 32, 32],
actor_head_hidden_size=1024,
critic_head_hidden_size=1024,

# different option about whether to share_conv_encoder in two Q networks
# and whether to use embed_action

share_conv_encoder=False,
embed_action=False,

# share_conv_encoder=True,
# embed_action=False,

# share_conv_encoder=False,
# embed_action=True,

# share_conv_encoder=True,
# embed_action=True,
embed_action_density=0.1,
share_encoder=True,
),
learn=dict(
ignore_done=True,
update_per_collect=1,
batch_size=128,
# batch_size=4, # debug
learning_rate_q=1e-3,
learning_rate_policy=1e-3,
learning_rate_alpha=3e-4,
Expand All @@ -70,7 +50,6 @@
n_sample=1,
unroll_len=1,
),
command=dict(),
eval=dict(),
other=dict(replay_buffer=dict(replay_buffer_size=100000, ), ),
),
Expand All @@ -85,7 +64,6 @@
import_names=['dizoo.dmc2gym.envs.dmc2gym_env'],
),
env_manager=dict(type='subprocess'),
# env_manager=dict(type='base'), # debug
policy=dict(
type='sac',
import_names=['ding.policy.sac'],
Expand All @@ -94,23 +72,3 @@
)
dmc2gym_sac_create_config = EasyDict(dmc2gym_sac_create_config)
create_config = dmc2gym_sac_create_config

# if __name__ == "__main__":
# # or you can enter `ding -m serial -c dmc2gym_sac_pixel_config.py -s 0`
# from ding.entry import serial_pipeline
# serial_pipeline([main_config, create_config], seed=0)


if __name__ == "__main__":
import copy
import argparse
from ding.entry import serial_pipeline

for seed in [0, 1, 2]:
parser = argparse.ArgumentParser()
parser.add_argument('--seed', '-s', type=int, default=seed)
args = parser.parse_args()

main_config.exp_name = 'dmc2gym_sac_pixel_scet-eat01-detach' + 'seed' + f'{args.seed}'
serial_pipeline([copy.deepcopy(main_config), copy.deepcopy(create_config)], seed=args.seed,
max_env_step=int(3e6))
32 changes: 0 additions & 32 deletions dizoo/dmc2gym/config/dmc2gym_sac_state_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
model_type='state',
cuda=True,
random_collect_size=10000,
load_path="/root/dmc2gym_cartpole_swingup_state_sac_eval/ckpt/ckpt_best.pth.tar",
model=dict(
obs_shape=5,
action_shape=1,
Expand All @@ -46,7 +45,6 @@
n_sample=1,
unroll_len=1,
),
command=dict(),
eval=dict(),
other=dict(replay_buffer=dict(replay_buffer_size=1000000, ), ),
),
Expand All @@ -69,33 +67,3 @@
)
dmc2gym_sac_create_config = EasyDict(dmc2gym_sac_create_config)
create_config = dmc2gym_sac_create_config

# if __name__ == "__main__":
# # or you can enter `ding -m serial -c dmc2gym_sac_state_config.py -s 0`
# from ding.entry import serial_pipeline
# serial_pipeline([main_config, create_config], seed=0)

# if __name__ == "__main__":
# # or you can enter `ding -m serial -c dmc2gym_sac_config.py -s 0`
# from ding.entry import serial_pipeline
# serial_pipeline([main_config, create_config], seed=0)

def train(args):
main_config.exp_name = 'dmc2gym_sac_state_old_check/' + 'seed' + f'{args.seed}' + '_5M'
serial_pipeline([copy.deepcopy(main_config), copy.deepcopy(create_config)], seed=args.seed,
max_env_step=int(5e6))


if __name__ == "__main__":
import copy
import argparse
from ding.entry import serial_pipeline

for seed in [0, 1, 2]:
parser = argparse.ArgumentParser()
parser.add_argument('--seed', '-s', type=int, default=seed)
args = parser.parse_args()

main_config.exp_name = 'dmc2gym_sac_state' + 'seed' + f'{args.seed}'
serial_pipeline([copy.deepcopy(main_config), copy.deepcopy(create_config)], seed=args.seed,
max_env_step=int(3e6))
15 changes: 8 additions & 7 deletions dizoo/dmc2gym/entry/dmc2gym_sac_pixel_main.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
from tensorboardX import SummaryWriter
from ditk import logging
from ding.model.template.qac import QACPixel
import os
import numpy as np
from ding.model.template.qac import QAC
from ding.policy import SACPolicy
from ding.envs import BaseEnvManagerV2
from ding.data import DequeBuffer
Expand All @@ -11,9 +14,6 @@
from ding.utils import set_pkg_seed
from dizoo.dmc2gym.envs.dmc2gym_env import DMC2GymEnv
from dizoo.dmc2gym.config.dmc2gym_sac_pixel_config import main_config, create_config
import numpy as np
from tensorboardX import SummaryWriter
import os

def main():
logging.getLogger().setLevel(logging.INFO)
Expand All @@ -35,7 +35,8 @@ def main():

set_pkg_seed(cfg.seed, use_cuda=cfg.policy.cuda)

model = QACPixel(**cfg.policy.model)
model = QAC(**cfg.policy.model)
logging.info(model)
buffer_ = DequeBuffer(size=cfg.policy.other.replay_buffer.replay_buffer_size)
policy = SACPolicy(cfg.policy, model=model)

Expand Down Expand Up @@ -72,8 +73,8 @@ def _add_train_scalar(ctx):
task.use(data_pusher(cfg, buffer_))
task.use(OffPolicyLearner(cfg, policy.learn_mode, buffer_))
task.use(_add_train_scalar)
task.use(CkptSaver(cfg, policy, train_freq=100))
task.use(termination_checker(max_env_step=int(5000000)))
task.use(CkptSaver(policy, cfg.exp_name, train_freq=int(1e5)))
task.use(termination_checker(max_env_step=int(5e6)))
task.run()


Expand Down
4 changes: 2 additions & 2 deletions dizoo/dmc2gym/entry/dmc2gym_sac_state_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,8 @@ def _add_train_scalar(ctx):
task.use(data_pusher(cfg, buffer_))
task.use(OffPolicyLearner(cfg, policy.learn_mode, buffer_))
task.use(_add_train_scalar)
task.use(CkptSaver(cfg, policy, train_freq=100))
task.use(termination_checker(max_env_step=int(5000000)))
task.use(CkptSaver(policy, cfg.exp_name, train_freq=int(1e5)))
task.use(termination_checker(max_env_step=int(5e6)))
task.run()


Expand Down

0 comments on commit f798002

Please sign in to comment.