Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Properly handle the race condition #164

Merged
merged 2 commits into from
Nov 25, 2022
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion .github/workflows/python-package.yml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ jobs:
python-version: ${{ matrix.python-version }}

- name: Update pip
run: python -m pip install -U coverage flake8 pip pytest pytest-coverage
run: python -m pip install -U coverage flake8 pip pytest pytest-coverage pytest-benchmark

- name: Flake8
run: flake8 sqlitedict.py tests
Expand All @@ -35,5 +35,8 @@ jobs:
- name: Run tests
run: pytest tests --cov=sqlitedict

- name: Run benchmarks
run: pytest benchmarks

- name: Run doctests
run: python -m doctest README.rst
15 changes: 15 additions & 0 deletions benchmarks/test_insert.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
import tempfile

from sqlitedict import SqliteDict


def insert():
with tempfile.NamedTemporaryFile() as tmp:
for j in range(100):
with SqliteDict(tmp.name) as d:
d["tmp"] = j
d.commit()


def test(benchmark):
benchmark(insert)
103 changes: 55 additions & 48 deletions sqlitedict.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,13 +30,11 @@
import os
import sys
import tempfile
import threading
import logging
import time
import traceback
from base64 import b64decode, b64encode

from threading import Thread

__version__ = '2.0.0'


Expand Down Expand Up @@ -172,7 +170,6 @@ def __init__(self, filename=None, tablename='unnamed', flag='c',
self.decode = decode
self.encode_key = encode_key
self.decode_key = decode_key
self.timeout = timeout
self._outer_stack = outer_stack

logger.debug("opening Sqlite table %r in %r" % (tablename, filename))
Expand All @@ -193,7 +190,6 @@ def _new_conn(self):
self.filename,
autocommit=self.autocommit,
journal_mode=self.journal_mode,
timeout=self.timeout,
outer_stack=self._outer_stack,
)

Expand Down Expand Up @@ -382,30 +378,49 @@ def __del__(self):
pass


class SqliteMultithread(Thread):
class SqliteMultithread(threading.Thread):
"""
Wrap sqlite connection in a way that allows concurrent requests from multiple threads.

This is done by internally queueing the requests and processing them sequentially
in a separate thread (in the same order they arrived).

"""
def __init__(self, filename, autocommit, journal_mode, timeout, outer_stack=True):
def __init__(self, filename, autocommit, journal_mode, outer_stack=True):
super(SqliteMultithread, self).__init__()
self.filename = filename
self.autocommit = autocommit
self.journal_mode = journal_mode
# use request queue of unlimited size
self.reqs = Queue()
self.daemon = True
self.exception = None
self._sqlitedict_thread_initialized = None
self._outer_stack = outer_stack
self.timeout = timeout
self.log = logging.getLogger('sqlitedict.SqliteMultithread')

#
# Parts of this object's state get accessed from different threads, so
# we use synchronization to avoid race conditions. For example, the
# .exception and ._sqlite_thread_initialized are set inside the new
# daemon thread that we spawned, but they get read from the main
# thread. This is particularly important during initialization: the
# Thread needs some time to actually start working, and until this
# happens, any calls to e.g. check_raise_error() will prematurely
# return None, meaning all is well. If the that connection happens to
# fail, we'll never know about it, and instead wait for a result that
# never arrives (effectively, deadlocking). Locking solves this
# problem by eliminating the race condition.
#
self._lock = threading.Lock()
self._lock.acquire()
self.exception = None

self.start()

def run(self):
def _connect(self):
"""Connect to the underlying database.

Raises an exception on failure. Returns the connection and cursor on success.
"""
try:
if self.autocommit:
conn = sqlite3.connect(self.filename, isolation_level=None, check_same_thread=False)
Expand All @@ -427,7 +442,17 @@ def run(self):
self.exception = sys.exc_info()
raise

self._sqlitedict_thread_initialized = True
return conn, cursor

def run(self):
#
# Nb. this is what actually runs inside the new daemon thread.
# self._lock is locked at this stage - see the initializer function.
#
try:
conn, cursor = self._connect()
finally:
self._lock.release()

