Skip to content

Commit

Permalink
add ctx to begin_state in rnn_layer (apache#7580)
Browse files Browse the repository at this point in the history
* add ctx to begin_state

* fix image classification
  • Loading branch information
szha authored and crazy-cat committed Oct 26, 2017
1 parent 1cb9ed6 commit 3b9aa70
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 4 deletions.
8 changes: 6 additions & 2 deletions example/gluon/image_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,8 @@
help='enable batch normalization or not in vgg. default is false.')
parser.add_argument('--use-pretrained', action='store_true',
help='enable using pretrained model from gluon.')
parser.add_argument('--kvstore', type=str, default='device',
help='kvstore to use for trainer/module.')
parser.add_argument('--log-interval', type=int, default=50, help='Number of batches to wait before logging.')
opt = parser.parse_args()

Expand Down Expand Up @@ -116,7 +118,8 @@ def train(epochs, ctx):
if isinstance(ctx, mx.Context):
ctx = [ctx]
net.initialize(mx.init.Xavier(magnitude=2.24), ctx=ctx)
trainer = gluon.Trainer(net.collect_params(), 'sgd', {'learning_rate': opt.lr, 'wd': opt.wd})
trainer = gluon.Trainer(net.collect_params(), 'sgd', {'learning_rate': opt.lr, 'wd': opt.wd},
kvstore = opt.kvstore)
metric = mx.metric.Accuracy()
loss = gluon.loss.SoftmaxCrossEntropyLoss()

Expand Down Expand Up @@ -162,7 +165,8 @@ def train(epochs, ctx):
out = net(data)
softmax = mx.sym.SoftmaxOutput(out, name='softmax')
mod = mx.mod.Module(softmax, context=[mx.gpu(i) for i in range(gpus)] if gpus > 0 else [mx.cpu()])
mod.fit(train_data, num_epoch=opt.epochs, batch_end_callback = mx.callback.Speedometer(batch_size, 1))
mod.fit(train_data, num_epoch=opt.epochs, kvstore=opt.kvstore,
batch_end_callback = mx.callback.Speedometer(batch_size, 1))
else:
if opt.mode == 'hybrid':
net.hybridize()
Expand Down
4 changes: 2 additions & 2 deletions python/mxnet/gluon/rnn/rnn_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ def begin_state(self, batch_size=0, func=ndarray.zeros, **kwargs):
batch_size: int
Only required for `NDArray` API. Size of the batch ('N' in layout).
Dimension of the input.
func : callable, default `symbol.zeros`
func : callable, default `ndarray.zeros`
Function for creating initial state.
For Symbol API, func can be `symbol.zeros`, `symbol.uniform`,
Expand Down Expand Up @@ -172,7 +172,7 @@ def forward(self, inputs, states=None):
batch_size = inputs.shape[self._layout.find('N')]
skip_states = states is None
if skip_states:
states = self.begin_state(batch_size)
states = self.begin_state(batch_size, ctx=inputs.context)
if isinstance(states, ndarray.NDArray):
states = [states]
for state, info in zip(states, self.state_info(batch_size)):
Expand Down

0 comments on commit 3b9aa70

Please sign in to comment.