From b2697ffdf18f7acd88a35e9a0a252c3b6fb25070 Mon Sep 17 00:00:00 2001 From: Dan Watson Date: Mon, 9 Oct 2023 16:15:30 -0400 Subject: [PATCH] Add query logging callbacks and context manager (#1043) --- asyncpg/connection.py | 128 ++++++++++++++++++++++++++++++++++++++---- tests/test_logging.py | 51 +++++++++++++++++ 2 files changed, 169 insertions(+), 10 deletions(-) create mode 100644 tests/test_logging.py diff --git a/asyncpg/connection.py b/asyncpg/connection.py index ba0e14fe..0367e365 100644 --- a/asyncpg/connection.py +++ b/asyncpg/connection.py @@ -9,6 +9,7 @@ import asyncpg import collections import collections.abc +import contextlib import functools import itertools import inspect @@ -53,7 +54,7 @@ class Connection(metaclass=ConnectionMeta): '_intro_query', '_reset_query', '_proxy', '_stmt_exclusive_section', '_config', '_params', '_addr', '_log_listeners', '_termination_listeners', '_cancellations', - '_source_traceback', '__weakref__') + '_source_traceback', '_query_loggers', '__weakref__') def __init__(self, protocol, transport, loop, addr, @@ -87,6 +88,7 @@ def __init__(self, protocol, transport, loop, self._log_listeners = set() self._cancellations = set() self._termination_listeners = set() + self._query_loggers = set() settings = self._protocol.get_settings() ver_string = settings.server_version @@ -224,6 +226,30 @@ def remove_termination_listener(self, callback): """ self._termination_listeners.discard(_Callback.from_callable(callback)) + def add_query_logger(self, callback): + """Add a logger that will be called when queries are executed. + + :param callable callback: + A callable or a coroutine function receiving one argument: + **record**: a LoggedQuery containing `query`, `args`, `timeout`, + `elapsed`, `exception`, `conn_addr`, and + `conn_params`. + + .. versionadded:: 0.29.0 + """ + self._query_loggers.add(_Callback.from_callable(callback)) + + def remove_query_logger(self, callback): + """Remove a query logger callback. + + :param callable callback: + The callable or coroutine function that was passed to + :meth:`Connection.add_query_logger`. + + .. versionadded:: 0.29.0 + """ + self._query_loggers.discard(_Callback.from_callable(callback)) + def get_server_pid(self): """Return the PID of the Postgres server the connection is bound to.""" return self._protocol.get_server_pid() @@ -317,7 +343,12 @@ async def execute(self, query: str, *args, timeout: float=None) -> str: self._check_open() if not args: - return await self._protocol.query(query, timeout) + if self._query_loggers: + with self._time_and_log(query, args, timeout): + result = await self._protocol.query(query, timeout) + else: + result = await self._protocol.query(query, timeout) + return result _, status, _ = await self._execute( query, @@ -1487,6 +1518,7 @@ def _cleanup(self): self._mark_stmts_as_closed() self._listeners.clear() self._log_listeners.clear() + self._query_loggers.clear() self._clean_tasks() def _clean_tasks(self): @@ -1770,6 +1802,63 @@ async def _execute( ) return result + @contextlib.contextmanager + def query_logger(self, callback): + """Context manager that adds `callback` to the list of query loggers, + and removes it upon exit. + + :param callable callback: + A callable or a coroutine function receiving one argument: + **record**: a LoggedQuery containing `query`, `args`, `timeout`, + `elapsed`, `exception`, `conn_addr`, and + `conn_params`. + + Example: + + .. code-block:: pycon + + >>> class QuerySaver: + def __init__(self): + self.queries = [] + def __call__(self, record): + self.queries.append(record.query) + >>> with con.query_logger(QuerySaver()): + >>> await con.execute("SELECT 1") + >>> print(log.queries) + ['SELECT 1'] + + .. versionadded:: 0.29.0 + """ + self.add_query_logger(callback) + yield + self.remove_query_logger(callback) + + @contextlib.contextmanager + def _time_and_log(self, query, args, timeout): + start = time.monotonic() + exception = None + try: + yield + except BaseException as ex: + exception = ex + raise + finally: + elapsed = time.monotonic() - start + record = LoggedQuery( + query=query, + args=args, + timeout=timeout, + elapsed=elapsed, + exception=exception, + conn_addr=self._addr, + conn_params=self._params, + ) + for cb in self._query_loggers: + if cb.is_async: + self._loop.create_task(cb.cb(record)) + else: + self._loop.call_soon(cb.cb, record) + async def __execute( self, query, @@ -1790,13 +1879,24 @@ async def __execute( timeout=timeout, ) timeout = self._protocol._get_timeout(timeout) - return await self._do_execute( - query, - executor, - timeout, - record_class=record_class, - ignore_custom_codec=ignore_custom_codec, - ) + if self._query_loggers: + with self._time_and_log(query, args, timeout): + result, stmt = await self._do_execute( + query, + executor, + timeout, + record_class=record_class, + ignore_custom_codec=ignore_custom_codec, + ) + else: + result, stmt = await self._do_execute( + query, + executor, + timeout, + record_class=record_class, + ignore_custom_codec=ignore_custom_codec, + ) + return result, stmt async def _executemany(self, query, args, timeout): executor = lambda stmt, timeout: self._protocol.bind_execute_many( @@ -1807,7 +1907,8 @@ async def _executemany(self, query, args, timeout): ) timeout = self._protocol._get_timeout(timeout) with self._stmt_exclusive_section: - result, _ = await self._do_execute(query, executor, timeout) + with self._time_and_log(query, args, timeout): + result, _ = await self._do_execute(query, executor, timeout) return result async def _do_execute( @@ -2440,6 +2541,13 @@ class _ConnectionProxy: __slots__ = () +LoggedQuery = collections.namedtuple( + 'LoggedQuery', + ['query', 'args', 'timeout', 'elapsed', 'exception', 'conn_addr', + 'conn_params']) +LoggedQuery.__doc__ = 'Log record of an executed query.' + + ServerCapabilities = collections.namedtuple( 'ServerCapabilities', ['advisory_locks', 'notifications', 'plpgsql', 'sql_reset', diff --git a/tests/test_logging.py b/tests/test_logging.py new file mode 100644 index 00000000..a9af94c4 --- /dev/null +++ b/tests/test_logging.py @@ -0,0 +1,51 @@ +import asyncio + +from asyncpg import _testbase as tb +from asyncpg import exceptions + + +class LogCollector: + def __init__(self): + self.records = [] + + def __call__(self, record): + self.records.append(record) + + +class TestQueryLogging(tb.ConnectedTestCase): + + async def test_logging_context(self): + queries = asyncio.Queue() + + def query_saver(record): + queries.put_nowait(record) + + log = LogCollector() + + with self.con.query_logger(query_saver): + self.assertEqual(len(self.con._query_loggers), 1) + await self.con.execute("SELECT 1") + with self.con.query_logger(log): + self.assertEqual(len(self.con._query_loggers), 2) + await self.con.execute("SELECT 2") + + r1 = await queries.get() + r2 = await queries.get() + self.assertEqual(r1.query, "SELECT 1") + self.assertEqual(r2.query, "SELECT 2") + self.assertEqual(len(log.records), 1) + self.assertEqual(log.records[0].query, "SELECT 2") + self.assertEqual(len(self.con._query_loggers), 0) + + async def test_error_logging(self): + log = LogCollector() + with self.con.query_logger(log): + with self.assertRaises(exceptions.UndefinedColumnError): + await self.con.execute("SELECT x") + + await asyncio.sleep(0) # wait for logging + self.assertEqual(len(log.records), 1) + self.assertEqual( + type(log.records[0].exception), + exceptions.UndefinedColumnError + )