Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 38 additions & 27 deletions dask_cuda/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,39 +5,50 @@

import dask
import dask.utils
import dask.dataframe.shuffle
from .explicit_comms.dataframe.shuffle import patch_shuffle_expression
from distributed.protocol.cuda import cuda_deserialize, cuda_serialize
from distributed.protocol.serialize import dask_deserialize, dask_serialize

from ._version import __git_commit__, __version__
from .cuda_worker import CUDAWorker

from .local_cuda_cluster import LocalCUDACluster
from .proxify_device_objects import proxify_decorator, unproxify_decorator


# Monkey patching Dask to make use of explicit-comms when `DASK_EXPLICIT_COMMS=True`
patch_shuffle_expression()
# Monkey patching Dask to make use of proxify and unproxify in compatibility mode
dask.dataframe.shuffle.shuffle_group = proxify_decorator(
dask.dataframe.shuffle.shuffle_group
)
dask.dataframe.core._concat = unproxify_decorator(dask.dataframe.core._concat)


def _register_cudf_spill_aware():
import cudf

# Only enable Dask/cuDF spilling if cuDF spilling is disabled, see
# https://github.com/rapidsai/dask-cuda/issues/1363
if not cudf.get_option("spill"):
# This reproduces the implementation of `_register_cudf`, see
# https://github.com/dask/distributed/blob/40fcd65e991382a956c3b879e438be1b100dff97/distributed/protocol/__init__.py#L106-L115
from cudf.comm import serialize


for registry in [cuda_serialize, cuda_deserialize, dask_serialize, dask_deserialize]:
for lib in ["cudf", "dask_cudf"]:
if lib in registry._lazy:
registry._lazy[lib] = _register_cudf_spill_aware
try:
import dask.dataframe as dask_dataframe
except ImportError:
# Dask DataFrame (optional) isn't installed
dask_dataframe = None


if dask_dataframe is not None:
from .explicit_comms.dataframe.shuffle import patch_shuffle_expression
from .proxify_device_objects import proxify_decorator, unproxify_decorator

# Monkey patching Dask to make use of explicit-comms when `DASK_EXPLICIT_COMMS=True`
patch_shuffle_expression()
# Monkey patching Dask to make use of proxify and unproxify in compatibility mode
dask_dataframe.shuffle.shuffle_group = proxify_decorator(
dask.dataframe.shuffle.shuffle_group
)
dask_dataframe.core._concat = unproxify_decorator(dask.dataframe.core._concat)

def _register_cudf_spill_aware():
import cudf

# Only enable Dask/cuDF spilling if cuDF spilling is disabled, see
# https://github.com/rapidsai/dask-cuda/issues/1363
if not cudf.get_option("spill"):
# This reproduces the implementation of `_register_cudf`, see
# https://github.com/dask/distributed/blob/40fcd65e991382a956c3b879e438be1b100dff97/distributed/protocol/__init__.py#L106-L115
from cudf.comm import serialize

for registry in [
cuda_serialize,
cuda_deserialize,
dask_serialize,
dask_deserialize,
]:
for lib in ["cudf", "dask_cudf"]:
if lib in registry._lazy:
registry._lazy[lib] = _register_cudf_spill_aware
90 changes: 55 additions & 35 deletions dask_cuda/proxy_object.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,6 @@
import pandas

import dask
import dask.array.core
import dask.dataframe.backends
import dask.dataframe.dispatch
import dask.dataframe.utils
import dask.utils
import distributed.protocol
import distributed.utils
Expand All @@ -30,6 +26,22 @@
from .proxify_host_file import ProxyManager


try:
import dask.dataframe as dask_dataframe
import dask.dataframe.backends
import dask.dataframe.dispatch
import dask.dataframe.utils
except ImportError:
dask_dataframe = None


try:
import dask.array as dask_array
import dask.array.core
except ImportError:
dask_array = None


# List of attributes that should be copied to the proxy at creation, which makes
# them accessible without deserialization of the proxied object
_FIXED_ATTRS = ["name", "__len__"]
Expand Down Expand Up @@ -884,14 +896,6 @@ def obj_pxy_dask_deserialize(header, frames):
return subclass(pxy)


