diff --git a/python/mxnet/module/bucketing_module.py b/python/mxnet/module/bucketing_module.py index 9b568618566b..66c666659d0b 100644 --- a/python/mxnet/module/bucketing_module.py +++ b/python/mxnet/module/bucketing_module.py @@ -95,6 +95,7 @@ def __init__(self, sym_gen, default_bucket_key=None, logger=logging, self._curr_bucket_key = None self._params_dirty = False self._monitor = None + self._grad_req = None def _reset_bind(self): """Internal utility function to reset binding.""" @@ -331,6 +332,7 @@ def bind(self, data_shapes, label_shapes=None, for_training=True, self.for_training = for_training self.inputs_need_grad = inputs_need_grad self.binded = True + self._grad_req = grad_req symbol, data_names, label_names = self._call_sym_gen(self._default_bucket_key) module = Module(symbol, data_names, label_names, logger=self.logger, @@ -340,7 +342,7 @@ def bind(self, data_shapes, label_shapes=None, for_training=True, group2ctxs=self._group2ctxs, compression_params=self._compression_params) module.bind(data_shapes, label_shapes, for_training, inputs_need_grad, - force_rebind=False, shared_module=None, grad_req=grad_req) + force_rebind=False, shared_module=None, grad_req=self._grad_req) self._curr_module = module self._curr_bucket_key = self._default_bucket_key self._buckets[self._default_bucket_key] = module @@ -373,7 +375,8 @@ def switch_bucket(self, bucket_key, data_shapes, label_shapes=None): compression_params=self._compression_params) module.bind(data_shapes, label_shapes, self._curr_module.for_training, self._curr_module.inputs_need_grad, - force_rebind=False, shared_module=self._buckets[self._default_bucket_key]) + force_rebind=False, shared_module=self._buckets[self._default_bucket_key], + grad_req=self._grad_req) if self._monitor is not None: module.install_monitor(self._monitor) self._buckets[bucket_key] = module diff --git a/tests/python/unittest/test_module.py b/tests/python/unittest/test_module.py index d9d7175f540e..ae38a2297ded 100644 --- a/tests/python/unittest/test_module.py +++ b/tests/python/unittest/test_module.py @@ -874,6 +874,48 @@ def empty_fn(*args, **kwargs): train_data = MockTrainData(batches=2) mod.fit(train_data, num_epoch=1) +@with_seed() +def test_bucket_module_grad_req(): + batch_size = 2 + def sym_gen(_): + data = mx.symbol.Variable('data') + weight = mx.symbol.Variable('a', shape=(1,), init=mx.init.One()) + sym = mx.sym.make_loss(mx.sym.broadcast_mul(data, weight)) + return sym, ('data',), None + + mod = mx.mod.BucketingModule(sym_gen=sym_gen, default_bucket_key=10) + mod.bind(data_shapes=[['data', (batch_size, )]], for_training=True, grad_req='write') + mod.init_params() + + mod.forward_backward(mx.io.DataBatch(data=[mx.nd.ones((batch_size,))], + label=None, + provide_data=[mx.io.DataDesc(name='data', shape=(batch_size, ), layout='N')], + bucket_key=10)) + assert(mod._curr_module._exec_group.execs[0].grad_dict['a'].asscalar() == batch_size) + + mod.forward_backward(mx.io.DataBatch(data=[mx.nd.ones((batch_size,))], + label=None, + provide_data=[mx.io.DataDesc(name='data', shape=(batch_size, ), layout='N')], + bucket_key=5)) + assert(mod._curr_module._exec_group.execs[0].grad_dict['a'].asscalar() == batch_size) + + mod = mx.mod.BucketingModule(sym_gen=sym_gen, default_bucket_key=10) + mod.bind(data_shapes=[['data', (batch_size, )]], for_training=True, grad_req='add') + mod.init_params() + + mod.forward_backward(mx.io.DataBatch(data=[mx.nd.ones((batch_size,))], + label=None, + provide_data=[mx.io.DataDesc(name='data', shape=(batch_size,), layout='N')], + bucket_key=10)) + assert(mod._curr_module._exec_group.execs[0].grad_dict['a'].asscalar() == batch_size) + + mod.forward_backward(mx.io.DataBatch(data=[mx.nd.ones((batch_size,))], + label=None, + provide_data=[mx.io.DataDesc(name='data', shape=(batch_size,), layout='N')], + bucket_key=5)) + assert mod._curr_module._grad_req == 'add' + assert(mod._curr_module._exec_group.execs[0].grad_dict['a'].asscalar() == 2 * batch_size) + if __name__ == '__main__': import nose