Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[dask] use random ports in network setup #3823

Merged
merged 11 commits into from
Feb 24, 2021
34 changes: 29 additions & 5 deletions python-package/lightgbm/dask.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,22 @@ def _find_ports_for_workers(client: Client, worker_addresses: Iterable[str], loc
return worker_ip_to_port


def _find_random_open_port() -> int:
"""Find a random open port on the machine.

Returns
-------
port : int
A free port on the machine.
"""
import socket

with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
jmoralez marked this conversation as resolved.
Show resolved Hide resolved
s.bind(('', 0))
port = s.getsockname()[1]
return port


def _concat(seq: List[_DaskPart]) -> _DaskPart:
if isinstance(seq[0], np.ndarray):
return np.concatenate(seq, axis=0)
Expand Down Expand Up @@ -296,11 +312,19 @@ def _train(
# find an open port on each worker. note that multiple workers can run
# on the same machine, so this needs to ensure that each one gets its
# own port
worker_address_to_port = _find_ports_for_workers(
client=client,
worker_addresses=worker_map.keys(),
local_listen_port=params["local_listen_port"]
)
if params["local_listen_port"] == 12400:
# If local_listen_port is set to LightGBM's default value
# then we find a random open port for each worker
worker_address_to_port = client.run(_find_random_open_port,
workers=worker_map.keys())
jmoralez marked this conversation as resolved.
Show resolved Hide resolved
else:
# If another port was specified then we search for an open port
# in [local_listen_port, local_listen_port+999] for each worker
worker_address_to_port = _find_ports_for_workers(
client=client,
worker_addresses=worker_map.keys(),
local_listen_port=params["local_listen_port"]
)

# Tell each worker to train on the parts that it has locally
futures_classifiers = [
Expand Down
12 changes: 10 additions & 2 deletions tests/python_package_test/test_dask.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,15 +234,23 @@ def test_classifier_pred_contrib(output, centers, client, listen_port):
client.close(timeout=CLIENT_CLOSE_TIMEOUT)


def test_training_does_not_fail_on_port_conflicts(client):
def test_different_ports_on_local_cluster(client):
for _ in range(5):
worker_address_to_port = client.run(lgb.dask._find_random_open_port)
found_ports = worker_address_to_port.values()
assert len(set(found_ports)) == len(found_ports)


@pytest.mark.parametrize('local_listen_port', [None, 12400, 13000])
def test_training_does_not_fail_on_port_conflicts(client, local_listen_port):
_, _, _, dX, dy, dw = _create_data('classification', output='array')

with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.bind(('127.0.0.1', 12400))

dask_classifier = lgb.DaskLGBMClassifier(
time_out=5,
local_listen_port=12400,
local_listen_port=local_listen_port,
n_estimators=5,
num_leaves=5
)
Expand Down