Skip to content

Commit 5bfd6af

Browse files
superbobryjax authors
authored and
jax authors
committed
Removed unnecessary skip in pallas_test.py::SoftmaxTest
The Triton bug, whatever it was, seems to have been fixed. PiperOrigin-RevId: 644293465
1 parent 3fd9326 commit 5bfd6af

File tree

1 file changed

+7
-10
lines changed

1 file changed

+7
-10
lines changed

Diff for: tests/pallas/pallas_test.py

+7-10
Original file line numberDiff line numberDiff line change
@@ -2148,16 +2148,11 @@ class RmsNormInterpreterTest(PallasTest):
21482148

21492149
class SoftmaxTest(PallasTest):
21502150

2151-
@parameterized.parameters(
2152-
(shape, dtype)
2153-
for shape in [(1024, 125), (4, 1024, 125)]
2154-
for dtype in (jnp.bfloat16, jnp.float16, jnp.float32)
2151+
@parameterized.product(
2152+
shape=[(1024, 125), (4, 1024, 125)],
2153+
dtype=[jnp.bfloat16, jnp.float16, jnp.float32]
21552154
)
21562155
def test_softmax(self, shape, dtype):
2157-
# TODO(bchetioui): add Triton bug reference when filed
2158-
if dtype == jnp.bfloat16:
2159-
raise absltest.SkipTest("Disabled due to Triton lowering bug")
2160-
21612156
x = jax.random.normal(random.key(0), shape, dtype=dtype)
21622157

21632158
atol, rtol = {
@@ -2166,9 +2161,11 @@ def test_softmax(self, shape, dtype):
21662161
jnp.float32: (1e-7, 1e-6),
21672162
}[dtype]
21682163

2164+
# We upcast to float32 because NumPy <2.0 does not handle custom dtypes
2165+
# properly. See https://github.com/google/jax/issues/11014.
21692166
np.testing.assert_allclose(
2170-
softmax.softmax(x, axis=-1),
2171-
jax.nn.softmax(x, axis=-1),
2167+
softmax.softmax(x, axis=-1).astype(jnp.float32),
2168+
jax.nn.softmax(x, axis=-1).astype(jnp.float32),
21722169
atol=atol,
21732170
rtol=rtol,
21742171
)

0 commit comments

Comments
 (0)