diff --git a/dask_cuda/__init__.py b/dask_cuda/__init__.py index a93cb1cb5..355588920 100644 --- a/dask_cuda/__init__.py +++ b/dask_cuda/__init__.py @@ -5,8 +5,6 @@ import dask import dask.utils -import dask.dataframe.shuffle -from .explicit_comms.dataframe.shuffle import patch_shuffle_expression from distributed.protocol.cuda import cuda_deserialize, cuda_serialize from distributed.protocol.serialize import dask_deserialize, dask_serialize @@ -14,30 +12,43 @@ from .cuda_worker import CUDAWorker from .local_cuda_cluster import LocalCUDACluster -from .proxify_device_objects import proxify_decorator, unproxify_decorator -# Monkey patching Dask to make use of explicit-comms when `DASK_EXPLICIT_COMMS=True` -patch_shuffle_expression() -# Monkey patching Dask to make use of proxify and unproxify in compatibility mode -dask.dataframe.shuffle.shuffle_group = proxify_decorator( - dask.dataframe.shuffle.shuffle_group -) -dask.dataframe.core._concat = unproxify_decorator(dask.dataframe.core._concat) - - -def _register_cudf_spill_aware(): - import cudf - - # Only enable Dask/cuDF spilling if cuDF spilling is disabled, see - # https://github.com/rapidsai/dask-cuda/issues/1363 - if not cudf.get_option("spill"): - # This reproduces the implementation of `_register_cudf`, see - # https://github.com/dask/distributed/blob/40fcd65e991382a956c3b879e438be1b100dff97/distributed/protocol/__init__.py#L106-L115 - from cudf.comm import serialize - - -for registry in [cuda_serialize, cuda_deserialize, dask_serialize, dask_deserialize]: - for lib in ["cudf", "dask_cudf"]: - if lib in registry._lazy: - registry._lazy[lib] = _register_cudf_spill_aware +try: + import dask.dataframe as dask_dataframe +except ImportError: + # Dask DataFrame (optional) isn't installed + dask_dataframe = None + + +if dask_dataframe is not None: + from .explicit_comms.dataframe.shuffle import patch_shuffle_expression + from .proxify_device_objects import proxify_decorator, unproxify_decorator + + # Monkey patching Dask to make use of explicit-comms when `DASK_EXPLICIT_COMMS=True` + patch_shuffle_expression() + # Monkey patching Dask to make use of proxify and unproxify in compatibility mode + dask_dataframe.shuffle.shuffle_group = proxify_decorator( + dask.dataframe.shuffle.shuffle_group + ) + dask_dataframe.core._concat = unproxify_decorator(dask.dataframe.core._concat) + + def _register_cudf_spill_aware(): + import cudf + + # Only enable Dask/cuDF spilling if cuDF spilling is disabled, see + # https://github.com/rapidsai/dask-cuda/issues/1363 + if not cudf.get_option("spill"): + # This reproduces the implementation of `_register_cudf`, see + # https://github.com/dask/distributed/blob/40fcd65e991382a956c3b879e438be1b100dff97/distributed/protocol/__init__.py#L106-L115 + from cudf.comm import serialize + + for registry in [ + cuda_serialize, + cuda_deserialize, + dask_serialize, + dask_deserialize, + ]: + for lib in ["cudf", "dask_cudf"]: + if lib in registry._lazy: + registry._lazy[lib] = _register_cudf_spill_aware diff --git a/dask_cuda/proxy_object.py b/dask_cuda/proxy_object.py index b42af7b1c..e1a4e7d68 100644 --- a/dask_cuda/proxy_object.py +++ b/dask_cuda/proxy_object.py @@ -11,10 +11,6 @@ import pandas import dask -import dask.array.core -import dask.dataframe.backends -import dask.dataframe.dispatch -import dask.dataframe.utils import dask.utils import distributed.protocol import distributed.utils @@ -30,6 +26,22 @@ from .proxify_host_file import ProxyManager +try: + import dask.dataframe as dask_dataframe + import dask.dataframe.backends + import dask.dataframe.dispatch + import dask.dataframe.utils +except ImportError: + dask_dataframe = None + + +try: + import dask.array as dask_array + import dask.array.core +except ImportError: + dask_array = None + + # List of attributes that should be copied to the proxy at creation, which makes # them accessible without deserialization of the proxied object _FIXED_ATTRS = ["name", "__len__"] @@ -884,14 +896,6 @@ def obj_pxy_dask_deserialize(header, frames): return subclass(pxy) -@dask.dataframe.dispatch.get_parallel_type.register(ProxyObject) -def get_parallel_type_proxy_object(obj: ProxyObject): - # Notice, `get_parallel_type()` needs a instance not a type object - return dask.dataframe.dispatch.get_parallel_type( - obj.__class__.__new__(obj.__class__) - ) - - def unproxify_input_wrapper(func): """Unproxify the input of `func`""" @@ -904,26 +908,42 @@ def wrapper(*args, **kwargs): return wrapper -# Register dispatch of ProxyObject on all known dispatch objects -for dispatch in ( - dask.dataframe.dispatch.hash_object_dispatch, - dask.dataframe.dispatch.make_meta_dispatch, - dask.dataframe.utils.make_scalar, - dask.dataframe.dispatch.group_split_dispatch, - dask.array.core.tensordot_lookup, - dask.array.core.einsum_lookup, - dask.array.core.concatenate_lookup, -): - dispatch.register(ProxyObject, unproxify_input_wrapper(dispatch)) - -dask.dataframe.dispatch.concat_dispatch.register( - ProxyObject, unproxify_input_wrapper(dask.dataframe.dispatch.concat) -) - - -# We overwrite the Dask dispatch of Pandas objects in order to -# deserialize all ProxyObjects before concatenating -dask.dataframe.dispatch.concat_dispatch.register( - (pandas.DataFrame, pandas.Series, pandas.Index), - unproxify_input_wrapper(dask.dataframe.backends.concat_pandas), -) +if dask_array is not None: + + # Register dispatch of ProxyObject on all known dispatch objects + for dispatch in ( + dask.array.core.tensordot_lookup, + dask.array.core.einsum_lookup, + dask.array.core.concatenate_lookup, + ): + dispatch.register(ProxyObject, unproxify_input_wrapper(dispatch)) + + +if dask_dataframe is not None: + + @dask.dataframe.dispatch.get_parallel_type.register(ProxyObject) + def get_parallel_type_proxy_object(obj: ProxyObject): + # Notice, `get_parallel_type()` needs a instance not a type object + return dask.dataframe.dispatch.get_parallel_type( + obj.__class__.__new__(obj.__class__) + ) + + # Register dispatch of ProxyObject on all known dispatch objects + for dispatch in ( + dask.dataframe.dispatch.hash_object_dispatch, + dask.dataframe.dispatch.make_meta_dispatch, + dask.dataframe.utils.make_scalar, + dask.dataframe.dispatch.group_split_dispatch, + ): + dispatch.register(ProxyObject, unproxify_input_wrapper(dispatch)) + + dask.dataframe.dispatch.concat_dispatch.register( + ProxyObject, unproxify_input_wrapper(dask.dataframe.dispatch.concat) + ) + + # We overwrite the Dask dispatch of Pandas objects in order to + # deserialize all ProxyObjects before concatenating + dask.dataframe.dispatch.concat_dispatch.register( + (pandas.DataFrame, pandas.Series, pandas.Index), + unproxify_input_wrapper(dask.dataframe.backends.concat_pandas), + ) diff --git a/dask_cuda/tests/test_initialize.py b/dask_cuda/tests/test_initialize.py index a953a10c1..a098b64cf 100644 --- a/dask_cuda/tests/test_initialize.py +++ b/dask_cuda/tests/test_initialize.py @@ -1,4 +1,5 @@ import multiprocessing as mp +import sys import numpy import psutil @@ -214,3 +215,38 @@ def test_initialize_ucx_all(protocol): p.start() p.join() assert not p.exitcode + + +def _test_dask_cuda_import(): + # Check that importing `dask_cuda` does NOT + # require `dask.dataframe` or `dask.array`. + + # Patch sys.modules so that `dask.dataframe` + # and `dask.array` cannot be found. + with pytest.MonkeyPatch.context() as monkeypatch: + for k in list(sys.modules): + if k.startswith("dask.dataframe") or k.startswith("dask.array"): + monkeypatch.setitem(sys.modules, k, None) + monkeypatch.delitem(sys.modules, "dask_cuda") + + # Check that top-level imports still succeed. + import dask_cuda # noqa: F401 + from dask_cuda import CUDAWorker # noqa: F401 + from dask_cuda import LocalCUDACluster + + with LocalCUDACluster( + dashboard_address=None, + n_workers=1, + threads_per_worker=1, + processes=True, + worker_class=IncreasedCloseTimeoutNanny, + ) as cluster: + with Client(cluster) as client: + client.run(lambda *args: None) + + +def test_dask_cuda_import(): + p = mp.Process(target=_test_dask_cuda_import) + p.start() + p.join() + assert not p.exitcode