Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Include eval_net the validation model in the gluon estimator api (#16957
Browse files Browse the repository at this point in the history
)

* Include eval_net the validation model in the estimator api

* fix small issue
  • Loading branch information
liuzh47 authored and leezu committed Dec 10, 2019
1 parent 9f5b8bc commit e18e4ce
Show file tree
Hide file tree
Showing 2 changed files with 101 additions and 6 deletions.
32 changes: 28 additions & 4 deletions python/mxnet/gluon/contrib/estimator/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,9 +59,29 @@ class Estimator(object):
Trainer to apply optimizer on network parameters.
context : Context or list of Context
Device(s) to run the training on.
evaluation_loss: gluon.loss.loss
Loss (objective) function to calculate during evaluation. If set evaluation_loss
evaluation_loss : gluon.loss.loss
Loss (objective) function to calculate during validation. If set evaluation_loss
None, it will use the same loss function as self.loss
eval_net : gluon.Block
The model used for validation. The validation model does not necessarily belong to
the same model class as the training model. But the two models typically share the
same architecture. Therefore the validation model can reuse parameters of the
training model.
The code example of consruction of eval_net sharing the same network parameters as
the training net is given below:
>>> net = _get_train_network()
>>> eval_net = _get_test_network(params=net.collect_params())
>>> net.initialize(ctx=ctx)
>>> est = Estimator(net, loss, eval_net=eval_net)
Proper namespace match is required for weight sharing between two networks. Most networks
inheriting :py:class:`Block` can share their parameters correctly. An exception is
Sequential networks that Block scope must be specified for correct weight sharing. For
the naming in mxnet Gluon API, please refer to the site
(https://mxnet.apache.org/api/python/docs/tutorials/packages/gluon/blocks/naming.html)
for future information.
"""

Expand Down Expand Up @@ -89,7 +109,8 @@ def __init__(self, net,
initializer=None,
trainer=None,
context=None,
evaluation_loss=None):
evaluation_loss=None,
eval_net=None):
self.net = net
self.loss = self._check_loss(loss)
self._train_metrics = _check_metrics(metrics)
Expand All @@ -98,6 +119,9 @@ def __init__(self, net,
self.evaluation_loss = self.loss
if evaluation_loss is not None:
self.evaluation_loss = self._check_loss(evaluation_loss)
self.eval_net = self.net
if eval_net is not None:
self.eval_net = eval_net

self.logger = logging.Logger(name='Estimator', level=logging.INFO)
self.logger.addHandler(logging.StreamHandler(sys.stdout))
Expand Down Expand Up @@ -234,7 +258,7 @@ def evaluate_batch(self,
Batch axis to split the validation data into devices.
"""
data, label = self._get_data_and_label(val_batch, self.context, batch_axis)
pred = [self.net(x) for x in data]
pred = [self.eval_net(x) for x in data]
loss = [self.evaluation_loss(y_hat, y) for y_hat, y in zip(pred, label)]
# update metrics
for metric in val_metrics:
Expand Down
75 changes: 73 additions & 2 deletions tests/python/unittest/test_gluon_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,16 @@
from nose.tools import assert_raises


def _get_test_network():
net = nn.Sequential()
def _get_test_network(params=None):
net = nn.Sequential(params=params)
net.add(nn.Dense(4, activation='relu', flatten=False))
return net

def _get_test_network_with_namescope(params=None):
net = nn.Sequential(params=params)
with net.name_scope():
net.add(nn.Dense(4, activation='relu', flatten=False))
return net

def _get_test_data():
batch_size = 4
Expand Down Expand Up @@ -371,3 +376,69 @@ def test_default_handlers():
assert isinstance(handlers[0], GradientUpdateHandler)
assert isinstance(handlers[1], MetricHandler)
assert isinstance(handlers[4], LoggingHandler)

def test_eval_net():
''' test estimator with a different evaluation net '''
''' test weight sharing of sequential networks without namescope '''
net = _get_test_network()
eval_net = _get_test_network(params=net.collect_params())
dataloader, dataiter = _get_test_data()
num_epochs = 1
ctx = mx.cpu()
loss = gluon.loss.L2Loss()
evaluation_loss = gluon.loss.L2Loss()
acc = mx.metric.Accuracy()
net.initialize(ctx=ctx)
trainer = gluon.Trainer(net.collect_params(), 'sgd', {'learning_rate': 0.001})
est = Estimator(net=net,
loss=loss,
metrics=acc,
trainer=trainer,
context=ctx,
evaluation_loss=evaluation_loss,
eval_net=eval_net)

with assert_raises(RuntimeError):
est.fit(train_data=dataloader,
val_data=dataloader,
epochs=num_epochs)

''' test weight sharing of sequential networks with namescope '''
net = _get_test_network_with_namescope()
eval_net = _get_test_network_with_namescope(params=net.collect_params())
net.initialize(ctx=ctx)
trainer = gluon.Trainer(net.collect_params(), 'sgd', {'learning_rate': 0.001})
est = Estimator(net=net,
loss=loss,
metrics=acc,
trainer=trainer,
context=ctx,
evaluation_loss=evaluation_loss,
eval_net=eval_net)

est.fit(train_data=dataloader,
val_data=dataloader,
epochs=num_epochs)

''' test weight sharing of two resnets '''
net = gluon.model_zoo.vision.resnet18_v1(pretrained=False, ctx=ctx)
net.output = gluon.nn.Dense(10)
eval_net = gluon.model_zoo.vision.resnet18_v1(pretrained=False, ctx=ctx)
eval_net.output = gluon.nn.Dense(10, params=net.collect_params())
dataset = gluon.data.ArrayDataset(mx.nd.zeros((10, 3, 224, 224)), mx.nd.zeros((10, 10)))
dataloader = gluon.data.DataLoader(dataset=dataset, batch_size=5)
net.initialize(ctx=ctx)
eval_net.initialize(ctx=ctx)
trainer = gluon.Trainer(net.collect_params(), 'sgd', {'learning_rate': 0.001})
est = Estimator(net=net,
loss=loss,
metrics=acc,
trainer=trainer,
context=ctx,
evaluation_loss=evaluation_loss,
eval_net=eval_net)

est.fit(train_data=dataloader,
val_data=dataloader,
epochs=num_epochs)

0 comments on commit e18e4ce

Please sign in to comment.