@dask.dataframe.dispatch.get_parallel_type.register(ProxyObject)
def get_parallel_type_proxy_object(obj: ProxyObject):
# Notice, `get_parallel_type()` needs a instance not a type object
return dask.dataframe.dispatch.get_parallel_type(
obj.__class__.__new__(obj.__class__)
)


def unproxify_input_wrapper(func):
"""Unproxify the input of `func`"""

Expand All @@ -904,26 +908,42 @@ def wrapper(*args, **kwargs):
return wrapper


# Register dispatch of ProxyObject on all known dispatch objects
for dispatch in (
dask.dataframe.dispatch.hash_object_dispatch,
dask.dataframe.dispatch.make_meta_dispatch,
dask.dataframe.utils.make_scalar,
dask.dataframe.dispatch.group_split_dispatch,
dask.array.core.tensordot_lookup,
dask.array.core.einsum_lookup,
dask.array.core.concatenate_lookup,
):
dispatch.register(ProxyObject, unproxify_input_wrapper(dispatch))

dask.dataframe.dispatch.concat_dispatch.register(
ProxyObject, unproxify_input_wrapper(dask.dataframe.dispatch.concat)
)


# We overwrite the Dask dispatch of Pandas objects in order to
# deserialize all ProxyObjects before concatenating
dask.dataframe.dispatch.concat_dispatch.register(
(pandas.DataFrame, pandas.Series, pandas.Index),
unproxify_input_wrapper(dask.dataframe.backends.concat_pandas),
)
if dask_array is not None:

# Register dispatch of ProxyObject on all known dispatch objects
for dispatch in (
dask.array.core.tensordot_lookup,
dask.array.core.einsum_lookup,
dask.array.core.concatenate_lookup,
):
dispatch.register(ProxyObject, unproxify_input_wrapper(dispatch))


if dask_dataframe is not None:

@dask.dataframe.dispatch.get_parallel_type.register(ProxyObject)
def get_parallel_type_proxy_object(obj: ProxyObject):
# Notice, `get_parallel_type()` needs a instance not a type object
return dask.dataframe.dispatch.get_parallel_type(
obj.__class__.__new__(obj.__class__)
)

# Register dispatch of ProxyObject on all known dispatch objects
for dispatch in (
dask.dataframe.dispatch.hash_object_dispatch,
dask.dataframe.dispatch.make_meta_dispatch,
dask.dataframe.utils.make_scalar,
dask.dataframe.dispatch.group_split_dispatch,
):
dispatch.register(ProxyObject, unproxify_input_wrapper(dispatch))

dask.dataframe.dispatch.concat_dispatch.register(
ProxyObject, unproxify_input_wrapper(dask.dataframe.dispatch.concat)
)

# We overwrite the Dask dispatch of Pandas objects in order to
# deserialize all ProxyObjects before concatenating
dask.dataframe.dispatch.concat_dispatch.register(
(pandas.DataFrame, pandas.Series, pandas.Index),
unproxify_input_wrapper(dask.dataframe.backends.concat_pandas),
)
36 changes: 36 additions & 0 deletions dask_cuda/tests/test_initialize.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import multiprocessing as mp
import sys

import numpy
import psutil
Expand Down Expand Up @@ -214,3 +215,38 @@ def test_initialize_ucx_all(protocol):
p.start()
p.join()
assert not p.exitcode


def _test_dask_cuda_import():
# Check that importing `dask_cuda` does NOT
# require `dask.dataframe` or `dask.array`.

# Patch sys.modules so that `dask.dataframe`
# and `dask.array` cannot be found.
with pytest.MonkeyPatch.context() as monkeypatch:
for k in list(sys.modules):
if k.startswith("dask.dataframe") or k.startswith("dask.array"):
monkeypatch.setitem(sys.modules, k, None)
monkeypatch.delitem(sys.modules, "dask_cuda")

# Check that top-level imports still succeed.
import dask_cuda # noqa: F401
from dask_cuda import CUDAWorker # noqa: F401
from dask_cuda import LocalCUDACluster

with LocalCUDACluster(
dashboard_address=None,
n_workers=1,
threads_per_worker=1,
processes=True,
worker_class=IncreasedCloseTimeoutNanny,
) as cluster:
with Client(cluster) as client:
client.run(lambda *args: None)


def test_dask_cuda_import():
p = mp.Process(target=_test_dask_cuda_import)
p.start()
p.join()
assert not p.exitcode