diff --git a/distributed/client.py b/distributed/client.py index b03f967128..ec55b7f2c0 100644 --- a/distributed/client.py +++ b/distributed/client.py @@ -153,6 +153,7 @@ def _get_global_client() -> Client | None: def _set_global_client(c: Client | None) -> None: if c is not None: + c._set_as_default = True _global_clients[_global_client_index[0]] = c _global_client_index[0] += 1 @@ -899,6 +900,7 @@ def __init__( deserializers = serializers self._deserializers = deserializers self.direct_to_workers = direct_to_workers + self._previous_as_current = None # Communication self.scheduler_comm = None @@ -1092,6 +1094,10 @@ def current(cls, allow_global=True): ------ ValueError If there is no client set, a ValueError is raised + + See also + -------- + default_client """ out = _current_client.get() if out: @@ -1397,8 +1403,6 @@ async def _ensure_connected(self, timeout=None): bcomm = BatchedSend(interval="10ms", loop=self.loop) bcomm.start(comm) self.scheduler_comm = bcomm - if self._set_as_default: - _set_global_client(self) self.status = "running" for msg in self._pending_msg_buffer: @@ -1486,13 +1490,19 @@ def _heartbeat(self): def __enter__(self): if not self._loop_runner.is_started(): self.start() + if self._set_as_default: + self._previous_as_current = _current_client.set(self) return self async def __aenter__(self): await self + if self._set_as_default: + self._previous_as_current = _current_client.set(self) return self async def __aexit__(self, exc_type, exc_value, traceback): + if self._previous_as_current: + _current_client.reset(self._previous_as_current) await self._close( # if we're handling an exception, we assume that it's more # important to deliver that exception than shutdown gracefully. @@ -1501,6 +1511,8 @@ async def __aexit__(self, exc_type, exc_value, traceback): ) def __exit__(self, exc_type, exc_value, traceback): + if self._previous_as_current: + _current_client.reset(self._previous_as_current) self.close() def __del__(self): @@ -5526,6 +5538,10 @@ def default_client(c=None): ------- c : Client The client, if one has started + + See also + -------- + Client.current (alias) """ c = c or _get_global_client() if c: @@ -5878,7 +5894,8 @@ def temp_default_client(c): old_exec = default_client() _set_global_client(c) try: - yield + with c.as_current(): + yield finally: _set_global_client(old_exec) diff --git a/distributed/tests/test_client.py b/distributed/tests/test_client.py index 2f9c252b0c..2ed731a160 100644 --- a/distributed/tests/test_client.py +++ b/distributed/tests/test_client.py @@ -62,6 +62,7 @@ Client, Future, _get_global_client, + _global_clients, as_completed, default_client, ensure_default_client, @@ -1142,10 +1143,10 @@ async def test_get_releases_data(c, s, a, b): await asyncio.sleep(0.01) -def test_current(s, a, b, loop_in_thread): - loop = loop_in_thread +def test_current(s, loop): with Client(s["address"], loop=loop) as c: assert Client.current() is c + assert Client.current(allow_global=False) is c with pytest.raises( ValueError, match=r"No clients found" @@ -1156,6 +1157,121 @@ def test_current(s, a, b, loop_in_thread): Client.current() with Client(s["address"], loop=loop) as c: assert Client.current() is c + assert Client.current(allow_global=False) is c + + +def test_current_nested(s, loop): + with pytest.raises( + ValueError, + match=r"No clients found" + r"\nStart a client and point it to the scheduler address" + r"\n from distributed import Client" + r"\n client = Client\('ip-addr-of-scheduler:8786'\)", + ): + Client.current() + + class MyException(Exception): + pass + + with Client(s["address"], loop=loop) as c_outer: + assert Client.current() is c_outer + assert Client.current(allow_global=False) is c_outer + + with Client(s["address"], loop=loop) as c_inner: + assert Client.current() is c_inner + assert Client.current(allow_global=False) is c_inner + + assert Client.current() is c_outer + assert Client.current(allow_global=False) is c_outer + + with pytest.raises(MyException): + with Client(s["address"], loop=loop) as c_inner2: + assert Client.current() is c_inner2 + assert Client.current(allow_global=False) is c_inner2 + raise MyException + + assert Client.current() is c_outer + assert Client.current(allow_global=False) is c_outer + + +@gen_cluster(nthreads=[]) +async def test_current_nested_async(s): + with pytest.raises( + ValueError, + match=r"No clients found" + r"\nStart a client and point it to the scheduler address" + r"\n from distributed import Client" + r"\n client = Client\('ip-addr-of-scheduler:8786'\)", + ): + Client.current() + + class MyException(Exception): + pass + + async with Client(s.address, asynchronous=True) as c_outer: + assert Client.current() is c_outer + assert Client.current(allow_global=False) is c_outer + + async with Client(s.address, asynchronous=True) as c_inner: + assert Client.current() is c_inner + assert Client.current(allow_global=False) is c_inner + + assert Client.current() is c_outer + assert Client.current(allow_global=False) is c_outer + + with pytest.raises(MyException): + async with Client(s.address, asynchronous=True) as c_inner2: + assert Client.current() is c_inner2 + assert Client.current(allow_global=False) is c_inner2 + raise MyException + + assert Client.current() is c_outer + assert Client.current(allow_global=False) is c_outer + + +@gen_cluster(nthreads=[]) +async def test_current_concurrent(s): + client_1_started = asyncio.Event() + client_2_started = asyncio.Event() + stop_client_1 = asyncio.Event() + stop_client_2 = asyncio.Event() + client_2_stopped = asyncio.Event() + + c1 = None + c2 = None + + def _all_global_clients(): + return [v for _, v in sorted(_global_clients.items())] + + async def client_1(): + nonlocal c1 + async with Client(s.address, asynchronous=True) as c1: + assert _all_global_clients() == [c1] + assert Client.current() is c1 + client_1_started.set() + await client_2_started.wait() + # c2 is the highest priority global client + assert _all_global_clients() == [c1, c2] + # but the contextvar means the current client is still us + assert Client.current() is c1 + stop_client_2.set() + await stop_client_1.wait() + + async def client_2(): + nonlocal c2 + await client_1_started.wait() + async with Client(s.address, asynchronous=True) as c2: + assert _all_global_clients() == [c1, c2] + assert Client.current() is c2 + client_2_started.set() + await stop_client_2.wait() + + assert _all_global_clients() == [c1] + # Client.current() is now based on _global_clients instead of the cvar + assert Client.current() is c1 + stop_client_1.set() + + await asyncio.gather(client_1(), client_2()) def test_global_clients(loop): @@ -3294,16 +3410,21 @@ async def test_get_scheduler_default_client_config_interleaving(s): await client.close() -@gen_cluster(client=True) -async def test_ensure_default_client(c, s, a, b): - assert c is default_client() - - async with Client(s.address, set_as_default=False, asynchronous=True) as c2: +@gen_cluster() +async def test_ensure_default_client(s, a, b): + # Note: this test will fail if you use `async with Client` + c = await Client(s.address, asynchronous=True) + try: assert c is default_client() - assert c2 is not default_client() - ensure_default_client(c2) - assert c is not default_client() - assert c2 is default_client() + + async with Client(s.address, set_as_default=False, asynchronous=True) as c2: + assert c is default_client() + assert c2 is not default_client() + ensure_default_client(c2) + assert c is not default_client() + assert c2 is default_client() + finally: + await c.close() @gen_cluster() @@ -4022,8 +4143,7 @@ async def test_as_current(c, s, a, b): ) as c2: with temp_default_client(c): assert Client.current() is c - with pytest.raises(ValueError): - Client.current(allow_global=False) + assert Client.current(allow_global=False) is c with c1.as_current(): assert Client.current() is c1 assert Client.current(allow_global=True) is c1 diff --git a/distributed/tests/test_events.py b/distributed/tests/test_events.py index 5d6f4bbe84..b26df7ad39 100644 --- a/distributed/tests/test_events.py +++ b/distributed/tests/test_events.py @@ -224,15 +224,15 @@ def event_is_set(event_name): assert not s.extensions["events"]._waiter_count -@gen_cluster(client=True, nthreads=[]) -async def test_unpickle_without_client(c, s): +@gen_cluster(nthreads=[]) +async def test_unpickle_without_client(s): """Ensure that the object properly pickle roundtrips even if no client, worker, etc. is active in the given context. This typically happens if the object is being deserialized on the scheduler. """ - obj = await Event() - pickled = pickle.dumps(obj) - await c.close() + async with Client(s.address, asynchronous=True) as c: + obj = await Event() + pickled = pickle.dumps(obj) # We do not want to initialize a client during unpickling with pytest.raises(ValueError): diff --git a/distributed/tests/test_publish.py b/distributed/tests/test_publish.py index fe898d1513..f1107b5761 100644 --- a/distributed/tests/test_publish.py +++ b/distributed/tests/test_publish.py @@ -300,4 +300,4 @@ async def test_deserialize_client(c, s, a, b): # Ensure cleanup from distributed.client import _current_client - assert _current_client.get() is None + assert _current_client.get() is c diff --git a/distributed/tests/test_queues.py b/distributed/tests/test_queues.py index ef92caa77b..8a0dd96ca9 100644 --- a/distributed/tests/test_queues.py +++ b/distributed/tests/test_queues.py @@ -303,15 +303,15 @@ def foo(): assert result == 123 -@gen_cluster(client=True, nthreads=[]) -async def test_unpickle_without_client(c, s): +@gen_cluster(nthreads=[]) +async def test_unpickle_without_client(s): """Ensure that the object properly pickle roundtrips even if no client, worker, etc. is active in the given context. This typically happens if the object is being deserialized on the scheduler. """ - q = await Queue() - pickled = pickle.dumps(q) - await c.close() + async with Client(s.address, asynchronous=True) as c: + q = await Queue() + pickled = pickle.dumps(q) # We do not want to initialize a client during unpickling with pytest.raises(ValueError): diff --git a/distributed/tests/test_semaphore.py b/distributed/tests/test_semaphore.py index 025a965396..0cd98e1613 100644 --- a/distributed/tests/test_semaphore.py +++ b/distributed/tests/test_semaphore.py @@ -585,15 +585,15 @@ async def test_release_failure(c, s, a, b, caplog): await pool.close() -@gen_cluster(client=True, nthreads=[]) -async def test_unpickle_without_client(c, s): +@gen_cluster(nthreads=[]) +async def test_unpickle_without_client(s): """Ensure that the object properly pickle roundtrips even if no client, worker, etc. is active in the given context. This typically happens if the object is being deserialized on the scheduler. """ - sem = await Semaphore() - pickled = pickle.dumps(sem) - await c.close() + async with Client(s.address, asynchronous=True) as c: + sem = await Semaphore() + pickled = pickle.dumps(sem) # We do not want to initialize a client during unpickling with pytest.raises(ValueError): diff --git a/distributed/tests/test_variable.py b/distributed/tests/test_variable.py index 2e711e9940..00618da70d 100644 --- a/distributed/tests/test_variable.py +++ b/distributed/tests/test_variable.py @@ -297,15 +297,15 @@ async def test_variables_do_not_leak_client(c, s, a, b): assert time() < start + 5 -@gen_cluster(client=True, nthreads=[]) -async def test_unpickle_without_client(c, s): +@gen_cluster(nthreads=[]) +async def test_unpickle_without_client(s): """Ensure that the object properly pickle roundtrips even if no client, worker, etc. is active in the given context. This typically happens if the object is being deserialized on the scheduler. """ - obj = Variable("foo") - pickled = pickle.dumps(obj) - await c.close() + async with Client(s.address, asynchronous=True) as c: + obj = Variable("foo") + pickled = pickle.dumps(obj) # We do not want to initialize a client during unpickling with pytest.raises(ValueError):