Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

RNN Symbol #2401

Closed
sbodenstein opened this issue Jun 12, 2016 · 33 comments
Closed

RNN Symbol #2401

sbodenstein opened this issue Jun 12, 2016 · 33 comments

Comments

@sbodenstein
Copy link
Contributor

sbodenstein commented Jun 12, 2016

I would like to help implement a RNN symbol (perhaps sym.RNN) which uses the cuDNN RNN implementation. This issue should be a place for everyone working on this to discuss and decide on various design/implementation issues.

As a starting point, the symbol operates on a rank-3 tensor of dim (batch, seqLength, inputDim) and returns a tensor of the dimensions (batch, seqLength, num_outputs). It takes the following args:

  • "layer_num": how many rnn layers to stack (positive int)
  • "rnn_type": one of ("rnn_relu", "rnn_tanh", "lstm", "gru") corresponding to the cuDNN enum (CUDNN_RNN_RELU, CUDNN_RNN_TANH, CUDNN_LSTM, CUDNN_GRU)
  • "bidirectional": boolean
  • "num_outputs": output RNN vector size (positive int)

There are two major tasks:

  • a cuDNN version of the symbol
  • a MShadow implementation for CPU + non cuDNN

I am happy to do the first task, but would really appreciate other contributors to also work on the MShadow implementation, which is the harder task.

I would also very much like to hear anyones thoughts on this before diving in: @piiswrong, @tqchen, @pluskid

@piiswrong
Copy link
Contributor

piiswrong commented Jun 13, 2016

Thank you very much for the contribution. This is an important feature.
The first should be straight forward. Like the other cudnn ops
An easy way to do the second is to chain up other operators in the c++ side. You just need to dispatch the Tblobs correctly.

@pluskid
Copy link
Contributor

pluskid commented Jun 13, 2016

@sbodenstein Thanks a lot!
@piiswrong does it have to be on the C++ side? I guess one could also construct the RNN subgraph in Python, and compose with other parts of the computational graph.

@piiswrong
Copy link
Contributor

piiswrong commented Jun 13, 2016

Yes you can do that but python doesn't know whether cudnn is available or not. And the symbol graph will be different with/without cudnn

@sbodenstein
Copy link
Contributor Author

@piiswrong: yes, it does look straightforward hooking up the cuDNN version. I will try get a first version in the next 2 weeks.

