diff --git a/tensor2tensor/data_generators/problem.py b/tensor2tensor/data_generators/problem.py index 03ec21ea1..b89aac526 100644 --- a/tensor2tensor/data_generators/problem.py +++ b/tensor2tensor/data_generators/problem.py @@ -367,6 +367,19 @@ def eval_metrics(self): metrics.Metrics.ACC_PER_SEQ, metrics.Metrics.NEG_LOG_PERPLEXITY ] + def eval_metric_fns(self, model_hparams): + metric_names = self.eval_metrics() + if not all([m in metrics.METRICS_FNS for m in metric_names]): + error_str = ("Unrecognized metric. Problem %s specified metrics " + "%s. Recognized metrics are %s.") + raise ValueError(error_str % (self.name, + metric_names, + list(metrics.METRICS_FNS.keys()))) + return { + metric_name: metrics.METRICS_FNS[metric_name] + for metric_name in metric_names + } + def eval_hooks(self, features, logits, hparams): del features, logits, hparams return [] diff --git a/tensor2tensor/utils/metrics.py b/tensor2tensor/utils/metrics.py index aac82e868..846f73252 100644 --- a/tensor2tensor/utils/metrics.py +++ b/tensor2tensor/utils/metrics.py @@ -602,15 +602,9 @@ def weights_fn_for_mp(problem_task_id): problem_name = problem_instance.name if problem_instance.was_reversed: problem_name += "_rev" - metrics = problem_instance.eval_metrics() + metrics = problem_instance.eval_metric_fns(model_hparams) if hasattr(model_hparams.problem, "task_list"): - metrics = model_hparams.problem.eval_metrics() - if not all([m in METRICS_FNS for m in metrics]): - error_str = ("Unrecognized metric. Problem %s specified metrics " - "%s. Recognized metrics are %s.") - raise ValueError(error_str % (problem_name, - metrics, - list(METRICS_FNS.keys()))) + metrics = model_hparams.problem.eval_metric_fns(model_hparams) tm = problem_instance.get_hparams(model_hparams).modality["targets"] if not isinstance(tm, dict): @@ -622,8 +616,7 @@ def weights_fn_for_mp(problem_task_id): ptid = problem_instance.task_id # pylint: disable=cell-var-from-loop weights_fn = weights_fn_for_mp(ptid) - for metric in metrics: - metric_fn = METRICS_FNS[metric] + for metric, metric_fn in six.iteritems(metrics): overload_eval_metric_name = getattr( model_hparams, "overload_eval_metric_name", None) if len(problems) == 1 and overload_eval_metric_name: @@ -642,9 +635,10 @@ def weights_fn_for_mp(problem_task_id): def create_eager_metrics_for_problem(problem, model_hparams): """See create_eager_metrics.""" - metric_names = problem.eval_metrics() + metric_fns = problem.eval_metric_fns(model_hparams) tm = problem.get_hparams(model_hparams).modality["targets"] - return create_eager_metrics(metric_names, weights_fn=tm.targets_weights_fn) + return create_eager_metrics_internal( + metric_fns, weights_fn=tm.targets_weights_fn) def create_eager_metrics(metric_names, weights_fn=common_layers.weights_all): @@ -662,9 +656,26 @@ def create_eager_metrics(metric_names, weights_fn=common_layers.weights_all): """ metric_fns = dict( [(name, METRICS_FNS[name]) for name in metric_names]) + return create_eager_metrics_internal(metric_fns, weights_fn) + + +def create_eager_metrics_internal(metric_fns, + weights_fn=common_layers.weights_all): + """Create metrics accumulators and averager for Eager mode. + + Args: + metric_names: dict + weights_fn: function that takes labels and returns a weights mask. Defaults + to weights of all 1, i.e. common_layers.weights_all. Use + common_layers.weights_nonzero if labels have 0-padding. + + Returns: + (accum_fn(predictions, targets) => None, + result_fn() => dict + """ tfe_metrics = dict() - for name in metric_names: + for name in metric_fns: tfe_metrics[name] = tfe.metrics.Mean(name=name) def metric_accum(predictions, targets): @@ -675,7 +686,7 @@ def metric_accum(predictions, targets): def metric_means(): avgs = {} - for name in metric_names: + for name in metric_fns: avgs[name] = tfe_metrics[name].result().numpy() return avgs diff --git a/tensor2tensor/utils/t2t_model.py b/tensor2tensor/utils/t2t_model.py index 393d66cb0..6829a8366 100644 --- a/tensor2tensor/utils/t2t_model.py +++ b/tensor2tensor/utils/t2t_model.py @@ -1721,7 +1721,7 @@ def create_tpu_eval_metrics_fn(problem, model_hparams): """Create the metrics_fn that TPUEstimatorSpec expects.""" metric_fns = [] - eval_metrics = problem.eval_metrics() + eval_metrics = problem.eval_metric_fns(model_hparams) tm = _create_target_modality(problem.get_hparams(model_hparams).modality) if isinstance(tm, dict): @@ -1739,12 +1739,12 @@ def wrapped_metric_fn(logits, labels, features, weights_fn=weights_fn): return wrapped_metric_fn - for metric in eval_metrics: + for metric, metric_fn in six.iteritems(eval_metrics): if metric in TPU_METRIC_BLACKLIST: log_warn("Skipping eval metric %s in TPU_METRIC_BLACKLIST", metric) continue name = "%s/metrics-%s/%s" % (k, problem.name, metric) - metric_fns.append((name, make_metric_fn(metrics.METRICS_FNS[metric]))) + metric_fns.append((name, make_metric_fn(metric_fn))) else: weights_fn = tm.targets_weights_fn @@ -1759,12 +1759,12 @@ def wrapped_metric_fn(logits, labels, features): return wrapped_metric_fn - for metric in eval_metrics: + for metric, metric_fn in six.iteritems(eval_metrics): if metric in TPU_METRIC_BLACKLIST: log_warn("Skipping eval metric %s in TPU_METRIC_BLACKLIST", metric) continue name = "metrics-%s/%s" % (problem.name, metric) - metric_fns.append((name, make_metric_fn(metrics.METRICS_FNS[metric]))) + metric_fns.append((name, make_metric_fn(metric_fn))) def all_metrics_fn(**kwargs): """Construct metrics dictionary."""