Skip to content

Commit

Permalink
check against NaN loss in tacotron_loss
Browse files Browse the repository at this point in the history
  • Loading branch information
erogol committed Nov 2, 2020
1 parent ef04d7f commit b8ac9ab
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 4 deletions.
6 changes: 3 additions & 3 deletions .compute
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
#!/bin/bash
yes | apt-get install sox
yes | apt-get install ffmpeg
yes | apt-get install espeak
yes | apt-get install espeak
yes | apt-get install tmux
yes | apt-get install zsh
sh -c "$(curl -fsSL https://raw.githubusercontent.com/robbyrussell/oh-my-zsh/master/tools/install.sh)"
pip3 install https://download.pytorch.org/whl/cu100/torch-1.3.0%2Bcu100-cp36-cp36m-linux_x86_64.whl
sudo sh install.sh
pip install pytorch==1.3.0+cu100
python3 setup.py develop
# pip install pytorch==1.7.0+cu100
# python3 setup.py develop
# python3 distribute.py --config_path config.json --data_path /data/ro/shared/data/keithito/LJSpeech-1.1/
# cp -R ${USER_DIR}/Mozilla_22050 ../tmp/
# python3 distribute.py --config_path config_tacotron_gst.json --data_path ../tmp/Mozilla_22050/
Expand Down
7 changes: 6 additions & 1 deletion TTS/tts/layers/losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,7 @@ def forward(self, att_ws, ilens, olens):

@staticmethod
def _make_ga_mask(ilen, olen, sigma):
grid_x, grid_y = torch.meshgrid(torch.arange(olen), torch.arange(ilen))
grid_x, grid_y = torch.meshgrid(torch.arange(olen).to(olen), torch.arange(ilen).to(ilen))
grid_x, grid_y = grid_x.float(), grid_y.float()
return 1.0 - torch.exp(-(grid_y / ilen - grid_x / olen)**2 /
(2 * (sigma**2)))
Expand Down Expand Up @@ -373,6 +373,11 @@ def forward(self, postnet_output, decoder_output, mel_input, linear_input,
return_dict['postnet_ssim_loss'] = postnet_ssim_loss

return_dict['loss'] = loss

# check if any loss is NaN
for key, loss in return_dict.items():
if torch.isnan(loss):
raise RuntimeError(f" [!] NaN loss with {key}.")
return return_dict


Expand Down

0 comments on commit b8ac9ab

Please sign in to comment.