From 95d57ee24e79aa68a2eb35096dc145faa8434bc3 Mon Sep 17 00:00:00 2001 From: Elvis Pranskevichus Date: Fri, 18 Oct 2024 12:41:46 -0700 Subject: [PATCH] Allow customizing connection state reset A coroutine can be passed to the new `reset` argument of `create_pool` to control what happens to the connection when it is returned back to the pool by `release()`. By default `Connection.reset()` is called. Additionally, `Connection.get_reset_query` is renamed from `Connection._get_reset_query` to enable an alternative way of customizing the reset process via subclassing. Closes: #780 Closes: #1146 --- asyncpg/connection.py | 45 +++++++++++++++++++++++++++++++++++++------ asyncpg/pool.py | 36 ++++++++++++++++++++++++++++++---- tests/test_pool.py | 17 +++++++++++++++- 3 files changed, 87 insertions(+), 11 deletions(-) diff --git a/asyncpg/connection.py b/asyncpg/connection.py index 79711c0c..3a86466c 100644 --- a/asyncpg/connection.py +++ b/asyncpg/connection.py @@ -1515,11 +1515,10 @@ def terminate(self): self._abort() self._cleanup() - async def reset(self, *, timeout=None): + async def _reset(self): self._check_open() self._listeners.clear() self._log_listeners.clear() - reset_query = self._get_reset_query() if self._protocol.is_in_transaction() or self._top_xact is not None: if self._top_xact is None or not self._top_xact._managed: @@ -1531,10 +1530,36 @@ async def reset(self, *, timeout=None): }) self._top_xact = None - reset_query = 'ROLLBACK;\n' + reset_query + await self.execute("ROLLBACK") + + async def reset(self, *, timeout=None): + """Reset the connection state. + + Calling this will reset the connection session state to a state + resembling that of a newly obtained connection. Namely, an open + transaction (if any) is rolled back, open cursors are closed, + all `LISTEN `_ + registrations are removed, all session configuration + variables are reset to their default values, and all advisory locks + are released. + + Note that the above describes the default query returned by + :meth:`Connection.get_reset_query`. If one overloads the method + by subclassing ``Connection``, then this method will do whatever + the overloaded method returns, except open transactions are always + terminated and any callbacks registered by + :meth:`Connection.add_listener` or :meth:`Connection.add_log_listener` + are removed. - if reset_query: - await self.execute(reset_query, timeout=timeout) + :param float timeout: + A timeout for resetting the connection. If not specified, defaults + to no timeout. + """ + async with compat.timeout(timeout): + await self._reset() + reset_query = self.get_reset_query() + if reset_query: + await self.execute(reset_query) def _abort(self): # Put the connection into the aborted state. @@ -1695,7 +1720,15 @@ def _unwrap(self): con_ref = self._proxy return con_ref - def _get_reset_query(self): + def get_reset_query(self): + """Return the query sent to server on connection release. + + The query returned by this method is used by :meth:`Connection.reset`, + which is, in turn, used by :class:`~asyncpg.pool.Pool` before making + the connection available to another acquirer. + + .. versionadded:: 0.30.0 + """ if self._reset_query is not None: return self._reset_query diff --git a/asyncpg/pool.py b/asyncpg/pool.py index a18dd3b3..e3898d53 100644 --- a/asyncpg/pool.py +++ b/asyncpg/pool.py @@ -210,7 +210,12 @@ async def release(self, timeout): if budget is not None: budget -= time.monotonic() - started - await self._con.reset(timeout=budget) + if self._pool._reset is not None: + async with compat.timeout(budget): + await self._con._reset() + await self._pool._reset(self._con) + else: + await self._con.reset(timeout=budget) except (Exception, asyncio.CancelledError) as ex: # If the `reset` call failed, terminate the connection. # A new one will be created when `acquire` is called @@ -313,7 +318,7 @@ class Pool: __slots__ = ( '_queue', '_loop', '_minsize', '_maxsize', - '_init', '_connect', '_connect_args', '_connect_kwargs', + '_init', '_connect', '_reset', '_connect_args', '_connect_kwargs', '_holders', '_initialized', '_initializing', '_closing', '_closed', '_connection_class', '_record_class', '_generation', '_setup', '_max_queries', '_max_inactive_connection_lifetime' @@ -327,6 +332,7 @@ def __init__(self, *connect_args, connect=None, setup=None, init=None, + reset=None, loop, connection_class, record_class, @@ -393,6 +399,7 @@ def __init__(self, *connect_args, self._setup = setup self._init = init + self._reset = reset self._max_queries = max_queries self._max_inactive_connection_lifetime = \ @@ -1036,6 +1043,7 @@ def create_pool(dsn=None, *, connect=None, setup=None, init=None, + reset=None, loop=None, connection_class=connection.Connection, record_class=protocol.Record, @@ -1125,7 +1133,7 @@ def create_pool(dsn=None, *, :param coroutine setup: A coroutine to prepare a connection right before it is returned - from :meth:`Pool.acquire() `. An example use + from :meth:`Pool.acquire()`. An example use case would be to automatically set up notifications listeners for all connections of a pool. @@ -1137,6 +1145,25 @@ def create_pool(dsn=None, *, or :meth:`Connection.set_type_codec() <\ asyncpg.connection.Connection.set_type_codec>`. + :param coroutine reset: + A coroutine to reset a connection before it is returned to the pool by + :meth:`Pool.release()`. The function is supposed + to reset any changes made to the database session so that the next + acquirer gets the connection in a well-defined state. + + The default implementation calls :meth:`Connection.reset() <\ + asyncpg.connection.Connection.reset>`, which runs the following:: + + SELECT pg_advisory_unlock_all(); + CLOSE ALL; + UNLISTEN *; + RESET ALL; + + The exact reset query is determined by detected server capabilities, + and a custom *reset* implementation can obtain the default query + by calling :meth:`Connection.get_reset_query() <\ + asyncpg.connection.Connection.get_reset_query>`. + :param loop: An asyncio event loop instance. If ``None``, the default event loop will be used. @@ -1165,7 +1192,7 @@ def create_pool(dsn=None, *, Added the *record_class* parameter. .. versionchanged:: 0.30.0 - Added the *connect* parameter. + Added the *connect* and *reset* parameters. """ return Pool( dsn, @@ -1178,6 +1205,7 @@ def create_pool(dsn=None, *, connect=connect, setup=setup, init=init, + reset=reset, max_inactive_connection_lifetime=max_inactive_connection_lifetime, **connect_kwargs, ) diff --git a/tests/test_pool.py b/tests/test_pool.py index 5bd70bd9..3f10ae5c 100644 --- a/tests/test_pool.py +++ b/tests/test_pool.py @@ -137,6 +137,9 @@ async def setup(con): async def test_pool_07(self): cons = set() connect_called = 0 + init_called = 0 + setup_called = 0 + reset_called = 0 async def connect(*args, **kwargs): nonlocal connect_called @@ -144,13 +147,21 @@ async def connect(*args, **kwargs): return await pg_connection.connect(*args, **kwargs) async def setup(con): + nonlocal setup_called if con._con not in cons: # `con` is `PoolConnectionProxy`. raise RuntimeError('init was not called before setup') + setup_called += 1 async def init(con): + nonlocal init_called if con in cons: raise RuntimeError('init was called more than once') cons.add(con) + init_called += 1 + + async def reset(con): + nonlocal reset_called + reset_called += 1 async def user(pool): async with pool.acquire() as con: @@ -162,12 +173,16 @@ async def user(pool): max_size=5, connect=connect, init=init, - setup=setup) as pool: + setup=setup, + reset=reset) as pool: users = asyncio.gather(*[user(pool) for _ in range(10)]) await users self.assertEqual(len(cons), 5) self.assertEqual(connect_called, 5) + self.assertEqual(init_called, 5) + self.assertEqual(setup_called, 10) + self.assertEqual(reset_called, 10) async def bad_connect(*args, **kwargs): return 1