Skip to content
This repository has been archived by the owner on Sep 18, 2024. It is now read-only.

ENAS and DRATS search space zoo #2589

Merged
merged 47 commits into from
Jul 27, 2020
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
47 commits
Select commit Hold shift + click to select a range
673cf3d
add darts cell and search space
tabVersion Jun 23, 2020
72f9f12
move to search_space_zoo
tabVersion Jun 23, 2020
99c841b
accept a cell to build full model
tabVersion Jun 24, 2020
e0e9e2c
fix compile error
tabVersion Jun 24, 2020
b55c6cd
bug fix
tabVersion Jun 24, 2020
3e162f4
change DartsCell signiture
tabVersion Jun 27, 2020
181e4c1
format code
tabVersion Jun 27, 2020
e35ff4b
change signature & inherit sequencial
tabVersion Jun 29, 2020
7896cb4
add search space example
tabVersion Jun 29, 2020
cd4eb1f
structure adjust & comment change
tabVersion Jun 29, 2020
736d196
clearify darts search space doc
tabVersion Jun 29, 2020
8c4f0bc
move dartsStackCells to example
tabVersion Jun 30, 2020
cf720c9
update docs
tabVersion Jul 3, 2020
823b0be
Merge branch 'master' into darts
tabVersion Jul 3, 2020
5d41f19
doc missing fix
tabVersion Jul 3, 2020
510dc38
Merge branch 'darts' of https://github.com/tabVersion/nni into darts
tabVersion Jul 3, 2020
03f7a28
doc fix
tabVersion Jul 6, 2020
40c517a
change code to fix doc
tabVersion Jul 6, 2020
8696f96
enas test
tabVersion Jul 6, 2020
1a46bf0
enas test
tabVersion Jul 6, 2020
40ab64a
enas test
tabVersion Jul 6, 2020
473e247
enas micro
tabVersion Jul 6, 2020
94f3eba
code format & doc fix & add example
tabVersion Jul 6, 2020
9efdb8a
refine doc
tabVersion Jul 7, 2020
5e2ed66
code format
tabVersion Jul 7, 2020
4cfdb10
add enas micro doc
tabVersion Jul 8, 2020
6a7a6ba
fix trailing whitespace
tabVersion Jul 8, 2020
8ddf8f1
add enas macro
tabVersion Jul 9, 2020
c0aecff
format doc
tabVersion Jul 9, 2020
c785024
fix doc
tabVersion Jul 9, 2020
0316d31
fix systax
tabVersion Jul 9, 2020
f6e9565
fix
tabVersion Jul 9, 2020
1b0c398
refine doc
tabVersion Jul 11, 2020
5b3dc94
refine doc
tabVersion Jul 13, 2020
f12df2c
update
tabVersion Jul 13, 2020
6cdfc5c
refine
tabVersion Jul 14, 2020
ec6ac2b
refine doc
tabVersion Jul 15, 2020
7ef03a6
refine doc
tabVersion Jul 16, 2020
199efb7
doc refine
tabVersion Jul 20, 2020
5b9c3ae
change sketch
tabVersion Jul 22, 2020
d5f63e2
change illustration
tabVersion Jul 24, 2020
05875f0
resolution fix
tabVersion Jul 24, 2020
2a5c434
update doc
tabVersion Jul 24, 2020
1d12bec
update doc
tabVersion Jul 24, 2020
de282c5
update doc
tabVersion Jul 24, 2020
63b20ce
doc
tabVersion Jul 24, 2020
8eb2afa
adjust menu sequence
tabVersion Jul 27, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,12 @@
import torch.nn as nn

import datasets
# from model import CNN
from nni.nas.pytorch.callbacks import ArchitectureCheckpoint, LRSchedulerCallback
from nni.nas.pytorch.darts import DartsTrainer
from utils import accuracy

from nni.nas.pytorch.search_space_zoo import DartsCell
from nni.nas.pytorch.search_space_zoo import DartsStackedCells
from darts_search_space import DartsStackedCells

logger = logging.getLogger('nni')

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,20 +22,8 @@ class DartsStackedCells(nn.Module):
n_layers: int
the number of cells contained in this network
factory_func: function
return a callable instance for demand cell structure
user should pass in ``__init__`` of the cell class with following parameters (see darts_cell for detail)
n_nodes: int
the number of nodes contained in this cell
channels_pp: int
the number of previous previous cell's output channels
channels_p: int
the number of previous cell's output channels
channels: int
the number of output channels for each node
reduction_p: bool
Is previous cell a reduction cell
reduction: bool
is current cell a reduction cell
return a callable instance for demand cell structure.
user should pass in ``__init__`` of the cell class with required parameters (see nni.nas.DartsCell for detail)
tabVersion marked this conversation as resolved.
Show resolved Hide resolved
n_nodes: int
the number of nodes contained in each cell
stem_multiplier: int
Expand Down
56 changes: 56 additions & 0 deletions examples/nas/search_space_zoo/datasets.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

import numpy as np
import torch
from torchvision import transforms
from torchvision.datasets import CIFAR10


class Cutout(object):
def __init__(self, length):
self.length = length

def __call__(self, img):
h, w = img.size(1), img.size(2)
mask = np.ones((h, w), np.float32)
y = np.random.randint(h)
x = np.random.randint(w)

y1 = np.clip(y - self.length // 2, 0, h)
y2 = np.clip(y + self.length // 2, 0, h)
x1 = np.clip(x - self.length // 2, 0, w)
x2 = np.clip(x + self.length // 2, 0, w)

mask[y1: y2, x1: x2] = 0.
mask = torch.from_numpy(mask)
mask = mask.expand_as(img)
img *= mask

return img


def get_dataset(cls, cutout_length=0):
MEAN = [0.49139968, 0.48215827, 0.44653124]
STD = [0.24703233, 0.24348505, 0.26158768]
transf = [
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip()
]
normalize = [
transforms.ToTensor(),
transforms.Normalize(MEAN, STD)
]
cutout = []
if cutout_length > 0:
cutout.append(Cutout(cutout_length))

train_transform = transforms.Compose(transf + normalize + cutout)
valid_transform = transforms.Compose(normalize)

if cls == "cifar10":
dataset_train = CIFAR10(root="./data", train=True, download=True, transform=train_transform)
dataset_valid = CIFAR10(root="./data", train=False, download=True, transform=valid_transform)
else:
raise NotImplementedError
return dataset_train, dataset_valid
21 changes: 21 additions & 0 deletions examples/nas/search_space_zoo/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

def accuracy(output, target, topk=(1,)):
""" Computes the precision@k for the specified values of k """
maxk = max(topk)
batch_size = target.size(0)

_, pred = output.topk(maxk, 1, True, True)
pred = pred.t()
# one-hot case
if target.ndimension() > 1:
target = target.max(1)[1]

correct = pred.eq(target.view(1, -1).expand_as(pred))

res = dict()
for k in topk:
correct_k = correct[:k].view(-1).float().sum(0)
res["acc{}".format(k)] = correct_k.mul_(1.0 / batch_size).item()
return res
2 changes: 1 addition & 1 deletion src/sdk/pynni/nni/nas/pytorch/search_space_zoo/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
from .darts_cell import DartsCell
from .darts_search_space import DartsStackedCells