From 7013aee10ac18c58080d02fcbc5d5f0b168325ad Mon Sep 17 00:00:00 2001 From: Lai Wei Date: Wed, 6 Mar 2019 09:33:54 -0800 Subject: [PATCH 1/7] base class for estimator and eventhandler --- example/gluon/estimator_example/mnist_cnn.py | 67 ++++++ python/mxnet/gluon/estimator/__init__.py | 21 ++ python/mxnet/gluon/estimator/estimator.py | 203 ++++++++++++++++++ python/mxnet/gluon/estimator/event_handler.py | 99 +++++++++ 4 files changed, 390 insertions(+) create mode 100644 example/gluon/estimator_example/mnist_cnn.py create mode 100644 python/mxnet/gluon/estimator/__init__.py create mode 100644 python/mxnet/gluon/estimator/estimator.py create mode 100644 python/mxnet/gluon/estimator/event_handler.py diff --git a/example/gluon/estimator_example/mnist_cnn.py b/example/gluon/estimator_example/mnist_cnn.py new file mode 100644 index 000000000000..27848dfffe95 --- /dev/null +++ b/example/gluon/estimator_example/mnist_cnn.py @@ -0,0 +1,67 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# coding: utf-8 +# pylint: disable=wildcard-import +"""Gluon Estimator example on MNIST dataset with simple CNN""" + +import os +import sys + +from mxnet import metric +from mxnet import gluon +from mxnet.gluon import nn, data +from mxnet.gluon.estimator import estimator + +net = nn.Sequential() + +net.add(nn.Conv2D(32, kernel_size=3, activation='relu'), + nn.Conv2D(64, kernel_size=3, activation='relu'), + nn.MaxPool2D(pool_size=2), + nn.Dropout(0.25), + nn.Flatten(), + nn.Dense(128, activation="relu"), nn.Dropout(0.5), + nn.Dropout(0.5), + nn.Dense(10)) + + +def load_data_fashion_mnist(batch_size, resize=None, root=os.path.join( + '~', '.mxnet', 'datasets', 'fashion-mnist')): + root = os.path.expanduser(root) # Expand the user path '~'. + transformer = [] + if resize: + transformer += [data.vision.transforms.Resize(resize)] + transformer += [data.vision.transforms.ToTensor()] + transformer = data.vision.transforms.Compose(transformer) + mnist_train = data.vision.MNIST(root=root, train=True) + mnist_test = data.vision.MNIST(root=root, train=False) + num_workers = 0 if sys.platform.startswith('win32') else 4 + train_iter = data.DataLoader( + mnist_train.transform_first(transformer), batch_size, shuffle=True, + num_workers=num_workers) + test_iter = data.DataLoader( + mnist_test.transform_first(transformer), batch_size, shuffle=False, + num_workers=num_workers) + return train_iter, test_iter + + +batch_size = 128 +train_data, test_data = load_data_fashion_mnist(batch_size, resize=28) +loss = gluon.loss.SoftmaxCrossEntropyLoss() +acc = metric.Accuracy() +est = estimator.Estimator(net=net, loss=loss, metrics=acc) +est.fit(train_data=train_data, epochs=5) \ No newline at end of file diff --git a/python/mxnet/gluon/estimator/__init__.py b/python/mxnet/gluon/estimator/__init__.py new file mode 100644 index 000000000000..fd1a80502c05 --- /dev/null +++ b/python/mxnet/gluon/estimator/__init__.py @@ -0,0 +1,21 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# pylint: disable=wildcard-import +"""Gluon Estimator Module""" +from .estimator import * +from .event_handler import * \ No newline at end of file diff --git a/python/mxnet/gluon/estimator/estimator.py b/python/mxnet/gluon/estimator/estimator.py new file mode 100644 index 000000000000..663557b409a0 --- /dev/null +++ b/python/mxnet/gluon/estimator/estimator.py @@ -0,0 +1,203 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# coding: utf-8 +# pylint: disable=wildcard-import +"""Gluon Estimator""" + + +import warnings + +from .event_handler import LoggingHandler +from ... import * +from ... import gluon, autograd +from ...context import cpu, gpu, num_gpus +from ...metric import EvalMetric, Loss + +__all__ = ['Estimator'] + + +class Estimator(object): + """ + Estimator Class for easy model training + TODO: update doc + """ + + def __init__(self, net, + loss=None, + metrics=None, + initializer=None, + trainers=None, + context=None): + + self.net = net + if isinstance(loss, gluon.loss.Loss): + self.loss = [loss] + else: + self.loss = loss or [] + if isinstance(metrics, EvalMetric): + self.metrics = [metrics] + else: + self.metrics = metrics or [] + + self.initializer = initializer + # store training statistics + self.train_stats = {} + self.train_stats['epochs'] = [] + self.train_stats['learning_rate'] = [] + # time used for each epoch + self.train_stats['step'] = '' + for metric in self.metrics: + # record a history of metrics over each epoch + self.train_stats['train_' + metric.name] = [] + # only record the latest metric numbers after each batch + self.train_stats['batch_' + metric.name] = 0. + self.loss_metrics = [] + # using the metric wrapper for loss to record loss value + for loss in self.loss: + self.loss_metrics.append(Loss(loss.name)) + self.train_stats['train_' + loss.name] = [] + # only record the latest loss numbers after each batch + self.train_stats['batch_' + loss.name] = 0. + + # handle context + if isinstance(context, Context): + self.context = [context] + if not context: + if num_gpus() > 0: + # only use 1 GPU by default + if num_gpus() > 1: + warnings.warn("You have multiple GPUs, gpu(0) will be used by default." + "To utilize all your GPUs, specify context as a list of gpus, e.g. context=[mx.gpu(0), mx.gpu(2)] ") + self.context = [gpu(0)] + else: + self.context = [cpu()] + + # initialize the network + if self.initializer: + if self._is_initialized(): + # if already initialized, re-init with user specified initializer + warnings.warn("You have already initialized your net, it will be forced re-initialized " + "with the initializer you speficied. You don't need to pass initializer if you alraedy initialized your net.") + self.net.initialize(init=self.initializer, ctx=self.context, force_reinit=True) + else: + # initialize with user specified initializer + self.net.initialize(init=self.initializer, ctx=self.context, force_reinit=False) + else: + if not self._is_initialized(): + self.net.initialize(ctx=self.context) + + # handle trainers + if isinstance(trainers, gluon.Trainer): + self.trainers = [trainers] + else: + self.trainers = trainers or [] + if not self.trainers: + warnings.warn("No trainer specified, default SGD optimizer with learning rate 0.001 is used.") + self.trainers = [gluon.Trainer(self.net.collect_params(), 'sgd', {'learning_rate': 0.001})] + + def _is_initialized(self): + param_dict = self.net.collect_params() + for param in param_dict: + try: + param_dict[param].list_ctx() + except RuntimeError: + return False + return True + + def _batch_fn(self, batch, ctx): + data = gluon.utils.split_and_load(batch[0], ctx_list=ctx, batch_axis=0) + label = gluon.utils.split_and_load(batch[1], ctx_list=ctx, batch_axis=0) + return data, label + + def fit(self, train_data, + val_data=None, + epochs=1, + batch_size=None, + event_handlers=None): + + if not batch_size: + batch_size = 32 * len(self.context) + + event_handlers = event_handlers or [] + # provide default logging handler + if not event_handlers or not any(isinstance(handler, LoggingHandler) for handler in event_handlers): + event_handlers.append(LoggingHandler(self)) + + # TODO: handle validation logic and update train stats + do_validation = False + if val_data: + do_validation = True + + # training begin + for handler in event_handlers: + handler.train_begin() + + for epoch in range(epochs): + # epoch begin + self.train_stats["epochs"].append(epoch) + self.train_stats["learning_rate"].append(self.trainers[0].learning_rate) + + for handler in event_handlers: + handler.epoch_begin() + + for metric in self.metrics + self.loss_metrics: + metric.reset() + + for i, batch in enumerate(train_data): + data, label = self._batch_fn(batch, self.context) + + # batch begin + for handler in event_handlers: + handler.batch_begin() + + with autograd.record(): + pred = [self.net(x) for x in data] + losses = [] + for loss in self.loss: + losses.append([loss(y_hat, y) for y_hat, y in zip(pred, label)]) + + for loss in losses: + for l in loss: + l.backward() + + # update metrics + for metric in self.metrics: + metric.update(label, pred) + self.train_stats['batch_' + metric.name] = metric.get()[1] + for loss, loss_metric, in zip(losses, self.loss_metrics): + loss_metric.update(0, [l for l in loss]) + self.train_stats['batch_' + loss_metric.name] = loss_metric.get()[1] + + self.train_stats['step'] = str(batch_size * (i + 1)) + '/' + str(len(train_data._dataset)) + + for trainer in self.trainers: + trainer.step(batch_size) + + # batch end + for handler in event_handlers: + handler.batch_end() + + for metric in self.metrics + self.loss_metrics: + self.train_stats['train_' + metric.name].append(metric.get()[1]) + # epoch end + for handler in event_handlers: + handler.epoch_end() + + # train end + for handler in event_handlers: + handler.train_end() diff --git a/python/mxnet/gluon/estimator/event_handler.py b/python/mxnet/gluon/estimator/event_handler.py new file mode 100644 index 000000000000..5f688a1736bd --- /dev/null +++ b/python/mxnet/gluon/estimator/event_handler.py @@ -0,0 +1,99 @@ +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# pylint: disable=arguments-differ, too-many-lines +# coding: utf-8 +"""Gluon EventHandlers for Estimators""" + +__all__ = ['EventHandler', 'LoggingHandler'] +import logging +import os +import time + + +class EventHandler(object): + def __init__(self, estimator): + self._estimator = estimator + + def train_begin(self): + pass + + def train_end(self): + pass + + def batch_begin(self): + pass + + def batch_end(self): + pass + + def epoch_begin(self): + pass + + def epoch_end(self): + pass + + +class LoggingHandler(EventHandler): + """Basic Logging Handler that applies to every Gluon estimator by default. + TODO: add doc + """ + + def __init__(self, estimator, log_name=None, file_name=None, file_location=None, ): + super(LoggingHandler, self).__init__(estimator) + log_name = log_name or 'Gluon Estimator' + self.logger = logging.getLogger(log_name) + self.logger.setLevel(logging.INFO) + streamhandler = logging.StreamHandler() + self.logger.addHandler(streamhandler) + # save logger to file only if file name or location is specified + if file_name or file_location: + file_name = file_name or log_name or 'estimator_log' + file_location = file_location or './' + filehandler = logging.FileHandler(os.path.join(file_location, file_name)) + self.logger.addHandler(filehandler) + + def train_begin(self): + pass + # logger.info(opt) + + def train_end(self): + pass + + def batch_begin(self): + self.batch_start = time.time() + + def batch_end(self): + batch_time = time.time() - self.batch_start + epoch = self._estimator.train_stats['epochs'][-1] + step = self._estimator.train_stats['step'] + msg = '[Epoch %d] [Step %s] time/step: %.3fs ' % (epoch, step, batch_time) + for key in self._estimator.train_stats.keys(): + if key.startswith('batch_'): + msg += key[6:] + ': ' + '%.4f ' % self._estimator.train_stats[key] + self.logger.info(msg) + + def epoch_begin(self): + self.epoch_start = time.time() + + def epoch_end(self): + epoch_time = time.time() - self.epoch_start + epoch = self._estimator.train_stats['epochs'][-1] + msg = 'Epoch %d finished in %.3fs: ' % (epoch, epoch_time) + for key in self._estimator.train_stats.keys(): + if key.startswith('train_') or key.startswith('test_'): + msg += key + ': ' + '%.4f ' % self._estimator.train_stats[key][epoch] + self.logger.info(msg) \ No newline at end of file From 51d011a127098567b13f47cab8c32a4daccfb0c5 Mon Sep 17 00:00:00 2001 From: Lai Wei Date: Wed, 6 Mar 2019 10:29:48 -0800 Subject: [PATCH 2/7] add license --- python/mxnet/gluon/estimator/event_handler.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/python/mxnet/gluon/estimator/event_handler.py b/python/mxnet/gluon/estimator/event_handler.py index 5f688a1736bd..ef54060ce574 100644 --- a/python/mxnet/gluon/estimator/event_handler.py +++ b/python/mxnet/gluon/estimator/event_handler.py @@ -1,3 +1,4 @@ +# Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file # distributed with this work for additional information # regarding copyright ownership. The ASF licenses this file @@ -14,8 +15,8 @@ # specific language governing permissions and limitations # under the License. -# pylint: disable=arguments-differ, too-many-lines # coding: utf-8 +# pylint: disable=wildcard-import """Gluon EventHandlers for Estimators""" __all__ = ['EventHandler', 'LoggingHandler'] From 8af19e5c679ed804fb87543aa03718091c15c0e9 Mon Sep 17 00:00:00 2001 From: Lai Wei Date: Tue, 12 Mar 2019 14:11:38 -0700 Subject: [PATCH 3/7] add event handlers --- example/gluon/estimator_example/mnist_cnn.py | 67 ----- python/mxnet/gluon/estimator/__init__.py | 2 +- python/mxnet/gluon/estimator/estimator.py | 120 ++++++--- python/mxnet/gluon/estimator/event_handler.py | 231 +++++++++++++++++- 4 files changed, 309 insertions(+), 111 deletions(-) delete mode 100644 example/gluon/estimator_example/mnist_cnn.py diff --git a/example/gluon/estimator_example/mnist_cnn.py b/example/gluon/estimator_example/mnist_cnn.py deleted file mode 100644 index 27848dfffe95..000000000000 --- a/example/gluon/estimator_example/mnist_cnn.py +++ /dev/null @@ -1,67 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -# coding: utf-8 -# pylint: disable=wildcard-import -"""Gluon Estimator example on MNIST dataset with simple CNN""" - -import os -import sys - -from mxnet import metric -from mxnet import gluon -from mxnet.gluon import nn, data -from mxnet.gluon.estimator import estimator - -net = nn.Sequential() - -net.add(nn.Conv2D(32, kernel_size=3, activation='relu'), - nn.Conv2D(64, kernel_size=3, activation='relu'), - nn.MaxPool2D(pool_size=2), - nn.Dropout(0.25), - nn.Flatten(), - nn.Dense(128, activation="relu"), nn.Dropout(0.5), - nn.Dropout(0.5), - nn.Dense(10)) - - -def load_data_fashion_mnist(batch_size, resize=None, root=os.path.join( - '~', '.mxnet', 'datasets', 'fashion-mnist')): - root = os.path.expanduser(root) # Expand the user path '~'. - transformer = [] - if resize: - transformer += [data.vision.transforms.Resize(resize)] - transformer += [data.vision.transforms.ToTensor()] - transformer = data.vision.transforms.Compose(transformer) - mnist_train = data.vision.MNIST(root=root, train=True) - mnist_test = data.vision.MNIST(root=root, train=False) - num_workers = 0 if sys.platform.startswith('win32') else 4 - train_iter = data.DataLoader( - mnist_train.transform_first(transformer), batch_size, shuffle=True, - num_workers=num_workers) - test_iter = data.DataLoader( - mnist_test.transform_first(transformer), batch_size, shuffle=False, - num_workers=num_workers) - return train_iter, test_iter - - -batch_size = 128 -train_data, test_data = load_data_fashion_mnist(batch_size, resize=28) -loss = gluon.loss.SoftmaxCrossEntropyLoss() -acc = metric.Accuracy() -est = estimator.Estimator(net=net, loss=loss, metrics=acc) -est.fit(train_data=train_data, epochs=5) \ No newline at end of file diff --git a/python/mxnet/gluon/estimator/__init__.py b/python/mxnet/gluon/estimator/__init__.py index fd1a80502c05..58600dadffb4 100644 --- a/python/mxnet/gluon/estimator/__init__.py +++ b/python/mxnet/gluon/estimator/__init__.py @@ -18,4 +18,4 @@ # pylint: disable=wildcard-import """Gluon Estimator Module""" from .estimator import * -from .event_handler import * \ No newline at end of file +from .event_handler import * diff --git a/python/mxnet/gluon/estimator/estimator.py b/python/mxnet/gluon/estimator/estimator.py index 663557b409a0..b0d404b8432d 100644 --- a/python/mxnet/gluon/estimator/estimator.py +++ b/python/mxnet/gluon/estimator/estimator.py @@ -19,22 +19,35 @@ # pylint: disable=wildcard-import """Gluon Estimator""" - import warnings from .event_handler import LoggingHandler -from ... import * from ... import gluon, autograd -from ...context import cpu, gpu, num_gpus +from ...context import Context, cpu, gpu, num_gpus +from ...io import DataIter from ...metric import EvalMetric, Loss __all__ = ['Estimator'] class Estimator(object): - """ - Estimator Class for easy model training - TODO: update doc + """Estimator Class for easy model training + + :py:class:`Estimator` can be used to facilitate the training & validation process + + + Parameters + ---------- + loss : Loss or list of Loss + Loss(objective functions) to calculate during training + metrics : EvalMetric or list of EvalMetric + Metrics for evaluating models + initializer : Initializer + initializer to initialize the network + trainers : Trainer or list of Trainer + Trainers to apply optimizers on network parameters + context : Context or list of Context + devices to run the training on """ def __init__(self, net, @@ -45,6 +58,8 @@ def __init__(self, net, context=None): self.net = net + self.stop_training = False + if isinstance(loss, gluon.loss.Loss): self.loss = [loss] else: @@ -59,7 +74,7 @@ def __init__(self, net, self.train_stats = {} self.train_stats['epochs'] = [] self.train_stats['learning_rate'] = [] - # time used for each epoch + # current step of the epoch self.train_stats['step'] = '' for metric in self.metrics: # record a history of metrics over each epoch @@ -68,11 +83,11 @@ def __init__(self, net, self.train_stats['batch_' + metric.name] = 0. self.loss_metrics = [] # using the metric wrapper for loss to record loss value - for loss in self.loss: - self.loss_metrics.append(Loss(loss.name)) - self.train_stats['train_' + loss.name] = [] + for l in self.loss: + self.loss_metrics.append(Loss(l.name)) + self.train_stats['train_' + l.name] = [] # only record the latest loss numbers after each batch - self.train_stats['batch_' + loss.name] = 0. + self.train_stats['batch_' + l.name] = 0. # handle context if isinstance(context, Context): @@ -82,7 +97,8 @@ def __init__(self, net, # only use 1 GPU by default if num_gpus() > 1: warnings.warn("You have multiple GPUs, gpu(0) will be used by default." - "To utilize all your GPUs, specify context as a list of gpus, e.g. context=[mx.gpu(0), mx.gpu(2)] ") + "To utilize all your GPUs, specify context as a list of gpus, " + "e.g. context=[mx.gpu(0), mx.gpu(2)] ") self.context = [gpu(0)] else: self.context = [cpu()] @@ -91,8 +107,9 @@ def __init__(self, net, if self.initializer: if self._is_initialized(): # if already initialized, re-init with user specified initializer - warnings.warn("You have already initialized your net, it will be forced re-initialized " - "with the initializer you speficied. You don't need to pass initializer if you alraedy initialized your net.") + warnings.warn("Network already initialized, re-initializing with %s. " + "You don't need to pass initializer if you already " + "initialized your net."% type(self.initializer).__name__) self.net.initialize(init=self.initializer, ctx=self.context, force_reinit=True) else: # initialize with user specified initializer @@ -107,8 +124,10 @@ def __init__(self, net, else: self.trainers = trainers or [] if not self.trainers: - warnings.warn("No trainer specified, default SGD optimizer with learning rate 0.001 is used.") - self.trainers = [gluon.Trainer(self.net.collect_params(), 'sgd', {'learning_rate': 0.001})] + warnings.warn("No trainer specified, default SGD optimizer " + "with learning rate 0.001 is used.") + self.trainers = [gluon.Trainer(self.net.collect_params(), + 'sgd', {'learning_rate': 0.001})] def _is_initialized(self): param_dict = self.net.collect_params() @@ -119,38 +138,62 @@ def _is_initialized(self): return False return True - def _batch_fn(self, batch, ctx): - data = gluon.utils.split_and_load(batch[0], ctx_list=ctx, batch_axis=0) - label = gluon.utils.split_and_load(batch[1], ctx_list=ctx, batch_axis=0) + def _batch_fn(self, batch, ctx, is_iterator=False): + if is_iterator: + data = batch.data[0] + label = batch.label[0] + else: + data = batch[0] + label = batch[1] + data = gluon.utils.split_and_load(data, ctx_list=ctx, batch_axis=0) + label = gluon.utils.split_and_load(label, ctx_list=ctx, batch_axis=0) return data, label def fit(self, train_data, val_data=None, epochs=1, batch_size=None, - event_handlers=None): - + event_handlers=None, + batch_fn=None): + """Main training loop + + Parameters + ---------- + train_data : DataLoader or DataIter + training data with data and labels + val_data : DataLoader or DataIter + validation data with data and labels + epochs : int, default 1 + number of epochs to iterate on the training data. + batch_size : int + number of samples per gradient update. + default will be 32 per device + event_handlers : EventHandler or list of EventHandler + list of EventHandlers to apply during training + batch_fn : function + custom batch function to extract data and label + from a data batch and load into contexts(devices) + """ + + + self.epochs = epochs if not batch_size: batch_size = 32 * len(self.context) event_handlers = event_handlers or [] # provide default logging handler - if not event_handlers or not any(isinstance(handler, LoggingHandler) for handler in event_handlers): + if not event_handlers or \ + not any(isinstance(handler, LoggingHandler) for handler in event_handlers): event_handlers.append(LoggingHandler(self)) - # TODO: handle validation logic and update train stats - do_validation = False - if val_data: - do_validation = True - # training begin for handler in event_handlers: handler.train_begin() for epoch in range(epochs): # epoch begin - self.train_stats["epochs"].append(epoch) - self.train_stats["learning_rate"].append(self.trainers[0].learning_rate) + self.train_stats['epochs'].append(epoch) + self.train_stats['learning_rate'].append(self.trainers[0].learning_rate) for handler in event_handlers: handler.epoch_begin() @@ -159,7 +202,16 @@ def fit(self, train_data, metric.reset() for i, batch in enumerate(train_data): - data, label = self._batch_fn(batch, self.context) + if not batch_fn: + if isinstance(train_data, gluon.data.DataLoader): + data, label = self._batch_fn(batch, self.context) + elif isinstance(train_data, DataIter): + data, label = self._batch_fn(batch, self.context, is_iterator=True) + else: + raise ValueError("You are using a custom iteration, please also provide " + "batch_fn to extract data and label") + else: + data, label = batch_fn(batch, self.context) # batch begin for handler in event_handlers: @@ -183,7 +235,10 @@ def fit(self, train_data, loss_metric.update(0, [l for l in loss]) self.train_stats['batch_' + loss_metric.name] = loss_metric.get()[1] - self.train_stats['step'] = str(batch_size * (i + 1)) + '/' + str(len(train_data._dataset)) + try: + self.train_stats['step'] = "{}/{}".format(batch_size * (i + 1), len(train_data._dataset)) + except AttributeError: + self.train_stats['step'] = i for trainer in self.trainers: trainer.step(batch_size) @@ -198,6 +253,9 @@ def fit(self, train_data, for handler in event_handlers: handler.epoch_end() + if self.stop_training: + break + # train end for handler in event_handlers: handler.train_end() diff --git a/python/mxnet/gluon/estimator/event_handler.py b/python/mxnet/gluon/estimator/event_handler.py index ef54060ce574..fb1d815d3aef 100644 --- a/python/mxnet/gluon/estimator/event_handler.py +++ b/python/mxnet/gluon/estimator/event_handler.py @@ -23,9 +23,23 @@ import logging import os import time +import warnings + +import numpy as np class EventHandler(object): + """Basic for event handlers + + :py:class:`EventHandler` can perform user defined functions at + different stages of training: train begin, epoch begin, batch begin, + batch end, epoch end, train end. + + Parameters + ---------- + estimator : Estimator + The :py:class:`Estimator` to get training statistics + """ def __init__(self, estimator): self._estimator = estimator @@ -50,26 +64,35 @@ def epoch_end(self): class LoggingHandler(EventHandler): """Basic Logging Handler that applies to every Gluon estimator by default. - TODO: add doc + + :py:class:`LoggingHandler` logs hyper-parameters, training statistics, + and other useful information during training + + Parameters + ---------- + estimator : Estimator + The :py:class:`Estimator` to get training statistics + file_name : str + file name to save the logs + file_location: str + file location to save the logs """ - def __init__(self, estimator, log_name=None, file_name=None, file_location=None, ): + def __init__(self, estimator, file_name=None, file_location=None, ): super(LoggingHandler, self).__init__(estimator) - log_name = log_name or 'Gluon Estimator' - self.logger = logging.getLogger(log_name) + self.logger = logging.getLogger(__name__) self.logger.setLevel(logging.INFO) - streamhandler = logging.StreamHandler() - self.logger.addHandler(streamhandler) + stream_handler = logging.StreamHandler() + self.logger.addHandler(stream_handler) # save logger to file only if file name or location is specified if file_name or file_location: - file_name = file_name or log_name or 'estimator_log' + file_name = file_name or 'estimator_log' file_location = file_location or './' - filehandler = logging.FileHandler(os.path.join(file_location, file_name)) - self.logger.addHandler(filehandler) + file_handler = logging.FileHandler(os.path.join(file_location, file_name)) + self.logger.addHandler(file_handler) def train_begin(self): pass - # logger.info(opt) def train_end(self): pass @@ -93,8 +116,192 @@ def epoch_begin(self): def epoch_end(self): epoch_time = time.time() - self.epoch_start epoch = self._estimator.train_stats['epochs'][-1] - msg = 'Epoch %d finished in %.3fs: ' % (epoch, epoch_time) + msg = '\n[Epoch %d] finished in %.3fs: ' % (epoch, epoch_time) for key in self._estimator.train_stats.keys(): if key.startswith('train_') or key.startswith('test_'): msg += key + ': ' + '%.4f ' % self._estimator.train_stats[key][epoch] - self.logger.info(msg) \ No newline at end of file + self.logger.info(msg) + + +class CheckpointHandler(EventHandler): + """Save the model after every epoch. + + :py:class:`CheckpointHandler` save the network parameters every epoch + + Parameters + ---------- + estimator : Estimator + The :py:class:`Estimator` to get training statistics + filepath : str + file name to save the parameters, it can contain directories, + for example: ./saved_model/resnet.params + monitor: str + the metrics to monitor + verbose: int, default 0 + verbosity mode + save_best_only: bool + if True, only save the parameters if monitored value improved + mode: str, default 'auto' + one of {auto, min, max}, if `save_best_only=True`, the comparison to make + and determine if the monitored value has improved + period: int, default 1 + intervals between saving the network + """ + + def __init__(self, estimator, + filepath, + monitor='val_loss', + verbose=0, + save_best_only=False, + mode='auto', + period=1): + super(CheckpointHandler, self).__init__(estimator) + self.monitor = monitor + self.verbose = verbose + self.filepath = filepath + self.save_best_only = save_best_only + self.period = period + self.epochs_since_last_save = 0 + self.logger = logging.getLogger(__name__) + + if mode not in ['auto', 'min', 'max']: + warnings.warn('ModelCheckpoint mode %s is unknown, ' + 'fallback to auto mode.' % (mode), + RuntimeWarning) + mode = 'auto' + + if mode == 'min': + self.monitor_op = np.less + self.best = np.Inf + elif mode == 'max': + self.monitor_op = np.greater + self.best = -np.Inf + else: + # use greater for accuracy and less otherwise + if 'acc' in self.monitor: + self.monitor_op = np.greater + self.best = -np.Inf + else: + self.monitor_op = np.less + self.best = np.Inf + + def epoch_end(self, ): + epoch = self._estimator.train_stats['epochs'][-1] + # add extension for weights + if '.params' not in self.filepath: + self.filepath += '.params' + self.epochs_since_last_save += 1 + if self.epochs_since_last_save >= self.period: + self.epochs_since_last_save = 0 + if self.save_best_only: + # check if monitor exists in train_stats + if self.monitor not in self._estimator.train_stats: + warnings.warn(RuntimeWarning('Unable to find %s in training statistics, make sure' + 'you are passing one of the metric names as monitor', self.monitor)) + self._estimator.net.save_parameters(self.filepath) + else: + current = self._estimator.train_stats[self.monitor][-1] + if self.monitor_op(current, self.best): + if self.verbose > 0: + self.logger.info('\n[Epoch %d] %s improved from %0.5f to %0.5f,' + ' saving model to %s', + epoch, self.monitor, self.best, current, self.filepath) + self.best = current + self._estimator.net.save_parameters(self.filepath) + else: + if self.verbose > 0: + self.logger.info('\n[Epoch %d] %s did not improve from %0.5f, skipping save model', + epoch, self.monitor, self.best) + else: + if self.verbose > 0: + logging.info('\nEpoch %d: saving model to %s', epoch, self.filepath) + self._estimator.net.save_parameters(self.filepath) + + +class EarlyStoppingHandler(EventHandler): + """Early stop training if monitored value is not improving + + Parameters + ---------- + estimator : Estimator + The :py:class:`Estimator` to get training statistics + monitor: str + the metrics to monitor + min_delta: float, default 0 + minimal change in monitored value to be considered as an improvement + patience: int, default 0 + number of epochs to wait for improvement before terminate training + mode: str, default 'auto' + one of {auto, min, max}, the comparison to make + and determine if the monitored value has improved + baseline: float + baseline value to compare the monitored value with + """ + + def __init__(self, estimator, + monitor='val_loss', + min_delta=0, + patience=0, + mode='auto', + baseline=None): + super(EarlyStoppingHandler, self).__init__(estimator) + + self._estimator = estimator + self.monitor = monitor + self.baseline = baseline + self.patience = patience + self.min_delta = min_delta + self.wait = 0 + self.stopped_epoch = 0 + self.logger = logging.getLogger(__name__) + + if mode not in ['auto', 'min', 'max']: + warnings.warn(RuntimeWarning('EarlyStopping mode %s is unknown, ' + 'fallback to auto mode.', mode)) + mode = 'auto' + + if mode == 'min': + self.monitor_op = np.less + elif mode == 'max': + self.monitor_op = np.greater + else: + if 'acc' in self.monitor: + self.monitor_op = np.greater + else: + self.monitor_op = np.less + + if self.monitor_op == np.greater: + self.min_delta *= 1 + else: + self.min_delta *= -1 + + def train_begin(self): + self.wait = 0 + self.stopped_epoch = 0 + if self.baseline is not None: + self.best = self.baseline + else: + self.best = np.Inf if self.monitor_op == np.less else -np.Inf + + def epoch_end(self): + epoch = self._estimator.train_stats['epochs'][-1] + if self.monitor not in self._estimator.train_stats: + warnings.warn(RuntimeWarning('Unable to find %s in training statistics, make sure' + 'you are passing one of the metric names as monitor', self.monitor)) + else: + current = self._estimator.train_stats[self.monitor][-1] + if current is None: + return + + if self.monitor_op(current - self.min_delta, self.best): + self.best = current + self.wait = 0 + else: + self.wait += 1 + if self.wait >= self.patience: + self.stopped_epoch = epoch + self._estimator.stop_training = True + + def train_end(self): + if self.stopped_epoch > 0: + self.logger.info('Epoch %d: early stopping due to %s not improving', self.stopped_epoch, self.monitor) From 67b468df86c11b77660a681804119bbb261378a8 Mon Sep 17 00:00:00 2001 From: Lai Wei Date: Tue, 12 Mar 2019 14:58:54 -0700 Subject: [PATCH 4/7] fix pylint --- python/mxnet/gluon/estimator/estimator.py | 1 - 1 file changed, 1 deletion(-) diff --git a/python/mxnet/gluon/estimator/estimator.py b/python/mxnet/gluon/estimator/estimator.py index b0d404b8432d..5ec2e28e59b9 100644 --- a/python/mxnet/gluon/estimator/estimator.py +++ b/python/mxnet/gluon/estimator/estimator.py @@ -150,7 +150,6 @@ def _batch_fn(self, batch, ctx, is_iterator=False): return data, label def fit(self, train_data, - val_data=None, epochs=1, batch_size=None, event_handlers=None, From 62e67ddee5fb07632c8a6a2f82b8e046824f68b6 Mon Sep 17 00:00:00 2001 From: Lai Wei Date: Thu, 14 Mar 2019 16:47:56 -0700 Subject: [PATCH 5/7] improve arg check --- python/mxnet/gluon/estimator/estimator.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/python/mxnet/gluon/estimator/estimator.py b/python/mxnet/gluon/estimator/estimator.py index 5ec2e28e59b9..468d22e67698 100644 --- a/python/mxnet/gluon/estimator/estimator.py +++ b/python/mxnet/gluon/estimator/estimator.py @@ -64,10 +64,17 @@ def __init__(self, net, self.loss = [loss] else: self.loss = loss or [] + for loss in self.loss: + if not isinstance(loss, gluon.loss.Loss): + raise ValueError("loss must be a Loss or a list of Loss, refer to gluon.loss.Loss") + if isinstance(metrics, EvalMetric): self.metrics = [metrics] else: self.metrics = metrics or [] + for metric in self.metrics: + if not isinstance(metric, EvalMetric): + raise ValueError("metrics must be a Metric or a list of Metric, refer to mxnet.metric.EvalMetric") self.initializer = initializer # store training statistics @@ -98,7 +105,7 @@ def __init__(self, net, if num_gpus() > 1: warnings.warn("You have multiple GPUs, gpu(0) will be used by default." "To utilize all your GPUs, specify context as a list of gpus, " - "e.g. context=[mx.gpu(0), mx.gpu(2)] ") + "e.g. context=[mx.gpu(0), mx.gpu(1)] ") self.context = [gpu(0)] else: self.context = [cpu()] From 07933f1e867c4997d050d2cc8c193f2299f9d025 Mon Sep 17 00:00:00 2001 From: Lai Wei Date: Thu, 14 Mar 2019 16:59:02 -0700 Subject: [PATCH 6/7] fix pylint --- python/mxnet/gluon/estimator/estimator.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/mxnet/gluon/estimator/estimator.py b/python/mxnet/gluon/estimator/estimator.py index 468d22e67698..159f7e220427 100644 --- a/python/mxnet/gluon/estimator/estimator.py +++ b/python/mxnet/gluon/estimator/estimator.py @@ -64,7 +64,7 @@ def __init__(self, net, self.loss = [loss] else: self.loss = loss or [] - for loss in self.loss: + for l in self.loss: if not isinstance(loss, gluon.loss.Loss): raise ValueError("loss must be a Loss or a list of Loss, refer to gluon.loss.Loss") From d551e442c6e6fa826ab9fe2778507f9b19156a8d Mon Sep 17 00:00:00 2001 From: Lai Wei Date: Fri, 15 Mar 2019 11:39:23 -0700 Subject: [PATCH 7/7] add unit tests --- python/mxnet/gluon/estimator/event_handler.py | 2 +- .../unittest/test_gluon_event_handler.py | 92 +++++++++++++++++++ 2 files changed, 93 insertions(+), 1 deletion(-) create mode 100644 tests/python/unittest/test_gluon_event_handler.py diff --git a/python/mxnet/gluon/estimator/event_handler.py b/python/mxnet/gluon/estimator/event_handler.py index fb1d815d3aef..0162c36993f3 100644 --- a/python/mxnet/gluon/estimator/event_handler.py +++ b/python/mxnet/gluon/estimator/event_handler.py @@ -215,7 +215,7 @@ def epoch_end(self, ): else: if self.verbose > 0: logging.info('\nEpoch %d: saving model to %s', epoch, self.filepath) - self._estimator.net.save_parameters(self.filepath) + self._estimator.net.save_parameters(self.filepath) class EarlyStoppingHandler(EventHandler): diff --git a/tests/python/unittest/test_gluon_event_handler.py b/tests/python/unittest/test_gluon_event_handler.py new file mode 100644 index 000000000000..a551594d6430 --- /dev/null +++ b/tests/python/unittest/test_gluon_event_handler.py @@ -0,0 +1,92 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import os +import tempfile +import mxnet as mx +from mxnet import nd +from mxnet.gluon import nn, loss +from mxnet.gluon.estimator import estimator, event_handler + +def _get_test_network(): + net = nn.Sequential() + net.add(nn.Dense(128, activation='relu', in_units=100, flatten=False), + nn.Dense(64, activation='relu', in_units=128), + nn.Dense(10, activation='relu', in_units=64)) + return net + +def _get_test_data(): + return mx.io.NDArrayIter(data=nd.ones((32, 100)), label=nd.random.randint(0, 10, (32, 1))) + + +def test_checkpoint_handler(): + tmpdir = tempfile.mkdtemp() + file_path = os.path.join(tmpdir, "model.params") + test_data = _get_test_data() + + save_best_only = False + mode = 'auto' + + net = _get_test_network() + ce_loss = loss.SoftmaxCrossEntropyLoss() + acc = mx.metric.Accuracy() + est = estimator.Estimator(net, loss=ce_loss, metrics=acc) + checkpoint_handler = [event_handler.CheckpointHandler(est, file_path, + save_best_only=save_best_only, + mode=mode)] + est.fit(test_data, event_handlers=checkpoint_handler, epochs=1) + assert os.path.isfile(file_path) + os.remove(file_path) + +def test_early_stopping(): + test_data = _get_test_data() + + mode = 'max' + monitor = 'train_accuracy' + patience = 0 + + net = _get_test_network() + ce_loss = loss.SoftmaxCrossEntropyLoss() + acc = mx.metric.Accuracy() + est = estimator.Estimator(net, loss=ce_loss, metrics=acc) + early_stopping = [event_handler.EarlyStoppingHandler(est, monitor, + patience=patience, + mode=mode)] + est.fit(test_data, event_handlers=early_stopping, epochs=1) + + mode = 'auto' + monitor = 'train_accuracy' + patience = 2 + early_stopping = [event_handler.EarlyStoppingHandler(est, monitor, + patience=patience, + mode=mode)] + est.fit(test_data, event_handlers=early_stopping, epochs=1) + +def test_logging(): + tmpdir = tempfile.mkdtemp() + test_data = _get_test_data() + file_name = 'test_log' + output_dir = os.path.join(tmpdir, file_name) + + net = _get_test_network() + ce_loss = loss.SoftmaxCrossEntropyLoss() + acc = mx.metric.Accuracy() + est = estimator.Estimator(net, loss=ce_loss, metrics=acc) + logging_handler = [event_handler.LoggingHandler(est, file_name=file_name, file_location=tmpdir)] + est.fit(test_data, event_handlers=logging_handler, epochs=1) + assert os.path.isfile(output_dir) + os.remove(output_dir) \ No newline at end of file