Skip to content
This repository has been archived by the owner on Sep 18, 2024. It is now read-only.

support hybrid training service v2.0 config #3251

Merged
merged 12 commits into from
Jan 6, 2021
14 changes: 9 additions & 5 deletions nni/experiment/config/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,15 +65,19 @@ class ExperimentConfig(ConfigBase):
tuner: Optional[_AlgorithmConfig] = None
accessor: Optional[_AlgorithmConfig] = None
advisor: Optional[_AlgorithmConfig] = None
training_service: TrainingServiceConfig
training_service: Union[TrainingServiceConfig, List[TrainingServiceConfig]]

def __init__(self, training_service_platform: Optional[str] = None, **kwargs):
def __init__(self, training_service_platform: Optional[Union[str, List[str]]] = None, **kwargs):
kwargs = util.case_insensitive(kwargs)
if training_service_platform is not None:
assert 'trainingservice' not in kwargs
kwargs['trainingservice'] = util.training_service_config_factory(training_service_platform)
elif isinstance(kwargs.get('trainingservice'), dict):
kwargs['trainingservice'] = util.training_service_config_factory(**kwargs['trainingservice'])
kwargs['trainingservice'] = util.training_service_config_factory(platform = training_service_platform)
elif isinstance(kwargs.get('trainingservice'), (dict, list)):
# dict means a single training service
# list means hybrid training service
kwargs['trainingservice'] = util.training_service_config_factory(config = kwargs['trainingservice'])
else:
raise RuntimeError('Unsupported Training service configuration!')
super().__init__(**kwargs)

