Skip to content

Commit

Permalink
add option for evaluating checkpoint (#246)
Browse files Browse the repository at this point in the history
* add option for evaluating checkpoint

* update setup.py

* update docs

* update docs
  • Loading branch information
kellyguo11 authored Jul 25, 2023
1 parent f2b45f2 commit 990b478
Show file tree
Hide file tree
Showing 3 changed files with 100 additions and 1 deletion.
4 changes: 4 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,9 @@ torchrun --standalone --nnodes=1 --nproc_per_node=2 runner.py --train --file rl_
| env_config | | | Env configuration block. It goes directly to the environment. This example was take for my atari wrapper. |
| skip | 4 | | Number of frames to skip |
| name | BreakoutNoFrameskip-v4 | | The exact name of an (atari) gym env. An example, depends on the training env this parameters can be different. |
| evaluation | True | False | Enables the evaluation feature for inferencing while training. |
| update_checkpoint_freq | 100 | 100 | Frequency in number of steps to look for new checkpoints. |
| dir_to_monitor | | | Directory to search for checkpoints in during evaluation. |

## Custom network example:
[simple test network](rl_games/envs/test_network.py)
Expand Down Expand Up @@ -299,6 +302,7 @@ Additional environment supported properties and functions
* Added shaped reward graph to the tensorboard.
* Fixed bug with SAC not saving weights with save_frequency.
* Added multi-node training support for GPU-accelerated training environments like Isaac Gym. No changes in training scripts are required. Thanks to @ankurhanda and @ArthurAllshire for assistance in implementation.
* Added evaluation feature for inferencing during training. Checkpoints from training process can be automatically picked up and updated in the inferencing process when enabled.

1.6.0

Expand Down
94 changes: 94 additions & 0 deletions rl_games/common/player.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,13 @@
import os
import shutil
import threading
import time
import gym
import numpy as np
import torch
import copy
from os.path import basename
from typing import Optional
from rl_games.common import vecenv
from rl_games.common import env_configurations
from rl_games.algos_torch import model_builder
Expand Down Expand Up @@ -71,6 +76,90 @@ def __init__(self, params):
self.max_steps = 108000 // 4
self.device = torch.device(self.device_name)

self.evaluation = self.player_config.get("evaluation", False)
self.update_checkpoint_freq = self.player_config.get("update_checkpoint_freq", 100)
# if we run player as evaluation worker this will take care of loading new checkpoints
self.dir_to_monitor = self.player_config.get("dir_to_monitor")
# path to the newest checkpoint
self.checkpoint_to_load: Optional[str] = None

if self.evaluation and self.dir_to_monitor is not None:
self.checkpoint_mutex = threading.Lock()
self.eval_checkpoint_dir = os.path.join(self.dir_to_monitor, "eval_checkpoints")
os.makedirs(self.eval_checkpoint_dir, exist_ok=True)

patterns = ["*.pth"]
from watchdog.observers import Observer
from watchdog.events import PatternMatchingEventHandler
self.file_events = PatternMatchingEventHandler(patterns)
self.file_events.on_created = self.on_file_created
self.file_events.on_modified = self.on_file_modified

self.file_observer = Observer()
self.file_observer.schedule(self.file_events, self.dir_to_monitor, recursive=False)
self.file_observer.start()

def wait_for_checkpoint(self):
if self.dir_to_monitor is None:
return

attempt = 0
while True:
attempt += 1
with self.checkpoint_mutex:
if self.checkpoint_to_load is not None:
if attempt % 10 == 0:
print(f"Evaluation: waiting for new checkpoint in {self.dir_to_monitor}...")
break
time.sleep(1.0)

print(f"Checkpoint {self.checkpoint_to_load} is available!")

def maybe_load_new_checkpoint(self):
# lock mutex while loading new checkpoint
with self.checkpoint_mutex:
if self.checkpoint_to_load is not None:
print(f"Evaluation: loading new checkpoint {self.checkpoint_to_load}...")
# try if we can load anything from the pth file, this will quickly fail if the file is corrupted
# without triggering the retry loop in "safe_filesystem_op()"
load_error = False
try:
torch.load(self.checkpoint_to_load)
except Exception as e:
print(f"Evaluation: checkpoint file is likely corrupted {self.checkpoint_to_load}: {e}")
load_error = True

if not load_error:
try:
self.restore(self.checkpoint_to_load)
except Exception as e:
print(f"Evaluation: failed to load new checkpoint {self.checkpoint_to_load}: {e}")

# whether we succeeded or not, forget about this checkpoint
self.checkpoint_to_load = None

def process_new_eval_checkpoint(self, path):
with self.checkpoint_mutex:
# print(f"New checkpoint {path} available for evaluation")
# copy file to eval_checkpoints dir using shutil
# since we're running the evaluation worker in a separate process,
# there is a chance that the file is changed/corrupted while we're copying it
# not sure what we can do about this. In practice it never happened so far though
try:
eval_checkpoint_path = os.path.join(self.eval_checkpoint_dir, basename(path))
shutil.copyfile(path, eval_checkpoint_path)
except Exception as e:
print(f"Failed to copy {path} to {eval_checkpoint_path}: {e}")
return

self.checkpoint_to_load = eval_checkpoint_path

def on_file_created(self, event):
self.process_new_eval_checkpoint(event.src_path)

def on_file_modified(self, event):
self.process_new_eval_checkpoint(event.src_path)

def load_networks(self, params):
builder = model_builder.ModelBuilder()
self.config['network'] = builder.load(params)
Expand Down Expand Up @@ -204,6 +293,8 @@ def run(self):
if has_masks_func:
has_masks = self.env.has_action_mask()

self.wait_for_checkpoint()

need_init_rnn = self.is_rnn
for _ in range(n_games):
if games_played >= n_games:
Expand All @@ -223,6 +314,9 @@ def run(self):
print_game_res = False

for n in range(self.max_steps):
if self.evaluation and n % self.update_checkpoint_freq == 0:
self.maybe_load_new_checkpoint()

if has_masks:
masks = self.env.get_action_mask()
action = self.get_masked_action(
Expand Down
3 changes: 2 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
'tensorboardX>=1.6',
'setproctitle',
'psutil',
'pyyaml'
'pyyaml',
'watchdog>=2.1.9,<3.0.0', # for evaluation process
],
)

0 comments on commit 990b478

Please sign in to comment.