Skip to content

Commit

Permalink
Merge pull request #455 from kengz/resume
Browse files Browse the repository at this point in the history
`train` mode with resume; `enjoy` mode refactor
  • Loading branch information
kengz authored Apr 14, 2020
2 parents b608395 + 111af12 commit 7605a82
Show file tree
Hide file tree
Showing 24 changed files with 162 additions and 203 deletions.
51 changes: 37 additions & 14 deletions run_lab.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
# The SLM Lab entrypoint
from glob import glob
from slm_lab import EVAL_MODES, TRAIN_MODES
from slm_lab.experiment import search
from slm_lab.experiment.control import Session, Trial, Experiment
from slm_lab.lib import logger, util
from slm_lab.spec import spec_util
import os
import pydash as ps
import sys
import torch
import torch.multiprocessing as mp


Expand All @@ -19,34 +18,58 @@
logger = logger.get_logger(__name__)


def get_spec(spec_file, spec_name, lab_mode, pre_):
'''Get spec using args processed from inputs'''
if lab_mode in TRAIN_MODES:
if pre_ is None: # new train trial
spec = spec_util.get(spec_file, spec_name)
else:
# for resuming with train@{predir}
# e.g. train@latest (fill find the latest predir)
# e.g. train@data/reinforce_cartpole_2020_04_13_232521
predir = pre_
if predir == 'latest':
predir = sorted(glob(f'data/{spec_name}*/'))[-1] # get the latest predir with spec_name
_, _, _, _, experiment_ts = util.prepath_split(predir) # get experiment_ts to resume train spec
logger.info(f'Resolved to train@{predir}')
spec = spec_util.get(spec_file, spec_name, experiment_ts)
elif lab_mode == 'enjoy':
# for enjoy@{session_spec_file}
# e.g. enjoy@data/reinforce_cartpole_2020_04_13_232521/reinforce_cartpole_t0_s0_spec.json
session_spec_file = pre_
assert session_spec_file is not None, 'enjoy mode must specify a `enjoy@{session_spec_file}`'
spec = util.read(f'{session_spec_file}')
else:
raise ValueError(f'Unrecognizable lab_mode not of {TRAIN_MODES} or {EVAL_MODES}')
return spec


def run_spec(spec, lab_mode):
'''Run a spec in lab_mode'''
os.environ['lab_mode'] = lab_mode
os.environ['lab_mode'] = lab_mode # set lab_mode
spec = spec_util.override_spec(spec, lab_mode) # conditionally override spec
if lab_mode in TRAIN_MODES:
spec_util.save(spec) # first save the new spec
if lab_mode == 'dev':
spec = spec_util.override_dev_spec(spec)
if lab_mode == 'search':
spec_util.tick(spec, 'experiment')
Experiment(spec).run()
else:
spec_util.tick(spec, 'trial')
Trial(spec).run()
elif lab_mode in EVAL_MODES:
spec = spec_util.override_enjoy_spec(spec)
Session(spec).run()
else:
raise ValueError(f'Unrecognizable lab_mode not of {TRAIN_MODES} or {EVAL_MODES}')


def read_spec_and_run(spec_file, spec_name, lab_mode):
def get_spec_and_run(spec_file, spec_name, lab_mode):
'''Read a spec and run it in lab mode'''
logger.info(f'Running lab spec_file:{spec_file} spec_name:{spec_name} in mode:{lab_mode}')
if lab_mode in TRAIN_MODES:
spec = spec_util.get(spec_file, spec_name)
else: # eval mode
lab_mode, prename = lab_mode.split('@')
spec = spec_util.get_eval_spec(spec_file, prename)
if '@' in lab_mode: # process lab_mode@{predir/prename}
lab_mode, pre_ = lab_mode.split('@')
else:
pre_ = None
spec = get_spec(spec_file, spec_name, lab_mode, pre_)

if 'spec_params' not in spec:
run_spec(spec, lab_mode)
Expand All @@ -62,10 +85,10 @@ def main():
job_file = args[0] if len(args) == 1 else 'job/experiments.json'
for spec_file, spec_and_mode in util.read(job_file).items():
for spec_name, lab_mode in spec_and_mode.items():
read_spec_and_run(spec_file, spec_name, lab_mode)
get_spec_and_run(spec_file, spec_name, lab_mode)
else: # run single spec
assert len(args) == 3, f'To use sys args, specify spec_file, spec_name, lab_mode'
read_spec_and_run(*args)
get_spec_and_run(*args)


