Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add query logging callbacks and context manager #1043

Merged
merged 12 commits into from
Oct 9, 2023
128 changes: 118 additions & 10 deletions asyncpg/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import asyncpg
import collections
import collections.abc
import contextlib
import functools
import itertools
import inspect
Expand Down Expand Up @@ -53,7 +54,7 @@ class Connection(metaclass=ConnectionMeta):
'_intro_query', '_reset_query', '_proxy',
'_stmt_exclusive_section', '_config', '_params', '_addr',
'_log_listeners', '_termination_listeners', '_cancellations',
'_source_traceback', '__weakref__')
'_source_traceback', '_query_loggers', '__weakref__')

def __init__(self, protocol, transport, loop,
addr,
Expand Down Expand Up @@ -87,6 +88,7 @@ def __init__(self, protocol, transport, loop,
self._log_listeners = set()
self._cancellations = set()
self._termination_listeners = set()
self._query_loggers = set()

settings = self._protocol.get_settings()
ver_string = settings.server_version
Expand Down Expand Up @@ -224,6 +226,30 @@ def remove_termination_listener(self, callback):
"""
self._termination_listeners.discard(_Callback.from_callable(callback))

def add_query_logger(self, callback):
"""Add a logger that will be called when queries are executed.

:param callable callback:
A callable or a coroutine function receiving one argument:
**record**: a LoggedQuery containing `query`, `args`, `timeout`,
`elapsed`, `exception`, `conn_addr`, and
`conn_params`.

.. versionadded:: 0.29.0
"""
self._query_loggers.add(_Callback.from_callable(callback))

def remove_query_logger(self, callback):
"""Remove a query logger callback.

:param callable callback:
The callable or coroutine function that was passed to
:meth:`Connection.add_query_logger`.

.. versionadded:: 0.29.0
"""
self._query_loggers.discard(_Callback.from_callable(callback))

def get_server_pid(self):
"""Return the PID of the Postgres server the connection is bound to."""
return self._protocol.get_server_pid()
Expand Down Expand Up @@ -317,7 +343,12 @@ async def execute(self, query: str, *args, timeout: float=None) -> str:
self._check_open()

if not args:
return await self._protocol.query(query, timeout)
if self._query_loggers:
with self._time_and_log(query, args, timeout):
result = await self._protocol.query(query, timeout)
else:
result = await self._protocol.query(query, timeout)
return result

_, status, _ = await self._execute(
query,
Expand Down Expand Up @@ -1487,6 +1518,7 @@ def _cleanup(self):
self._mark_stmts_as_closed()
self._listeners.clear()
self._log_listeners.clear()
self._query_loggers.clear()
self._clean_tasks()

def _clean_tasks(self):
Expand Down Expand Up @@ -1770,6 +1802,63 @@ async def _execute(
)
return result

@contextlib.contextmanager
def query_logger(self, callback):
"""Context manager that adds `callback` to the list of query loggers,
and removes it upon exit.

:param callable callback:
A callable or a coroutine function receiving one argument:
**record**: a LoggedQuery containing `query`, `args`, `timeout`,
`elapsed`, `exception`, `conn_addr`, and
`conn_params`.

Example:

.. code-block:: pycon

>>> class QuerySaver:
def __init__(self):
self.queries = []
def __call__(self, record):
self.queries.append(record.query)
>>> with con.query_logger(QuerySaver()):
>>> await con.execute("SELECT 1")
>>> print(log.queries)
['SELECT 1']

.. versionadded:: 0.29.0
"""
self.add_query_logger(callback)
yield
self.remove_query_logger(callback)

@contextlib.contextmanager
def _time_and_log(self, query, args, timeout):
start = time.monotonic()
exception = None
try:
yield
except BaseException as ex:
exception = ex
raise
finally:
elapsed = time.monotonic() - start
record = LoggedQuery(
query=query,
args=args,
timeout=timeout,
elapsed=elapsed,
exception=exception,
conn_addr=self._addr,
conn_params=self._params,
)
for cb in self._query_loggers:
if cb.is_async:
self._loop.create_task(cb.cb(record))
else:
self._loop.call_soon(cb.cb, record)

async def __execute(
self,
query,
Expand All @@ -1790,13 +1879,24 @@ async def __execute(
timeout=timeout,
)
timeout = self._protocol._get_timeout(timeout)
return await self._do_execute(
query,
executor,
timeout,
record_class=record_class,
ignore_custom_codec=ignore_custom_codec,
)
if self._query_loggers:
with self._time_and_log(query, args, timeout):
result, stmt = await self._do_execute(
query,
executor,
timeout,
record_class=record_class,
ignore_custom_codec=ignore_custom_codec,
)
else:
result, stmt = await self._do_execute(
query,
executor,
timeout,
record_class=record_class,
ignore_custom_codec=ignore_custom_codec,
)
return result, stmt

async def _executemany(self, query, args, timeout):
executor = lambda stmt, timeout: self._protocol.bind_execute_many(
Expand All @@ -1807,7 +1907,8 @@ async def _executemany(self, query, args, timeout):
)
timeout = self._protocol._get_timeout(timeout)
with self._stmt_exclusive_section:
result, _ = await self._do_execute(query, executor, timeout)
with self._time_and_log(query, args, timeout):
result, _ = await self._do_execute(query, executor, timeout)
return result

async def _do_execute(
Expand Down Expand Up @@ -2440,6 +2541,13 @@ class _ConnectionProxy:
__slots__ = ()


LoggedQuery = collections.namedtuple(
'LoggedQuery',
['query', 'args', 'timeout', 'elapsed', 'exception', 'conn_addr',
'conn_params'])
LoggedQuery.__doc__ = 'Log record of an executed query.'


ServerCapabilities = collections.namedtuple(
'ServerCapabilities',
['advisory_locks', 'notifications', 'plpgsql', 'sql_reset',
Expand Down
51 changes: 51 additions & 0 deletions tests/test_logging.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
import asyncio

from asyncpg import _testbase as tb
from asyncpg import exceptions


class LogCollector:
def __init__(self):
self.records = []

def __call__(self, record):
self.records.append(record)


class TestQueryLogging(tb.ConnectedTestCase):

async def test_logging_context(self):
queries = asyncio.Queue()

def query_saver(record):
queries.put_nowait(record)

log = LogCollector()

with self.con.query_logger(query_saver):
self.assertEqual(len(self.con._query_loggers), 1)
await self.con.execute("SELECT 1")
with self.con.query_logger(log):
self.assertEqual(len(self.con._query_loggers), 2)
await self.con.execute("SELECT 2")

r1 = await queries.get()
r2 = await queries.get()
self.assertEqual(r1.query, "SELECT 1")
self.assertEqual(r2.query, "SELECT 2")
self.assertEqual(len(log.records), 1)
self.assertEqual(log.records[0].query, "SELECT 2")
self.assertEqual(len(self.con._query_loggers), 0)

async def test_error_logging(self):
log = LogCollector()
with self.con.query_logger(log):
with self.assertRaises(exceptions.UndefinedColumnError):
await self.con.execute("SELECT x")

await asyncio.sleep(0) # wait for logging
self.assertEqual(len(log.records), 1)
self.assertEqual(
type(log.records[0].exception),
exceptions.UndefinedColumnError
)