Skip to content

Commit

Permalink
Set Client.as_current when entering ctx (#6527)
Browse files Browse the repository at this point in the history
  • Loading branch information
fjetter authored Jul 7, 2023
1 parent 631d92a commit 8107f1f
Show file tree
Hide file tree
Showing 7 changed files with 174 additions and 37 deletions.
23 changes: 20 additions & 3 deletions distributed/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,7 @@ def _get_global_client() -> Client | None:

def _set_global_client(c: Client | None) -> None:
if c is not None:
c._set_as_default = True
_global_clients[_global_client_index[0]] = c
_global_client_index[0] += 1

Expand Down Expand Up @@ -899,6 +900,7 @@ def __init__(
deserializers = serializers
self._deserializers = deserializers
self.direct_to_workers = direct_to_workers
self._previous_as_current = None

# Communication
self.scheduler_comm = None
Expand Down Expand Up @@ -1092,6 +1094,10 @@ def current(cls, allow_global=True):
------
ValueError
If there is no client set, a ValueError is raised
See also
--------
default_client
"""
out = _current_client.get()
if out:
Expand Down Expand Up @@ -1397,8 +1403,6 @@ async def _ensure_connected(self, timeout=None):
bcomm = BatchedSend(interval="10ms", loop=self.loop)
bcomm.start(comm)
self.scheduler_comm = bcomm
if self._set_as_default:
_set_global_client(self)
self.status = "running"

for msg in self._pending_msg_buffer:
Expand Down Expand Up @@ -1486,13 +1490,19 @@ def _heartbeat(self):
def __enter__(self):
if not self._loop_runner.is_started():
self.start()
if self._set_as_default:
self._previous_as_current = _current_client.set(self)
return self

async def __aenter__(self):
await self
if self._set_as_default:
self._previous_as_current = _current_client.set(self)
return self

async def __aexit__(self, exc_type, exc_value, traceback):
if self._previous_as_current:
_current_client.reset(self._previous_as_current)
await self._close(
# if we're handling an exception, we assume that it's more
# important to deliver that exception than shutdown gracefully.
Expand All @@ -1501,6 +1511,8 @@ async def __aexit__(self, exc_type, exc_value, traceback):
)

def __exit__(self, exc_type, exc_value, traceback):
if self._previous_as_current:
_current_client.reset(self._previous_as_current)
self.close()

def __del__(self):
Expand Down Expand Up @@ -5526,6 +5538,10 @@ def default_client(c=None):
-------
c : Client
The client, if one has started
See also
--------
Client.current (alias)
"""
c = c or _get_global_client()
if c:
Expand Down Expand Up @@ -5878,7 +5894,8 @@ def temp_default_client(c):
old_exec = default_client()
_set_global_client(c)
try:
yield
with c.as_current():
yield
finally:
_set_global_client(old_exec)

Expand Down
146 changes: 133 additions & 13 deletions distributed/tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@
Client,
Future,
_get_global_client,
_global_clients,
as_completed,
default_client,
ensure_default_client,
Expand Down Expand Up @@ -1142,10 +1143,10 @@ async def test_get_releases_data(c, s, a, b):
await asyncio.sleep(0.01)


def test_current(s, a, b, loop_in_thread):
loop = loop_in_thread
def test_current(s, loop):
with Client(s["address"], loop=loop) as c:
assert Client.current() is c
assert Client.current(allow_global=False) is c
with pytest.raises(
ValueError,
match=r"No clients found"
Expand All @@ -1156,6 +1157,121 @@ def test_current(s, a, b, loop_in_thread):
Client.current()
with Client(s["address"], loop=loop) as c:
assert Client.current() is c
assert Client.current(allow_global=False) is c


def test_current_nested(s, loop):
with pytest.raises(
ValueError,
match=r"No clients found"
r"\nStart a client and point it to the scheduler address"
r"\n from distributed import Client"
r"\n client = Client\('ip-addr-of-scheduler:8786'\)",
):
Client.current()

class MyException(Exception):
pass

with Client(s["address"], loop=loop) as c_outer:
assert Client.current() is c_outer
assert Client.current(allow_global=False) is c_outer

with Client(s["address"], loop=loop) as c_inner:
assert Client.current() is c_inner
assert Client.current(allow_global=False) is c_inner

assert Client.current() is c_outer
assert Client.current(allow_global=False) is c_outer

with pytest.raises(MyException):
with Client(s["address"], loop=loop) as c_inner2:
assert Client.current() is c_inner2
assert Client.current(allow_global=False) is c_inner2
raise MyException

assert Client.current() is c_outer
assert Client.current(allow_global=False) is c_outer


@gen_cluster(nthreads=[])
async def test_current_nested_async(s):
with pytest.raises(
ValueError,
match=r"No clients found"
r"\nStart a client and point it to the scheduler address"
r"\n from distributed import Client"
r"\n client = Client\('ip-addr-of-scheduler:8786'\)",
):
Client.current()

class MyException(Exception):
pass

async with Client(s.address, asynchronous=True) as c_outer:
assert Client.current() is c_outer
assert Client.current(allow_global=False) is c_outer

async with Client(s.address, asynchronous=True) as c_inner:
assert Client.current() is c_inner
assert Client.current(allow_global=False) is c_inner

assert Client.current() is c_outer
assert Client.current(allow_global=False) is c_outer

with pytest.raises(MyException):
async with Client(s.address, asynchronous=True) as c_inner2:
assert Client.current() is c_inner2
assert Client.current(allow_global=False) is c_inner2
raise MyException

assert Client.current() is c_outer
assert Client.current(allow_global=False) is c_outer


@gen_cluster(nthreads=[])
async def test_current_concurrent(s):
client_1_started = asyncio.Event()
client_2_started = asyncio.Event()
stop_client_1 = asyncio.Event()
stop_client_2 = asyncio.Event()
client_2_stopped = asyncio.Event()

c1 = None
c2 = None

def _all_global_clients():
return [v for _, v in sorted(_global_clients.items())]

async def client_1():
nonlocal c1
async with Client(s.address, asynchronous=True) as c1:
assert _all_global_clients() == [c1]
assert Client.current() is c1
client_1_started.set()
await client_2_started.wait()
# c2 is the highest priority global client
assert _all_global_clients() == [c1, c2]
# but the contextvar means the current client is still us
assert Client.current() is c1
stop_client_2.set()
await stop_client_1.wait()

async def client_2():
nonlocal c2
await client_1_started.wait()
async with Client(s.address, asynchronous=True) as c2:
assert _all_global_clients() == [c1, c2]
assert Client.current() is c2
client_2_started.set()
await stop_client_2.wait()

assert _all_global_clients() == [c1]
# Client.current() is now based on _global_clients instead of the cvar
assert Client.current() is c1
stop_client_1.set()

await asyncio.gather(client_1(), client_2())


def test_global_clients(loop):
Expand Down Expand Up @@ -3294,16 +3410,21 @@ async def test_get_scheduler_default_client_config_interleaving(s):
await client.close()


@gen_cluster(client=True)
async def test_ensure_default_client(c, s, a, b):
assert c is default_client()

async with Client(s.address, set_as_default=False, asynchronous=True) as c2:
@gen_cluster()
async def test_ensure_default_client(s, a, b):
# Note: this test will fail if you use `async with Client`
c = await Client(s.address, asynchronous=True)
try:
assert c is default_client()
assert c2 is not default_client()
ensure_default_client(c2)
assert c is not default_client()
assert c2 is default_client()

async with Client(s.address, set_as_default=False, asynchronous=True) as c2:
assert c is default_client()
assert c2 is not default_client()
ensure_default_client(c2)
assert c is not default_client()
assert c2 is default_client()
finally:
await c.close()


@gen_cluster()
Expand Down Expand Up @@ -4022,8 +4143,7 @@ async def test_as_current(c, s, a, b):
) as c2:
with temp_default_client(c):
assert Client.current() is c
with pytest.raises(ValueError):
Client.current(allow_global=False)
assert Client.current(allow_global=False) is c
with c1.as_current():
assert Client.current() is c1
assert Client.current(allow_global=True) is c1
Expand Down
10 changes: 5 additions & 5 deletions distributed/tests/test_events.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,15 +224,15 @@ def event_is_set(event_name):
assert not s.extensions["events"]._waiter_count


@gen_cluster(client=True, nthreads=[])
async def test_unpickle_without_client(c, s):
@gen_cluster(nthreads=[])
async def test_unpickle_without_client(s):
"""Ensure that the object properly pickle roundtrips even if no client, worker, etc. is active in the given context.
This typically happens if the object is being deserialized on the scheduler.
"""
obj = await Event()
pickled = pickle.dumps(obj)
await c.close()
async with Client(s.address, asynchronous=True) as c:
obj = await Event()
pickled = pickle.dumps(obj)

# We do not want to initialize a client during unpickling
with pytest.raises(ValueError):
Expand Down
2 changes: 1 addition & 1 deletion distributed/tests/test_publish.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,4 +300,4 @@ async def test_deserialize_client(c, s, a, b):
# Ensure cleanup
from distributed.client import _current_client

assert _current_client.get() is None
assert _current_client.get() is c
10 changes: 5 additions & 5 deletions distributed/tests/test_queues.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,15 +303,15 @@ def foo():
assert result == 123


@gen_cluster(client=True, nthreads=[])
async def test_unpickle_without_client(c, s):
@gen_cluster(nthreads=[])
async def test_unpickle_without_client(s):
"""Ensure that the object properly pickle roundtrips even if no client, worker, etc. is active in the given context.
This typically happens if the object is being deserialized on the scheduler.
"""
q = await Queue()
pickled = pickle.dumps(q)
await c.close()
async with Client(s.address, asynchronous=True) as c:
q = await Queue()
pickled = pickle.dumps(q)

# We do not want to initialize a client during unpickling
with pytest.raises(ValueError):
Expand Down
10 changes: 5 additions & 5 deletions distributed/tests/test_semaphore.py
Original file line number Diff line number Diff line change
Expand Up @@ -585,15 +585,15 @@ async def test_release_failure(c, s, a, b, caplog):
await pool.close()


@gen_cluster(client=True, nthreads=[])
async def test_unpickle_without_client(c, s):
@gen_cluster(nthreads=[])
async def test_unpickle_without_client(s):
"""Ensure that the object properly pickle roundtrips even if no client, worker, etc. is active in the given context.
This typically happens if the object is being deserialized on the scheduler.
"""
sem = await Semaphore()
pickled = pickle.dumps(sem)
await c.close()
async with Client(s.address, asynchronous=True) as c:
sem = await Semaphore()
pickled = pickle.dumps(sem)

# We do not want to initialize a client during unpickling
with pytest.raises(ValueError):
Expand Down
10 changes: 5 additions & 5 deletions distributed/tests/test_variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,15 +297,15 @@ async def test_variables_do_not_leak_client(c, s, a, b):
assert time() < start + 5


@gen_cluster(client=True, nthreads=[])
async def test_unpickle_without_client(c, s):
@gen_cluster(nthreads=[])
async def test_unpickle_without_client(s):
"""Ensure that the object properly pickle roundtrips even if no client, worker, etc. is active in the given context.
This typically happens if the object is being deserialized on the scheduler.
"""
obj = Variable("foo")
pickled = pickle.dumps(obj)
await c.close()
async with Client(s.address, asynchronous=True) as c:
obj = Variable("foo")
pickled = pickle.dumps(obj)

# We do not want to initialize a client during unpickling
with pytest.raises(ValueError):
Expand Down

0 comments on commit 8107f1f

Please sign in to comment.