diff --git a/README.md b/README.md index 4eb73ed4..6410f09d 100644 --- a/README.md +++ b/README.md @@ -263,6 +263,9 @@ torchrun --standalone --nnodes=1 --nproc_per_node=2 runner.py --train --file rl_ | env_config | | | Env configuration block. It goes directly to the environment. This example was take for my atari wrapper. | | skip | 4 | | Number of frames to skip | | name | BreakoutNoFrameskip-v4 | | The exact name of an (atari) gym env. An example, depends on the training env this parameters can be different. | +| evaluation | True | False | Enables the evaluation feature for inferencing while training. | +| update_checkpoint_freq | 100 | 100 | Frequency in number of steps to look for new checkpoints. | +| dir_to_monitor | | | Directory to search for checkpoints in during evaluation. | ## Custom network example: [simple test network](rl_games/envs/test_network.py) @@ -299,6 +302,9 @@ Additional environment supported properties and functions * Added shaped reward graph to the tensorboard. * Fixed bug with SAC not saving weights with save_frequency. * Added multi-node training support for GPU-accelerated training environments like Isaac Gym. No changes in training scripts are required. Thanks to @ankurhanda and @ArthurAllshire for assistance in implementation. +* Added evaluation feature for inferencing during training. Checkpoints from training process can be automatically picked up and updated in the inferencing process when enabled. +* Added get/set API for runtime update of rl training parameters. Thanks to @ArthurAllshire for the initial version of fast PBT code. +* Fixed SAC not loading weights properly. 1.6.0 diff --git a/rl_games/algos_torch/a2c_continuous.py b/rl_games/algos_torch/a2c_continuous.py index 2fa97c28..09f9f35e 100644 --- a/rl_games/algos_torch/a2c_continuous.py +++ b/rl_games/algos_torch/a2c_continuous.py @@ -10,6 +10,7 @@ class A2CAgent(a2c_common.ContinuousA2CBase): + def __init__(self, base_name, params): a2c_common.ContinuousA2CBase.__init__(self, base_name, params) obs_shape = self.obs_shape @@ -66,9 +67,9 @@ def save(self, fn): state = self.get_full_state_weights() torch_ext.save_checkpoint(fn, state) - def restore(self, fn): + def restore(self, fn, set_epoch=True): checkpoint = torch_ext.load_checkpoint(fn) - self.set_full_state_weights(checkpoint) + self.set_full_state_weights(checkpoint, set_epoch=set_epoch) def get_masked_action_values(self, obs, action_masks): assert False diff --git a/rl_games/algos_torch/a2c_discrete.py b/rl_games/algos_torch/a2c_discrete.py index a4f20d7b..d386cd8e 100644 --- a/rl_games/algos_torch/a2c_discrete.py +++ b/rl_games/algos_torch/a2c_discrete.py @@ -70,9 +70,9 @@ def save(self, fn): state = self.get_full_state_weights() torch_ext.save_checkpoint(fn, state) - def restore(self, fn): + def restore(self, fn, set_epoch=True): checkpoint = torch_ext.load_checkpoint(fn) - self.set_full_state_weights(checkpoint) + self.set_full_state_weights(checkpoint, set_epoch=set_epoch) def get_masked_action_values(self, obs, action_masks): processed_obs = self._preproc_obs(obs['obs']) diff --git a/rl_games/algos_torch/players.py b/rl_games/algos_torch/players.py index 1d19fb16..4a4dfaab 100644 --- a/rl_games/algos_torch/players.py +++ b/rl_games/algos_torch/players.py @@ -16,6 +16,7 @@ def rescale_actions(low, high, action): class PpoPlayerContinuous(BasePlayer): + def __init__(self, params): BasePlayer.__init__(self, params) self.network = self.config['network'] @@ -187,6 +188,7 @@ def reset(self): class SACPlayer(BasePlayer): + def __init__(self, params): BasePlayer.__init__(self, params) self.network = self.config['network'] diff --git a/rl_games/algos_torch/sac_agent.py b/rl_games/algos_torch/sac_agent.py index c9b4dc15..3697a7f3 100644 --- a/rl_games/algos_torch/sac_agent.py +++ b/rl_games/algos_torch/sac_agent.py @@ -38,6 +38,7 @@ def __init__(self, base_name, params): self.num_steps_per_episode = config.get("num_steps_per_episode", 1) self.normalize_input = config.get("normalize_input", False) + # TODO: double-check! To use bootstrap instead? self.max_env_steps = config.get("max_env_steps", 1000) # temporary, in future we will use other approach print(self.batch_size, self.num_actors, self.num_agents) @@ -60,7 +61,6 @@ def __init__(self, base_name, params): 'action_dim': self.env_info["action_space"].shape[0], 'actions_num' : self.actions_num, 'input_shape' : obs_shape, - 'normalize_input' : self.normalize_input, 'normalize_input': self.normalize_input, } self.model = self.network.build(net_config) @@ -88,12 +88,8 @@ def __init__(self, base_name, params): self.target_entropy = self.target_entropy_coef * -self.env_info['action_space'].shape[0] print("Target entropy", self.target_entropy) - self.step = 0 self.algo_observer = config['features']['observer'] - # TODO: Is there a better way to get the maximum number of episodes? - self.max_episodes = torch.ones(self.num_actors, device=self._device)*self.num_steps_per_episode - def load_networks(self, params): builder = model_builder.ModelBuilder() self.config['network'] = builder.load(params) @@ -133,6 +129,8 @@ def base_init(self, base_name, config): self.max_epochs = self.config.get('max_epochs', -1) self.max_frames = self.config.get('max_frames', -1) + self.save_freq = config.get('save_frequency', 0) + self.network = config['network'] self.rewards_shaper = config['reward_shaper'] self.num_agents = self.env_info.get('agents', 1) @@ -146,10 +144,10 @@ def base_init(self, base_name, config): self.min_alpha = torch.tensor(np.log(1)).float().to(self._device) self.frame = 0 + self.epoch_num = 0 self.update_time = 0 - self.last_mean_rewards = -100500 + self.last_mean_rewards = -1000000000 self.play_time = 0 - self.epoch_num = 0 # TODO: put it into the separate class pbt_str = '' @@ -205,17 +203,8 @@ def alpha(self): def device(self): return self._device - def get_full_state_weights(self): - state = self.get_weights() - - state['steps'] = self.step - state['actor_optimizer'] = self.actor_optimizer.state_dict() - state['critic_optimizer'] = self.critic_optimizer.state_dict() - state['log_alpha_optimizer'] = self.log_alpha_optimizer.state_dict() - - return state - def get_weights(self): + print("Loading weights") state = {'actor': self.model.sac_network.actor.state_dict(), 'critic': self.model.sac_network.critic.state_dict(), 'critic_target': self.model.sac_network.critic_target.state_dict()} @@ -233,17 +222,45 @@ def set_weights(self, weights): if self.normalize_input and 'running_mean_std' in weights: self.model.running_mean_std.load_state_dict(weights['running_mean_std']) - def set_full_state_weights(self, weights): + def get_full_state_weights(self): + print("Loading full weights") + state = self.get_weights() + + state['epoch'] = self.epoch_num + state['frame'] = self.frame + state['actor_optimizer'] = self.actor_optimizer.state_dict() + state['critic_optimizer'] = self.critic_optimizer.state_dict() + state['log_alpha_optimizer'] = self.log_alpha_optimizer.state_dict() + + return state + + def set_full_state_weights(self, weights, set_epoch=True): self.set_weights(weights) - self.step = weights['step'] + if set_epoch: + self.epoch_num = weights['epoch'] + self.frame = weights['frame'] + self.actor_optimizer.load_state_dict(weights['actor_optimizer']) self.critic_optimizer.load_state_dict(weights['critic_optimizer']) self.log_alpha_optimizer.load_state_dict(weights['log_alpha_optimizer']) - def restore(self, fn): + self.last_mean_rewards = weights.get('last_mean_rewards', -1000000000) + + if self.vec_env is not None: + env_state = weights.get('env_state', None) + self.vec_env.set_env_state(env_state) + + def restore(self, fn, set_epoch=True): + print("SAC restore") checkpoint = torch_ext.load_checkpoint(fn) - self.set_full_state_weights(checkpoint) + self.set_full_state_weights(checkpoint, set_epoch=set_epoch) + + def get_param(self, param_name): + pass + + def set_param(self, param_name, param_value): + pass def get_masked_action_values(self, obs, action_masks): assert False @@ -334,6 +351,7 @@ def preproc_obs(self, obs): if isinstance(obs, dict): obs = obs['obs'] obs = self.model.norm_obs(obs) + return obs def cast_obs(self, obs): @@ -348,7 +366,7 @@ def cast_obs(self, obs): return obs - # todo: move to common utils + # TODO: move to common utils def obs_to_tensors(self, obs): obs_is_dict = isinstance(obs, dict) if obs_is_dict: @@ -359,6 +377,7 @@ def obs_to_tensors(self, obs): upd_obs = self.cast_obs(obs) if not obs_is_dict or 'obs' not in obs: upd_obs = {'obs' : upd_obs} + return upd_obs def _obs_to_tensors_internal(self, obs): @@ -368,18 +387,19 @@ def _obs_to_tensors_internal(self, obs): upd_obs[key] = self._obs_to_tensors_internal(value) else: upd_obs = self.cast_obs(obs) + return upd_obs def preprocess_actions(self, actions): if not self.is_tensor_obses: actions = actions.cpu().numpy() + return actions def env_step(self, actions): actions = self.preprocess_actions(actions) obs, rewards, dones, infos = self.vec_env.step(actions) # (obs_space) -> (n, obs_space) - self.step += self.num_actors if self.is_tensor_obses: return self.obs_to_tensors(obs), rewards.to(self._device), dones.to(self._device), infos else: @@ -415,7 +435,7 @@ def extract_actor_stats(self, actor_losses, entropies, alphas, alpha_losses, act def clear_stats(self): self.game_rewards.clear() self.game_lengths.clear() - self.mean_rewards = self.last_mean_rewards = -100500 + self.mean_rewards = self.last_mean_rewards = -1000000000 self.algo_observer.after_clear_stats() def play_steps(self, random_exploration = False): @@ -431,6 +451,11 @@ def play_steps(self, random_exploration = False): critic2_losses = [] obs = self.obs + if isinstance(obs, dict): + obs = self.obs['obs'] + + next_obs_processed = obs.clone() + for s in range(self.num_steps_per_episode): self.set_eval() if random_exploration: @@ -466,16 +491,17 @@ def play_steps(self, random_exploration = False): self.current_rewards = self.current_rewards * not_dones self.current_lengths = self.current_lengths * not_dones - if isinstance(obs, dict): - obs = obs['obs'] if isinstance(next_obs, dict): - next_obs = next_obs['obs'] + next_obs_processed = next_obs['obs'] + + self.obs = next_obs.clone() rewards = self.rewards_shaper(rewards) - self.replay_buffer.add(obs, action, torch.unsqueeze(rewards, 1), next_obs, torch.unsqueeze(dones, 1)) + self.replay_buffer.add(obs, action, torch.unsqueeze(rewards, 1), next_obs_processed, torch.unsqueeze(dones, 1)) - self.obs = obs = next_obs.clone() + if isinstance(obs, dict): + obs = self.obs['obs'] if not random_exploration: self.set_train() @@ -505,10 +531,9 @@ def train_epoch(self): def train(self): self.init_tensors() self.algo_observer.after_init(self) - self.last_mean_rewards = -100500 total_time = 0 # rep_count = 0 - self.frame = 0 + self.obs = self.env_reset() while True: @@ -560,7 +585,7 @@ def train(self): should_exit = False if self.save_freq > 0: - if (self.epoch_num % self.save_freq) == 0: + if self.epoch_num % self.save_freq == 0: self.save(os.path.join(self.nn_dir, 'last_' + checkpoint_name)) if mean_rewards > self.last_mean_rewards and self.epoch_num >= self.save_best_after: diff --git a/rl_games/common/a2c_common.py b/rl_games/common/a2c_common.py index 2080c684..a25a654e 100644 --- a/rl_games/common/a2c_common.py +++ b/rl_games/common/a2c_common.py @@ -161,6 +161,7 @@ def __init__(self, base_name, params): self.rnn_states = None self.name = base_name + # TODO: do we still need it? self.ppo = config.get('ppo', True) self.max_epochs = self.config.get('max_epochs', -1) self.max_frames = self.config.get('max_frames', -1) @@ -229,22 +230,26 @@ def __init__(self, base_name, params): self.game_lengths = torch_ext.AverageMeter(1, self.games_to_track).to(self.ppo_device) self.obs = None self.games_num = self.config['minibatch_size'] // self.seq_len # it is used only for current rnn implementation + self.batch_size = self.horizon_length * self.num_actors * self.num_agents self.batch_size_envs = self.horizon_length * self.num_actors + assert(('minibatch_size_per_env' in self.config) or ('minibatch_size' in self.config)) self.minibatch_size_per_env = self.config.get('minibatch_size_per_env', 0) self.minibatch_size = self.config.get('minibatch_size', self.num_actors * self.minibatch_size_per_env) - self.mini_epochs_num = self.config['mini_epochs'] + self.num_minibatches = self.batch_size // self.minibatch_size assert(self.batch_size % self.minibatch_size == 0) + self.mini_epochs_num = self.config['mini_epochs'] + self.mixed_precision = self.config.get('mixed_precision', False) self.scaler = torch.cuda.amp.GradScaler(enabled=self.mixed_precision) self.last_lr = self.config['learning_rate'] self.frame = 0 self.update_time = 0 - self.mean_rewards = self.last_mean_rewards = -100500 + self.mean_rewards = self.last_mean_rewards = -1000000000 self.play_time = 0 self.epoch_num = 0 self.curr_frames = 0 @@ -588,10 +593,11 @@ def train_central_value(self): def get_full_state_weights(self): state = self.get_weights() state['epoch'] = self.epoch_num + state['frame'] = self.frame state['optimizer'] = self.optimizer.state_dict() + if self.has_central_value: state['assymetric_vf_nets'] = self.central_value_net.state_dict() - state['frame'] = self.frame # This is actually the best reward ever achieved. last_mean_rewards is perhaps not the best variable name # We save it to the checkpoint to prevent overriding the "best ever" checkpoint upon experiment restart @@ -603,19 +609,22 @@ def get_full_state_weights(self): return state - def set_full_state_weights(self, weights): + def set_full_state_weights(self, weights, set_epoch=True): + self.set_weights(weights) - self.epoch_num = weights['epoch'] # frames as well? + if set_epoch: + self.epoch_num = weights['epoch'] + self.frame = weights['frame'] + if self.has_central_value: self.central_value_net.load_state_dict(weights['assymetric_vf_nets']) self.optimizer.load_state_dict(weights['optimizer']) - self.frame = weights.get('frame', 0) - self.last_mean_rewards = weights.get('last_mean_rewards', -100500) - env_state = weights.get('env_state', None) + self.last_mean_rewards = weights.get('last_mean_rewards', -1000000000) if self.vec_env is not None: + env_state = weights.get('env_state', None) self.vec_env.set_env_state(env_state) def get_weights(self): @@ -651,6 +660,55 @@ def set_weights(self, weights): self.model.load_state_dict(weights['model'], strict=False) self.set_stats_weights(weights) + def get_param(self, param_name): + if param_name in [ + "grad_norm", + "critic_coef", + "bounds_loss_coef", + "entropy_coef", + "kl_threshold", + "gamma", + "tau", + "mini_epochs_num", + "e_clip", + ]: + return getattr(self, param_name) + elif param_name == "learning_rate": + return self.last_lr + else: + raise NotImplementedError(f"Can't get param {param_name}") + + def set_param(self, param_name, param_value): + if param_name in [ + "grad_norm", + "critic_coef", + "bounds_loss_coef", + "entropy_coef", + "gamma", + "tau", + "mini_epochs_num", + "e_clip", + ]: + setattr(self, param_name, param_value) + elif param_name == "learning_rate": + if self.global_rank == 0: + if self.is_adaptive_lr: + raise NotImplementedError("Can't directly mutate LR on this schedule") + else: + self.learning_rate = param_value + + for param_group in self.optimizer.param_groups: + param_group["lr"] = self.learning_rate + elif param_name == "kl_threshold": + if self.global_rank == 0: + if self.is_adaptive_lr: + self.kl_threshold = param_value + self.scheduler.kl_threshold = param_value + else: + raise NotImplementedError("Can't directly mutate kl threshold") + else: + raise NotImplementedError(f"No param found for {param_value}") + def _preproc_obs(self, obs_batch): if type(obs_batch) is dict: obs_batch = copy.copy(obs_batch) @@ -912,7 +970,6 @@ def prepare_dataset(self, batch_dict): values = self.value_mean_std(values) returns = self.value_mean_std(returns) self.value_mean_std.eval() - advantages = torch.sum(advantages, axis=1) @@ -1024,7 +1081,7 @@ def train(self): checkpoint_name = self.config['name'] + '_ep_' + str(epoch_num) + '_rew_' + str(mean_rewards[0]) if self.save_freq > 0: - if (epoch_num % self.save_freq) == 0: + if epoch_num % self.save_freq == 0: self.save(os.path.join(self.nn_dir, 'last_' + checkpoint_name)) if mean_rewards[0] > self.last_mean_rewards and epoch_num >= self.save_best_after: @@ -1301,7 +1358,7 @@ def train(self): checkpoint_name = self.config['name'] + '_ep_' + str(epoch_num) + '_rew_' + str(mean_rewards[0]) if self.save_freq > 0: - if (epoch_num % self.save_freq) == 0: + if epoch_num % self.save_freq == 0: self.save(os.path.join(self.nn_dir, 'last_' + checkpoint_name)) if mean_rewards[0] > self.last_mean_rewards and epoch_num >= self.save_best_after: diff --git a/rl_games/common/player.py b/rl_games/common/player.py index d41fad0b..98be6501 100644 --- a/rl_games/common/player.py +++ b/rl_games/common/player.py @@ -1,8 +1,13 @@ +import os +import shutil +import threading import time import gym import numpy as np import torch import copy +from os.path import basename +from typing import Optional from rl_games.common import vecenv from rl_games.common import env_configurations from rl_games.algos_torch import model_builder @@ -71,6 +76,90 @@ def __init__(self, params): self.max_steps = 108000 // 4 self.device = torch.device(self.device_name) + self.evaluation = self.player_config.get("evaluation", False) + self.update_checkpoint_freq = self.player_config.get("update_checkpoint_freq", 100) + # if we run player as evaluation worker this will take care of loading new checkpoints + self.dir_to_monitor = self.player_config.get("dir_to_monitor") + # path to the newest checkpoint + self.checkpoint_to_load: Optional[str] = None + + if self.evaluation and self.dir_to_monitor is not None: + self.checkpoint_mutex = threading.Lock() + self.eval_checkpoint_dir = os.path.join(self.dir_to_monitor, "eval_checkpoints") + os.makedirs(self.eval_checkpoint_dir, exist_ok=True) + + patterns = ["*.pth"] + from watchdog.observers import Observer + from watchdog.events import PatternMatchingEventHandler + self.file_events = PatternMatchingEventHandler(patterns) + self.file_events.on_created = self.on_file_created + self.file_events.on_modified = self.on_file_modified + + self.file_observer = Observer() + self.file_observer.schedule(self.file_events, self.dir_to_monitor, recursive=False) + self.file_observer.start() + + def wait_for_checkpoint(self): + if self.dir_to_monitor is None: + return + + attempt = 0 + while True: + attempt += 1 + with self.checkpoint_mutex: + if self.checkpoint_to_load is not None: + if attempt % 10 == 0: + print(f"Evaluation: waiting for new checkpoint in {self.dir_to_monitor}...") + break + time.sleep(1.0) + + print(f"Checkpoint {self.checkpoint_to_load} is available!") + + def maybe_load_new_checkpoint(self): + # lock mutex while loading new checkpoint + with self.checkpoint_mutex: + if self.checkpoint_to_load is not None: + print(f"Evaluation: loading new checkpoint {self.checkpoint_to_load}...") + # try if we can load anything from the pth file, this will quickly fail if the file is corrupted + # without triggering the retry loop in "safe_filesystem_op()" + load_error = False + try: + torch.load(self.checkpoint_to_load) + except Exception as e: + print(f"Evaluation: checkpoint file is likely corrupted {self.checkpoint_to_load}: {e}") + load_error = True + + if not load_error: + try: + self.restore(self.checkpoint_to_load) + except Exception as e: + print(f"Evaluation: failed to load new checkpoint {self.checkpoint_to_load}: {e}") + + # whether we succeeded or not, forget about this checkpoint + self.checkpoint_to_load = None + + def process_new_eval_checkpoint(self, path): + with self.checkpoint_mutex: + # print(f"New checkpoint {path} available for evaluation") + # copy file to eval_checkpoints dir using shutil + # since we're running the evaluation worker in a separate process, + # there is a chance that the file is changed/corrupted while we're copying it + # not sure what we can do about this. In practice it never happened so far though + try: + eval_checkpoint_path = os.path.join(self.eval_checkpoint_dir, basename(path)) + shutil.copyfile(path, eval_checkpoint_path) + except Exception as e: + print(f"Failed to copy {path} to {eval_checkpoint_path}: {e}") + return + + self.checkpoint_to_load = eval_checkpoint_path + + def on_file_created(self, event): + self.process_new_eval_checkpoint(event.src_path) + + def on_file_modified(self, event): + self.process_new_eval_checkpoint(event.src_path) + def load_networks(self, params): builder = model_builder.ModelBuilder() self.config['network'] = builder.load(params) @@ -204,6 +293,8 @@ def run(self): if has_masks_func: has_masks = self.env.has_action_mask() + self.wait_for_checkpoint() + need_init_rnn = self.is_rnn for _ in range(n_games): if games_played >= n_games: @@ -223,6 +314,9 @@ def run(self): print_game_res = False for n in range(self.max_steps): + if self.evaluation and n % self.update_checkpoint_freq == 0: + self.maybe_load_new_checkpoint() + if has_masks: masks = self.env.get_action_mask() action = self.get_masked_action( @@ -270,9 +364,9 @@ def run(self): cur_rewards_done = cur_rewards/done_count cur_steps_done = cur_steps/done_count if print_game_res: - print(f'reward: {cur_rewards_done:.4} steps: {cur_steps_done:.4} w: {game_res}') + print(f'reward: {cur_rewards_done:.2f} steps: {cur_steps_done:.1f} w: {game_res}') else: - print(f'reward: {cur_rewards_done:.4} steps: {cur_steps_done:.4f}') + print(f'reward: {cur_rewards_done:.2f} steps: {cur_steps_done:.1f}') sum_game_res += game_res if batch_size//self.num_agents == 1 or games_played >= n_games: diff --git a/rl_games/common/vecenv.py b/rl_games/common/vecenv.py index d6261656..50ab4999 100644 --- a/rl_games/common/vecenv.py +++ b/rl_games/common/vecenv.py @@ -8,6 +8,7 @@ from time import sleep import torch + class RayWorker: def __init__(self, config_name, config): self.env = configurations[config_name]['env_creator'](**config) diff --git a/rl_games/configs/mujoco/sac_ant_envpool.yaml b/rl_games/configs/mujoco/sac_ant_envpool.yaml index e8390ed5..719e8b8f 100644 --- a/rl_games/configs/mujoco/sac_ant_envpool.yaml +++ b/rl_games/configs/mujoco/sac_ant_envpool.yaml @@ -29,7 +29,7 @@ params: max_epochs: 10000 num_steps_per_episode: 8 save_best_after: 500 - save_frequency: 10000 + save_frequency: 1000 gamma: 0.99 init_alpha: 1 alpha_lr: 5e-3 diff --git a/rl_games/interfaces/base_algorithm.py b/rl_games/interfaces/base_algorithm.py index 054483f7..0edf4504 100644 --- a/rl_games/interfaces/base_algorithm.py +++ b/rl_games/interfaces/base_algorithm.py @@ -1,6 +1,7 @@ from abc import ABC from abc import abstractmethod, abstractproperty + class BaseAlgorithm(ABC): def __init__(self, base_name, config): pass @@ -26,7 +27,7 @@ def get_full_state_weights(self): pass @abstractmethod - def set_full_state_weights(self, weights): + def set_full_state_weights(self, weights, set_epoch): pass @abstractmethod @@ -37,5 +38,14 @@ def get_weights(self): def set_weights(self, weights): pass + # Get algo training parameters + @abstractmethod + def get_param(self, param_name): + pass + + # Set algo training parameters + @abstractmethod + def set_param(self, param_name, param_value): + pass diff --git a/rl_games/torch_runner.py b/rl_games/torch_runner.py index e8ec70ee..39b13edb 100644 --- a/rl_games/torch_runner.py +++ b/rl_games/torch_runner.py @@ -131,11 +131,8 @@ def reset(self): pass def run(self, args): - load_path = None - if args['train']: self.run_train(args) - elif args['play']: self.run_play(args) else: diff --git a/runner.py b/runner.py index 25f79af4..c1188751 100644 --- a/runner.py +++ b/runner.py @@ -50,11 +50,9 @@ except yaml.YAMLError as exc: print(exc) - #rank = int(os.getenv("LOCAL_RANK", "0")) global_rank = int(os.getenv("RANK", "0")) if args["track"] and global_rank == 0: import wandb - wandb.init( project=args["wandb_project_name"], entity=args["wandb_entity"], diff --git a/setup.py b/setup.py index 9ce0cd32..475a555d 100644 --- a/setup.py +++ b/setup.py @@ -45,6 +45,7 @@ 'tensorboardX>=1.6', 'setproctitle', 'psutil', - 'pyyaml' + 'pyyaml', + 'watchdog>=2.1.9,<3.0.0', # for evaluation process ], )