Skip to content

Commit

Permalink
hybridize rnn and add model graph (apache#13244)
Browse files Browse the repository at this point in the history
* hybridize rnn and add model graph

* trigger CI

* separate mxboard visualization

* add options and she-bang

* add defaults

* trigger CI

* rename export-model
  • Loading branch information
yifeim authored and stephenrawls committed Feb 16, 2019
1 parent cf19784 commit 973bc25
Show file tree
Hide file tree
Showing 4 changed files with 56 additions and 7 deletions.
26 changes: 24 additions & 2 deletions example/gluon/word_language_model/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,9 @@ python train.py --cuda --tied --nhid 650 --emsize 650 --epochs 40 --dropout 0.5
```
python train.py --cuda --tied --nhid 1500 --emsize 1500 --epochs 60 --dropout 0.65 # Test ppl of 88.42
```

```
python train.py --export-model # hybridize and export model graph. See below for visualization options.
```

<br>

Expand All @@ -38,7 +40,8 @@ usage: train.py [-h] [--model MODEL] [--emsize EMSIZE] [--nhid NHID]
[--nlayers NLAYERS] [--lr LR] [--clip CLIP] [--epochs EPOCHS]
[--batch_size N] [--bptt BPTT] [--dropout DROPOUT] [--tied]
[--cuda] [--log-interval N] [--save SAVE] [--gctype GCTYPE]
[--gcthreshold GCTHRESHOLD]
[--gcthreshold GCTHRESHOLD] [--hybridize] [--static-alloc]
[--static-shape] [--export-model]
MXNet Autograd RNN/LSTM Language Model on Wikitext-2.
Expand All @@ -62,4 +65,23 @@ optional arguments:
`none` for now.
--gcthreshold GCTHRESHOLD
threshold for 2bit gradient compression
--hybridize whether to hybridize in mxnet>=1.3 (default=False)
--static-alloc whether to use static-alloc hybridize in mxnet>=1.3
(default=False)
--static-shape whether to use static-shape hybridize in mxnet>=1.3
(default=False)
--export-model export a symbol graph and exit (default=False)
```

You may visualize the graph with `mxnet.viz.plot_network` without any additional dependencies. Alternatively, if [mxboard](https://github.com/awslabs/mxboard) is installed, use the following approach for interactive visualization.
```python
#!python
import mxnet, mxboard
with mxboard.SummaryWriter(logdir='./model-graph') as sw:
sw.add_graph(mxnet.sym.load('./model-symbol.json'))
```
```bash
#!/bin/bash
tensorboard --logdir=./model-graph/
```
![model graph](./model-graph.png?raw=true "rnn model graph")
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
4 changes: 2 additions & 2 deletions example/gluon/word_language_model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from mxnet import gluon
from mxnet.gluon import nn, rnn

class RNNModel(gluon.Block):
class RNNModel(gluon.HybridBlock):
"""A model with an encoder, recurrent layer, and a decoder."""

def __init__(self, mode, vocab_size, num_embed, num_hidden,
Expand Down Expand Up @@ -53,7 +53,7 @@ def __init__(self, mode, vocab_size, num_embed, num_hidden,

self.num_hidden = num_hidden

def forward(self, inputs, hidden):
def hybrid_forward(self, F, inputs, hidden):
emb = self.drop(self.encoder(inputs))
output, hidden = self.rnn(emb, hidden)
output = self.drop(output)
Expand Down
33 changes: 30 additions & 3 deletions example/gluon/word_language_model/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,14 @@
takes `2bit` or `none` for now.')
parser.add_argument('--gcthreshold', type=float, default=0.5,
help='threshold for 2bit gradient compression')
parser.add_argument('--hybridize', action='store_true',
help='whether to hybridize in mxnet>=1.3 (default=False)')
parser.add_argument('--static-alloc', action='store_true',
help='whether to use static-alloc hybridize in mxnet>=1.3 (default=False)')
parser.add_argument('--static-shape', action='store_true',
help='whether to use static-shape hybridize in mxnet>=1.3 (default=False)')
parser.add_argument('--export-model', action='store_true',
help='export a symbol graph and exit (default=False)')
args = parser.parse_args()

print(args)
Expand All @@ -72,6 +80,15 @@
else:
context = mx.cpu(0)

if args.export_model:
args.hybridize = True

# optional parameters only for mxnet >= 1.3
hybridize_optional = dict(filter(lambda kv:kv[1],
{'static_alloc':args.static_alloc, 'static_shape':args.static_shape}.items()))
if args.hybridize:
print('hybridize_optional', hybridize_optional)

dirname = './data'
dirname = os.path.expanduser(dirname)
if not os.path.exists(dirname):
Expand Down Expand Up @@ -114,6 +131,8 @@
ntokens = len(vocab)
model = model.RNNModel(args.model, ntokens, args.emsize, args.nhid,
args.nlayers, args.dropout, args.tied)
if args.hybridize:
model.hybridize(**hybridize_optional)
model.initialize(mx.init.Xavier(), ctx=context)

compression_params = None if args.gctype == 'none' else {'type': args.gctype, 'threshold': args.gcthreshold}
Expand All @@ -123,6 +142,8 @@
'wd': 0},
compression_params=compression_params)
loss = gluon.loss.SoftmaxCrossEntropyLoss()
if args.hybridize:
loss.hybridize(**hybridize_optional)

###############################################################################
# Training code
Expand Down Expand Up @@ -177,6 +198,10 @@ def train():
epoch, i, cur_L, math.exp(cur_L)))
total_L = 0.0

if args.export_model:
model.export('model')
return

val_L = eval(val_data)

print('[Epoch %d] time cost %.2fs, valid loss %.2f, valid ppl %.2f'%(
Expand All @@ -193,6 +218,8 @@ def train():

if __name__ == '__main__':
train()
model.load_parameters(args.save, context)
test_L = eval(test_data)
print('Best test loss %.2f, test ppl %.2f'%(test_L, math.exp(test_L)))
if not args.export_model:
model.load_parameters(args.save, context)
test_L = eval(test_data)
print('Best test loss %.2f, test ppl %.2f'%(test_L, math.exp(test_L)))

0 comments on commit 973bc25

Please sign in to comment.