Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[Fit-API] Adress PR comments #14885

Merged
merged 9 commits into from
May 14, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
162 changes: 97 additions & 65 deletions python/mxnet/gluon/contrib/estimator/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,8 @@

import copy
import warnings
import weakref

from .event_handler import MetricHandler, ValidationHandler, LoggingHandler
from .event_handler import MetricHandler, ValidationHandler, LoggingHandler, StoppingHandler
from .event_handler import TrainBegin, EpochBegin, BatchBegin, BatchEnd, EpochEnd, TrainEnd
from .... import gluon, autograd
from ....context import Context, cpu, gpu, num_gpus
Expand All @@ -40,16 +39,18 @@ class Estimator(object):

Parameters
----------
net : Block
The model used for training.
loss : gluon.loss.Loss or list of gluon.loss.Loss
Loss(objective functions) to calculate during training
Loss(objective functions) to calculate during training.
metrics : EvalMetric or list of EvalMetric
Metrics for evaluating models
Metrics for evaluating models.
initializer : Initializer
initializer to initialize the network
Initializer to initialize the network.
trainer : Trainer
Trainer to apply optimizer on network parameters
Trainer to apply optimizer on network parameters.
context : Context or list of Context
device(s) to run the training on
Device(s) to run the training on.
"""

def __init__(self, net,
Expand All @@ -70,7 +71,7 @@ def __init__(self, net,
def _check_loss(self, loss):
if isinstance(loss, gluon.loss.Loss):
loss = [loss]
elif isinstance(loss, list) or all([isinstance(l, gluon.loss.Loss) for l in loss]):
elif isinstance(loss, list) and all([isinstance(l, gluon.loss.Loss) for l in loss]):
loss = loss
else:
raise ValueError("loss must be a Loss or a list of Loss, "
Expand Down Expand Up @@ -122,19 +123,23 @@ def _check_context(self, context):

def _initialize(self, initializer):
# initialize the network
if initializer:
if self._is_initialized():
# if already initialized, re-init with user specified initializer
warnings.warn("Network already initialized, re-initializing with %s. "
"You don't need to pass initializer if you already "
"initialized your net." % type(initializer).__name__)
self.net.initialize(init=initializer, ctx=self.context, force_reinit=True)
if not self._is_initialized():
# net is partially or not initialized,
# initialize with user specified initializer
# if initializer is None, default initializer will be used
# do not re-init layers already initialized
if initializer:
self.net.initialize(init=initializer, ctx=self.context)
else:
# initialize with user specified initializer
self.net.initialize(init=initializer, ctx=self.context, force_reinit=False)
else:
if not self._is_initialized():
self.net.initialize(ctx=self.context)
elif initializer:
# net is fully initialized, and user passed not None initializer
# do not force reinitialize, give warning
warnings.warn("Network already fully initialized, skipping initialization. "
"You don't need to pass initializer if you already "
"initialized your net. "
"You can use net.initialize(init=your_initializer, force_reinit=True)"
"to force re-initialize.")

def _check_trainer(self, trainer):
# handle trainer
Expand All @@ -157,11 +162,11 @@ def _is_initialized(self):
return False
return True

def _get_data_and_label(self, batch, ctx):
def _get_data_and_label(self, batch, ctx, batch_axis=0):
data = batch[0]
label = batch[1]
data = gluon.utils.split_and_load(data, ctx_list=ctx, batch_axis=0)
label = gluon.utils.split_and_load(label, ctx_list=ctx, batch_axis=0)
data = gluon.utils.split_and_load(data, ctx_list=ctx, batch_axis=batch_axis)
label = gluon.utils.split_and_load(label, ctx_list=ctx, batch_axis=batch_axis)
return data, label

def prepare_loss_and_metrics(self):
Expand All @@ -183,33 +188,36 @@ def prepare_loss_and_metrics(self):
self.train_metrics.append(Loss(loss.name.rstrip('1234567890')))
for metric in self.train_metrics:
val_metric = copy.deepcopy(metric)
metric.name = "Train " + metric.name
val_metric.name = "Validation " + val_metric.name
metric.name = "train " + metric.name
val_metric.name = "validation " + val_metric.name
self.val_metrics.append(val_metric)
return self.train_metrics, self.val_metrics

def evaluate(self,
val_data,
val_metrics):
val_metrics,
batch_axis=0):
"""Evaluate model on validation data

