From 15de3168262b56c1b13a70b4abcabcaaa36a850b Mon Sep 17 00:00:00 2001 From: Youngwook Kim Date: Wed, 5 Dec 2018 08:53:20 +0900 Subject: [PATCH 1/3] Custom evaluation metrics --- tensor2tensor/data_generators/problem.py | 13 ++++++++++++ tensor2tensor/utils/metrics.py | 27 ++++++++---------------- tensor2tensor/utils/t2t_model.py | 10 ++++----- 3 files changed, 27 insertions(+), 23 deletions(-) diff --git a/tensor2tensor/data_generators/problem.py b/tensor2tensor/data_generators/problem.py index 03ec21ea1..b89aac526 100644 --- a/tensor2tensor/data_generators/problem.py +++ b/tensor2tensor/data_generators/problem.py @@ -367,6 +367,19 @@ def eval_metrics(self): metrics.Metrics.ACC_PER_SEQ, metrics.Metrics.NEG_LOG_PERPLEXITY ] + def eval_metric_fns(self, model_hparams): + metric_names = self.eval_metrics() + if not all([m in metrics.METRICS_FNS for m in metric_names]): + error_str = ("Unrecognized metric. Problem %s specified metrics " + "%s. Recognized metrics are %s.") + raise ValueError(error_str % (self.name, + metric_names, + list(metrics.METRICS_FNS.keys()))) + return { + metric_name: metrics.METRICS_FNS[metric_name] + for metric_name in metric_names + } + def eval_hooks(self, features, logits, hparams): del features, logits, hparams return [] diff --git a/tensor2tensor/utils/metrics.py b/tensor2tensor/utils/metrics.py index aac82e868..a93a717e5 100644 --- a/tensor2tensor/utils/metrics.py +++ b/tensor2tensor/utils/metrics.py @@ -602,15 +602,9 @@ def weights_fn_for_mp(problem_task_id): problem_name = problem_instance.name if problem_instance.was_reversed: problem_name += "_rev" - metrics = problem_instance.eval_metrics() + metrics = problem_instance.eval_metric_fns(model_hparams) if hasattr(model_hparams.problem, "task_list"): - metrics = model_hparams.problem.eval_metrics() - if not all([m in METRICS_FNS for m in metrics]): - error_str = ("Unrecognized metric. Problem %s specified metrics " - "%s. Recognized metrics are %s.") - raise ValueError(error_str % (problem_name, - metrics, - list(METRICS_FNS.keys()))) + metrics = model_hparams.problem.eval_metric_fns(model_hparams) tm = problem_instance.get_hparams(model_hparams).modality["targets"] if not isinstance(tm, dict): @@ -622,8 +616,7 @@ def weights_fn_for_mp(problem_task_id): ptid = problem_instance.task_id # pylint: disable=cell-var-from-loop weights_fn = weights_fn_for_mp(ptid) - for metric in metrics: - metric_fn = METRICS_FNS[metric] + for metric, metric_fn in metrics.items(): overload_eval_metric_name = getattr( model_hparams, "overload_eval_metric_name", None) if len(problems) == 1 and overload_eval_metric_name: @@ -642,16 +635,16 @@ def weights_fn_for_mp(problem_task_id): def create_eager_metrics_for_problem(problem, model_hparams): """See create_eager_metrics.""" - metric_names = problem.eval_metrics() + metric_fns = problem.eval_metric_fns(model_hparams) tm = problem.get_hparams(model_hparams).modality["targets"] - return create_eager_metrics(metric_names, weights_fn=tm.targets_weights_fn) + return create_eager_metrics(metric_fns, weights_fn=tm.targets_weights_fn) -def create_eager_metrics(metric_names, weights_fn=common_layers.weights_all): +def create_eager_metrics(metric_fns, weights_fn=common_layers.weights_all): """Create metrics accumulators and averager for Eager mode. Args: - metric_names: list from Metrics enum + metric_names: dict. weights_fn: function that takes labels and returns a weights mask. Defaults to weights of all 1, i.e. common_layers.weights_all. Use common_layers.weights_nonzero if labels have 0-padding. @@ -660,11 +653,9 @@ def create_eager_metrics(metric_names, weights_fn=common_layers.weights_all): (accum_fn(predictions, targets) => None, result_fn() => dict """ - metric_fns = dict( - [(name, METRICS_FNS[name]) for name in metric_names]) tfe_metrics = dict() - for name in metric_names: + for name in metric_fns: tfe_metrics[name] = tfe.metrics.Mean(name=name) def metric_accum(predictions, targets): @@ -675,7 +666,7 @@ def metric_accum(predictions, targets): def metric_means(): avgs = {} - for name in metric_names: + for name in metric_fns: avgs[name] = tfe_metrics[name].result().numpy() return avgs diff --git a/tensor2tensor/utils/t2t_model.py b/tensor2tensor/utils/t2t_model.py index 393d66cb0..95a3ed701 100644 --- a/tensor2tensor/utils/t2t_model.py +++ b/tensor2tensor/utils/t2t_model.py @@ -1721,7 +1721,7 @@ def create_tpu_eval_metrics_fn(problem, model_hparams): """Create the metrics_fn that TPUEstimatorSpec expects.""" metric_fns = [] - eval_metrics = problem.eval_metrics() + eval_metrics = problem.eval_metric_fns(model_hparams) tm = _create_target_modality(problem.get_hparams(model_hparams).modality) if isinstance(tm, dict): @@ -1739,12 +1739,12 @@ def wrapped_metric_fn(logits, labels, features, weights_fn=weights_fn): return wrapped_metric_fn - for metric in eval_metrics: + for metric, metric_fn in eval_metrics.items(): if metric in TPU_METRIC_BLACKLIST: log_warn("Skipping eval metric %s in TPU_METRIC_BLACKLIST", metric) continue name = "%s/metrics-%s/%s" % (k, problem.name, metric) - metric_fns.append((name, make_metric_fn(metrics.METRICS_FNS[metric]))) + metric_fns.append((name, make_metric_fn(metric_fn))) else: weights_fn = tm.targets_weights_fn @@ -1759,12 +1759,12 @@ def wrapped_metric_fn(logits, labels, features): return wrapped_metric_fn - for metric in eval_metrics: + for metric, metric_fn in eval_metrics.items(): if metric in TPU_METRIC_BLACKLIST: log_warn("Skipping eval metric %s in TPU_METRIC_BLACKLIST", metric) continue name = "metrics-%s/%s" % (problem.name, metric) - metric_fns.append((name, make_metric_fn(metrics.METRICS_FNS[metric]))) + metric_fns.append((name, make_metric_fn(metric_fn))) def all_metrics_fn(**kwargs): """Construct metrics dictionary.""" From c71c328daab37a2277164f19f551a4cd84c659a8 Mon Sep 17 00:00:00 2001 From: Youngwook Kim Date: Wed, 2 Jan 2019 12:32:37 +0900 Subject: [PATCH 2/3] Fix Python 2 compatibility issue --- tensor2tensor/utils/metrics.py | 2 +- tensor2tensor/utils/t2t_model.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/tensor2tensor/utils/metrics.py b/tensor2tensor/utils/metrics.py index a93a717e5..42de33d3c 100644 --- a/tensor2tensor/utils/metrics.py +++ b/tensor2tensor/utils/metrics.py @@ -616,7 +616,7 @@ def weights_fn_for_mp(problem_task_id): ptid = problem_instance.task_id # pylint: disable=cell-var-from-loop weights_fn = weights_fn_for_mp(ptid) - for metric, metric_fn in metrics.items(): + for metric, metric_fn in six.iteritems(metrics): overload_eval_metric_name = getattr( model_hparams, "overload_eval_metric_name", None) if len(problems) == 1 and overload_eval_metric_name: diff --git a/tensor2tensor/utils/t2t_model.py b/tensor2tensor/utils/t2t_model.py index 95a3ed701..6829a8366 100644 --- a/tensor2tensor/utils/t2t_model.py +++ b/tensor2tensor/utils/t2t_model.py @@ -1739,7 +1739,7 @@ def wrapped_metric_fn(logits, labels, features, weights_fn=weights_fn): return wrapped_metric_fn - for metric, metric_fn in eval_metrics.items(): + for metric, metric_fn in six.iteritems(eval_metrics): if metric in TPU_METRIC_BLACKLIST: log_warn("Skipping eval metric %s in TPU_METRIC_BLACKLIST", metric) continue @@ -1759,7 +1759,7 @@ def wrapped_metric_fn(logits, labels, features): return wrapped_metric_fn - for metric, metric_fn in eval_metrics.items(): + for metric, metric_fn in six.iteritems(eval_metrics): if metric in TPU_METRIC_BLACKLIST: log_warn("Skipping eval metric %s in TPU_METRIC_BLACKLIST", metric) continue From e4a0afa0e4994a710c67f4a6cb547ae8f99e95b2 Mon Sep 17 00:00:00 2001 From: Youngwook Kim Date: Wed, 2 Jan 2019 22:00:41 +0900 Subject: [PATCH 3/3] Fix notebook test --- tensor2tensor/utils/metrics.py | 26 +++++++++++++++++++++++--- 1 file changed, 23 insertions(+), 3 deletions(-) diff --git a/tensor2tensor/utils/metrics.py b/tensor2tensor/utils/metrics.py index 42de33d3c..846f73252 100644 --- a/tensor2tensor/utils/metrics.py +++ b/tensor2tensor/utils/metrics.py @@ -637,14 +637,34 @@ def create_eager_metrics_for_problem(problem, model_hparams): """See create_eager_metrics.""" metric_fns = problem.eval_metric_fns(model_hparams) tm = problem.get_hparams(model_hparams).modality["targets"] - return create_eager_metrics(metric_fns, weights_fn=tm.targets_weights_fn) + return create_eager_metrics_internal( + metric_fns, weights_fn=tm.targets_weights_fn) -def create_eager_metrics(metric_fns, weights_fn=common_layers.weights_all): +def create_eager_metrics(metric_names, weights_fn=common_layers.weights_all): """Create metrics accumulators and averager for Eager mode. Args: - metric_names: dict. + metric_names: list from Metrics enum + weights_fn: function that takes labels and returns a weights mask. Defaults + to weights of all 1, i.e. common_layers.weights_all. Use + common_layers.weights_nonzero if labels have 0-padding. + + Returns: + (accum_fn(predictions, targets) => None, + result_fn() => dict + """ + metric_fns = dict( + [(name, METRICS_FNS[name]) for name in metric_names]) + return create_eager_metrics_internal(metric_fns, weights_fn) + + +def create_eager_metrics_internal(metric_fns, + weights_fn=common_layers.weights_all): + """Create metrics accumulators and averager for Eager mode. + + Args: + metric_names: dict weights_fn: function that takes labels and returns a weights mask. Defaults to weights of all 1, i.e. common_layers.weights_all. Use common_layers.weights_nonzero if labels have 0-padding.