Skip to content

Commit

Permalink
Add visualdl callback function (#27565)
Browse files Browse the repository at this point in the history
* add visualdl callback
  • Loading branch information
LielinJiang authored Sep 30, 2020
1 parent 9b3ef59 commit a0f1dba
Show file tree
Hide file tree
Showing 3 changed files with 140 additions and 1 deletion.
112 changes: 111 additions & 1 deletion python/paddle/hapi/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,14 @@
# limitations under the License.

import os
import numbers

from paddle.fluid.dygraph.parallel import ParallelEnv
from paddle.utils import try_import

from .progressbar import ProgressBar

__all__ = ['Callback', 'ProgBarLogger', 'ModelCheckpoint']
__all__ = ['Callback', 'ProgBarLogger', 'ModelCheckpoint', 'VisualDL']


def config_callbacks(callbacks=None,
Expand Down Expand Up @@ -471,3 +473,111 @@ def on_train_end(self, logs=None):
path = '{}/final'.format(self.save_dir)
print('save checkpoint at {}'.format(os.path.abspath(path)))
self.model.save(path)


class VisualDL(Callback):
"""VisualDL callback function
Args:
log_dir (str): The directory to save visualdl log file.
Examples:
.. code-block:: python
import paddle
from paddle.static import InputSpec
inputs = [InputSpec([-1, 1, 28, 28], 'float32', 'image')]
labels = [InputSpec([None, 1], 'int64', 'label')]
train_dataset = paddle.vision.datasets.MNIST(mode='train')
eval_dataset = paddle.vision.datasets.MNIST(mode='test')
net = paddle.vision.LeNet()
model = paddle.Model(net, inputs, labels)
optim = paddle.optimizer.Adam(0.001, parameters=net.parameters())
model.prepare(optimizer=optim,
loss=paddle.nn.CrossEntropyLoss(),
metrics=paddle.metric.Accuracy())
## uncomment following lines to fit model with visualdl callback function
# callback = paddle.callbacks.VisualDL(log_dir='visualdl_log_dir')
# model.fit(train_dataset, eval_dataset, batch_size=64, callbacks=callback)
"""

def __init__(self, log_dir):
self.log_dir = log_dir
self.epochs = None
self.steps = None
self.epoch = 0

def _is_write(self):
return ParallelEnv().local_rank == 0

def on_train_begin(self, logs=None):
self.epochs = self.params['epochs']
assert self.epochs
self.train_metrics = self.params['metrics']
assert self.train_metrics
self._is_fit = True
self.train_step = 0

def on_epoch_begin(self, epoch=None, logs=None):
self.steps = self.params['steps']
self.epoch = epoch

def _updates(self, logs, mode):
if not self._is_write():
return
if not hasattr(self, 'writer'):
visualdl = try_import('visualdl')
self.writer = visualdl.LogWriter(self.log_dir)

metrics = getattr(self, '%s_metrics' % (mode))
current_step = getattr(self, '%s_step' % (mode))

if mode == 'train':
total_step = current_step
else:
total_step = self.epoch

for k in metrics:
if k in logs:
temp_tag = mode + '/' + k

if isinstance(logs[k], (list, tuple)):
temp_value = logs[k][0]
elif isinstance(logs[k], numbers.Number):
temp_value = logs[k]
else:
continue

self.writer.add_scalar(
tag=temp_tag, step=total_step, value=temp_value)

def on_train_batch_end(self, step, logs=None):
logs = logs or {}
self.train_step += 1

if self._is_write():
self._updates(logs, 'train')

def on_eval_begin(self, logs=None):
self.eval_steps = logs.get('steps', None)
self.eval_metrics = logs.get('metrics', [])
self.eval_step = 0
self.evaled_samples = 0

def on_train_end(self, logs=None):
if hasattr(self, 'writer'):
self.writer.close()
delattr(self, 'writer')

def on_eval_end(self, logs=None):
if self._is_write():
self._updates(logs, 'eval')

if (not hasattr(self, '_is_fit')) and hasattr(self, 'writer'):
self.writer.close()
delattr(self, 'writer')
28 changes: 28 additions & 0 deletions python/paddle/tests/test_callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import sys
import unittest
import time
import random
import tempfile
import shutil
import paddle

from paddle import Model
from paddle.static import InputSpec
Expand Down Expand Up @@ -102,6 +104,32 @@ def test_callback_verbose_2(self):
self.verbose = 2
self.run_callback()

def test_visualdl_callback(self):
# visualdl not support python3
if sys.version_info < (3, ):
return

inputs = [InputSpec([-1, 1, 28, 28], 'float32', 'image')]
labels = [InputSpec([None, 1], 'int64', 'label')]

train_dataset = paddle.vision.datasets.MNIST(mode='train')
eval_dataset = paddle.vision.datasets.MNIST(mode='test')

net = paddle.vision.LeNet()
model = paddle.Model(net, inputs, labels)

optim = paddle.optimizer.Adam(0.001, parameters=net.parameters())
model.prepare(
optimizer=optim,
loss=paddle.nn.CrossEntropyLoss(),
metrics=paddle.metric.Accuracy())

callback = paddle.callbacks.VisualDL(log_dir='visualdl_log_dir')
model.fit(train_dataset,
eval_dataset,
batch_size=64,
callbacks=callback)


if __name__ == '__main__':
unittest.main()
1 change: 1 addition & 0 deletions python/unittest_py/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,4 @@ PyGithub
coverage
pycrypto ; platform_system != "Windows"
mock
visualdl ; python_version>="3.5"

0 comments on commit a0f1dba

Please sign in to comment.