res = None
while True:
Expand All @@ -443,7 +468,9 @@ def run(self):
try:
cursor.execute(req, arg)
except Exception:
self.exception = (e_type, e_value, e_tb) = sys.exc_info()
with self._lock:
self.exception = (e_type, e_value, e_tb) = sys.exc_info()

inner_stack = traceback.extract_stack()

# An exception occurred in our thread, but we may not
Expand Down Expand Up @@ -499,29 +526,29 @@ def check_raise_error(self):
calls to the `execute*` methods to check for and raise an exception in a
previous call to the MainThread.
"""
if self.exception:
e_type, e_value, e_tb = self.exception
with self._lock:
if self.exception:
e_type, e_value, e_tb = self.exception

# clear self.exception, if the caller decides to handle such
# exception, we should not repeatedly re-raise it.
self.exception = None
# clear self.exception, if the caller decides to handle such
# exception, we should not repeatedly re-raise it.
self.exception = None

self.log.error('An exception occurred from a previous statement, view '
'the logging namespace "sqlitedict" for outer stack.')
self.log.error('An exception occurred from a previous statement, view '
'the logging namespace "sqlitedict" for outer stack.')

# The third argument to raise is the traceback object, and it is
# substituted instead of the current location as the place where
# the exception occurred, this is so that when using debuggers such
# as `pdb', or simply evaluating the naturally raised traceback, we
# retain the original (inner) location of where the exception
# occurred.
reraise(e_type, e_value, e_tb)
# The third argument to raise is the traceback object, and it is
# substituted instead of the current location as the place where
# the exception occurred, this is so that when using debuggers such
# as `pdb', or simply evaluating the naturally raised traceback, we
# retain the original (inner) location of where the exception
# occurred.
reraise(e_type, e_value, e_tb)

def execute(self, req, arg=None, res=None):
"""
`execute` calls are non-blocking: just queue up the request and return immediately.
"""
self._wait_for_initialization()
self.check_raise_error()
stack = None

Expand Down Expand Up @@ -589,26 +616,6 @@ def close(self, force=False):
self.select_one('--close--')
self.join()

def _wait_for_initialization(self):
"""
Polls the 'initialized' flag to be set by the started Thread in run().
"""
# A race condition may occur without waiting for initialization:
# __init__() finishes with the start() call, but the Thread needs some time to actually start working.
# If opening the database file fails in run(), an exception will occur and self.exception will be set.
# But if we run check_raise_error() before run() had a chance to set self.exception, it will report
# a false negative: An exception occured and the thread terminates but self.exception is unset.
# This leads to a deadlock while waiting for the results of execute().
# By waiting for the Thread to set the initialized flag, we can ensure the thread has successfully
# opened the file - and possibly set self.exception to be detected by check_raise_error().

start_time = time.time()
while time.time() - start_time < self.timeout:
if self._sqlitedict_thread_initialized or self.exception:
return
time.sleep(0.1)
raise TimeoutError("SqliteMultithread failed to flag initialization within %0.0f seconds." % self.timeout)


#
# This is here for .github/workflows/release.yml
Expand Down
8 changes: 7 additions & 1 deletion tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,12 +282,18 @@ def test_complex_struct(self):


class TablenamesTest(unittest.TestCase):
def tearDown(self):
for f in ('tablenames-test-1.sqlite', 'tablenames-test-2.sqlite'):
path = norm_file(os.path.join('tests/db', f))
if os.path.isfile(path):
os.unlink(path)

def test_tablenames(self):
def test_tablenames_unnamed(self):
fname = norm_file('tests/db/tablenames-test-1.sqlite')
SqliteDict(fname)
self.assertEqual(SqliteDict.get_tablenames(fname), ['unnamed'])

def test_tablenams_named(self):
fname = norm_file('tests/db/tablenames-test-2.sqlite')
with SqliteDict(fname, tablename='table1'):
self.assertEqual(SqliteDict.get_tablenames(fname), ['table1'])
Expand Down