From 8f8671a21318fa5dfee0672353baf3d8e392aa9d Mon Sep 17 00:00:00 2001 From: Dan Watson Date: Fri, 30 Jun 2023 15:21:02 -0400 Subject: [PATCH 1/9] Add query logging callbacks and context manager --- asyncpg/connection.py | 77 +++++++++++++++++++++++++++++++++++++++++-- tests/test_logging.py | 20 +++++++++++ 2 files changed, 94 insertions(+), 3 deletions(-) create mode 100644 tests/test_logging.py diff --git a/asyncpg/connection.py b/asyncpg/connection.py index 0b13d356..4163a131 100644 --- a/asyncpg/connection.py +++ b/asyncpg/connection.py @@ -51,7 +51,7 @@ class Connection(metaclass=ConnectionMeta): '_intro_query', '_reset_query', '_proxy', '_stmt_exclusive_section', '_config', '_params', '_addr', '_log_listeners', '_termination_listeners', '_cancellations', - '_source_traceback', '__weakref__') + '_source_traceback', '_query_loggers', '__weakref__') def __init__(self, protocol, transport, loop, addr, @@ -84,6 +84,7 @@ def __init__(self, protocol, transport, loop, self._log_listeners = set() self._cancellations = set() self._termination_listeners = set() + self._query_loggers = set() settings = self._protocol.get_settings() ver_string = settings.server_version @@ -221,6 +222,30 @@ def remove_termination_listener(self, callback): """ self._termination_listeners.discard(_Callback.from_callable(callback)) + def add_query_logger(self, callback): + """Add a logger that will be called when queries are executed. + + :param callable callback: + A callable or a coroutine function receiving two arguments: + **connection**: a Connection the callback is registered with. + **query**: a LoggedQuery containing the query, args, timeout, and + elapsed. + + .. versionadded:: 0.28.0 + """ + self._query_loggers.add(_Callback.from_callable(callback)) + + def remove_query_logger(self, callback): + """Remove a query logger callback. + + :param callable callback: + The callable or coroutine function that was passed to + :meth:`Connection.add_query_logger`. + + .. versionadded:: 0.28.0 + """ + self._query_loggers.discard(_Callback.from_callable(callback)) + def get_server_pid(self): """Return the PID of the Postgres server the connection is bound to.""" return self._protocol.get_server_pid() @@ -314,7 +339,11 @@ async def execute(self, query: str, *args, timeout: float=None) -> str: self._check_open() if not args: - return await self._protocol.query(query, timeout) + start = time.monotonic() + result = await self._protocol.query(query, timeout) + elapsed = time.monotonic() - start + self._log_query(query, args, timeout, elapsed) + return result _, status, _ = await self._execute( query, @@ -1667,6 +1696,20 @@ async def _execute( ) return result + def logger(self, callback): + return _LoggingContext(self, callback) + + def _log_query(self, query, args, timeout, elapsed): + if not self._query_loggers: + return + con_ref = self._unwrap() + record = LoggedQuery(query, args, timeout, elapsed) + for cb in self._query_loggers: + if cb.is_async: + self._loop.create_task(cb.cb(con_ref, record)) + else: + self._loop.call_soon(cb.cb, con_ref, record) + async def __execute( self, query, @@ -1681,20 +1724,27 @@ async def __execute( executor = lambda stmt, timeout: self._protocol.bind_execute( stmt, args, '', limit, return_status, timeout) timeout = self._protocol._get_timeout(timeout) - return await self._do_execute( + start = time.monotonic() + result, stmt = await self._do_execute( query, executor, timeout, record_class=record_class, ignore_custom_codec=ignore_custom_codec, ) + elapsed = time.monotonic() - start + self._log_query(query, args, timeout, elapsed) + return result, stmt async def _executemany(self, query, args, timeout): executor = lambda stmt, timeout: self._protocol.bind_execute_many( stmt, args, '', timeout) timeout = self._protocol._get_timeout(timeout) with self._stmt_exclusive_section: + start = time.monotonic() result, _ = await self._do_execute(query, executor, timeout) + elapsed = time.monotonic() - start + self._log_query(query, args, timeout, elapsed) return result async def _do_execute( @@ -2323,6 +2373,27 @@ class _ConnectionProxy: __slots__ = () +LoggedQuery = collections.namedtuple( + 'LoggedQuery', + ['query', 'args', 'timeout', 'elapsed']) +LoggedQuery.__doc__ = 'Log record of an executed query.' + + +class _LoggingContext: + __slots__ = ('_conn', '_cb') + + def __init__(self, conn, callback): + self._conn = conn + self._cb = callback + + def __enter__(self): + self._conn.add_query_logger(self._cb) + return self + + def __exit__(self, *exc_info): + self._conn.remove_query_logger(self._cb) + + ServerCapabilities = collections.namedtuple( 'ServerCapabilities', ['advisory_locks', 'notifications', 'plpgsql', 'sql_reset', diff --git a/tests/test_logging.py b/tests/test_logging.py new file mode 100644 index 00000000..8a7cc51c --- /dev/null +++ b/tests/test_logging.py @@ -0,0 +1,20 @@ +import asyncio + +from asyncpg import _testbase as tb + + +class TestQueryLogging(tb.ConnectedTestCase): + + async def test_logging_context(self): + queries = asyncio.Queue() + + def query_saver(conn, record): + queries.put_nowait(record) + + with self.con.logger(query_saver): + self.assertEqual(len(self.con._query_loggers), 1) + await self.con.execute("SELECT 1") + + record = await queries.get() + self.assertEqual(record.query, "SELECT 1") + self.assertEqual(len(self.con._query_loggers), 0) From f64a8eac4a1f0393e7c8987d77d3d9c9acc330e4 Mon Sep 17 00:00:00 2001 From: Dan Watson Date: Thu, 6 Jul 2023 21:03:32 -0400 Subject: [PATCH 2/9] Wrap timing in a context manager --- asyncpg/connection.py | 33 +++++++++++++++------------------ asyncpg/utils.py | 26 ++++++++++++++++++++++++++ 2 files changed, 41 insertions(+), 18 deletions(-) diff --git a/asyncpg/connection.py b/asyncpg/connection.py index 4163a131..d59482a7 100644 --- a/asyncpg/connection.py +++ b/asyncpg/connection.py @@ -339,10 +339,9 @@ async def execute(self, query: str, *args, timeout: float=None) -> str: self._check_open() if not args: - start = time.monotonic() - result = await self._protocol.query(query, timeout) - elapsed = time.monotonic() - start - self._log_query(query, args, timeout, elapsed) + with utils.timer() as t: + result = await self._protocol.query(query, timeout) + self._log_query(query, args, timeout, t.elapsed) return result _, status, _ = await self._execute( @@ -1724,16 +1723,15 @@ async def __execute( executor = lambda stmt, timeout: self._protocol.bind_execute( stmt, args, '', limit, return_status, timeout) timeout = self._protocol._get_timeout(timeout) - start = time.monotonic() - result, stmt = await self._do_execute( - query, - executor, - timeout, - record_class=record_class, - ignore_custom_codec=ignore_custom_codec, - ) - elapsed = time.monotonic() - start - self._log_query(query, args, timeout, elapsed) + with utils.timer() as t: + result, stmt = await self._do_execute( + query, + executor, + timeout, + record_class=record_class, + ignore_custom_codec=ignore_custom_codec, + ) + self._log_query(query, args, timeout, t.elapsed) return result, stmt async def _executemany(self, query, args, timeout): @@ -1741,10 +1739,9 @@ async def _executemany(self, query, args, timeout): stmt, args, '', timeout) timeout = self._protocol._get_timeout(timeout) with self._stmt_exclusive_section: - start = time.monotonic() - result, _ = await self._do_execute(query, executor, timeout) - elapsed = time.monotonic() - start - self._log_query(query, args, timeout, elapsed) + with utils.timer() as t: + result, _ = await self._do_execute(query, executor, timeout) + self._log_query(query, args, timeout, t.elapsed) return result async def _do_execute( diff --git a/asyncpg/utils.py b/asyncpg/utils.py index 3940e04d..bd3be330 100644 --- a/asyncpg/utils.py +++ b/asyncpg/utils.py @@ -6,6 +6,7 @@ import re +import time def _quote_ident(ident): @@ -43,3 +44,28 @@ async def _mogrify(conn, query, args): # Finally, replace $n references with text values. return re.sub( r'\$(\d+)\b', lambda m: textified[int(m.group(1)) - 1], query) + + +class timer: + __slots__ = ('start', 'elapsed') + + def __init__(self): + self.start = time.monotonic() + self.elapsed = None + + @property + def current(self): + return time.monotonic() - self.start + + def restart(self): + self.start = time.monotonic() + + def stop(self): + self.elapsed = self.current + + def __enter__(self): + self.restart() + return self + + def __exit__(self, *exc): + self.stop() From 937890d4d84d0701b0d247f1fa03057acb48e59a Mon Sep 17 00:00:00 2001 From: Dan Watson Date: Thu, 6 Jul 2023 21:37:44 -0400 Subject: [PATCH 3/9] More docs, test multiple loggers --- asyncpg/connection.py | 31 ++++++++++++++++++++++++++++--- tests/test_logging.py | 11 ++++++++++- 2 files changed, 38 insertions(+), 4 deletions(-) diff --git a/asyncpg/connection.py b/asyncpg/connection.py index 493ec15b..b6150519 100644 --- a/asyncpg/connection.py +++ b/asyncpg/connection.py @@ -231,7 +231,7 @@ def add_query_logger(self, callback): **query**: a LoggedQuery containing the query, args, timeout, and elapsed. - .. versionadded:: 0.28.0 + .. versionadded:: 0.29.0 """ self._query_loggers.add(_Callback.from_callable(callback)) @@ -242,7 +242,7 @@ def remove_query_logger(self, callback): The callable or coroutine function that was passed to :meth:`Connection.add_query_logger`. - .. versionadded:: 0.28.0 + .. versionadded:: 0.29.0 """ self._query_loggers.discard(_Callback.from_callable(callback)) @@ -1696,6 +1696,31 @@ async def _execute( return result def logger(self, callback): + """Context manager that adds `callback` to the list of query loggers, + and removes it upon exit. + + :param callable callback: + A callable or a coroutine function receiving two arguments: + **connection**: a Connection the callback is registered with. + **query**: a LoggedQuery containing the query, args, timeout, and + elapsed. + + Example: + + .. code-block:: pycon + + >>> class QuerySaver: + def __init__(self): + self.queries = [] + def __call__(self, conn, record): + self.queries.append(record.query) + >>> with con.logger(QuerySaver()) as log: + >>> await con.execute("SELECT 1") + >>> print(log.queries) + ['SELECT 1'] + + .. versionadded:: 0.29.0 + """ return _LoggingContext(self, callback) def _log_query(self, query, args, timeout, elapsed): @@ -2389,7 +2414,7 @@ def __init__(self, conn, callback): def __enter__(self): self._conn.add_query_logger(self._cb) - return self + return self._cb def __exit__(self, *exc_info): self._conn.remove_query_logger(self._cb) diff --git a/tests/test_logging.py b/tests/test_logging.py index 8a7cc51c..73664bb6 100644 --- a/tests/test_logging.py +++ b/tests/test_logging.py @@ -11,10 +11,19 @@ async def test_logging_context(self): def query_saver(conn, record): queries.put_nowait(record) + class QuerySaver: + def __init__(self): + self.queries = [] + def __call__(self, conn, record): + self.queries.append(record.query) + with self.con.logger(query_saver): self.assertEqual(len(self.con._query_loggers), 1) - await self.con.execute("SELECT 1") + with self.con.logger(QuerySaver()) as log: + self.assertEqual(len(self.con._query_loggers), 2) + await self.con.execute("SELECT 1") record = await queries.get() self.assertEqual(record.query, "SELECT 1") + self.assertEqual(log.queries, ["SELECT 1"]) self.assertEqual(len(self.con._query_loggers), 0) From 107f1962137a436be179305a4f69da88e0804b8b Mon Sep 17 00:00:00 2001 From: Dan Watson Date: Thu, 6 Jul 2023 21:38:29 -0400 Subject: [PATCH 4/9] Appease flake8 --- tests/test_logging.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/test_logging.py b/tests/test_logging.py index 73664bb6..dbb3eeb8 100644 --- a/tests/test_logging.py +++ b/tests/test_logging.py @@ -14,6 +14,7 @@ def query_saver(conn, record): class QuerySaver: def __init__(self): self.queries = [] + def __call__(self, conn, record): self.queries.append(record.query) From 92baac18d56bb107c9baaaef9831b82f21e2a783 Mon Sep 17 00:00:00 2001 From: Dan Watson Date: Thu, 6 Jul 2023 23:41:33 -0400 Subject: [PATCH 5/9] Log query errors, clean up context managers --- asyncpg/connection.py | 84 +++++++++++++++++++++---------------------- asyncpg/utils.py | 26 -------------- tests/test_logging.py | 46 ++++++++++++++++-------- 3 files changed, 74 insertions(+), 82 deletions(-) diff --git a/asyncpg/connection.py b/asyncpg/connection.py index b6150519..389afcef 100644 --- a/asyncpg/connection.py +++ b/asyncpg/connection.py @@ -9,6 +9,7 @@ import asyncpg import collections import collections.abc +import contextlib import functools import itertools import inspect @@ -226,10 +227,9 @@ def add_query_logger(self, callback): """Add a logger that will be called when queries are executed. :param callable callback: - A callable or a coroutine function receiving two arguments: - **connection**: a Connection the callback is registered with. - **query**: a LoggedQuery containing the query, args, timeout, and - elapsed. + A callable or a coroutine function receiving one argument: + **record**: a LoggedQuery containing `query`, `args`, `timeout`, + `elapsed`, `addr`, `params`, and `exception`. .. versionadded:: 0.29.0 """ @@ -339,9 +339,8 @@ async def execute(self, query: str, *args, timeout: float=None) -> str: self._check_open() if not args: - with utils.timer() as t: + with self._time_and_log(query, args, timeout): result = await self._protocol.query(query, timeout) - self._log_query(query, args, timeout, t.elapsed) return result _, status, _ = await self._execute( @@ -1412,6 +1411,7 @@ def _cleanup(self): self._mark_stmts_as_closed() self._listeners.clear() self._log_listeners.clear() + self._query_loggers.clear() self._clean_tasks() def _clean_tasks(self): @@ -1695,15 +1695,15 @@ async def _execute( ) return result + @contextlib.contextmanager def logger(self, callback): """Context manager that adds `callback` to the list of query loggers, and removes it upon exit. :param callable callback: - A callable or a coroutine function receiving two arguments: - **connection**: a Connection the callback is registered with. - **query**: a LoggedQuery containing the query, args, timeout, and - elapsed. + A callable or a coroutine function receiving one argument: + **record**: a LoggedQuery containing `query`, `args`, `timeout`, + `elapsed`, `addr`, and `params`. Example: @@ -1721,18 +1721,35 @@ def __call__(self, conn, record): .. versionadded:: 0.29.0 """ - return _LoggingContext(self, callback) - - def _log_query(self, query, args, timeout, elapsed): - if not self._query_loggers: - return - con_ref = self._unwrap() - record = LoggedQuery(query, args, timeout, elapsed) - for cb in self._query_loggers: - if cb.is_async: - self._loop.create_task(cb.cb(con_ref, record)) - else: - self._loop.call_soon(cb.cb, con_ref, record) + self.add_query_logger(callback) + yield callback + self.remove_query_logger(callback) + + @contextlib.contextmanager + def _time_and_log(self, query, args, timeout): + start = time.monotonic() + exception = None + try: + yield + except Exception as ex: + exception = ex + raise + finally: + elapsed = time.monotonic() - start + record = LoggedQuery( + query=query, + args=args, + timeout=timeout, + elapsed=elapsed, + addr=self._addr, + params=self._params, + exception=exception, + ) + for cb in self._query_loggers: + if cb.is_async: + self._loop.create_task(cb.cb(record)) + else: + self._loop.call_soon(cb.cb, record) async def __execute( self, @@ -1748,7 +1765,7 @@ async def __execute( executor = lambda stmt, timeout: self._protocol.bind_execute( stmt, args, '', limit, return_status, timeout) timeout = self._protocol._get_timeout(timeout) - with utils.timer() as t: + with self._time_and_log(query, args, timeout): result, stmt = await self._do_execute( query, executor, @@ -1756,7 +1773,6 @@ async def __execute( record_class=record_class, ignore_custom_codec=ignore_custom_codec, ) - self._log_query(query, args, timeout, t.elapsed) return result, stmt async def _executemany(self, query, args, timeout): @@ -1764,9 +1780,8 @@ async def _executemany(self, query, args, timeout): stmt, args, '', timeout) timeout = self._protocol._get_timeout(timeout) with self._stmt_exclusive_section: - with utils.timer() as t: + with self._time_and_log(query, args, timeout): result, _ = await self._do_execute(query, executor, timeout) - self._log_query(query, args, timeout, t.elapsed) return result async def _do_execute( @@ -2401,25 +2416,10 @@ class _ConnectionProxy: LoggedQuery = collections.namedtuple( 'LoggedQuery', - ['query', 'args', 'timeout', 'elapsed']) + ['query', 'args', 'timeout', 'elapsed', 'exception', 'addr', 'params']) LoggedQuery.__doc__ = 'Log record of an executed query.' -class _LoggingContext: - __slots__ = ('_conn', '_cb') - - def __init__(self, conn, callback): - self._conn = conn - self._cb = callback - - def __enter__(self): - self._conn.add_query_logger(self._cb) - return self._cb - - def __exit__(self, *exc_info): - self._conn.remove_query_logger(self._cb) - - ServerCapabilities = collections.namedtuple( 'ServerCapabilities', ['advisory_locks', 'notifications', 'plpgsql', 'sql_reset', diff --git a/asyncpg/utils.py b/asyncpg/utils.py index bd3be330..3940e04d 100644 --- a/asyncpg/utils.py +++ b/asyncpg/utils.py @@ -6,7 +6,6 @@ import re -import time def _quote_ident(ident): @@ -44,28 +43,3 @@ async def _mogrify(conn, query, args): # Finally, replace $n references with text values. return re.sub( r'\$(\d+)\b', lambda m: textified[int(m.group(1)) - 1], query) - - -class timer: - __slots__ = ('start', 'elapsed') - - def __init__(self): - self.start = time.monotonic() - self.elapsed = None - - @property - def current(self): - return time.monotonic() - self.start - - def restart(self): - self.start = time.monotonic() - - def stop(self): - self.elapsed = self.current - - def __enter__(self): - self.restart() - return self - - def __exit__(self, *exc): - self.stop() diff --git a/tests/test_logging.py b/tests/test_logging.py index dbb3eeb8..b40ba335 100644 --- a/tests/test_logging.py +++ b/tests/test_logging.py @@ -1,6 +1,15 @@ import asyncio from asyncpg import _testbase as tb +from asyncpg import exceptions + + +class LogCollector: + def __init__(self): + self.records = [] + + def __call__(self, record): + self.records.append(record) class TestQueryLogging(tb.ConnectedTestCase): @@ -8,23 +17,32 @@ class TestQueryLogging(tb.ConnectedTestCase): async def test_logging_context(self): queries = asyncio.Queue() - def query_saver(conn, record): + def query_saver(record): queries.put_nowait(record) - class QuerySaver: - def __init__(self): - self.queries = [] - - def __call__(self, conn, record): - self.queries.append(record.query) - with self.con.logger(query_saver): self.assertEqual(len(self.con._query_loggers), 1) - with self.con.logger(QuerySaver()) as log: + await self.con.execute("SELECT 1") + with self.con.logger(LogCollector()) as log: self.assertEqual(len(self.con._query_loggers), 2) - await self.con.execute("SELECT 1") - - record = await queries.get() - self.assertEqual(record.query, "SELECT 1") - self.assertEqual(log.queries, ["SELECT 1"]) + await self.con.execute("SELECT 2") + + r1 = await queries.get() + r2 = await queries.get() + self.assertEqual(r1.query, "SELECT 1") + self.assertEqual(r2.query, "SELECT 2") + self.assertEqual(len(log.records), 1) + self.assertEqual(log.records[0].query, "SELECT 2") self.assertEqual(len(self.con._query_loggers), 0) + + async def test_error_logging(self): + with self.con.logger(LogCollector()) as log: + with self.assertRaises(exceptions.UndefinedColumnError): + await self.con.execute("SELECT x") + + await asyncio.sleep(0) # wait for logging + self.assertEqual(len(log.records), 1) + self.assertEqual( + type(log.records[0].exception), + exceptions.UndefinedColumnError + ) From 67dbc42c992b59a225fa8f59c36e3254aa4a9120 Mon Sep 17 00:00:00 2001 From: Dan Watson Date: Fri, 7 Jul 2023 09:13:58 -0400 Subject: [PATCH 6/9] Fix docs --- asyncpg/connection.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/asyncpg/connection.py b/asyncpg/connection.py index 389afcef..5cec8467 100644 --- a/asyncpg/connection.py +++ b/asyncpg/connection.py @@ -1703,7 +1703,7 @@ def logger(self, callback): :param callable callback: A callable or a coroutine function receiving one argument: **record**: a LoggedQuery containing `query`, `args`, `timeout`, - `elapsed`, `addr`, and `params`. + `elapsed`, `addr`, `params`, and `exception`. Example: @@ -1712,7 +1712,7 @@ def logger(self, callback): >>> class QuerySaver: def __init__(self): self.queries = [] - def __call__(self, conn, record): + def __call__(self, record): self.queries.append(record.query) >>> with con.logger(QuerySaver()) as log: >>> await con.execute("SELECT 1") From a21d9c50c49a7e0408a6ac6b9883b01540b70ac0 Mon Sep 17 00:00:00 2001 From: Dan Watson Date: Fri, 7 Jul 2023 14:48:16 -0400 Subject: [PATCH 7/9] Renamed LoggedQuery attrs --- asyncpg/connection.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/asyncpg/connection.py b/asyncpg/connection.py index 5cec8467..865ed3d9 100644 --- a/asyncpg/connection.py +++ b/asyncpg/connection.py @@ -229,7 +229,8 @@ def add_query_logger(self, callback): :param callable callback: A callable or a coroutine function receiving one argument: **record**: a LoggedQuery containing `query`, `args`, `timeout`, - `elapsed`, `addr`, `params`, and `exception`. + `elapsed`, `exception`, `conn_addr`, and + `conn_params`. .. versionadded:: 0.29.0 """ @@ -1703,7 +1704,8 @@ def logger(self, callback): :param callable callback: A callable or a coroutine function receiving one argument: **record**: a LoggedQuery containing `query`, `args`, `timeout`, - `elapsed`, `addr`, `params`, and `exception`. + `elapsed`, `exception`, `conn_addr`, and + `conn_params`. Example: @@ -1741,9 +1743,9 @@ def _time_and_log(self, query, args, timeout): args=args, timeout=timeout, elapsed=elapsed, - addr=self._addr, - params=self._params, exception=exception, + conn_addr=self._addr, + conn_params=self._params, ) for cb in self._query_loggers: if cb.is_async: @@ -2416,7 +2418,8 @@ class _ConnectionProxy: LoggedQuery = collections.namedtuple( 'LoggedQuery', - ['query', 'args', 'timeout', 'elapsed', 'exception', 'addr', 'params']) + ['query', 'args', 'timeout', 'elapsed', 'exception', 'conn_addr', + 'conn_params']) LoggedQuery.__doc__ = 'Log record of an executed query.' From 3d2924a86002c3b2855b7c14fada67fce067b9eb Mon Sep 17 00:00:00 2001 From: Elvis Pranskevichus Date: Mon, 9 Oct 2023 12:38:53 -0700 Subject: [PATCH 8/9] Avoid use of logging context when no loggers are configured --- asyncpg/connection.py | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/asyncpg/connection.py b/asyncpg/connection.py index 2989e068..21b8a810 100644 --- a/asyncpg/connection.py +++ b/asyncpg/connection.py @@ -343,7 +343,10 @@ async def execute(self, query: str, *args, timeout: float=None) -> str: self._check_open() if not args: - with self._time_and_log(query, args, timeout): + if self._query_loggers: + with self._time_and_log(query, args, timeout): + result = await self._protocol.query(query, timeout) + else: result = await self._protocol.query(query, timeout) return result @@ -1876,7 +1879,16 @@ async def __execute( timeout=timeout, ) timeout = self._protocol._get_timeout(timeout) - with self._time_and_log(query, args, timeout): + if self._query_loggers: + with self._time_and_log(query, args, timeout): + result, stmt = await self._do_execute( + query, + executor, + timeout, + record_class=record_class, + ignore_custom_codec=ignore_custom_codec, + ) + else: result, stmt = await self._do_execute( query, executor, From 040892d391e32097b5edc3b0852425a31892adaf Mon Sep 17 00:00:00 2001 From: Elvis Pranskevichus Date: Mon, 9 Oct 2023 12:48:07 -0700 Subject: [PATCH 9/9] s/logger/query_logger/, report `BaseException` too --- asyncpg/connection.py | 8 ++++---- tests/test_logging.py | 9 ++++++--- 2 files changed, 10 insertions(+), 7 deletions(-) diff --git a/asyncpg/connection.py b/asyncpg/connection.py index 21b8a810..0367e365 100644 --- a/asyncpg/connection.py +++ b/asyncpg/connection.py @@ -1803,7 +1803,7 @@ async def _execute( return result @contextlib.contextmanager - def logger(self, callback): + def query_logger(self, callback): """Context manager that adds `callback` to the list of query loggers, and removes it upon exit. @@ -1822,7 +1822,7 @@ def __init__(self): self.queries = [] def __call__(self, record): self.queries.append(record.query) - >>> with con.logger(QuerySaver()) as log: + >>> with con.query_logger(QuerySaver()): >>> await con.execute("SELECT 1") >>> print(log.queries) ['SELECT 1'] @@ -1830,7 +1830,7 @@ def __call__(self, record): .. versionadded:: 0.29.0 """ self.add_query_logger(callback) - yield callback + yield self.remove_query_logger(callback) @contextlib.contextmanager @@ -1839,7 +1839,7 @@ def _time_and_log(self, query, args, timeout): exception = None try: yield - except Exception as ex: + except BaseException as ex: exception = ex raise finally: diff --git a/tests/test_logging.py b/tests/test_logging.py index b40ba335..a9af94c4 100644 --- a/tests/test_logging.py +++ b/tests/test_logging.py @@ -20,10 +20,12 @@ async def test_logging_context(self): def query_saver(record): queries.put_nowait(record) - with self.con.logger(query_saver): + log = LogCollector() + + with self.con.query_logger(query_saver): self.assertEqual(len(self.con._query_loggers), 1) await self.con.execute("SELECT 1") - with self.con.logger(LogCollector()) as log: + with self.con.query_logger(log): self.assertEqual(len(self.con._query_loggers), 2) await self.con.execute("SELECT 2") @@ -36,7 +38,8 @@ def query_saver(record): self.assertEqual(len(self.con._query_loggers), 0) async def test_error_logging(self): - with self.con.logger(LogCollector()) as log: + log = LogCollector() + with self.con.query_logger(log): with self.assertRaises(exceptions.UndefinedColumnError): await self.con.execute("SELECT x")