Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[dask] use random ports in network setup #3823

Merged
merged 11 commits into from
Feb 24, 2021
Prev Previous commit
Next Next commit
parametrize local_listen_port. type hint to _find_random_open_port. f…
…id open ports only on workers with data.
jmoralez committed Feb 3, 2021
commit ba7977c7ff8fae38494caf5f2d55158f23651850
5 changes: 3 additions & 2 deletions python-package/lightgbm/dask.py
Original file line number Diff line number Diff line change
@@ -107,7 +107,7 @@ def _find_ports_for_workers(client: Client, worker_addresses: Iterable[str], loc
return worker_ip_to_port


def _find_random_open_port():
def _find_random_open_port() -> int:
"""Find a random open port on the machine.

Returns
@@ -315,7 +315,8 @@ def _train(
if params["local_listen_port"] == 12400:
# If local_listen_port is set to LightGBM's default value
# then we find a random open port for each worker
worker_address_to_port = client.run(_find_random_open_port)
worker_address_to_port = client.run(_find_random_open_port,
workers=worker_map.keys())
else:
# If another port was specified then we search for an open port
# in [local_listen_port, local_listen_port+999] for each worker
30 changes: 15 additions & 15 deletions tests/python_package_test/test_dask.py
Original file line number Diff line number Diff line change
@@ -241,27 +241,27 @@ def test_different_ports_on_local_cluster(client):
assert len(set(found_ports)) == len(found_ports)


def test_training_does_not_fail_on_port_conflicts(client):
@pytest.mark.parametrize('local_listen_port', [None, 12400, 13000])
def test_training_does_not_fail_on_port_conflicts(client, local_listen_port):
_, _, _, dX, dy, dw = _create_data('classification', output='array')

with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.bind(('127.0.0.1', 12400))

for local_listen_port in (None, 12400, 13000):
dask_classifier = lgb.DaskLGBMClassifier(
time_out=5,
local_listen_port=local_listen_port,
n_estimators=5,
num_leaves=5
dask_classifier = lgb.DaskLGBMClassifier(
time_out=5,
local_listen_port=local_listen_port,
n_estimators=5,
num_leaves=5
)
for _ in range(5):
dask_classifier.fit(
X=dX,
y=dy,
sample_weight=dw,
client=client
)
for _ in range(5):
dask_classifier.fit(
X=dX,
y=dy,
sample_weight=dw,
client=client
)
assert dask_classifier.booster_
assert dask_classifier.booster_

client.close(timeout=CLIENT_CLOSE_TIMEOUT)