Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add query logging callbacks and context manager #1043

Merged
merged 12 commits into from
Oct 9, 2023
77 changes: 74 additions & 3 deletions asyncpg/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ class Connection(metaclass=ConnectionMeta):
'_intro_query', '_reset_query', '_proxy',
'_stmt_exclusive_section', '_config', '_params', '_addr',
'_log_listeners', '_termination_listeners', '_cancellations',
'_source_traceback', '__weakref__')
'_source_traceback', '_query_loggers', '__weakref__')

def __init__(self, protocol, transport, loop,
addr,
Expand Down Expand Up @@ -84,6 +84,7 @@ def __init__(self, protocol, transport, loop,
self._log_listeners = set()
self._cancellations = set()
self._termination_listeners = set()
self._query_loggers = set()

settings = self._protocol.get_settings()
ver_string = settings.server_version
Expand Down Expand Up @@ -221,6 +222,30 @@ def remove_termination_listener(self, callback):
"""
self._termination_listeners.discard(_Callback.from_callable(callback))

def add_query_logger(self, callback):
"""Add a logger that will be called when queries are executed.

:param callable callback:
A callable or a coroutine function receiving two arguments:
**connection**: a Connection the callback is registered with.
**query**: a LoggedQuery containing the query, args, timeout, and
elapsed.

.. versionadded:: 0.28.0
"""
self._query_loggers.add(_Callback.from_callable(callback))

def remove_query_logger(self, callback):
"""Remove a query logger callback.

:param callable callback:
The callable or coroutine function that was passed to
:meth:`Connection.add_query_logger`.

.. versionadded:: 0.28.0
"""
self._query_loggers.discard(_Callback.from_callable(callback))

def get_server_pid(self):
"""Return the PID of the Postgres server the connection is bound to."""
return self._protocol.get_server_pid()
Expand Down Expand Up @@ -314,7 +339,11 @@ async def execute(self, query: str, *args, timeout: float=None) -> str:
self._check_open()

if not args:
return await self._protocol.query(query, timeout)
start = time.monotonic()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think a context manager would probably be nicer/reduce duplication.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed, change incoming.

result = await self._protocol.query(query, timeout)
elapsed = time.monotonic() - start
self._log_query(query, args, timeout, elapsed)
return result

_, status, _ = await self._execute(
query,
Expand Down Expand Up @@ -1667,6 +1696,20 @@ async def _execute(
)
return result

def logger(self, callback):
return _LoggingContext(self, callback)

def _log_query(self, query, args, timeout, elapsed):
if not self._query_loggers:
return
con_ref = self._unwrap()
record = LoggedQuery(query, args, timeout, elapsed)
for cb in self._query_loggers:
if cb.is_async:
self._loop.create_task(cb.cb(con_ref, record))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Passing a con_ref is probably unnecessary?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I likely wouldn't use it, so happy to remove it, but I put it there so you could potentially log queries by host.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A concern here is potentially retaining references to free-d connections. Other callbacks take it, of course, but that's an API decision I've come to regret. Perhaps we can pass connection's addr and params instead?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Works for me!

else:
self._loop.call_soon(cb.cb, con_ref, record)

async def __execute(
self,
query,
Expand All @@ -1681,20 +1724,27 @@ async def __execute(
executor = lambda stmt, timeout: self._protocol.bind_execute(
stmt, args, '', limit, return_status, timeout)
timeout = self._protocol._get_timeout(timeout)
return await self._do_execute(
start = time.monotonic()
result, stmt = await self._do_execute(
query,
executor,
timeout,
record_class=record_class,
ignore_custom_codec=ignore_custom_codec,
)
elapsed = time.monotonic() - start
self._log_query(query, args, timeout, elapsed)
return result, stmt

async def _executemany(self, query, args, timeout):
executor = lambda stmt, timeout: self._protocol.bind_execute_many(
stmt, args, '', timeout)
timeout = self._protocol._get_timeout(timeout)
with self._stmt_exclusive_section:
start = time.monotonic()
result, _ = await self._do_execute(query, executor, timeout)
elapsed = time.monotonic() - start
self._log_query(query, args, timeout, elapsed)
return result

async def _do_execute(
Expand Down Expand Up @@ -2323,6 +2373,27 @@ class _ConnectionProxy:
__slots__ = ()


LoggedQuery = collections.namedtuple(
'LoggedQuery',
['query', 'args', 'timeout', 'elapsed'])
LoggedQuery.__doc__ = 'Log record of an executed query.'


class _LoggingContext:
__slots__ = ('_conn', '_cb')

def __init__(self, conn, callback):
self._conn = conn
self._cb = callback

def __enter__(self):
self._conn.add_query_logger(self._cb)
return self

def __exit__(self, *exc_info):
self._conn.remove_query_logger(self._cb)


ServerCapabilities = collections.namedtuple(
'ServerCapabilities',
['advisory_locks', 'notifications', 'plpgsql', 'sql_reset',
Expand Down
20 changes: 20 additions & 0 deletions tests/test_logging.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import asyncio

from asyncpg import _testbase as tb


class TestQueryLogging(tb.ConnectedTestCase):

async def test_logging_context(self):
queries = asyncio.Queue()

def query_saver(conn, record):
queries.put_nowait(record)

with self.con.logger(query_saver):
self.assertEqual(len(self.con._query_loggers), 1)
await self.con.execute("SELECT 1")

record = await queries.get()
self.assertEqual(record.query, "SELECT 1")
self.assertEqual(len(self.con._query_loggers), 0)