Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 12 additions & 10 deletions scripts/rnnlm/get_best_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,14 @@
# Copyright 2017 Johns Hopkins University (author: Daniel Povey)
# License: Apache 2.0.

import os
import argparse
import sys
import glob
import re
import sys

parser = argparse.ArgumentParser(description="Works out the best iteration of RNNLM training "
"based on dev-set perplexity, and prints the number corresponding "
"to that iteration",
"based on dev-set perplexity, and prints the number corresponding "
"to that iteration",
epilog="E.g. " + sys.argv[0] + " exp/rnnlm_a",
formatter_class=argparse.ArgumentDefaultsHelpFormatter)

Expand All @@ -19,8 +19,7 @@

args = parser.parse_args()


num_iters=None
num_iters = None
try:
with open(args.rnnlm_dir + "/info.txt", encoding="latin-1") as f:
for line in f:
Expand All @@ -36,15 +35,15 @@
sys.exit(sys.argv[0] + ": could not get num_iters from {0}/info.txt".format(
args.rnnlm_dir))

best_objf=-2000
best_iter=-1
best_objf = -2000
best_iter = -1
for i in range(1, num_iters):
this_logfile = "{0}/log/compute_prob.{1}.log".format(args.rnnlm_dir, i)
try:
f = open(this_logfile, 'r', encoding='latin-1')
except:
sys.exit(sys.argv[0] + ": could not open log-file {0}".format(this_logfile))
this_objf=-1000
this_objf = -1000
for line in f:
m = re.search('Overall objf .* (\S+)$', str(line))
if m is not None:
Expand All @@ -53,6 +52,10 @@
except Exception as e:
sys.exit(sys.argv[0] + ": line in file {0} could not be parsed: {1}, error is: {2}".format(
this_logfile, line, str(e)))
# verify this iteration still has model files present
if len(glob.glob("{0}/{1}.raw".format(args.rnnlm_dir, i))) == 0:
# this iteration has log files, but model files have been cleaned up, skip it
continue
if this_objf == -1000:
print(sys.argv[0] + ": warning: could not parse objective function from {0}".format(
this_logfile), file=sys.stderr)
Expand All @@ -63,5 +66,4 @@
if best_iter == -1:
sys.exit(sys.argv[0] + ": error: could not get best iteration.")


print(str(best_iter))
160 changes: 160 additions & 0 deletions scripts/rnnlm/rnnlm_cleanup.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
#!/usr/bin/env python3

# Copyright 2018 Tilde
# License: Apache 2.0

import sys

import argparse
import os
import re
import glob

script_name = sys.argv[0]

parser = argparse.ArgumentParser(description="Removes models from past training iterations of "
"RNNLM. Can use either 'keep_latest' (default) or "
"'keep_best' cleanup strategy, where former keeps "
"the models that are freshest, while latter keeps "
"the models with best training objective score on "
"dev set.",
epilog="E.g. " + script_name + " exp/rnnlm_a --keep_best",
formatter_class=argparse.ArgumentDefaultsHelpFormatter)

parser.add_argument("rnnlm_dir",
help="Directory where the RNNLM has been trained")
parser.add_argument("--iters_to_keep",
help="Max number of iterations to keep",
type=int,
default=3)
parser.add_argument("--keep_latest",
help="Keeps the training iterations that are latest by age",
action="store_const",
const=True,
default=False)
parser.add_argument("--keep_best",
help="Keeps the training iterations that have the best objf",
action="store_const",
const=True,
default=False)

args = parser.parse_args()

# validate arguments
if args.keep_latest and args.keep_best:
sys.exit(script_name + ": can only use one of 'keep_latest' or 'keep_best', but not both")
elif not args.keep_latest and not args.keep_best:
sys.exit(script_name + ": no cleanup strategy specified: use 'keep_latest' or 'keep_best'")


class IterationInfo:
def __init__(self, model_files, objf, compute_prob_done):
self.model_files = model_files
self.objf = objf
self.compute_prob_done = compute_prob_done

def __str__(self):
return "{model_files: %s, compute_prob: %s, objf: %2.3f}" % (self.model_files,
self.compute_prob_done,
self.objf)

def __repr__(self):
return self.__str__()


def get_compute_prob_info(log_file):
# we want to know 3 things: iteration number, objf and whether compute prob is done
iteration = int(log_file.split(".")[-2])
objf = -2000
compute_prob_done = False
# roughly based on code in get_best_model.py
try:
f = open(log_file, "r", encoding="latin-1")
except:
print(script_name + ": warning: compute_prob log not found for iteration " +
str(iter) + ". Skipping",
file=sys.stderr)
return iteration, objf, compute_prob_done
for line in f:
objf_m = re.search('Overall objf .* (\S+)$', str(line))
if objf_m is not None:
try:
objf = float(objf_m.group(1))
except Exception as e:
sys.exit(script_name + ": line in file {0} could not be parsed: {1}, error is: {2}".format(
log_file, line, str(e)))
if "# Ended" in line:
compute_prob_done = True
if objf == -2000:
print(script_name + ": warning: could not parse objective function from " + log_file, file=sys.stderr)
return iteration, objf, compute_prob_done


def get_iteration_files(exp_dir):
iterations = dict()
compute_prob_logs = glob.glob(exp_dir + "/log/compute_prob.[0-9]*.log")
for log in compute_prob_logs:
iteration, objf, compute_prob_done = get_compute_prob_info(log)
if iteration == 0:
# iteration 0 is special, never consider it for cleanup
continue
if compute_prob_done:
# this iteration can be safely considered for cleanup
# gather all model files belonging to it
model_files = []
# when there are multiple jobs per iteration, there can be several model files
# we need to potentially clean them all up without mixing them up
model_files.extend(glob.glob("{0}/word_embedding.{1}.mat".format(exp_dir, iteration)))
model_files.extend(glob.glob("{0}/word_embedding.{1}.[0-9]*.mat".format(exp_dir, iteration)))
model_files.extend(glob.glob("{0}/feat_embedding.{1}.mat".format(exp_dir, iteration)))
model_files.extend(glob.glob("{0}/feat_embedding.{1}.[0-9]*.mat".format(exp_dir, iteration)))
model_files.extend(glob.glob("{0}/{1}.raw".format(exp_dir, iteration)))
model_files.extend(glob.glob("{0}/{1}.[0-9]*.raw".format(exp_dir, iteration)))
# compute_prob logs outlive model files, only consider iterations that do still have model files
if len(model_files) > 0:
iterations[iteration] = IterationInfo(model_files, objf, compute_prob_done)
return iterations


def remove_model_files_for_iter(iter_info):
for f in iter_info.model_files:
os.remove(f)


def keep_latest(iteration_dict):
max_to_keep = args.iters_to_keep
kept = 0
iterations_in_reverse_order = reversed(sorted(iteration_dict))
for iter in iterations_in_reverse_order:
if kept < max_to_keep:
kept += 1
else:
remove_model_files_for_iter(iteration_dict[iter])


def keep_best(iteration_dict):
iters_to_keep = args.iters_to_keep
best = []
for iter, iter_info in iteration_dict.items():
objf = iter_info.objf
if objf == -2000:
print(script_name + ": warning: objf unavailable for iter " + str(iter), file=sys.stderr)
continue
# add potential best, sort by objf, trim to iters_to_keep size
best.append((iter, objf))
best = sorted(best, key=lambda x: -x[1])
if len(best) > iters_to_keep:
throwaway = best[iters_to_keep:]
best = best[:iters_to_keep]
# remove iters that we know are not the best
for (iter, _) in throwaway:
remove_model_files_for_iter(iteration_dict[iter])


# grab all the iterations mapped to their model files, objf score and compute_prob status
iterations = get_iteration_files(args.rnnlm_dir)
# apply chosen cleanup strategy
if args.keep_latest:
keep_latest(iterations)
else:
keep_best(iterations)
11 changes: 10 additions & 1 deletion scripts/rnnlm/train_rnnlm.sh
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,11 @@ num_egs_threads=10 # number of threads used for sampling, if we're using
use_gpu=true # use GPU for training
use_gpu_for_diagnostics=false # set true to use GPU for compute_prob_*.log

# optional cleanup options
cleanup=false # add option --cleanup true to enable automatic cleanup of old models
cleanup_strategy="keep_latest" # determines cleanup strategy, use either "keep_latest" or "keep_best"
cleanup_keep_iters=3 # number of iterations that will have their models retained

trap 'for pid in $(jobs -pr); do kill -KILL $pid; done' INT QUIT TERM
. utils/parse_options.sh

Expand Down Expand Up @@ -222,12 +227,16 @@ while [ $x -lt $num_iters ]; do
nnet3-average $src_models $dir/$[x+1].raw '&&' \
matrix-sum --average=true $src_matrices $dir/${embedding_type}_embedding.$[x+1].mat
fi
# optionally, perform cleanup after training
if [ "$cleanup" = true ] ; then
python3 rnnlm/rnnlm_cleanup.py $dir --$cleanup_strategy --iters_to_keep $cleanup_keep_iters
fi
)

# the error message below is not that informative, but $cmd will
# have printed a more specific one.
[ -f $dir/.error ] && echo "$0: error with diagnostics on iteration $x of training" && exit 1;
fi

x=$[x+1]
num_splits_processed=$[num_splits_processed+this_num_jobs]
done
Expand Down