@@ -602,15 +602,9 @@ def weights_fn_for_mp(problem_task_id):
602602    problem_name  =  problem_instance .name 
603603    if  problem_instance .was_reversed :
604604      problem_name  +=  "_rev" 
605-     metrics  =  problem_instance .eval_metrics ( )
605+     metrics  =  problem_instance .eval_metric_fns ( model_hparams )
606606    if  hasattr (model_hparams .problem , "task_list" ):
607-       metrics  =  model_hparams .problem .eval_metrics ()
608-     if  not  all ([m  in  METRICS_FNS  for  m  in  metrics ]):
609-       error_str  =  ("Unrecognized metric. Problem %s specified metrics " 
610-                    "%s. Recognized metrics are %s." )
611-       raise  ValueError (error_str  %  (problem_name ,
612-                                     metrics ,
613-                                     list (METRICS_FNS .keys ())))
607+       metrics  =  model_hparams .problem .eval_metric_fns (model_hparams )
614608
615609    tm  =  problem_instance .get_hparams (model_hparams ).modality ["targets" ]
616610    if  not  isinstance (tm , dict ):
@@ -622,8 +616,7 @@ def weights_fn_for_mp(problem_task_id):
622616        ptid  =  problem_instance .task_id   # pylint: disable=cell-var-from-loop 
623617        weights_fn  =  weights_fn_for_mp (ptid )
624618
625-       for  metric  in  metrics :
626-         metric_fn  =  METRICS_FNS [metric ]
619+       for  metric , metric_fn  in  six .iteritems (metrics ):
627620        overload_eval_metric_name  =  getattr (
628621            model_hparams , "overload_eval_metric_name" , None )
629622        if  len (problems ) ==  1  and  overload_eval_metric_name :
@@ -642,9 +635,10 @@ def weights_fn_for_mp(problem_task_id):
642635
643636def  create_eager_metrics_for_problem (problem , model_hparams ):
644637  """See create_eager_metrics.""" 
645-   metric_names  =  problem .eval_metrics ( )
638+   metric_fns  =  problem .eval_metric_fns ( model_hparams )
646639  tm  =  problem .get_hparams (model_hparams ).modality ["targets" ]
647-   return  create_eager_metrics (metric_names , weights_fn = tm .targets_weights_fn )
640+   return  create_eager_metrics_internal (
641+         metric_fns , weights_fn = tm .targets_weights_fn )
648642
649643
650644def  create_eager_metrics (metric_names , weights_fn = common_layers .weights_all ):
@@ -662,9 +656,26 @@ def create_eager_metrics(metric_names, weights_fn=common_layers.weights_all):
662656  """ 
663657  metric_fns  =  dict (
664658      [(name , METRICS_FNS [name ]) for  name  in  metric_names ])
659+   return  create_eager_metrics_internal (metric_fns , weights_fn )
660+ 
661+ 
662+ def  create_eager_metrics_internal (metric_fns ,
663+                                   weights_fn = common_layers .weights_all ):
664+   """Create metrics accumulators and averager for Eager mode. 
665+ 
666+   Args: 
667+     metric_names: dict<metric name, metric function> 
668+     weights_fn: function that takes labels and returns a weights mask. Defaults 
669+       to weights of all 1, i.e. common_layers.weights_all. Use 
670+       common_layers.weights_nonzero if labels have 0-padding. 
671+ 
672+   Returns: 
673+     (accum_fn(predictions, targets) => None, 
674+      result_fn() => dict<str metric_name, float avg_val> 
675+   """ 
665676  tfe_metrics  =  dict ()
666677
667-   for  name  in  metric_names :
678+   for  name  in  metric_fns :
668679    tfe_metrics [name ] =  tfe .metrics .Mean (name = name )
669680
670681  def  metric_accum (predictions , targets ):
@@ -675,7 +686,7 @@ def metric_accum(predictions, targets):
675686
676687  def  metric_means ():
677688    avgs  =  {}
678-     for  name  in  metric_names :
689+     for  name  in  metric_fns :
679690      avgs [name ] =  tfe_metrics [name ].result ().numpy ()
680691    return  avgs 
681692
0 commit comments