diff --git a/.github/workflows/python-package.yml b/.github/workflows/python-package.yml index 582098a..68e8f60 100644 --- a/.github/workflows/python-package.yml +++ b/.github/workflows/python-package.yml @@ -19,7 +19,7 @@ jobs: python-version: ${{ matrix.python-version }} - name: Update pip - run: python -m pip install -U coverage flake8 pip pytest pytest-coverage + run: python -m pip install -U coverage flake8 pip pytest pytest-coverage pytest-benchmark - name: Flake8 run: flake8 sqlitedict.py tests @@ -35,5 +35,8 @@ jobs: - name: Run tests run: pytest tests --cov=sqlitedict + - name: Run benchmarks + run: pytest benchmarks + - name: Run doctests run: python -m doctest README.rst diff --git a/benchmarks/test_insert.py b/benchmarks/test_insert.py new file mode 100644 index 0000000..0988021 --- /dev/null +++ b/benchmarks/test_insert.py @@ -0,0 +1,15 @@ +import tempfile + +from sqlitedict import SqliteDict + + +def insert(): + with tempfile.NamedTemporaryFile() as tmp: + for j in range(100): + with SqliteDict(tmp.name) as d: + d["tmp"] = j + d.commit() + + +def test(benchmark): + benchmark(insert) diff --git a/sqlitedict.py b/sqlitedict.py index b036ab3..635dfde 100755 --- a/sqlitedict.py +++ b/sqlitedict.py @@ -30,13 +30,11 @@ import os import sys import tempfile +import threading import logging -import time import traceback from base64 import b64decode, b64encode -from threading import Thread - __version__ = '2.0.0' @@ -172,7 +170,6 @@ def __init__(self, filename=None, tablename='unnamed', flag='c', self.decode = decode self.encode_key = encode_key self.decode_key = decode_key - self.timeout = timeout self._outer_stack = outer_stack logger.debug("opening Sqlite table %r in %r" % (tablename, filename)) @@ -193,7 +190,6 @@ def _new_conn(self): self.filename, autocommit=self.autocommit, journal_mode=self.journal_mode, - timeout=self.timeout, outer_stack=self._outer_stack, ) @@ -382,7 +378,7 @@ def __del__(self): pass -class SqliteMultithread(Thread): +class SqliteMultithread(threading.Thread): """ Wrap sqlite connection in a way that allows concurrent requests from multiple threads. @@ -390,7 +386,7 @@ class SqliteMultithread(Thread): in a separate thread (in the same order they arrived). """ - def __init__(self, filename, autocommit, journal_mode, timeout, outer_stack=True): + def __init__(self, filename, autocommit, journal_mode, outer_stack=True): super(SqliteMultithread, self).__init__() self.filename = filename self.autocommit = autocommit @@ -398,14 +394,33 @@ def __init__(self, filename, autocommit, journal_mode, timeout, outer_stack=True # use request queue of unlimited size self.reqs = Queue() self.daemon = True - self.exception = None - self._sqlitedict_thread_initialized = None self._outer_stack = outer_stack - self.timeout = timeout self.log = logging.getLogger('sqlitedict.SqliteMultithread') + + # + # Parts of this object's state get accessed from different threads, so + # we use synchronization to avoid race conditions. For example, + # .exception gets set inside the new daemon thread that we spawned, but + # gets read from the main thread. This is particularly important + # during initialization: the Thread needs some time to actually start + # working, and until this happens, any calls to e.g. + # check_raise_error() will prematurely return None, meaning all is + # well. If the that connection happens to fail, we'll never know about + # it, and instead wait for a result that never arrives (effectively, + # deadlocking). Locking solves this problem by eliminating the race + # condition. + # + self._lock = threading.Lock() + self._lock.acquire() + self.exception = None + self.start() - def run(self): + def _connect(self): + """Connect to the underlying database. + + Raises an exception on failure. Returns the connection and cursor on success. + """ try: if self.autocommit: conn = sqlite3.connect(self.filename, isolation_level=None, check_same_thread=False) @@ -427,7 +442,17 @@ def run(self): self.exception = sys.exc_info() raise - self._sqlitedict_thread_initialized = True + return conn, cursor + + def run(self): + # + # Nb. this is what actually runs inside the new daemon thread. + # self._lock is locked at this stage - see the initializer function. + # + try: + conn, cursor = self._connect() + finally: + self._lock.release() res = None while True: @@ -443,7 +468,9 @@ def run(self): try: cursor.execute(req, arg) except Exception: - self.exception = (e_type, e_value, e_tb) = sys.exc_info() + with self._lock: + self.exception = (e_type, e_value, e_tb) = sys.exc_info() + inner_stack = traceback.extract_stack() # An exception occurred in our thread, but we may not @@ -499,29 +526,29 @@ def check_raise_error(self): calls to the `execute*` methods to check for and raise an exception in a previous call to the MainThread. """ - if self.exception: - e_type, e_value, e_tb = self.exception + with self._lock: + if self.exception: + e_type, e_value, e_tb = self.exception - # clear self.exception, if the caller decides to handle such - # exception, we should not repeatedly re-raise it. - self.exception = None + # clear self.exception, if the caller decides to handle such + # exception, we should not repeatedly re-raise it. + self.exception = None - self.log.error('An exception occurred from a previous statement, view ' - 'the logging namespace "sqlitedict" for outer stack.') + self.log.error('An exception occurred from a previous statement, view ' + 'the logging namespace "sqlitedict" for outer stack.') - # The third argument to raise is the traceback object, and it is - # substituted instead of the current location as the place where - # the exception occurred, this is so that when using debuggers such - # as `pdb', or simply evaluating the naturally raised traceback, we - # retain the original (inner) location of where the exception - # occurred. - reraise(e_type, e_value, e_tb) + # The third argument to raise is the traceback object, and it is + # substituted instead of the current location as the place where + # the exception occurred, this is so that when using debuggers such + # as `pdb', or simply evaluating the naturally raised traceback, we + # retain the original (inner) location of where the exception + # occurred. + reraise(e_type, e_value, e_tb) def execute(self, req, arg=None, res=None): """ `execute` calls are non-blocking: just queue up the request and return immediately. """ - self._wait_for_initialization() self.check_raise_error() stack = None @@ -589,26 +616,6 @@ def close(self, force=False): self.select_one('--close--') self.join() - def _wait_for_initialization(self): - """ - Polls the 'initialized' flag to be set by the started Thread in run(). - """ - # A race condition may occur without waiting for initialization: - # __init__() finishes with the start() call, but the Thread needs some time to actually start working. - # If opening the database file fails in run(), an exception will occur and self.exception will be set. - # But if we run check_raise_error() before run() had a chance to set self.exception, it will report - # a false negative: An exception occured and the thread terminates but self.exception is unset. - # This leads to a deadlock while waiting for the results of execute(). - # By waiting for the Thread to set the initialized flag, we can ensure the thread has successfully - # opened the file - and possibly set self.exception to be detected by check_raise_error(). - - start_time = time.time() - while time.time() - start_time < self.timeout: - if self._sqlitedict_thread_initialized or self.exception: - return - time.sleep(0.1) - raise TimeoutError("SqliteMultithread failed to flag initialization within %0.0f seconds." % self.timeout) - # # This is here for .github/workflows/release.yml diff --git a/tests/test_core.py b/tests/test_core.py index 160e306..740860b 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -282,12 +282,18 @@ def test_complex_struct(self): class TablenamesTest(unittest.TestCase): + def tearDown(self): + for f in ('tablenames-test-1.sqlite', 'tablenames-test-2.sqlite'): + path = norm_file(os.path.join('tests/db', f)) + if os.path.isfile(path): + os.unlink(path) - def test_tablenames(self): + def test_tablenames_unnamed(self): fname = norm_file('tests/db/tablenames-test-1.sqlite') SqliteDict(fname) self.assertEqual(SqliteDict.get_tablenames(fname), ['unnamed']) + def test_tablenams_named(self): fname = norm_file('tests/db/tablenames-test-2.sqlite') with SqliteDict(fname, tablename='table1'): self.assertEqual(SqliteDict.get_tablenames(fname), ['table1'])