Skip to content

Commit

Permalink
Global logger (#556)
Browse files Browse the repository at this point in the history
* [experimental] global multiprocess logger

This PR adds a feature that different processes can add summaries (e.g. scalar) to the same TensorBoard file. In the case of a modularized program, there is no need to pass the SummaryWriter instance around. Now we provide a logging style syntax like GlobalSummaryWriter.getSummaryWriter() to get the writer from anywhere. Both cases depend on a new GlobalSummaryWriter class, the usage is similar to the SummaryWriter except that the parameter global_step in add_x(...) shall not be passed. This is by design because the execution order of a multiprocess program is not guaranteed, the global step should be maintained by the GlobalSummaryWriter instead of determined by each child process.
  • Loading branch information
lanpa authored Feb 22, 2020
1 parent d877dfe commit b59430b
Show file tree
Hide file tree
Showing 10 changed files with 325 additions and 8 deletions.
1 change: 1 addition & 0 deletions .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ script:
- sleep 5
- python -c "import visdom; v = visdom.Visdom()"
- py.test --cov=tensorboardX tests/
- pytest tests/tset_multiprocess_write.py # pytest has issue with multiprocessing, so I rename to "tset"
- python examples/demo.py
- python examples/demo_graph.py
- python examples/demo_embedding.py
Expand Down
9 changes: 9 additions & 0 deletions examples/demo_global_writer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
# This program show that you can use summary writer globally
# So that you can use the writer like the python.logging module

# This file triggers global_1 and global_2 to do their job.
import global_1
import time
time.sleep(2)
import global_2

49 changes: 49 additions & 0 deletions examples/demo_multiprocessing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
from tensorboardX import GlobalSummaryWriter
import multiprocessing as mp
import time
import os
import psutil
import torch
import numpy as np

w = GlobalSummaryWriter()


def train3():
for i in range(100):
w.add_scalar('many_write_in_func', np.random.randn())
time.sleep(0.01*np.random.randint(0, 10))

def train2(x):
np.random.seed(x)
w.add_scalar('few_write_per_func/1', np.random.randn())
time.sleep(0.05*np.random.randint(0, 10))
w.add_scalar('few_write_per_func/2', np.random.randn())

def train(x):

w.add_scalar('poolmap/1', x*np.random.randn())
time.sleep(0.05*np.random.randint(0, 10))
w.add_scalar('poolmap/2', x*np.random.randn())



if __name__ == '__main__':

with mp.Pool() as pool:
pool.map(train, range(100))


processes = []
for i in range(4):
p0 = mp.Process(target=train2, args=(i,))
p1 = mp.Process(target=train3)
processes.append(p0)
processes.append(p1)
p0.start()
p1.start()

for p in processes:
p.join()

w.close()
13 changes: 13 additions & 0 deletions examples/global_1.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# called by demo_global_writer

from tensorboardX import GlobalSummaryWriter

writer = GlobalSummaryWriter.getSummaryWriter()

writer.add_text('my_log', 'greeting from global1')

for i in range(100):
writer.add_scalar('global1', i)

for i in range(100):
writer.add_scalar('common', i)
13 changes: 13 additions & 0 deletions examples/global_2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# called by demo_global_writer

from tensorboardX import GlobalSummaryWriter

writer = GlobalSummaryWriter.getSummaryWriter()

writer.add_text('my_log', 'greeting from global2')

for i in range(100):
writer.add_scalar('global2', i)

for i in range(100):
writer.add_scalar('common', i)
2 changes: 2 additions & 0 deletions run_pytest.sh
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,5 @@ if [ `ps|grep visdom |wc -l` = "1" ]
fi

PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python pytest

pytest tests/tset_global_writer.py
2 changes: 1 addition & 1 deletion tensorboardX/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,5 +4,5 @@
from .record_writer import RecordWriter
from .torchvis import TorchVis
from .writer import FileWriter, SummaryWriter

from .global_writer import GlobalSummaryWriter
__version__ = "2.0" # will be overwritten if run setup.py
11 changes: 4 additions & 7 deletions tensorboardX/event_file_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
import socket
import threading
import time

import multiprocessing
import six

from .proto import event_pb2
Expand Down Expand Up @@ -102,7 +102,7 @@ def __init__(self, logdir, max_queue_size=10, flush_secs=120, filename_suffix=''
"""
self._logdir = logdir
directory_check(self._logdir)
self._event_queue = six.moves.queue.Queue(max_queue_size)
self._event_queue = multiprocessing.Queue(max_queue_size)
self._ev_writer = EventsWriter(os.path.join(
self._logdir, "events"), filename_suffix)
self._flush_secs = flush_secs
Expand Down Expand Up @@ -145,7 +145,6 @@ def flush(self):
disk.
"""
if not self._closed:
self._event_queue.join()
self._ev_writer.flush()

def close(self):
Expand All @@ -157,6 +156,7 @@ def close(self):
self.flush()
self._worker.stop()
self._ev_writer.close()
self._event_queue.close()
self._closed = True


Expand Down Expand Up @@ -201,15 +201,12 @@ def run(self):
else:
data = self._queue.get(False)

if data == self._shutdown_signal:
if type(data) == type(self._shutdown_signal):
return
self._record_writer.write_event(data)
self._has_pending_data = True
except six.moves.queue.Empty:
pass
finally:
if data:
self._queue.task_done()

now = time.time()
if now > self._next_flush_time:
Expand Down
142 changes: 142 additions & 0 deletions tensorboardX/global_writer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
from .writer import SummaryWriter
from multiprocessing import Value
import multiprocessing as mp

global _writer
_writer = None


class GlobalSummaryWriter(object):
def __init__(self, logdir=None, comment='', purge_step=None, max_queue=10,
flush_secs=120, filename_suffix='', write_to_disk=True, log_dir=None, coalesce_process=True, **kwargs):
self.smw = SummaryWriter(logdir=logdir, comment=comment, purge_step=purge_step, max_queue=max_queue,
flush_secs=flush_secs, filename_suffix=filename_suffix, write_to_disk=write_to_disk,
log_dir=log_dir)
self.lock = mp.Lock()
self.scalar_tag_to_step = mp.Manager().dict()
self.image_tag_to_step = mp.Manager().dict()
self.histogram_tag_to_step = mp.Manager().dict()
self.text_tag_to_step = mp.Manager().dict()
self.audio_tag_to_step = mp.Manager().dict()

def add_scalar(self, tag, scalar_value, walltime=None):
"""Add scalar data to summary.
Args:
tag (string): Data identifier
scalar_value (float): Value to save
walltime (float): Optional override default walltime (time.time()) of event
"""
with self.lock:
if tag in self.scalar_tag_to_step:
self.scalar_tag_to_step[tag] += 1
else:
self.scalar_tag_to_step[tag] = 0

self.smw.add_scalar(tag, scalar_value, self.scalar_tag_to_step[tag], walltime)

# def add_histogram(self, tag, values, bins='tensorflow', walltime=None, max_bins=None):
# """Add histogram to summary.

# Args:
# tag (string): Data identifier
# values (torch.Tensor, numpy.array): Values to build histogram
# bins (string): One of {'tensorflow','auto', 'fd', ...}.
# This determines how the bins are made. You can find
# other options in: https://docs.scipy.org/doc/numpy/reference/generated/numpy.histogram.html
# walltime (float): Optional override default walltime (time.time()) of event

# """
# with self.new_tag_mutex.get_lock():
# if tag in self.histogram_tag_to_step:
# self.histogram_tag_to_step[tag] += 1
# else:
# self.histogram_tag_to_step[tag] = 0
# self.smw.add_histogram(tag,
# values,
# self.histogram_tag_to_step[tag],
# bins=bins,
# walltime=walltime,
# max_bins=max_bins)

def add_image(self, tag, img_tensor, walltime=None, dataformats='CHW'):
"""Add image data to summary.
Note that this requires the ``pillow`` package.
Args:
tag (string): Data identifier
img_tensor (torch.Tensor, numpy.array): An `uint8` or `float`
Tensor of shape `[channel, height, width]` where `channel` is 1, 3, or 4.
The elements in img_tensor can either have values in [0, 1] (float32) or [0, 255] (uint8).
Users are responsible to scale the data in the correct range/type.
walltime (float): Optional override default walltime (time.time()) of event.
dataformats (string): This parameter specifies the meaning of each dimension of the input tensor.
Shape:
img_tensor: Default is :math:`(3, H, W)`. You can use ``torchvision.utils.make_grid()`` to
convert a batch of tensor into 3xHxW format or use ``add_images()`` and let us do the job.
Tensor with :math:`(1, H, W)`, :math:`(H, W)`, :math:`(H, W, 3)` is also suitible as long as
corresponding ``dataformats`` argument is passed. e.g. CHW, HWC, HW.
"""
with self.lock:
if tag in self.image_tag_to_step:
self.image_tag_to_step[tag] += 1
else:
self.image_tag_to_step[tag] = 0

self.smw.add_image(tag, img_tensor, self.image_tag_to_step[tag], walltime=walltime, dataformats=dataformats)

# def add_audio(self, tag, snd_tensor, sample_rate=44100, walltime=None):
# """Add audio data to summary.

# Args:
# tag (string): Data identifier
# snd_tensor (torch.Tensor): Sound data
# sample_rate (int): sample rate in Hz
# walltime (float): Optional override default walltime (time.time()) of event
# Shape:
# snd_tensor: :math:`(1, L)`. The values should lie between [-1, 1].
# """

# with self.new_tag_mutex.get_lock():
# if tag in self.audio_tag_to_step:
# self.audio_tag_to_step[tag] += 1
# else:
# self.audio_tag_to_step[tag] = 0

# self.smw.add_audio(tag, snd_tensor, self.audio_tag_to_step[tag], sample_rate=44100, walltime=walltime)

def add_text(self, tag, text_string, walltime=None):
"""Add text data to summary.
Args:
tag (string): Data identifier
text_string (string): String to save
walltime (float): Optional override default walltime (time.time()) of event
"""
with self.lock:
if tag in self.text_tag_to_step:
self.text_tag_to_step[tag] += 1
else:
self.text_tag_to_step[tag] = 0

self.smw.add_text(tag, text_string, global_step=self.text_tag_to_step[tag], walltime=walltime)

@staticmethod
def getSummaryWriter():
global _writer
if not hasattr(_writer, "smw") or _writer.smw is None:
_writer = GlobalSummaryWriter()

print("Using the global logger in:", _writer.smw.file_writer.get_logdir())
return _writer

@property
def file_writer(self):
return self.smw._get_file_writer()

def close(self):
self.smw.flush()
self.smw.close()
91 changes: 91 additions & 0 deletions tests/tset_multiprocess_write.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
# the file name is intended. pytest don't play well with multiprocessing

from tensorboardX import GlobalSummaryWriter as SummaryWriter
from tensorboard.compat.tensorflow_stub.pywrap_tensorflow import PyRecordReader_New
from tensorboardX.proto import event_pb2
import multiprocessing as mp
import numpy as np
import pytest
import unittest
import time


class GlobalWriterTest(unittest.TestCase):
def test_flush(self):
N_TEST = 5
w = SummaryWriter(flush_secs=1)
f = w.file_writer.event_writer._ev_writer._file_name
for i in range(N_TEST):
w.add_scalar('a', i)
time.sleep(2)
r = PyRecordReader_New(f)
r.GetNext() # meta data, so skip
for _ in range(N_TEST): # all of the data should be flushed
r.GetNext()

def test_flush_timer_is_long_so_data_is_not_there(self):
with self.assertRaises(BaseException):
N_TEST = 5
w = SummaryWriter(flush_secs=20)
f = w.file_writer.event_writer._ev_writer._file_name
for i in range(N_TEST):
w.add_scalar('a', i)
time.sleep(2)
r = PyRecordReader_New(f)
r.GetNext() # meta data, so skip
for _ in range(N_TEST): # missing data
r.GetNext()

def test_flush_after_close(self):
N_TEST = 5
w = SummaryWriter(flush_secs=20)
f = w.file_writer.event_writer._ev_writer._file_name
for i in range(N_TEST):
w.add_scalar('a', i)
time.sleep(2)
w.close()
r = PyRecordReader_New(f)
r.GetNext() # meta data, so skip
for _ in range(N_TEST): # all of the data should be flushed
r.GetNext()


def test_auto_close(self):
pass

def test_writer(self):
TEST_LEN = 100
N_PROC = 4
writer = SummaryWriter()
event_filename = writer.file_writer.event_writer._ev_writer._file_name

predifined_values = list(range(TEST_LEN))
def train3():
for i in range(TEST_LEN):
writer.add_scalar('many_write_in_func', predifined_values[i])
time.sleep(0.01*np.random.randint(0, 10))

processes = []
for i in range(N_PROC):
p1 = mp.Process(target=train3)
processes.append(p1)
p1.start()

for p in processes:
p.join()
writer.close()


collected_values = []
r = PyRecordReader_New(event_filename)
r.GetNext() # meta data, so skip
for _ in range(TEST_LEN*N_PROC): # all of the data should be flushed
r.GetNext()
ev = event_pb2.Event()
value = ev.FromString(r.record()).summary.value
collected_values.append(value[0].simple_value)

collected_values = sorted(collected_values)
for i in range(TEST_LEN):
for j in range(N_PROC):
assert collected_values[i*N_PROC+j] == i

0 comments on commit b59430b

Please sign in to comment.