-
Notifications
You must be signed in to change notification settings - Fork 6.8k
[MXNET-1121] Example to demonstrate the inference workflow using RNN #13680
Changes from 3 commits
9519a2d
8c06090
7a341f9
c19760a
438c3c4
c2f7a67
768fe5e
d8abf83
fd33d22
e31a5bd
c53d329
9d315a2
c774f3c
8d88feb
6d631b3
0d00c74
03744a3
ff5fca3
4ffd4a9
bd6fad5
d5119c2
4bafe95
1da9482
4393f18
45cbba9
bf48c42
bd35b40
b198339
4897901
bb14d79
a487ca9
c3cace1
7d4a173
b1e074f
0f155e9
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -39,3 +39,48 @@ Alternatively, The script [unit_test_inception_inference.sh](<https://github.com | |
``` | ||
./unit_test_inception_inference.sh | ||
``` | ||
|
||
### [simple_rnn.cpp](<https://github.com/apache/incubator-mxnet/blob/master/cpp-package/example/inference/simple_rnn.cpp>) | ||
This example demonstrates sequence prediction workflow with pre-trained RNN model using MXNet C++ API. The purpose of this example is to demonstrate how a pre-trained RNN model can be loaded and used to generate an output sequence using C++ API. | ||
leleamol marked this conversation as resolved.
Show resolved
Hide resolved
|
||
The example performs following tasks | ||
leleamol marked this conversation as resolved.
Show resolved
Hide resolved
|
||
- Load the pre-trained RNN model. | ||
leleamol marked this conversation as resolved.
Show resolved
Hide resolved
|
||
- Load the dictionary file that contains word to index mapping. | ||
leleamol marked this conversation as resolved.
Show resolved
Hide resolved
|
||
- Convert the input string to vector of indices and padded to match the input data length. | ||
leleamol marked this conversation as resolved.
Show resolved
Hide resolved
|
||
- Run the forward pass and predict the output string. | ||
leleamol marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
The example uses a pre-trained RNN model that is trained with the dataset containing speeches given by Obama. | ||
leleamol marked this conversation as resolved.
Show resolved
Hide resolved
|
||
The model consists of : | ||
- Embedding Layer with the size of embedding to be 650 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. A simple image will be helpful There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I can get the pdf version of model (generated using visualiation API). Need some suggestions from @aaronmarkham to embed them in README.
leleamol marked this conversation as resolved.
Show resolved
Hide resolved
|
||
- 3 LSTM Layers with hidden dimension size of 650 and sequence length of 35 | ||
- FullyConnected Layer | ||
- SoftmaxOutput | ||
The model was trained for 100 epochs. | ||
|
||
The model files can be found here. | ||
- [obama-speaks-symbol.json](<https://s3.amazonaws.com/mxnet-cpp/RNN_model/obama-speaks-symbol.json>) | ||
- [obama-speaks-0100.params](<https://s3.amazonaws.com/mxnet-cpp/RNN_model/obama-speaks-0100.params>) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why in mxnet-cpp bucket why not in mxnet pretrained models bucket? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The current bucket mxnet-cpp is not a public by default but the contents are made publicly readable. This is similar to mxnet-scala bucket used for scala examples. |
||
- [obama.dictionary.txt](<https://s3.amazonaws.com/mxnet-cpp/RNN_model/obama.dictionary.txt>) Each line of the dictionary file contains a word and a unique index for that word, separated by a space, with a total of 14293 words generated from the training dataset. | ||
The example downloads the above files while running. | ||
|
||
The example's command line parameters are as shown below: | ||
|
||
``` | ||
./simple_rnn --help | ||
Usage: | ||
simple_rnn | ||
[--input] Input string sequence. | ||
[--gpu] Specify this option if workflow needs to be run in gpu context. | ||
|
||
./simple_rnn | ||
|
||
or | ||
|
||
./simple_rnn --input "Good morning. I appreciate the opportunity to speak here" | ||
``` | ||
|
||
The example will output the seqence of 35 words as follows: | ||
leleamol marked this conversation as resolved.
Show resolved
Hide resolved
|
||
``` | ||
[waters elected Amendment Amendment Amendment Amendment retirement maximize maximize maximize acr citi sophisticatio sophisticatio sophisticatio sophisticatio sophisticatio sophisticatio sophisticatio sophisticatio sophisticatio sophisticatio sophisticatio sophisticatio sophisticatio sophisticatio sophisticatio sophisticatio sophisticatio sophisticatio sophisticatio sophisticatio sophisticatio sophisticatio sophisticatio ] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why is the output bad? Can we give example of a well trained model? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I tried to get better output by changing model hyperparameters but couldn't get it. It would require a good amount of input data processing as well. All these efforts would require dedicated time and out of scope for this example. The example aims towards loading the model and running forward pass. Improving on the model would be a separate task. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I am working on implementing the RNN model using C++ API. I can work on improving the accuracy of that model and use it in this example later. |
||
``` | ||
|
||
Alternatively, user can run [unit_test_simple_rnn.sh](<https://github.com/apache/incubator-mxnet/blob/master/cpp-package/example/inference/unit_test_simple_rnn.sh>) script. | ||
leleamol marked this conversation as resolved.
Show resolved
Hide resolved
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Probably can be made more simple for making it easier to read and follow.
@aaronmarkham - Can you please help us here with the documentation? Thanks.