@@ -681,38 +681,38 @@ def metric_means():
681681  return  metric_accum , metric_means 
682682
683683
684- def  word_error_rate (raw_predictions , labels , lookup = None ,
684+ def  word_error_rate (raw_predictions ,
685+                     labels ,
686+                     lookup = None ,
685687                    weights_fn = common_layers .weights_nonzero ):
686-   """ 
687-   :param raw_predictions: 
688-   :param labels: 
689-   :param lookup: 
690-     A tf.constant mapping indices to output tokens. 
691-   :param weights_fn: 
692-   :return: 
688+   """Calculate word error rate. 
689+ 
690+   Args: 
691+     raw_predictions: The raw predictions. 
692+     labels: The actual labels. 
693+     lookup: A tf.constant mapping indices to output tokens. 
694+     weights_fn: Weighting function. 
695+ 
696+   Returns: 
693697    The word error rate. 
694698  """ 
695699
696700  def  from_tokens (raw , lookup_ ):
697701    gathered  =  tf .gather (lookup_ , tf .cast (raw , tf .int32 ))
698-     joined  =  tf .regex_replace (tf .reduce_join (gathered , axis = 1 ), b' <EOS>.*'  , b''  )
699-     cleaned  =  tf .regex_replace (joined , b'_'  , b' '  )
700-     tokens  =  tf .string_split (cleaned , ' ' )
702+     joined  =  tf .regex_replace (tf .reduce_join (gathered , axis = 1 ), b" <EOS>.*"  , b""  )
703+     cleaned  =  tf .regex_replace (joined , b"_"  , b" "  )
704+     tokens  =  tf .string_split (cleaned , " " )
701705    return  tokens 
702706
703707  def  from_characters (raw , lookup_ ):
704-     """ 
705-     Convert ascii+2 encoded codes to string-tokens. 
706-     """ 
708+     """Convert ascii+2 encoded codes to string-tokens.""" 
707709    corrected  =  tf .bitcast (
708-       tf .clip_by_value (
709-         tf .subtract (raw , 2 ), 0 , 255 
710-       ), tf .uint8 )
710+         tf .clip_by_value (tf .subtract (raw , 2 ), 0 , 255 ), tf .uint8 )
711711
712712    gathered  =  tf .gather (lookup_ , tf .cast (corrected , tf .int32 ))[:, :, 0 ]
713713    joined  =  tf .reduce_join (gathered , axis = 1 )
714-     cleaned  =  tf .regex_replace (joined , b' \0 '  , b''  )
715-     tokens  =  tf .string_split (cleaned , ' ' )
714+     cleaned  =  tf .regex_replace (joined , b" \0 "  , b""  )
715+     tokens  =  tf .string_split (cleaned , " " )
716716    return  tokens 
717717
718718  if  lookup  is  None :
@@ -727,18 +727,16 @@ def from_characters(raw, lookup_):
727727  with  tf .variable_scope ("word_error_rate" , values = [raw_predictions , labels ]):
728728
729729    raw_predictions  =  tf .squeeze (
730-       tf .argmax (raw_predictions , axis = - 1 ), axis = (2 , 3 ))
730+          tf .argmax (raw_predictions , axis = - 1 ), axis = (2 , 3 ))
731731    labels  =  tf .squeeze (labels , axis = (2 , 3 ))
732732
733733    reference  =  convert_fn (labels , lookup )
734734    predictions  =  convert_fn (raw_predictions , lookup )
735735
736736    distance  =  tf .reduce_sum (
737-       tf .edit_distance (predictions , reference , normalize = False )
738-     )
737+         tf .edit_distance (predictions , reference , normalize = False ))
739738    reference_length  =  tf .cast (
740-       tf .size (reference .values , out_type = tf .int32 ), dtype = tf .float32 
741-     )
739+         tf .size (reference .values , out_type = tf .int32 ), dtype = tf .float32 )
742740
743741    return  distance  /  reference_length , reference_length 
744742
0 commit comments