diff --git a/dask_cuda/initialize.py b/dask_cuda/initialize.py index 2317f8d5..bb8fa628 100644 --- a/dask_cuda/initialize.py +++ b/dask_cuda/initialize.py @@ -45,7 +45,10 @@ def _warn_generic(): _create_cuda_context_handler() - if not distributed.comm.ucx.cuda_context_created.has_context: + if ( + distributed.comm.ucx.cuda_context_created is False + or distributed.comm.ucx.cuda_context_created.has_context + ): ctx = has_cuda_context() if ctx.has_context and ctx.device_info != cuda_visible_device: distributed.comm.ucx._warn_cuda_context_wrong_device( @@ -184,40 +187,10 @@ def initialize( default=False, help="Create CUDA context", ) -@click.option( - "--protocol", - default=None, - type=str, - help="Communication protocol, such as: 'tcp', 'tls', 'ucx' or 'ucxx'.", -) -@click.option( - "--enable-tcp-over-ucx/--disable-tcp-over-ucx", - default=False, - help="Enable TCP communication over UCX", -) -@click.option( - "--enable-infiniband/--disable-infiniband", - default=False, - help="Enable InfiniBand communication", -) -@click.option( - "--enable-nvlink/--disable-nvlink", - default=False, - help="Enable NVLink communication", -) -@click.option( - "--enable-rdmacm/--disable-rdmacm", - default=False, - help="Enable RDMA connection manager, currently requires InfiniBand enabled.", -) def dask_setup( - service, + worker, create_cuda_context, - protocol, - enable_tcp_over_ucx, - enable_infiniband, - enable_nvlink, - enable_rdmacm, ): + protocol = worker._protocol.split("://")[0] if create_cuda_context: _create_cuda_context(protocol=protocol) diff --git a/dask_cuda/local_cuda_cluster.py b/dask_cuda/local_cuda_cluster.py index 26383078..b0e0a686 100644 --- a/dask_cuda/local_cuda_cluster.py +++ b/dask_cuda/local_cuda_cluster.py @@ -420,7 +420,7 @@ def __init__( ) + ["dask_cuda.initialize"] self.new_spec["options"]["preload_argv"] = self.new_spec["options"].get( "preload_argv", [] - ) + ["--create-cuda-context", "--protocol", protocol] + ) + ["--create-cuda-context"] self.cuda_visible_devices = CUDA_VISIBLE_DEVICES self.scale(n_workers) diff --git a/dask_cuda/tests/test_dask_setup.py b/dask_cuda/tests/test_dask_setup.py new file mode 100644 index 00000000..649e9c98 --- /dev/null +++ b/dask_cuda/tests/test_dask_setup.py @@ -0,0 +1,195 @@ +# SPDX-FileCopyrightText: Copyright (c) 2019-2025, NVIDIA CORPORATION & AFFILIATES. +# SPDX-License-Identifier: Apache-2.0 + +import json +import os +import time +from contextlib import contextmanager +from unittest.mock import Mock, patch + +import pytest + +from distributed import Client +from distributed.utils import open_port +from distributed.utils_test import popen + +from dask_cuda.initialize import dask_setup +from dask_cuda.utils import wait_workers + + +@pytest.mark.parametrize("protocol", ["tcp", "ucx", "ucxx"]) +def test_dask_setup_function_with_mock_worker(protocol): + """Test the dask_setup function directly with mock worker.""" + # Create a mock worker object + mock_worker = Mock() + mock_worker._protocol = protocol + + with patch("dask_cuda.initialize._create_cuda_context") as mock_create_context: + # Test with create_cuda_context=True + # Call the underlying function directly (the Click decorator wraps the real + # function) + dask_setup.callback( + worker=mock_worker, + create_cuda_context=True, + ) + + mock_create_context.assert_called_once_with(protocol=protocol) + + mock_create_context.reset_mock() + + # Test with create_cuda_context=False + dask_setup.callback( + worker=mock_worker, + create_cuda_context=False, + ) + + mock_create_context.assert_not_called() + + +@contextmanager +def start_dask_scheduler(protocol: str, max_attempts: int = 5, timeout: int = 10): + """Start Dask scheduler in subprocess. + + Attempts to start a Dask scheduler in subprocess, if the port is not available + retry on a different port up to a maximum of `max_attempts` attempts. The stdout + and stderr of the process is read to determine whether the scheduler failed to + bind to port or succeeded, and ensures no more than `timeout` seconds are awaited + for between reads. + + This is primarily useful because UCX does not release TCP ports immediately. A + workaround without the need for this function is setting `UCX_TCP_CM_REUSEADDR=y`, + but that requires to be explicitly set when running tests, and that is not very + friendly. + + Parameters + ---------- + protocol: str + Communication protocol to use. + max_attempts: int + Maximum attempts to try to open scheduler. + timeout: int + Time to wait while reading stdout/stderr of subprocess. + """ + port = open_port() + for _ in range(max_attempts): + with popen( + [ + "dask", + "scheduler", + "--no-dashboard", + "--protocol", + protocol, + "--port", + str(port), + ], + capture_output=True, # Capture stdout and stderr + ) as scheduler_process: + # Check if the scheduler process started successfully by streaming output + try: + start_time = time.monotonic() + while True: + if time.monotonic() - start_time > timeout: + raise TimeoutError("Timeout while waiting for scheduler output") + + line = scheduler_process.stdout.readline() + if not line: + break # End of output + print( + line.decode(), end="" + ) # Since capture_output=True, print the line here + if b"Scheduler at:" in line: + # Scheduler is now listening + break + elif b"UCXXBusyError" in line: + raise Exception("UCXXBusyError detected in scheduler output") + except Exception: + port += 1 + else: + yield scheduler_process, port + return + else: + pytest.fail(f"Failed to start dask scheduler after {max_attempts} attempts.") + + +@pytest.mark.timeout(30) +@patch.dict(os.environ, {"CUDA_VISIBLE_DEVICES": "0"}) +@pytest.mark.parametrize("protocol", ["tcp", "ucx", "ucxx"]) +def test_dask_cuda_worker_cli_integration(protocol, tmp_path): + """Test that dask cuda worker CLI correctly passes arguments to dask_setup. + + Verifies the end-to-end integration where the CLI tool actually launches and calls + dask_setup with correct args. + """ + + # Use pytest's tmp_path for file management + capture_file_path = tmp_path / "dask_setup_integration_test.json" + preload_file = tmp_path / "preload_capture.py" + + # Write the preload script to tmp_path + preload_file.write_text( + f''' +import json +import os + +def capture_dask_setup_call(worker, create_cuda_context): + """Capture dask_setup arguments and write to file.""" + result = {{ + 'worker_protocol': getattr(worker, '_protocol', 'unknown'), + 'create_cuda_context': create_cuda_context, + 'test_success': True + }} + + # Write immediately to ensure it gets captured + with open(r"{capture_file_path}", 'w') as f: + json.dump(result, f) + +# Patch dask_setup callback +from dask_cuda.initialize import dask_setup +dask_setup.callback = capture_dask_setup_call +''' + ) + + with start_dask_scheduler(protocol=protocol) as scheduler_process_port: + scheduler_process, scheduler_port = scheduler_process_port + sched_addr = f"{protocol}://127.0.0.1:{scheduler_port}" + print(f"{sched_addr=}", flush=True) + + # Build dask cuda worker args + dask_cuda_worker_args = [ + "dask", + "cuda", + "worker", + sched_addr, + "--host", + "127.0.0.1", + "--no-dashboard", + "--preload", + str(preload_file), + "--death-timeout", + "10", + ] + + with popen(dask_cuda_worker_args): + # Wait and check for worker connection + with Client(sched_addr) as client: + assert wait_workers(client, n_gpus=1) + + # Check if dask_setup was called and captured correctly + if capture_file_path.exists(): + with open(capture_file_path, "r") as cf: + captured_args = json.load(cf) + + # Verify the critical arguments were passed correctly + assert ( + captured_args["create_cuda_context"] is True + ), "create_cuda_context should be True" + + # Verify worker has a protocol set + assert ( + captured_args["worker_protocol"] == protocol + ), "Worker should have a protocol" + else: + pytest.fail( + "capture file not found: dask_setup was not called or " + "failed to write to file" + )