Skip to content

Commit

Permalink
Merge pull request #164 from RaRe-Technologies/synchro
Browse files Browse the repository at this point in the history
Properly handle the race condition
  • Loading branch information
mpenkov authored Nov 25, 2022
2 parents 98a7991 + 5d3bb03 commit 8606ca4
Show file tree
Hide file tree
Showing 4 changed files with 81 additions and 50 deletions.
5 changes: 4 additions & 1 deletion .github/workflows/python-package.yml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ jobs:
python-version: ${{ matrix.python-version }}

- name: Update pip
run: python -m pip install -U coverage flake8 pip pytest pytest-coverage
run: python -m pip install -U coverage flake8 pip pytest pytest-coverage pytest-benchmark

- name: Flake8
run: flake8 sqlitedict.py tests
Expand All @@ -35,5 +35,8 @@ jobs:
- name: Run tests
run: pytest tests --cov=sqlitedict

- name: Run benchmarks
run: pytest benchmarks

- name: Run doctests
run: python -m doctest README.rst
15 changes: 15 additions & 0 deletions benchmarks/test_insert.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
import tempfile

from sqlitedict import SqliteDict


def insert():
with tempfile.NamedTemporaryFile() as tmp:
for j in range(100):
with SqliteDict(tmp.name) as d:
d["tmp"] = j
d.commit()


def test(benchmark):
benchmark(insert)
103 changes: 55 additions & 48 deletions sqlitedict.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,13 +30,11 @@
import os
import sys
import tempfile
import threading
import logging
import time
import traceback
from base64 import b64decode, b64encode

from threading import Thread

__version__ = '2.0.0'


Expand Down Expand Up @@ -172,7 +170,6 @@ def __init__(self, filename=None, tablename='unnamed', flag='c',
self.decode = decode
self.encode_key = encode_key
self.decode_key = decode_key
self.timeout = timeout
self._outer_stack = outer_stack

logger.debug("opening Sqlite table %r in %r" % (tablename, filename))
Expand All @@ -193,7 +190,6 @@ def _new_conn(self):
self.filename,
autocommit=self.autocommit,
journal_mode=self.journal_mode,
timeout=self.timeout,
outer_stack=self._outer_stack,
)

Expand Down Expand Up @@ -382,30 +378,49 @@ def __del__(self):
pass


class SqliteMultithread(Thread):
class SqliteMultithread(threading.Thread):
"""
Wrap sqlite connection in a way that allows concurrent requests from multiple threads.
This is done by internally queueing the requests and processing them sequentially
in a separate thread (in the same order they arrived).
"""
def __init__(self, filename, autocommit, journal_mode, timeout, outer_stack=True):
def __init__(self, filename, autocommit, journal_mode, outer_stack=True):
super(SqliteMultithread, self).__init__()
self.filename = filename
self.autocommit = autocommit
self.journal_mode = journal_mode
# use request queue of unlimited size
self.reqs = Queue()
self.daemon = True
self.exception = None
self._sqlitedict_thread_initialized = None
self._outer_stack = outer_stack
self.timeout = timeout
self.log = logging.getLogger('sqlitedict.SqliteMultithread')

#
# Parts of this object's state get accessed from different threads, so
# we use synchronization to avoid race conditions. For example,
# .exception gets set inside the new daemon thread that we spawned, but
# gets read from the main thread. This is particularly important
# during initialization: the Thread needs some time to actually start
# working, and until this happens, any calls to e.g.
# check_raise_error() will prematurely return None, meaning all is
# well. If the that connection happens to fail, we'll never know about
# it, and instead wait for a result that never arrives (effectively,
# deadlocking). Locking solves this problem by eliminating the race
# condition.
#
self._lock = threading.Lock()
self._lock.acquire()
self.exception = None

self.start()

def run(self):
def _connect(self):
"""Connect to the underlying database.
Raises an exception on failure. Returns the connection and cursor on success.
"""
try:
if self.autocommit:
conn = sqlite3.connect(self.filename, isolation_level=None, check_same_thread=False)
Expand All @@ -427,7 +442,17 @@ def run(self):
self.exception = sys.exc_info()
raise

self._sqlitedict_thread_initialized = True
return conn, cursor