Is there an example somewhere of defining a layer by chaining together C++ ops, or some other example that would help with this? (And am I assuming correctly that this involves chaining ops with derivatives, so that one doesn't have to write a derivative expression explicitly?)

@sbodenstein
Copy link
Contributor Author

sbodenstein commented Jun 19, 2016

@piiswrong: there is a design issue with this symbol. As the "layer_num" is a parameter, the number of weights in this symbol can grow unboundedly.

One solution (used by cuDNN and other interfaces to it like Torch) is to use only a single weight vector (whose size depends on the recurrent type, etc).

The other solution is to expose all the parameters (a single GRU layer has 6 different weights and 6 biases...), and like concat symbol, let each of these take a list symbols, one per layer. The advantage of this is that one can more easily initialize certain weights/biases (though one could write a Python function that takes a RNN symbol weight, and converts it to all the individual weights + biases if needed for initialization etc).

What do others think?

@sbodenstein
Copy link
Contributor Author

Also of recent interest: Baidu just released this: https://github.com/baidu-research/persistent-rnn
The interface is modelled after cuDNN.

@piiswrong
Copy link
Contributor

@sbodenstein actually you can have a layer with dynimic num of params depending on parameter. Just return as many as you want in list_input

@sxjscience
Copy link
Member

sxjscience commented Jun 26, 2016

@sbodenstein What's the current progress? I can also help implement this OP and do some benchmarks.

@sbodenstein
Copy link
Contributor Author

@sxjscience: I have had conferences + travel in the last 2 weeks, and haven't had much time to finish the cuDNN op (but have made some progress, will work on it this week again).

Can you start with the CPU version? I will gladly help once I've finished the cuDNN version.

@sxjscience
Copy link
Member

@sbodenstein OK, we are planning to use the CuDNN version of LSTM in the DeepMark benchmark. Really appreciate your work!
I'm also thinking about the argument issue. Would it be simpler for the C++ side if we combine the weights together?

@sbodenstein
Copy link
Contributor Author

@sxjscience and @piiswrong: cuDNN expects a single parameter array. So if we expose multiple argument ports to the symbol, it will be on us to create a single array in the backend, which might involve copying. So this might be a strong argument to have only a single parameter port. Thoughts?

@piiswrong
Copy link
Contributor

Sure, one parameter array sounds fine

@sbodenstein
Copy link
Contributor Author

Great. Lets go with that for now.

@sxjscience
Copy link
Member

sxjscience commented Jul 5, 2016

@piiswrong @pluskid @sbodenstein I'm thinking whether it's necessary to write non-cudnn based RNN symbols in C++. Torch seems to only have included the CuDNN version of LSTM.

@sbodenstein
Copy link
Contributor Author

@sxjscience: it would be very convenient to have a C++ version, as this interface will probably be the simplest RNN interface for users (hence quite popular), and it will be very annoying to not be able to evaluate/deploy this on CPU (you would need to insert a complicated computational graph to replicate this).

Regarding Torch: the Torch cuDNN package only has cuDNN implementations, its up to other packed (like NN) to have non-cuDNN implementations as well. The RNN package is wanting to add a compatible version, at least for LSTM: Element-Research/rnn#239

@tqchen
Copy link
Member

tqchen commented Jul 5, 2016

I think we can first start with CuDNN version, then add a compatible GPU/CPU code as a second step

@sbodenstein
Copy link
Contributor Author

@tqchen: yes, it would also make writing a cuDNN compatible CPU/GPU version easier having this first.

I should have a first version of the cuDNN symbol by the end of this week.

@sxjscience
Copy link
Member

@sbodenstein Any update on this?

@antinucleon
Copy link
Contributor

@sxjscience I will do it today

@sbodenstein
Copy link
Contributor Author

@antinucleon: I'm working on this today again, and have a few days free in which to finish it.

To see what I have so far:

https://github.com/sbodenstein/mxnet/blob/feature/RNN_Symbol_cuDNN/src/operator/rnn-inl.h
https://github.com/sbodenstein/mxnet/blob/feature/RNN_Symbol_cuDNN/src/operator/cudnn_rnn-inl.h

Lets set a date: I should have a finished version by Wednesday. Also, any comments would be appreciated!

@sxjscience
Copy link
Member

Great!

@sbodenstein
Copy link
Contributor Author

@sxjscience @antinucleon: I've finished this, and doing some testing against cudnn.torch now. Should have a pull request later today once I'm happy that everything is correct.

Can discuss some of the outstanding design questions in the PR.

@sxjscience
Copy link
Member

@sbodenstein Thanks very much for your work!

@antinucleon
Copy link
Contributor

Awesome, will check it in morning
On Wed, Jul 20, 2016 at 06:48 Xingjian Shi [email protected] wrote:

@sbodenstein https://github.com/sbodenstein Thanks very much for your
work!


You are receiving this because you were mentioned.

Reply to this email directly, view it on GitHub
#2401 (comment), or mute
the thread
https://github.com/notifications/unsubscribe-auth/ABM13oikW2WWueQcMvqWWIOCSXn45z8_ks5qXic_gaJpZM4Iz1gn
.

Sent from mobile phone

@antinucleon
Copy link
Contributor

Thank you @sbodenstein !
Let's discuss after you create PR.

@sbodenstein
Copy link
Contributor Author

sbodenstein commented Jul 20, 2016

@antinucleon: the forward pass in inference mode is now consistent with Torch. But I'm having problems with the backward pass. Particularly, problems are caused by the dropout descriptor, which requires a permanent state attached at initialization (see cudnnSetDropoutDescriptor), and according to the docs, "No other function should be writing to the memory pointed at by states argument while this function is running. The user is expected not to change memory pointed at by states for the duration of the computation." If you use NULL for the states (and remove dropout rate as an optional parameters), it causes failures in cudnnRNNForwardTraining.

Its not simply a piece of workspace memory like the other examples, which get generic chunks of memory doing a forward or a backward pass. There is also no cuDNN dropout layer, so don't know how best to solve this from the other examples. cudaMalloc some memory? I would really appreciate your opinion on dealing with this issue.

@piiswrong
Copy link
Contributor

@sbodenstein Put it in an additional invisible output. See batchnorm for example. Could you open a PR first to make discussion easier?

@sxjscience
Copy link
Member

@sbodenstein @piiswrong @antinucleon I suggest wrapping RNN in the script language. We can write a stand-alone CuDNNRNN operator (only calls the rnn API in cuDNN) and build another python class to call this operator together with other operators like dropout. To decide whether to use CuDNNRNN, we can call an API function that determines the compiled cudnn version.

@magic282
Copy link

Hi, I noticed that there is a new sym API RNN, which seems to be utilizing the CuDNN v5 feature.
Is this finished now? I found that there are two src files, rnn-inl.h and cudnn_rnn-inl.h. And in rnn-inl.h I saw a TODO mark.

And, will mask be supported in the RNN symbol? btw, any usage examples for this new symbol?
Thanks.

@sbodenstein
Copy link
Contributor Author

@magic282: The cuDNN version is finished, but the TODO refers to adding a CPU version. There are also no examples yet.

Mask is not supported. How would you like the mask to work with this operator?

@magic282
Copy link

@sbodenstein Since I don't really know how to use the new symbol, I got no idea about the mask.
For now, the mask I implemented looks like this:

# stack LSTM
for i in range(self.num_of_layer):
    if i == 0:
        dp_ratio = 0.
    else:
        dp_ratio = self.dropout
    next_state = lstm(self.state_dim, indata=hidden,
                      prev_state=last_states[i],
                      param=param_cells[i],
                      seqidx=seq_idx, layeridx=i, dropout=dp_ratio)

    if self.use_masking:
        prev_state_h = last_states[i].h
        prev_state_c = last_states[i].c
        new_h = mx.sym.broadcast_mul(1.0 - mask, prev_state_h) + mx.sym.broadcast_mul(mask, next_state.h)
        new_c = mx.sym.broadcast_mul(1.0 - mask, prev_state_c) + mx.sym.broadcast_mul(mask, next_state.c)
        next_state = LSTMState(c=new_c, h=new_h)

    hidden = next_state.h
    last_states[i] = next_state

@sbodenstein
Copy link
Contributor Author

@magic282: masking isn't directly supported by the cuDNN API yet, although variable length sequences are (others, such as the maintainer of the Torch RNN package has made comments about this: soumith/cudnn.torch#210).

There are two major masking approaches:

  1. passing the mask info as an extra input tensor (like your example)
  2. use zero-masking (or some other value), like Torch uses (eg. https://github.com/Element-Research/rnn#maskzeroninputdim) and Keras supports (see this discussion: Masks for RNNs keras-team/keras#176).

I prefer the latter. We could make a feature request to NVIDIA regarding this.

@magic282
Copy link

@sbodenstein I think both of them are ok to me. BTW, have you tested the performance gain of CuDNN RNN using mxnet?

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Projects
None yet
Development

No branches or pull requests

7 participants