Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

CTC Loss Refactor #5559

Closed
sbodenstein opened this issue Mar 24, 2017 · 29 comments
Closed

CTC Loss Refactor #5559

sbodenstein opened this issue Mar 24, 2017 · 29 comments
Labels

Comments

@sbodenstein
Copy link
Contributor

We would like to use the WarpCTC loss plugin, but the current design makes is unuseable for us. It would also be good to make it compatible with the new loss API. I propose to change the CTC symbol so that:

  • The input has dims of the form [batch, seq len, feature] or [seq len, batch, feature]. The current design uses the strange [batch x seq len, feature]
  • Output dim: [batch], so it outputs a loss per example in the batch, which allows example weighting. Currently it throws away the expensively computed loss.

Some questions:

  • What should the option be called switching between batch first or seq len first layout? This seems like a useful option for many sequence ops (eg. the cuDNN rnn symbol). How about "layout" in analogy with Convolution and "layout" can be either "BS" or "SB"? Or perhaps an option that is True or False might be simpler (eg "batch_first" which is either true or false).

I have time next week to refactor this. So would like to get a go-ahead with the design (and check that no one else is working on this, or another CTC loss implementation), and just make sure I have permission to break the current design.

@piiswrong, @taliesinb

@piiswrong
Copy link
Contributor

piiswrong commented Mar 24, 2017

'TNC' or 'NTC'
The proposal sounds good. The tricky thing is backward compatibility.

@piiswrong
Copy link
Contributor

BTW while you are at it, consider remove the mallocs and cudamallocs

@sbodenstein
Copy link
Contributor Author

How about a new symbol WarpCTCLoss, and deprecate the symbol WarpCTC (or just leave it)?

BTW while you are at it, consider remove the mallocs and cudamallocs
Sure.

@piiswrong
Copy link
Contributor

ctc_loss should be good.

@sbodenstein
Copy link
Contributor Author

I'm worried that MXNet will get its own CTC implementation one day, and then there will be a name conflict.

@piiswrong
Copy link
Contributor

If we do implement our own, will it do the exact same thing? i.e. is ctc loss clearly defined mathematically or does everyone have their own variant?

@sbodenstein
Copy link
Contributor Author

The definitions should be the same. There are cases where you might want two symbols:

  • WarpCTC might be faster than MXNet implementation, so some users might still want to use WarpCTC (whilst MXNet version might work on more platforms)
  • WarpCTC allows variable-length probability and label inputs. We might expose this in the warpctc symbol, but don't want to be bound implementing it for a future MXNet version.

If these are not major concerns, lets go with ctc_loss, otherwise warp_ctc_loss.

@piiswrong
Copy link
Contributor

Ok then let's go with CTCLoss. You can put the implementation in src/operator/contrib and make it LOG(FATAL) << "Compile with baidu ctc to enable" when not linking against baidu ctc. Add a flag USE_BAIDUCTC in config.mk to turn it on.

@piiswrong
Copy link
Contributor

BTW since baidu ctc also uses Apache 2.0 license we can do a deeper integration and absorb their code. That way it doesn't need to stay a plugin

@sbodenstein
Copy link
Contributor Author

That is a good point. It would simplify build issues for users. I will take a first stab at integrating it as a proper operator.

@sbodenstein
Copy link
Contributor Author

sbodenstein commented Apr 2, 2017

@piiswrong: WarpCTC has a dependency on the header-only ModernGPU library https://github.com/moderngpu/moderngpu/wiki (see https://github.com/baidu-research/warp-ctc/tree/master/include/contrib/moderngpu) which is licensed under 3-clause BSD.

How should this dependency be handled?

Also: if we want to integrate the WarpCTC library into the MXNet, should we make minimal changes to the WarpCTC code (so that incorporating any updates/fixes is easy), or should some rewriting be done (eg getting rid of flags like CTC_DISABLE_OMP and replacing them with appropriate MXNet versions)?

@sbodenstein
Copy link
Contributor Author

@piiswrong: after thinking about this more, having a native CTC op will actually be the best way forward for MXNet. Let me know what to do about the moderngpu dependency though. I don't want to start adding this until I know what to do with this.

@sbodenstein
Copy link
Contributor Author

@piiswrong: I would like to get going with this. Any update?

@piiswrong
Copy link
Contributor

moderngpu is header only so we can copy that too. Just need an separate license for the subdirectory

@sbodenstein
Copy link
Contributor Author

Ok, great!

@tqchen
Copy link
Member

tqchen commented Apr 7, 2017

a reminder, moderngpu could be possibly duplicated with cub.

@sbodenstein
Copy link
Contributor Author

sbodenstein commented Apr 9, 2017

@tqchen: I will commit the first version using moderngpu, and it can be a secondary project to try get rid of the moderngpu dependency using cub instead. It should be faster as well.

I might not have time for this: is there a place where one can propose projects that others who are interested can pick up?

@buaacszj
Copy link

buaacszj commented Sep 7, 2017