def validate(self, initialized_tuner: bool = False) -> None:
Expand Down
23 changes: 18 additions & 5 deletions nni/experiment/config/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,16 +18,28 @@ def to_v1_yaml(config: ExperimentConfig, skip_nnictl: bool = False) -> Dict[str,
data = config.json()

ts = data.pop('trainingService')
if ts['platform'] == 'openpai':
ts['platform'] = 'pai'
if isinstance(ts, list):
hybrid_names = []
for conf in ts:
if conf['platform'] == 'openpai':
conf['platform'] = 'pai'
hybrid_names.append(conf['platform'])
_handle_training_service(conf, data)
data['trainingServicePlatform'] = 'heterogeneous'
data['heterogeneousConfig'] = {'trainingServicePlatforms': hybrid_names}
else:
if ts['platform'] == 'openpai':
ts['platform'] = 'pai'
data['trainingServicePlatform'] = ts['platform']
_handle_training_service(ts, data)

data['authorName'] = 'N/A'
data['experimentName'] = data.get('experimentName', 'N/A')
data['maxExecDuration'] = data.pop('maxExperimentDuration', '999d')
if data['debug']:
data['versionCheck'] = False
data['maxTrialNum'] = data.pop('maxTrialNumber', 99999)
data['trainingServicePlatform'] = ts['platform']

ss = data.pop('searchSpace', None)
ss_file = data.pop('searchSpaceFile', None)
if ss is not None:
Expand Down Expand Up @@ -66,6 +78,9 @@ def to_v1_yaml(config: ExperimentConfig, skip_nnictl: bool = False) -> Dict[str,
if 'trialGpuNumber' in data:
data['trial']['gpuNum'] = data.pop('trialGpuNumber')

return data

def _handle_training_service(ts, data):
if ts['platform'] == 'local':
data['localConfig'] = {
'useActiveGpu': ts.get('useActiveGpu', False),
Expand Down Expand Up @@ -140,8 +155,6 @@ def to_v1_yaml(config: ExperimentConfig, skip_nnictl: bool = False) -> Dict[str,
elif ts['platform'] == 'adl':
data['trial']['image'] = ts['dockerImage']

return data

def _convert_gpu_indices(indices):
return ','.join(str(idx) for idx in indices) if indices is not None else None

Expand Down
26 changes: 20 additions & 6 deletions nni/experiment/config/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import math
import os.path
from pathlib import Path
from typing import Any, Dict, Optional, Union
from typing import Any, Dict, Optional, Union, List

PathLike = Union[Path, str]

Expand All @@ -29,12 +29,26 @@ def canonical_path(path: Optional[PathLike]) -> Optional[str]:
def count(*values) -> int:
return sum(value is not None and value is not False for value in values)

def training_service_config_factory(platform: str, **kwargs): # -> TrainingServiceConfig
def training_service_config_factory(platform: Union[str, List[str]] = None, config: Union[List, Dict] = None): # -> TrainingServiceConfig
from .common import TrainingServiceConfig
for cls in TrainingServiceConfig.__subclasses__():
if cls.platform == platform:
return cls(**kwargs)
raise ValueError(f'Unrecognized platform {platform}')
ts_configs = []
if platform is not None:
assert config is None
platforms = platform if isinstance(platform, list) else [platform]
for cls in TrainingServiceConfig.__subclasses__():
if cls.platform in platforms:
ts_configs.append(cls())
if len(ts_configs) < len(platforms):
raise RuntimeError('There is unrecognized platform!')
else:
assert config is not None
supported_platforms = {cls.platform: cls for cls in TrainingServiceConfig.__subclasses__()}
configs = config if isinstance(config, list) else [config]
for conf in configs:
if conf['platform'] not in supported_platforms:
raise RuntimeError(f'Unrecognized platform {conf["platform"]}')
ts_configs.append(supported_platforms[conf['platform']](**conf))
return ts_configs if len(ts_configs) > 1 else ts_configs[0]

def load_config(Type, value):
if isinstance(value, list):
Expand Down
2 changes: 1 addition & 1 deletion nni/retiarii/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def __init__(self, training_service_platform: Optional[str] = None, **kwargs):
super().__init__(**kwargs)
if training_service_platform is not None:
assert 'training_service' not in kwargs
self.training_service = util.training_service_config_factory(training_service_platform)
self.training_service = util.training_service_config_factory(platform = training_service_platform)

def validate(self, initialized_tuner: bool = False) -> None:
super().validate()
Expand Down
23 changes: 14 additions & 9 deletions nni/tools/nnictl/launcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from .rest_utils import rest_put, rest_post, check_rest_server, check_response
from .url_utils import cluster_metadata_url, experiment_url, get_local_urls
from .config_utils import Config, Experiments
from .common_utils import get_yml_content, get_json_content, print_error, print_normal, \
from .common_utils import get_yml_content, get_json_content, print_error, print_normal, print_warning, \
detect_port, get_user

from .constants import NNICTL_HOME_DIR, ERROR_INFO, REST_TIME_OUT, EXPERIMENT_SUCCESS_INFO, LOG_HEADER
Expand Down Expand Up @@ -592,16 +592,21 @@ def create_experiment(args):
print_error('Please set correct config path!')
exit(1)
experiment_config = get_yml_content(config_path)
try:
config = ExperimentConfig(**experiment_config)
experiment_config = convert.to_v1_yaml(config)
except Exception:
pass

try:
validate_all_content(experiment_config, config_path)
except Exception as e:
print_error(e)
exit(1)
except Exception:
print_warning('Validation with V1 schema failed. Trying to convert from V2 format...')
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When use V2 yaml file, there always be a warning here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, this is a content that I merge from yuge's pr. I think it is okay for now

try:
config = ExperimentConfig(**experiment_config)
experiment_config = convert.to_v1_yaml(config)
except Exception as e:
print_error(f'Conversion from v2 format failed: {repr(e)}')
try:
validate_all_content(experiment_config, config_path)
except Exception as e:
print_error(f'Config validation failed. {repr(e)}')
exit(1)

nni_config.set_config('experimentConfig', experiment_config)
nni_config.set_config('restServerPort', args.port)
Expand Down