Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Multiprocessing with Jax on Slurm fails #23770

Open
NonsansWD opened this issue Sep 19, 2024 · 0 comments
Open

Multiprocessing with Jax on Slurm fails #23770

NonsansWD opened this issue Sep 19, 2024 · 0 comments
Labels
bug Something isn't working

Comments

@NonsansWD
Copy link

Description

Hello, unfortunately I am having an issue running a script using jax on SLURM. My Job only specifies 1 GPU but it still seems like an issue as the system still has multiple GPUs. That was also the only difference I was able to find between my personal computer and the server as it works on my personal machine with no problem using the exact same conda environment. Unfortunately I need to fix this anyway since my machine runs out of memory quickly. The error message I am getting looks like this:

2024-09-19 21:50:58.963539: W external/xla/xla/service/platform_util.cc:199] unable to create StreamExecutor for CUDA:0: failed initializing StreamExecutor for CUDA device ordinal 0: INTERNAL: failed call to cuDevicePrimaryCtxRetain: CUDA_ERROR_DEVICE_UNAVAILABLE: CUDA-capable device(s) is/are busy or unavailable
╭─────────────────────────────── Traceback (most recent call last) ────────────────────────────────╮
│ /home/**/.conda/envs/dtest/lib/python3.9/site-packages/jax/_src/xla_bridge.py:879 in         │
│ backends                                                                                         │
│                                                                                                  │
│    876 │   default_priority = -1000                                                              │
│    877 │   for platform, priority, fail_quietly in platform_registrations:                       │
│    878 │     try:                                                                                │
│ ❱  879 │   │   backend = _init_backend(platform)                                                 │
│    880 │   │   _backends[platform] = backend                                                     │
│    881 │   │                                                                                     │
│    882 │   │   if priority > default_priority:                                                   │
│                                                                                                  │
│ /home/**/.conda/envs/dtest/lib/python3.9/site-packages/jax/_src/xla_bridge.py:970 in         │
│ _init_backend                                                                                    │
│                                                                                                  │
│    967 │   logger.warning(f"Platform '{platform}' is experimental and not all JAX "              │
│    968 │   │   │   │      "functionality may be correctly supported!")                           │
│    969   logger.debug("Initializing backend '%s'", platform)                                     │
│ ❱  970   backend = registration.factory()                                                        │
│    971   # TODO(skye): consider raising more descriptive errors directly from backend            │
│    972   # factories instead of returning None.                                                  │
│    973   if backend is None:                                                                     │
│                                                                                                  │
│ /home/**/.conda/envs/dtest/lib/python3.9/site-packages/jax/_src/xla_bridge.py:668 in factory │
│                                                                                                  │
│    665 │     updated_options.update(options)                                                     │
│    666 │   updated_options.update(_options_from_jax_configs(plugin_name))                        │
│    667 │   if distributed.global_state.client is None:                                           │
│ ❱  668 │     return xla_client.make_c_api_client(plugin_name, updated_options, None)             │
│    669 │                                                                                         │
│    670 │   distribute_options = {                                                                │
│    671 │   │   'node_id': distributed.global_state.process_id,                                   │
│                                                                                                  │
│ /home/**/.conda/envs/dtest/lib/python3.9/site-packages/jaxlib/xla_client.py:200 in           │
│ make_c_api_client                                                                                │
│                                                                                                  │
│   197   """                                                                                      │
│   198   if options is None:                                                                      │
│   199 │   options = {}                                                                           │
│ ❱ 200   return _xla.get_c_api_client(plugin_name, options, distributed_client)                   │
│   201                                                                                            │
│   202                                                                                            │
│   203 def make_tpu_client(                                                                       │
╰──────────────────────────────────────────────────────────────────────────────────────────────────╯
XlaRuntimeError: INTERNAL: no supported devices found for platform CUDA

During handling of the above exception, another exception occurred:

SimplifiedTraceback: For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.

The above exception was the direct cause of the following exception:

╭─────────────────────────────── Traceback (most recent call last) ────────────────────────────────╮
│ in <module>:1                                                                                    │
│                                                                                                  │
│ /home/**/.conda/envs/dtest/lib/python3.9/multiprocessing/spawn.py:116 in spawn_main          │
│                                                                                                  │
│   113 │   │   resource_tracker._resource_tracker._fd = tracker_fd                                │
│   114 │   │   fd = pipe_handle                                                                   │
│   115 │   │   parent_sentinel = os.dup(pipe_handle)                                              │
│ ❱ 116 │   exitcode = _main(fd, parent_sentinel)                                                  │
│   117 │   sys.exit(exitcode)                                                                     │
│   118                                                                                            │
│   119                                                                                            │
│                                                                                                  │
│ /home/**/.conda/envs/dtest/lib/python3.9/multiprocessing/spawn.py:125 in _main               │
│                                                                                                  │
│   122 │   │   process.current_process()._inheriting = True                                       │
│   123 │   │   try:                                                                               │
│   124 │   │   │   preparation_data = reduction.pickle.load(from_parent)                          │
│ ❱ 125 │   │   │   prepare(preparation_data)                                                      │
│   126 │   │   │   self = reduction.pickle.load(from_parent)                                      │
│   127 │   │   finally:                                                                           │
│   128 │   │   │   del process.current_process()._inheriting                                      │
│                                                                                                  │
│ /home/**/.conda/envs/dtest/lib/python3.9/multiprocessing/spawn.py:236 in prepare             │
│                                                                                                  │
│   233 │   if 'init_main_from_name' in data:                                                      │
│   234 │   │   _fixup_main_from_name(data['init_main_from_name'])                                 │
│   235 │   elif 'init_main_from_path' in data:                                                    │
│ ❱ 236 │   │   _fixup_main_from_path(data['init_main_from_path'])                                 │
│   237                                                                                            │
│   238 # Multiprocessing module helpers to fix up the main module in                              │
│   239 # spawned subprocesses                                                                     │
│                                                                                                  │
│ /home/**/.conda/envs/dtest/lib/python3.9/multiprocessing/spawn.py:287 in                     │
│ _fixup_main_from_path                                                                            │
│                                                                                                  │
│   284 │   # non-main code that needs to be executed                                              │
│   285 │   old_main_modules.append(current_main)                                                  │
│   286 │   main_module = types.ModuleType("__mp_main__")                                          │
│ ❱ 287 │   main_content = runpy.run_path(main_path,                                               │
│   288 │   │   │   │   │   │   │   │     run_name="__mp_main__")                                  │
│   289 │   main_module.__dict__.update(main_content)                                              │
│   290 │   sys.modules['__main__'] = sys.modules['__mp_main__'] = main_module                     │
│                                                                                                  │
│ /home/**/.conda/envs/dtest/lib/python3.9/runpy.py:288 in run_path                            │
│                                                                                                  │
│   285 │   │   # Not a valid sys.path entry, so run the code directly                             │
│   286 │   │   # execfile() doesn't help as we want to allow compiled files                       │
│   287 │   │   code, fname = _get_code_from_file(run_name, path_name)                             │
│ ❱ 288 │   │   return _run_module_code(code, init_globals, run_name,                              │
│   289 │   │   │   │   │   │   │   │   pkg_name=pkg_name, script_name=fname)                      │
│   290 │   else:                                                                                  │
│   291 │   │   # Finder is defined for path, so add it to                                         │
│                                                                                                  │
│ /home/**/.conda/envs/dtest/lib/python3.9/runpy.py:97 in _run_module_code                     │
│                                                                                                  │
│    94 │   fname = script_name if mod_spec is None else mod_spec.origin                           │
│    95 │   with _TempModule(mod_name) as temp_module, _ModifiedArgv0(fname):                      │
│    96 │   │   mod_globals = temp_module.module.__dict__                                          │
│ ❱  97 │   │   _run_code(code, mod_globals, init_globals,                                         │
│    98 │   │   │   │     mod_name, mod_spec, pkg_name, script_name)                               │
│    99 │   # Copy the globals of the temporary module, as they                                    │
│   100 │   # may be cleared when the temporary module goes away                                   │
│                                                                                                  │
│ /home/**/.conda/envs/dtest/lib/python3.9/runpy.py:87 in _run_code                            │
│                                                                                                  │
│    84 │   │   │   │   │      __loader__ = loader,                                                │
│    85 │   │   │   │   │      __package__ = pkg_name,                                             │
│    86 │   │   │   │   │      __spec__ = mod_spec)                                                │
│ ❱  87 │   exec(code, run_globals)                                                                │
│    88 │   return run_globals                                                                     │
│    89                                                                                            │
│    90 def _run_module_code(code, init_globals=None,                                              │
│                                                                                                  │
│ /scratch/**/viper_debug/viper_rl/scripts/train_dreamer.py:25 in <module>                     │
│                                                                                                  │
│    22                                                                                            │
│    23 from viper_rl.dreamerv3 import embodied                                                    │
│    24 from viper_rl.dreamerv3.embodied import wrappers                                           │
│ ❱  25 from train_videogpt import collect_data                                                    │
│    26 from flax.training import checkpoints                                                      │
│    27                                                                                            │
│    28 def main(argv=None):                                                                       │
│                                                                                                  │
│ /scratch/**/viper_debug/viper_rl/scripts/train_videogpt.py:25 in <module>                    │
│                                                                                                  │
│    22 directory = directory.parent                                                               │
│    23 sys.path.append(str(directory.parent))                                                     │
│    24                                                                                            │
│ ❱  25 from viper_rl.videogpt.models import AE, VideoGPT                                          │
│    26 from viper_rl.videogpt.sampler import VideoGPTSampler                                      │
│    27 from viper_rl.videogpt.data import load_dataset                                            │
│    28 from viper_rl.videogpt.train_utils import init_model_state_videogpt, get_first_device, P   │
│                                                                                                  │
│ /scratch/**/viper_debug/viper_rl/viper_rl/videogpt/models/__init__.py:12 in <module>         │
│                                                                                                  │
│     9                                                                                            │
│    10 from .vqgan import VQGAN                                                                   │
│    11 from .videogpt import VideoGPT                                                             │
│ ❱  12 from .stylegan_disc import StyleGANDisc                                                    │
│    13 from .vqgan import VQGAN                                                                   │
│    14                                                                                            │
│    15                                                                                            │
│                                                                                                  │
│ /scratch/**/viper_debug/viper_rl/viper_rl/videogpt/models/stylegan_disc.py:174 in <module>   │
│                                                                                                  │
│   171 │   return jnp.concatenate((x, y_std), axis=3)                                             │
│   172                                                                                            │
│   173                                                                                            │
│ ❱ 174 class DiscriminatorBlock(nn.Module):                                                       │
│   175 │   in_features: int                                                                       │
│   176 │   out_features: int                                                                      │
│   177 │   activation_function: ActivationFunction = jnn.leaky_relu                               │
│                                                                                                  │
│ /scratch/**/viper_debug/viper_rl/viper_rl/videogpt/models/stylegan_disc.py:178 in            │
│ DiscriminatorBlock                                                                               │
│                                                                                                  │
│   175 │   in_features: int                                                                       │
│   176 │   out_features: int                                                                      │
│   177 │   activation_function: ActivationFunction = jnn.leaky_relu                               │
│ ❱ 178 │   resample_kernel: jnp.ndarray = jnp.array([1, 3, 3, 1])                                 │
│   179 │   dtype: jnp.dtype = jnp.float32                                                         │
│   180 │                                                                                          │
│   181 │   def setup(self):                                                                       │
│                                                                                                  │
│ /home/**/.conda/envs/dtest/lib/python3.9/site-packages/jax/_src/numpy/lax_numpy.py:3214 in   │
│ array                                                                                            │
│                                                                                                  │
│   3211   else:                                                                                   │
│   3212 │   raise TypeError(f"Unexpected input type for array: {type(object)}")                   │
│   3213                                                                                           │
│ ❱ 3214   out_array: Array = lax_internal._convert_element_type(                                  │
│   3215 │     out, dtype, weak_type=weak_type)                                                    │
│   3216   if ndmin > ndim(out_array):                                                             │
│   3217 │   out_array = lax.expand_dims(out_array, range(ndmin - ndim(out_array)))                │
│                                                                                                  │
│ /home/**/.conda/envs/dtest/lib/python3.9/site-packages/jax/_src/lax/lax.py:559 in            │
│ _convert_element_type                                                                            │
│                                                                                                  │
│    556 │   │      isinstance(core.get_aval(operand), core.ConcreteArray))):                      │
│    557 │   return type_cast(Array, operand)                                                      │
│    558   else:                                                                                   │
│ ❱  559 │   return convert_element_type_p.bind(operand, new_dtype=new_dtype,                      │
│    560 │   │   │   │   │   │   │   │   │      weak_type=bool(weak_type))                         │
│    561                                                                                           │
│    562 def bitcast_convert_type(operand: ArrayLike, new_dtype: DTypeLike) -> Array:              │
│                                                                                                  │
│ /home/**/.conda/envs/dtest/lib/python3.9/site-packages/jax/_src/core.py:416 in bind          │
│                                                                                                  │
│    413   def bind(self, *args, **params):                                                        │
│    414 │   assert (not config.enable_checks.value or                                             │
│    415 │   │   │   all(isinstance(arg, Tracer) or valid_jaxtype(arg) for arg in args)), args     │
│ ❱  416 │   return self.bind_with_trace(find_top_trace(args), args, params)                       │
│    417                                                                                           │
│    418   def bind_with_trace(self, trace, args, params):                                         │
│    419 │   with pop_level(trace.level):                                                          │
│                                                                                                  │
│ /home/**/.conda/envs/dtest/lib/python3.9/site-packages/jax/_src/core.py:420 in               │
│ bind_with_trace                                                                                  │
│                                                                                                  │
│    417                                                                                           │
│    418   def bind_with_trace(self, trace, args, params):                                         │
│    419 │   with pop_level(trace.level):                                                          │
│ ❱  420 │     out = trace.process_primitive(self, map(trace.full_raise, args), params)            │
│    421 │   return map(full_lower, out) if self.multiple_results else full_lower(out)             │
│    422                                                                                           │
│    423   def def_impl(self, impl):                                                               │
│                                                                                                  │
│ /home/**/.conda/envs/dtest/lib/python3.9/site-packages/jax/_src/core.py:921 in               │
│ process_primitive                                                                                │
│                                                                                                  │
│    918 │     from jax.experimental.key_reuse._core import call_impl_with_key_reuse_checks  # py  │
│    919 │     return call_impl_with_key_reuse_checks(primitive, primitive.impl, *tracers, **para  │
│    920 │   else:                                                                                 │
│ ❱  921 │     return primitive.impl(*tracers, **params)                                           │
│    922                                                                                           │
│    923   def process_call(self, primitive, f, tracers, params):                                  │
│    924 │   if config.debug_key_reuse.value:                                                      │
│                                                                                                  │
│ /home/**/.conda/envs/dtest/lib/python3.9/site-packages/jax/_src/dispatch.py:87 in            │
│ apply_primitive                                                                                  │
│                                                                                                  │
│    84   # triggering the disable jit path instead of messing around with it here.                │
│    85   prev = lib.jax_jit.swap_thread_local_state_disable_jit(False)                            │
│    86   try:                                                                                     │
│ ❱  87 │   outs = fun(*args)                                                                      │
│    88   finally:                                                                                 │
│    89 │   lib.jax_jit.swap_thread_local_state_disable_jit(prev)                                  │
│    90   return outs                                                                              │
╰──────────────────────────────────────────────────────────────────────────────────────────────────╯
RuntimeError: Unable to initialize backend 'cuda': INTERNAL: no supported devices found for platform CUDA (you may need to uninstall the failing plugin package, or set JAX_PLATFORMS=cpu to skip this backend.)

I have already tried updating the jax config jax_platform to gpu after this error happened but it raises pretty much the same issue without the entire traceback of course. To explain this a little more:

When I activate the conda environment and import jax to make some testing calculations it just works fine and the output of jax.devices() shows the cuda device with ID 0. When I then run the script that causes this issue and check jax.devices() afterwards it only shows CPUdevice and no cuda device any longer and after that as explained I cant even set the platform back to GPU. However with CPUdevice it still crashes cause it is trying to run on a cuda device but I kinda suspect that once the issue happens it will be stuck in some bad state no matter what I do. Setting the platform for jax to cpu did not even work and jax.devices() still only showed the cuda device. Im also wondering why jax.devices() always shows only one of them but never that both are available. This Issue seems to be strongly related to JAX so I am hoping I can get help from here. For clarification, the code I am trying to run is related to the video prediction model viper from here https://github.com/neuronphysics/viper_rl and specifically this issue happens when trying to run the train_dreamer.py script within the scripts folder. Any help would be much appreciated as fixing this issue is of high priority to me.

System info (python version, jaxlib version, accelerator, etc.)

jax: 0.4.30
jaxlib: 0.4.30
numpy: 1.24.1
python: 3.9.19 | packaged by conda-forge | (main, Mar 20 2024, 12:50:21) [GCC 12.3.0]
jax.devices (1 total, 1 local): [cuda(id=0)]
process_count: 1

NVIDIA-SMI 550.90.07 Driver Version: 550.90.07 CUDA Version: 12.4

@NonsansWD NonsansWD added the bug Something isn't working label Sep 19, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

1 participant