Skip to content

Commit

Permalink
[Fit-API] Adress PR comments (apache#14885)
Browse files Browse the repository at this point in the history
* address comments

* update checkpoint

* test symbol save

* address comments

* add resume

* update doc and resume checkpoint

* update docs

* trigger ci

* trigger ci
  • Loading branch information
roywei authored and haohuw committed Jun 23, 2019
1 parent 15d1a45 commit e608313
Show file tree
Hide file tree
Showing 4 changed files with 636 additions and 257 deletions.
162 changes: 97 additions & 65 deletions python/mxnet/gluon/contrib/estimator/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,8 @@

import copy
import warnings
import weakref

from .event_handler import MetricHandler, ValidationHandler, LoggingHandler
from .event_handler import MetricHandler, ValidationHandler, LoggingHandler, StoppingHandler
from .event_handler import TrainBegin, EpochBegin, BatchBegin, BatchEnd, EpochEnd, TrainEnd
from .... import gluon, autograd
from ....context import Context, cpu, gpu, num_gpus
Expand All @@ -40,16 +39,18 @@ class Estimator(object):
Parameters
----------
net : Block
The model used for training.
loss : gluon.loss.Loss or list of gluon.loss.Loss
Loss(objective functions) to calculate during training
Loss(objective functions) to calculate during training.
metrics : EvalMetric or list of EvalMetric
Metrics for evaluating models
Metrics for evaluating models.
initializer : Initializer
initializer to initialize the network
Initializer to initialize the network.
trainer : Trainer
Trainer to apply optimizer on network parameters
Trainer to apply optimizer on network parameters.
context : Context or list of Context
device(s) to run the training on
Device(s) to run the training on.
"""

def __init__(self, net,
Expand All @@ -70,7 +71,7 @@ def __init__(self, net,
def _check_loss(self, loss):
if isinstance(loss, gluon.loss.Loss):
loss = [loss]
elif isinstance(loss, list) or all([isinstance(l, gluon.loss.Loss) for l in loss]):
elif isinstance(loss, list) and all([isinstance(l, gluon.loss.Loss) for l in loss]):
loss = loss
else:
raise ValueError("loss must be a Loss or a list of Loss, "
Expand Down Expand Up @@ -122,19 +123,23 @@ def _check_context(self, context):

def _initialize(self, initializer):
# initialize the network
if initializer:
if self._is_initialized():
# if already initialized, re-init with user specified initializer
warnings.warn("Network already initialized, re-initializing with %s. "
"You don't need to pass initializer if you already "
"initialized your net." % type(initializer).__name__)
self.net.initialize(init=initializer, ctx=self.context, force_reinit=True)
if not self._is_initialized():
# net is partially or not initialized,
# initialize with user specified initializer
# if initializer is None, default initializer will be used
# do not re-init layers already initialized
if initializer:
self.net.initialize(init=initializer, ctx=self.context)
else:
# initialize with user specified initializer
self.net.initialize(init=initializer, ctx=self.context, force_reinit=False)
else:
if not self._is_initialized():
self.net.initialize(ctx=self.context)
elif initializer:
# net is fully initialized, and user passed not None initializer
# do not force reinitialize, give warning
warnings.warn("Network already fully initialized, skipping initialization. "
"You don't need to pass initializer if you already "
"initialized your net. "
"You can use net.initialize(init=your_initializer, force_reinit=True)"
"to force re-initialize.")

def _check_trainer(self, trainer):
# handle trainer
Expand All @@ -157,11 +162,11 @@ def _is_initialized(self):
return False
return True

def _get_data_and_label(self, batch, ctx):
def _get_data_and_label(self, batch, ctx, batch_axis=0):
data = batch[0]
label = batch[1]
data = gluon.utils.split_and_load(data, ctx_list=ctx, batch_axis=0)
label = gluon.utils.split_and_load(label, ctx_list=ctx, batch_axis=0)
data = gluon.utils.split_and_load(data, ctx_list=ctx, batch_axis=batch_axis)
label = gluon.utils.split_and_load(label, ctx_list=ctx, batch_axis=batch_axis)
return data, label

def prepare_loss_and_metrics(self):
Expand All @@ -183,33 +188,36 @@ def prepare_loss_and_metrics(self):
self.train_metrics.append(Loss(loss.name.rstrip('1234567890')))
for metric in self.train_metrics:
val_metric = copy.deepcopy(metric)
metric.name = "Train " + metric.name
val_metric.name = "Validation " + val_metric.name
metric.name = "train " + metric.name
val_metric.name = "validation " + val_metric.name
self.val_metrics.append(val_metric)
return self.train_metrics, self.val_metrics

def evaluate(self,
val_data,
val_metrics):
val_metrics,
batch_axis=0):
"""Evaluate model on validation data
Parameters
----------
val_data : DataLoader
validation data with data and labels
Validation data loader with data and labels.
val_metrics : EvalMetric or list of EvalMetrics
metrics to update validation result
Metrics to update validation result.
batch_axis : int, default 0
Batch axis to split the validation data into devices.
"""
if not isinstance(val_data, gluon.data.DataLoader):
raise ValueError("Estimator only support input as Gluon DataLoader. Alternatively, you "
"can transform your DataIter or any NDArray into Gluon DataLoader. "
"Refer to gluon.data.dataloader")

for metric in val_metrics:
metric.reset()

for _, batch in enumerate(val_data):
if not isinstance(val_data, gluon.data.DataLoader):
raise ValueError("Estimator only support input as Gluon DataLoader. Alternatively, you "
"can transform your DataIter or any NDArray into Gluon DataLoader. "
"Refer to gluon.data.dataloader")
data, label = self._get_data_and_label(batch, self.context)
data, label = self._get_data_and_label(batch, self.context, batch_axis)
pred = [self.net(x) for x in data]
loss = [self.loss[0](y_hat, y) for y_hat, y in zip(pred, label)]
# update metrics
Expand All @@ -221,54 +229,65 @@ def evaluate(self,

def fit(self, train_data,
val_data=None,
epochs=1,
event_handlers=None):
"""Trains the model on a given dataset for a specified
number of epochs. Also, the batch size is inferred from the
DataLoader's batch_size.
epochs=None,
event_handlers=None,
batches=None,
batch_axis=0):
"""Trains the model with a given :py:class:`DataLoader` for a specified
number of epochs or batches. The batch size is inferred from the
data loader's batch_size.
Parameters
----------
train_data : DataLoader
training data with data and labels
val_data : DataLoader
validation data with data and labels
epochs : int, default 1
number of epochs to iterate on the training data.
batch_size : int
number of samples per gradient update.
default will be 32 per device
Training data loader with data and labels.
val_data : DataLoader, default None
Validation data loader with data and labels.
epochs : int, default None
Number of epochs to iterate on the training data.
You can only specify one and only one type of iteration(epochs or batches).
event_handlers : EventHandler or list of EventHandler
list of EventHandlers to apply during training
batch_fn : function
custom batch function to extract data and label
from a data batch and load into contexts(devices)
List of :py:class:`EventHandlers` to apply during training.
batches : int, default None
Number of batches to iterate on the training data.
You can only specify one and only one type of iteration(epochs or batches).
batch_axis : int, default 0
Batch axis to split the training data into devices.
"""
self.max_epochs = epochs
if not isinstance(train_data, gluon.data.DataLoader):
raise ValueError("Estimator only support input as Gluon DataLoader. Alternatively, you "
"can transform your DataIter or any NDArray into Gluon DataLoader. "
"Refer to gluon.data.dataloader")

