-
Notifications
You must be signed in to change notification settings - Fork 6.8k
CTC Loss Refactor #5559
Comments
'TNC' or 'NTC' |
BTW while you are at it, consider remove the mallocs and cudamallocs |
How about a new symbol
|
ctc_loss should be good. |
I'm worried that MXNet will get its own CTC implementation one day, and then there will be a name conflict. |
If we do implement our own, will it do the exact same thing? i.e. is ctc loss clearly defined mathematically or does everyone have their own variant? |
The definitions should be the same. There are cases where you might want two symbols:
If these are not major concerns, lets go with |
Ok then let's go with CTCLoss. You can put the implementation in src/operator/contrib and make it LOG(FATAL) << "Compile with baidu ctc to enable" when not linking against baidu ctc. Add a flag USE_BAIDUCTC in config.mk to turn it on. |
BTW since baidu ctc also uses Apache 2.0 license we can do a deeper integration and absorb their code. That way it doesn't need to stay a plugin |
That is a good point. It would simplify build issues for users. I will take a first stab at integrating it as a proper operator. |
@piiswrong: WarpCTC has a dependency on the header-only ModernGPU library https://github.com/moderngpu/moderngpu/wiki (see https://github.com/baidu-research/warp-ctc/tree/master/include/contrib/moderngpu) which is licensed under 3-clause BSD. How should this dependency be handled? Also: if we want to integrate the WarpCTC library into the MXNet, should we make minimal changes to the WarpCTC code (so that incorporating any updates/fixes is easy), or should some rewriting be done (eg getting rid of flags like |
@piiswrong: after thinking about this more, having a native CTC op will actually be the best way forward for MXNet. Let me know what to do about the moderngpu dependency though. I don't want to start adding this until I know what to do with this. |
@piiswrong: I would like to get going with this. Any update? |
moderngpu is header only so we can copy that too. Just need an separate license for the subdirectory |
Ok, great! |
a reminder, moderngpu could be possibly duplicated with cub. |
@tqchen: I will commit the first version using moderngpu, and it can be a secondary project to try get rid of the moderngpu dependency using cub instead. It should be faster as well. I might not have time for this: is there a place where one can propose projects that others who are interested can pick up? |
Is there any update about this refactor? Now I'm developing an ocr model and I find that warpctc is not convenient since it's return value is probability distribution and I don't know how to print its loss. While current implementation of ctc_loss is super slow(about more than 70 times slower than warpctc). |
@szha That's pretty cool! So the new CTC loss have performance improvement than the old one, right? Is it as fast as warpctc? |
@buaacszj the change I mentioned above is based on @sbodenstein's work on contrib CTC and my changes are mainly in improving usability, such as to remove the effect of undefined gradients inside padding area, and improve numerical stability by doing log softmax for the log probabilities. I haven't done thorough analysis on its speed. Do you have ideas on how it can be improved? |
As I see contrib CTC is implemented inside mxnet without dependence of baidu's warpctc. But as tested before, its performance is super slow when the alphabet_size is more than 5000. Does the contrib CTC use the same algorithm as warpctc? |
@buaacszj: the contrib CTC implementation is basically just Warp-CTC (with one or two smallish changes). The Warp-CTC implementation was copied and adapted so that it could be compiled together with MXNet (rather than annoyingly having to compile it separately). If it has any speed difference, that will surprising (could you give the exact configuration you saw these large differences in speed for? And was this GPU or CPU?) Also: @szha did some speed tests against cuDNN (#7442). I think that this tested against the contrib symbol CTC, which was quite competitive with cuDNN CTC. |
I am using two k40m GPUs. contrib CTC:
warpctc: |
@sbodenstein Please tell me if you need extra information. |
@buaacszj: a few more pieces of information:
My hypothesis is that it is actually using the cuDNN v7 CTC loss, which is slower for very large alphabet sizes. |
@sbodenstein sorry for replying late. here is the information of my cuda and cudnn:
have you tested that contrib CTC run as fast as baidu's warp-ctc for large alphabet sizes? |
@apache/mxnet-committers: This issue has been inactive for the past 90 days. It has no label and needs triage. For general "how-to" questions, our user forum (and Chinese version) is a good place to get help. |
We would like to use the WarpCTC loss plugin, but the current design makes is unuseable for us. It would also be good to make it compatible with the new loss API. I propose to change the CTC symbol so that:
[batch, seq len, feature]
or[seq len, batch, feature]
. The current design uses the strange[batch x seq len, feature]
[batch]
, so it outputs a loss per example in the batch, which allows example weighting. Currently it throws away the expensively computed loss.Some questions:
"layout"
in analogy withConvolution
and"layout"
can be either"BS"
or"SB"
? Or perhaps an option that isTrue
orFalse
might be simpler (eg"batch_first"
which is either true or false).I have time next week to refactor this. So would like to get a go-ahead with the design (and check that no one else is working on this, or another CTC loss implementation), and just make sure I have permission to break the current design.
@piiswrong, @taliesinb
The text was updated successfully, but these errors were encountered: