Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[dask] allow tight control over ports #3994

Merged
merged 30 commits into from
Feb 23, 2021
Merged
Show file tree
Hide file tree
Changes from 19 commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
58beb97
[dask] allow tight control over ports
jameslamb Feb 16, 2021
a79b13d
getting there, getting there
jameslamb Feb 16, 2021
30828d3
fix params maybe
jameslamb Feb 16, 2021
6f538ea
merge master
jameslamb Feb 16, 2021
5c36083
fixing params
jameslamb Feb 16, 2021
ff6f50b
remove unnecessary stuff
jameslamb Feb 16, 2021
26bbc48
fix tests
jameslamb Feb 16, 2021
bbe0d10
fixes
jameslamb Feb 16, 2021
c352eb5
some minor changes
jameslamb Feb 17, 2021
58c6470
fix flaky test
jameslamb Feb 17, 2021
69dea53
linting
jameslamb Feb 17, 2021
0cc9d67
more linting
jameslamb Feb 17, 2021
52e0c39
merge master
jameslamb Feb 18, 2021
f00b3a7
Merge branch 'master' into feat/network-params
jameslamb Feb 18, 2021
25462ea
clarify parameter description
jameslamb Feb 18, 2021
5bdf5be
Merge branch 'master' into feat/network-params
jameslamb Feb 19, 2021
deeab63
add warning
jameslamb Feb 19, 2021
eeb75f5
Merge branch 'master' into feat/network-params
jameslamb Feb 19, 2021
0c81f60
revert docs change
jameslamb Feb 19, 2021
e1a4d4d
Update python-package/lightgbm/dask.py
jameslamb Feb 19, 2021
da1c0ea
Apply suggestions from code review
jameslamb Feb 21, 2021
e36b169
Merge branch 'master' into feat/network-params
jameslamb Feb 21, 2021
dcae2d0
trying to fix stuff
jameslamb Feb 21, 2021
d507210
Merge branch 'master' into feat/network-params
jameslamb Feb 22, 2021
e474c37
this is working
jameslamb Feb 22, 2021
1e9244d
update tests
jameslamb Feb 22, 2021
c36ec28
Merge branch 'master' into feat/network-params
jameslamb Feb 22, 2021
4fc9f70
Apply suggestions from code review
jameslamb Feb 23, 2021
040ad1f
Merge branch 'master' into feat/network-params
jameslamb Feb 23, 2021
b3c8a2c
indent
jameslamb Feb 23, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/Parameters.rst
Original file line number Diff line number Diff line change
Expand Up @@ -1137,7 +1137,7 @@ Network Parameters

- this parameter is needed to be set in both **socket** and **mpi** versions

- ``local_listen_port`` :raw-html:`<a id="local_listen_port" title="Permalink to this parameter" href="#local_listen_port">&#x1F517;&#xFE0E;</a>`, default = ``12400``, type = int, aliases: ``local_port``, ``port``, constraints: ``local_listen_port > 0``
- ``local_listen_port`` :raw-html:`<a id="local_listen_port" title="Permalink to this parameter" href="#local_listen_port">&#x1F517;&#xFE0E;</a>`, default = ``12400 (random for Dask-package)``, type = int, aliases: ``local_port``, ``port``, constraints: ``local_listen_port > 0``

- TCP listen port for local machines

Expand Down
1 change: 1 addition & 0 deletions include/LightGBM/config.h
Original file line number Diff line number Diff line change
Expand Up @@ -966,6 +966,7 @@ struct Config {
int num_machines = 1;

// check = >0
// default = 12400 (random for Dask-package)
// alias = local_port, port
// desc = TCP listen port for local machines
// desc = **Note**: don't forget to allow this port in firewall settings before training
Expand Down
3 changes: 1 addition & 2 deletions python-package/lightgbm/compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,14 +114,13 @@ def _check_sample_weight(sample_weight, X, dtype=None):
from dask.array import Array as dask_Array
from dask.dataframe import DataFrame as dask_DataFrame
from dask.dataframe import Series as dask_Series
from dask.distributed import Client, default_client, get_worker, wait
from dask.distributed import Client, default_client, wait
DASK_INSTALLED = True
except ImportError:
DASK_INSTALLED = False
delayed = None
Client = object
default_client = None
get_worker = None
wait = None

class dask_Array:
Expand Down
167 changes: 139 additions & 28 deletions python-package/lightgbm/dask.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,9 @@
import numpy as np
import scipy.sparse as ss

from .basic import _LIB, LightGBMError, _choose_param_value, _ConfigAliases, _log_warning, _safe_call
from .basic import _LIB, LightGBMError, _choose_param_value, _ConfigAliases, _log_info, _log_warning, _safe_call
from .compat import (DASK_INSTALLED, PANDAS_INSTALLED, SKLEARN_INSTALLED, Client, LGBMNotFittedError, concat,
dask_Array, dask_DataFrame, dask_Series, default_client, delayed, get_worker, pd_DataFrame,
pd_Series, wait)
dask_Array, dask_DataFrame, dask_Series, default_client, delayed, pd_DataFrame, pd_Series, wait)
from .sklearn import LGBMClassifier, LGBMModel, LGBMRanker, LGBMRegressor, _lgbmmodel_doc_fit, _lgbmmodel_doc_predict

_DaskCollection = Union[dask_Array, dask_DataFrame, dask_Series]
Expand Down Expand Up @@ -140,22 +139,18 @@ def _train_part(
params: Dict[str, Any],
model_factory: Type[LGBMModel],
list_of_parts: List[Dict[str, _DaskPart]],
worker_address_to_port: Dict[str, int],
machines: str,
local_listen_port: int,
num_machines: int,
return_model: bool,
time_out: int = 120,
**kwargs: Any
) -> Optional[LGBMModel]:
local_worker_address = get_worker().address
machine_list = ','.join([
'%s:%d' % (urlparse(worker_address).hostname, port)
for worker_address, port
in worker_address_to_port.items()
])
network_params = {
'machines': machine_list,
'local_listen_port': worker_address_to_port[local_worker_address],
'machines': machines,
'local_listen_port': local_listen_port,
'time_out': time_out,
'num_machines': len(worker_address_to_port)
'num_machines': num_machines
}
params.update(network_params)

Expand Down Expand Up @@ -199,6 +194,39 @@ def _split_to_parts(data: _DaskCollection, is_matrix: bool) -> List[_DaskPart]:
return parts


def _machines_to_worker_map(machines: str, worker_addresses: List[str]) -> Dict[str, int]:
"""Create a worker_map from machines list.

Given ``machines`` and a list of Dask worker addresses, return a mapping where the keys are
``worker_addresses`` and the values are ports from ``machines``.

Parameters
----------
machines : str
A comma-delimited list of workers, of the form ``ip1:port,ip2:port``.
worker_addresses : list of str
A list of Dask worker addresses, of the form ``{protocol}{hostname}:{port}``, where ``port`` is the port Dask's scheduler uses to talk to that worker.

Returns
-------
result : Dict[str, int]
Dictionary where keys are work addresses in the form expected by Dask and values are a port for LightGBM to use.
"""
machine_addresses = machines.split(",")
machine_to_port = {}
for address in machine_addresses:
host, port = address.split(":")
machine_to_port[host] = machine_to_port.get(host, set())
machine_to_port[host].add(int(port))
jameslamb marked this conversation as resolved.
Show resolved Hide resolved

out = {}
for address in worker_addresses:
worker_host = urlparse(address).hostname
out[address] = machine_to_port[worker_host].pop()

return out


