@@ -18,16 +18,28 @@ def to_v1_yaml(config: ExperimentConfig, skip_nnictl: bool = False) -> Dict[str,
18
18
data = config .json ()
19
19
20
20
ts = data .pop ('trainingService' )
21
- if ts ['platform' ] == 'openpai' :
22
- ts ['platform' ] = 'pai'
21
+ if isinstance (ts , list ):
22
+ hybrid_names = []
23
+ for conf in ts :
24
+ if conf ['platform' ] == 'openpai' :
25
+ conf ['platform' ] = 'pai'
26
+ hybrid_names .append (conf ['platform' ])
27
+ _handle_training_service (conf , data )
28
+ data ['trainingServicePlatform' ] = 'hybrid'
29
+ data ['hybridConfig' ] = {'trainingServicePlatforms' : hybrid_names }
30
+ else :
31
+ if ts ['platform' ] == 'openpai' :
32
+ ts ['platform' ] = 'pai'
33
+ data ['trainingServicePlatform' ] = ts ['platform' ]
34
+ _handle_training_service (ts , data )
23
35
24
36
data ['authorName' ] = 'N/A'
25
37
data ['experimentName' ] = data .get ('experimentName' , 'N/A' )
26
38
data ['maxExecDuration' ] = data .pop ('maxExperimentDuration' , '999d' )
27
39
if data ['debug' ]:
28
40
data ['versionCheck' ] = False
29
41
data ['maxTrialNum' ] = data .pop ('maxTrialNumber' , 99999 )
30
- data [ 'trainingServicePlatform' ] = ts [ 'platform' ]
42
+
31
43
ss = data .pop ('searchSpace' , None )
32
44
ss_file = data .pop ('searchSpaceFile' , None )
33
45
if ss is not None :
@@ -66,6 +78,9 @@ def to_v1_yaml(config: ExperimentConfig, skip_nnictl: bool = False) -> Dict[str,
66
78
if 'trialGpuNumber' in data :
67
79
data ['trial' ]['gpuNum' ] = data .pop ('trialGpuNumber' )
68
80
81
+ return data
82
+
83
+ def _handle_training_service (ts , data ):
69
84
if ts ['platform' ] == 'local' :
70
85
data ['localConfig' ] = {
71
86
'useActiveGpu' : ts .get ('useActiveGpu' , False ),
@@ -140,8 +155,6 @@ def to_v1_yaml(config: ExperimentConfig, skip_nnictl: bool = False) -> Dict[str,
140
155
elif ts ['platform' ] == 'adl' :
141
156
data ['trial' ]['image' ] = ts ['dockerImage' ]
142
157
143
- return data
144
-
145
158
def _convert_gpu_indices (indices ):
146
159
return ',' .join (str (idx ) for idx in indices ) if indices is not None else None
147
160
@@ -175,19 +188,34 @@ def to_cluster_metadata(config: ExperimentConfig) -> List[Dict[str, Any]]:
175
188
experiment_config = to_v1_yaml (config , skip_nnictl = True )
176
189
ret = []
177
190
178
- if config .training_service .platform == 'local' :
191
+ if isinstance (config .training_service , list ):
192
+ hybrid_conf = dict ()
193
+ hybrid_conf ['hybrid_config' ] = experiment_config ['hybridConfig' ]
194
+ for conf in config .training_service :
195
+ metadata = _get_cluster_metadata (conf .platform , experiment_config )
196
+ if metadata is not None :
197
+ hybrid_conf .update (metadata )
198
+ ret .append (hybrid_conf )
199
+ else :
200
+ metadata = _get_cluster_metadata (config .training_service .platform , experiment_config )
201
+ if metadata is not None :
202
+ ret .append (metadata )
203
+
204
+ if experiment_config .get ('nniManagerIp' ) is not None :
205
+ ret .append ({'nni_manager_ip' : {'nniManagerIp' : experiment_config ['nniManagerIp' ]}})
206
+ ret .append ({'trial_config' : experiment_config ['trial' ]})
207
+ return ret
208
+
209
+ def _get_cluster_metadata (platform : str , experiment_config ) -> Dict :
210
+ if platform == 'local' :
179
211
request_data = dict ()
180
212
request_data ['local_config' ] = experiment_config ['localConfig' ]
181
213
if request_data ['local_config' ]:
182
214
if request_data ['local_config' ].get ('gpuIndices' ) and isinstance (request_data ['local_config' ].get ('gpuIndices' ), int ):
183
215
request_data ['local_config' ]['gpuIndices' ] = str (request_data ['local_config' ].get ('gpuIndices' ))
184
- if request_data ['local_config' ].get ('maxTrialNumOnEachGpu' ):
185
- request_data ['local_config' ]['maxTrialNumOnEachGpu' ] = request_data ['local_config' ].get ('maxTrialNumOnEachGpu' )
186
- if request_data ['local_config' ].get ('useActiveGpu' ):
187
- request_data ['local_config' ]['useActiveGpu' ] = request_data ['local_config' ].get ('useActiveGpu' )
188
- ret .append (request_data )
216
+ return request_data
189
217
190
- elif config . training_service . platform == 'remote' :
218
+ elif platform == 'remote' :
191
219
request_data = dict ()
192
220
if experiment_config .get ('remoteConfig' ):
193
221
request_data ['remote_config' ] = experiment_config ['remoteConfig' ]
@@ -198,31 +226,25 @@ def to_cluster_metadata(config: ExperimentConfig) -> List[Dict[str, Any]]:
198
226
for i in range (len (request_data ['machine_list' ])):
199
227
if isinstance (request_data ['machine_list' ][i ].get ('gpuIndices' ), int ):
200
228
request_data ['machine_list' ][i ]['gpuIndices' ] = str (request_data ['machine_list' ][i ].get ('gpuIndices' ))
201
- ret . append ( request_data )
229
+ return request_data
202
230
203
- elif config . training_service . platform == 'openpai' :
204
- ret . append ( {'pai_config' : experiment_config ['paiConfig' ]})
231
+ elif platform == 'openpai' :
232
+ return {'pai_config' : experiment_config ['paiConfig' ]}
205
233
206
- elif config . training_service . platform == 'aml' :
207
- ret . append ( {'aml_config' : experiment_config ['amlConfig' ]})
234
+ elif platform == 'aml' :
235
+ return {'aml_config' : experiment_config ['amlConfig' ]}
208
236
209
- elif config . training_service . platform == 'kubeflow' :
210
- ret . append ( {'kubeflow_config' : experiment_config ['kubeflowConfig' ]})
237
+ elif platform == 'kubeflow' :
238
+ return {'kubeflow_config' : experiment_config ['kubeflowConfig' ]}
211
239
212
- elif config . training_service . platform == 'frameworkcontroller' :
213
- ret . append ( {'frameworkcontroller_config' : experiment_config ['frameworkcontrollerConfig' ]})
240
+ elif platform == 'frameworkcontroller' :
241
+ return {'frameworkcontroller_config' : experiment_config ['frameworkcontrollerConfig' ]}
214
242
215
- elif config . training_service . platform == 'adl' :
216
- pass
243
+ elif platform == 'adl' :
244
+ return None
217
245
218
246
else :
219
- raise RuntimeError ('Unsupported training service ' + config .training_service .platform )
220
-
221
- if experiment_config .get ('nniManagerIp' ) is not None :
222
- ret .append ({'nni_manager_ip' : {'nniManagerIp' : experiment_config ['nniManagerIp' ]}})
223
- ret .append ({'trial_config' : experiment_config ['trial' ]})
224
- return ret
225
-
247
+ raise RuntimeError ('Unsupported training service ' + platform )
226
248
227
249
def to_rest_json (config : ExperimentConfig ) -> Dict [str , Any ]:
228
250
experiment_config = to_v1_yaml (config , skip_nnictl = True )
0 commit comments