You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I can reproduce this both on a fresh conda env with conda install jax cuda-nvcc -c conda-forge -c nvidia:
# Name Version Build Channel
cuda-nvcc 12.0.76 0 nvidia
jax 0.4.1 pyhd8ed1ab_0 conda-forge
jaxlib 0.4.1 cuda112py310hf0bc174_201 conda-forge
and in a Docker container with this Dockerfile:
FROM nvcr.io/nvidia/tensorflow:22.12-tf2-py3
RUN pip install --upgrade pip
RUN pip install --upgrade "jax[cuda]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
reacted with thumbs up emoji reacted with thumbs down emoji reacted with laugh emoji reacted with hooray emoji reacted with confused emoji reacted with heart emoji reacted with rocket emoji reacted with eyes emoji
-
I'd like to use 4D convolution. This minimal example
gives the following error:
If I change the example to 3D:
then I get the correct output.
I can reproduce this both on a fresh
conda
env withconda install jax cuda-nvcc -c conda-forge -c nvidia
:and in a Docker container with this
Dockerfile
:Am I doing something wrong or this is a bug?
Beta Was this translation helpful? Give feedback.
All reactions