def _train(
client: Client,
data: _DaskMatrixLike,
Expand Down Expand Up @@ -238,14 +266,44 @@ def _train(
-------
model : lightgbm.LGBMClassifier, lightgbm.LGBMRegressor, or lightgbm.LGBMRanker class
Returns fitted underlying model.

Note
----

This method handles setting up the following network parameters based on information
about the Dask cluster referenced by ``client``.

* ``local_listen_port``: port that each LightGBM worker opens a listening socket on,
to accept connections from other workers. This can be differ from LightGBM worker
jameslamb marked this conversation as resolved.
Show resolved Hide resolved
to LightGBM worker, but does not have to.
* ``machines``: a comma-delimited list of all workers in the cluster, in the
form ``ip:port,ip:port``. If running multiple Dask workers on the same host, use different
ports for each worker. For example, for ``LocalCluster(n_workers=3)``, you might
pass ``"127.0.0.1:12400,127.0.0.1:12401,127.0.0.1:12402"```.
jameslamb marked this conversation as resolved.
Show resolved Hide resolved
* ``num_machines``: number of LightGBM workers
* ``timeout``: time in minutes to wait before closing unused sockets
jameslamb marked this conversation as resolved.
Show resolved Hide resolved

The default behavior of this function is to generate ``machines`` from the list of
Dask workers which hold some piece of the training data, and to search for an open
port on each worker to be used as ``local_listen_port``.

If ``machines`` is provided explicitly in ``params``, this function uses the hosts
and ports in that list directly, and does not do any searching. This means that if
any of the Dask workers are missing from the list or any of those ports are not free
when training starts, training will fail.

If ``local_listen_port`` is provided in ``params`` and ``machines`` is not, this function
constructs ``machines`` from the list of Dask workers which hold some piece of the
training data, assuming that each one will use the same ``local_listen_port``.
"""
params = deepcopy(params)

params = _choose_param_value(
main_param_name="local_listen_port",
params=params,
default_value=12400
)
# capture whether local_listen_port or its aliases were provided
port_aliases = _ConfigAliases.get('local_listen_port')
listen_port_in_params = False
for param in params.keys():
if param in port_aliases:
listen_port_in_params = True
jameslamb marked this conversation as resolved.
Show resolved Hide resolved

params = _choose_param_value(
main_param_name="tree_learner",
Expand All @@ -271,10 +329,9 @@ def _train(
)

# Some passed-in parameters can be removed:
# * 'machines': constructed automatically from Dask worker list
# * 'num_machines': set automatically from Dask worker list
# * 'num_threads': overridden to match nthreads on each Dask process
for param_alias in _ConfigAliases.get('machines', 'num_machines', 'num_threads'):
for param_alias in _ConfigAliases.get('num_machines', 'num_threads'):
params.pop(param_alias, None)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe it is important to notify users about this behavior.

Suggested change
params.pop(param_alias, None)
if param_alias in params:
_log_warning(f"Parameter {param_alias} will be ignored.")
params.pop(param_alias)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sorry, I accepted this suggestion but I think now that we should only apply it to num_machines, not num_threads.

This results in a warning that users cannot suppress.

/opt/conda/lib/python3.8/site-packages/lightgbm/dask.py:338: UserWarning: Parameter n_jobs will be ignored.
_log_warning(f"Parameter {param_alias} will be ignored.")

Caused by the fact that n_jobs is an alias of num_threads

I believe that every warning should be something that can be changed by user code changes. Otherwise, we're just adding noise to logs that might cause people to start filtering out ALL warnings.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it is important to notify users about num_threads as well before implementing #3714. Silently ignore parameter is more serious problem compared to unfixable warning, I believe.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I disagree in this specific case about the meaning of "ignore", since this is a parameter default and not something explicitly passed in. However, since num_threads isn't directly related to the purpose of this PR and since I don't want to delay this PR too long because I'd like to merge #3823 soon after it, I'll leave this warning in for now and propose another PR in a few days where we could discuss it further.


# Split arrays/dataframes into parts. Arrange parts into dicts to enforce co-locality
Expand Down Expand Up @@ -312,14 +369,60 @@ def _train(
master_worker = next(iter(worker_map))
worker_ncores = client.ncores()

# find an open port on each worker. note that multiple workers can run
# on the same machine, so this needs to ensure that each one gets its
# own port
worker_address_to_port = _find_ports_for_workers(
client=client,
worker_addresses=worker_map.keys(),
local_listen_port=params["local_listen_port"]
# resolve aliases for network parameters and pop the result off params.
# these values are added back in calls to `_train_part()`
params = _choose_param_value(
main_param_name="local_listen_port",
params=params,
default_value=12400
)
local_listen_port = params.pop("local_listen_port")

params = _choose_param_value(
main_param_name="machines",
params=params,
default_value=None
)
machines = params.pop("machines")

# figure out network params
worker_addresses = worker_map.keys()
if machines is not None:
_log_info("Using passed-in 'machines' parameter")
worker_address_to_port = _machines_to_worker_map(
machines=machines,
worker_addresses=worker_addresses
)
else:
if listen_port_in_params:
_log_info("Using passed-in 'local_listen_port' for all workers")
unique_hosts = set([urlparse(a).hostname for a in worker_addresses])
jameslamb marked this conversation as resolved.
Show resolved Hide resolved
if len(unique_hosts) < len(worker_addresses):
jameslamb marked this conversation as resolved.
Show resolved Hide resolved
msg = (
"'local_listen_port' was provided in Dask training parameters, but at least one "
"machine in the cluster has multiple Dask worker processes running on it. Please omit "
"'local_listen_port' or pass 'machines'."
Comment on lines +407 to +408
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
"machine in the cluster has multiple Dask worker processes running on it. Please omit "
"'local_listen_port' or pass 'machines'."
"machine in the cluster has multiple Dask worker processes running on it.\nPlease omit "
"'local_listen_port' or pass full configuration via 'machines' parameter."

Copy link
Collaborator Author

@jameslamb jameslamb Feb 21, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you explain why you think we should include a newline?

I'm concerned that in logs, it will look like an exception with only the text before the newline followed by a separate print statement.

image

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, it is up to you. Feel free to revert new line. I personally don't like long line warnings/errors.

Copy link
Collaborator Author

@jameslamb jameslamb Feb 22, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

alright I'm not going to accept this suggestion then if it's just a matter a matter of personal preference.

I've had problems in the past with external logs-management systems and log messages that have newline characters. You can read about that general problem at https://www.datadoghq.com/blog/multiline-logging-guide/#the-multi-line-logging-problem if you're interested.

Long log messages will also be wrapped automatically in Jupyter notebooks

image

and in python REPLs

image

)
raise LightGBMError(msg)

worker_address_to_port = {
address: local_listen_port
for address in worker_addresses
}
else:
_log_info("Finding random open ports for workers")
worker_address_to_port = _find_ports_for_workers(
client=client,
worker_addresses=worker_map.keys(),
local_listen_port=local_listen_port
)
machines = ','.join([
'%s:%d' % (urlparse(worker_address).hostname, port)
for worker_address, port
in worker_address_to_port.items()
])

num_machines = len(worker_address_to_port)

# Tell each worker to train on the parts that it has locally
futures_classifiers = [
Expand All @@ -328,7 +431,9 @@ def _train(
model_factory=model_factory,
params={**params, 'num_threads': worker_ncores[worker]},
list_of_parts=list_of_parts,
worker_address_to_port=worker_address_to_port,
machines=machines,
local_listen_port=worker_address_to_port[worker],
num_machines=num_machines,
time_out=params.get('time_out', 120),
return_model=(worker == master_worker),
**kwargs
Expand Down Expand Up @@ -500,6 +605,12 @@ def _fit(
**kwargs
)

# if network parameters were updated during training, remove them so that
# they're generated dynamically on every run based on the Dask cluster you're
# connected to and which workers have pieces of the training data
for param in _ConfigAliases.get('local_listen_port', 'machines', 'num_machines', 'timeout'):
model._other_params.pop(param, None)

self.set_params(**model.get_params())
self._copy_extra_params(model, self)

Expand Down
Loading