This repository has been archived by the owner on Sep 18, 2024. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 1.8k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
12 changed files
with
738 additions
and
29 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,51 @@ | ||
import math | ||
|
||
import torch.nn as nn | ||
|
||
|
||
def truncated_normal_(tensor, mean=0, std=1): | ||
# https://discuss.pytorch.org/t/implementing-truncated-normal-initializer/4778/15 | ||
size = tensor.shape | ||
tmp = tensor.new_empty(size + (4,)).normal_() | ||
valid = (tmp < 2) & (tmp > -2) | ||
ind = valid.max(-1, keepdim=True)[1] | ||
tensor.data.copy_(tmp.gather(-1, ind).squeeze(-1)) | ||
tensor.data.mul_(std).add_(mean) | ||
|
||
|
||
class ConvBnRelu(nn.Module): | ||
def __init__(self, in_channels, out_channels, kernel_size=1, stride=1, padding=0): | ||
super(ConvBnRelu, self).__init__() | ||
self.in_channels = in_channels | ||
self.out_channels = out_channels | ||
self.conv_bn_relu = nn.Sequential( | ||
nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias=False), | ||
nn.BatchNorm2d(out_channels), | ||
nn.ReLU(inplace=True) | ||
) | ||
self.reset_parameters() | ||
|
||
def reset_parameters(self): | ||
for m in self.modules(): | ||
if isinstance(m, nn.Conv2d): | ||
fan_in = m.kernel_size[0] * m.kernel_size[1] * m.in_channels | ||
truncated_normal_(m.weight.data, mean=0., std=math.sqrt(1. / fan_in)) | ||
if isinstance(m, nn.BatchNorm2d): | ||
m.weight.data.fill_(1) | ||
m.bias.data.zero_() | ||
|
||
def forward(self, x): | ||
return self.conv_bn_relu(x) | ||
|
||
|
||
class Conv3x3BnRelu(ConvBnRelu): | ||
def __init__(self, in_channels, out_channels): | ||
super(Conv3x3BnRelu, self).__init__(in_channels, out_channels, kernel_size=3, stride=1, padding=1) | ||
|
||
|
||
class Conv1x1BnRelu(ConvBnRelu): | ||
def __init__(self, in_channels, out_channels): | ||
super(Conv1x1BnRelu, self).__init__(in_channels, out_channels, kernel_size=1, stride=1, padding=0) | ||
|
||
|
||
Projection = Conv1x1BnRelu |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,173 @@ | ||
import click | ||
import nni | ||
import nni.retiarii.evaluator.pytorch.lightning as pl | ||
import torch.nn as nn | ||
import torchmetrics | ||
from nni.retiarii import model_wrapper, serialize, serialize_cls | ||
from nni.retiarii.experiment.pytorch import RetiariiExperiment, RetiariiExeConfig | ||
from nni.retiarii.nn.pytorch import NasBench101Cell | ||
from nni.retiarii.strategy import Random | ||
from pytorch_lightning.callbacks import LearningRateMonitor | ||
from timm.optim import RMSpropTF | ||
from torch.optim.lr_scheduler import CosineAnnealingLR | ||
from torchvision import transforms | ||
from torchvision.datasets import CIFAR10 | ||
|
||
from base_ops import Conv3x3BnRelu, Conv1x1BnRelu, Projection | ||
|
||
|
||
@model_wrapper | ||
class NasBench101(nn.Module): | ||
def __init__(self, | ||
stem_out_channels: int = 128, | ||
num_stacks: int = 3, | ||
num_modules_per_stack: int = 3, | ||
max_num_vertices: int = 7, | ||
max_num_edges: int = 9, | ||
num_labels: int = 10, | ||
bn_eps: float = 1e-5, | ||
bn_momentum: float = 0.003): | ||
super().__init__() | ||
|
||
op_candidates = { | ||
'conv3x3': lambda num_features: Conv3x3BnRelu(num_features, num_features), | ||
'conv1x1': lambda num_features: Conv1x1BnRelu(num_features, num_features), | ||
'maxpool': lambda num_features: nn.MaxPool2d(3, 1, 1) | ||
} | ||
|
||
# initial stem convolution | ||
self.stem_conv = Conv3x3BnRelu(3, stem_out_channels) | ||
|
||
layers = [] | ||
in_channels = out_channels = stem_out_channels | ||
for stack_num in range(num_stacks): | ||
if stack_num > 0: | ||
downsample = nn.MaxPool2d(kernel_size=2, stride=2) | ||
layers.append(downsample) | ||
out_channels *= 2 | ||
for _ in range(num_modules_per_stack): | ||
cell = NasBench101Cell(op_candidates, in_channels, out_channels, | ||
lambda cin, cout: Projection(cin, cout), | ||
max_num_vertices, max_num_edges, label='cell') | ||
layers.append(cell) | ||
in_channels = out_channels | ||
|
||
self.features = nn.ModuleList(layers) | ||
self.gap = nn.AdaptiveAvgPool2d(1) | ||
self.classifier = nn.Linear(out_channels, num_labels) | ||
|
||
for module in self.modules(): | ||
if isinstance(module, nn.BatchNorm2d): | ||
module.eps = bn_eps | ||
module.momentum = bn_momentum | ||
|
||
def forward(self, x): | ||
bs = x.size(0) | ||
out = self.stem_conv(x) | ||
for layer in self.features: | ||
out = layer(out) | ||
out = self.gap(out).view(bs, -1) | ||
out = self.classifier(out) | ||
return out | ||
|
||
def reset_parameters(self): | ||
for module in self.modules(): | ||
if isinstance(module, nn.BatchNorm2d): | ||
module.eps = self.config.bn_eps | ||
module.momentum = self.config.bn_momentum | ||
|
||
|
||
class AccuracyWithLogits(torchmetrics.Accuracy): | ||
def update(self, pred, target): | ||
return super().update(nn.functional.softmax(pred), target) | ||
|
||
|
||
@serialize_cls | ||
class NasBench101TrainingModule(pl.LightningModule): | ||
def __init__(self, max_epochs=108, learning_rate=0.1, weight_decay=1e-4): | ||
super().__init__() | ||
self.save_hyperparameters('learning_rate', 'weight_decay', 'max_epochs') | ||
self.criterion = nn.CrossEntropyLoss() | ||
self.accuracy = AccuracyWithLogits() | ||
|
||
def forward(self, x): | ||
y_hat = self.model(x) | ||
return y_hat | ||
|
||
def training_step(self, batch, batch_idx): | ||
x, y = batch | ||
y_hat = self(x) | ||
loss = self.criterion(y_hat, y) | ||
self.log('train_loss', loss, prog_bar=True) | ||
self.log('train_accuracy', self.accuracy(y_hat, y), prog_bar=True) | ||
return loss | ||
|
||
def validation_step(self, batch, batch_idx): | ||
x, y = batch | ||
y_hat = self(x) | ||
self.log('val_loss', self.criterion(y_hat, y), prog_bar=True) | ||
self.log('val_accuracy', self.accuracy(y_hat, y), prog_bar=True) | ||
|
||
def configure_optimizers(self): | ||
optimizer = RMSpropTF(self.parameters(), lr=self.hparams.learning_rate, | ||
weight_decay=self.hparams.weight_decay, | ||
momentum=0.9, alpha=0.9, eps=1.0) | ||
return { | ||
'optimizer': optimizer, | ||
'scheduler': CosineAnnealingLR(optimizer, self.hparams.max_epochs) | ||
} | ||
|
||
def on_validation_epoch_end(self): | ||
nni.report_intermediate_result(self.trainer.callback_metrics['val_accuracy'].item()) | ||
|
||
def teardown(self, stage): | ||
if stage == 'fit': | ||
nni.report_final_result(self.trainer.callback_metrics['val_accuracy'].item()) | ||
|
||
|
||
@click.command() | ||
@click.option('--epochs', default=108, help='Training length.') | ||
@click.option('--batch_size', default=256, help='Batch size.') | ||
@click.option('--port', default=8081, help='On which port the experiment is run.') | ||
def _multi_trial_test(epochs, batch_size, port): | ||
# initalize dataset. Note that 50k+10k is used. It's a little different from paper | ||
transf = [ | ||
transforms.RandomCrop(32, padding=4), | ||
transforms.RandomHorizontalFlip() | ||
] | ||
normalize = [ | ||
transforms.ToTensor(), | ||
transforms.Normalize([0.49139968, 0.48215827, 0.44653124], [0.24703233, 0.24348505, 0.26158768]) | ||
] | ||
train_dataset = serialize(CIFAR10, 'data', train=True, download=True, transform=transforms.Compose(transf + normalize)) | ||
test_dataset = serialize(CIFAR10, 'data', train=False, transform=transforms.Compose(normalize)) | ||
|
||
# specify training hyper-parameters | ||
training_module = NasBench101TrainingModule(max_epochs=epochs) | ||
# FIXME: need to fix a bug in serializer for this to work | ||
# lr_monitor = serialize(LearningRateMonitor, logging_interval='step') | ||
trainer = pl.Trainer(max_epochs=epochs, gpus=1) | ||
lightning = pl.Lightning( | ||
lightning_module=training_module, | ||
trainer=trainer, | ||
train_dataloader=pl.DataLoader(train_dataset, batch_size=batch_size, shuffle=True), | ||
val_dataloaders=pl.DataLoader(test_dataset, batch_size=batch_size), | ||
) | ||
|
||
strategy = Random() | ||
|
||
model = NasBench101() | ||
|
||
exp = RetiariiExperiment(model, lightning, [], strategy) | ||
|
||
exp_config = RetiariiExeConfig('local') | ||
exp_config.trial_concurrency = 2 | ||
exp_config.max_trial_number = 20 | ||
exp_config.trial_gpu_number = 1 | ||
exp_config.training_service.use_active_gpu = False | ||
|
||
exp.run(exp_config, port) | ||
|
||
|
||
if __name__ == '__main__': | ||
_multi_trial_test() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.