Skip to content

Commit 30b7b1c

Browse files
authored
[CI] Upgrade unity image tag to 20240917-153130-9f281758 (#17410)
* upgrade docker image to `20240917-153130-9f281758` * fix dynamo test case * building torch requires c++ 17 * temporary skip jax gpu tests due to XlaRuntimeError
1 parent 4e70e4a commit 30b7b1c

File tree

4 files changed

+41
-7
lines changed

4 files changed

+41
-7
lines changed

ci/jenkins/unity_jenkinsfile.groovy

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -30,14 +30,14 @@
3030
import org.jenkinsci.plugins.pipeline.modeldefinition.Utils
3131

3232
// NOTE: these lines are scanned by docker/dev_common.sh. Please update the regex as needed. -->
33-
ci_lint = 'tlcpack/ci-lint:20240105-165030-51bdaec6'
34-
ci_gpu = 'tlcpack/ci-gpu:20240105-165030-51bdaec6'
35-
ci_cpu = 'tlcpack/ci-cpu:20240105-165030-51bdaec6'
33+
ci_lint = 'tlcpack/ci_lint:20240917-153130-9f281758'
34+
ci_gpu = 'tlcpack/ci_gpu:20240917-153130-9f281758'
35+
ci_cpu = 'tlcpack/ci_cpu:20240917-153130-9f281758'
3636
ci_wasm = 'tlcpack/ci-wasm:v0.72'
3737
ci_i386 = 'tlcpack/ci-i386:v0.75'
3838
ci_qemu = 'tlcpack/ci-qemu:v0.11'
3939
ci_arm = 'tlcpack/ci-arm:v0.08'
40-
ci_hexagon = 'tlcpack/ci-hexagon:20240105-165030-51bdaec6'
40+
ci_hexagon = 'tlcpack/ci_hexagon:20240917-153130-9f281758'
4141
// <--- End of regex-scanned config.
4242

4343
// Parameters to allow overriding (in Jenkins UI), the images

src/contrib/msc/plugin/torch_codegen.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -219,7 +219,7 @@ void TorchPluginCodeGen::CodeGenCmake(const std::set<String>& devices) {
219219
flags.Set("PLUGIN_SUPPORT_TORCH", "");
220220
CodeGenPreCmake(devices, flags);
221221
stack_.line()
222-
.line("set(CMAKE_CXX_STANDARD 14)")
222+
.line("set(CMAKE_CXX_STANDARD 17)")
223223
.line("list(APPEND CMAKE_PREFIX_PATH \"" + config()->torch_prefix + "\")")
224224
.line("find_package(Torch REQUIRED)");
225225
Array<String> includes, libs;

tests/python/relax/test_frontend_dynamo.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -223,7 +223,7 @@ def subgraph_1(
223223
) -> R.Tensor((10,), dtype="float32"):
224224
# block 0
225225
with R.dataflow():
226-
lv5: R.Tensor((10,), dtype="float32") = R.multiply(inp_11, inp_01)
226+
lv5: R.Tensor((10,), dtype="float32") = R.multiply(inp_01, inp_11)
227227
gv1: R.Tensor((10,), dtype="float32") = lv5
228228
R.output(gv1)
229229
return gv1

tests/python/relax/test_frontend_stablehlo.py

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -196,6 +196,10 @@ def main(
196196

197197

198198
@tvm.testing.requires_gpu
199+
@pytest.mark.skip(
200+
reason="jaxlib.xla_extension.XlaRuntimeError: FAILED_PRECONDITION: DNN library initialization failed."
201+
)
202+
# TODO(mshr-h): may be fixed by upgrading jax to >=0.4.33
199203
def test_unary():
200204
import jax
201205

@@ -229,6 +233,10 @@ def _round(x):
229233

230234

231235
@tvm.testing.requires_gpu
236+
@pytest.mark.skip(
237+
reason="jaxlib.xla_extension.XlaRuntimeError: FAILED_PRECONDITION: DNN library initialization failed."
238+
)
239+
# TODO(mshr-h): may be fixed by upgrading jax to >=0.4.33
232240
def test_binary():
233241
import jax
234242

@@ -250,6 +258,10 @@ def fn(x, y):
250258

251259

252260
@tvm.testing.requires_gpu
261+
@pytest.mark.skip(
262+
reason="jaxlib.xla_extension.XlaRuntimeError: FAILED_PRECONDITION: DNN library initialization failed."
263+
)
264+
# TODO(mshr-h): may be fixed by upgrading jax to >=0.4.33
253265
def test_const():
254266
import jax
255267

@@ -260,6 +272,10 @@ def fn(x):
260272

261273

262274
@tvm.testing.requires_gpu
275+
@pytest.mark.skip(
276+
reason="jaxlib.xla_extension.XlaRuntimeError: FAILED_PRECONDITION: DNN library initialization failed."
277+
)
278+
# TODO(mshr-h): may be fixed by upgrading jax to >=0.4.33
263279
def test_maximum():
264280
import jax
265281
import jax.numpy as jnp
@@ -271,6 +287,10 @@ def fn(x, y):
271287

272288

273289
@tvm.testing.requires_gpu
290+
@pytest.mark.skip(
291+
reason="jaxlib.xla_extension.XlaRuntimeError: FAILED_PRECONDITION: DNN library initialization failed."
292+
)
293+
# TODO(mshr-h): may be fixed by upgrading jax to >=0.4.33
274294
def test_minimum():
275295
import jax
276296
import jax.numpy as jnp
@@ -282,6 +302,10 @@ def fn(x, y):
282302

283303

284304
@tvm.testing.requires_gpu
305+
@pytest.mark.skip(
306+
reason="jaxlib.xla_extension.XlaRuntimeError: FAILED_PRECONDITION: DNN library initialization failed."
307+
)
308+
# TODO(mshr-h): may be fixed by upgrading jax to >=0.4.33
285309
def test_reduce():
286310
import jax
287311
import jax.numpy as jnp
@@ -293,6 +317,10 @@ def fn(x):
293317

294318

295319
@tvm.testing.requires_gpu
320+
@pytest.mark.skip(
321+
reason="jaxlib.xla_extension.XlaRuntimeError: FAILED_PRECONDITION: DNN library initialization failed."
322+
)
323+
# TODO(mshr-h): may be fixed by upgrading jax to >=0.4.33
296324
def test_reduce_window():
297325
import jax
298326
from flax import linen as nn
@@ -304,6 +332,10 @@ def fn(x):
304332

305333

306334
@tvm.testing.requires_gpu
335+
@pytest.mark.skip(
336+
reason="jaxlib.xla_extension.XlaRuntimeError: FAILED_PRECONDITION: DNN library initialization failed."
337+
)
338+
# TODO(mshr-h): may be fixed by upgrading jax to >=0.4.33
307339
def test_dot_general():
308340
import jax
309341

@@ -314,8 +346,10 @@ def fn(x, y):
314346
check_correctness(jax.jit(fn), input_shapes)
315347

316348

317-
@pytest.mark.skip()
318349
@tvm.testing.requires_gpu
350+
@pytest.mark.skip(
351+
reason="jaxlib.xla_extension.XlaRuntimeError: FAILED_PRECONDITION: DNN library initialization failed."
352+
)
319353
# TODO(yongwww): fix flaky error of "invalid device ordinal"
320354
def test_conv():
321355
import jax

0 commit comments

Comments
 (0)