Skip to content
This repository has been archived by the owner on Sep 18, 2024. It is now read-only.

support hybrid training service v2.0 config #3251

Merged
merged 12 commits into from
Jan 6, 2021
14 changes: 9 additions & 5 deletions nni/experiment/config/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,15 +65,19 @@ class ExperimentConfig(ConfigBase):
tuner: Optional[_AlgorithmConfig] = None
accessor: Optional[_AlgorithmConfig] = None
advisor: Optional[_AlgorithmConfig] = None
training_service: TrainingServiceConfig
training_service: Union[TrainingServiceConfig, List[TrainingServiceConfig]]

def __init__(self, training_service_platform: Optional[str] = None, **kwargs):
def __init__(self, training_service_platform: Optional[Union[str, List[str]]] = None, **kwargs):
kwargs = util.case_insensitive(kwargs)
if training_service_platform is not None:
assert 'trainingservice' not in kwargs
kwargs['trainingservice'] = util.training_service_config_factory(training_service_platform)
elif isinstance(kwargs.get('trainingservice'), dict):
kwargs['trainingservice'] = util.training_service_config_factory(**kwargs['trainingservice'])
kwargs['trainingservice'] = util.training_service_config_factory(platform = training_service_platform)
elif isinstance(kwargs.get('trainingservice'), (dict, list)):
# dict means a single training service
# list means hybrid training service
kwargs['trainingservice'] = util.training_service_config_factory(config = kwargs['trainingservice'])
else:
raise RuntimeError('Unsupported Training service configuration!')
super().__init__(**kwargs)

def validate(self, initialized_tuner: bool = False) -> None:
Expand Down
23 changes: 18 additions & 5 deletions nni/experiment/config/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,16 +18,28 @@ def to_v1_yaml(config: ExperimentConfig, skip_nnictl: bool = False) -> Dict[str,
data = config.json()

ts = data.pop('trainingService')
if ts['platform'] == 'openpai':
ts['platform'] = 'pai'
if isinstance(ts, list):
hybrid_names = []
for conf in ts:
if conf['platform'] == 'openpai':
conf['platform'] = 'pai'
hybrid_names.append(conf['platform'])
_handle_training_service(conf, data)
data['trainingServicePlatform'] = 'heterogeneous'
data['heterogeneousConfig'] = {'trainingServicePlatforms': hybrid_names}
else:
if ts['platform'] == 'openpai':
ts['platform'] = 'pai'
data['trainingServicePlatform'] = ts['platform']
_handle_training_service(ts, data)

data['authorName'] = 'N/A'
data['experimentName'] = data.get('experimentName', 'N/A')
data['maxExecDuration'] = data.pop('maxExperimentDuration', '999d')
if data['debug']:
data['versionCheck'] = False
data['maxTrialNum'] = data.pop('maxTrialNumber', 99999)
data['trainingServicePlatform'] = ts['platform']

ss = data.pop('searchSpace', None)
ss_file = data.pop('searchSpaceFile', None)
if ss is not None:
Expand Down Expand Up @@ -66,6 +78,9 @@ def to_v1_yaml(config: ExperimentConfig, skip_nnictl: bool = False) -> Dict[str,
if 'trialGpuNumber' in data:
data['trial']['gpuNum'] = data.pop('trialGpuNumber')

return data

def _handle_training_service(ts, data):
if ts['platform'] == 'local':
data['localConfig'] = {
'useActiveGpu': ts.get('useActiveGpu', False),
Expand Down Expand Up @@ -140,8 +155,6 @@ def to_v1_yaml(config: ExperimentConfig, skip_nnictl: bool = False) -> Dict[str,
elif ts['platform'] == 'adl':
data['trial']['image'] = ts['dockerImage']

return data

def _convert_gpu_indices(indices):
return ','.join(str(idx) for idx in indices) if indices is not None else None

Expand Down
26 changes: 20 additions & 6 deletions nni/experiment/config/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import math
import os.path
from pathlib import Path
from typing import Any, Dict, Optional, Union
from typing import Any, Dict, Optional, Union, List

PathLike = Union[Path, str]

Expand All @@ -29,12 +29,26 @@ def canonical_path(path: Optional[PathLike]) -> Optional[str]:
def count(*values) -> int:
return sum(value is not None and value is not False for value in values)

def training_service_config_factory(platform: str, **kwargs): # -> TrainingServiceConfig
def training_service_config_factory(platform: Union[str, List[str]] = None, config: Union[List, Dict] = None): # -> TrainingServiceConfig
from .common import TrainingServiceConfig
for cls in TrainingServiceConfig.__subclasses__():
if cls.platform == platform:
return cls(**kwargs)
raise ValueError(f'Unrecognized platform {platform}')
ts_configs = []
if platform is not None:
assert config is None
platforms = platform if isinstance(platform, list) else [platform]
for cls in TrainingServiceConfig.__subclasses__():
if cls.platform in platforms:
ts_configs.append(cls())
if len(ts_configs) < len(platforms):
raise RuntimeError('There is unrecognized platform!')
else:
assert config is not None
supported_platforms = {cls.platform: cls for cls in TrainingServiceConfig.__subclasses__()}
configs = config if isinstance(config, list) else [config]
for conf in configs:
if conf['platform'] not in supported_platforms:
raise RuntimeError(f'Unrecognized platform {conf["platform"]}')
ts_configs.append(supported_platforms[conf['platform']](**conf))
return ts_configs if len(ts_configs) > 1 else ts_configs[0]

def load_config(Type, value):
if isinstance(value, list):
Expand Down
2 changes: 1 addition & 1 deletion nni/retiarii/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def __init__(self, training_service_platform: Optional[str] = None, **kwargs):
super().__init__(**kwargs)
if training_service_platform is not None:
assert 'training_service' not in kwargs
self.training_service = util.training_service_config_factory(training_service_platform)
self.training_service = util.training_service_config_factory(platform = training_service_platform)

def validate(self, initialized_tuner: bool = False) -> None:
super().validate()
Expand Down
1 change: 0 additions & 1 deletion nni/tools/nnictl/launcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -592,7 +592,6 @@ def create_experiment(args):
print_error('Please set correct config path!')
exit(1)
experiment_config = get_yml_content(config_path)

try:
validate_all_content(experiment_config, config_path)
except Exception:
Expand Down