diff --git a/rl_games/common/player.py b/rl_games/common/player.py index 527602bd..bafbf804 100644 --- a/rl_games/common/player.py +++ b/rl_games/common/player.py @@ -1,11 +1,18 @@ +import os +import shutil +import threading import time import gym import numpy as np import torch import copy +from os.path import basename +from typing import Optional + from rl_games.common import env_configurations from rl_games.algos_torch import model_builder + class BasePlayer(object): def __init__(self, params): self.config = config = params['config'] @@ -52,6 +59,90 @@ def __init__(self, params): self.max_steps = 108000 // 4 self.device = torch.device(self.device_name) + self.evaluation = self.player_config.get("evaluation", False) + self.update_checkpoint_freq = self.player_config.get("update_checkpoint_freq", 100) + # if we run player as evaluation worker this will take care of loading new checkpoints + self.dir_to_monitor = self.player_config.get("dir_to_monitor") + # path to the newest checkpoint + self.checkpoint_to_load: Optional[str] = None + self.checkpoint_mutex = threading.Lock() + + if self.evaluation and self.dir_to_monitor is not None: + self.eval_checkpoint_dir = os.path.join(self.dir_to_monitor, "eval_checkpoints") + os.makedirs(self.eval_checkpoint_dir, exist_ok=True) + + patterns = ["*.pth"] + from watchdog.observers import Observer + from watchdog.events import PatternMatchingEventHandler + self.file_events = PatternMatchingEventHandler(patterns) + self.file_events.on_created = self.on_file_created + self.file_events.on_modified = self.on_file_modified + + self.file_observer = Observer() + self.file_observer.schedule(self.file_events, self.dir_to_monitor, recursive=False) + self.file_observer.start() + + def wait_for_checkpoint(self): + if self.dir_to_monitor is None: + return + + attempt = 0 + while True: + attempt += 1 + with self.checkpoint_mutex: + if self.checkpoint_to_load is not None: + if attempt % 10 == 0: + print(f"Evaluation: waiting for new checkpoint in {self.dir_to_monitor}...") + break + time.sleep(1.0) + + print(f"Checkpoint {self.checkpoint_to_load} is available!") + + def maybe_load_new_checkpoint(self): + # lock mutex while loading new checkpoint + with self.checkpoint_mutex: + if self.checkpoint_to_load is not None: + print(f"Evaluation: loading new checkpoint {self.checkpoint_to_load}...") + # try if we can load anything from the pth file, this will quickly fail if the file is corrupted + # without triggering the retry loop in "safe_filesystem_op()" + load_error = False + try: + torch.load(self.checkpoint_to_load) + except Exception as e: + print(f"Evaluation: checkpoint file is likely corrupted {self.checkpoint_to_load}: {e}") + load_error = True + + if not load_error: + try: + self.restore(self.checkpoint_to_load) + except Exception as e: + print(f"Evaluation: failed to load new checkpoint {self.checkpoint_to_load}: {e}") + + # whether we succeeded or not, forget about this checkpoint + self.checkpoint_to_load = None + + def process_new_eval_checkpoint(self, path): + with self.checkpoint_mutex: + # print(f"New checkpoint {path} available for evaluation") + # copy file to eval_checkpoints dir using shutil + # since we're running the evaluation worker in a separate process, + # there is a chance that the file is changed/corrupted while we're copying it + # not sure what we can do about this. In practice it never happened so far though + try: + eval_checkpoint_path = os.path.join(self.eval_checkpoint_dir, basename(path)) + shutil.copyfile(path, eval_checkpoint_path) + except Exception as e: + print(f"Failed to copy {path} to {eval_checkpoint_path}: {e}") + return + + self.checkpoint_to_load = eval_checkpoint_path + + def on_file_created(self, event): + self.process_new_eval_checkpoint(event.src_path) + + def on_file_modified(self, event): + self.process_new_eval_checkpoint(event.src_path) + def load_networks(self, params): builder = model_builder.ModelBuilder() self.config['network'] = builder.load(params) @@ -184,6 +275,8 @@ def run(self): if has_masks_func: has_masks = self.env.has_action_mask() + self.wait_for_checkpoint() + need_init_rnn = self.is_rnn for _ in range(n_games): if games_played >= n_games: @@ -203,6 +296,9 @@ def run(self): print_game_res = False for n in range(self.max_steps): + if n % self.update_checkpoint_freq == 0: + self.maybe_load_new_checkpoint() + if has_masks: masks = self.env.get_action_mask() action = self.get_masked_action( diff --git a/setup.py b/setup.py index 10ae34a3..683fe70e 100644 --- a/setup.py +++ b/setup.py @@ -44,6 +44,7 @@ 'setproctitle', 'psutil', 'pyyaml' + 'watchdog>=2.1.9,<3.0.0', # for evaluation process (IsaacGymEnvs 1.4.0 feature) # Optional dependencies # 'ray>=1.1.0', ],