# must specify one and only one of epochs or batches
if (not epochs) == (not batches):
raise ValueError(
"Fit only support exactly one type of iteration, "
"train by number of epochs or number of batches."
"Please specify one and only one of: epochs or batches.")

self.max_epoch = epochs
self.max_batch = batches

# provide default handlers
event_handlers = self._prepare_default_handlers(val_data, event_handlers)

train_begin, epoch_begin, batch_begin, \
batch_end, epoch_end, train_end = self._categorize_handlers(event_handlers)

# only pass a weak reference to all event handlers
estimator_ref = weakref.proxy(self)
# pass a reference to all event handlers
estimator_ref = self
# training begin
for handler in train_begin:
handler.train_begin(estimator_ref)

for epoch in range(epochs):
while True:
# epoch begin
for handler in epoch_begin:
handler.epoch_begin(estimator_ref)

for i, batch in enumerate(train_data):
if not isinstance(train_data, gluon.data.DataLoader):
raise ValueError("Estimator only support input as Gluon DataLoader. Alternatively, you "
"can transform your DataIter or any NDArray into Gluon DataLoader. "
"Refer to gluon.data.dataloader")
data, label = self._get_data_and_label(batch, self.context)
data, label = self._get_data_and_label(batch, self.context, batch_axis)

batch_size = batch[0].shape[0]

Expand All @@ -285,15 +304,22 @@ def fit(self, train_data,

self.trainer.step(batch_size)
# batch end

batch_end_result = []
for handler in batch_end:
if handler.batch_end(estimator_ref, batch=batch,
pred=pred, label=label, loss=loss):
break
batch_end_result.append(handler.batch_end(estimator_ref, batch=batch,
pred=pred, label=label, loss=loss))
# if any handler signaled to stop
if any(batch_end_result):
break

# epoch end
epoch_end_result = []
for handler in epoch_end:
if handler.epoch_end(estimator_ref):
break
epoch_end_result.append(handler.epoch_end(estimator_ref))
# if any handler signaled to stop
if any(epoch_end_result):
break

# train end
for handler in train_end:
Expand All @@ -304,6 +330,9 @@ def _prepare_default_handlers(self, val_data, event_handlers):
default_handlers = []
train_metrics, val_metrics = self.prepare_loss_and_metrics()

# no need to add to default handler check as StoppingHandler does not use metrics
event_handlers.append(StoppingHandler(self.max_epoch, self.max_batch))

if not any(isinstance(handler, MetricHandler) for handler in event_handlers):
event_handlers.append(MetricHandler(train_metrics=train_metrics))
default_handlers.append("MetricHandler")
Expand All @@ -319,13 +348,14 @@ def _prepare_default_handlers(self, val_data, event_handlers):
default_handlers.append("LoggingHandler")

# if there is a mix of user defined event handlers and default event handlers
# they should have the save set of loss and metrics
# they should have the same set of loss and metrics
if default_handlers:
msg = "You are training with the following default event handlers: %s. " \
"They use loss and metrics from estimator.prepare_loss_and_metrics(). " \
"Please use the same set of metrics for all your other handlers." % \
", ".join(default_handlers)
warnings.warn(msg)
# check if all handlers has the same set of references to loss and metrics
references = []
for handler in event_handlers:
for attribute in dir(handler):
Expand All @@ -335,8 +365,10 @@ def _prepare_default_handlers(self, val_data, event_handlers):
references += reference
else:
references.append(reference)
# remove None metric references
references = set([ref for ref in references if ref])
for metric in references:
if metric and metric not in train_metrics + val_metrics:
if metric not in train_metrics + val_metrics:
msg = "We have added following default handlers for you: %s and used " \
"estimator.prepare_loss_and_metrics() to pass metrics to " \
"those handlers. Please use the same set of metrics " \
Expand Down
Loading

0 comments on commit e608313

Please sign in to comment.