From a10da4e6b02bb4d72aaaf5f59110d9577c7b0a5d Mon Sep 17 00:00:00 2001 From: fis Date: Tue, 8 Dec 2020 03:09:30 +0800 Subject: [PATCH 1/3] Try bring back host ip. --- python-package/xgboost/dask.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/python-package/xgboost/dask.py b/python-package/xgboost/dask.py index c449c0e75665..79fb131c36c2 100644 --- a/python-package/xgboost/dask.py +++ b/python-package/xgboost/dask.py @@ -67,11 +67,9 @@ LOGGER = logging.getLogger('[xgboost.dask]') -def _start_tracker(n_workers): +def _start_tracker(host, n_workers): """Start Rabit tracker """ env = {'DMLC_NUM_WORKER': n_workers} - import socket - host = socket.gethostbyname(socket.gethostname()) rabit_context = RabitTracker(hostIP=host, nslave=n_workers) env.update(rabit_context.slave_envs()) @@ -603,7 +601,9 @@ def _dmatrix_from_list_of_parts(is_quantile, **kwargs): async def _get_rabit_args(n_workers: int, client): '''Get rabit context arguments from data distribution in DaskDMatrix.''' - env = await client.run_on_scheduler(_start_tracker, n_workers) + host = distributed.comm.get_address_host(client.scheduler.address) + env = await client.run_on_scheduler( + _start_tracker, host.strip('/:'), n_workers) rabit_args = [('%s=%s' % item).encode() for item in env.items()] return rabit_args From 55faba59866c0f3728a77e06dd2889302c2634f3 Mon Sep 17 00:00:00 2001 From: fis Date: Tue, 8 Dec 2020 03:29:32 +0800 Subject: [PATCH 2/3] Get host ip. --- python-package/xgboost/dask.py | 9 ++++----- python-package/xgboost/tracker.py | 21 +++++++++++++++++++++ 2 files changed, 25 insertions(+), 5 deletions(-) diff --git a/python-package/xgboost/dask.py b/python-package/xgboost/dask.py index 79fb131c36c2..d836e96c3292 100644 --- a/python-package/xgboost/dask.py +++ b/python-package/xgboost/dask.py @@ -33,7 +33,7 @@ from .core import DMatrix, DeviceQuantileDMatrix, Booster, _expect, DataIter from .core import _deprecate_positional_args from .training import train as worker_train -from .tracker import RabitTracker +from .tracker import RabitTracker, get_host_ip from .sklearn import XGBModel, XGBRegressorBase, XGBClassifierBase from .sklearn import xgboost_model_doc @@ -67,9 +67,10 @@ LOGGER = logging.getLogger('[xgboost.dask]') -def _start_tracker(host, n_workers): +def _start_tracker(n_workers): """Start Rabit tracker """ env = {'DMLC_NUM_WORKER': n_workers} + host = get_host_ip('auto') rabit_context = RabitTracker(hostIP=host, nslave=n_workers) env.update(rabit_context.slave_envs()) @@ -601,9 +602,7 @@ def _dmatrix_from_list_of_parts(is_quantile, **kwargs): async def _get_rabit_args(n_workers: int, client): '''Get rabit context arguments from data distribution in DaskDMatrix.''' - host = distributed.comm.get_address_host(client.scheduler.address) - env = await client.run_on_scheduler( - _start_tracker, host.strip('/:'), n_workers) + env = await client.run_on_scheduler(_start_tracker, n_workers) rabit_args = [('%s=%s' % item).encode() for item in env.items()] return rabit_args diff --git a/python-package/xgboost/tracker.py b/python-package/xgboost/tracker.py index 5b217b5c86f0..10edc8343f05 100644 --- a/python-package/xgboost/tracker.py +++ b/python-package/xgboost/tracker.py @@ -52,6 +52,27 @@ def get_some_ip(host): return socket.getaddrinfo(host, None)[0][4][0] +def get_host_ip(hostIP=None): + if hostIP is None or hostIP == 'auto': + hostIP = 'ip' + + if hostIP == 'dns': + hostIP = socket.getfqdn() + elif hostIP == 'ip': + from socket import gaierror + try: + hostIP = socket.gethostbyname(socket.getfqdn()) + except gaierror: + logging.warn('gethostbyname(socket.getfqdn()) failed... trying on hostname()') + hostIP = socket.gethostbyname(socket.gethostname()) + if hostIP.startswith("127."): + s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) + # doesn't have to be reachable + s.connect(('10.255.255.255', 1)) + hostIP = s.getsockname()[0] + return hostIP + + def get_family(addr): return socket.getaddrinfo(addr, None)[0][0] From a547d3bb9e082dc016d26604429c26827cbc590d Mon Sep 17 00:00:00 2001 From: fis Date: Tue, 8 Dec 2020 03:41:33 +0800 Subject: [PATCH 3/3] Lint. --- python-package/xgboost/tracker.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/python-package/xgboost/tracker.py b/python-package/xgboost/tracker.py index 10edc8343f05..700b6898fa44 100644 --- a/python-package/xgboost/tracker.py +++ b/python-package/xgboost/tracker.py @@ -63,7 +63,8 @@ def get_host_ip(hostIP=None): try: hostIP = socket.gethostbyname(socket.getfqdn()) except gaierror: - logging.warn('gethostbyname(socket.getfqdn()) failed... trying on hostname()') + logging.warning( + 'gethostbyname(socket.getfqdn()) failed... trying on hostname()') hostIP = socket.gethostbyname(socket.gethostname()) if hostIP.startswith("127."): s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)