Skip to content
This repository was archived by the owner on Sep 18, 2024. It is now read-only.

Commit 995f625

Browse files
authored
Enable gpu scheduler in AML mode (#2769)
1 parent accb40f commit 995f625

File tree

5 files changed

+31
-15
lines changed

5 files changed

+31
-15
lines changed

docs/en_US/TrainingService/AMLMode.md

+11-7
Original file line numberDiff line numberDiff line change
@@ -49,30 +49,34 @@ tuner:
4949
trial:
5050
command: python3 mnist.py
5151
codeDir: .
52-
computeTarget: ${replace_to_your_computeTarget}
5352
image: msranni/nni
53+
gpuNum: 1
5454
amlConfig:
5555
subscriptionId: ${replace_to_your_subscriptionId}
5656
resourceGroup: ${replace_to_your_resourceGroup}
5757
workspaceName: ${replace_to_your_workspaceName}
58-
58+
computeTarget: ${replace_to_your_computeTarget}
5959
```
6060
6161
Note: You should set `trainingServicePlatform: aml` in NNI config YAML file if you want to start experiment in aml mode.
6262

6363
Compared with [LocalMode](LocalMode.md) trial configuration in aml mode have these additional keys:
64-
* computeTarget
65-
* required key. The compute cluster name you want to use in your AML workspace. See Step 6.
6664
* image
6765
* required key. The docker image name used in job. The image `msranni/nni` of this example only support GPU computeTargets.
6866

6967
amlConfig:
7068
* subscriptionId
71-
* the subscriptionId of your account
69+
* required key, the subscriptionId of your account
7270
* resourceGroup
73-
* the resourceGroup of your account
71+
* required key, the resourceGroup of your account
7472
* workspaceName
75-
* the workspaceName of your account
73+
* required key, the workspaceName of your account
74+
* computeTarget
75+
* required key, the compute cluster name you want to use in your AML workspace. See Step 6.
76+
* maxTrialNumPerGpu
77+
* optional key, used to specify the max concurrency trial number on a GPU device.
78+
* useActiveGpu
79+
* optional key, used to specify whether to use a GPU if there is another process. By default, NNI will use the GPU only if there is no other active process in the GPU.
7680

7781
The required information of amlConfig could be found in the downloaded `config.json` in Step 5.
7882

src/nni_manager/rest_server/restValidationSchemas.ts

+4-2
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,6 @@ export namespace ValidationSchemas {
3939
nniManagerNFSMountPath: joi.string().min(1),
4040
containerNFSMountPath: joi.string().min(1),
4141
paiConfigPath: joi.string(),
42-
computeTarget: joi.string(),
4342
nodeCount: joi.number(),
4443
paiStorageConfigName: joi.string().min(1),
4544
nasMode: joi.string().valid('classic_mode', 'enas_mode', 'oneshot_mode', 'darts_mode'),
@@ -159,7 +158,10 @@ export namespace ValidationSchemas {
159158
aml_config: joi.object({ // eslint-disable-line @typescript-eslint/camelcase
160159
subscriptionId: joi.string().min(1),
161160
resourceGroup: joi.string().min(1),
162-
workspaceName: joi.string().min(1)
161+
workspaceName: joi.string().min(1),
162+
computeTarget: joi.string().min(1),
163+
maxTrialNumPerGpu: joi.number(),
164+
useActiveGpu: joi.boolean()
163165
}),
164166
nni_manager_ip: joi.object({ // eslint-disable-line @typescript-eslint/camelcase
165167
nniManagerIp: joi.string().min(1)

src/nni_manager/training_service/reusable/aml/amlConfig.ts

+9-4
Original file line numberDiff line numberDiff line change
@@ -11,26 +11,31 @@ export class AMLClusterConfig {
1111
public readonly subscriptionId: string;
1212
public readonly resourceGroup: string;
1313
public readonly workspaceName: string;
14+
public readonly computeTarget: string;
15+
public useActiveGpu?: boolean;
16+
public maxTrialNumPerGpu?: number;
1417

15-
constructor(subscriptionId: string, resourceGroup: string, workspaceName: string) {
18+
constructor(subscriptionId: string, resourceGroup: string, workspaceName: string, computeTarget: string,
19+
useActiveGpu?: boolean, maxTrialNumPerGpu?: number) {
1620
this.subscriptionId = subscriptionId;
1721
this.resourceGroup = resourceGroup;
1822
this.workspaceName = workspaceName;
23+
this.computeTarget = computeTarget;
24+
this.useActiveGpu = useActiveGpu;
25+
this.maxTrialNumPerGpu = maxTrialNumPerGpu;
1926
}
2027
}
2128

2229
export class AMLTrialConfig extends TrialConfig {
2330
public readonly image: string;
2431
public readonly command: string;
2532
public readonly codeDir: string;
26-
public readonly computeTarget: string;
2733

28-
constructor(codeDir: string, command: string, image: string, computeTarget: string) {
34+
constructor(codeDir: string, command: string, image: string) {
2935
super("", codeDir, 0);
3036
this.codeDir = codeDir;
3137
this.command = command;
3238
this.image = image;
33-
this.computeTarget = computeTarget;
3439
}
3540
}
3641

src/nni_manager/training_service/reusable/environments/amlEnvironmentService.ts

+3-1
Original file line numberDiff line numberDiff line change
@@ -112,13 +112,15 @@ export class AMLEnvironmentService extends EnvironmentService {
112112
const amlEnvironment: AMLEnvironmentInformation = environment as AMLEnvironmentInformation;
113113
const environmentLocalTempFolder = path.join(this.experimentRootDir, this.experimentId, "environment-temp");
114114
environment.command = `import os\nos.system('${amlEnvironment.command}')`;
115+
environment.useActiveGpu = this.amlClusterConfig.useActiveGpu;
116+
environment.maxTrialNumberPerGpu = this.amlClusterConfig.maxTrialNumPerGpu;
115117
await fs.promises.writeFile(path.join(environmentLocalTempFolder, 'nni_script.py'), amlEnvironment.command, { encoding: 'utf8' });
116118
const amlClient = new AMLClient(
117119
this.amlClusterConfig.subscriptionId,
118120
this.amlClusterConfig.resourceGroup,
119121
this.amlClusterConfig.workspaceName,
120122
this.experimentId,
121-
this.amlTrialConfig.computeTarget,
123+
this.amlClusterConfig.computeTarget,
122124
this.amlTrialConfig.image,
123125
'nni_script.py',
124126
environmentLocalTempFolder

tools/nni_cmd/config_schema.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -245,7 +245,7 @@ def validate(self, data):
245245
'codeDir': setPathCheck('codeDir'),
246246
'command': setType('command', str),
247247
'image': setType('image', str),
248-
'computeTarget': setType('computeTarget', str)
248+
Optional('gpuNum'): setNumberRange('gpuNum', int, 0, 99999),
249249
}
250250
}
251251

@@ -254,6 +254,9 @@ def validate(self, data):
254254
'subscriptionId': setType('subscriptionId', str),
255255
'resourceGroup': setType('resourceGroup', str),
256256
'workspaceName': setType('workspaceName', str),
257+
'computeTarget': setType('computeTarget', str),
258+
Optional('maxTrialNumPerGpu'): setType('maxTrialNumPerGpu', int),
259+
Optional('useActiveGpu'): setType('useActiveGpu', bool),
257260
}
258261
}
259262

0 commit comments

Comments
 (0)