diff --git a/ding/entry/serial_entry.py b/ding/entry/serial_entry.py index 929c83a219..f7c039e494 100644 --- a/ding/entry/serial_entry.py +++ b/ding/entry/serial_entry.py @@ -47,7 +47,15 @@ def serial_pipeline( cfg, create_cfg = deepcopy(input_cfg) create_cfg.policy.type = create_cfg.policy.type + '_command' env_fn = None if env_setting is None else env_setting[0] - cfg = compile_config(cfg, seed=seed, env=env_fn, auto=True, create_cfg=create_cfg, save_cfg=True) + cfg = compile_config( + cfg, + seed=seed, + env=env_fn, + auto=True, + create_cfg=create_cfg, + save_cfg=True, + renew_dir=not cfg.policy.learn.get('resume_training', False) + ) # Create main components: env, policy if env_setting is None: env_fn, collector_env_cfg, evaluator_env_cfg = get_vec_env_setting(cfg.env) @@ -86,6 +94,8 @@ def serial_pipeline( # ========== # Learner's before_run hook. learner.call_hook('before_run') + if cfg.policy.learn.get('resume_training', False): + collector.envstep = learner.collector_envstep # Accumulate plenty of data at the beginning of training. if cfg.policy.get('random_collect_size', 0) > 0: diff --git a/ding/entry/serial_entry_mbrl.py b/ding/entry/serial_entry_mbrl.py index 03d240c6ea..edb97e0c62 100644 --- a/ding/entry/serial_entry_mbrl.py +++ b/ding/entry/serial_entry_mbrl.py @@ -30,7 +30,15 @@ def mbrl_entry_setup( cfg, create_cfg = deepcopy(input_cfg) create_cfg.policy.type = create_cfg.policy.type + '_command' env_fn = None if env_setting is None else env_setting[0] - cfg = compile_config(cfg, seed=seed, env=env_fn, auto=True, create_cfg=create_cfg, save_cfg=True) + cfg = compile_config( + cfg, + seed=seed, + env=env_fn, + auto=True, + create_cfg=create_cfg, + save_cfg=True, + renew_dir=not cfg.policy.learn.get('resume_training', False) + ) if env_setting is None: env_fn, collector_env_cfg, evaluator_env_cfg = get_vec_env_setting(cfg.env) @@ -70,18 +78,7 @@ def mbrl_entry_setup( cfg.policy.other.commander, learner, collector, evaluator, env_buffer, policy.command_mode ) - return ( - cfg, - policy, - world_model, - env_buffer, - learner, - collector, - collector_env, - evaluator, - commander, - tb_logger, - ) + return (cfg, policy, world_model, env_buffer, learner, collector, collector_env, evaluator, commander, tb_logger) def create_img_buffer( @@ -131,6 +128,8 @@ def serial_pipeline_dyna( img_buffer = create_img_buffer(cfg, input_cfg, world_model, tb_logger) learner.call_hook('before_run') + if cfg.policy.learn.get('resume_training', False): + collector.envstep = learner.collector_envstep if cfg.policy.get('random_collect_size', 0) > 0: random_collect(cfg.policy, policy, collector, collector_env, commander, env_buffer) @@ -202,6 +201,8 @@ def serial_pipeline_dream( mbrl_entry_setup(input_cfg, seed, env_setting, model) learner.call_hook('before_run') + if cfg.policy.learn.get('resume_training', False): + collector.envstep = learner.collector_envstep if cfg.policy.get('random_collect_size', 0) > 0: random_collect(cfg.policy, policy, collector, collector_env, commander, env_buffer) diff --git a/ding/entry/serial_entry_ngu.py b/ding/entry/serial_entry_ngu.py index 176f5558cd..1fcce53dc7 100644 --- a/ding/entry/serial_entry_ngu.py +++ b/ding/entry/serial_entry_ngu.py @@ -47,7 +47,15 @@ def serial_pipeline_ngu( cfg, create_cfg = deepcopy(input_cfg) create_cfg.policy.type = create_cfg.policy.type + '_command' env_fn = None if env_setting is None else env_setting[0] - cfg = compile_config(cfg, seed=seed, env=env_fn, auto=True, create_cfg=create_cfg, save_cfg=True) + cfg = compile_config( + cfg, + seed=seed, + env=env_fn, + auto=True, + create_cfg=create_cfg, + save_cfg=True, + renew_dir=not cfg.policy.learn.get('resume_training', False) + ) # Create main components: env, policy if env_setting is None: env_fn, collector_env_cfg, evaluator_env_cfg = get_vec_env_setting(cfg.env) @@ -89,6 +97,8 @@ def serial_pipeline_ngu( # ========== # Learner's before_run hook. learner.call_hook('before_run') + if cfg.policy.learn.get('resume_training', False): + collector.envstep = learner.collector_envstep # Accumulate plenty of data at the beginning of training. if cfg.policy.get('random_collect_size', 0) > 0: diff --git a/ding/entry/serial_entry_onpolicy.py b/ding/entry/serial_entry_onpolicy.py index 22e9cf74f9..713fbcac58 100644 --- a/ding/entry/serial_entry_onpolicy.py +++ b/ding/entry/serial_entry_onpolicy.py @@ -45,7 +45,16 @@ def serial_pipeline_onpolicy( cfg, create_cfg = deepcopy(input_cfg) create_cfg.policy.type = create_cfg.policy.type + '_command' env_fn = None if env_setting is None else env_setting[0] - cfg = compile_config(cfg, seed=seed, env=env_fn, auto=True, create_cfg=create_cfg, save_cfg=True) + cfg = compile_config( + cfg, + seed=seed, + env=env_fn, + auto=True, + create_cfg=create_cfg, + save_cfg=True, + renew_dir=not cfg.policy.learn.get('resume_training', False) + ) + # Create main components: env, policy if env_setting is None: env_fn, collector_env_cfg, evaluator_env_cfg = get_vec_env_setting(cfg.env) @@ -80,6 +89,8 @@ def serial_pipeline_onpolicy( # ========== # Learner's before_run hook. learner.call_hook('before_run') + if cfg.policy.learn.get('resume_training', False): + collector.envstep = learner.collector_envstep while True: collect_kwargs = commander.step() diff --git a/ding/entry/serial_entry_onpolicy_ppg.py b/ding/entry/serial_entry_onpolicy_ppg.py index 02c6dee307..90f31891ab 100644 --- a/ding/entry/serial_entry_onpolicy_ppg.py +++ b/ding/entry/serial_entry_onpolicy_ppg.py @@ -45,7 +45,15 @@ def serial_pipeline_onpolicy_ppg( cfg, create_cfg = deepcopy(input_cfg) create_cfg.policy.type = create_cfg.policy.type + '_command' env_fn = None if env_setting is None else env_setting[0] - cfg = compile_config(cfg, seed=seed, env=env_fn, auto=True, create_cfg=create_cfg, save_cfg=True) + cfg = compile_config( + cfg, + seed=seed, + env=env_fn, + auto=True, + create_cfg=create_cfg, + save_cfg=True, + renew_dir=not cfg.policy.learn.get('resume_training', False) + ) # Create main components: env, policy if env_setting is None: env_fn, collector_env_cfg, evaluator_env_cfg = get_vec_env_setting(cfg.env) @@ -80,6 +88,8 @@ def serial_pipeline_onpolicy_ppg( # ========== # Learner's before_run hook. learner.call_hook('before_run') + if cfg.policy.learn.get('resume_training', False): + collector.envstep = learner.collector_envstep while True: collect_kwargs = commander.step() diff --git a/ding/policy/base_policy.py b/ding/policy/base_policy.py index 7e843f8429..61d8c1e845 100644 --- a/ding/policy/base_policy.py +++ b/ding/policy/base_policy.py @@ -93,6 +93,10 @@ def default_config(cls: type) -> EasyDict: traj_len_inf=False, # neural network model config model=dict(), + # If resume_training is True, the environment step count (collector.envstep) and training iteration (train_iter) + # will be loaded from the pretrained checkpoint, allowing training to resume seamlessly + # from where the ckpt left off. + learn=dict(resume_training=False), ) def __init__( diff --git a/ding/worker/collector/battle_episode_serial_collector.py b/ding/worker/collector/battle_episode_serial_collector.py index 6609adcaea..156ab986d6 100644 --- a/ding/worker/collector/battle_episode_serial_collector.py +++ b/ding/worker/collector/battle_episode_serial_collector.py @@ -162,10 +162,20 @@ def envstep(self) -> int: Overview: Print the total envstep count. Return: - - envstep (:obj:`int`): the total envstep count + - envstep (:obj:`int`): The total envstep count. """ return self._total_envstep_count + @envstep.setter + def envstep(self, value: int) -> None: + """ + Overview: + Set the total envstep count. + Arguments: + - value (:obj:`int`): The total envstep count. + """ + self._total_envstep_count = value + def close(self) -> None: """ Overview: diff --git a/ding/worker/collector/battle_sample_serial_collector.py b/ding/worker/collector/battle_sample_serial_collector.py index dffc43f5f7..33f4df11c9 100644 --- a/ding/worker/collector/battle_sample_serial_collector.py +++ b/ding/worker/collector/battle_sample_serial_collector.py @@ -175,10 +175,20 @@ def envstep(self) -> int: Overview: Print the total envstep count. Return: - - envstep (:obj:`int`): the total envstep count + - envstep (:obj:`int`): The total envstep count. """ return self._total_envstep_count + @envstep.setter + def envstep(self, value: int) -> None: + """ + Overview: + Set the total envstep count. + Arguments: + - value (:obj:`int`): The total envstep count. + """ + self._total_envstep_count = value + def close(self) -> None: """ Overview: diff --git a/ding/worker/collector/episode_serial_collector.py b/ding/worker/collector/episode_serial_collector.py index 6fca2283f8..5147f5a456 100644 --- a/ding/worker/collector/episode_serial_collector.py +++ b/ding/worker/collector/episode_serial_collector.py @@ -157,10 +157,20 @@ def envstep(self) -> int: Overview: Print the total envstep count. Return: - - envstep (:obj:`int`): the total envstep count + - envstep (:obj:`int`): The total envstep count. """ return self._total_envstep_count + @envstep.setter + def envstep(self, value: int) -> None: + """ + Overview: + Set the total envstep count. + Arguments: + - value (:obj:`int`): The total envstep count. + """ + self._total_envstep_count = value + def close(self) -> None: """ Overview: diff --git a/ding/worker/collector/sample_serial_collector.py b/ding/worker/collector/sample_serial_collector.py index 26db458edb..1688c78c65 100644 --- a/ding/worker/collector/sample_serial_collector.py +++ b/ding/worker/collector/sample_serial_collector.py @@ -185,10 +185,20 @@ def envstep(self) -> int: Overview: Print the total envstep count. Return: - - envstep (:obj:`int`): the total envstep count + - envstep (:obj:`int`): The total envstep count. """ return self._total_envstep_count + @envstep.setter + def envstep(self, value: int) -> None: + """ + Overview: + Set the total envstep count. + Arguments: + - value (:obj:`int`): The total envstep count. + """ + self._total_envstep_count = value + def close(self) -> None: """ Overview: diff --git a/ding/worker/learner/base_learner.py b/ding/worker/learner/base_learner.py index 1144a412cd..0b57b06c76 100644 --- a/ding/worker/learner/base_learner.py +++ b/ding/worker/learner/base_learner.py @@ -122,6 +122,8 @@ def __init__( self._hooks = {'before_run': [], 'before_iter': [], 'after_iter': [], 'after_run': []} # Last iteration. Used to record current iter. self._last_iter = CountVar(init_val=0) + # Collector envstep. Used to record current envstep. + self._collector_envstep = 0 # Setup time wrapper and hook. self._setup_wrapper() @@ -177,6 +179,26 @@ def register_hook(self, hook: LearnerHook) -> None: """ add_learner_hook(self._hooks, hook) + @property + def collector_envstep(self) -> int: + """ + Overview: + Get current collector envstep. + Returns: + - collector_envstep (:obj:`int`): Current collector envstep. + """ + return self._collector_envstep + + @collector_envstep.setter + def collector_envstep(self, value: int) -> None: + """ + Overview: + Set current collector envstep. + Arguments: + - value (:obj:`int`): Current collector envstep. + """ + self._collector_envstep = value + def train(self, data: dict, envstep: int = -1, policy_kwargs: Optional[dict] = None) -> None: """ Overview: diff --git a/ding/worker/learner/learner_hook.py b/ding/worker/learner/learner_hook.py index 250a8f1950..6797af077b 100644 --- a/ding/worker/learner/learner_hook.py +++ b/ding/worker/learner/learner_hook.py @@ -117,6 +117,9 @@ def __call__(self, engine: 'BaseLearner') -> None: # noqa if 'last_iter' in state_dict: last_iter = state_dict.pop('last_iter') engine.last_iter.update(last_iter) + if 'last_step' in state_dict: + last_step = state_dict.pop('last_step') + engine._collector_envstep = last_step engine.policy.load_state_dict(state_dict) engine.info('{} load ckpt in {}'.format(engine.instance_name, path)) @@ -166,6 +169,7 @@ def __call__(self, engine: 'BaseLearner') -> None: # noqa path = os.path.join(dirname, ckpt_name) state_dict = engine.policy.state_dict() state_dict.update({'last_iter': engine.last_iter.val}) + state_dict.update({'last_step': engine.collector_envstep}) save_file(path, state_dict) engine.info('{} save ckpt in {}'.format(engine.instance_name, path)) diff --git a/dizoo/classic_control/cartpole/config/cartpole_ppo_config.py b/dizoo/classic_control/cartpole/config/cartpole_ppo_config.py index e8a5108721..4c1333bbc8 100644 --- a/dizoo/classic_control/cartpole/config/cartpole_ppo_config.py +++ b/dizoo/classic_control/cartpole/config/cartpole_ppo_config.py @@ -26,7 +26,15 @@ value_weight=0.5, entropy_weight=0.01, clip_ratio=0.2, - learner=dict(hook=dict(save_ckpt_after_iter=100)), + # Path to the pretrained checkpoint (ckpt). + # If set to an empty string (''), no pretrained model will be loaded. + # To load a pretrained ckpt, specify the path like this: + # learner=dict(hook=dict(load_ckpt_before_run='/path/to/your/ckpt/iteration_100.pth.tar')), + + # If True, the environment step count (collector.envstep) and training iteration (train_iter) + # will be loaded from the pretrained checkpoint, allowing training to resume seamlessly + # from where the ckpt left off. + resume_training=False, ), collect=dict( n_sample=256, diff --git a/dizoo/cliffwalking/envs/cliffwalking_env.py b/dizoo/cliffwalking/envs/cliffwalking_env.py index 79d53ba64c..af5d094aa6 100644 --- a/dizoo/cliffwalking/envs/cliffwalking_env.py +++ b/dizoo/cliffwalking/envs/cliffwalking_env.py @@ -24,8 +24,8 @@ def __init__(self, cfg: dict) -> None: self._replay_path = None self._observation_space = gym.spaces.Box(low=0, high=1, shape=(48, ), dtype=np.float32) self._env = gym.make( - "CliffWalking", render_mode=self._cfg.render_mode, max_episode_steps=self._cfg.max_episode_steps - ) + "CliffWalking", render_mode=self._cfg.render_mode, max_episode_steps=self._cfg.max_episode_steps + ) self._action_space = self._env.action_space self._reward_space = gym.spaces.Box( low=self._env.reward_range[0], high=self._env.reward_range[1], shape=(1, ), dtype=np.float32 diff --git a/dizoo/d4rl/config/antmaze_umaze_pd_config.py b/dizoo/d4rl/config/antmaze_umaze_pd_config.py index 96ca022545..5f111f58fd 100755 --- a/dizoo/d4rl/config/antmaze_umaze_pd_config.py +++ b/dizoo/d4rl/config/antmaze_umaze_pd_config.py @@ -45,7 +45,7 @@ value_model_cfg=dict( model='TemporalValue', model_cfg=dict( - horizon = 256, + horizon=256, transition_dim=37, dim=32, dim_mults=[1, 2, 4, 8], @@ -92,10 +92,8 @@ import_names=['dizoo.d4rl.envs.d4rl_env'], ), env_manager=dict(type='subprocess'), - policy=dict( - type='pd', - ), + policy=dict(type='pd', ), replay_buffer=dict(type='naive', ), ) create_config = EasyDict(create_config) -create_config = create_config \ No newline at end of file +create_config = create_config diff --git a/dizoo/d4rl/config/halfcheetah_medium_expert_pd_config.py b/dizoo/d4rl/config/halfcheetah_medium_expert_pd_config.py index 66c8ba8d91..d42692bc63 100755 --- a/dizoo/d4rl/config/halfcheetah_medium_expert_pd_config.py +++ b/dizoo/d4rl/config/halfcheetah_medium_expert_pd_config.py @@ -45,7 +45,7 @@ value_model_cfg=dict( model='TemporalValue', model_cfg=dict( - horizon = 4, + horizon=4, transition_dim=23, dim=32, dim_mults=[1, 4, 8], @@ -92,9 +92,7 @@ import_names=['dizoo.d4rl.envs.d4rl_env'], ), env_manager=dict(type='subprocess'), - policy=dict( - type='pd', - ), + policy=dict(type='pd', ), replay_buffer=dict(type='naive', ), ) create_config = EasyDict(create_config) diff --git a/dizoo/d4rl/config/halfcheetah_medium_expert_qgpo_config.py b/dizoo/d4rl/config/halfcheetah_medium_expert_qgpo_config.py index 787a421d76..c8b75891ca 100644 --- a/dizoo/d4rl/config/halfcheetah_medium_expert_qgpo_config.py +++ b/dizoo/d4rl/config/halfcheetah_medium_expert_qgpo_config.py @@ -7,9 +7,7 @@ evaluator_env_num=8, n_evaluator_episode=8, ), - dataset=dict( - env_id="halfcheetah-medium-expert-v2", - ), + dataset=dict(env_id="halfcheetah-medium-expert-v2", ), policy=dict( cuda=True, on_policy=False, @@ -44,8 +42,6 @@ create_config = dict( env_manager=dict(type='base'), - policy=dict( - type='qgpo', - ), + policy=dict(type='qgpo', ), ) create_config = EasyDict(create_config) diff --git a/dizoo/d4rl/config/halfcheetah_medium_pd_config.py b/dizoo/d4rl/config/halfcheetah_medium_pd_config.py index 674395a4e1..ae4145f9a3 100755 --- a/dizoo/d4rl/config/halfcheetah_medium_pd_config.py +++ b/dizoo/d4rl/config/halfcheetah_medium_pd_config.py @@ -45,7 +45,7 @@ value_model_cfg=dict( model='TemporalValue', model_cfg=dict( - horizon = 4, + horizon=4, transition_dim=23, dim=32, dim_mults=[1, 4, 8], @@ -92,9 +92,7 @@ import_names=['dizoo.d4rl.envs.d4rl_env'], ), env_manager=dict(type='subprocess'), - policy=dict( - type='pd', - ), + policy=dict(type='pd', ), replay_buffer=dict(type='naive', ), ) create_config = EasyDict(create_config) diff --git a/dizoo/d4rl/config/hopper_medium_expert_pd_config.py b/dizoo/d4rl/config/hopper_medium_expert_pd_config.py index 3df47f8d1b..d7eeeeca5b 100755 --- a/dizoo/d4rl/config/hopper_medium_expert_pd_config.py +++ b/dizoo/d4rl/config/hopper_medium_expert_pd_config.py @@ -45,7 +45,7 @@ value_model_cfg=dict( model='TemporalValue', model_cfg=dict( - horizon = 32, + horizon=32, transition_dim=14, dim=32, dim_mults=[1, 2, 4, 8], @@ -92,10 +92,8 @@ import_names=['dizoo.d4rl.envs.d4rl_env'], ), env_manager=dict(type='subprocess'), - policy=dict( - type='pd', - ), + policy=dict(type='pd', ), replay_buffer=dict(type='naive', ), ) create_config = EasyDict(create_config) -create_config = create_config \ No newline at end of file +create_config = create_config diff --git a/dizoo/d4rl/config/hopper_medium_expert_qgpo_config.py b/dizoo/d4rl/config/hopper_medium_expert_qgpo_config.py index b0941bc13c..d47656afa2 100644 --- a/dizoo/d4rl/config/hopper_medium_expert_qgpo_config.py +++ b/dizoo/d4rl/config/hopper_medium_expert_qgpo_config.py @@ -7,9 +7,7 @@ evaluator_env_num=8, n_evaluator_episode=8, ), - dataset=dict( - env_id="hopper-medium-expert-v2", - ), + dataset=dict(env_id="hopper-medium-expert-v2", ), policy=dict( cuda=True, on_policy=False, @@ -44,8 +42,6 @@ create_config = dict( env_manager=dict(type='base'), - policy=dict( - type='qgpo', - ), + policy=dict(type='qgpo', ), ) create_config = EasyDict(create_config) diff --git a/dizoo/d4rl/config/hopper_medium_pd_config.py b/dizoo/d4rl/config/hopper_medium_pd_config.py index 8dfee5d824..47f1c36ce0 100755 --- a/dizoo/d4rl/config/hopper_medium_pd_config.py +++ b/dizoo/d4rl/config/hopper_medium_pd_config.py @@ -45,7 +45,7 @@ value_model_cfg=dict( model='TemporalValue', model_cfg=dict( - horizon = 32, + horizon=32, transition_dim=14, dim=32, dim_mults=[1, 2, 4, 8], @@ -92,10 +92,8 @@ import_names=['dizoo.d4rl.envs.d4rl_env'], ), env_manager=dict(type='subprocess'), - policy=dict( - type='pd', - ), + policy=dict(type='pd', ), replay_buffer=dict(type='naive', ), ) create_config = EasyDict(create_config) -create_config = create_config \ No newline at end of file +create_config = create_config diff --git a/dizoo/d4rl/config/maze2d_large_pd_config.py b/dizoo/d4rl/config/maze2d_large_pd_config.py index a68838213a..a722488b89 100755 --- a/dizoo/d4rl/config/maze2d_large_pd_config.py +++ b/dizoo/d4rl/config/maze2d_large_pd_config.py @@ -75,10 +75,8 @@ import_names=['dizoo.d4rl.envs.d4rl_env'], ), env_manager=dict(type='subprocess'), - policy=dict( - type='pd', - ), + policy=dict(type='pd', ), replay_buffer=dict(type='naive', ), ) create_config = EasyDict(create_config) -create_config = create_config \ No newline at end of file +create_config = create_config diff --git a/dizoo/d4rl/config/maze2d_medium_pd_config.py b/dizoo/d4rl/config/maze2d_medium_pd_config.py index a14cac7480..2a3d43d443 100755 --- a/dizoo/d4rl/config/maze2d_medium_pd_config.py +++ b/dizoo/d4rl/config/maze2d_medium_pd_config.py @@ -75,10 +75,8 @@ import_names=['dizoo.d4rl.envs.d4rl_env'], ), env_manager=dict(type='subprocess'), - policy=dict( - type='pd', - ), + policy=dict(type='pd', ), replay_buffer=dict(type='naive', ), ) create_config = EasyDict(create_config) -create_config = create_config \ No newline at end of file +create_config = create_config diff --git a/dizoo/d4rl/config/maze2d_umaze_pd_config.py b/dizoo/d4rl/config/maze2d_umaze_pd_config.py index 462d10651e..a10fef5844 100755 --- a/dizoo/d4rl/config/maze2d_umaze_pd_config.py +++ b/dizoo/d4rl/config/maze2d_umaze_pd_config.py @@ -75,10 +75,8 @@ import_names=['dizoo.d4rl.envs.d4rl_env'], ), env_manager=dict(type='subprocess'), - policy=dict( - type='pd', - ), + policy=dict(type='pd', ), replay_buffer=dict(type='naive', ), ) create_config = EasyDict(create_config) -create_config = create_config \ No newline at end of file +create_config = create_config diff --git a/dizoo/d4rl/config/walker2d_medium_expert_pd_config.py b/dizoo/d4rl/config/walker2d_medium_expert_pd_config.py index 3d4c060e83..06a80588aa 100755 --- a/dizoo/d4rl/config/walker2d_medium_expert_pd_config.py +++ b/dizoo/d4rl/config/walker2d_medium_expert_pd_config.py @@ -45,7 +45,7 @@ value_model_cfg=dict( model='TemporalValue', model_cfg=dict( - horizon = 32, + horizon=32, transition_dim=23, dim=32, dim_mults=[1, 2, 4, 8], @@ -92,10 +92,8 @@ import_names=['dizoo.d4rl.envs.d4rl_env'], ), env_manager=dict(type='subprocess'), - policy=dict( - type='pd', - ), + policy=dict(type='pd', ), replay_buffer=dict(type='naive', ), ) create_config = EasyDict(create_config) -create_config = create_config \ No newline at end of file +create_config = create_config diff --git a/dizoo/d4rl/config/walker2d_medium_expert_qgpo_config.py b/dizoo/d4rl/config/walker2d_medium_expert_qgpo_config.py index b4b3bd7bb6..46c98f8dd1 100644 --- a/dizoo/d4rl/config/walker2d_medium_expert_qgpo_config.py +++ b/dizoo/d4rl/config/walker2d_medium_expert_qgpo_config.py @@ -7,9 +7,7 @@ evaluator_env_num=8, n_evaluator_episode=8, ), - dataset=dict( - env_id="walker2d-medium-expert-v2", - ), + dataset=dict(env_id="walker2d-medium-expert-v2", ), policy=dict( cuda=True, on_policy=False, @@ -44,8 +42,6 @@ create_config = dict( env_manager=dict(type='base'), - policy=dict( - type='qgpo', - ), + policy=dict(type='qgpo', ), ) create_config = EasyDict(create_config) diff --git a/dizoo/d4rl/config/walker2d_medium_pd_config.py b/dizoo/d4rl/config/walker2d_medium_pd_config.py index 29fce259c8..099364a763 100755 --- a/dizoo/d4rl/config/walker2d_medium_pd_config.py +++ b/dizoo/d4rl/config/walker2d_medium_pd_config.py @@ -45,7 +45,7 @@ value_model_cfg=dict( model='TemporalValue', model_cfg=dict( - horizon = 32, + horizon=32, transition_dim=23, dim=32, dim_mults=[1, 2, 4, 8], @@ -92,10 +92,8 @@ import_names=['dizoo.d4rl.envs.d4rl_env'], ), env_manager=dict(type='subprocess'), - policy=dict( - type='pd', - ), + policy=dict(type='pd', ), replay_buffer=dict(type='naive', ), ) create_config = EasyDict(create_config) -create_config = create_config \ No newline at end of file +create_config = create_config diff --git a/dizoo/d4rl/entry/d4rl_pd_main.py b/dizoo/d4rl/entry/d4rl_pd_main.py index 1ca3c5b299..0d355a77ea 100755 --- a/dizoo/d4rl/entry/d4rl_pd_main.py +++ b/dizoo/d4rl/entry/d4rl_pd_main.py @@ -5,7 +5,7 @@ def train(args): # launch from anywhere - config = Path(__file__).absolute().parent.parent / 'config' / args.config + config = Path(__file__).absolute().parent.parent / 'config' / args.config config = read_config(str(config)) config[0].exp_name = config[0].exp_name.replace('0', str(args.seed)) serial_pipeline_offline(config, seed=args.seed) @@ -18,4 +18,4 @@ def train(args): parser.add_argument('--seed', '-s', type=int, default=10) parser.add_argument('--config', '-c', type=str, default='halfcheetah_medium_pd_config.py') args = parser.parse_args() - train(args) \ No newline at end of file + train(args) diff --git a/dizoo/d4rl/envs/d4rl_env.py b/dizoo/d4rl/envs/d4rl_env.py index db770fd099..d14bdafaa5 100755 --- a/dizoo/d4rl/envs/d4rl_env.py +++ b/dizoo/d4rl/envs/d4rl_env.py @@ -14,11 +14,8 @@ from .d4rl_wrappers import wrap_d4rl from ding.utils import ENV_REGISTRY -MAZE_BOUNDS = { - 'maze2d-umaze-v1': (0, 5, 0, 5), - 'maze2d-medium-v1': (0, 8, 0, 8), - 'maze2d-large-v1': (0, 9, 0, 12) -} +MAZE_BOUNDS = {'maze2d-umaze-v1': (0, 5, 0, 5), 'maze2d-medium-v1': (0, 8, 0, 8), 'maze2d-large-v1': (0, 9, 0, 12)} + def plot2img(fig, remove_margins=True): # https://stackoverflow.com/a/35362787/2912349 @@ -34,11 +31,13 @@ def plot2img(fig, remove_margins=True): img_as_string, (width, height) = canvas.print_to_buffer() return np.fromstring(img_as_string, dtype='uint8').reshape((height, width, 4)) + def zipsafe(*args): length = len(args[0]) assert all([len(a) == length for a in args]) return zip(*args) + def zipkw(*args, **kwargs): nargs = len(args) keys = kwargs.keys() @@ -49,6 +48,7 @@ def zipkw(*args, **kwargs): zipped_kwargs = {k: v for k, v in zipsafe(keys, items[nargs:])} yield zipped_args, zipped_kwargs + @ENV_REGISTRY.register('d4rl') class D4RLEnv(BaseEnv): @@ -137,18 +137,17 @@ def renders(self, observations, conditions=None, title=None): plt.clf() fig = plt.gcf() fig.set_size_inches(5, 5) - plt.imshow(self._background * .5, - extent=self._extent, cmap=plt.cm.binary, vmin=0, vmax=1) + plt.imshow(self._background * .5, extent=self._extent, cmap=plt.cm.binary, vmin=0, vmax=1) path_length = len(observations) - colors = plt.cm.jet(np.linspace(0,1,path_length)) - plt.plot(observations[:,1], observations[:,0], c='black', zorder=10) - plt.scatter(observations[:,1], observations[:,0], c=colors, zorder=20) + colors = plt.cm.jet(np.linspace(0, 1, path_length)) + plt.plot(observations[:, 1], observations[:, 0], c='black', zorder=10) + plt.scatter(observations[:, 1], observations[:, 0], c=colors, zorder=20) plt.axis('off') plt.title(title) img = plot2img(fig, remove_margins=self._remove_margins) return img - + def composite(self, savepath, paths, ncol=5, **kwargs): assert len(paths) % ncol == 0, 'Number of paths must be divisible by number of columns' @@ -159,8 +158,7 @@ def composite(self, savepath, paths, ncol=5, **kwargs): images = np.stack(images, axis=0) nrow = len(images) // ncol - images = einops.rearrange(images, - '(nrow ncol) H W C -> (nrow H) (ncol W) C', nrow=nrow, ncol=ncol) + images = einops.rearrange(images, '(nrow ncol) H W C -> (nrow H) (ncol W) C', nrow=nrow, ncol=ncol) imageio.imsave(savepath, images) print(f'Saved {len(paths)} samples to: {savepath}') diff --git a/dizoo/ising_env/envs/ising_model_env.py b/dizoo/ising_env/envs/ising_model_env.py index 70ccea9b59..e3ccbe890a 100644 --- a/dizoo/ising_env/envs/ising_model_env.py +++ b/dizoo/ising_env/envs/ising_model_env.py @@ -146,7 +146,7 @@ def render(self, action_matrix, info) -> None: # save the figure to buffer fig.canvas.draw() image = np.frombuffer(fig.canvas.tostring_rgb(), dtype='uint8') - image = image.reshape(fig.canvas.get_width_height()[::-1] + (3,)) + image = image.reshape(fig.canvas.get_width_height()[::-1] + (3, )) plt.close(fig) return image diff --git a/dizoo/league_demo/league_demo_collector.py b/dizoo/league_demo/league_demo_collector.py index ce7985a6dc..4afe72bbe8 100644 --- a/dizoo/league_demo/league_demo_collector.py +++ b/dizoo/league_demo/league_demo_collector.py @@ -163,7 +163,7 @@ def envstep(self) -> int: Overview: Print the total envstep count. Return: - - envstep (:obj:`int`): the total envstep count + - envstep (:obj:`int`): The total envstep count. """ return self._total_envstep_count diff --git a/dizoo/minigrid/config/minigrid_dreamer_config.py b/dizoo/minigrid/config/minigrid_dreamer_config.py index 410f803d96..68afa2757c 100644 --- a/dizoo/minigrid/config/minigrid_dreamer_config.py +++ b/dizoo/minigrid/config/minigrid_dreamer_config.py @@ -62,7 +62,7 @@ cuda=cuda, model=dict( state_size=1344, - obs_type = 'vector', + obs_type='vector', action_size=7, action_type='discrete', encoder_hidden_size_list=[256, 128, 64, 64], diff --git a/dizoo/petting_zoo/envs/petting_zoo_simple_spread_env.py b/dizoo/petting_zoo/envs/petting_zoo_simple_spread_env.py index bde84685f0..4be9687fe8 100644 --- a/dizoo/petting_zoo/envs/petting_zoo_simple_spread_env.py +++ b/dizoo/petting_zoo/envs/petting_zoo_simple_spread_env.py @@ -14,6 +14,7 @@ class PTZRecordVideo(gym.wrappers.RecordVideo): + def step(self, action): """Steps through the environment using action, recording observations if :attr:`self.recording`.""" # gymnasium==0.27.1 @@ -180,7 +181,9 @@ def reset(self) -> np.ndarray: ) if self._replay_path is not None: self._env.render_mode = 'rgb_array' - self._env = PTZRecordVideo(self._env, self._replay_path, name_prefix=f'rl-video-{id(self)}', disable_logger=True) + self._env = PTZRecordVideo( + self._env, self._replay_path, name_prefix=f'rl-video-{id(self)}', disable_logger=True + ) self._init_flag = True if hasattr(self, '_seed'): obs = self._env.reset(seed=self._seed) @@ -400,9 +403,7 @@ def _execute_world_step(self): def render(self): if self.render_mode is None: - gym.logger.warn( - "You are calling render method without specifying any render mode." - ) + gym.logger.warn("You are calling render method without specifying any render mode.") return import pygame @@ -412,8 +413,4 @@ def render(self): observation = np.array(pygame.surfarray.pixels3d(self.screen)) if self.render_mode == "human": pygame.display.flip() - return ( - np.transpose(observation, axes=(1, 0, 2)) - if self.render_mode == "rgb_array" - else None - ) + return (np.transpose(observation, axes=(1, 0, 2)) if self.render_mode == "rgb_array" else None) diff --git a/dizoo/taxi/config/taxi_dqn_config.py b/dizoo/taxi/config/taxi_dqn_config.py index dabbceaad7..312442b9f1 100644 --- a/dizoo/taxi/config/taxi_dqn_config.py +++ b/dizoo/taxi/config/taxi_dqn_config.py @@ -5,18 +5,14 @@ env=dict( collector_env_num=8, evaluator_env_num=8, - n_evaluator_episode=8, - stop_value=20, - max_episode_steps=60, - env_id="Taxi-v3" + n_evaluator_episode=8, + stop_value=20, + max_episode_steps=60, + env_id="Taxi-v3" ), policy=dict( cuda=True, - model=dict( - obs_shape=34, - action_shape=6, - encoder_hidden_size_list=[128, 128] - ), + model=dict(obs_shape=34, action_shape=6, encoder_hidden_size_list=[128, 128]), random_collect_size=5000, nstep=3, discount_factor=0.99, @@ -24,22 +20,13 @@ update_per_collect=10, batch_size=64, learning_rate=0.0001, - learner=dict( - hook=dict( - log_show_after_iter=1000, - ) - ), + learner=dict(hook=dict(log_show_after_iter=1000, )), ), collect=dict(n_sample=32), - eval=dict(evaluator=dict(eval_freq=1000, )), + eval=dict(evaluator=dict(eval_freq=1000, )), other=dict( - eps=dict( - type="linear", - start=1, - end=0.05, - decay=3000000 - ), - replay_buffer=dict(replay_buffer_size=100000,), + eps=dict(type="linear", start=1, end=0.05, decay=3000000), + replay_buffer=dict(replay_buffer_size=100000, ), ), ) ) @@ -47,10 +34,7 @@ main_config = taxi_dqn_config taxi_dqn_create_config = dict( - env=dict( - type="taxi", - import_names=["dizoo.taxi.envs.taxi_env"] - ), + env=dict(type="taxi", import_names=["dizoo.taxi.envs.taxi_env"]), env_manager=dict(type='base'), policy=dict(type='dqn'), replay_buffer=dict(type='deque', import_names=['ding.data.buffer.deque_buffer_wrapper']), @@ -61,4 +45,4 @@ if __name__ == "__main__": from ding.entry import serial_pipeline - serial_pipeline((main_config, create_config), max_env_step=3000000, seed=0) \ No newline at end of file + serial_pipeline((main_config, create_config), max_env_step=3000000, seed=0) diff --git a/dizoo/taxi/entry/taxi_dqn_deploy.py b/dizoo/taxi/entry/taxi_dqn_deploy.py index 15470d8197..d2faba273e 100644 --- a/dizoo/taxi/entry/taxi_dqn_deploy.py +++ b/dizoo/taxi/entry/taxi_dqn_deploy.py @@ -9,6 +9,7 @@ from dizoo.taxi.config.taxi_dqn_config import create_config, main_config from dizoo.taxi.envs.taxi_env import TaxiEnv + def main(main_config: EasyDict, create_config: EasyDict, ckpt_path: str) -> None: main_config.exp_name = f'taxi_dqn_seed0_deploy' cfg = compile_config(main_config, create_cfg=create_config, auto=True) @@ -31,8 +32,4 @@ def main(main_config: EasyDict, create_config: EasyDict, ckpt_path: str) -> None if __name__ == "__main__": - main( - main_config=main_config, - create_config=create_config, - ckpt_path=f'./taxi_dqn_seed0/ckpt/ckpt_best.pth.tar' - ) \ No newline at end of file + main(main_config=main_config, create_config=create_config, ckpt_path=f'./taxi_dqn_seed0/ckpt/ckpt_best.pth.tar') diff --git a/dizoo/taxi/envs/__init__.py b/dizoo/taxi/envs/__init__.py index ce013dfe26..188628ae48 100644 --- a/dizoo/taxi/envs/__init__.py +++ b/dizoo/taxi/envs/__init__.py @@ -1 +1 @@ -from .taxi_env import TaxiEnv \ No newline at end of file +from .taxi_env import TaxiEnv diff --git a/dizoo/taxi/envs/taxi_env.py b/dizoo/taxi/envs/taxi_env.py index a2d5285e58..c47c8c49f1 100644 --- a/dizoo/taxi/envs/taxi_env.py +++ b/dizoo/taxi/envs/taxi_env.py @@ -13,59 +13,58 @@ from ding.torch_utils import to_ndarray from ding.utils import ENV_REGISTRY + @ENV_REGISTRY.register('taxi', force_overwrite=True) class TaxiEnv(BaseEnv): - + def __init__(self, cfg: EasyDict) -> None: - + self._cfg = cfg assert self._cfg.env_id == "Taxi-v3", "Your environment name is not Taxi-v3!" self._init_flag = False self._replay_path = None self._save_replay = False self._frames = [] - + def reset(self) -> np.ndarray: if not self._init_flag: self._env = gym.make( - id=self._cfg.env_id, - render_mode="single_rgb_array", - max_episode_steps=self._cfg.max_episode_steps + id=self._cfg.env_id, render_mode="single_rgb_array", max_episode_steps=self._cfg.max_episode_steps ) self._observation_space = self._env.observation_space self._action_space = self._env.action_space self._reward_space = Box( low=self._env.reward_range[0], high=self._env.reward_range[1], shape=(1, ), dtype=np.float32 - ) - self._init_flag = True + ) + self._init_flag = True self._eval_episode_return = 0 if hasattr(self, '_seed') and hasattr(self, '_dynamic_seed') and self._dynamic_seed: np_seed = 100 * np.random.randint(1, 1000) - self._env_seed = self._seed + np_seed + self._env_seed = self._seed + np_seed elif hasattr(self, '_seed'): self._env_seed = self._seed if hasattr(self, '_seed'): obs = self._env.reset(seed=self._env_seed) else: obs = self._env.reset() - + if self._save_replay: picture = self._env.render() self._frames.append(picture) self._eval_episode_return = 0. obs = self._encode_taxi(obs).astype(np.float32) return obs - + def close(self) -> None: if self._init_flag: self._env.close() self._init_flag = False - + def seed(self, seed: int, dynamic_seed: bool = True) -> None: self._seed = seed self._dynamic_seed = dynamic_seed np.random.seed(self._seed) - + def step(self, action: np.ndarray) -> BaseEnvTimestep: assert isinstance(action, np.ndarray), type(action) action = action.item() @@ -89,7 +88,7 @@ def step(self, action: np.ndarray) -> BaseEnvTimestep: rew = rew.astype(np.float32) obs = obs.astype(np.float32) return BaseEnvTimestep(obs, rew, done, info) - + def enable_save_replay(self, replay_path: Optional[str] = None) -> None: if replay_path is None: replay_path = './video' @@ -98,7 +97,7 @@ def enable_save_replay(self, replay_path: Optional[str] = None) -> None: self._replay_path = replay_path self._save_replay = True self._save_replay_count = 0 - + def random_action(self) -> np.ndarray: random_action = self.action_space.sample() if isinstance(random_action, np.ndarray): @@ -110,12 +109,12 @@ def random_action(self) -> np.ndarray: else: raise TypeError( '`random_action` should be either int/np.ndarray or dict of int/np.ndarray, but get {}: {}'.format( - type(random_action), random_action + type(random_action), random_action ) ) return random_action - - #todo encode the state into a vector + + #todo encode the state into a vector def _encode_taxi(self, obs: np.ndarray) -> np.ndarray: taxi_row, taxi_col, passenger_location, destination = self._env.unwrapped.decode(obs) encoded_obs = np.zeros(34) @@ -123,7 +122,7 @@ def _encode_taxi(self, obs: np.ndarray) -> np.ndarray: encoded_obs[25 + passenger_location] = 1 encoded_obs[30 + destination] = 1 return to_ndarray(encoded_obs) - + @property def observation_space(self) -> Space: return self._observation_space @@ -135,10 +134,10 @@ def action_space(self) -> Space: @property def reward_space(self) -> Space: return self._reward_space - + def __repr__(self) -> str: return "DI-engine Taxi-v3 Env" - + @staticmethod def frames_to_gif(frames: List[imageio.core.util.Array], gif_path: str, duration: float = 0.1) -> None: """ diff --git a/dizoo/taxi/envs/test_taxi_env.py b/dizoo/taxi/envs/test_taxi_env.py index 7334ce4a08..ad3802ac5e 100644 --- a/dizoo/taxi/envs/test_taxi_env.py +++ b/dizoo/taxi/envs/test_taxi_env.py @@ -3,16 +3,12 @@ from easydict import EasyDict from dizoo.taxi import TaxiEnv + @pytest.mark.envtest class TestTaxiEnv: - + def test_naive(self): - env = TaxiEnv( - EasyDict({ - "env_id": "Taxi-v3", - "max_episode_steps": 300 - }) - ) + env = TaxiEnv(EasyDict({"env_id": "Taxi-v3", "max_episode_steps": 300})) env.seed(314, dynamic_seed=False) assert env._seed == 314 obs = env.reset() @@ -38,4 +34,3 @@ def test_naive(self): assert timestep.reward <= env.reward_space.high print(env.observation_space, env.action_space, env.reward_space) env.close() - \ No newline at end of file