diff --git a/aiosqlite/core.py b/aiosqlite/core.py index 58c227d..58c3fec 100644 --- a/aiosqlite/core.py +++ b/aiosqlite/core.py @@ -10,9 +10,8 @@ import sqlite3 from functools import partial from pathlib import Path -from queue import Empty, Queue +from queue import Empty, Queue, SimpleQueue from threading import Thread - from typing import ( Any, AsyncIterator, @@ -21,6 +20,7 @@ Iterable, Literal, Optional, + Tuple, Type, Union, ) @@ -37,6 +37,21 @@ IsolationLevel = Optional[Literal["DEFERRED", "IMMEDIATE", "EXCLUSIVE"]] +def set_result(fut: asyncio.Future, result: Any) -> None: + """Set the result of a future if it hasn't been set already.""" + if not fut.done(): + fut.set_result(result) + + +def set_exception(fut: asyncio.Future, e: BaseException) -> None: + """Set the exception of a future if it hasn't been set already.""" + if not fut.done(): + fut.set_exception(e) + + +_STOP_RUNNING_SENTINEL = object() + + class Connection(Thread): def __init__( self, @@ -48,7 +63,7 @@ def __init__( self._running = True self._connection: Optional[sqlite3.Connection] = None self._connector = connector - self._tx: Queue = Queue() + self._tx: SimpleQueue[Tuple[asyncio.Future, Callable[[], Any]]] = SimpleQueue() self._iter_chunk_size = iter_chunk_size if loop is not None: @@ -57,6 +72,11 @@ def __init__( DeprecationWarning, ) + def _stop_running(self): + self._running = False + # PEP 661 is not accepted yet, so we cannot type a sentinel + self._tx.put_nowait(_STOP_RUNNING_SENTINEL) # type: ignore[arg-type] + @property def _conn(self) -> sqlite3.Connection: if self._connection is None: @@ -83,29 +103,20 @@ def run(self) -> None: # Continues running until all queue items are processed, # even after connection is closed (so we can finalize all # futures) - try: - future, function = self._tx.get(timeout=0.1) - except Empty: - if self._running: - continue + + tx_item = self._tx.get() + if tx_item is _STOP_RUNNING_SENTINEL: break + + future, function = tx_item + try: LOG.debug("executing %s", function) result = function() LOG.debug("operation %s completed", function) - - def set_result(fut, result): - if not fut.done(): - fut.set_result(result) - future.get_loop().call_soon_threadsafe(set_result, future, result) except BaseException as e: # noqa B036 LOG.debug("returning exception %s", e) - - def set_exception(fut, e): - if not fut.done(): - fut.set_exception(e) - future.get_loop().call_soon_threadsafe(set_exception, future, e) async def _execute(self, fn, *args, **kwargs): @@ -128,7 +139,7 @@ async def _connect(self) -> "Connection": self._tx.put_nowait((future, self._connector)) self._connection = await future except Exception: - self._running = False + self._stop_running() self._connection = None raise @@ -169,7 +180,7 @@ async def close(self) -> None: LOG.info("exception occurred while closing connection") raise finally: - self._running = False + self._stop_running() self._connection = None @contextmanager