-
Notifications
You must be signed in to change notification settings - Fork 761
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
StopIteration: Caught StopIteration in replica 0 on device 0. #123
Comments
I think this is the problem. |
I run into the same error. I wonder is this solved? Thanks. |
You can downgrade your torch to 1.4.0, which works fine for me (hint: you might have to change your cuda toolkit to lower versions, too). |
I confirm with the following env: name: pt1.4
channels:
- pytorch
- salilab
- conda-forge
- bioconda
- defaults
dependencies:
- _libgcc_mutex=0.1=conda_forge
- _openmp_mutex=4.5=1_llvm
- ca-certificates=2020.12.5=ha878542_0
- certifi=2020.12.5=py38h578d9bd_1
- cudatoolkit=10.1.243=h036e899_8
- freetype=2.10.4=h0708190_1
- jpeg=9d=h36c2ea0_0
- lcms2=2.12=hddcbb42_0
- ld_impl_linux-64=2.35.1=hea4e1c9_2
- libblas=3.9.0=8_openblas
- libcblas=3.9.0=8_openblas
- libffi=3.3=h58526e2_2
- libgcc-ng=9.3.0=h2828fa1_18
- libgfortran-ng=9.3.0=hff62375_18
- libgfortran5=9.3.0=hff62375_18
- liblapack=3.9.0=8_openblas
- libopenblas=0.3.12=pthreads_h4812303_1
- libpng=1.6.37=h21135ba_2
- libstdcxx-ng=9.3.0=h6de172a_18
- libtiff=4.2.0=hdc55705_0
- libwebp-base=1.2.0=h7f98852_2
- llvm-openmp=11.1.0=h4bd325d_0
- lz4-c=1.9.3=h9c3ff4c_0
- mkl=2020.4=h726a3e6_304
- ncurses=6.2=h58526e2_4
- ninja=1.10.2=h4bd325d_0
- numpy=1.20.1=py38h18fd61f_0
- olefile=0.46=pyh9f0ad1d_1
- openssl=1.1.1j=h7f98852_0
- pillow=8.1.2=py38ha0e1e83_0
- pip=21.0.1=pyhd8ed1ab_0
- python=3.8.8=hffdb5ce_0_cpython
- python_abi=3.8=1_cp38
- pytorch=1.4.0=py3.8_cuda10.1.243_cudnn7.6.3_0
- readline=8.0=he28a2e2_2
- setuptools=49.6.0=py38h578d9bd_3
- six=1.15.0=pyh9f0ad1d_0
- sqlite=3.35.2=h74cdb3f_0
- tk=8.6.10=h21135ba_1
- torchvision=0.5.0=py38_cu101
- wheel=0.36.2=pyhd3deb0d_0
- xz=5.2.5=h516909a_1
- zlib=1.2.11=h516909a_1010
- zstd=1.4.9=ha95c52a_0 with more than 1 GPU cards (otherwise one will get a dividing by 0 error) with mem_transformer.py line 754 changed to loss = self.crit(pred_hid.reshape(-1, pred_hid.size(-1)), target.reshape(-1)) (use it can run through for |
when running with bash run_wt103_base.sh train --work_dir TRAIN_wt103, the same problem happens to me as well. def init_mems(self): (3) change line 754 loss = self.crit(pred_hid.view(-1, pred_hid.size(-1)), target.view(-1)) to loss = self.crit(pred_hid.reshape(-1, pred_hid.size(-1)), target.reshape(-1)). Note all the changes are made in mem_transformer.py. |
Thanks a lot for this solution. I also have the same bug when using pytorch=2.0.0. And this solution works well for me. |
No description provided.
The text was updated successfully, but these errors were encountered: