Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

When I import jaxmarl, some error happens. #41

Closed
zyh1999 opened this issue Nov 22, 2023 · 6 comments
Closed

When I import jaxmarl, some error happens. #41

zyh1999 opened this issue Nov 22, 2023 · 6 comments
Assignees

Comments

@zyh1999
Copy link

zyh1999 commented Nov 22, 2023

No description provided.

@zyh1999
Copy link
Author

zyh1999 commented Nov 22, 2023

I download jax 0.4.11 and jaxlib-0.4.11+cuda11.cudnn86-cp38-cp38-manylinux2014_x86_64.whl, and then when I import jaxmarl,
some error happens:
Traceback (most recent call last): File "<stdin>", line 1, in <module> File "/disk1/zyh/jaxmarl/jaxmarl/__init__.py", line 1, in <module> from .registration import make, registered_envs File "/disk1/zyh/jaxmarl/jaxmarl/registration.py", line 1, in <module> from .environments import ( File "/disk1/zyh/jaxmarl/jaxmarl/environments/__init__.py", line 2, in <module> from .mpe import ( File "/disk1/zyh/jaxmarl/jaxmarl/environments/mpe/__init__.py", line 5, in <module> from .simple_push import SimplePushMPE File "/disk1/zyh/jaxmarl/jaxmarl/environments/mpe/simple_push.py", line 13, in <module> OBS_COLOUR = jnp.concatenate([COLOUR_1, COLOUR_2]) File "/disk1/zyh/miniconda3/envs/jax_marl/lib/python3.8/site-packages/jax/_src/numpy/lax_numpy.py", line 1837, in concatenate arrays_out = [lax.concatenate(arrays_out[i:i+k], axis) File "/disk1/zyh/miniconda3/envs/jax_marl/lib/python3.8/site-packages/jax/_src/numpy/lax_numpy.py", line 1837, in <listcomp> arrays_out = [lax.concatenate(arrays_out[i:i+k], axis) File "/disk1/zyh/miniconda3/envs/jax_marl/lib/python3.8/site-packages/jax/_src/lax/lax.py", line 613, in concatenate return concatenate_p.bind(*operands, dimension=dimension) File "/disk1/zyh/miniconda3/envs/jax_marl/lib/python3.8/site-packages/jax/_src/core.py", line 380, in bind return self.bind_with_trace(find_top_trace(args), args, params) File "/disk1/zyh/miniconda3/envs/jax_marl/lib/python3.8/site-packages/jax/_src/core.py", line 383, in bind_with_trace out = trace.process_primitive(self, map(trace.full_raise, args), params) File "/disk1/zyh/miniconda3/envs/jax_marl/lib/python3.8/site-packages/jax/_src/core.py", line 815, in process_primitive return primitive.impl(*tracers, **params) File "/disk1/zyh/miniconda3/envs/jax_marl/lib/python3.8/site-packages/jax/_src/dispatch.py", line 132, in apply_primitive compiled_fun = xla_primitive_callable( File "/disk1/zyh/miniconda3/envs/jax_marl/lib/python3.8/site-packages/jax/_src/util.py", line 284, in wrapper return cached(config._trace_context(), *args, **kwargs) File "/disk1/zyh/miniconda3/envs/jax_marl/lib/python3.8/site-packages/jax/_src/util.py", line 277, in cached return f(*args, **kwargs) File "/disk1/zyh/miniconda3/envs/jax_marl/lib/python3.8/site-packages/jax/_src/dispatch.py", line 223, in xla_primitive_callable compiled = _xla_callable_uncached( File "/disk1/zyh/miniconda3/envs/jax_marl/lib/python3.8/site-packages/jax/_src/dispatch.py", line 253, in _xla_callable_uncached return computation.compile().unsafe_call File "/disk1/zyh/miniconda3/envs/jax_marl/lib/python3.8/site-packages/jax/_src/interpreters/pxla.py", line 2323, in compile executable = UnloadedMeshExecutable.from_hlo( File "/disk1/zyh/miniconda3/envs/jax_marl/lib/python3.8/site-packages/jax/_src/interpreters/pxla.py", line 2645, in from_hlo xla_executable, compile_options = _cached_compilation( File "/disk1/zyh/miniconda3/envs/jax_marl/lib/python3.8/site-packages/jax/_src/interpreters/pxla.py", line 2555, in _cached_compilation xla_executable = dispatch.compile_or_get_cached( File "/disk1/zyh/miniconda3/envs/jax_marl/lib/python3.8/site-packages/jax/_src/dispatch.py", line 497, in compile_or_get_cached return backend_compile(backend, computation, compile_options, File "/disk1/zyh/miniconda3/envs/jax_marl/lib/python3.8/site-packages/jax/_src/profiler.py", line 314, in wrapper return func(*args, **kwargs) File "/disk1/zyh/miniconda3/envs/jax_marl/lib/python3.8/site-packages/jax/_src/dispatch.py", line 465, in backend_compile return backend.compile(built_c, compile_options=options) jaxlib.xla_extension.XlaRuntimeError: FAILED_PRECONDITION: DNN library initialization failed. Look at the errors above for more details.
My cudatoolkit version is 11.4

@zyh1999 zyh1999 changed the title May I ask which version of CUDAToolkit you are using? When I import jaxmarl, some error happens, do I need to download cudnn from the official Nvidia website for this package? Nov 22, 2023
@zyh1999 zyh1999 changed the title When I import jaxmarl, some error happens, do I need to download cudnn from the official Nvidia website for this package? When I import jaxmarl, some error happens. Nov 22, 2023
@amacrutherford
Copy link
Collaborator

Ah looks like a jax issue? does the jax install work outside of importing jaxmarl?

@zyh1999
Copy link
Author

zyh1999 commented Nov 22, 2023

when just import jax, some basic operations can work. But once import jaxmarl, the error above happens.

@amacrutherford
Copy link
Collaborator

judging from the error message it does look like a jax issue. From the error message, the error should be reproduced if you run the following:

import jax.numpy as jnp
COLOUR_1 = jnp.array([0.1, 0.9, 0.1])
COLOUR_2 = jnp.array([0.1, 0.1, 0.9])  
OBS_COLOUR = jnp.concatenate([COLOUR_1, COLOUR_2])

I would double check your jax installation, or otherwise take a look at this issue thread from the JAX repo: jax-ml/jax#15361

@amacrutherford
Copy link
Collaborator

Hey! how did you get on?

@alexunderch
Copy link
Contributor

when installing from source (applying a docker image to reproduce)

FROM nvidia/cuda:11.8.0-devel-ubuntu22.04

# install python
ARG DEBIAN_FRONTEND=noninteractive
ARG PYTHON_VERSION=3.10

RUN apt-get update && \
  DEBIAN_FRONTEND=noninteractive apt-get -qq -y install \
  software-properties-common \
  build-essential \
  curl \
  git \
  vim \
  wget \
  && apt-get clean \
  && rm -rf /var/lib/apt/lists/*

  RUN add-apt-repository ppa:deadsnakes/ppa
  RUN apt-get update && apt-get install -y -qq python${PYTHON_VERSION} \
      python${PYTHON_VERSION}-dev \
      python${PYTHON_VERSION}-distutils 

# Set python aliases
RUN update-alternatives --install /usr/bin/python python /usr/bin/python${PYTHON_VERSION} 1
RUN update-alternatives --install /usr/bin/python3 python3 /usr/bin/python${PYTHON_VERSION} 1
RUN curl https://bootstrap.pypa.io/get-pip.py -o get-pip.py && python get-pip.py

# default workdir
WORKDIR /home/workdir

#installing jaxmarl from source
RUN git clone https://github.com/FLAIROx/JaxMARL.git && cd JaxMARL && pip install --ignore-installed -e .

CMD ["/bin/bash"]

If in vscode I try to import jaxmarl outside the actual repository the editor cannot find it, but the code runs fine.
The solution is to add a line like here

export PYTHONPATH=./JaxMARL:$PYTHONPATH

Maybe to add a line with a hint to the readme?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants