-
Notifications
You must be signed in to change notification settings - Fork 865
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* [experimental] global multiprocess logger This PR adds a feature that different processes can add summaries (e.g. scalar) to the same TensorBoard file. In the case of a modularized program, there is no need to pass the SummaryWriter instance around. Now we provide a logging style syntax like GlobalSummaryWriter.getSummaryWriter() to get the writer from anywhere. Both cases depend on a new GlobalSummaryWriter class, the usage is similar to the SummaryWriter except that the parameter global_step in add_x(...) shall not be passed. This is by design because the execution order of a multiprocess program is not guaranteed, the global step should be maintained by the GlobalSummaryWriter instead of determined by each child process.
- Loading branch information
Showing
10 changed files
with
325 additions
and
8 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
# This program show that you can use summary writer globally | ||
# So that you can use the writer like the python.logging module | ||
|
||
# This file triggers global_1 and global_2 to do their job. | ||
import global_1 | ||
import time | ||
time.sleep(2) | ||
import global_2 | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,49 @@ | ||
from tensorboardX import GlobalSummaryWriter | ||
import multiprocessing as mp | ||
import time | ||
import os | ||
import psutil | ||
import torch | ||
import numpy as np | ||
|
||
w = GlobalSummaryWriter() | ||
|
||
|
||
def train3(): | ||
for i in range(100): | ||
w.add_scalar('many_write_in_func', np.random.randn()) | ||
time.sleep(0.01*np.random.randint(0, 10)) | ||
|
||
def train2(x): | ||
np.random.seed(x) | ||
w.add_scalar('few_write_per_func/1', np.random.randn()) | ||
time.sleep(0.05*np.random.randint(0, 10)) | ||
w.add_scalar('few_write_per_func/2', np.random.randn()) | ||
|
||
def train(x): | ||
|
||
w.add_scalar('poolmap/1', x*np.random.randn()) | ||
time.sleep(0.05*np.random.randint(0, 10)) | ||
w.add_scalar('poolmap/2', x*np.random.randn()) | ||
|
||
|
||
|
||
if __name__ == '__main__': | ||
|
||
with mp.Pool() as pool: | ||
pool.map(train, range(100)) | ||
|
||
|
||
processes = [] | ||
for i in range(4): | ||
p0 = mp.Process(target=train2, args=(i,)) | ||
p1 = mp.Process(target=train3) | ||
processes.append(p0) | ||
processes.append(p1) | ||
p0.start() | ||
p1.start() | ||
|
||
for p in processes: | ||
p.join() | ||
|
||
w.close() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
# called by demo_global_writer | ||
|
||
from tensorboardX import GlobalSummaryWriter | ||
|
||
writer = GlobalSummaryWriter.getSummaryWriter() | ||
|
||
writer.add_text('my_log', 'greeting from global1') | ||
|
||
for i in range(100): | ||
writer.add_scalar('global1', i) | ||
|
||
for i in range(100): | ||
writer.add_scalar('common', i) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
# called by demo_global_writer | ||
|
||
from tensorboardX import GlobalSummaryWriter | ||
|
||
writer = GlobalSummaryWriter.getSummaryWriter() | ||
|
||
writer.add_text('my_log', 'greeting from global2') | ||
|
||
for i in range(100): | ||
writer.add_scalar('global2', i) | ||
|
||
for i in range(100): | ||
writer.add_scalar('common', i) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,142 @@ | ||
from .writer import SummaryWriter | ||
from multiprocessing import Value | ||
import multiprocessing as mp | ||
|
||
global _writer | ||
_writer = None | ||
|
||
|
||
class GlobalSummaryWriter(object): | ||
def __init__(self, logdir=None, comment='', purge_step=None, max_queue=10, | ||
flush_secs=120, filename_suffix='', write_to_disk=True, log_dir=None, coalesce_process=True, **kwargs): | ||
self.smw = SummaryWriter(logdir=logdir, comment=comment, purge_step=purge_step, max_queue=max_queue, | ||
flush_secs=flush_secs, filename_suffix=filename_suffix, write_to_disk=write_to_disk, | ||
log_dir=log_dir) | ||
self.lock = mp.Lock() | ||
self.scalar_tag_to_step = mp.Manager().dict() | ||
self.image_tag_to_step = mp.Manager().dict() | ||
self.histogram_tag_to_step = mp.Manager().dict() | ||
self.text_tag_to_step = mp.Manager().dict() | ||
self.audio_tag_to_step = mp.Manager().dict() | ||
|
||
def add_scalar(self, tag, scalar_value, walltime=None): | ||
"""Add scalar data to summary. | ||
Args: | ||
tag (string): Data identifier | ||
scalar_value (float): Value to save | ||
walltime (float): Optional override default walltime (time.time()) of event | ||
""" | ||
with self.lock: | ||
if tag in self.scalar_tag_to_step: | ||
self.scalar_tag_to_step[tag] += 1 | ||
else: | ||
self.scalar_tag_to_step[tag] = 0 | ||
|
||
self.smw.add_scalar(tag, scalar_value, self.scalar_tag_to_step[tag], walltime) | ||
|
||
# def add_histogram(self, tag, values, bins='tensorflow', walltime=None, max_bins=None): | ||
# """Add histogram to summary. | ||
|
||
# Args: | ||
# tag (string): Data identifier | ||
# values (torch.Tensor, numpy.array): Values to build histogram | ||
# bins (string): One of {'tensorflow','auto', 'fd', ...}. | ||
# This determines how the bins are made. You can find | ||
# other options in: https://docs.scipy.org/doc/numpy/reference/generated/numpy.histogram.html | ||
# walltime (float): Optional override default walltime (time.time()) of event | ||
|
||
# """ | ||
# with self.new_tag_mutex.get_lock(): | ||
# if tag in self.histogram_tag_to_step: | ||
# self.histogram_tag_to_step[tag] += 1 | ||
# else: | ||
# self.histogram_tag_to_step[tag] = 0 | ||
# self.smw.add_histogram(tag, | ||
# values, | ||
# self.histogram_tag_to_step[tag], | ||
# bins=bins, | ||
# walltime=walltime, | ||
# max_bins=max_bins) | ||
|
||
def add_image(self, tag, img_tensor, walltime=None, dataformats='CHW'): | ||
"""Add image data to summary. | ||
Note that this requires the ``pillow`` package. | ||
Args: | ||
tag (string): Data identifier | ||
img_tensor (torch.Tensor, numpy.array): An `uint8` or `float` | ||
Tensor of shape `[channel, height, width]` where `channel` is 1, 3, or 4. | ||
The elements in img_tensor can either have values in [0, 1] (float32) or [0, 255] (uint8). | ||
Users are responsible to scale the data in the correct range/type. | ||
walltime (float): Optional override default walltime (time.time()) of event. | ||
dataformats (string): This parameter specifies the meaning of each dimension of the input tensor. | ||
Shape: | ||
img_tensor: Default is :math:`(3, H, W)`. You can use ``torchvision.utils.make_grid()`` to | ||
convert a batch of tensor into 3xHxW format or use ``add_images()`` and let us do the job. | ||
Tensor with :math:`(1, H, W)`, :math:`(H, W)`, :math:`(H, W, 3)` is also suitible as long as | ||
corresponding ``dataformats`` argument is passed. e.g. CHW, HWC, HW. | ||
""" | ||
with self.lock: | ||
if tag in self.image_tag_to_step: | ||
self.image_tag_to_step[tag] += 1 | ||
else: | ||
self.image_tag_to_step[tag] = 0 | ||
|
||
self.smw.add_image(tag, img_tensor, self.image_tag_to_step[tag], walltime=walltime, dataformats=dataformats) | ||
|
||
# def add_audio(self, tag, snd_tensor, sample_rate=44100, walltime=None): | ||
# """Add audio data to summary. | ||
|
||
# Args: | ||
# tag (string): Data identifier | ||
# snd_tensor (torch.Tensor): Sound data | ||
# sample_rate (int): sample rate in Hz | ||
# walltime (float): Optional override default walltime (time.time()) of event | ||
# Shape: | ||
# snd_tensor: :math:`(1, L)`. The values should lie between [-1, 1]. | ||
# """ | ||
|
||
# with self.new_tag_mutex.get_lock(): | ||
# if tag in self.audio_tag_to_step: | ||
# self.audio_tag_to_step[tag] += 1 | ||
# else: | ||
# self.audio_tag_to_step[tag] = 0 | ||
|
||
# self.smw.add_audio(tag, snd_tensor, self.audio_tag_to_step[tag], sample_rate=44100, walltime=walltime) | ||
|
||
def add_text(self, tag, text_string, walltime=None): | ||
"""Add text data to summary. | ||
Args: | ||
tag (string): Data identifier | ||
text_string (string): String to save | ||
walltime (float): Optional override default walltime (time.time()) of event | ||
""" | ||
with self.lock: | ||
if tag in self.text_tag_to_step: | ||
self.text_tag_to_step[tag] += 1 | ||
else: | ||
self.text_tag_to_step[tag] = 0 | ||
|
||
self.smw.add_text(tag, text_string, global_step=self.text_tag_to_step[tag], walltime=walltime) | ||
|
||
@staticmethod | ||
def getSummaryWriter(): | ||
global _writer | ||
if not hasattr(_writer, "smw") or _writer.smw is None: | ||
_writer = GlobalSummaryWriter() | ||
|
||
print("Using the global logger in:", _writer.smw.file_writer.get_logdir()) | ||
return _writer | ||
|
||
@property | ||
def file_writer(self): | ||
return self.smw._get_file_writer() | ||
|
||
def close(self): | ||
self.smw.flush() | ||
self.smw.close() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,91 @@ | ||
# the file name is intended. pytest don't play well with multiprocessing | ||
|
||
from tensorboardX import GlobalSummaryWriter as SummaryWriter | ||
from tensorboard.compat.tensorflow_stub.pywrap_tensorflow import PyRecordReader_New | ||
from tensorboardX.proto import event_pb2 | ||
import multiprocessing as mp | ||
import numpy as np | ||
import pytest | ||
import unittest | ||
import time | ||
|
||
|
||
class GlobalWriterTest(unittest.TestCase): | ||
def test_flush(self): | ||
N_TEST = 5 | ||
w = SummaryWriter(flush_secs=1) | ||
f = w.file_writer.event_writer._ev_writer._file_name | ||
for i in range(N_TEST): | ||
w.add_scalar('a', i) | ||
time.sleep(2) | ||
r = PyRecordReader_New(f) | ||
r.GetNext() # meta data, so skip | ||
for _ in range(N_TEST): # all of the data should be flushed | ||
r.GetNext() | ||
|
||
def test_flush_timer_is_long_so_data_is_not_there(self): | ||
with self.assertRaises(BaseException): | ||
N_TEST = 5 | ||
w = SummaryWriter(flush_secs=20) | ||
f = w.file_writer.event_writer._ev_writer._file_name | ||
for i in range(N_TEST): | ||
w.add_scalar('a', i) | ||
time.sleep(2) | ||
r = PyRecordReader_New(f) | ||
r.GetNext() # meta data, so skip | ||
for _ in range(N_TEST): # missing data | ||
r.GetNext() | ||
|
||
def test_flush_after_close(self): | ||
N_TEST = 5 | ||
w = SummaryWriter(flush_secs=20) | ||
f = w.file_writer.event_writer._ev_writer._file_name | ||
for i in range(N_TEST): | ||
w.add_scalar('a', i) | ||
time.sleep(2) | ||
w.close() | ||
r = PyRecordReader_New(f) | ||
r.GetNext() # meta data, so skip | ||
for _ in range(N_TEST): # all of the data should be flushed | ||
r.GetNext() | ||
|
||
|
||
def test_auto_close(self): | ||
pass | ||
|
||
def test_writer(self): | ||
TEST_LEN = 100 | ||
N_PROC = 4 | ||
writer = SummaryWriter() | ||
event_filename = writer.file_writer.event_writer._ev_writer._file_name | ||
|
||
predifined_values = list(range(TEST_LEN)) | ||
def train3(): | ||
for i in range(TEST_LEN): | ||
writer.add_scalar('many_write_in_func', predifined_values[i]) | ||
time.sleep(0.01*np.random.randint(0, 10)) | ||
|
||
processes = [] | ||
for i in range(N_PROC): | ||
p1 = mp.Process(target=train3) | ||
processes.append(p1) | ||
p1.start() | ||
|
||
for p in processes: | ||
p.join() | ||
writer.close() | ||
|
||
|
||
collected_values = [] | ||
r = PyRecordReader_New(event_filename) | ||
r.GetNext() # meta data, so skip | ||
for _ in range(TEST_LEN*N_PROC): # all of the data should be flushed | ||
r.GetNext() | ||
ev = event_pb2.Event() | ||
value = ev.FromString(r.record()).summary.value | ||
collected_values.append(value[0].simple_value) | ||
|
||
collected_values = sorted(collected_values) | ||
for i in range(TEST_LEN): | ||
for j in range(N_PROC): | ||
assert collected_values[i*N_PROC+j] == i |