@@ -1790,6 +1790,98 @@ def padded_cross_entropy(logits,
17901790    return  tf .reduce_sum (xent  *  weights ), tf .reduce_sum (weights )
17911791
17921792
1793+ def  padded_cross_entropy_mixture (logits ,
1794+                                  labels ,
1795+                                  label_smoothing ,
1796+                                  num_mixtures ,
1797+                                  weights_fn = weights_nonzero ,
1798+                                  reduce_sum = False ,
1799+                                  cutoff = 0.0 ,
1800+                                  gaussian = False ,
1801+                                  return_best_logits = False ):
1802+   """Compute cross-entropy assuming 0s are padding. 
1803+ 
1804+   Computes a loss numerator (the sum of losses), and loss denominator 
1805+   (the number of non-padding tokens). 
1806+ 
1807+   Computes cross-entropy for each mixture, and returns the corresponding values 
1808+   for the mixture with the highest probability 
1809+ 
1810+   Args: 
1811+     logits: `Tensor` with shape `[batch * num_mixtures, timesteps, vocab_size]`. 
1812+       optionally a FactoredTensor. 
1813+     labels: an integer `Tensor` with shape `[batch, timesteps]`. 
1814+     label_smoothing: a floating point `Scalar`. 
1815+     num_mixtures: an integer. 
1816+     weights_fn: A function from labels to weights. 
1817+     reduce_sum: a Boolean, whether to sum at the end or not. 
1818+     cutoff: a float, at which point to have no loss. 
1819+     gaussian: If true, use a Gaussian distribution for label smoothing 
1820+     return_best_logits: If true, return the logits of the mixture with highest 
1821+     probabilities for an example 
1822+ 
1823+   Returns: 
1824+     loss_numerator: a `Scalar`.  Sum of losses. 
1825+     loss_denominator: a `Scalar.  The number of non-padding target tokens. 
1826+ 
1827+   Raises: 
1828+     ValueError: in case of unsupported argument types. 
1829+   """ 
1830+   logit_shapes  =  shape_list (
1831+       logits )  # batch_size * num_mixtures, timesteps, 1, 1, vocab_size 
1832+   batch_size  =  tf .cast (logit_shapes [0 ] /  num_mixtures , dtype = tf .int32 )
1833+   timesteps  =  logit_shapes [1 ]
1834+   vocab_size  =  logit_shapes [4 ]
1835+ 
1836+   new_shape_for_xent  =  [num_mixtures ] +  shape_list (labels )
1837+   labels  =  tf .tile (labels , [num_mixtures , 1 , 1 , 1 ])
1838+ 
1839+   xent , weights  =  padded_cross_entropy (
1840+       logits , labels , label_smoothing , weights_fn , reduce_sum , cutoff , gaussian )
1841+ 
1842+   # reshape xent and weights to have the num_mixtures as first dimension 
1843+   xent  =  tf .reshape (xent , new_shape_for_xent )
1844+   weights  =  tf .reshape (weights , new_shape_for_xent [:- 1 ])
1845+ 
1846+   # sum up sentence neg log probs 
1847+   xent  =  tf .reduce_sum (xent , axis = 2 )
1848+ 
1849+   # if we need to compute the best logits 
1850+   if  return_best_logits :
1851+     best_mixture_indices  =  tf .cast (tf .argmin (xent , 0 ), dtype = tf .int32 )
1852+     individual_element_indices  =  tf .range (batch_size )
1853+     stacked_mixture_element_indices  =  tf .stack (
1854+         (tf .squeeze (best_mixture_indices ), individual_element_indices ), - 1 )
1855+     best_logits  =  tf .reshape (logits ,
1856+                              [num_mixtures , - 1 , timesteps , 1 , 1 , vocab_size ])
1857+     best_logits  =  tf .gather_nd (best_logits , stacked_mixture_element_indices )
1858+     best_logits  =  tf .reshape (best_logits ,
1859+                              [batch_size , timesteps , 1 , 1 , vocab_size ])
1860+ 
1861+   with  tf .control_dependencies ([
1862+       tf .assert_equal (
1863+           tf .shape (xent )[:3 ], [num_mixtures , batch_size , 1 ],
1864+           message = "Each batch element should have a probability value for each mixture element" 
1865+       )
1866+   ]):
1867+     xent  =  tf .reduce_min (xent , axis = 0 )
1868+     weights  =  tf .reduce_mean (weights , axis = 0 )
1869+ 
1870+   with  tf .control_dependencies ([
1871+       tf .assert_equal (
1872+           tf .shape (xent )[0 ], [batch_size ],
1873+           message = "There should be batch_size elements after selecting best mixture probabilities" 
1874+       )
1875+   ]):
1876+     summed_xent  =  tf .reduce_sum (xent )
1877+     summed_weights  =  tf .reduce_sum (weights )
1878+ 
1879+   if  return_best_logits :
1880+     return  summed_xent , summed_weights , best_logits 
1881+   else :
1882+     return  summed_xent , summed_weights 
1883+ 
1884+ 
17931885def  _weights_one_third (labels ):
17941886  """Returns Tensor of shape [batch, height, width]. Each element is 1/3.""" 
17951887  return  tf .ones (tf .shape (labels )[:- 1 ]) /  3. 
0 commit comments