diff --git a/clickhouse_connect/driver/asyncclient.py b/clickhouse_connect/driver/asyncclient.py index d2346f04..e78dfd5d 100644 --- a/clickhouse_connect/driver/asyncclient.py +++ b/clickhouse_connect/driver/asyncclient.py @@ -64,11 +64,12 @@ def min_version(self, version_str: str) -> bool: """ return self.client.min_version(version_str) - def close(self): + async def close(self): """ Subclass implementation to close the connection to the server/deallocate the client """ self.client.close() + await asyncio.to_thread(self.executor.shutdown, True) async def query(self, query: Optional[str] = None, @@ -676,3 +677,9 @@ def _raw_insert(): loop = asyncio.get_running_loop() result = await loop.run_in_executor(self.executor, _raw_insert) return result + + async def __aenter__(self) -> "AsyncClient": + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb) -> None: + await self.close() diff --git a/examples/run_async.py b/examples/run_async.py index c4cb338e..069e3f90 100644 --- a/examples/run_async.py +++ b/examples/run_async.py @@ -41,6 +41,7 @@ async def semaphore_wrapper(sm: asyncio.Semaphore, num: int): semaphore = asyncio.Semaphore(SEMAPHORE) await asyncio.gather(*[semaphore_wrapper(semaphore, num) for num in range(QUERIES)]) + await client.close() async def main(): diff --git a/tests/integration_tests/conftest.py b/tests/integration_tests/conftest.py index f2fb6be1..6beeef25 100644 --- a/tests/integration_tests/conftest.py +++ b/tests/integration_tests/conftest.py @@ -3,8 +3,9 @@ import random import time from subprocess import Popen, PIPE -from typing import Iterator, NamedTuple, Sequence, Optional, Callable +from typing import Iterator, NamedTuple, Sequence, Optional, Callable, AsyncContextManager +import pytest_asyncio from pytest import fixture from clickhouse_connect import common @@ -129,9 +130,10 @@ def test_client_fixture(test_config: TestConfig, test_create_client: Callable) - sys.stderr.write('Successfully stopped docker compose') -@fixture(scope='session', autouse=True, name='test_async_client') -def test_async_client_fixture(test_client: Client) -> Iterator[AsyncClient]: - yield AsyncClient(client=test_client) +@pytest_asyncio.fixture(scope='session', autouse=True, name='test_async_client') +async def test_async_client_fixture(test_client: Client) -> AsyncContextManager[AsyncClient]: + async with AsyncClient(client=test_client) as client: + yield client @fixture(scope='session', name='table_context') diff --git a/tests/integration_tests/test_session_id.py b/tests/integration_tests/test_session_id.py index 84c683f9..17ac9112 100644 --- a/tests/integration_tests/test_session_id.py +++ b/tests/integration_tests/test_session_id.py @@ -46,7 +46,7 @@ async def test_async_client_default_session_id(test_config: TestConfig): user=test_config.username, password=test_config.password) assert async_client.get_client_setting(SESSION_KEY) is None - async_client.close() + await async_client.close() @pytest.mark.asyncio @@ -62,7 +62,7 @@ async def test_async_client_autogenerate_session_id(test_config: TestConfig): uuid.UUID(session_id) except ValueError: pytest.fail(f"Invalid session_id: {session_id}") - async_client.close() + await async_client.close() @pytest.mark.asyncio @@ -75,4 +75,4 @@ async def test_async_client_custom_session_id(test_config: TestConfig): password=test_config.password, session_id=session_id) assert async_client.get_client_setting(SESSION_KEY) == session_id - async_client.close() + await async_client.close()