if __name__ == '__main__':
Expand Down
26 changes: 17 additions & 9 deletions slm_lab/agent/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def act(self, state):
def update(self, state, action, reward, next_state, done):
'''Update per timestep after env transitions, e.g. memory, algorithm, update agent params, train net'''
self.body.update(state, action, reward, next_state, done)
if util.in_eval_lab_modes(): # eval does not update agent for training
if util.in_eval_lab_mode(): # eval does not update agent for training
return
self.body.memory.update(state, action, reward, next_state, done)
loss = self.algorithm.train()
Expand All @@ -59,7 +59,7 @@ def update(self, state, action, reward, next_state, done):
@lab_api
def save(self, ckpt=None):
'''Save agent'''
if util.in_eval_lab_modes(): # eval does not save new models
if util.in_eval_lab_mode(): # eval does not save new models
return
self.algorithm.save(ckpt=ckpt)

Expand Down Expand Up @@ -103,8 +103,16 @@ def __init__(self, env, spec, aeb=(0, 0, 0)):
self.train_df = pd.DataFrame(columns=[
'epi', 't', 'wall_t', 'opt_step', 'frame', 'fps', 'total_reward', 'total_reward_ma', 'loss', 'lr',
'explore_var', 'entropy_coef', 'entropy', 'grad_norm'])

# in train@ mode, override from saved train_df if exists
if util.in_train_lab_mode() and self.spec['meta']['resume']:
train_df_filepath = util.get_session_df_path(self.spec, 'train')
if os.path.exists(train_df_filepath):
self.train_df = util.read(train_df_filepath)
self.env.clock.load(self.train_df)

# track eval data within run_eval. the same as train_df except for reward
if ps.get(self.spec, 'meta.rigorous_eval'):
if self.spec['meta']['rigorous_eval']:
self.eval_df = self.train_df.copy()
else:
self.eval_df = self.train_df
Expand Down Expand Up @@ -178,6 +186,7 @@ def ckpt(self, env, df_mode):
df = getattr(self, f'{df_mode}_df')
df.loc[len(df)] = row # append efficiently to df
df.iloc[-1]['total_reward_ma'] = total_reward_ma = df[-viz.PLOT_MA_WINDOW:]['total_reward'].mean()
df.drop_duplicates('frame', inplace=True) # remove any duplicates by the same frame
self.total_reward_ma = total_reward_ma

def get_mean_lr(self):
Expand All @@ -192,10 +201,9 @@ def get_mean_lr(self):

def get_log_prefix(self):
'''Get the prefix for logging'''
spec = self.agent.spec
spec_name = spec['name']
trial_index = spec['meta']['trial']
session_index = spec['meta']['session']
spec_name = self.spec['name']
trial_index = self.spec['meta']['trial']
session_index = self.spec['meta']['session']
prefix = f'Trial {trial_index} session {session_index} {spec_name}_t{trial_index}_s{session_index}'
return prefix

Expand Down Expand Up @@ -232,8 +240,8 @@ def log_tensorboard(self):
self.tb_actions = [] # store actions for tensorboard
logger.info(f'Using TensorBoard logging for dev mode. Run `tensorboard --logdir={log_prepath}` to start TensorBoard.')

trial_index = self.agent.spec['meta']['trial']
session_index = self.agent.spec['meta']['session']
trial_index = self.spec['meta']['trial']
session_index = self.spec['meta']['session']
if session_index != 0: # log only session 0
return
idx_suffix = f'trial{trial_index}_session{session_index}'
Expand Down
4 changes: 1 addition & 3 deletions slm_lab/agent/algorithm/actor_critic.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ def init_nets(self, global_nets=None):
self.critic_optim = net_util.get_optim(self.critic_net, self.critic_net.optim_spec)
self.critic_lr_scheduler = net_util.get_lr_scheduler(self.critic_optim, self.critic_net.lr_scheduler_spec)
net_util.set_global_nets(self, global_nets)
self.post_init_nets()
self.end_init_nets()

@lab_api
def calc_pdparam(self, x, net=None):
Expand Down Expand Up @@ -278,8 +278,6 @@ def calc_val_loss(self, v_preds, v_targets):

def train(self):
'''Train actor critic by computing the loss in batch efficiently'''
if util.in_eval_lab_modes():
return np.nan
clock = self.body.env.clock
if self.to_train == 1:
batch = self.sample()
Expand Down
21 changes: 10 additions & 11 deletions slm_lab/agent/algorithm/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def __init__(self, agent, global_nets=None):
self.body = self.agent.body
self.init_algorithm_params()
self.init_nets(global_nets)
logger.info(util.self_desc(self))
logger.info(util.self_desc(self, omit=['algorithm_spec', 'name', 'memory_spec', 'net_spec', 'body']))

@abstractmethod
@lab_api
Expand All @@ -37,19 +37,20 @@ def init_nets(self, global_nets=None):
raise NotImplementedError

@lab_api
def post_init_nets(self):
'''
Method to conditionally load models.
Call at the end of init_nets() after setting self.net_names
'''
def end_init_nets(self):
'''Checkers and conditional loaders called at the end of init_nets()'''
# check all nets naming
assert hasattr(self, 'net_names')
for net_name in self.net_names:
assert net_name.endswith('net'), f'Naming convention: net_name must end with "net"; got {net_name}'
if util.in_eval_lab_modes():

# load algorithm if is in train@ resume or enjoy mode
lab_mode = util.get_lab_mode()
if self.agent.spec['meta']['resume'] or lab_mode == 'enjoy':
self.load()
logger.info(f'Loaded algorithm models for lab_mode: {util.get_lab_mode()}')
logger.info(f'Loaded algorithm models for lab_mode: {lab_mode}')
else:
logger.info(f'Initialized algorithm models for lab_mode: {util.get_lab_mode()}')
logger.info(f'Initialized algorithm models for lab_mode: {lab_mode}')

@lab_api
def calc_pdparam(self, x, net=None):
Expand All @@ -76,8 +77,6 @@ def sample(self):
@lab_api
def train(self):
'''Implement algorithm train, or throw NotImplementedError'''
if util.in_eval_lab_modes():
return np.nan
raise NotImplementedError

@abstractmethod
Expand Down
10 changes: 3 additions & 7 deletions slm_lab/agent/algorithm/dqn.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
from slm_lab.agent import net
from slm_lab.agent.algorithm import policy_util
from slm_lab.agent.algorithm.sarsa import SARSA
from slm_lab.agent.net import net_util
from slm_lab.lib import logger, math_util, util
from slm_lab.lib import logger, util
from slm_lab.lib.decorator import lab_api
import numpy as np
import pydash as ps
import torch

logger = logger.get_logger(__name__)
Expand Down Expand Up @@ -87,7 +85,7 @@ def init_nets(self, global_nets=None):
self.optim = net_util.get_optim(self.net, self.net.optim_spec)
self.lr_scheduler = net_util.get_lr_scheduler(self.optim, self.net.lr_scheduler_spec)
net_util.set_global_nets(self, global_nets)
self.post_init_nets()
self.end_init_nets()

