Skip to content

Commit 61f6cc9

Browse files
Prevent deadlocks in EagerIterators by making prefetch optional. (#185)
Previously, if you provided a thread pool that was too small and an EagerIterator could not create a new preloading thread, the iterator would deadlock, since it would wait for the new thread to be created forever and not try to just do the work itself. This change instead uses preloading as an optional optimization, and if the preload has not yet been completed, computes the next value itself.
1 parent a79f984 commit 61f6cc9

File tree

2 files changed

+81
-4
lines changed

2 files changed

+81
-4
lines changed

python-spec/src/somacore/query/_eager_iter.py

+17-4
Original file line numberDiff line numberDiff line change
@@ -14,16 +14,29 @@ def __init__(
1414
self.iterator = iterator
1515
self._pool = pool or futures.ThreadPoolExecutor()
1616
self._own_pool = pool is None
17-
self._future = self._pool.submit(self.iterator.__next__)
17+
self._preload_future = self._pool.submit(self.iterator.__next__)
1818

1919
def __next__(self) -> _T:
20+
stopped = False
2021
try:
21-
res = self._future.result()
22-
self._future = self._pool.submit(self.iterator.__next__)
23-
return res
22+
if self._preload_future.cancel():
23+
# If `.cancel` returns True, cancellation was successful.
24+
# The self.iterator.__next__ call has not yet been started,
25+
# and will never be started, so we can compute next ourselves.
26+
# This prevents deadlocks if the thread pool is too small
27+
# and we can never create a preload thread.
28+
return next(self.iterator)
29+
# `.cancel` returned false, so the preload is already running.
30+
# Just wait for it.
31+
return self._preload_future.result()
2432
except StopIteration:
2533
self._cleanup()
34+
stopped = True
2635
raise
36+
finally:
37+
if not stopped:
38+
# If we have more to do, go for the next thing.
39+
self._preload_future = self._pool.submit(self.iterator.__next__)
2740

2841
def _cleanup(self) -> None:
2942
if self._own_pool:
+64
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
import threading
2+
import unittest
3+
from concurrent import futures
4+
from unittest import mock
5+
6+
from somacore.query import _eager_iter
7+
8+
9+
class EagerIterTest(unittest.TestCase):
10+
def setUp(self):
11+
super().setUp()
12+
self.kiddie_pool = futures.ThreadPoolExecutor(1)
13+
"""Tiny thread pool for testing."""
14+
self.verify_pool = futures.ThreadPoolExecutor(1)
15+
"""Separate thread pool so verification is not blocked."""
16+
17+
def tearDown(self):
18+
self.verify_pool.shutdown(wait=False)
19+
self.kiddie_pool.shutdown(wait=False)
20+
super().tearDown()
21+
22+
def test_thread_starvation(self):
23+
sem = threading.Semaphore()
24+
try:
25+
# Monopolize the threadpool.
26+
sem.acquire()
27+
self.kiddie_pool.submit(sem.acquire)
28+
eager = _eager_iter.EagerIterator(iter("abc"), pool=self.kiddie_pool)
29+
got_a = self.verify_pool.submit(lambda: next(eager))
30+
self.assertEqual("a", got_a.result(0.1))
31+
got_b = self.verify_pool.submit(lambda: next(eager))
32+
self.assertEqual("b", got_b.result(0.1))
33+
got_c = self.verify_pool.submit(lambda: next(eager))
34+
self.assertEqual("c", got_c.result(0.1))
35+
with self.assertRaises(StopIteration):
36+
self.verify_pool.submit(lambda: next(eager)).result(0.1)
37+
finally:
38+
sem.release()
39+
40+
def test_nesting(self):
41+
inner = _eager_iter.EagerIterator(iter("abc"), pool=self.kiddie_pool)
42+
outer = _eager_iter.EagerIterator(inner, pool=self.kiddie_pool)
43+
self.assertEqual(
44+
"a, b, c", self.verify_pool.submit(", ".join, outer).result(0.1)
45+
)
46+
47+
def test_exceptions(self):
48+
flaky = mock.MagicMock()
49+
flaky.__next__.side_effect = [1, 2, ValueError(), 3, 4]
50+
51+
eager_flaky = _eager_iter.EagerIterator(flaky, pool=self.kiddie_pool)
52+
got_1 = self.verify_pool.submit(lambda: next(eager_flaky))
53+
self.assertEqual(1, got_1.result(0.1))
54+
got_2 = self.verify_pool.submit(lambda: next(eager_flaky))
55+
self.assertEqual(2, got_2.result(0.1))
56+
with self.assertRaises(ValueError):
57+
self.verify_pool.submit(lambda: next(eager_flaky)).result(0.1)
58+
got_3 = self.verify_pool.submit(lambda: next(eager_flaky))
59+
self.assertEqual(3, got_3.result(0.1))
60+
got_4 = self.verify_pool.submit(lambda: next(eager_flaky))
61+
self.assertEqual(4, got_4.result(0.1))
62+
for _ in range(5):
63+
with self.assertRaises(StopIteration):
64+
self.verify_pool.submit(lambda: next(eager_flaky)).result(0.1)

0 commit comments

Comments
 (0)