Skip to content

Commit

Permalink
Python BucketingModule bind() with grad_req = 'add' (apache#13984)
Browse files Browse the repository at this point in the history
* remember grad_req from bind and apply it to sub-modules

* unit-test for gradient accumulation with bucketing modules
  • Loading branch information
slyforce authored and haohuw committed Jun 23, 2019
1 parent bf83e02 commit 7ebfe7a
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 2 deletions.
7 changes: 5 additions & 2 deletions python/mxnet/module/bucketing_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ def __init__(self, sym_gen, default_bucket_key=None, logger=logging,
self._curr_bucket_key = None
self._params_dirty = False
self._monitor = None
self._grad_req = None

def _reset_bind(self):
"""Internal utility function to reset binding."""
Expand Down Expand Up @@ -331,6 +332,7 @@ def bind(self, data_shapes, label_shapes=None, for_training=True,
self.for_training = for_training
self.inputs_need_grad = inputs_need_grad
self.binded = True
self._grad_req = grad_req

symbol, data_names, label_names = self._call_sym_gen(self._default_bucket_key)
module = Module(symbol, data_names, label_names, logger=self.logger,
Expand All @@ -340,7 +342,7 @@ def bind(self, data_shapes, label_shapes=None, for_training=True,
group2ctxs=self._group2ctxs,
compression_params=self._compression_params)
module.bind(data_shapes, label_shapes, for_training, inputs_need_grad,
force_rebind=False, shared_module=None, grad_req=grad_req)
force_rebind=False, shared_module=None, grad_req=self._grad_req)
self._curr_module = module
self._curr_bucket_key = self._default_bucket_key
self._buckets[self._default_bucket_key] = module
Expand Down Expand Up @@ -373,7 +375,8 @@ def switch_bucket(self, bucket_key, data_shapes, label_shapes=None):
compression_params=self._compression_params)
module.bind(data_shapes, label_shapes, self._curr_module.for_training,
self._curr_module.inputs_need_grad,
force_rebind=False, shared_module=self._buckets[self._default_bucket_key])
force_rebind=False, shared_module=self._buckets[self._default_bucket_key],
grad_req=self._grad_req)
if self._monitor is not None:
module.install_monitor(self._monitor)
self._buckets[bucket_key] = module
Expand Down
42 changes: 42 additions & 0 deletions tests/python/unittest/test_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -874,6 +874,48 @@ def empty_fn(*args, **kwargs):
train_data = MockTrainData(batches=2)
mod.fit(train_data, num_epoch=1)

@with_seed()
def test_bucket_module_grad_req():
batch_size = 2
def sym_gen(_):
data = mx.symbol.Variable('data')
weight = mx.symbol.Variable('a', shape=(1,), init=mx.init.One())
sym = mx.sym.make_loss(mx.sym.broadcast_mul(data, weight))
return sym, ('data',), None

mod = mx.mod.BucketingModule(sym_gen=sym_gen, default_bucket_key=10)
mod.bind(data_shapes=[['data', (batch_size, )]], for_training=True, grad_req='write')
mod.init_params()

mod.forward_backward(mx.io.DataBatch(data=[mx.nd.ones((batch_size,))],
label=None,
provide_data=[mx.io.DataDesc(name='data', shape=(batch_size, ), layout='N')],
bucket_key=10))
assert(mod._curr_module._exec_group.execs[0].grad_dict['a'].asscalar() == batch_size)

mod.forward_backward(mx.io.DataBatch(data=[mx.nd.ones((batch_size,))],
label=None,
provide_data=[mx.io.DataDesc(name='data', shape=(batch_size, ), layout='N')],
bucket_key=5))
assert(mod._curr_module._exec_group.execs[0].grad_dict['a'].asscalar() == batch_size)

mod = mx.mod.BucketingModule(sym_gen=sym_gen, default_bucket_key=10)
mod.bind(data_shapes=[['data', (batch_size, )]], for_training=True, grad_req='add')
mod.init_params()

mod.forward_backward(mx.io.DataBatch(data=[mx.nd.ones((batch_size,))],
label=None,
provide_data=[mx.io.DataDesc(name='data', shape=(batch_size,), layout='N')],
bucket_key=10))
assert(mod._curr_module._exec_group.execs[0].grad_dict['a'].asscalar() == batch_size)

mod.forward_backward(mx.io.DataBatch(data=[mx.nd.ones((batch_size,))],
label=None,
provide_data=[mx.io.DataDesc(name='data', shape=(batch_size,), layout='N')],
bucket_key=5))
assert mod._curr_module._grad_req == 'add'
assert(mod._curr_module._exec_group.execs[0].grad_dict['a'].asscalar() == 2 * batch_size)


if __name__ == '__main__':
import nose
Expand Down

0 comments on commit 7ebfe7a

Please sign in to comment.