Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feature(pu): add resume_training option to allow the envstep and train_iter resume seamlessly #835

Merged
merged 8 commits into from
Nov 5, 2024
4 changes: 3 additions & 1 deletion ding/entry/serial_entry.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def serial_pipeline(
cfg, create_cfg = deepcopy(input_cfg)
create_cfg.policy.type = create_cfg.policy.type + '_command'
env_fn = None if env_setting is None else env_setting[0]
cfg = compile_config(cfg, seed=seed, env=env_fn, auto=True, create_cfg=create_cfg, save_cfg=True)
cfg = compile_config(cfg, seed=seed, env=env_fn, auto=True, create_cfg=create_cfg, save_cfg=True, renew_dir=not cfg.policy.learn.resume_training)
# Create main components: env, policy
if env_setting is None:
env_fn, collector_env_cfg, evaluator_env_cfg = get_vec_env_setting(cfg.env)
Expand Down Expand Up @@ -86,6 +86,8 @@ def serial_pipeline(
# ==========
# Learner's before_run hook.
learner.call_hook('before_run')
if cfg.policy.learn.resume_training:
collector.envstep = learner.collector_envstep

# Accumulate plenty of data at the beginning of training.
if cfg.policy.get('random_collect_size', 0) > 0:
Expand Down
11 changes: 8 additions & 3 deletions ding/entry/serial_entry_mbrl.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def mbrl_entry_setup(
cfg, create_cfg = deepcopy(input_cfg)
create_cfg.policy.type = create_cfg.policy.type + '_command'
env_fn = None if env_setting is None else env_setting[0]
cfg = compile_config(cfg, seed=seed, env=env_fn, auto=True, create_cfg=create_cfg, save_cfg=True)
cfg = compile_config(cfg, seed=seed, env=env_fn, auto=True, create_cfg=create_cfg, save_cfg=True, renew_dir=not cfg.policy.learn.resume_training)

if env_setting is None:
env_fn, collector_env_cfg, evaluator_env_cfg = get_vec_env_setting(cfg.env)
Expand Down Expand Up @@ -81,6 +81,7 @@ def mbrl_entry_setup(
evaluator,
commander,
tb_logger,
resume_training
)


Expand Down Expand Up @@ -125,12 +126,14 @@ def serial_pipeline_dyna(
Returns:
- policy (:obj:`Policy`): Converged policy.
"""
cfg, policy, world_model, env_buffer, learner, collector, collector_env, evaluator, commander, tb_logger = \
cfg, policy, world_model, env_buffer, learner, collector, collector_env, evaluator, commander, tb_logger, resume_training = \
mbrl_entry_setup(input_cfg, seed, env_setting, model)

img_buffer = create_img_buffer(cfg, input_cfg, world_model, tb_logger)

learner.call_hook('before_run')
if cfg.policy.learn.resume_training:
collector.envstep = learner.collector_envstep

if cfg.policy.get('random_collect_size', 0) > 0:
random_collect(cfg.policy, policy, collector, collector_env, commander, env_buffer)
Expand Down Expand Up @@ -198,10 +201,12 @@ def serial_pipeline_dream(
Returns:
- policy (:obj:`Policy`): Converged policy.
"""
cfg, policy, world_model, env_buffer, learner, collector, collector_env, evaluator, commander, tb_logger = \
cfg, policy, world_model, env_buffer, learner, collector, collector_env, evaluator, commander, tb_logger, resume_training = \
mbrl_entry_setup(input_cfg, seed, env_setting, model)

learner.call_hook('before_run')
if cfg.policy.learn.resume_training:
collector.envstep = learner.collector_envstep

if cfg.policy.get('random_collect_size', 0) > 0:
random_collect(cfg.policy, policy, collector, collector_env, commander, env_buffer)
Expand Down
4 changes: 3 additions & 1 deletion ding/entry/serial_entry_ngu.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def serial_pipeline_ngu(
cfg, create_cfg = deepcopy(input_cfg)
create_cfg.policy.type = create_cfg.policy.type + '_command'
env_fn = None if env_setting is None else env_setting[0]
cfg = compile_config(cfg, seed=seed, env=env_fn, auto=True, create_cfg=create_cfg, save_cfg=True)
cfg = compile_config(cfg, seed=seed, env=env_fn, auto=True, create_cfg=create_cfg, save_cfg=True, renew_dir=not cfg.policy.learn.resume_training)
# Create main components: env, policy
if env_setting is None:
env_fn, collector_env_cfg, evaluator_env_cfg = get_vec_env_setting(cfg.env)
Expand Down Expand Up @@ -89,6 +89,8 @@ def serial_pipeline_ngu(
# ==========
# Learner's before_run hook.
learner.call_hook('before_run')
if cfg.policy.learn.resume_training:
collector.envstep = learner.collector_envstep

# Accumulate plenty of data at the beginning of training.
if cfg.policy.get('random_collect_size', 0) > 0:
Expand Down
4 changes: 3 additions & 1 deletion ding/entry/serial_entry_onpolicy.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def serial_pipeline_onpolicy(
cfg, create_cfg = deepcopy(input_cfg)
create_cfg.policy.type = create_cfg.policy.type + '_command'
env_fn = None if env_setting is None else env_setting[0]
cfg = compile_config(cfg, seed=seed, env=env_fn, auto=True, create_cfg=create_cfg, save_cfg=True)
cfg = compile_config(cfg, seed=seed, env=env_fn, auto=True, create_cfg=create_cfg, save_cfg=True, renew_dir=not cfg.policy.learn.resume_training)
# Create main components: env, policy
if env_setting is None:
env_fn, collector_env_cfg, evaluator_env_cfg = get_vec_env_setting(cfg.env)
Expand Down Expand Up @@ -80,6 +80,8 @@ def serial_pipeline_onpolicy(
# ==========
# Learner's before_run hook.
learner.call_hook('before_run')
if cfg.policy.learn.resume_training:
collector.envstep = learner.collector_envstep

while True:
collect_kwargs = commander.step()
Expand Down
4 changes: 3 additions & 1 deletion ding/entry/serial_entry_onpolicy_ppg.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def serial_pipeline_onpolicy_ppg(
cfg, create_cfg = deepcopy(input_cfg)
create_cfg.policy.type = create_cfg.policy.type + '_command'
env_fn = None if env_setting is None else env_setting[0]
cfg = compile_config(cfg, seed=seed, env=env_fn, auto=True, create_cfg=create_cfg, save_cfg=True)
cfg = compile_config(cfg, seed=seed, env=env_fn, auto=True, create_cfg=create_cfg, save_cfg=True, renew_dir=not cfg.policy.learn.resume_training)
# Create main components: env, policy
if env_setting is None:
env_fn, collector_env_cfg, evaluator_env_cfg = get_vec_env_setting(cfg.env)
Expand Down Expand Up @@ -80,6 +80,8 @@ def serial_pipeline_onpolicy_ppg(
# ==========
# Learner's before_run hook.
learner.call_hook('before_run')
if cfg.policy.learn.resume_training:
collector.envstep = learner.collector_envstep

while True:
collect_kwargs = commander.step()
Expand Down
4 changes: 4 additions & 0 deletions ding/policy/base_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,10 @@ def default_config(cls: type) -> EasyDict:
traj_len_inf=False,
# neural network model config
model=dict(),
# If resume_training is True, the environment step count (collector.envstep) and training iteration (train_iter)
# will be loaded from the pretrained checkpoint, allowing training to resume seamlessly
# from where the ckpt left off.
learn=dict(resume_training=False),
)

def __init__(
Expand Down
12 changes: 11 additions & 1 deletion ding/worker/collector/battle_episode_serial_collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,10 +162,20 @@ def envstep(self) -> int:
Overview:
Print the total envstep count.
Return:
- envstep (:obj:`int`): the total envstep count
- envstep (:obj:`int`): The total envstep count.
"""
return self._total_envstep_count

@envstep.setter
def envstep(self, value: int) -> None:
"""
Overview:
Set the total envstep count.
Arguments:
- value (:obj:`int`): The total envstep count.
"""
self._total_envstep_count = value

def close(self) -> None:
"""
Overview:
Expand Down
12 changes: 11 additions & 1 deletion ding/worker/collector/battle_sample_serial_collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,10 +175,20 @@ def envstep(self) -> int:
Overview:
Print the total envstep count.
Return:
- envstep (:obj:`int`): the total envstep count
- envstep (:obj:`int`): The total envstep count.
"""
return self._total_envstep_count

@envstep.setter
def envstep(self, value: int) -> None:
"""
Overview:
Set the total envstep count.
Arguments:
- value (:obj:`int`): The total envstep count.
"""
self._total_envstep_count = value

def close(self) -> None:
"""
Overview:
Expand Down
12 changes: 11 additions & 1 deletion ding/worker/collector/episode_serial_collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,10 +157,20 @@ def envstep(self) -> int:
Overview:
Print the total envstep count.
Return:
- envstep (:obj:`int`): the total envstep count
- envstep (:obj:`int`): The total envstep count.
"""
return self._total_envstep_count

@envstep.setter
def envstep(self, value: int) -> None:
"""
Overview:
Set the total envstep count.
Arguments:
- value (:obj:`int`): The total envstep count.
"""
self._total_envstep_count = value

def close(self) -> None:
"""
Overview:
Expand Down
12 changes: 11 additions & 1 deletion ding/worker/collector/sample_serial_collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,10 +185,20 @@ def envstep(self) -> int:
Overview:
Print the total envstep count.
Return:
- envstep (:obj:`int`): the total envstep count
- envstep (:obj:`int`): The total envstep count.
"""
return self._total_envstep_count

@envstep.setter
def envstep(self, value: int) -> None:
"""
Overview:
Set the total envstep count.
Arguments:
- value (:obj:`int`): The total envstep count.
"""
self._total_envstep_count = value

def close(self) -> None:
"""
Overview:
Expand Down
22 changes: 22 additions & 0 deletions ding/worker/learner/base_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,8 @@ def __init__(
self._hooks = {'before_run': [], 'before_iter': [], 'after_iter': [], 'after_run': []}
# Last iteration. Used to record current iter.
self._last_iter = CountVar(init_val=0)
# Collector envstep. Used to record current envstep.
self._collector_envstep = 0

# Setup time wrapper and hook.
self._setup_wrapper()
Expand Down Expand Up @@ -177,6 +179,26 @@ def register_hook(self, hook: LearnerHook) -> None:
"""
add_learner_hook(self._hooks, hook)

@property
def collector_envstep(self) -> int:
"""
Overview:
Get current collector envstep.
Returns:
- collector_envstep (:obj:`int`): Current collector envstep.
"""
return self._collector_envstep

@collector_envstep.setter
def collector_envstep(self, value: int) -> None:
"""
Overview:
Set current collector envstep.
Arguments:
- value (:obj:`int`): Current collector envstep.
"""
self._collector_envstep = value

def train(self, data: dict, envstep: int = -1, policy_kwargs: Optional[dict] = None) -> None:
"""
Overview:
Expand Down
4 changes: 4 additions & 0 deletions ding/worker/learner/learner_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,9 @@ def __call__(self, engine: 'BaseLearner') -> None: # noqa
if 'last_iter' in state_dict:
last_iter = state_dict.pop('last_iter')
engine.last_iter.update(last_iter)
if 'last_step' in state_dict:
last_step = state_dict.pop('last_step')
engine._collector_envstep = last_step
engine.policy.load_state_dict(state_dict)
engine.info('{} load ckpt in {}'.format(engine.instance_name, path))

Expand Down Expand Up @@ -166,6 +169,7 @@ def __call__(self, engine: 'BaseLearner') -> None: # noqa
path = os.path.join(dirname, ckpt_name)
state_dict = engine.policy.state_dict()
state_dict.update({'last_iter': engine.last_iter.val})
state_dict.update({'last_step': engine.collector_envstep})
save_file(path, state_dict)
engine.info('{} save ckpt in {}'.format(engine.instance_name, path))

Expand Down
10 changes: 9 additions & 1 deletion dizoo/classic_control/cartpole/config/cartpole_ppo_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,15 @@
value_weight=0.5,
entropy_weight=0.01,
clip_ratio=0.2,
learner=dict(hook=dict(save_ckpt_after_iter=100)),
# Path to the pretrained checkpoint (ckpt).
# If set to an empty string (''), no pretrained model will be loaded.
# To load a pretrained ckpt, specify the path like this:
# learner=dict(hook=dict(load_ckpt_before_run='/path/to/your/ckpt/iteration_100.pth.tar')),

# If True, the environment step count (collector.envstep) and training iteration (train_iter)
# will be loaded from the pretrained checkpoint, allowing training to resume seamlessly
# from where the ckpt left off.
resume_training=False,
),
collect=dict(
n_sample=256,
Expand Down
2 changes: 1 addition & 1 deletion dizoo/league_demo/league_demo_collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ def envstep(self) -> int:
Overview:
Print the total envstep count.
Return:
- envstep (:obj:`int`): the total envstep count
- envstep (:obj:`int`): The total envstep count.
"""
return self._total_envstep_count

Expand Down
Loading