diff --git a/python/mxnet/gluon/contrib/estimator/estimator.py b/python/mxnet/gluon/contrib/estimator/estimator.py index ab7018f58e1f..ac7c3d3825ab 100644 --- a/python/mxnet/gluon/contrib/estimator/estimator.py +++ b/python/mxnet/gluon/contrib/estimator/estimator.py @@ -59,9 +59,29 @@ class Estimator(object): Trainer to apply optimizer on network parameters. context : Context or list of Context Device(s) to run the training on. - evaluation_loss: gluon.loss.loss - Loss (objective) function to calculate during evaluation. If set evaluation_loss + evaluation_loss : gluon.loss.loss + Loss (objective) function to calculate during validation. If set evaluation_loss None, it will use the same loss function as self.loss + eval_net : gluon.Block + The model used for validation. The validation model does not necessarily belong to + the same model class as the training model. But the two models typically share the + same architecture. Therefore the validation model can reuse parameters of the + training model. + + The code example of consruction of eval_net sharing the same network parameters as + the training net is given below: + + >>> net = _get_train_network() + >>> eval_net = _get_test_network(params=net.collect_params()) + >>> net.initialize(ctx=ctx) + >>> est = Estimator(net, loss, eval_net=eval_net) + + Proper namespace match is required for weight sharing between two networks. Most networks + inheriting :py:class:`Block` can share their parameters correctly. An exception is + Sequential networks that Block scope must be specified for correct weight sharing. For + the naming in mxnet Gluon API, please refer to the site + (https://mxnet.apache.org/api/python/docs/tutorials/packages/gluon/blocks/naming.html) + for future information. """ @@ -89,7 +109,8 @@ def __init__(self, net, initializer=None, trainer=None, context=None, - evaluation_loss=None): + evaluation_loss=None, + eval_net=None): self.net = net self.loss = self._check_loss(loss) self._train_metrics = _check_metrics(metrics) @@ -98,6 +119,9 @@ def __init__(self, net, self.evaluation_loss = self.loss if evaluation_loss is not None: self.evaluation_loss = self._check_loss(evaluation_loss) + self.eval_net = self.net + if eval_net is not None: + self.eval_net = eval_net self.logger = logging.Logger(name='Estimator', level=logging.INFO) self.logger.addHandler(logging.StreamHandler(sys.stdout)) @@ -234,7 +258,7 @@ def evaluate_batch(self, Batch axis to split the validation data into devices. """ data, label = self._get_data_and_label(val_batch, self.context, batch_axis) - pred = [self.net(x) for x in data] + pred = [self.eval_net(x) for x in data] loss = [self.evaluation_loss(y_hat, y) for y_hat, y in zip(pred, label)] # update metrics for metric in val_metrics: diff --git a/tests/python/unittest/test_gluon_estimator.py b/tests/python/unittest/test_gluon_estimator.py index 21f949a0bba6..dba3f122a9b6 100644 --- a/tests/python/unittest/test_gluon_estimator.py +++ b/tests/python/unittest/test_gluon_estimator.py @@ -29,11 +29,16 @@ from nose.tools import assert_raises -def _get_test_network(): - net = nn.Sequential() +def _get_test_network(params=None): + net = nn.Sequential(params=params) net.add(nn.Dense(4, activation='relu', flatten=False)) return net +def _get_test_network_with_namescope(params=None): + net = nn.Sequential(params=params) + with net.name_scope(): + net.add(nn.Dense(4, activation='relu', flatten=False)) + return net def _get_test_data(): batch_size = 4 @@ -371,3 +376,69 @@ def test_default_handlers(): assert isinstance(handlers[0], GradientUpdateHandler) assert isinstance(handlers[1], MetricHandler) assert isinstance(handlers[4], LoggingHandler) + +def test_eval_net(): + ''' test estimator with a different evaluation net ''' + ''' test weight sharing of sequential networks without namescope ''' + net = _get_test_network() + eval_net = _get_test_network(params=net.collect_params()) + dataloader, dataiter = _get_test_data() + num_epochs = 1 + ctx = mx.cpu() + loss = gluon.loss.L2Loss() + evaluation_loss = gluon.loss.L2Loss() + acc = mx.metric.Accuracy() + net.initialize(ctx=ctx) + trainer = gluon.Trainer(net.collect_params(), 'sgd', {'learning_rate': 0.001}) + est = Estimator(net=net, + loss=loss, + metrics=acc, + trainer=trainer, + context=ctx, + evaluation_loss=evaluation_loss, + eval_net=eval_net) + + with assert_raises(RuntimeError): + est.fit(train_data=dataloader, + val_data=dataloader, + epochs=num_epochs) + + ''' test weight sharing of sequential networks with namescope ''' + net = _get_test_network_with_namescope() + eval_net = _get_test_network_with_namescope(params=net.collect_params()) + net.initialize(ctx=ctx) + trainer = gluon.Trainer(net.collect_params(), 'sgd', {'learning_rate': 0.001}) + est = Estimator(net=net, + loss=loss, + metrics=acc, + trainer=trainer, + context=ctx, + evaluation_loss=evaluation_loss, + eval_net=eval_net) + + est.fit(train_data=dataloader, + val_data=dataloader, + epochs=num_epochs) + + ''' test weight sharing of two resnets ''' + net = gluon.model_zoo.vision.resnet18_v1(pretrained=False, ctx=ctx) + net.output = gluon.nn.Dense(10) + eval_net = gluon.model_zoo.vision.resnet18_v1(pretrained=False, ctx=ctx) + eval_net.output = gluon.nn.Dense(10, params=net.collect_params()) + dataset = gluon.data.ArrayDataset(mx.nd.zeros((10, 3, 224, 224)), mx.nd.zeros((10, 10))) + dataloader = gluon.data.DataLoader(dataset=dataset, batch_size=5) + net.initialize(ctx=ctx) + eval_net.initialize(ctx=ctx) + trainer = gluon.Trainer(net.collect_params(), 'sgd', {'learning_rate': 0.001}) + est = Estimator(net=net, + loss=loss, + metrics=acc, + trainer=trainer, + context=ctx, + evaluation_loss=evaluation_loss, + eval_net=eval_net) + + est.fit(train_data=dataloader, + val_data=dataloader, + epochs=num_epochs) +