Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
110 changes: 81 additions & 29 deletions egs/wsj/s5/steps/libs/nnet3/report/log_parse.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,17 @@
"value-avg=\[.*=\((.+)\), mean=([0-9\.\-e]+), stddev=([0-9\.e\-]+)\].*",
"deriv-avg=\[.*=\((.+)\), mean=([0-9\.\-e]+), stddev=([0-9\.e\-]+)\]"])


g_normal_nonlin_regex_pattern = ''.join([".*progress.([0-9]+).log:component name=(.+) ",
"type=(.*)Component,.*",
"value-avg=\[.*=\((.+)\), mean=([0-9\.\-e]+), stddev=([0-9\.e\-]+)\].*",
"deriv-avg=\[.*=\((.+)\), mean=([0-9\.\-e]+), stddev=([0-9\.e\-]+)\]"])

g_normal_nonlin_regex_pattern_with_oderiv = ''.join([".*progress.([0-9]+).log:component name=(.+) ",
"type=(.*)Component,.*",
"value-avg=\[.*=\((.+)\), mean=([0-9\.\-e]+), stddev=([0-9\.e\-]+)\].*",
"deriv-avg=\[.*=\((.+)\), mean=([0-9\.\-e]+), stddev=([0-9\.e\-]+)\].*",
"oderiv-rms=\[.*=\((.+)\), mean=([0-9\.\-e]+), stddev=([0-9\.e\-]+)\]"])

