diff --git a/README.md b/README.md index 4eb73ed4..329684aa 100644 --- a/README.md +++ b/README.md @@ -263,6 +263,9 @@ torchrun --standalone --nnodes=1 --nproc_per_node=2 runner.py --train --file rl_ | env_config | | | Env configuration block. It goes directly to the environment. This example was take for my atari wrapper. | | skip | 4 | | Number of frames to skip | | name | BreakoutNoFrameskip-v4 | | The exact name of an (atari) gym env. An example, depends on the training env this parameters can be different. | +| evaluation | True | False | Enables the evaluation feature for inferencing while training. | +| update_checkpoint_freq | 100 | 100 | Frequency in number of steps to look for new checkpoints. | +| dir_to_monitor | | | Directory to search for checkpoints in during evaluation. | ## Custom network example: [simple test network](rl_games/envs/test_network.py) @@ -299,6 +302,7 @@ Additional environment supported properties and functions * Added shaped reward graph to the tensorboard. * Fixed bug with SAC not saving weights with save_frequency. * Added multi-node training support for GPU-accelerated training environments like Isaac Gym. No changes in training scripts are required. Thanks to @ankurhanda and @ArthurAllshire for assistance in implementation. +* Added evaluation feature for inferencing during training. Checkpoints from training process can be automatically picked up and updated in the inferencing process when enabled. 1.6.0 diff --git a/rl_games/common/player.py b/rl_games/common/player.py index d41fad0b..f1a5c35e 100644 --- a/rl_games/common/player.py +++ b/rl_games/common/player.py @@ -1,8 +1,13 @@ +import os +import shutil +import threading import time import gym import numpy as np import torch import copy +from os.path import basename +from typing import Optional from rl_games.common import vecenv from rl_games.common import env_configurations from rl_games.algos_torch import model_builder @@ -71,6 +76,90 @@ def __init__(self, params): self.max_steps = 108000 // 4 self.device = torch.device(self.device_name) + self.evaluation = self.player_config.get("evaluation", False) + self.update_checkpoint_freq = self.player_config.get("update_checkpoint_freq", 100) + # if we run player as evaluation worker this will take care of loading new checkpoints + self.dir_to_monitor = self.player_config.get("dir_to_monitor") + # path to the newest checkpoint + self.checkpoint_to_load: Optional[str] = None + + if self.evaluation and self.dir_to_monitor is not None: + self.checkpoint_mutex = threading.Lock() + self.eval_checkpoint_dir = os.path.join(self.dir_to_monitor, "eval_checkpoints") + os.makedirs(self.eval_checkpoint_dir, exist_ok=True) + + patterns = ["*.pth"] + from watchdog.observers import Observer + from watchdog.events import PatternMatchingEventHandler + self.file_events = PatternMatchingEventHandler(patterns) + self.file_events.on_created = self.on_file_created + self.file_events.on_modified = self.on_file_modified + + self.file_observer = Observer() + self.file_observer.schedule(self.file_events, self.dir_to_monitor, recursive=False) + self.file_observer.start() + + def wait_for_checkpoint(self): + if self.dir_to_monitor is None: + return + + attempt = 0 + while True: + attempt += 1 + with self.checkpoint_mutex: + if self.checkpoint_to_load is not None: + if attempt % 10 == 0: + print(f"Evaluation: waiting for new checkpoint in {self.dir_to_monitor}...") + break + time.sleep(1.0) + + print(f"Checkpoint {self.checkpoint_to_load} is available!") + + def maybe_load_new_checkpoint(self): + # lock mutex while loading new checkpoint + with self.checkpoint_mutex: + if self.checkpoint_to_load is not None: + print(f"Evaluation: loading new checkpoint {self.checkpoint_to_load}...") + # try if we can load anything from the pth file, this will quickly fail if the file is corrupted + # without triggering the retry loop in "safe_filesystem_op()" + load_error = False + try: + torch.load(self.checkpoint_to_load) + except Exception as e: + print(f"Evaluation: checkpoint file is likely corrupted {self.checkpoint_to_load}: {e}") + load_error = True + + if not load_error: + try: + self.restore(self.checkpoint_to_load) + except Exception as e: + print(f"Evaluation: failed to load new checkpoint {self.checkpoint_to_load}: {e}") + + # whether we succeeded or not, forget about this checkpoint + self.checkpoint_to_load = None + + def process_new_eval_checkpoint(self, path): + with self.checkpoint_mutex: + # print(f"New checkpoint {path} available for evaluation") + # copy file to eval_checkpoints dir using shutil + # since we're running the evaluation worker in a separate process, + # there is a chance that the file is changed/corrupted while we're copying it + # not sure what we can do about this. In practice it never happened so far though + try: + eval_checkpoint_path = os.path.join(self.eval_checkpoint_dir, basename(path)) + shutil.copyfile(path, eval_checkpoint_path) + except Exception as e: + print(f"Failed to copy {path} to {eval_checkpoint_path}: {e}") + return + + self.checkpoint_to_load = eval_checkpoint_path + + def on_file_created(self, event): + self.process_new_eval_checkpoint(event.src_path) + + def on_file_modified(self, event): + self.process_new_eval_checkpoint(event.src_path) + def load_networks(self, params): builder = model_builder.ModelBuilder() self.config['network'] = builder.load(params) @@ -204,6 +293,8 @@ def run(self): if has_masks_func: has_masks = self.env.has_action_mask() + self.wait_for_checkpoint() + need_init_rnn = self.is_rnn for _ in range(n_games): if games_played >= n_games: @@ -223,6 +314,9 @@ def run(self): print_game_res = False for n in range(self.max_steps): + if self.evaluation and n % self.update_checkpoint_freq == 0: + self.maybe_load_new_checkpoint() + if has_masks: masks = self.env.get_action_mask() action = self.get_masked_action( diff --git a/setup.py b/setup.py index 514751a7..99c2ea82 100644 --- a/setup.py +++ b/setup.py @@ -45,6 +45,7 @@ 'tensorboardX>=1.6', 'setproctitle', 'psutil', - 'pyyaml' + 'pyyaml', + 'watchdog>=2.1.9,<3.0.0', # for evaluation process ], )