Is there any update about this refactor?

Now I'm developing an ocr model and I find that warpctc is not convenient since it's return value is probability distribution and I don't know how to print its loss. While current implementation of ctc_loss is super slow(about more than 70 times slower than warpctc).

@szha
Copy link
Member

szha commented Sep 7, 2017

@buaacszj yes, I've been working on it for CTC loss in contrib. Things should stabilize after #7727.

@buaacszj
Copy link

buaacszj commented Sep 7, 2017

@szha That's pretty cool! So the new CTC loss have performance improvement than the old one, right? Is it as fast as warpctc?

@szha
Copy link
Member

szha commented Sep 7, 2017

@buaacszj the change I mentioned above is based on @sbodenstein's work on contrib CTC and my changes are mainly in improving usability, such as to remove the effect of undefined gradients inside padding area, and improve numerical stability by doing log softmax for the log probabilities. I haven't done thorough analysis on its speed. Do you have ideas on how it can be improved?

@buaacszj
Copy link

buaacszj commented Sep 7, 2017

As I see contrib CTC is implemented inside mxnet without dependence of baidu's warpctc. But as tested before, its performance is super slow when the alphabet_size is more than 5000. Does the contrib CTC use the same algorithm as warpctc?

@sbodenstein
Copy link
Contributor Author

sbodenstein commented Sep 7, 2017

As I see contrib CTC is implemented inside mxnet without dependence of baidu's warpctc.

@buaacszj: the contrib CTC implementation is basically just Warp-CTC (with one or two smallish changes). The Warp-CTC implementation was copied and adapted so that it could be compiled together with MXNet (rather than annoyingly having to compile it separately). If it has any speed difference, that will surprising (could you give the exact configuration you saw these large differences in speed for? And was this GPU or CPU?)

Also: @szha did some speed tests against cuDNN (#7442). I think that this tested against the contrib symbol CTC, which was quite competitive with cuDNN CTC.

@buaacszj
Copy link

buaacszj commented Sep 8, 2017

I am using two k40m GPUs.
And the network is CNN+LSTM+CTC, batch size is 128.
When alphabet_size is 15, the contrib CTC is fast enough and I didn't test warp-ctc. While if alphabet_size is more than 5000, it costs about 380-400 seconds per batch. But warpctc only costs 5-6 seconds per batch.
Here is the code for calling these two api.

contrib CTC:
ctc = mx.contrib.symbol.ctc_loss(data=fc1_transpose, label=labels, name='ctc_loss')
ctc_loss = mx.sym.MakeLoss(ctc)

  softmax_class = mx.symbol.SoftmaxActivation(data=fc1)
  softmax_loss = mx.sym.MakeLoss(softmax_class)
  softmax_loss = mx.sym.BlockGrad(softmax_loss)

  group = mx.sym.Group([softmax_loss, ctc_loss])

warpctc:
fc1_forwarpctc = mx.symbol.reshape(data=fc1_transpose, shape=(-1, num_classes))
labels = mx.sym.Reshape(data=labels, shape=(-1,))
labels = mx.sym.Cast(data = labels, dtype = 'int32')
ctc = mx.sym.WarpCTC(data=fc1_forwarpctc, label=labels, label_length=label_len, input_length=SEQ_LENGTH)

@buaacszj
Copy link

buaacszj commented Sep 8, 2017

@sbodenstein Please tell me if you need extra information.

@sbodenstein
Copy link
Contributor Author

sbodenstein commented Sep 8, 2017

@buaacszj: a few more pieces of information:

  • which cuDNN version are you using? (ie are you using cuDNN v7?)
  • And which commit of MXNet are you using?
  • Could you also give the shape of your inputs to the CTC layer? (eg sequence length, etc)

My hypothesis is that it is actually using the cuDNN v7 CTC loss, which is slower for very large alphabet sizes.

@buaacszj
Copy link

@sbodenstein sorry for replying late.

here is the information of my cuda and cudnn:

  • Cuda compilation tools, release 8.0, V8.0.61

  • libcudnn.so and libcudnn.so.5 in cuda/lib64

  1. I believe I am using CUDA8.0 and cuDNN 5.
  2. I am using the master branch of MXNet
  3. the sequence length of the CTC layer is 140

have you tested that contrib CTC run as fast as baidu's warp-ctc for large alphabet sizes?

@szha
Copy link
Member

szha commented Dec 22, 2017

@apache/mxnet-committers: This issue has been inactive for the past 90 days. It has no label and needs triage.

For general "how-to" questions, our user forum (and Chinese version) is a good place to get help.

@yzhliu
Copy link
Member

yzhliu commented Apr 11, 2018

Looks like the usability issue has been resolved by @szha 's PR #7727 . If you still have performance issue of CTC loss in MXNet, feel free to open another issue. @buaacszj

@yzhliu yzhliu closed this as completed Apr 11, 2018
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
Projects
None yet
Development

No branches or pull requests

6 participants