class KaldiLogParseException(Exception):
""" An Exception class that throws an error when there is an issue in
parsing the log files. Extend this class if more granularity is needed.
Expand All @@ -54,10 +59,12 @@ def __init__(self, message = None):

# This function is used to fill stats_per_component_per_iter table with the
# results of regular expression.

def fill_nonlin_stats_table_with_regex_result(groups, gate_index, stats_table):
iteration = int(groups[0])
component_name = groups[1]
component_type = groups[2]
# for value-avg
value_percentiles = groups[3+gate_index*6]
value_mean = float(groups[4+gate_index*6])
value_stddev = float(groups[5+gate_index*6])
Expand All @@ -66,6 +73,7 @@ def fill_nonlin_stats_table_with_regex_result(groups, gate_index, stats_table):
value_5th = float(value_percentiles_split[4])
value_50th = float(value_percentiles_split[6])
value_95th = float(value_percentiles_split[9])
# for deriv-avg
deriv_percentiles = groups[6+gate_index*6]
deriv_mean = float(groups[7+gate_index*6])
deriv_stddev = float(groups[8+gate_index*6])
Expand All @@ -74,29 +82,68 @@ def fill_nonlin_stats_table_with_regex_result(groups, gate_index, stats_table):
deriv_5th = float(deriv_percentiles_split[4])
deriv_50th = float(deriv_percentiles_split[6])
deriv_95th = float(deriv_percentiles_split[9])
try:
if stats_table[component_name]['stats'].has_key(iteration):
stats_table[component_name]['stats'][iteration].extend(
[value_mean, value_stddev,
deriv_mean, deriv_stddev,
value_5th, value_50th, value_95th,
deriv_5th, deriv_50th, deriv_95th])
else:
stats_table[component_name]['stats'][iteration] = [
value_mean, value_stddev,
deriv_mean, deriv_stddev,
value_5th, value_50th, value_95th,
deriv_5th, deriv_50th, deriv_95th]
except KeyError:
stats_table[component_name] = {}
stats_table[component_name]['type'] = component_type
stats_table[component_name]['stats'] = {}
stats_table[component_name][
'stats'][iteration] = [value_mean, value_stddev,
deriv_mean, deriv_stddev,
value_5th, value_50th, value_95th,
deriv_5th, deriv_50th, deriv_95th]

if len(groups) <= 9:
try:
if stats_table[component_name]['stats'].has_key(iteration):
stats_table[component_name]['stats'][iteration].extend(
[value_mean, value_stddev,
deriv_mean, deriv_stddev,
value_5th, value_50th, value_95th,
deriv_5th, deriv_50th, deriv_95th])
else:
stats_table[component_name]['stats'][iteration] = [
value_mean, value_stddev,
deriv_mean, deriv_stddev,
value_5th, value_50th, value_95th,
deriv_5th, deriv_50th, deriv_95th]
except KeyError:
stats_table[component_name] = {}
stats_table[component_name]['type'] = component_type
stats_table[component_name]['stats'] = {}
stats_table[component_name][
'stats'][iteration] = [value_mean, value_stddev,
deriv_mean, deriv_stddev,
value_5th, value_50th, value_95th,
deriv_5th, deriv_50th, deriv_95th]
else:
#for oderiv-rms
oderiv_percentiles = groups[9+gate_index*6]
oderiv_mean = float(groups[10+gate_index*6])
oderiv_stddev = float(groups[11+gate_index*6])
oderiv_percentiles_split = re.split(',| ',oderiv_percentiles)
assert len(oderiv_percentiles_split) == 13
oderiv_5th = float(oderiv_percentiles_split[4])
oderiv_50th = float(oderiv_percentiles_split[6])
oderiv_95th = float(oderiv_percentiles_split[9])
try:
if stats_table[component_name]['stats'].has_key(iteration):
stats_table[component_name]['stats'][iteration].extend(
[value_mean, value_stddev,
deriv_mean, deriv_stddev,
oderiv_mean, oderiv_stddev,
value_5th, value_50th, value_95th,
deriv_5th, deriv_50th, deriv_95th,
oderiv_5th, oderiv_50th, oderiv_95th])
else:
stats_table[component_name]['stats'][iteration] = [
value_mean, value_stddev,
deriv_mean, deriv_stddev,
oderiv_mean, oderiv_stddev,
value_5th, value_50th, value_95th,
deriv_5th, deriv_50th, deriv_95th,
oderiv_5th, oderiv_50th, oderiv_95th]
except KeyError:
stats_table[component_name] = {}
stats_table[component_name]['type'] = component_type
stats_table[component_name]['stats'] = {}
stats_table[component_name][
'stats'][iteration] = [value_mean, value_stddev,
deriv_mean, deriv_stddev,
oderiv_mean, oderiv_stddev,
value_5th, value_50th, value_95th,
deriv_5th, deriv_50th, deriv_95th,
oderiv_5th, oderiv_50th, oderiv_95th]

def parse_progress_logs_for_nonlinearity_stats(exp_dir):

Expand All @@ -116,11 +163,18 @@ def parse_progress_logs_for_nonlinearity_stats(exp_dir):
stats_per_component_per_iter = {}

progress_log_lines = common_lib.get_command_stdout(
'grep -e "value-avg.*deriv-avg" {0}'.format(progress_log_files),
'grep -e "value-avg.*deriv-avg.*oderiv" {0}'.format(progress_log_files),
require_zero_status = False)

parse_regex = re.compile(g_normal_nonlin_regex_pattern)

if progress_log_lines:
# cases with oderiv-rms
parse_regex = re.compile(g_normal_nonlin_regex_pattern_with_oderiv)
else:
# cases with only value-avg and deriv-avg
progress_log_lines = common_lib.get_command_stdout(
'grep -e "value-avg.*deriv-avg" {0}'.format(progress_log_files),
require_zero_status = False)
parse_regex = re.compile(g_normal_nonlin_regex_pattern)

for line in progress_log_lines.split("\n"):
mat_obj = parse_regex.search(line)
Expand Down Expand Up @@ -333,7 +387,6 @@ def get_train_times(exp_dir):
train_times[iter] = max(values)
return train_times


def parse_prob_logs(exp_dir, key='accuracy', output="output"):
train_prob_files = "%s/log/compute_prob_train.*.log" % (exp_dir)
valid_prob_files = "%s/log/compute_prob_valid.*.log" % (exp_dir)
Expand Down Expand Up @@ -456,7 +509,6 @@ def parse_rnnlm_prob_logs(exp_dir, key='objf'):




def generate_acc_logprob_report(exp_dir, key="accuracy", output="output"):
try:
times = get_train_times(exp_dir)
Expand Down Expand Up @@ -488,4 +540,4 @@ def generate_acc_logprob_report(exp_dir, key="accuracy", output="output"):
total_time += times[iter]
report.append("Total training time is {0}\n".format(
str(datetime.timedelta(seconds=total_time))))
return ["\n".join(report), times, data]
return ["\n".join(report), times, data]
Loading