diff --git a/scripts/rnnlm/get_best_model.py b/scripts/rnnlm/get_best_model.py index e8c6bd8a2f4..333ed8dbfc7 100755 --- a/scripts/rnnlm/get_best_model.py +++ b/scripts/rnnlm/get_best_model.py @@ -3,14 +3,14 @@ # Copyright 2017 Johns Hopkins University (author: Daniel Povey) # License: Apache 2.0. -import os import argparse -import sys +import glob import re +import sys parser = argparse.ArgumentParser(description="Works out the best iteration of RNNLM training " - "based on dev-set perplexity, and prints the number corresponding " - "to that iteration", + "based on dev-set perplexity, and prints the number corresponding " + "to that iteration", epilog="E.g. " + sys.argv[0] + " exp/rnnlm_a", formatter_class=argparse.ArgumentDefaultsHelpFormatter) @@ -19,8 +19,7 @@ args = parser.parse_args() - -num_iters=None +num_iters = None try: with open(args.rnnlm_dir + "/info.txt", encoding="latin-1") as f: for line in f: @@ -36,15 +35,15 @@ sys.exit(sys.argv[0] + ": could not get num_iters from {0}/info.txt".format( args.rnnlm_dir)) -best_objf=-2000 -best_iter=-1 +best_objf = -2000 +best_iter = -1 for i in range(1, num_iters): this_logfile = "{0}/log/compute_prob.{1}.log".format(args.rnnlm_dir, i) try: f = open(this_logfile, 'r', encoding='latin-1') except: sys.exit(sys.argv[0] + ": could not open log-file {0}".format(this_logfile)) - this_objf=-1000 + this_objf = -1000 for line in f: m = re.search('Overall objf .* (\S+)$', str(line)) if m is not None: @@ -53,6 +52,10 @@ except Exception as e: sys.exit(sys.argv[0] + ": line in file {0} could not be parsed: {1}, error is: {2}".format( this_logfile, line, str(e))) + # verify this iteration still has model files present + if len(glob.glob("{0}/{1}.raw".format(args.rnnlm_dir, i))) == 0: + # this iteration has log files, but model files have been cleaned up, skip it + continue if this_objf == -1000: print(sys.argv[0] + ": warning: could not parse objective function from {0}".format( this_logfile), file=sys.stderr) @@ -63,5 +66,4 @@ if best_iter == -1: sys.exit(sys.argv[0] + ": error: could not get best iteration.") - print(str(best_iter)) diff --git a/scripts/rnnlm/rnnlm_cleanup.py b/scripts/rnnlm/rnnlm_cleanup.py new file mode 100644 index 00000000000..40cbee7a496 --- /dev/null +++ b/scripts/rnnlm/rnnlm_cleanup.py @@ -0,0 +1,160 @@ +#!/usr/bin/env python3 + +# Copyright 2018 Tilde +# License: Apache 2.0 + +import sys + +import argparse +import os +import re +import glob + +script_name = sys.argv[0] + +parser = argparse.ArgumentParser(description="Removes models from past training iterations of " + "RNNLM. Can use either 'keep_latest' (default) or " + "'keep_best' cleanup strategy, where former keeps " + "the models that are freshest, while latter keeps " + "the models with best training objective score on " + "dev set.", + epilog="E.g. " + script_name + " exp/rnnlm_a --keep_best", + formatter_class=argparse.ArgumentDefaultsHelpFormatter) + +parser.add_argument("rnnlm_dir", + help="Directory where the RNNLM has been trained") +parser.add_argument("--iters_to_keep", + help="Max number of iterations to keep", + type=int, + default=3) +parser.add_argument("--keep_latest", + help="Keeps the training iterations that are latest by age", + action="store_const", + const=True, + default=False) +parser.add_argument("--keep_best", + help="Keeps the training iterations that have the best objf", + action="store_const", + const=True, + default=False) + +args = parser.parse_args() + +# validate arguments +if args.keep_latest and args.keep_best: + sys.exit(script_name + ": can only use one of 'keep_latest' or 'keep_best', but not both") +elif not args.keep_latest and not args.keep_best: + sys.exit(script_name + ": no cleanup strategy specified: use 'keep_latest' or 'keep_best'") + + +class IterationInfo: + def __init__(self, model_files, objf, compute_prob_done): + self.model_files = model_files + self.objf = objf + self.compute_prob_done = compute_prob_done + + def __str__(self): + return "{model_files: %s, compute_prob: %s, objf: %2.3f}" % (self.model_files, + self.compute_prob_done, + self.objf) + + def __repr__(self): + return self.__str__() + + +def get_compute_prob_info(log_file): + # we want to know 3 things: iteration number, objf and whether compute prob is done + iteration = int(log_file.split(".")[-2]) + objf = -2000 + compute_prob_done = False + # roughly based on code in get_best_model.py + try: + f = open(log_file, "r", encoding="latin-1") + except: + print(script_name + ": warning: compute_prob log not found for iteration " + + str(iter) + ". Skipping", + file=sys.stderr) + return iteration, objf, compute_prob_done + for line in f: + objf_m = re.search('Overall objf .* (\S+)$', str(line)) + if objf_m is not None: + try: + objf = float(objf_m.group(1)) + except Exception as e: + sys.exit(script_name + ": line in file {0} could not be parsed: {1}, error is: {2}".format( + log_file, line, str(e))) + if "# Ended" in line: + compute_prob_done = True + if objf == -2000: + print(script_name + ": warning: could not parse objective function from " + log_file, file=sys.stderr) + return iteration, objf, compute_prob_done + + +def get_iteration_files(exp_dir): + iterations = dict() + compute_prob_logs = glob.glob(exp_dir + "/log/compute_prob.[0-9]*.log") + for log in compute_prob_logs: + iteration, objf, compute_prob_done = get_compute_prob_info(log) + if iteration == 0: + # iteration 0 is special, never consider it for cleanup + continue + if compute_prob_done: + # this iteration can be safely considered for cleanup + # gather all model files belonging to it + model_files = [] + # when there are multiple jobs per iteration, there can be several model files + # we need to potentially clean them all up without mixing them up + model_files.extend(glob.glob("{0}/word_embedding.{1}.mat".format(exp_dir, iteration))) + model_files.extend(glob.glob("{0}/word_embedding.{1}.[0-9]*.mat".format(exp_dir, iteration))) + model_files.extend(glob.glob("{0}/feat_embedding.{1}.mat".format(exp_dir, iteration))) + model_files.extend(glob.glob("{0}/feat_embedding.{1}.[0-9]*.mat".format(exp_dir, iteration))) + model_files.extend(glob.glob("{0}/{1}.raw".format(exp_dir, iteration))) + model_files.extend(glob.glob("{0}/{1}.[0-9]*.raw".format(exp_dir, iteration))) + # compute_prob logs outlive model files, only consider iterations that do still have model files + if len(model_files) > 0: + iterations[iteration] = IterationInfo(model_files, objf, compute_prob_done) + return iterations + + +def remove_model_files_for_iter(iter_info): + for f in iter_info.model_files: + os.remove(f) + + +def keep_latest(iteration_dict): + max_to_keep = args.iters_to_keep + kept = 0 + iterations_in_reverse_order = reversed(sorted(iteration_dict)) + for iter in iterations_in_reverse_order: + if kept < max_to_keep: + kept += 1 + else: + remove_model_files_for_iter(iteration_dict[iter]) + + +def keep_best(iteration_dict): + iters_to_keep = args.iters_to_keep + best = [] + for iter, iter_info in iteration_dict.items(): + objf = iter_info.objf + if objf == -2000: + print(script_name + ": warning: objf unavailable for iter " + str(iter), file=sys.stderr) + continue + # add potential best, sort by objf, trim to iters_to_keep size + best.append((iter, objf)) + best = sorted(best, key=lambda x: -x[1]) + if len(best) > iters_to_keep: + throwaway = best[iters_to_keep:] + best = best[:iters_to_keep] + # remove iters that we know are not the best + for (iter, _) in throwaway: + remove_model_files_for_iter(iteration_dict[iter]) + + +# grab all the iterations mapped to their model files, objf score and compute_prob status +iterations = get_iteration_files(args.rnnlm_dir) +# apply chosen cleanup strategy +if args.keep_latest: + keep_latest(iterations) +else: + keep_best(iterations) diff --git a/scripts/rnnlm/train_rnnlm.sh b/scripts/rnnlm/train_rnnlm.sh index aedfc470ac9..d6d38f3d734 100755 --- a/scripts/rnnlm/train_rnnlm.sh +++ b/scripts/rnnlm/train_rnnlm.sh @@ -38,6 +38,11 @@ num_egs_threads=10 # number of threads used for sampling, if we're using use_gpu=true # use GPU for training use_gpu_for_diagnostics=false # set true to use GPU for compute_prob_*.log +# optional cleanup options +cleanup=false # add option --cleanup true to enable automatic cleanup of old models +cleanup_strategy="keep_latest" # determines cleanup strategy, use either "keep_latest" or "keep_best" +cleanup_keep_iters=3 # number of iterations that will have their models retained + trap 'for pid in $(jobs -pr); do kill -KILL $pid; done' INT QUIT TERM . utils/parse_options.sh @@ -222,12 +227,16 @@ while [ $x -lt $num_iters ]; do nnet3-average $src_models $dir/$[x+1].raw '&&' \ matrix-sum --average=true $src_matrices $dir/${embedding_type}_embedding.$[x+1].mat fi + # optionally, perform cleanup after training + if [ "$cleanup" = true ] ; then + python3 rnnlm/rnnlm_cleanup.py $dir --$cleanup_strategy --iters_to_keep $cleanup_keep_iters + fi ) - # the error message below is not that informative, but $cmd will # have printed a more specific one. [ -f $dir/.error ] && echo "$0: error with diagnostics on iteration $x of training" && exit 1; fi + x=$[x+1] num_splits_processed=$[num_splits_processed+this_num_jobs] done