Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Add event handlers in validation handler (#17322)
Browse files Browse the repository at this point in the history
* Add event handlers in validation handler

* update doc string
  • Loading branch information
liuzh47 authored and roywei committed Jan 23, 2020
1 parent a252a17 commit 359da76
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 3 deletions.
14 changes: 11 additions & 3 deletions python/mxnet/gluon/contrib/estimator/event_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,14 +181,19 @@ class ValidationHandler(TrainBegin, BatchEnd, EpochEnd):
Priority level of the ValidationHandler. Priority level is sorted in
ascending order. The lower the number is, the higher priority level the
handler is.
event_handlers : EventHandler or list of EventHandlers
List of :py:class:`EventHandler` to apply during validaiton. This argument
is used by self.eval_fn function in order to process customized event
handlers.
"""

def __init__(self,
val_data,
eval_fn,
epoch_period=1,
batch_period=None,
priority=-1000):
priority=-1000,
event_handlers=None):
self.val_data = val_data
self.eval_fn = eval_fn
self.epoch_period = epoch_period
Expand All @@ -198,6 +203,7 @@ def __init__(self,
# order to be called among all callbacks
# validation metrics need to be calculated before other callbacks can access them
self.priority = priority
self.event_handlers = event_handlers

def train_begin(self, estimator, *args, **kwargs):
# reset epoch and batch counter
Expand All @@ -207,12 +213,14 @@ def train_begin(self, estimator, *args, **kwargs):
def batch_end(self, estimator, *args, **kwargs):
self.current_batch += 1
if self.batch_period and self.current_batch % self.batch_period == 0:
self.eval_fn(val_data=self.val_data)
self.eval_fn(val_data=self.val_data, batch_axis=estimator.batch_axis,
event_handlers=self.event_handlers)

def epoch_end(self, estimator, *args, **kwargs):
self.current_epoch += 1
if self.epoch_period and self.current_epoch % self.epoch_period == 0:
self.eval_fn(val_data=self.val_data, batch_axis=estimator.batch_axis)
self.eval_fn(val_data=self.val_data, batch_axis=estimator.batch_axis,
event_handlers=self.event_handlers)


class LoggingHandler(TrainBegin, TrainEnd, EpochBegin, EpochEnd, BatchBegin, BatchEnd):
Expand Down
23 changes: 23 additions & 0 deletions tests/python/unittest/test_gluon_event_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@
from mxnet.gluon import nn, loss
from mxnet.gluon.contrib.estimator import estimator, event_handler
from mxnet.gluon.contrib.estimator.event_handler import LoggingHandler
from mxnet.gluon.contrib.estimator.event_handler import ValidationHandler
from mxnet.gluon.contrib.estimator import EpochEnd
from mxnet.gluon.data.dataset import Dataset
try:
from StringIO import StringIO
Expand All @@ -48,6 +50,13 @@ def __getitem__(self, idx):
def __len__(self):
return self._length

class TestHandler(EpochEnd):
def __init__(self):
pass

def epoch_end(self, estimator, *args, **kwargs):
estimator.run_test_handler = True

def _get_test_network(net=nn.Sequential()):
net.add(nn.Dense(128, activation='relu', flatten=False),
nn.Dense(64, activation='relu'),
Expand Down Expand Up @@ -301,3 +310,17 @@ def test_validation_handler_batch_axis():
est.fit(test_data, val_data=val_data,
epochs=3, batch_axis=1)

def test_validation_handler():
test_data = _get_test_data()

net = _get_test_network()
ce_loss = loss.SoftmaxCrossEntropyLoss()
acc = mx.metric.Accuracy()
est = estimator.Estimator(net, loss=ce_loss, train_metrics=acc)
val_handler = ValidationHandler(val_data=test_data,
eval_fn=est.evaluate,
event_handlers=TestHandler())

est.fit(train_data=test_data, val_data=test_data,
event_handlers=[val_handler], epochs=2)
assert est.run_test_handler == True

0 comments on commit 359da76

Please sign in to comment.