def run(self):
#
# Nb. this is what actually runs inside the new daemon thread.
# self._lock is locked at this stage - see the initializer function.
#
try:
conn, cursor = self._connect()
finally:
self._lock.release()

res = None
while True:
Expand All @@ -443,7 +468,9 @@ def run(self):
try:
cursor.execute(req, arg)
except Exception:
self.exception = (e_type, e_value, e_tb) = sys.exc_info()
with self._lock:
self.exception = (e_type, e_value, e_tb) = sys.exc_info()

inner_stack = traceback.extract_stack()

# An exception occurred in our thread, but we may not
Expand Down Expand Up @@ -499,29 +526,29 @@ def check_raise_error(self):
calls to the `execute*` methods to check for and raise an exception in a
previous call to the MainThread.
"""
if self.exception:
e_type, e_value, e_tb = self.exception
with self._lock:
if self.exception:
e_type, e_value, e_tb = self.exception

# clear self.exception, if the caller decides to handle such
# exception, we should not repeatedly re-raise it.
self.exception = None
# clear self.exception, if the caller decides to handle such
# exception, we should not repeatedly re-raise it.
self.exception = None

self.log.error('An exception occurred from a previous statement, view '
'the logging namespace "sqlitedict" for outer stack.')
self.log.error('An exception occurred from a previous statement, view '
'the logging namespace "sqlitedict" for outer stack.')

# The third argument to raise is the traceback object, and it is
# substituted instead of the current location as the place where
# the exception occurred, this is so that when using debuggers such
# as `pdb', or simply evaluating the naturally raised traceback, we
# retain the original (inner) location of where the exception
# occurred.
reraise(e_type, e_value, e_tb)
# The third argument to raise is the traceback object, and it is
# substituted instead of the current location as the place where
# the exception occurred, this is so that when using debuggers such
# as `pdb', or simply evaluating the naturally raised traceback, we
# retain the original (inner) location of where the exception
# occurred.
reraise(e_type, e_value, e_tb)

def execute(self, req, arg=None, res=None):
"""
`execute` calls are non-blocking: just queue up the request and return immediately.
"""
self._wait_for_initialization()
self.check_raise_error()
stack = None

Expand Down Expand Up @@ -589,26 +616,6 @@ def close(self, force=False):
self.select_one('--close--')
self.join()

def _wait_for_initialization(self):
"""
Polls the 'initialized' flag to be set by the started Thread in run().
"""
# A race condition may occur without waiting for initialization:
# __init__() finishes with the start() call, but the Thread needs some time to actually start working.
# If opening the database file fails in run(), an exception will occur and self.exception will be set.
# But if we run check_raise_error() before run() had a chance to set self.exception, it will report
# a false negative: An exception occured and the thread terminates but self.exception is unset.
# This leads to a deadlock while waiting for the results of execute().
# By waiting for the Thread to set the initialized flag, we can ensure the thread has successfully
# opened the file - and possibly set self.exception to be detected by check_raise_error().

start_time = time.time()
while time.time() - start_time < self.timeout:
if self._sqlitedict_thread_initialized or self.exception:
return
time.sleep(0.1)
raise TimeoutError("SqliteMultithread failed to flag initialization within %0.0f seconds." % self.timeout)


#
# This is here for .github/workflows/release.yml
Expand Down
8 changes: 7 additions & 1 deletion tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,12 +282,18 @@ def test_complex_struct(self):


class TablenamesTest(unittest.TestCase):
def tearDown(self):
for f in ('tablenames-test-1.sqlite', 'tablenames-test-2.sqlite'):
path = norm_file(os.path.join('tests/db', f))
if os.path.isfile(path):
os.unlink(path)

def test_tablenames(self):
def test_tablenames_unnamed(self):
fname = norm_file('tests/db/tablenames-test-1.sqlite')
SqliteDict(fname)
self.assertEqual(SqliteDict.get_tablenames(fname), ['unnamed'])

def test_tablenams_named(self):
fname = norm_file('tests/db/tablenames-test-2.sqlite')
with SqliteDict(fname, tablename='table1'):
self.assertEqual(SqliteDict.get_tablenames(fname), ['table1'])
Expand Down

0 comments on commit 8606ca4

Please sign in to comment.