def calc_q_loss(self, batch):
'''Compute the Q value loss using predicted and target Q values from the appropriate networks'''
Expand Down Expand Up @@ -130,8 +128,6 @@ def train(self):
For each of the batches, the target Q values (q_targets) are computed and a single training step is taken k times
Otherwise this function does nothing.
'''
if util.in_eval_lab_modes():
return np.nan
clock = self.body.env.clock
if self.to_train == 1:
total_loss = torch.tensor(0.0)
Expand Down Expand Up @@ -187,7 +183,7 @@ def init_nets(self, global_nets=None):
self.optim = net_util.get_optim(self.net, self.net.optim_spec)
self.lr_scheduler = net_util.get_lr_scheduler(self.optim, self.net.lr_scheduler_spec)
net_util.set_global_nets(self, global_nets)
self.post_init_nets()
self.end_init_nets()
self.online_net = self.target_net
self.eval_net = self.target_net

Expand Down
7 changes: 3 additions & 4 deletions slm_lab/agent/algorithm/policy_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from slm_lab.lib import distribution, logger, math_util, util
from torch import distributions
import numpy as np
import pydash as ps
import torch

logger = logger.get_logger(__name__)
Expand Down Expand Up @@ -61,7 +60,7 @@ def guard_tensor(state, body):
if isinstance(state, LazyFrames):
state = state.__array__() # realize data
state = torch.from_numpy(state.astype(np.float32))
if not body.env.is_venv or util.in_eval_lab_modes():
if not body.env.is_venv:
# singleton state, unsqueeze as minibatch for net input
state = state.unsqueeze(dim=0)
return state
Expand Down Expand Up @@ -142,7 +141,7 @@ def default(state, algorithm, body):

def random(state, algorithm, body):
'''Random action using gym.action_space.sample(), with the same format as default()'''
if body.env.is_venv and not util.in_eval_lab_modes():
if body.env.is_venv:
_action = [body.action_space.sample() for _ in range(body.env.num_envs)]
else:
_action = [body.action_space.sample()]
Expand Down Expand Up @@ -269,7 +268,7 @@ def __init__(self, var_decay_spec=None):

def update(self, algorithm, clock):
'''Get an updated value for var'''
if (util.in_eval_lab_modes()) or self._updater_name == 'no_decay':
if (util.in_eval_lab_mode()) or self._updater_name == 'no_decay':
return self.end_val
step = clock.get()
val = self._updater(self.start_val, self.end_val, self.start_step, self.end_step, step)
Expand Down
4 changes: 0 additions & 4 deletions slm_lab/agent/algorithm/ppo.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,11 @@
from copy import deepcopy
from slm_lab.agent import net
from slm_lab.agent.algorithm import policy_util
from slm_lab.agent.algorithm.actor_critic import ActorCritic
from slm_lab.agent.net import net_util
from slm_lab.lib import logger, math_util, util
from slm_lab.lib.decorator import lab_api
import math
import numpy as np
import pydash as ps
import torch

logger = logger.get_logger(__name__)
Expand Down Expand Up @@ -168,8 +166,6 @@ def calc_policy_loss(self, batch, pdparams, advs):
return policy_loss

def train(self):
if util.in_eval_lab_modes():
return np.nan
clock = self.body.env.clock
if self.to_train == 1:
net_util.copy(self.net, self.old_net) # update old net
Expand Down
4 changes: 2 additions & 2 deletions slm_lab/agent/algorithm/random.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# The random agent algorithm
# For basic dev purpose
from slm_lab.agent.algorithm.base import Algorithm
from slm_lab.lib import logger, util
from slm_lab.lib import logger
from slm_lab.lib.decorator import lab_api
import numpy as np

Expand Down Expand Up @@ -29,7 +29,7 @@ def init_nets(self, global_nets=None):
def act(self, state):
'''Random action'''
body = self.body
if body.env.is_venv and not util.in_eval_lab_modes():
if body.env.is_venv:
action = np.array([body.action_space.sample() for _ in range(body.env.num_envs)])
else:
action = body.action_space.sample()
Expand Down
4 changes: 1 addition & 3 deletions slm_lab/agent/algorithm/reinforce.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def init_nets(self, global_nets=None):
self.optim = net_util.get_optim(self.net, self.net.optim_spec)
self.lr_scheduler = net_util.get_lr_scheduler(self.optim, self.net.lr_scheduler_spec)
net_util.set_global_nets(self, global_nets)
self.post_init_nets()
self.end_init_nets()

@lab_api
def calc_pdparam(self, x, net=None):
Expand Down Expand Up @@ -145,8 +145,6 @@ def calc_policy_loss(self, batch, pdparams, advs):

@lab_api
def train(self):
if util.in_eval_lab_modes():
return np.nan
clock = self.body.env.clock
if self.to_train == 1:
batch = self.sample()
Expand Down
4 changes: 1 addition & 3 deletions slm_lab/agent/algorithm/sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def init_nets(self, global_nets=None):
self.alpha_optim = net_util.get_optim(self.log_alpha, self.net.optim_spec)
self.alpha_lr_scheduler = net_util.get_lr_scheduler(self.alpha_optim, self.net.lr_scheduler_spec)
net_util.set_global_nets(self, global_nets)
self.post_init_nets()
self.end_init_nets()

@lab_api
def act(self, state):
Expand Down Expand Up @@ -187,8 +187,6 @@ def train_alpha(self, alpha_loss):

def train(self):
'''Train actor critic by computing the loss in batch efficiently'''
if util.in_eval_lab_modes():
return np.nan
clock = self.body.env.clock
if self.to_train == 1:
for _ in range(self.training_iter):
Expand Down
Loading

0 comments on commit 7605a82

Please sign in to comment.