-
Notifications
You must be signed in to change notification settings - Fork 6.8k
Using cuDNN for CTC Loss #7445
Comments
Thanks for raising this, @sbodenstein. I'm working on using the cudnn7 implementation of CTC for GPU. |
@szha: do you agree that we should remove the WarpCTC CUDA implementation? |
@sbodenstein I agree. There is only one catch. It seems that the current WarpCTC supports variable-length inputs whereas cudnn7 only has the intention to support it. To elaborate, the current cudnn7 CTC API for getting workspace size looks like this:
However, if I give any
...so my guess is that this argument is left there so that variable input lengths will be supported going forward, because asking for a list of Ts doesn't make much sense otherwise. Back to whether we should remove the WarpCTC implementation, I think we need to first clarify with cudnn team on when the |
Ah, that is annoying! One other limitation I noticed:
whilst WarpCTC:
|
Good catch. Let me reflect this in the PR as well. |
so, the ctc of cudnn7 supports neither variable lengths inputs nor longer labellengths than 256. |
@szha: one annoying thing about you adding cuDNN support: if you build against cuDNN (which you always usually want to do), you automatically have to use the cuDNN WarpCTC implementation, which you might not want if you want to support variable length inputs. |
Current implementation still includes the WarpCTC implementation in the GPU version and only enables cudnn version when all input requirements are met, since the cudnn version is strictly more limited. |
Unfortunately I have to turn cudnn CTC off because of the API design. I have requested API changes to nv people and hopefully we could incorporate that once they change the API. |
@szha: why was that not enough? Why do you need to completely turn it off? |
The input length note being supported means that there's no way for me to enforce consistency between CPU and GPU implementation. cudnn version also doesn't reset the padding area gradients to zero, which can cause instability in training. It's too much of a risk to leave it be. |
@apache/mxnet-committers: This issue has been inactive for the past 90 days. It has no label and needs triage. For general "how-to" questions, our user forum (and Chinese version) is a good place to get help. |
cudnn integration for ctc requires action from cudnn team in improving the API. |
@piiswrong, @szha: Now that cuDNN 7 supports CTC loss, perhaps we should discard the current GPU implementation in contrib.ctc_loss (adapted from the WarpCTC implementation) and only use cuDNN for GPU? The main reasons:
I don't think the maintenance effort is worthwhile if almost every single user training with CUDA will have cuDNN.
What are your thoughts?
The text was updated successfully, but these errors were encountered: