diff --git a/newrelic/hooks/database_aiomysql.py b/newrelic/hooks/database_aiomysql.py index 9a2f3d1d18..2cedcb40f9 100644 --- a/newrelic/hooks/database_aiomysql.py +++ b/newrelic/hooks/database_aiomysql.py @@ -78,6 +78,10 @@ async def _wrap_pool__acquire(wrapped, instance, args, kwargs): with FunctionTrace(name=callable_name(wrapped), terminal=True, rollup=rollup, source=wrapped): connection = await wrapped(*args, **kwargs) connection_kwargs = getattr(instance, "_conn_kwargs", {}) + + if hasattr(connection, "__wrapped__"): + return connection + return AsyncConnectionWrapper(connection, dbapi2_module, (((), connection_kwargs))) return _wrap_pool__acquire diff --git a/tests/datastore_aiomysql/test_database.py b/tests/datastore_aiomysql/test_database.py index 20d1a48586..8cc386cfe1 100644 --- a/tests/datastore_aiomysql/test_database.py +++ b/tests/datastore_aiomysql/test_database.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +import inspect + import aiomysql from testing_support.db_settings import mysql_settings from testing_support.util import instance_hostname @@ -150,3 +152,35 @@ async def _test(): await pool.wait_closed() loop.run_until_complete(_test()) + + +@background_task() +def test_connection_pool_no_double_wrap(loop): + async def _test(): + pool = await aiomysql.create_pool( + db=DB_SETTINGS["name"], + user=DB_SETTINGS["user"], + password=DB_SETTINGS["password"], + host=DB_SETTINGS["host"], + port=DB_SETTINGS["port"], + loop=loop, + ) + + # Retrieve the same connection from the pool twice to see if it gets double wrapped + async with pool.acquire() as first_connection: + first_connection_unwrapped = inspect.unwrap(first_connection) + async with pool.acquire() as second_connection: + second_connection_unwrapped = inspect.unwrap(second_connection) + + # Ensure we actually retrieved the same underlying connection object from the pool twice + assert first_connection_unwrapped is second_connection_unwrapped, "Did not get same connection from pool" + + # Check that wrapping occurred only once + assert hasattr(first_connection, "__wrapped__"), "first_connection object was not wrapped" + assert hasattr(second_connection, "__wrapped__"), "second_connection object was not wrapped" + assert not hasattr(second_connection.__wrapped__, "__wrapped__"), "second_connection was double wrapped" + + pool.close() + await pool.wait_closed() + + loop.run_until_complete(_test())