diff --git a/python/mxnet/module/module.py b/python/mxnet/module/module.py index 4d77e0e4d8c0..a7d3336e8439 100644 --- a/python/mxnet/module/module.py +++ b/python/mxnet/module/module.py @@ -398,7 +398,6 @@ def bind(self, data_shapes, label_shapes=None, for_training=True, self.for_training = for_training self.inputs_need_grad = inputs_need_grad - self.binded = True self._grad_req = grad_req if not for_training: @@ -454,6 +453,8 @@ def bind(self, data_shapes, label_shapes=None, for_training=True, if shared_module is not None and shared_module.optimizer_initialized: self.borrow_optimizer(shared_module) + self.binded = True + def reshape(self, data_shapes, label_shapes=None): """Reshapes the module for new input shapes.