Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Include eval_net the validation model in the estimator api #16957

Merged
merged 2 commits into from
Dec 10, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 28 additions & 4 deletions python/mxnet/gluon/contrib/estimator/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,9 +59,29 @@ class Estimator(object):
Trainer to apply optimizer on network parameters.
context : Context or list of Context
Device(s) to run the training on.
evaluation_loss: gluon.loss.loss
Loss (objective) function to calculate during evaluation. If set evaluation_loss
evaluation_loss : gluon.loss.loss
Loss (objective) function to calculate during validation. If set evaluation_loss
None, it will use the same loss function as self.loss
eval_net : gluon.Block
The model used for validation. The validation model does not necessarily belong to
the same model class as the training model. But the two models typically share the
same architecture. Therefore the validation model can reuse parameters of the
training model.

The code example of consruction of eval_net sharing the same network parameters as
the training net is given below:

>>> net = _get_train_network()
>>> eval_net = _get_test_network(params=net.collect_params())
>>> net.initialize(ctx=ctx)
>>> est = Estimator(net, loss, eval_net=eval_net)

Proper namespace match is required for weight sharing between two networks. Most networks
inheriting :py:class:`Block` can share their parameters correctly. An exception is
Sequential networks that Block scope must be specified for correct weight sharing. For
the naming in mxnet Gluon API, please refer to the site
(https://mxnet.apache.org/api/python/docs/tutorials/packages/gluon/blocks/naming.html)
for future information.

"""

Expand Down Expand Up @@ -89,7 +109,8 @@ def __init__(self, net,
initializer=None,
trainer=None,
context=None,
evaluation_loss=None):
evaluation_loss=None,
eval_net=None):
self.net = net
self.loss = self._check_loss(loss)
self._train_metrics = _check_metrics(metrics)
Expand All @@ -98,6 +119,9 @@ def __init__(self, net,
self.evaluation_loss = self.loss
if evaluation_loss is not None:
self.evaluation_loss = self._check_loss(evaluation_loss)
self.eval_net = self.net
if eval_net is not None:
self.eval_net = eval_net

self.logger = logging.Logger(name='Estimator', level=logging.INFO)
self.logger.addHandler(logging.StreamHandler(sys.stdout))
Expand Down Expand Up @@ -234,7 +258,7 @@ def evaluate_batch(self,
Batch axis to split the validation data into devices.
"""
data, label = self._get_data_and_label(val_batch, self.context, batch_axis)
pred = [self.net(x) for x in data]
pred = [self.eval_net(x) for x in data]
loss = [self.evaluation_loss(y_hat, y) for y_hat, y in zip(pred, label)]
# update metrics
for metric in val_metrics:
Expand Down
75 changes: 73 additions & 2 deletions tests/python/unittest/test_gluon_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,16 @@
from nose.tools import assert_raises


def _get_test_network():
net = nn.Sequential()
def _get_test_network(params=None):
net = nn.Sequential(params=params)
net.add(nn.Dense(4, activation='relu', flatten=False))
return net

def _get_test_network_with_namescope(params=None):
net = nn.Sequential(params=params)
with net.name_scope():
net.add(nn.Dense(4, activation='relu', flatten=False))
return net

def _get_test_data():
batch_size = 4
Expand Down Expand Up @@ -371,3 +376,69 @@ def test_default_handlers():
assert isinstance(handlers[0], GradientUpdateHandler)
assert isinstance(handlers[1], MetricHandler)
assert isinstance(handlers[4], LoggingHandler)

def test_eval_net():
''' test estimator with a different evaluation net '''
''' test weight sharing of sequential networks without namescope '''
net = _get_test_network()
eval_net = _get_test_network(params=net.collect_params())
dataloader, dataiter = _get_test_data()
num_epochs = 1
ctx = mx.cpu()
loss = gluon.loss.L2Loss()
evaluation_loss = gluon.loss.L2Loss()
acc = mx.metric.Accuracy()
net.initialize(ctx=ctx)
trainer = gluon.Trainer(net.collect_params(), 'sgd', {'learning_rate': 0.001})
est = Estimator(net=net,
loss=loss,
metrics=acc,
trainer=trainer,
context=ctx,
evaluation_loss=evaluation_loss,
eval_net=eval_net)

with assert_raises(RuntimeError):
est.fit(train_data=dataloader,
val_data=dataloader,
epochs=num_epochs)

''' test weight sharing of sequential networks with namescope '''
net = _get_test_network_with_namescope()
eval_net = _get_test_network_with_namescope(params=net.collect_params())
net.initialize(ctx=ctx)
trainer = gluon.Trainer(net.collect_params(), 'sgd', {'learning_rate': 0.001})
est = Estimator(net=net,
loss=loss,
metrics=acc,
trainer=trainer,
context=ctx,
evaluation_loss=evaluation_loss,
eval_net=eval_net)

est.fit(train_data=dataloader,
val_data=dataloader,
epochs=num_epochs)

''' test weight sharing of two resnets '''
net = gluon.model_zoo.vision.resnet18_v1(pretrained=False, ctx=ctx)
net.output = gluon.nn.Dense(10)
eval_net = gluon.model_zoo.vision.resnet18_v1(pretrained=False, ctx=ctx)
eval_net.output = gluon.nn.Dense(10, params=net.collect_params())
dataset = gluon.data.ArrayDataset(mx.nd.zeros((10, 3, 224, 224)), mx.nd.zeros((10, 10)))
dataloader = gluon.data.DataLoader(dataset=dataset, batch_size=5)
net.initialize(ctx=ctx)
eval_net.initialize(ctx=ctx)
trainer = gluon.Trainer(net.collect_params(), 'sgd', {'learning_rate': 0.001})
est = Estimator(net=net,
loss=loss,
metrics=acc,
trainer=trainer,
context=ctx,
evaluation_loss=evaluation_loss,
eval_net=eval_net)

est.fit(train_data=dataloader,
val_data=dataloader,
epochs=num_epochs)