Skip to content

Commit

Permalink
Fix EventFileWriter deadlock on exception in background thread (tenso…
Browse files Browse the repository at this point in the history
…rflow#6168)

## Motivation for features / changes

To address tensorflow#6167

## Technical description of changes

This is a bug fix for possible deadlock when writing events through
`EventFileWriter`. The PR adds logic in `_AsyncWriterThread` to catch
exception to propagate it to the calling thread and adds logic to
`_AsyncWriter` to propagate exception raised in `_AsyncWriterThread`

## Detailed steps to verify changes work correctly (as executed by you)

New unit test that is not passing on master

## Alternate designs / implementations considered

* Instead of popping an item from the queue on exception, it's possible
to make `wait`/`flush` methods re-check the status periodically
* Instead of raising an exception in the foreground thread, it's
possible to ignore the raised exception altogether and just start
dropping events
* It's possible to drop the data after it cannot be added to the queue
for a certain period of time

Signed-off-by: Mik Vyatskov <[email protected]>
  • Loading branch information
crassirostris authored and yatbear committed Mar 27, 2023
1 parent e8c54b6 commit 7420d5d
Show file tree
Hide file tree
Showing 2 changed files with 84 additions and 0 deletions.
38 changes: 38 additions & 0 deletions tensorboard/summary/writer/event_file_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,20 +165,33 @@ def __init__(self, record_writer, max_queue_size=20, flush_secs=120):
def write(self, bytestring):
"""Enqueue the given bytes to be written asychronously."""
with self._lock:
# Status of the worker should be checked under the lock to avoid
# multiple threads passing the check and then switching just before
# blocking on putting to the queue which might result in a deadlock.
self._check_worker_status()
if self._closed:
raise IOError("Writer is closed")
self._byte_queue.put(bytestring)
# Check the status again in case the background worker thread has
# failed in the meantime to avoid waiting until the next call to
# surface the error.
self._check_worker_status()

def flush(self):
"""Write all the enqueued bytestring before this flush call to disk.
Block until all the above bytestring are written.
"""
with self._lock:
self._check_worker_status()
if self._closed:
raise IOError("Writer is closed")
self._byte_queue.join()
self._writer.flush()
# Check the status again in case the background worker thread has
# failed in the meantime to avoid waiting until the next call to
# surface the error.
self._check_worker_status()

def close(self):
"""Closes the underlying writer, flushing any pending writes first."""
Expand All @@ -190,6 +203,14 @@ def close(self):
self._writer.flush()
self._writer.close()

def _check_worker_status(self):
"""Makes sure the worker thread is still running and raises exception
thrown in the worker thread otherwise.
"""
exception = self._worker.exception
if exception is not None:
raise exception


class _AsyncWriterThread(threading.Thread):
"""Thread that processes asynchronous writes for _AsyncWriter."""
Expand All @@ -205,6 +226,7 @@ def __init__(self, queue, record_writer, flush_secs):
"""
threading.Thread.__init__(self)
self.daemon = True
self.exception = None
self._queue = queue
self._record_writer = record_writer
self._flush_secs = flush_secs
Expand All @@ -218,6 +240,22 @@ def stop(self):
self.join()

def run(self):
try:
self._run()
except Exception as ex:
self.exception = ex
try:
# In case there's a thread blocked on putting an item into the
# queue or a thread blocked on flushing, pop all items from the
# queue to let the foreground thread proceed.
while True:
self._queue.get(False)
self._queue.task_done()
except queue.Empty:
pass
raise

def _run(self):
# Here wait on the queue until an data appears, or till the next
# time to flush the writer, whichever is earlier. If we have an
# data, write it. If not, an empty queue exception will be raised
Expand Down
46 changes: 46 additions & 0 deletions tensorboard/summary/writer/event_file_writer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,11 @@

import glob
import os
import threading
import time
from typing import Optional
from unittest.mock import MagicMock

from tensorboard.summary.writer.event_file_writer import EventFileWriter
from tensorboard.summary.writer.event_file_writer import _AsyncWriter
from tensorboard.compat.proto import event_pb2
Expand Down Expand Up @@ -132,6 +137,47 @@ def test_write_after_async_writer_closed(self):
with open(filename, "rb") as f:
self.assertEqual(f.read(), bytes_to_write)

def test_exception_in_background_thread_while_waiting_to_put(self):
record_writer_mock = MagicMock()
w = _AsyncWriter(record_writer_mock, max_queue_size=10)

cv = threading.Condition()
writing_can_proceed: bool = False
last_write_timestamp: Optional[float] = None

def writing_routine():
nonlocal last_write_timestamp
# 30 messages should be enough to fill the queue even if some of
# the events are dequeued in the background thread.
for _ in range(30):
w.write(b"x" * 64)
last_write_timestamp = time.time()

def on_write(*args, **kwargs) -> None:
with cv:
cv.wait_for(lambda: writing_can_proceed)
raise Exception()

record_writer_mock.write.side_effect = on_write
thread = threading.Thread(target=writing_routine, daemon=True)
thread.start()

with cv:
# Wait until the writing routine is blocked on writing.
while (
last_write_timestamp is None
or time.time() < last_write_timestamp + 1
):
cv.wait(0.1)
writing_can_proceed = True
cv.notify_all()

# If the thread joins successfully, it means that the exception was
# successfully propagated. 10 seconds should be more than enough to
# make sure that the thread is hanging.
thread.join(timeout=10)
self.assertFalse(thread.is_alive())


if __name__ == "__main__":
tb_test.main()

0 comments on commit 7420d5d

Please sign in to comment.