Parameters
----------
val_data : DataLoader
validation data with data and labels
Validation data loader with data and labels.
val_metrics : EvalMetric or list of EvalMetrics
metrics to update validation result
Metrics to update validation result.
batch_axis : int, default 0
Batch axis to split the validation data into devices.
"""
if not isinstance(val_data, gluon.data.DataLoader):
raise ValueError("Estimator only support input as Gluon DataLoader. Alternatively, you "
"can transform your DataIter or any NDArray into Gluon DataLoader. "
"Refer to gluon.data.dataloader")

for metric in val_metrics:
metric.reset()

for _, batch in enumerate(val_data):
if not isinstance(val_data, gluon.data.DataLoader):
raise ValueError("Estimator only support input as Gluon DataLoader. Alternatively, you "
"can transform your DataIter or any NDArray into Gluon DataLoader. "
"Refer to gluon.data.dataloader")
data, label = self._get_data_and_label(batch, self.context)
data, label = self._get_data_and_label(batch, self.context, batch_axis)
pred = [self.net(x) for x in data]
loss = [self.loss[0](y_hat, y) for y_hat, y in zip(pred, label)]
# update metrics
Expand All @@ -221,54 +229,65 @@ def evaluate(self,

def fit(self, train_data,
eric-haibin-lin marked this conversation as resolved.
Show resolved Hide resolved
val_data=None,
epochs=1,
event_handlers=None):
"""Trains the model on a given dataset for a specified
number of epochs. Also, the batch size is inferred from the
DataLoader's batch_size.
epochs=None,
event_handlers=None,
batches=None,
batch_axis=0):
"""Trains the model with a given :py:class:`DataLoader` for a specified
number of epochs or batches. The batch size is inferred from the
data loader's batch_size.

Parameters
----------
train_data : DataLoader
training data with data and labels
val_data : DataLoader
validation data with data and labels
epochs : int, default 1
number of epochs to iterate on the training data.
batch_size : int
number of samples per gradient update.
default will be 32 per device
Training data loader with data and labels.
val_data : DataLoader, default None
Validation data loader with data and labels.
epochs : int, default None
Number of epochs to iterate on the training data.
You can only specify one and only one type of iteration(epochs or batches).
event_handlers : EventHandler or list of EventHandler
list of EventHandlers to apply during training
batch_fn : function
custom batch function to extract data and label
from a data batch and load into contexts(devices)
List of :py:class:`EventHandlers` to apply during training.
batches : int, default None
Number of batches to iterate on the training data.
You can only specify one and only one type of iteration(epochs or batches).
batch_axis : int, default 0
Batch axis to split the training data into devices.
"""
self.max_epochs = epochs
if not isinstance(train_data, gluon.data.DataLoader):
raise ValueError("Estimator only support input as Gluon DataLoader. Alternatively, you "
"can transform your DataIter or any NDArray into Gluon DataLoader. "
"Refer to gluon.data.dataloader")

# must specify one and only one of epochs or batches
if (not epochs) == (not batches):
raise ValueError(
"Fit only support exactly one type of iteration, "
"train by number of epochs or number of batches."
"Please specify one and only one of: epochs or batches.")

self.max_epoch = epochs
self.max_batch = batches

# provide default handlers
event_handlers = self._prepare_default_handlers(val_data, event_handlers)

train_begin, epoch_begin, batch_begin, \
batch_end, epoch_end, train_end = self._categorize_handlers(event_handlers)

# only pass a weak reference to all event handlers
estimator_ref = weakref.proxy(self)
# pass a reference to all event handlers
estimator_ref = self
# training begin
for handler in train_begin:
handler.train_begin(estimator_ref)

for epoch in range(epochs):
while True:
# epoch begin
for handler in epoch_begin:
handler.epoch_begin(estimator_ref)

for i, batch in enumerate(train_data):
if not isinstance(train_data, gluon.data.DataLoader):
raise ValueError("Estimator only support input as Gluon DataLoader. Alternatively, you "
"can transform your DataIter or any NDArray into Gluon DataLoader. "
"Refer to gluon.data.dataloader")
data, label = self._get_data_and_label(batch, self.context)
data, label = self._get_data_and_label(batch, self.context, batch_axis)

batch_size = batch[0].shape[0]

Expand All @@ -285,15 +304,22 @@ def fit(self, train_data,

self.trainer.step(batch_size)
# batch end

batch_end_result = []
for handler in batch_end:
if handler.batch_end(estimator_ref, batch=batch,
pred=pred, label=label, loss=loss):
break
batch_end_result.append(handler.batch_end(estimator_ref, batch=batch,
pred=pred, label=label, loss=loss))
# if any handler signaled to stop
if any(batch_end_result):
break

# epoch end
epoch_end_result = []
for handler in epoch_end:
if handler.epoch_end(estimator_ref):
break
epoch_end_result.append(handler.epoch_end(estimator_ref))
# if any handler signaled to stop
if any(epoch_end_result):
break

# train end
for handler in train_end:
Expand All @@ -304,6 +330,9 @@ def _prepare_default_handlers(self, val_data, event_handlers):
default_handlers = []
train_metrics, val_metrics = self.prepare_loss_and_metrics()

# no need to add to default handler check as StoppingHandler does not use metrics
event_handlers.append(StoppingHandler(self.max_epoch, self.max_batch))

if not any(isinstance(handler, MetricHandler) for handler in event_handlers):
event_handlers.append(MetricHandler(train_metrics=train_metrics))
default_handlers.append("MetricHandler")
Expand All @@ -319,13 +348,14 @@ def _prepare_default_handlers(self, val_data, event_handlers):
default_handlers.append("LoggingHandler")

# if there is a mix of user defined event handlers and default event handlers
# they should have the save set of loss and metrics
# they should have the same set of loss and metrics
if default_handlers:
msg = "You are training with the following default event handlers: %s. " \
"They use loss and metrics from estimator.prepare_loss_and_metrics(). " \
"Please use the same set of metrics for all your other handlers." % \
", ".join(default_handlers)
warnings.warn(msg)
# check if all handlers has the same set of references to loss and metrics
references = []
for handler in event_handlers:
for attribute in dir(handler):
Expand All @@ -335,8 +365,10 @@ def _prepare_default_handlers(self, val_data, event_handlers):
references += reference
else:
references.append(reference)
# remove None metric references
references = set([ref for ref in references if ref])
for metric in references:
if metric and metric not in train_metrics + val_metrics:
if metric not in train_metrics + val_metrics:
msg = "We have added following default handlers for you: %s and used " \
"estimator.prepare_loss_and_metrics() to pass metrics to " \
"those handlers. Please use the same set of metrics " \
Expand Down
Loading