Skip to content
This repository was archived by the owner on Sep 18, 2024. It is now read-only.

Commit eb9b72a

Browse files
authored
ENAS and DRATS search space zoo (#2589)
1 parent f8633ac commit eb9b72a

18 files changed

+1372
-1
lines changed

docs/en_US/NAS/Overview.md

+11
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,17 @@ Please refer to [here](NasGuide.md) for the usage of one-shot NAS algorithms.
5454
One-shot NAS can be visualized with our visualization tool. Learn more details [here](./Visualization.md).
5555

5656

57+
58+
## Search Space Zoo
59+
60+
NNI provides some predefined search space which can be easily reused. By stacking the extracted cells, user can quickly reproduce those NAS models.
61+
62+
Search Space Zoo contains the following NAS cells:
63+
64+
* [DartsCell](./SearchSpaceZoo.md#DartsCell)
65+
* [ENAS micro](./SearchSpaceZoo.md#ENASMicroLayer)
66+
* [ENAS macro](./SearchSpaceZoo.md#ENASMacroLayer)
67+
5768
## Using NNI API to Write Your Search Space
5869

5970
The programming interface of designing and searching a model is often demanded in two scenarios.

docs/en_US/NAS/SearchSpaceZoo.md

+175
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,175 @@
1+
# Search Space Zoo
2+
3+
## DartsCell
4+
5+
DartsCell is extracted from [CNN model](./DARTS.md) designed [here](https://github.com/microsoft/nni/tree/master/examples/nas/darts). A DartsCell is a directed acyclic graph containing an ordered sequence of N nodes and each node stands for a latent representation (e.g. feature map in a convolutional network). Directed edges from Node 1 to Node 2 are associated with some operations that transform Node 1 and the result is stored on Node 2. The [operations](#darts-predefined-operations) between nodes is predefined and unchangeable. One edge represents an operation that chosen from the predefined ones to be applied to the starting node of the edge. One cell contains two input nodes, a single output node, and other `n_node` nodes. The input nodes are defined as the cell outputs in the previous two layers. The output of the cell is obtained by applying a reduction operation (e.g. concatenation) to all the intermediate nodes. To make the search space continuous, the categorical choice of a particular operation is relaxed to a softmax over all possible operations. By adjusting the weight of softmax on every node, the operation with the highest probability is chosen to be part of the final structure. A CNN model can be formed by stacking several cells together, which builds a search space. Note that, in DARTS paper all cells in the model share the same structure.
6+
7+
One structure in the Darts search space is shown below. Note that, NNI merges the last one of the four intermediate nodes and the output node.
8+
9+
![](../../img/NAS_Darts_cell.svg)
10+
11+
The predefined operations are shown in [references](#predefined-operations-darts).
12+
13+
```eval_rst
14+
.. autoclass:: nni.nas.pytorch.search_space_zoo.DartsCell
15+
:members:
16+
```
17+
18+
### Example code
19+
20+
[example code](https://github.com/microsoft/nni/tree/master/examples/nas/search_space_zoo/darts_example.py)
21+
22+
```bash
23+
git clone https://github.com/Microsoft/nni.git
24+
cd nni/examples/nas/search_space_zoo
25+
# search the best structure
26+
python3 darts_example.py
27+
```
28+
29+
<a name="predefined-operations-darts"></a>
30+
31+
### References
32+
33+
All supported operations for Darts are listed below.
34+
35+
* MaxPool / AvgPool
36+
* MaxPool: Call `torch.nn.MaxPool2d`. This operation applies a 2D max pooling over all input channels. Its parameters `kernel_size=3` and `padding=1` are fixed. The pooling result will pass through a BatchNorm2d then return as the result.
37+
* AvgPool: Call `torch.nn.AvgPool2d`. This operation applies a 2D average pooling over all input channels. Its parameters `kernel_size=3` and `padding=1` are fixed. The pooling result will pass through a BatchNorm2d then return as the result.
38+
39+
MaxPool / AvgPool with `kernel_size=3` and `padding=1` followed by BatchNorm2d
40+
```eval_rst
41+
.. autoclass:: nni.nas.pytorch.search_space_zoo.darts_ops.PoolBN
42+
```
43+
* SkipConnect
44+
45+
There is no operation between two nodes. Call `torch.nn.Identity` to forward what it gets to the output.
46+
* Zero operation
47+
48+
There is no connection between two nodes.
49+
* DilConv3x3 / DilConv5x5
50+
51+
<a name="DilConv"></a>DilConv3x3: (Dilated) depthwise separable Conv. It's a 3x3 depthwise convolution with `C_in` groups, followed by a 1x1 pointwise convolution. It reduces the amount of parameters. Input is first passed through relu, then DilConv and finally batchNorm2d. **Note that the operation is not Dilated Convolution, but we follow the convention in NAS papers to name it DilConv.** 3x3 DilConv has parameters `kernel_size=3`, `padding=1` and 5x5 DilConv has parameters `kernel_size=5`, `padding=4`.
52+
```eval_rst
53+
.. autoclass:: nni.nas.pytorch.search_space_zoo.darts_ops.DilConv
54+
```
55+
* SepConv3x3 / SepConv5x5
56+
57+
Composed of two DilConvs with fixed `kernel_size=3`, `padding=1` or `kernel_size=5`, `padding=2` sequentially.
58+
```eval_rst
59+
.. autoclass:: nni.nas.pytorch.search_space_zoo.darts_ops.SepConv
60+
```
61+
62+
## ENASMicroLayer
63+
64+
This layer is extracted from the model designed [here](https://github.com/microsoft/nni/tree/master/examples/nas/enas). A model contains several blocks that share the same architecture. A block is made up of some normal layers and reduction layers, `ENASMicroLayer` is a unified implementation of the two types of layers. The only difference between the two layers is that reduction layers apply all operations with `stride=2`.
65+
66+
ENAS Micro employs a DAG with N nodes in one cell, where the nodes represent local computations, and the edges represent the flow of information between the N nodes. One cell contains two input nodes and a single output node. The following nodes choose two previous nodes as input and apply two operations from [predefined ones](#predefined-operations-enas) then add them as the output of this node. For example, Node 4 chooses Node 1 and Node 3 as inputs then applies `MaxPool` and `AvgPool` on the inputs respectively, then adds and sums them as the output of Node 4. Nodes that are not served as input for any other node are viewed as the output of the layer. If there are multiple output nodes, the model will calculate the average of these nodes as the layer output.
67+
68+
One structure in the ENAS micro search space is shown below.
69+
70+
![](../../img/NAS_ENAS_micro.svg)
71+
72+
The predefined operations can be seen [here](#predefined-operations-enas).
73+
74+
```eval_rst
75+
.. autoclass:: nni.nas.pytorch.search_space_zoo.ENASMicroLayer
76+
:members:
77+
```
78+
79+
The Reduction Layer is made up of two Conv operations followed by BatchNorm, each of them will output `C_out//2` channels and concat them in channels as the output. The Convolution has `kernel_size=1` and `stride=2`, and they perform alternate sampling on the input to reduce the resolution without loss of information. This layer is wrapped in `ENASMicroLayer`.
80+
81+
### Example code
82+
83+
[example code](https://github.com/microsoft/nni/tree/master/examples/nas/search_space_zoo/enas_micro_example.py)
84+
85+
```bash
86+
git clone https://github.com/Microsoft/nni.git
87+
cd nni/examples/nas/search_space_zoo
88+
# search the best cell structure
89+
python3 enas_micro_example.py
90+
```
91+
92+
<a name="predefined-operations-enas"></a>
93+
94+
### References
95+
96+
All supported operations for ENAS micro search are listed below.
97+
98+
* MaxPool / AvgPool
99+
* MaxPool: Call `torch.nn.MaxPool2d`. This operation applies a 2D max pooling over all input channels followed by BatchNorm2d. Its parameters are fixed to `kernel_size=3`, `stride=1` and `padding=1`.
100+
* AvgPool: Call `torch.nn.AvgPool2d`. This operation applies a 2D average pooling over all input channels followed by BatchNorm2d. Its parameters are fixed to `kernel_size=3`, `stride=1` and `padding=1`.
101+
```eval_rst
102+
.. autoclass:: nni.nas.pytorch.search_space_zoo.enas_ops.Pool
103+
```
104+
105+
* SepConv
106+
* SepConvBN3x3: ReLU followed by a [DilConv](#DilConv) and BatchNorm. Convolution parameters are `kernel_size=3`, `stride=1` and `padding=1`.
107+
* SepConvBN5x5: Do the same operation as the previous one but it has different kernel sizes and paddings, which is set to 5 and 2 respectively.
108+
109+
```eval_rst
110+
.. autoclass:: nni.nas.pytorch.search_space_zoo.enas_ops.SepConvBN
111+
```
112+
113+
* SkipConnect
114+
115+
Call `torch.nn.Identity` to connect directly to the next cell.
116+
117+
## ENASMacroLayer
118+
119+
In Macro search, the controller makes two decisions for each layer: i) the [operation](#macro-operations) to perform on the result of the previous layer, ii) which the previous layer to connect to for SkipConnects. ENAS uses a controller to design the whole model architecture instead of one of its components. The output of operations is going to concat with the tensor of the chosen layer for SkipConnect. NNI provides [predefined operations](#macro-operations) for macro search, which are listed in [references](#macro-operations).
120+
121+
Part of one structure in the ENAS macro search space is shown below.
122+
123+
![](../../img/NAS_ENAS_macro.svg)
124+
125+
```eval_rst
126+
.. autoclass:: nni.nas.pytorch.search_space_zoo.ENASMacroLayer
127+
:members:
128+
```
129+
130+
To describe the whole search space, NNI provides a model, which is built by stacking the layers.
131+
132+
```eval_rst
133+
.. autoclass:: nni.nas.pytorch.search_space_zoo.ENASMacroGeneralModel
134+
:members:
135+
```
136+
137+
### Example code
138+
139+
[example code](https://github.com/microsoft/nni/tree/master/examples/nas/search_space_zoo/enas_macro_example.py)
140+
141+
```bash
142+
git clone https://github.com/Microsoft/nni.git
143+
cd nni/examples/nas/search_space_zoo
144+
# search the best cell structure
145+
python3 enas_macro_example.py
146+
```
147+
148+
<a name="macro-operations"></a>
149+
150+
### References
151+
152+
All supported operations for ENAS macro search are listed below.
153+
154+
* ConvBranch
155+
156+
All input first passes into a StdConv, which is made up of a 1x1Conv followed by BatchNorm2d and ReLU. Then the intermediate result goes through one of the operations listed below. The final result is calculated through a BatchNorm2d and ReLU as post-procedure.
157+
* Separable Conv3x3: If `separable=True`, the cell will use [SepConv](#DilConv) instead of normal Conv operation. SepConv's `kernel_size=3`, `stride=1` and `padding=1`.
158+
* Separable Conv5x5: SepConv's `kernel_size=5`, `stride=1` and `padding=2`.
159+
* Normal Conv3x3: If `separable=False`, the cell will use a normal Conv operations with `kernel_size=3`, `stride=1` and `padding=1`.
160+
* Normal Conv5x5: Conv's `kernel_size=5`, `stride=1` and `padding=2`.
161+
162+
```eval_rst
163+
.. autoclass:: nni.nas.pytorch.search_space_zoo.enas_ops.ConvBranch
164+
```
165+
* PoolBranch
166+
167+
All input first passes into a StdConv, which is made up of a 1x1Conv followed by BatchNorm2d and ReLU. Then the intermediate goes through pooling operation followed by BatchNorm.
168+
* AvgPool: Call `torch.nn.AvgPool2d`. This operation applies a 2D average pooling over all input channels. Its parameters are fixed to `kernel_size=3`, `stride=1` and `padding=1`.
169+
* MaxPool: Call `torch.nn.MaxPool2d`. This operation applies a 2D max pooling over all input channels. Its parameters are fixed to `kernel_size=3`, `stride=1` and `padding=1`.
170+
171+
```eval_rst
172+
.. autoclass:: nni.nas.pytorch.search_space_zoo.enas_ops.PoolBranch
173+
```
174+
175+
<!-- push -->

docs/en_US/nas.rst

+1
Original file line numberDiff line numberDiff line change
@@ -23,5 +23,6 @@ For details, please refer to the following tutorials:
2323
One-shot NAS <NAS/one_shot_nas>
2424
Customize a NAS Algorithm <NAS/Advanced>
2525
NAS Visualization <NAS/Visualization>
26+
Search Space Zoo <NAS/SearchSpaceZoo>
2627
NAS Benchmarks <NAS/Benchmarks>
2728
API Reference <NAS/NasReference>

docs/img/NAS_Darts_cell.svg

+1
Loading

docs/img/NAS_ENAS_macro.svg

+1
Loading

docs/img/NAS_ENAS_micro.svg

+1
Loading
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
# copyright (c) Microsoft Corporation.
2+
# Licensed under the MIT license.
3+
4+
import logging
5+
import time
6+
from argparse import ArgumentParser
7+
8+
import torch
9+
import torch.nn as nn
10+
11+
import datasets
12+
from nni.nas.pytorch.callbacks import ArchitectureCheckpoint, LRSchedulerCallback
13+
from nni.nas.pytorch.darts import DartsTrainer
14+
from utils import accuracy
15+
16+
from nni.nas.pytorch.search_space_zoo import DartsCell
17+
from darts_search_space import DartsStackedCells
18+
19+
logger = logging.getLogger('nni')
20+
21+
if __name__ == "__main__":
22+
parser = ArgumentParser("darts")
23+
parser.add_argument("--layers", default=8, type=int)
24+
parser.add_argument("--batch-size", default=64, type=int)
25+
parser.add_argument("--log-frequency", default=10, type=int)
26+
parser.add_argument("--epochs", default=50, type=int)
27+
parser.add_argument("--channels", default=16, type=int)
28+
parser.add_argument("--unrolled", default=False, action="store_true")
29+
parser.add_argument("--visualization", default=False, action="store_true")
30+
args = parser.parse_args()
31+
32+
dataset_train, dataset_valid = datasets.get_dataset("cifar10")
33+
34+
model = DartsStackedCells(3, args.channels, 10, args.layers, DartsCell)
35+
criterion = nn.CrossEntropyLoss()
36+
37+
optim = torch.optim.SGD(model.parameters(), 0.025, momentum=0.9, weight_decay=3.0E-4)
38+
lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optim, args.epochs, eta_min=0.001)
39+
40+
trainer = DartsTrainer(model,
41+
loss=criterion,
42+
metrics=lambda output, target: accuracy(output, target, topk=(1,)),
43+
optimizer=optim,
44+
num_epochs=args.epochs,
45+
dataset_train=dataset_train,
46+
dataset_valid=dataset_valid,
47+
batch_size=args.batch_size,
48+
log_frequency=args.log_frequency,
49+
unrolled=args.unrolled,
50+
callbacks=[LRSchedulerCallback(lr_scheduler), ArchitectureCheckpoint("./checkpoints")])
51+
if args.visualization:
52+
trainer.enable_visualization()
53+
trainer.train()
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
# Copyright (c) Microsoft Corporation.
2+
# Licensed under the MIT license.
3+
4+
import torch.nn as nn
5+
import ops
6+
7+
8+
class DartsStackedCells(nn.Module):
9+
"""
10+
builtin Darts Search Space
11+
Compared to Darts example, DartsSearchSpace removes Auxiliary Head, which
12+
is considered as a trick rather than part of model.
13+
14+
Attributes
15+
---
16+
in_channels: int
17+
the number of input channels
18+
channels: int
19+
the number of initial channels expected
20+
n_classes: int
21+
classes for final classification
22+
n_layers: int
23+
the number of cells contained in this network
24+
factory_func: function
25+
return a callable instance for demand cell structure.
26+
user should pass in ``__init__`` of the cell class with required parameters (see nni.nas.DartsCell for detail)
27+
n_nodes: int
28+
the number of nodes contained in each cell
29+
stem_multiplier: int
30+
channels multiply coefficient when passing a cell
31+
"""
32+
33+
def __init__(self, in_channels, channels, n_classes, n_layers, factory_func, n_nodes=4,
34+
stem_multiplier=3):
35+
super().__init__()
36+
self.in_channels = in_channels
37+
self.channels = channels
38+
self.n_classes = n_classes
39+
self.n_layers = n_layers
40+
41+
c_cur = stem_multiplier * self.channels
42+
self.stem = nn.Sequential(
43+
nn.Conv2d(in_channels, c_cur, 3, 1, 1, bias=False),
44+
nn.BatchNorm2d(c_cur)
45+
)
46+
47+
# for the first cell, stem is used for both s0 and s1
48+
# [!] channels_pp and channels_p is output channel size, but c_cur is input channel size.
49+
channels_pp, channels_p, c_cur = c_cur, c_cur, channels
50+
51+
self.cells = nn.ModuleList()
52+
reduction_p, reduction = False, False
53+
for i in range(n_layers):
54+
reduction_p, reduction = reduction, False
55+
# Reduce featuremap size and double channels in 1/3 and 2/3 layer.
56+
if i in [n_layers // 3, 2 * n_layers // 3]:
57+
c_cur *= 2
58+
reduction = True
59+
60+
cell = factory_func(n_nodes, channels_pp, channels_p, c_cur, reduction_p, reduction)
61+
self.cells.append(cell)
62+
c_cur_out = c_cur * n_nodes
63+
channels_pp, channels_p = channels_p, c_cur_out
64+
65+
self.gap = nn.AdaptiveAvgPool2d(1)
66+
self.linear = nn.Linear(channels_p, n_classes)
67+
68+
def forward(self, x):
69+
s0 = s1 = self.stem(x)
70+
71+
for cell in self.cells:
72+
s0, s1 = s1, cell(s0, s1)
73+
74+
out = self.gap(s1)
75+
out = out.view(out.size(0), -1) # flatten
76+
logits = self.linear(out)
77+
78+
return logits
79+
80+
def drop_path_prob(self, p):
81+
for module in self.modules():
82+
if isinstance(module, ops.DropPath):
83+
module.p = p
+56
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
# Copyright (c) Microsoft Corporation.
2+
# Licensed under the MIT license.
3+
4+
import numpy as np
5+
import torch
6+
from torchvision import transforms
7+
from torchvision.datasets import CIFAR10
8+
9+
10+
class Cutout(object):
11+
def __init__(self, length):
12+
self.length = length
13+
14+
def __call__(self, img):
15+
h, w = img.size(1), img.size(2)
16+
mask = np.ones((h, w), np.float32)
17+
y = np.random.randint(h)
18+
x = np.random.randint(w)
19+
20+
y1 = np.clip(y - self.length // 2, 0, h)
21+
y2 = np.clip(y + self.length // 2, 0, h)
22+
x1 = np.clip(x - self.length // 2, 0, w)
23+
x2 = np.clip(x + self.length // 2, 0, w)
24+
25+
mask[y1: y2, x1: x2] = 0.
26+
mask = torch.from_numpy(mask)
27+
mask = mask.expand_as(img)
28+
img *= mask
29+
30+
return img
31+
32+
33+
def get_dataset(cls, cutout_length=0):
34+
MEAN = [0.49139968, 0.48215827, 0.44653124]
35+
STD = [0.24703233, 0.24348505, 0.26158768]
36+
transf = [
37+
transforms.RandomCrop(32, padding=4),
38+
transforms.RandomHorizontalFlip()
39+
]
40+
normalize = [
41+
transforms.ToTensor(),
42+
transforms.Normalize(MEAN, STD)
43+
]
44+
cutout = []
45+
if cutout_length > 0:
46+
cutout.append(Cutout(cutout_length))
47+
48+
train_transform = transforms.Compose(transf + normalize + cutout)
49+
valid_transform = transforms.Compose(normalize)
50+
51+
if cls == "cifar10":
52+
dataset_train = CIFAR10(root="./data", train=True, download=True, transform=train_transform)
53+
dataset_valid = CIFAR10(root="./data", train=False, download=True, transform=valid_transform)
54+
else:
55+
raise NotImplementedError
56+
return dataset_train, dataset_valid

0 commit comments

Comments
 (0)