@@ -196,6 +196,10 @@ def main(
196196
197197
198198@tvm .testing .requires_gpu
199+ @pytest .mark .skip (
200+ reason = "jaxlib.xla_extension.XlaRuntimeError: FAILED_PRECONDITION: DNN library initialization failed."
201+ )
202+ # TODO(mshr-h): may be fixed by upgrading jax to >=0.4.33
199203def test_unary ():
200204 import jax
201205
@@ -229,6 +233,10 @@ def _round(x):
229233
230234
231235@tvm .testing .requires_gpu
236+ @pytest .mark .skip (
237+ reason = "jaxlib.xla_extension.XlaRuntimeError: FAILED_PRECONDITION: DNN library initialization failed."
238+ )
239+ # TODO(mshr-h): may be fixed by upgrading jax to >=0.4.33
232240def test_binary ():
233241 import jax
234242
@@ -250,6 +258,10 @@ def fn(x, y):
250258
251259
252260@tvm .testing .requires_gpu
261+ @pytest .mark .skip (
262+ reason = "jaxlib.xla_extension.XlaRuntimeError: FAILED_PRECONDITION: DNN library initialization failed."
263+ )
264+ # TODO(mshr-h): may be fixed by upgrading jax to >=0.4.33
253265def test_const ():
254266 import jax
255267
@@ -260,6 +272,10 @@ def fn(x):
260272
261273
262274@tvm .testing .requires_gpu
275+ @pytest .mark .skip (
276+ reason = "jaxlib.xla_extension.XlaRuntimeError: FAILED_PRECONDITION: DNN library initialization failed."
277+ )
278+ # TODO(mshr-h): may be fixed by upgrading jax to >=0.4.33
263279def test_maximum ():
264280 import jax
265281 import jax .numpy as jnp
@@ -271,6 +287,10 @@ def fn(x, y):
271287
272288
273289@tvm .testing .requires_gpu
290+ @pytest .mark .skip (
291+ reason = "jaxlib.xla_extension.XlaRuntimeError: FAILED_PRECONDITION: DNN library initialization failed."
292+ )
293+ # TODO(mshr-h): may be fixed by upgrading jax to >=0.4.33
274294def test_minimum ():
275295 import jax
276296 import jax .numpy as jnp
@@ -282,6 +302,10 @@ def fn(x, y):
282302
283303
284304@tvm .testing .requires_gpu
305+ @pytest .mark .skip (
306+ reason = "jaxlib.xla_extension.XlaRuntimeError: FAILED_PRECONDITION: DNN library initialization failed."
307+ )
308+ # TODO(mshr-h): may be fixed by upgrading jax to >=0.4.33
285309def test_reduce ():
286310 import jax
287311 import jax .numpy as jnp
@@ -293,6 +317,10 @@ def fn(x):
293317
294318
295319@tvm .testing .requires_gpu
320+ @pytest .mark .skip (
321+ reason = "jaxlib.xla_extension.XlaRuntimeError: FAILED_PRECONDITION: DNN library initialization failed."
322+ )
323+ # TODO(mshr-h): may be fixed by upgrading jax to >=0.4.33
296324def test_reduce_window ():
297325 import jax
298326 from flax import linen as nn
@@ -304,6 +332,10 @@ def fn(x):
304332
305333
306334@tvm .testing .requires_gpu
335+ @pytest .mark .skip (
336+ reason = "jaxlib.xla_extension.XlaRuntimeError: FAILED_PRECONDITION: DNN library initialization failed."
337+ )
338+ # TODO(mshr-h): may be fixed by upgrading jax to >=0.4.33
307339def test_dot_general ():
308340 import jax
309341
@@ -314,8 +346,10 @@ def fn(x, y):
314346 check_correctness (jax .jit (fn ), input_shapes )
315347
316348
317- @pytest .mark .skip ()
318349@tvm .testing .requires_gpu
350+ @pytest .mark .skip (
351+ reason = "jaxlib.xla_extension.XlaRuntimeError: FAILED_PRECONDITION: DNN library initialization failed."
352+ )
319353# TODO(yongwww): fix flaky error of "invalid device ordinal"
320354def test_conv ():
321355 import jax
0 commit comments