Skip to content

Commit 560dbfe

Browse files
stefan-falkkpe
authored andcommitted
internal merge of PR tensorflow#1242
PiperOrigin-RevId: 223252032
1 parent eb046c0 commit 560dbfe

File tree

3 files changed

+35
-38
lines changed

3 files changed

+35
-38
lines changed

tensor2tensor/data_generators/speech_recognition.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,6 @@ def preprocess_example(self, example, mode, hparams):
140140
def eval_metrics(self):
141141
defaults = super(SpeechRecognitionProblem, self).eval_metrics()
142142
return defaults + [
143-
metrics.Metrics.EDIT_DISTANCE,
144-
metrics.Metrics.WORD_ERROR_RATE
143+
metrics.Metrics.EDIT_DISTANCE,
144+
metrics.Metrics.WORD_ERROR_RATE
145145
]

tensor2tensor/utils/metrics.py

Lines changed: 22 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -681,38 +681,38 @@ def metric_means():
681681
return metric_accum, metric_means
682682

683683

684-
def word_error_rate(raw_predictions, labels, lookup=None,
684+
def word_error_rate(raw_predictions,
685+
labels,
686+
lookup=None,
685687
weights_fn=common_layers.weights_nonzero):
686-
"""
687-
:param raw_predictions:
688-
:param labels:
689-
:param lookup:
690-
A tf.constant mapping indices to output tokens.
691-
:param weights_fn:
692-
:return:
688+
"""Calculate word error rate.
689+
690+
Args:
691+
raw_predictions: The raw predictions.
692+
labels: The actual labels.
693+
lookup: A tf.constant mapping indices to output tokens.
694+
weights_fn: Weighting function.
695+
696+
Returns:
693697
The word error rate.
694698
"""
695699

696700
def from_tokens(raw, lookup_):
697701
gathered = tf.gather(lookup_, tf.cast(raw, tf.int32))
698-
joined = tf.regex_replace(tf.reduce_join(gathered, axis=1), b'<EOS>.*', b'')
699-
cleaned = tf.regex_replace(joined, b'_', b' ')
700-
tokens = tf.string_split(cleaned, ' ')
702+
joined = tf.regex_replace(tf.reduce_join(gathered, axis=1), b"<EOS>.*", b"")
703+
cleaned = tf.regex_replace(joined, b"_", b" ")
704+
tokens = tf.string_split(cleaned, " ")
701705
return tokens
702706

703707
def from_characters(raw, lookup_):
704-
"""
705-
Convert ascii+2 encoded codes to string-tokens.
706-
"""
708+
"""Convert ascii+2 encoded codes to string-tokens."""
707709
corrected = tf.bitcast(
708-
tf.clip_by_value(
709-
tf.subtract(raw, 2), 0, 255
710-
), tf.uint8)
710+
tf.clip_by_value(tf.subtract(raw, 2), 0, 255), tf.uint8)
711711

712712
gathered = tf.gather(lookup_, tf.cast(corrected, tf.int32))[:, :, 0]
713713
joined = tf.reduce_join(gathered, axis=1)
714-
cleaned = tf.regex_replace(joined, b'\0', b'')
715-
tokens = tf.string_split(cleaned, ' ')
714+
cleaned = tf.regex_replace(joined, b"\0", b"")
715+
tokens = tf.string_split(cleaned, " ")
716716
return tokens
717717

718718
if lookup is None:
@@ -727,18 +727,16 @@ def from_characters(raw, lookup_):
727727
with tf.variable_scope("word_error_rate", values=[raw_predictions, labels]):
728728

729729
raw_predictions = tf.squeeze(
730-
tf.argmax(raw_predictions, axis=-1), axis=(2, 3))
730+
tf.argmax(raw_predictions, axis=-1), axis=(2, 3))
731731
labels = tf.squeeze(labels, axis=(2, 3))
732732

733733
reference = convert_fn(labels, lookup)
734734
predictions = convert_fn(raw_predictions, lookup)
735735

736736
distance = tf.reduce_sum(
737-
tf.edit_distance(predictions, reference, normalize=False)
738-
)
737+
tf.edit_distance(predictions, reference, normalize=False))
739738
reference_length = tf.cast(
740-
tf.size(reference.values, out_type=tf.int32), dtype=tf.float32
741-
)
739+
tf.size(reference.values, out_type=tf.int32), dtype=tf.float32)
742740

743741
return distance / reference_length, reference_length
744742

tensor2tensor/utils/metrics_test.py

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -108,18 +108,18 @@ def testSequenceEditDistanceMetric(self):
108108
def testWordErrorRateMetric(self):
109109

110110
ref = np.asarray([
111-
# a b c
112-
[97, 34, 98, 34, 99],
113-
[97, 34, 98, 34, 99],
114-
[97, 34, 98, 34, 99],
115-
[97, 34, 98, 34, 99],
111+
# a b c
112+
[97, 34, 98, 34, 99],
113+
[97, 34, 98, 34, 99],
114+
[97, 34, 98, 34, 99],
115+
[97, 34, 98, 34, 99],
116116
])
117117

118118
hyp = np.asarray([
119-
[97, 34, 98, 34, 99], # a b c
120-
[97, 34, 98, 0, 0], # a b
121-
[97, 34, 98, 34, 100], # a b d
122-
[0, 0, 0, 0, 0] # empty
119+
[97, 34, 98, 34, 99], # a b c
120+
[97, 34, 98, 0, 0], # a b
121+
[97, 34, 98, 34, 100], # a b d
122+
[0, 0, 0, 0, 0] # empty
123123
])
124124

125125
labels = np.reshape(ref, ref.shape + (1, 1))
@@ -130,9 +130,8 @@ def testWordErrorRateMetric(self):
130130
predictions[i, j, 0, 0, idx] = 1
131131

132132
with self.test_session() as session:
133-
actual_wer, actual_ref_len = session.run(
134-
metrics.word_error_rate(predictions, labels)
135-
)
133+
actual_wer, unused_actual_ref_len = session.run(
134+
metrics.word_error_rate(predictions, labels))
136135

137136
expected_wer = 0.417
138137
places = 3

0 commit comments

Comments
 (0)