Skip to content
This repository was archived by the owner on Sep 18, 2024. It is now read-only.

Commit b177bdc

Browse files
authored
support hybrid training service v2.0 config (#3251)
1 parent 6330df2 commit b177bdc

File tree

11 files changed

+162
-58
lines changed

11 files changed

+162
-58
lines changed

examples/trials/mnist-tfv1/config_heterogeneous.yml renamed to examples/trials/mnist-tfv1/config_hybrid.yml

+3-4
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@ remoteConfig:
2626
reuse: true
2727
machineList:
2828
- ip: 10.1.1.1
29-
username: bob
30-
passwd: bob123
31-
#port can be skip if using default ssh port 22
32-
#port: 22
29+
username: xxx
30+
passwd: xxx
31+
port: 22
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
experimentName: example_mnist
2+
trialConcurrency: 3
3+
maxExperimentDuration: 1h
4+
maxTrialNumber: 10
5+
searchSpaceFile: search_space.json
6+
7+
trialCodeDirectory: .
8+
trialCommand: python3 mnist.py
9+
trialGpuNumber: 0
10+
tuner:
11+
name: TPE
12+
classArgs:
13+
optimize_mode: maximize
14+
15+
trainingService:
16+
- platform: local
17+
- platform: remote
18+
reuseMode: true
19+
machineList:
20+
- host: 10.1.1.1
21+
user: xxx
22+
password: xxx
23+
#port can be skip if using default ssh port 22
24+
port: 22
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
# FIXME: For demonstration only. It should not be here
2+
3+
from pathlib import Path
4+
5+
from nni.experiment import Experiment
6+
from nni.experiment import RemoteMachineConfig
7+
from nni.algorithms.hpo.hyperopt_tuner import HyperoptTuner
8+
9+
tuner = HyperoptTuner('tpe')
10+
11+
search_space = {
12+
"dropout_rate": { "_type": "uniform", "_value": [0.5, 0.9] },
13+
"conv_size": { "_type": "choice", "_value": [2, 3, 5, 7] },
14+
"hidden_size": { "_type": "choice", "_value": [124, 512, 1024] },
15+
"batch_size": { "_type": "choice", "_value": [16, 32] },
16+
"learning_rate": { "_type": "choice", "_value": [0.0001, 0.001, 0.01, 0.1] }
17+
}
18+
19+
experiment = Experiment(tuner, ['local', 'remote'])
20+
experiment.config.experiment_name = 'test'
21+
experiment.config.trial_concurrency = 3
22+
experiment.config.max_trial_number = 10
23+
experiment.config.search_space = search_space
24+
experiment.config.trial_command = 'python3 mnist.py'
25+
experiment.config.trial_code_directory = Path(__file__).parent
26+
experiment.config.training_service[0].use_active_gpu = True
27+
experiment.config.training_service[1].reuse_mode = True
28+
rm_conf = RemoteMachineConfig()
29+
rm_conf.host = '10.1.1.1'
30+
rm_conf.user = 'xxx'
31+
rm_conf.password = 'xxx'
32+
rm_conf.port = 22
33+
experiment.config.training_service[1].machine_list = [rm_conf]
34+
35+
experiment.run(26780, debug=True)

nni/experiment/config/common.py

+9-5
Original file line numberDiff line numberDiff line change
@@ -65,15 +65,19 @@ class ExperimentConfig(ConfigBase):
6565
tuner: Optional[_AlgorithmConfig] = None
6666
accessor: Optional[_AlgorithmConfig] = None
6767
advisor: Optional[_AlgorithmConfig] = None
68-
training_service: TrainingServiceConfig
68+
training_service: Union[TrainingServiceConfig, List[TrainingServiceConfig]]
6969

70-
def __init__(self, training_service_platform: Optional[str] = None, **kwargs):
70+
def __init__(self, training_service_platform: Optional[Union[str, List[str]]] = None, **kwargs):
7171
kwargs = util.case_insensitive(kwargs)
7272
if training_service_platform is not None:
7373
assert 'trainingservice' not in kwargs
74-
kwargs['trainingservice'] = util.training_service_config_factory(training_service_platform)
75-
elif isinstance(kwargs.get('trainingservice'), dict):
76-
kwargs['trainingservice'] = util.training_service_config_factory(**kwargs['trainingservice'])
74+
kwargs['trainingservice'] = util.training_service_config_factory(platform = training_service_platform)
75+
elif isinstance(kwargs.get('trainingservice'), (dict, list)):
76+
# dict means a single training service
77+
# list means hybrid training service
78+
kwargs['trainingservice'] = util.training_service_config_factory(config = kwargs['trainingservice'])
79+
else:
80+
raise RuntimeError('Unsupported Training service configuration!')
7781
super().__init__(**kwargs)
7882

7983
def validate(self, initialized_tuner: bool = False) -> None:

nni/experiment/config/convert.py

+52-30
Original file line numberDiff line numberDiff line change
@@ -18,16 +18,28 @@ def to_v1_yaml(config: ExperimentConfig, skip_nnictl: bool = False) -> Dict[str,
1818
data = config.json()
1919

2020
ts = data.pop('trainingService')
21-
if ts['platform'] == 'openpai':
22-
ts['platform'] = 'pai'
21+
if isinstance(ts, list):
22+
hybrid_names = []
23+
for conf in ts:
24+
if conf['platform'] == 'openpai':
25+
conf['platform'] = 'pai'
26+
hybrid_names.append(conf['platform'])
27+
_handle_training_service(conf, data)
28+
data['trainingServicePlatform'] = 'hybrid'
29+
data['hybridConfig'] = {'trainingServicePlatforms': hybrid_names}
30+
else:
31+
if ts['platform'] == 'openpai':
32+
ts['platform'] = 'pai'
33+
data['trainingServicePlatform'] = ts['platform']
34+
_handle_training_service(ts, data)
2335

2436
data['authorName'] = 'N/A'
2537
data['experimentName'] = data.get('experimentName', 'N/A')
2638
data['maxExecDuration'] = data.pop('maxExperimentDuration', '999d')
2739
if data['debug']:
2840
data['versionCheck'] = False
2941
data['maxTrialNum'] = data.pop('maxTrialNumber', 99999)
30-
data['trainingServicePlatform'] = ts['platform']
42+
3143
ss = data.pop('searchSpace', None)
3244
ss_file = data.pop('searchSpaceFile', None)
3345
if ss is not None:
@@ -66,6 +78,9 @@ def to_v1_yaml(config: ExperimentConfig, skip_nnictl: bool = False) -> Dict[str,
6678
if 'trialGpuNumber' in data:
6779
data['trial']['gpuNum'] = data.pop('trialGpuNumber')
6880

81+
return data
82+
83+
def _handle_training_service(ts, data):
6984
if ts['platform'] == 'local':
7085
data['localConfig'] = {
7186
'useActiveGpu': ts.get('useActiveGpu', False),
@@ -140,8 +155,6 @@ def to_v1_yaml(config: ExperimentConfig, skip_nnictl: bool = False) -> Dict[str,
140155
elif ts['platform'] == 'adl':
141156
data['trial']['image'] = ts['dockerImage']
142157

143-
return data
144-
145158
def _convert_gpu_indices(indices):
146159
return ','.join(str(idx) for idx in indices) if indices is not None else None
147160

@@ -175,19 +188,34 @@ def to_cluster_metadata(config: ExperimentConfig) -> List[Dict[str, Any]]:
175188
experiment_config = to_v1_yaml(config, skip_nnictl=True)
176189
ret = []
177190

178-
if config.training_service.platform == 'local':
191+
if isinstance(config.training_service, list):
192+
hybrid_conf = dict()
193+
hybrid_conf['hybrid_config'] = experiment_config['hybridConfig']
194+
for conf in config.training_service:
195+
metadata = _get_cluster_metadata(conf.platform, experiment_config)
196+
if metadata is not None:
197+
hybrid_conf.update(metadata)
198+
ret.append(hybrid_conf)
199+
else:
200+
metadata = _get_cluster_metadata(config.training_service.platform, experiment_config)
201+
if metadata is not None:
202+
ret.append(metadata)
203+
204+
if experiment_config.get('nniManagerIp') is not None:
205+
ret.append({'nni_manager_ip': {'nniManagerIp': experiment_config['nniManagerIp']}})
206+
ret.append({'trial_config': experiment_config['trial']})
207+
return ret
208+
209+
def _get_cluster_metadata(platform: str, experiment_config) -> Dict:
210+
if platform == 'local':
179211
request_data = dict()
180212
request_data['local_config'] = experiment_config['localConfig']
181213
if request_data['local_config']:
182214
if request_data['local_config'].get('gpuIndices') and isinstance(request_data['local_config'].get('gpuIndices'), int):
183215
request_data['local_config']['gpuIndices'] = str(request_data['local_config'].get('gpuIndices'))
184-
if request_data['local_config'].get('maxTrialNumOnEachGpu'):
185-
request_data['local_config']['maxTrialNumOnEachGpu'] = request_data['local_config'].get('maxTrialNumOnEachGpu')
186-
if request_data['local_config'].get('useActiveGpu'):
187-
request_data['local_config']['useActiveGpu'] = request_data['local_config'].get('useActiveGpu')
188-
ret.append(request_data)
216+
return request_data
189217

190-
elif config.training_service.platform == 'remote':
218+
elif platform == 'remote':
191219
request_data = dict()
192220
if experiment_config.get('remoteConfig'):
193221
request_data['remote_config'] = experiment_config['remoteConfig']
@@ -198,31 +226,25 @@ def to_cluster_metadata(config: ExperimentConfig) -> List[Dict[str, Any]]:
198226
for i in range(len(request_data['machine_list'])):
199227
if isinstance(request_data['machine_list'][i].get('gpuIndices'), int):
200228
request_data['machine_list'][i]['gpuIndices'] = str(request_data['machine_list'][i].get('gpuIndices'))
201-
ret.append(request_data)
229+
return request_data
202230

203-
elif config.training_service.platform == 'openpai':
204-
ret.append({'pai_config': experiment_config['paiConfig']})
231+
elif platform == 'openpai':
232+
return {'pai_config': experiment_config['paiConfig']}
205233

206-
elif config.training_service.platform == 'aml':
207-
ret.append({'aml_config': experiment_config['amlConfig']})
234+
elif platform == 'aml':
235+
return {'aml_config': experiment_config['amlConfig']}
208236

209-
elif config.training_service.platform == 'kubeflow':
210-
ret.append({'kubeflow_config': experiment_config['kubeflowConfig']})
237+
elif platform == 'kubeflow':
238+
return {'kubeflow_config': experiment_config['kubeflowConfig']}
211239

212-
elif config.training_service.platform == 'frameworkcontroller':
213-
ret.append({'frameworkcontroller_config': experiment_config['frameworkcontrollerConfig']})
240+
elif platform == 'frameworkcontroller':
241+
return {'frameworkcontroller_config': experiment_config['frameworkcontrollerConfig']}
214242

215-
elif config.training_service.platform == 'adl':
216-
pass
243+
elif platform == 'adl':
244+
return None
217245

218246
else:
219-
raise RuntimeError('Unsupported training service ' + config.training_service.platform)
220-
221-
if experiment_config.get('nniManagerIp') is not None:
222-
ret.append({'nni_manager_ip': {'nniManagerIp': experiment_config['nniManagerIp']}})
223-
ret.append({'trial_config': experiment_config['trial']})
224-
return ret
225-
247+
raise RuntimeError('Unsupported training service ' + platform)
226248

227249
def to_rest_json(config: ExperimentConfig) -> Dict[str, Any]:
228250
experiment_config = to_v1_yaml(config, skip_nnictl=True)

nni/experiment/config/remote.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ class RemoteMachineConfig(ConfigBase):
1818
port: int = 22
1919
user: str
2020
password: Optional[str] = None
21-
ssh_key_file: PathLike = '~/.ssh/id_rsa'
21+
ssh_key_file: PathLike = None #'~/.ssh/id_rsa'
2222
ssh_passphrase: Optional[str] = None
2323
use_active_gpu: bool = False
2424
max_trial_number_per_gpu: int = 1

nni/experiment/config/util.py

+20-6
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import math
99
import os.path
1010
from pathlib import Path
11-
from typing import Any, Dict, Optional, Union
11+
from typing import Any, Dict, Optional, Union, List
1212

1313
PathLike = Union[Path, str]
1414

@@ -29,12 +29,26 @@ def canonical_path(path: Optional[PathLike]) -> Optional[str]:
2929
def count(*values) -> int:
3030
return sum(value is not None and value is not False for value in values)
3131

32-
def training_service_config_factory(platform: str, **kwargs): # -> TrainingServiceConfig
32+
def training_service_config_factory(platform: Union[str, List[str]] = None, config: Union[List, Dict] = None): # -> TrainingServiceConfig
3333
from .common import TrainingServiceConfig
34-
for cls in TrainingServiceConfig.__subclasses__():
35-
if cls.platform == platform:
36-
return cls(**kwargs)
37-
raise ValueError(f'Unrecognized platform {platform}')
34+
ts_configs = []
35+
if platform is not None:
36+
assert config is None
37+
platforms = platform if isinstance(platform, list) else [platform]
38+
for cls in TrainingServiceConfig.__subclasses__():
39+
if cls.platform in platforms:
40+
ts_configs.append(cls())
41+
if len(ts_configs) < len(platforms):
42+
raise RuntimeError('There is unrecognized platform!')
43+
else:
44+
assert config is not None
45+
supported_platforms = {cls.platform: cls for cls in TrainingServiceConfig.__subclasses__()}
46+
configs = config if isinstance(config, list) else [config]
47+
for conf in configs:
48+
if conf['platform'] not in supported_platforms:
49+
raise RuntimeError(f'Unrecognized platform {conf["platform"]}')
50+
ts_configs.append(supported_platforms[conf['platform']](**conf))
51+
return ts_configs if len(ts_configs) > 1 else ts_configs[0]
3852

3953
def load_config(Type, value):
4054
if isinstance(value, list):

nni/experiment/experiment.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from subprocess import Popen
66
from threading import Thread
77
import time
8-
from typing import Optional, overload
8+
from typing import Optional, Union, List, overload
99

1010
import colorama
1111
import psutil
@@ -54,7 +54,7 @@ def __init__(self, tuner: Tuner, config: ExperimentConfig) -> None:
5454
...
5555

5656
@overload
57-
def __init__(self, tuner: Tuner, training_service: str) -> None:
57+
def __init__(self, tuner: Tuner, training_service: Union[str, List[str]]) -> None:
5858
"""
5959
Prepare an experiment, leaving configuration fields to be set later.
6060
@@ -86,7 +86,7 @@ def __init__(self, tuner: Tuner, config=None, training_service=None):
8686
self._dispatcher: Optional[MsgDispatcher] = None
8787
self._dispatcher_thread: Optional[Thread] = None
8888

89-
if isinstance(config, str):
89+
if isinstance(config, (str, list)):
9090
config, training_service = None, config
9191

9292
if config is None:

nni/experiment/launcher.py

+13-7
Original file line numberDiff line numberDiff line change
@@ -27,11 +27,13 @@ def start_experiment(exp_id: str, config: ExperimentConfig, port: int, debug: bo
2727

2828
config.validate(initialized_tuner=True)
2929
_ensure_port_idle(port)
30-
if config.training_service.platform == 'openpai':
31-
_ensure_port_idle(port + 1, 'OpenPAI requires an additional port')
30+
if isinstance(config.training_service, list): # hybrid training service
31+
_ensure_port_idle(port + 1, 'Hybrid training service requires an additional port')
32+
elif config.training_service.platform in ['remote', 'openpai', 'kubeflow', 'frameworkcontroller', 'adl']:
33+
_ensure_port_idle(port + 1, f'{config.training_service.platform} requires an additional port')
3234

3335
try:
34-
_logger.info('Creating experiment %s', colorama.Fore.CYAN + exp_id + colorama.Style.RESET_ALL)
36+
_logger.info('Creating experiment, Experiment ID: %s', colorama.Fore.CYAN + exp_id + colorama.Style.RESET_ALL)
3537
pipe = Pipe(exp_id)
3638
start_time, proc = _start_rest_server(config, port, debug, exp_id, pipe.path)
3739
_logger.info('Connecting IPC pipe...')
@@ -40,7 +42,8 @@ def start_experiment(exp_id: str, config: ExperimentConfig, port: int, debug: bo
4042
nni.runtime.protocol._out_file = pipe_file
4143
_logger.info('Statring web server...')
4244
_check_rest_server(port)
43-
_save_experiment_information(exp_id, port, start_time, config.training_service.platform,
45+
platform = 'hybrid' if isinstance(config.training_service, list) else config.training_service.platform
46+
_save_experiment_information(exp_id, port, start_time, platform,
4447
config.experiment_name, proc.pid, config.experiment_working_directory)
4548
_logger.info('Setting up...')
4649
_init_experiment(config, port, debug)
@@ -66,9 +69,12 @@ def _ensure_port_idle(port: int, message: Optional[str] = None) -> None:
6669

6770

6871
def _start_rest_server(config: ExperimentConfig, port: int, debug: bool, experiment_id: str, pipe_path: str) -> Tuple[int, Popen]:
69-
ts = config.training_service.platform
70-
if ts == 'openpai':
71-
ts = 'pai'
72+
if isinstance(config.training_service, list):
73+
ts = 'hybrid'
74+
else:
75+
ts = config.training_service.platform
76+
if ts == 'openpai':
77+
ts = 'pai'
7278

7379
args = {
7480
'port': port,

nni/retiarii/experiment.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ def __init__(self, training_service_platform: Optional[str] = None, **kwargs):
4646
super().__init__(**kwargs)
4747
if training_service_platform is not None:
4848
assert 'training_service' not in kwargs
49-
self.training_service = util.training_service_config_factory(training_service_platform)
49+
self.training_service = util.training_service_config_factory(platform = training_service_platform)
5050

5151
def validate(self, initialized_tuner: bool = False) -> None:
5252
super().validate()

nni/tools/nnictl/launcher.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -607,7 +607,7 @@ def create_experiment(args):
607607
try:
608608
validate_all_content(experiment_config, config_path)
609609
except Exception as e:
610-
print_error(f'Config validation failed. {repr(e)}')
610+
print_error(f'Config in v1 format validation failed. {repr(e)}')
611611
exit(1)
612612

613613
nni_config.set_config('experimentConfig', experiment_config)

0 commit comments

Comments
 (0)