Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

Commit a8c246c

Browse files
Lukasz KaiserCopybara-Service
authored andcommitted
Correct modalities for TPU eval after recent changes.
PiperOrigin-RevId: 236380935
1 parent 5522545 commit a8c246c

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

tensor2tensor/utils/t2t_model.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1817,7 +1817,7 @@ def create_tpu_eval_metrics_fn(problem, model_hparams):
18171817
tm = _create_target_modality(problem.get_hparams(model_hparams).modality)
18181818
if isinstance(tm, dict):
18191819
for k, v in six.iteritems(tm):
1820-
weights_fn = v.targets_weights_fn
1820+
weights_fn = modalities.get_targets_weights_fn(v)
18211821

18221822
def make_metric_fn(metric_fn):
18231823
def wrapped_metric_fn(logits, labels, features, weights_fn=weights_fn):
@@ -1837,7 +1837,7 @@ def wrapped_metric_fn(logits, labels, features, weights_fn=weights_fn):
18371837
name = "%s/metrics-%s/%s" % (k, problem.name, metric)
18381838
metric_fns.append((name, make_metric_fn(metric_fn)))
18391839
else:
1840-
weights_fn = tm.targets_weights_fn
1840+
weights_fn = modalities.get_targets_weights_fn(tm)
18411841

18421842
def make_metric_fn(metric_fn):
18431843
def wrapped_metric_fn(logits, labels, features):

0 commit comments

Comments
 (0)