Skip to content

Commit

Permalink
Partial revert defaultclient config setting (#7803)
Browse files Browse the repository at this point in the history
Co-authored-by: Hendrik Makait <[email protected]>
  • Loading branch information
fjetter and hendrikmakait authored Apr 27, 2023
1 parent 839f4a9 commit 68b5bbf
Show file tree
Hide file tree
Showing 2 changed files with 74 additions and 5 deletions.
17 changes: 12 additions & 5 deletions distributed/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -970,7 +970,8 @@ def __init__(

self._start_arg = address
self._set_as_default = set_as_default

if set_as_default:
self._set_config = dask.config.set(scheduler="dask.distributed")
self._event_handlers = {}

self._stream_handlers = {
Expand Down Expand Up @@ -1054,10 +1055,11 @@ class method to return self. Any Future objects deserialized inside this
context manager will be automatically attached to this Client.
"""
tok = _current_client.set(self)
try:
yield
finally:
_current_client.reset(tok)
with dask.config.set(scheduler="dask.distributed"):
try:
yield
finally:
_current_client.reset(tok)

@classmethod
def current(cls, allow_global=True):
Expand Down Expand Up @@ -1669,6 +1671,11 @@ async def _close(self, fast=False):
with log_errors():
_del_global_client(self)
self._scheduler_identity = {}
if self._set_as_default and not _get_global_client():
with suppress(AttributeError):
# clear the dask.config set keys
with self._set_config:
pass
if self.get == dask.config.get("get", None):
del dask.config.config["get"]

Expand Down
62 changes: 62 additions & 0 deletions distributed/tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -3240,6 +3240,68 @@ def test_default_get(loop_in_thread):
assert dask.base.get_scheduler() == pre_get


@gen_cluster(config={"scheduler": "sync"}, nthreads=[])
async def test_get_scheduler_default_client_config_interleaving(s):
# This test is using context managers intentionally. We should not refactor
# this to use it in more places to make the client closing cleaner.
with pytest.warns(UserWarning):
assert dask.base.get_scheduler() == dask.local.get_sync
with dask.config.set(scheduler="threads"):
assert dask.base.get_scheduler() == dask.threaded.get
client = await Client(s.address, set_as_default=False, asynchronous=True)
try:
assert dask.base.get_scheduler() == dask.threaded.get
finally:
await client.close()

client = await Client(s.address, set_as_default=True, asynchronous=True)
try:
assert dask.base.get_scheduler() == client.get
finally:
await client.close()
assert dask.base.get_scheduler() == dask.threaded.get

# FIXME: As soon as async with uses as_current this will be true as well
# async with Client(s.address, set_as_default=False, asynchronous=True) as c:
# assert dask.base.get_scheduler() == c.get
# assert dask.base.get_scheduler() == dask.threaded.get

client = await Client(s.address, set_as_default=False, asynchronous=True)
try:
assert dask.base.get_scheduler() == dask.threaded.get
with client.as_current():
sc = dask.base.get_scheduler()
assert sc == client.get
assert dask.base.get_scheduler() == dask.threaded.get
finally:
await client.close()

# If it comes to a race between default and current, current wins
client = await Client(s.address, set_as_default=True, asynchronous=True)
client2 = await Client(s.address, set_as_default=False, asynchronous=True)
try:
with client2.as_current():
assert dask.base.get_scheduler() == client2.get
assert dask.base.get_scheduler() == client.get
finally:
await client.close()
await client2.close()

assert dask.base.get_scheduler() == dask.threaded.get

assert dask.base.get_scheduler() == dask.local.get_sync

client = await Client(s.address, set_as_default=True, asynchronous=True)
try:
assert dask.base.get_scheduler() == client.get
with dask.config.set(scheduler="threads"):
assert dask.base.get_scheduler() == dask.threaded.get
with client.as_current():
assert dask.base.get_scheduler() == client.get
finally:
await client.close()


@gen_cluster(client=True)
async def test_ensure_default_client(c, s, a, b):
assert c is default_client()
Expand Down

0 comments on commit 68b5bbf

Please sign in to comment.