Skip to content

Commit

Permalink
Enable gpu scheduler in AML mode (microsoft#2769)
Browse files Browse the repository at this point in the history
  • Loading branch information
SparkSnail authored Aug 11, 2020
1 parent accb40f commit 995f625
Show file tree
Hide file tree
Showing 5 changed files with 31 additions and 15 deletions.
18 changes: 11 additions & 7 deletions docs/en_US/TrainingService/AMLMode.md
Original file line number Diff line number Diff line change
Expand Up @@ -49,30 +49,34 @@ tuner:
trial:
command: python3 mnist.py
codeDir: .
computeTarget: ${replace_to_your_computeTarget}
image: msranni/nni
gpuNum: 1
amlConfig:
subscriptionId: ${replace_to_your_subscriptionId}
resourceGroup: ${replace_to_your_resourceGroup}
workspaceName: ${replace_to_your_workspaceName}

computeTarget: ${replace_to_your_computeTarget}
```
Note: You should set `trainingServicePlatform: aml` in NNI config YAML file if you want to start experiment in aml mode.

Compared with [LocalMode](LocalMode.md) trial configuration in aml mode have these additional keys:
* computeTarget
* required key. The compute cluster name you want to use in your AML workspace. See Step 6.
* image
* required key. The docker image name used in job. The image `msranni/nni` of this example only support GPU computeTargets.

amlConfig:
* subscriptionId
* the subscriptionId of your account
* required key, the subscriptionId of your account
* resourceGroup
* the resourceGroup of your account
* required key, the resourceGroup of your account
* workspaceName
* the workspaceName of your account
* required key, the workspaceName of your account
* computeTarget
* required key, the compute cluster name you want to use in your AML workspace. See Step 6.
* maxTrialNumPerGpu
* optional key, used to specify the max concurrency trial number on a GPU device.
* useActiveGpu
* optional key, used to specify whether to use a GPU if there is another process. By default, NNI will use the GPU only if there is no other active process in the GPU.

The required information of amlConfig could be found in the downloaded `config.json` in Step 5.

Expand Down
6 changes: 4 additions & 2 deletions src/nni_manager/rest_server/restValidationSchemas.ts
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@ export namespace ValidationSchemas {
nniManagerNFSMountPath: joi.string().min(1),
containerNFSMountPath: joi.string().min(1),
paiConfigPath: joi.string(),
computeTarget: joi.string(),
nodeCount: joi.number(),
paiStorageConfigName: joi.string().min(1),
nasMode: joi.string().valid('classic_mode', 'enas_mode', 'oneshot_mode', 'darts_mode'),
Expand Down Expand Up @@ -159,7 +158,10 @@ export namespace ValidationSchemas {
aml_config: joi.object({ // eslint-disable-line @typescript-eslint/camelcase
subscriptionId: joi.string().min(1),
resourceGroup: joi.string().min(1),
workspaceName: joi.string().min(1)
workspaceName: joi.string().min(1),
computeTarget: joi.string().min(1),
maxTrialNumPerGpu: joi.number(),
useActiveGpu: joi.boolean()
}),
nni_manager_ip: joi.object({ // eslint-disable-line @typescript-eslint/camelcase
nniManagerIp: joi.string().min(1)
Expand Down
13 changes: 9 additions & 4 deletions src/nni_manager/training_service/reusable/aml/amlConfig.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,26 +11,31 @@ export class AMLClusterConfig {
public readonly subscriptionId: string;
public readonly resourceGroup: string;
public readonly workspaceName: string;
public readonly computeTarget: string;
public useActiveGpu?: boolean;
public maxTrialNumPerGpu?: number;

constructor(subscriptionId: string, resourceGroup: string, workspaceName: string) {
constructor(subscriptionId: string, resourceGroup: string, workspaceName: string, computeTarget: string,
useActiveGpu?: boolean, maxTrialNumPerGpu?: number) {
this.subscriptionId = subscriptionId;
this.resourceGroup = resourceGroup;
this.workspaceName = workspaceName;
this.computeTarget = computeTarget;
this.useActiveGpu = useActiveGpu;
this.maxTrialNumPerGpu = maxTrialNumPerGpu;
}
}

export class AMLTrialConfig extends TrialConfig {
public readonly image: string;
public readonly command: string;
public readonly codeDir: string;
public readonly computeTarget: string;

constructor(codeDir: string, command: string, image: string, computeTarget: string) {
constructor(codeDir: string, command: string, image: string) {
super("", codeDir, 0);
this.codeDir = codeDir;
this.command = command;
this.image = image;
this.computeTarget = computeTarget;
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -112,13 +112,15 @@ export class AMLEnvironmentService extends EnvironmentService {
const amlEnvironment: AMLEnvironmentInformation = environment as AMLEnvironmentInformation;
const environmentLocalTempFolder = path.join(this.experimentRootDir, this.experimentId, "environment-temp");
environment.command = `import os\nos.system('${amlEnvironment.command}')`;
environment.useActiveGpu = this.amlClusterConfig.useActiveGpu;
environment.maxTrialNumberPerGpu = this.amlClusterConfig.maxTrialNumPerGpu;
await fs.promises.writeFile(path.join(environmentLocalTempFolder, 'nni_script.py'), amlEnvironment.command, { encoding: 'utf8' });
const amlClient = new AMLClient(
this.amlClusterConfig.subscriptionId,
this.amlClusterConfig.resourceGroup,
this.amlClusterConfig.workspaceName,
this.experimentId,
this.amlTrialConfig.computeTarget,
this.amlClusterConfig.computeTarget,
this.amlTrialConfig.image,
'nni_script.py',
environmentLocalTempFolder
Expand Down
5 changes: 4 additions & 1 deletion tools/nni_cmd/config_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,7 @@ def validate(self, data):
'codeDir': setPathCheck('codeDir'),
'command': setType('command', str),
'image': setType('image', str),
'computeTarget': setType('computeTarget', str)
Optional('gpuNum'): setNumberRange('gpuNum', int, 0, 99999),
}
}

Expand All @@ -254,6 +254,9 @@ def validate(self, data):
'subscriptionId': setType('subscriptionId', str),
'resourceGroup': setType('resourceGroup', str),
'workspaceName': setType('workspaceName', str),
'computeTarget': setType('computeTarget', str),
Optional('maxTrialNumPerGpu'): setType('maxTrialNumPerGpu', int),
Optional('useActiveGpu'): setType('useActiveGpu', bool),
}
}

Expand Down

0 comments on commit 995f625

Please sign in to comment.