@@ -2148,16 +2148,11 @@ class RmsNormInterpreterTest(PallasTest):
2148
2148
2149
2149
class SoftmaxTest (PallasTest ):
2150
2150
2151
- @parameterized .parameters (
2152
- (shape , dtype )
2153
- for shape in [(1024 , 125 ), (4 , 1024 , 125 )]
2154
- for dtype in (jnp .bfloat16 , jnp .float16 , jnp .float32 )
2151
+ @parameterized .product (
2152
+ shape = [(1024 , 125 ), (4 , 1024 , 125 )],
2153
+ dtype = [jnp .bfloat16 , jnp .float16 , jnp .float32 ]
2155
2154
)
2156
2155
def test_softmax (self , shape , dtype ):
2157
- # TODO(bchetioui): add Triton bug reference when filed
2158
- if dtype == jnp .bfloat16 :
2159
- raise absltest .SkipTest ("Disabled due to Triton lowering bug" )
2160
-
2161
2156
x = jax .random .normal (random .key (0 ), shape , dtype = dtype )
2162
2157
2163
2158
atol , rtol = {
@@ -2166,9 +2161,11 @@ def test_softmax(self, shape, dtype):
2166
2161
jnp .float32 : (1e-7 , 1e-6 ),
2167
2162
}[dtype ]
2168
2163
2164
+ # We upcast to float32 because NumPy <2.0 does not handle custom dtypes
2165
+ # properly. See https://github.com/google/jax/issues/11014.
2169
2166
np .testing .assert_allclose (
2170
- softmax .softmax (x , axis = - 1 ),
2171
- jax .nn .softmax (x , axis = - 1 ),
2167
+ softmax .softmax (x , axis = - 1 ). astype ( jnp . float32 ) ,
2168
+ jax .nn .softmax (x , axis = - 1 ). astype ( jnp . float32 ) ,
2172
2169
atol = atol ,
2173
2170
rtol = rtol ,
2174
2171
)
0 commit comments