@@ -256,7 +256,7 @@ def standardize_images(x):
256256 x = tf .to_float (tf .reshape (x , [- 1 ] + x_shape [- 3 :]))
257257 x_mean = tf .reduce_mean (x , axis = [1 , 2 ], keepdims = True )
258258 x_variance = tf .reduce_mean (
259- tf .squared_difference ( x , x_mean ), axis = [1 , 2 ], keepdims = True )
259+ tf .square ( x - x_mean ), axis = [1 , 2 ], keepdims = True )
260260 num_pixels = tf .to_float (x_shape [- 2 ] * x_shape [- 3 ])
261261 x = (x - x_mean ) / tf .maximum (tf .sqrt (x_variance ), tf .rsqrt (num_pixels ))
262262 return tf .reshape (x , x_shape )
@@ -634,8 +634,7 @@ def layer_norm_compute(x, epsilon, scale, bias):
634634 """Layer norm raw computation."""
635635 epsilon , scale , bias = [cast_like (t , x ) for t in [epsilon , scale , bias ]]
636636 mean = tf .reduce_mean (x , axis = [- 1 ], keepdims = True )
637- variance = tf .reduce_mean (
638- tf .squared_difference (x , mean ), axis = [- 1 ], keepdims = True )
637+ variance = tf .reduce_mean (tf .square (x - mean ), axis = [- 1 ], keepdims = True )
639638 norm_x = (x - mean ) * tf .rsqrt (variance + epsilon )
640639 return norm_x * scale + bias
641640
@@ -691,8 +690,7 @@ def l2_norm(x, filters=None, epsilon=1e-6, name=None, reuse=None):
691690 "l2_norm_bias" , [filters ], initializer = tf .zeros_initializer ())
692691 epsilon , scale , bias = [cast_like (t , x ) for t in [epsilon , scale , bias ]]
693692 mean = tf .reduce_mean (x , axis = [- 1 ], keepdims = True )
694- l2norm = tf .reduce_sum (
695- tf .squared_difference (x , mean ), axis = [- 1 ], keepdims = True )
693+ l2norm = tf .reduce_sum (tf .square (x - mean ), axis = [- 1 ], keepdims = True )
696694 norm_x = (x - mean ) * tf .rsqrt (l2norm + epsilon )
697695 return norm_x * scale + bias
698696
@@ -3348,7 +3346,7 @@ def get_sorted_projections(x):
33483346
33493347 proj1 = get_sorted_projections (logits1 )
33503348 proj2 = get_sorted_projections (logits2 )
3351- dist = tf .reduce_mean (tf .squared_difference (proj1 , proj2 ))
3349+ dist = tf .reduce_mean (tf .square (proj1 - proj2 ))
33523350 if return_logits :
33533351 return dist , logits1 , logits2
33543352 return dist
0 commit comments