Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add option for evaluating checkpoint #246

Merged
merged 4 commits into from
Jul 25, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,9 @@ torchrun --standalone --nnodes=1 --nproc_per_node=2 runner.py --train --file rl_
| env_config | | | Env configuration block. It goes directly to the environment. This example was take for my atari wrapper. |
| skip | 4 | | Number of frames to skip |
| name | BreakoutNoFrameskip-v4 | | The exact name of an (atari) gym env. An example, depends on the training env this parameters can be different. |
| evaluation | True | False | Enables the evaluation feature for inferencing while training. |
| update_checkpoint_freq | 100 | 100 | Frequency in number of steps to look for new checkpoints. |
| dir_to_monitor | | | Directory to search for checkpoints in during evaluation. |

## Custom network example:
[simple test network](rl_games/envs/test_network.py)
Expand Down Expand Up @@ -299,6 +302,7 @@ Additional environment supported properties and functions
* Added shaped reward graph to the tensorboard.
* Fixed bug with SAC not saving weights with save_frequency.
* Added multi-node training support for GPU-accelerated training environments like Isaac Gym. No changes in training scripts are required. Thanks to @ankurhanda and @ArthurAllshire for assistance in implementation.
* Added evaluation feature for inferencing during training. Checkpoints from training process can be automatically picked up and updated in the inferencing process when enabled.

1.6.0

Expand Down
94 changes: 94 additions & 0 deletions rl_games/common/player.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,13 @@
import os
import shutil
import threading
import time
import gym
import numpy as np
import torch
import copy
from os.path import basename
from typing import Optional
from rl_games.common import vecenv
from rl_games.common import env_configurations
from rl_games.algos_torch import model_builder
Expand Down Expand Up @@ -71,6 +76,90 @@ def __init__(self, params):
self.max_steps = 108000 // 4
self.device = torch.device(self.device_name)

self.evaluation = self.player_config.get("evaluation", False)
self.update_checkpoint_freq = self.player_config.get("update_checkpoint_freq", 100)
# if we run player as evaluation worker this will take care of loading new checkpoints
self.dir_to_monitor = self.player_config.get("dir_to_monitor")
# path to the newest checkpoint
self.checkpoint_to_load: Optional[str] = None

if self.evaluation and self.dir_to_monitor is not None:
self.checkpoint_mutex = threading.Lock()
self.eval_checkpoint_dir = os.path.join(self.dir_to_monitor, "eval_checkpoints")
os.makedirs(self.eval_checkpoint_dir, exist_ok=True)

patterns = ["*.pth"]
from watchdog.observers import Observer
from watchdog.events import PatternMatchingEventHandler
self.file_events = PatternMatchingEventHandler(patterns)
self.file_events.on_created = self.on_file_created
self.file_events.on_modified = self.on_file_modified

self.file_observer = Observer()
self.file_observer.schedule(self.file_events, self.dir_to_monitor, recursive=False)
self.file_observer.start()

def wait_for_checkpoint(self):
if self.dir_to_monitor is None:
return

attempt = 0
while True:
attempt += 1
with self.checkpoint_mutex:
if self.checkpoint_to_load is not None:
if attempt % 10 == 0:
print(f"Evaluation: waiting for new checkpoint in {self.dir_to_monitor}...")
break
time.sleep(1.0)

print(f"Checkpoint {self.checkpoint_to_load} is available!")

def maybe_load_new_checkpoint(self):
# lock mutex while loading new checkpoint
with self.checkpoint_mutex:
if self.checkpoint_to_load is not None:
print(f"Evaluation: loading new checkpoint {self.checkpoint_to_load}...")
# try if we can load anything from the pth file, this will quickly fail if the file is corrupted
# without triggering the retry loop in "safe_filesystem_op()"
load_error = False
try:
torch.load(self.checkpoint_to_load)
except Exception as e:
print(f"Evaluation: checkpoint file is likely corrupted {self.checkpoint_to_load}: {e}")
load_error = True

if not load_error:
try:
self.restore(self.checkpoint_to_load)
except Exception as e:
print(f"Evaluation: failed to load new checkpoint {self.checkpoint_to_load}: {e}")

# whether we succeeded or not, forget about this checkpoint
self.checkpoint_to_load = None

def process_new_eval_checkpoint(self, path):
with self.checkpoint_mutex:
# print(f"New checkpoint {path} available for evaluation")
# copy file to eval_checkpoints dir using shutil
# since we're running the evaluation worker in a separate process,
# there is a chance that the file is changed/corrupted while we're copying it
# not sure what we can do about this. In practice it never happened so far though
try:
eval_checkpoint_path = os.path.join(self.eval_checkpoint_dir, basename(path))
shutil.copyfile(path, eval_checkpoint_path)
except Exception as e:
print(f"Failed to copy {path} to {eval_checkpoint_path}: {e}")
return

self.checkpoint_to_load = eval_checkpoint_path

def on_file_created(self, event):
self.process_new_eval_checkpoint(event.src_path)

def on_file_modified(self, event):
self.process_new_eval_checkpoint(event.src_path)

def load_networks(self, params):
builder = model_builder.ModelBuilder()
self.config['network'] = builder.load(params)
Expand Down Expand Up @@ -204,6 +293,8 @@ def run(self):
if has_masks_func:
has_masks = self.env.has_action_mask()

self.wait_for_checkpoint()

need_init_rnn = self.is_rnn
for _ in range(n_games):
if games_played >= n_games:
Expand All @@ -223,6 +314,9 @@ def run(self):
print_game_res = False

for n in range(self.max_steps):
if self.evaluation and n % self.update_checkpoint_freq == 0:
self.maybe_load_new_checkpoint()

if has_masks:
masks = self.env.get_action_mask()
action = self.get_masked_action(
Expand Down
3 changes: 2 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
'tensorboardX>=1.6',
'setproctitle',
'psutil',
'pyyaml'
'pyyaml',
'watchdog>=2.1.9,<3.0.0', # for evaluation process
],
)