diff --git a/examples/nas/multi-trial/nasbench101/base_ops.py b/examples/nas/multi-trial/nasbench101/base_ops.py new file mode 100644 index 0000000000..abe2c2730d --- /dev/null +++ b/examples/nas/multi-trial/nasbench101/base_ops.py @@ -0,0 +1,51 @@ +import math + +import torch.nn as nn + + +def truncated_normal_(tensor, mean=0, std=1): + # https://discuss.pytorch.org/t/implementing-truncated-normal-initializer/4778/15 + size = tensor.shape + tmp = tensor.new_empty(size + (4,)).normal_() + valid = (tmp < 2) & (tmp > -2) + ind = valid.max(-1, keepdim=True)[1] + tensor.data.copy_(tmp.gather(-1, ind).squeeze(-1)) + tensor.data.mul_(std).add_(mean) + + +class ConvBnRelu(nn.Module): + def __init__(self, in_channels, out_channels, kernel_size=1, stride=1, padding=0): + super(ConvBnRelu, self).__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.conv_bn_relu = nn.Sequential( + nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias=False), + nn.BatchNorm2d(out_channels), + nn.ReLU(inplace=True) + ) + self.reset_parameters() + + def reset_parameters(self): + for m in self.modules(): + if isinstance(m, nn.Conv2d): + fan_in = m.kernel_size[0] * m.kernel_size[1] * m.in_channels + truncated_normal_(m.weight.data, mean=0., std=math.sqrt(1. / fan_in)) + if isinstance(m, nn.BatchNorm2d): + m.weight.data.fill_(1) + m.bias.data.zero_() + + def forward(self, x): + return self.conv_bn_relu(x) + + +class Conv3x3BnRelu(ConvBnRelu): + def __init__(self, in_channels, out_channels): + super(Conv3x3BnRelu, self).__init__(in_channels, out_channels, kernel_size=3, stride=1, padding=1) + + +class Conv1x1BnRelu(ConvBnRelu): + def __init__(self, in_channels, out_channels): + super(Conv1x1BnRelu, self).__init__(in_channels, out_channels, kernel_size=1, stride=1, padding=0) + + +Projection = Conv1x1BnRelu diff --git a/examples/nas/multi-trial/nasbench101/network.py b/examples/nas/multi-trial/nasbench101/network.py new file mode 100644 index 0000000000..df8bee2182 --- /dev/null +++ b/examples/nas/multi-trial/nasbench101/network.py @@ -0,0 +1,173 @@ +import click +import nni +import nni.retiarii.evaluator.pytorch.lightning as pl +import torch.nn as nn +import torchmetrics +from nni.retiarii import model_wrapper, serialize, serialize_cls +from nni.retiarii.experiment.pytorch import RetiariiExperiment, RetiariiExeConfig +from nni.retiarii.nn.pytorch import NasBench101Cell +from nni.retiarii.strategy import Random +from pytorch_lightning.callbacks import LearningRateMonitor +from timm.optim import RMSpropTF +from torch.optim.lr_scheduler import CosineAnnealingLR +from torchvision import transforms +from torchvision.datasets import CIFAR10 + +from base_ops import Conv3x3BnRelu, Conv1x1BnRelu, Projection + + +@model_wrapper +class NasBench101(nn.Module): + def __init__(self, + stem_out_channels: int = 128, + num_stacks: int = 3, + num_modules_per_stack: int = 3, + max_num_vertices: int = 7, + max_num_edges: int = 9, + num_labels: int = 10, + bn_eps: float = 1e-5, + bn_momentum: float = 0.003): + super().__init__() + + op_candidates = { + 'conv3x3': lambda num_features: Conv3x3BnRelu(num_features, num_features), + 'conv1x1': lambda num_features: Conv1x1BnRelu(num_features, num_features), + 'maxpool': lambda num_features: nn.MaxPool2d(3, 1, 1) + } + + # initial stem convolution + self.stem_conv = Conv3x3BnRelu(3, stem_out_channels) + + layers = [] + in_channels = out_channels = stem_out_channels + for stack_num in range(num_stacks): + if stack_num > 0: + downsample = nn.MaxPool2d(kernel_size=2, stride=2) + layers.append(downsample) + out_channels *= 2 + for _ in range(num_modules_per_stack): + cell = NasBench101Cell(op_candidates, in_channels, out_channels, + lambda cin, cout: Projection(cin, cout), + max_num_vertices, max_num_edges, label='cell') + layers.append(cell) + in_channels = out_channels + + self.features = nn.ModuleList(layers) + self.gap = nn.AdaptiveAvgPool2d(1) + self.classifier = nn.Linear(out_channels, num_labels) + + for module in self.modules(): + if isinstance(module, nn.BatchNorm2d): + module.eps = bn_eps + module.momentum = bn_momentum + + def forward(self, x): + bs = x.size(0) + out = self.stem_conv(x) + for layer in self.features: + out = layer(out) + out = self.gap(out).view(bs, -1) + out = self.classifier(out) + return out + + def reset_parameters(self): + for module in self.modules(): + if isinstance(module, nn.BatchNorm2d): + module.eps = self.config.bn_eps + module.momentum = self.config.bn_momentum + + +class AccuracyWithLogits(torchmetrics.Accuracy): + def update(self, pred, target): + return super().update(nn.functional.softmax(pred), target) + + +@serialize_cls +class NasBench101TrainingModule(pl.LightningModule): + def __init__(self, max_epochs=108, learning_rate=0.1, weight_decay=1e-4): + super().__init__() + self.save_hyperparameters('learning_rate', 'weight_decay', 'max_epochs') + self.criterion = nn.CrossEntropyLoss() + self.accuracy = AccuracyWithLogits() + + def forward(self, x): + y_hat = self.model(x) + return y_hat + + def training_step(self, batch, batch_idx): + x, y = batch + y_hat = self(x) + loss = self.criterion(y_hat, y) + self.log('train_loss', loss, prog_bar=True) + self.log('train_accuracy', self.accuracy(y_hat, y), prog_bar=True) + return loss + + def validation_step(self, batch, batch_idx): + x, y = batch + y_hat = self(x) + self.log('val_loss', self.criterion(y_hat, y), prog_bar=True) + self.log('val_accuracy', self.accuracy(y_hat, y), prog_bar=True) + + def configure_optimizers(self): + optimizer = RMSpropTF(self.parameters(), lr=self.hparams.learning_rate, + weight_decay=self.hparams.weight_decay, + momentum=0.9, alpha=0.9, eps=1.0) + return { + 'optimizer': optimizer, + 'scheduler': CosineAnnealingLR(optimizer, self.hparams.max_epochs) + } + + def on_validation_epoch_end(self): + nni.report_intermediate_result(self.trainer.callback_metrics['val_accuracy'].item()) + + def teardown(self, stage): + if stage == 'fit': + nni.report_final_result(self.trainer.callback_metrics['val_accuracy'].item()) + + +@click.command() +@click.option('--epochs', default=108, help='Training length.') +@click.option('--batch_size', default=256, help='Batch size.') +@click.option('--port', default=8081, help='On which port the experiment is run.') +def _multi_trial_test(epochs, batch_size, port): + # initalize dataset. Note that 50k+10k is used. It's a little different from paper + transf = [ + transforms.RandomCrop(32, padding=4), + transforms.RandomHorizontalFlip() + ] + normalize = [ + transforms.ToTensor(), + transforms.Normalize([0.49139968, 0.48215827, 0.44653124], [0.24703233, 0.24348505, 0.26158768]) + ] + train_dataset = serialize(CIFAR10, 'data', train=True, download=True, transform=transforms.Compose(transf + normalize)) + test_dataset = serialize(CIFAR10, 'data', train=False, transform=transforms.Compose(normalize)) + + # specify training hyper-parameters + training_module = NasBench101TrainingModule(max_epochs=epochs) + # FIXME: need to fix a bug in serializer for this to work + # lr_monitor = serialize(LearningRateMonitor, logging_interval='step') + trainer = pl.Trainer(max_epochs=epochs, gpus=1) + lightning = pl.Lightning( + lightning_module=training_module, + trainer=trainer, + train_dataloader=pl.DataLoader(train_dataset, batch_size=batch_size, shuffle=True), + val_dataloaders=pl.DataLoader(test_dataset, batch_size=batch_size), + ) + + strategy = Random() + + model = NasBench101() + + exp = RetiariiExperiment(model, lightning, [], strategy) + + exp_config = RetiariiExeConfig('local') + exp_config.trial_concurrency = 2 + exp_config.max_trial_number = 20 + exp_config.trial_gpu_number = 1 + exp_config.training_service.use_active_gpu = False + + exp.run(exp_config, port) + + +if __name__ == '__main__': + _multi_trial_test() diff --git a/nni/retiarii/mutator.py b/nni/retiarii/mutator.py index e7d5708169..abf69703fd 100644 --- a/nni/retiarii/mutator.py +++ b/nni/retiarii/mutator.py @@ -1,12 +1,12 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. -from typing import (Any, Iterable, List, Optional) +from typing import (Any, Iterable, List, Optional, Tuple) from .graph import Model, Mutation, ModelStatus -__all__ = ['Sampler', 'Mutator'] +__all__ = ['Sampler', 'Mutator', 'InvalidMutation'] Choice = Any @@ -77,7 +77,7 @@ def apply(self, model: Model) -> Model: self._cur_choice_idx = None return copy - def dry_run(self, model: Model) -> List[List[Choice]]: + def dry_run(self, model: Model) -> Tuple[List[List[Choice]], Model]: """ Dry run mutator on a model to collect choice candidates. If you invoke this method multiple times on same or different models, @@ -115,3 +115,7 @@ def __init__(self): def choice(self, candidates: List[Choice], *args) -> Choice: self.recorded_candidates.append(candidates) return candidates[0] + + +class InvalidMutation(Exception): + pass diff --git a/nni/retiarii/nn/pytorch/api.py b/nni/retiarii/nn/pytorch/api.py index 69d12fb908..617991afb5 100644 --- a/nni/retiarii/nn/pytorch/api.py +++ b/nni/retiarii/nn/pytorch/api.py @@ -3,13 +3,13 @@ import copy import warnings -from collections import OrderedDict from typing import Any, List, Union, Dict, Optional import torch import torch.nn as nn from ...serializer import Translatable, basic_unit +from ...utils import NoContextError from .utils import generate_new_label, get_fixed_value @@ -26,6 +26,8 @@ class LayerChoice(nn.Module): ---------- candidates : list of nn.Module or OrderedDict A module list to be selected from. + prior : list of float + Prior distribution used in random sampling. label : str Identifier of the layer choice. @@ -55,17 +57,21 @@ class LayerChoice(nn.Module): ``self.op_choice[1] = nn.Conv3d(...)``. Adding more choices is not supported yet. """ - def __new__(cls, candidates: Union[Dict[str, nn.Module], List[nn.Module]], label: Optional[str] = None, **kwargs): + # FIXME: prior is designed but not supported yet + + def __new__(cls, candidates: Union[Dict[str, nn.Module], List[nn.Module]], *, + prior: Optional[List[float]] = None, label: Optional[str] = None, **kwargs): try: chosen = get_fixed_value(label) if isinstance(candidates, list): return candidates[int(chosen)] else: return candidates[chosen] - except AssertionError: + except NoContextError: return super().__new__(cls) - def __init__(self, candidates: Union[Dict[str, nn.Module], List[nn.Module]], label: Optional[str] = None, **kwargs): + def __init__(self, candidates: Union[Dict[str, nn.Module], List[nn.Module]], *, + prior: Optional[List[float]] = None, label: Optional[str] = None, **kwargs): super(LayerChoice, self).__init__() if 'key' in kwargs: warnings.warn(f'"key" is deprecated. Assuming label.') @@ -75,10 +81,12 @@ def __init__(self, candidates: Union[Dict[str, nn.Module], List[nn.Module]], lab if 'reduction' in kwargs: warnings.warn(f'"reduction" is deprecated. Ignoring...') self.candidates = candidates + self.prior = prior or [1 / len(candidates) for _ in range(len(candidates))] + assert abs(sum(self.prior) - 1) < 1e-5, 'Sum of prior distribution is not 1.' self._label = generate_new_label(label) self.names = [] - if isinstance(candidates, OrderedDict): + if isinstance(candidates, dict): for name, module in candidates.items(): assert name not in ["length", "reduction", "return_mask", "_key", "key", "names"], \ "Please don't use a reserved name '{}' for your module.".format(name) @@ -169,17 +177,23 @@ class InputChoice(nn.Module): Recommended inputs to choose. If None, mutator is instructed to select any. reduction : str ``mean``, ``concat``, ``sum`` or ``none``. + prior : list of float + Prior distribution used in random sampling. label : str Identifier of the input choice. """ - def __new__(cls, n_candidates: int, n_chosen: int = 1, reduction: str = 'sum', label: Optional[str] = None, **kwargs): + def __new__(cls, n_candidates: int, n_chosen: Optional[int] = 1, + reduction: str = 'sum', *, + prior: Optional[List[float]] = None, label: Optional[str] = None, **kwargs): try: return ChosenInputs(get_fixed_value(label), reduction=reduction) - except AssertionError: + except NoContextError: return super().__new__(cls) - def __init__(self, n_candidates: int, n_chosen: int = 1, reduction: str = 'sum', label: Optional[str] = None, **kwargs): + def __init__(self, n_candidates: int, n_chosen: Optional[int] = 1, + reduction: str = 'sum', *, + prior: Optional[List[float]] = None, label: Optional[str] = None, **kwargs): super(InputChoice, self).__init__() if 'key' in kwargs: warnings.warn(f'"key" is deprecated. Assuming label.') @@ -191,6 +205,7 @@ def __init__(self, n_candidates: int, n_chosen: int = 1, reduction: str = 'sum', self.n_candidates = n_candidates self.n_chosen = n_chosen self.reduction = reduction + self.prior = prior or [1 / n_candidates for _ in range(n_candidates)] assert self.reduction in ['mean', 'concat', 'sum', 'none'] self._label = generate_new_label(label) @@ -277,19 +292,25 @@ def forward(self, x): ---------- candidates : list List of values to choose from. + prior : list of float + Prior distribution to sample from. label : str Identifier of the value choice. """ - def __new__(cls, candidates: List[Any], label: Optional[str] = None): + # FIXME: prior is designed but not supported yet + + def __new__(cls, candidates: List[Any], *, prior: Optional[List[float]] = None, label: Optional[str] = None): try: return get_fixed_value(label) - except AssertionError: + except NoContextError: return super().__new__(cls) - def __init__(self, candidates: List[Any], label: Optional[str] = None): + def __init__(self, candidates: List[Any], *, prior: Optional[List[float]] = None, label: Optional[str] = None): super().__init__() self.candidates = candidates + self.prior = prior or [1 / len(candidates) for _ in range(len(candidates))] + assert abs(sum(self.prior) - 1) < 1e-5, 'Sum of prior distribution is not 1.' self._label = generate_new_label(label) self._accessor = [] @@ -323,7 +344,7 @@ def __copy__(self): return self def __deepcopy__(self, memo): - new_item = ValueChoice(self.candidates, self.label) + new_item = ValueChoice(self.candidates, label=self.label) new_item._accessor = [*self._accessor] return new_item diff --git a/nni/retiarii/nn/pytorch/component.py b/nni/retiarii/nn/pytorch/component.py index 20c6e30de5..8a1470b3b6 100644 --- a/nni/retiarii/nn/pytorch/component.py +++ b/nni/retiarii/nn/pytorch/component.py @@ -7,10 +7,12 @@ from .api import LayerChoice, InputChoice from .nn import ModuleList +from .nasbench101 import NasBench101Cell, NasBench101Mutator from .utils import generate_new_label, get_fixed_value +from ...utils import NoContextError -__all__ = ['Repeat', 'Cell'] +__all__ = ['Repeat', 'Cell', 'NasBench101Cell', 'NasBench101Mutator'] class Repeat(nn.Module): @@ -33,7 +35,7 @@ def __new__(cls, blocks: Union[Callable[[], nn.Module], List[Callable[[], nn.Mod try: repeat = get_fixed_value(label) return nn.Sequential(*cls._replicate_and_instantiate(blocks, repeat)) - except AssertionError: + except NoContextError: return super().__new__(cls) def __init__(self, diff --git a/nni/retiarii/nn/pytorch/mutator.py b/nni/retiarii/nn/pytorch/mutator.py index 4307aeac68..d515f6b7eb 100644 --- a/nni/retiarii/nn/pytorch/mutator.py +++ b/nni/retiarii/nn/pytorch/mutator.py @@ -9,7 +9,7 @@ from ...mutator import Mutator from ...graph import Cell, Graph, Model, ModelStatus, Node from .api import LayerChoice, InputChoice, ValueChoice, Placeholder -from .component import Repeat +from .component import Repeat, NasBench101Cell, NasBench101Mutator from ...utils import uid @@ -47,7 +47,12 @@ def mutate(self, model): n_candidates = self.nodes[0].operation.parameters['n_candidates'] n_chosen = self.nodes[0].operation.parameters['n_chosen'] candidates = list(range(n_candidates)) - chosen = [self.choice(candidates) for _ in range(n_chosen)] + if n_chosen is None: + chosen = [i for i in candidates if self.choice([False, True])] + # FIXME This is a hack to make choice align with the previous format + self._cur_samples = chosen + else: + chosen = [self.choice(candidates) for _ in range(n_chosen)] for node in self.nodes: target = model.get_node_by_name(node.name) target.update_operation('__torch__.nni.retiarii.nn.pytorch.ChosenInputs', @@ -199,8 +204,15 @@ def number_of_chosen(node): def mutate(self, model: Model): # this mutate does not have any effect, but it is recorded in the mutation history for node in model.get_nodes_by_label(self.label): - for _ in range(self.number_of_chosen(node)): - self.choice(self.candidates(node)) + n_chosen = self.number_of_chosen(node) + if n_chosen is None: + candidates = [i for i in self.candidates(node) if self.choice([False, True])] + # FIXME This is a hack to make choice align with the previous format + # For example, it will convert [False, True, True] into [1, 2]. + self._cur_samples = candidates + else: + for _ in range(n_chosen): + self.choice(self.candidates(node)) break @@ -242,6 +254,11 @@ def extract_mutation_from_pt_module(pytorch_model: nn.Module) -> Tuple[Model, Op 'candidates': list(range(module.min_depth, module.max_depth + 1)) }) node.label = module.label + if isinstance(module, NasBench101Cell): + node = graph.add_node(name, 'NasBench101Cell', { + 'max_num_edges': module.max_num_edges + }) + node.label = module.label if isinstance(module, Placeholder): raise NotImplementedError('Placeholder is not supported in python execution mode.') @@ -250,13 +267,17 @@ def extract_mutation_from_pt_module(pytorch_model: nn.Module) -> Tuple[Model, Op return model, None mutators = [] + mutators_final = [] for nodes in _group_by_label_and_type(graph.hidden_nodes): assert _is_all_equal(map(lambda n: n.operation.type, nodes)), \ f'Node with label "{nodes[0].label}" does not all have the same type.' assert _is_all_equal(map(lambda n: n.operation.parameters, nodes)), \ f'Node with label "{nodes[0].label}" does not agree on parameters.' - mutators.append(ManyChooseManyMutator(nodes[0].label)) - return model, mutators + if nodes[0].operation.type == 'NasBench101Cell': + mutators_final.append(NasBench101Mutator(nodes[0].label)) + else: + mutators.append(ManyChooseManyMutator(nodes[0].label)) + return model, mutators + mutators_final # utility functions diff --git a/nni/retiarii/nn/pytorch/nasbench101.py b/nni/retiarii/nn/pytorch/nasbench101.py new file mode 100644 index 0000000000..2f15421121 --- /dev/null +++ b/nni/retiarii/nn/pytorch/nasbench101.py @@ -0,0 +1,390 @@ +import logging +from collections import OrderedDict +from typing import Callable, List, Optional, Union, Dict + +import numpy as np +import torch +import torch.nn as nn + +from .api import InputChoice, ValueChoice, LayerChoice +from .utils import generate_new_label, get_fixed_dict +from ...mutator import InvalidMutation, Mutator +from ...graph import Model +from ...utils import NoContextError + +_logger = logging.getLogger(__name__) + + +def compute_vertex_channels(input_channels, output_channels, matrix): + """ + This is (almost) copied from the original NAS-Bench-101 implementation. + + Computes the number of channels at every vertex. + + Given the input channels and output channels, this calculates the number of channels at each interior vertex. + Interior vertices have the same number of channels as the max of the channels of the vertices it feeds into. + The output channels are divided amongst the vertices that are directly connected to it. + When the division is not even, some vertices may receive an extra channel to compensate. + + Parameters + ---------- + in_channels : int + input channels count. + output_channels : int + output channel count. + matrix : np.ndarray + adjacency matrix for the module (pruned by model_spec). + + Returns + ------- + list of int + list of channel counts, in order of the vertices. + """ + + num_vertices = np.shape(matrix)[0] + + vertex_channels = [0] * num_vertices + vertex_channels[0] = input_channels + vertex_channels[num_vertices - 1] = output_channels + + if num_vertices == 2: + # Edge case where module only has input and output vertices + return vertex_channels + + # Compute the in-degree ignoring input, axis 0 is the src vertex and axis 1 is + # the dst vertex. Summing over 0 gives the in-degree count of each vertex. + in_degree = np.sum(matrix[1:], axis=0) + interior_channels = output_channels // in_degree[num_vertices - 1] + correction = output_channels % in_degree[num_vertices - 1] # Remainder to add + + # Set channels of vertices that flow directly to output + for v in range(1, num_vertices - 1): + if matrix[v, num_vertices - 1]: + vertex_channels[v] = interior_channels + if correction: + vertex_channels[v] += 1 + correction -= 1 + + # Set channels for all other vertices to the max of the out edges, going backwards. + # (num_vertices - 2) index skipped because it only connects to output. + for v in range(num_vertices - 3, 0, -1): + if not matrix[v, num_vertices - 1]: + for dst in range(v + 1, num_vertices - 1): + if matrix[v, dst]: + vertex_channels[v] = max(vertex_channels[v], vertex_channels[dst]) + assert vertex_channels[v] > 0 + + _logger.debug('vertex_channels: %s', str(vertex_channels)) + + # Sanity check, verify that channels never increase and final channels add up. + final_fan_in = 0 + for v in range(1, num_vertices - 1): + if matrix[v, num_vertices - 1]: + final_fan_in += vertex_channels[v] + for dst in range(v + 1, num_vertices - 1): + if matrix[v, dst]: + assert vertex_channels[v] >= vertex_channels[dst] + assert final_fan_in == output_channels or num_vertices == 2 + # num_vertices == 2 means only input/output nodes, so 0 fan-in + + return vertex_channels + + +def prune(matrix, ops): + """ + Prune the extraneous parts of the graph. + + General procedure: + + 1. Remove parts of graph not connected to input. + 2. Remove parts of graph not connected to output. + 3. Reorder the vertices so that they are consecutive after steps 1 and 2. + + These 3 steps can be combined by deleting the rows and columns of the + vertices that are not reachable from both the input and output (in reverse). + """ + num_vertices = np.shape(matrix)[0] + + # calculate the connection matrix within V number of steps. + connections = np.linalg.matrix_power(matrix + np.eye(num_vertices), num_vertices) + + visited_from_input = set([i for i in range(num_vertices) if connections[0, i]]) + visited_from_output = set([i for i in range(num_vertices) if connections[i, -1]]) + + # Any vertex that isn't connected to both input and output is extraneous to the computation graph. + extraneous = set(range(num_vertices)).difference( + visited_from_input.intersection(visited_from_output)) + + if len(extraneous) > num_vertices - 2: + raise InvalidMutation('Non-extraneous graph is less than 2 vertices, ' + 'the input is not connected to the output and the spec is invalid.') + + matrix = np.delete(matrix, list(extraneous), axis=0) + matrix = np.delete(matrix, list(extraneous), axis=1) + for index in sorted(extraneous, reverse=True): + del ops[index] + return matrix, ops + + +def truncate(inputs, channels): + input_channels = inputs.size(1) + if input_channels < channels: + raise ValueError('input channel < output channels for truncate') + elif input_channels == channels: + return inputs # No truncation necessary + else: + # Truncation should only be necessary when channel division leads to + # vertices with +1 channels. The input vertex should always be projected to + # the minimum channel count. + assert input_channels - channels == 1 + return inputs[:, :channels] + + +class _NasBench101CellFixed(nn.Module): + """ + The fixed version of NAS-Bench-101 Cell, used in python-version execution engine. + """ + + def __init__(self, operations: List[Callable[[int], nn.Module]], + adjacency_list: List[List[int]], + in_features: int, out_features: int, num_nodes: int, + projection: Callable[[int, int], nn.Module]): + super().__init__() + + assert num_nodes == len(operations) + 2 == len(adjacency_list) + 1 + + self.operations = ['IN'] + operations + ['OUT'] # add psuedo nodes + self.connection_matrix = self.build_connection_matrix(adjacency_list, num_nodes) + del num_nodes # raw number of nodes is no longer used + + self.connection_matrix, self.operations = prune(self.connection_matrix, self.operations) + + self.hidden_features = compute_vertex_channels(in_features, out_features, self.connection_matrix) + + self.num_nodes = len(self.connection_matrix) + self.in_features = in_features + self.out_features = out_features + _logger.info('Prund number of nodes: %d', self.num_nodes) + _logger.info('Pruned connection matrix: %s', str(self.connection_matrix)) + + self.projections = nn.ModuleList([nn.Identity()]) + self.ops = nn.ModuleList([nn.Identity()]) + for i in range(1, self.num_nodes): + self.projections.append(projection(in_features, self.hidden_features[i])) + + for i in range(1, self.num_nodes - 1): + self.ops.append(operations[i - 1](self.hidden_features[i])) + + @staticmethod + def build_connection_matrix(adjacency_list, num_nodes): + adjacency_list = [[]] + adjacency_list # add adjacency for first node + connections = np.zeros((num_nodes, num_nodes), dtype='int') + for i, lst in enumerate(adjacency_list): + assert all([0 <= k < i for k in lst]) + for k in lst: + connections[k, i] = 1 + return connections + + def forward(self, inputs): + tensors = [inputs] + for t in range(1, self.num_nodes - 1): + + # Create interior connections, truncating if necessary + add_in = [truncate(tensors[src], self.hidden_features[t]) + for src in range(1, t) if self.connection_matrix[src, t]] + + # Create add connection from projected input + if self.connection_matrix[0, t]: + add_in.append(self.projections[t](tensors[0])) + + if len(add_in) == 1: + vertex_input = add_in[0] + else: + vertex_input = sum(add_in) + + # Perform op at vertex t + vertex_out = self.ops[t](vertex_input) + tensors.append(vertex_out) + + # Construct final output tensor by concating all fan-in and adding input. + if np.sum(self.connection_matrix[:, -1]) == 1: + src = np.where(self.connection_matrix[:, -1] == 1)[0][0] + return self.projections[-1](tensors[0]) if src == 0 else tensors[src] + + outputs = torch.cat([tensors[src] for src in range(1, self.num_nodes - 1) if self.connection_matrix[src, -1]], 1) + if self.connection_matrix[0, -1]: + outputs += self.projections[-1](tensors[0]) + assert outputs.size(1) == self.out_features + return outputs + + +class NasBench101Cell(nn.Module): + """ + Cell structure that is proposed in NAS-Bench-101 [nasbench101]_ . + + This cell is usually used in evaluation of NAS algorithms because there is a ``comprehensive analysis'' of this search space + available, which includes a full architecture-dataset that ``maps 423k unique architectures to metrics + including run time and accuracy''. You can also use the space in your own space design, in which scenario it should be possible + to leverage results in the benchmark to narrow the huge space down to a few efficient architectures. + + The space of this cell architecture consists of all possible directed acyclic graphs on no more than ``max_num_nodes`` nodes, + where each possible node (other than IN and OUT) has one of ``op_candidates``, representing the corresponding operation. + Edges connecting the nodes can be no more than ``max_num_edges``. + To align with the paper settings, two vertices specially labeled as operation IN and OUT, are also counted into + ``max_num_nodes`` in our implementaion, the default value of ``max_num_nodes`` is 7 and ``max_num_edges`` is 9. + + Input of this cell should be of shape :math:`[N, C_{in}, *]`, while output should be `[N, C_{out}, *]`. The shape + of each hidden nodes will be first automatically computed, depending on the cell structure. Each of the ``op_candidates`` + should be a callable that accepts computed ``num_features`` and returns a ``Module``. For example, + + .. code-block:: python + + def conv_bn_relu(num_features): + return nn.Sequential( + nn.Conv2d(num_features, num_features, 1), + nn.BatchNorm2d(num_features), + nn.ReLU() + ) + + The output of each node is the sum of its input node feed into its operation, except for the last node (output node), + which is the concatenation of its input *hidden* nodes, adding the *IN* node (if IN and OUT are connected). + + When input tensor is added with any other tensor, there could be shape mismatch. Therefore, a projection transformation + is needed to transform the input tensor. In paper, this is simply a Conv1x1 followed by BN and ReLU. The ``projection`` + parameters accepts ``in_features`` and ``out_features``, returns a ``Module``. This parameter has no default value, + as we hold no assumption that users are dealing with images. An example for this parameter is, + + .. code-block:: python + + def projection_fn(in_features, out_features): + return nn.Conv2d(in_features, out_features, 1) + + Parameters + ---------- + op_candidates : list of callable + Operation candidates. Each should be a function accepts number of feature, returning nn.Module. + in_features : int + Input dimension of cell. + out_features : int + Output dimension of cell. + projection : callable + Projection module that is used to preprocess the input tensor of the whole cell. + A callable that accept input feature and output feature, returning nn.Module. + max_num_nodes : int + Maximum number of nodes in the cell, input and output included. At least 2. Default: 7. + max_num_edges : int + Maximum number of edges in the cell. Default: 9. + label : str + Identifier of the cell. Cell sharing the same label will semantically share the same choice. + + References + ---------- + .. [nasbench101] Ying, Chris, et al. "Nas-bench-101: Towards reproducible neural architecture search." + International Conference on Machine Learning. PMLR, 2019. + """ + + @staticmethod + def _make_dict(x): + if isinstance(x, list): + return OrderedDict([(str(i), t) for i, t in enumerate(x)]) + return OrderedDict(x) + + def __new__(cls, op_candidates: Union[Dict[str, Callable[[int], nn.Module]], List[Callable[[int], nn.Module]]], + in_features: int, out_features: int, projection: Callable[[int, int], nn.Module], + max_num_nodes: int = 7, max_num_edges: int = 9, label: Optional[str] = None): + def make_list(x): return x if isinstance(x, list) else [x] + + try: + label, selected = get_fixed_dict(label) + op_candidates = cls._make_dict(op_candidates) + num_nodes = selected[f'{label}/num_nodes'] + adjacency_list = [make_list(selected[f'{label}/input_{i}']) for i in range(1, num_nodes)] + if sum([len(e) for e in adjacency_list]) > max_num_edges: + raise InvalidMutation(f'Expected {max_num_edges} edges, found: {adjacency_list}') + return _NasBench101CellFixed( + [op_candidates[selected[f'{label}/op_{i}']] for i in range(1, num_nodes - 1)], + adjacency_list, in_features, out_features, num_nodes, projection) + except NoContextError: + return super().__new__(cls) + + def __init__(self, op_candidates: Union[Dict[str, Callable[[int], nn.Module]], List[Callable[[int], nn.Module]]], + in_features: int, out_features: int, projection: Callable[[int, int], nn.Module], + max_num_nodes: int = 7, max_num_edges: int = 9, label: Optional[str] = None): + + super().__init__() + self._label = generate_new_label(label) + num_vertices_prior = [2 ** i for i in range(2, max_num_nodes + 1)] + num_vertices_prior = (np.array(num_vertices_prior) / sum(num_vertices_prior)).tolist() + self.num_nodes = ValueChoice(list(range(2, max_num_nodes + 1)), + prior=num_vertices_prior, + label=f'{self._label}/num_nodes') + self.max_num_nodes = max_num_nodes + self.max_num_edges = max_num_edges + + op_candidates = self._make_dict(op_candidates) + + # this is only for input validation and instantiating enough layer choice and input choice + self.hidden_features = out_features + + self.projections = nn.ModuleList([nn.Identity()]) + self.ops = nn.ModuleList([nn.Identity()]) + self.inputs = nn.ModuleList([nn.Identity()]) + for _ in range(1, max_num_nodes): + self.projections.append(projection(in_features, self.hidden_features)) + for i in range(1, max_num_nodes): + if i < max_num_nodes - 1: + self.ops.append(LayerChoice(OrderedDict([(k, op(self.hidden_features)) for k, op in op_candidates.items()]), + label=f'{self._label}/op_{i}')) + self.inputs.append(InputChoice(i, None, label=f'{self._label}/input_{i}')) + + @property + def label(self): + return self._label + + def forward(self, x): + # This is a dummy forward and actually not used + tensors = [x] + for i in range(1, self.max_num_nodes): + node_input = self.inputs[i]([self.projections[i](tensors[0])] + [t for t in tensors[1:]]) + if i < self.max_num_nodes - 1: + node_output = self.ops[i](node_input) + else: + node_output = node_input + tensors.append(node_output) + return tensors[-1] + + +class NasBench101Mutator(Mutator): + # for validation purposes + # for python execution engine + + def __init__(self, label: Optional[str]): + super().__init__(label=label) + + @staticmethod + def candidates(node): + if 'n_candidates' in node.operation.parameters: + return list(range(node.operation.parameters['n_candidates'])) + else: + return node.operation.parameters['candidates'] + + @staticmethod + def number_of_chosen(node): + if 'n_chosen' in node.operation.parameters: + return node.operation.parameters['n_chosen'] + return 1 + + def mutate(self, model: Model): + for node in model.get_nodes_by_label(self.label): + max_num_edges = node.operation.parameters['max_num_edges'] + break + mutation_dict = {mut.mutator.label: mut.samples for mut in model.history} + num_nodes = mutation_dict[f'{self.label}/num_nodes'][0] + adjacency_list = [mutation_dict[f'{self.label}/input_{i}'] for i in range(1, num_nodes)] + if sum([len(e) for e in adjacency_list]) > max_num_edges: + raise InvalidMutation(f'Expected {max_num_edges} edges, found: {adjacency_list}') + matrix = _NasBench101CellFixed.build_connection_matrix(adjacency_list, num_nodes) + prune(matrix, [None] * len(matrix)) # dummy ops, possible to raise InvalidMutation inside + + def dry_run(self, model): + return [], model diff --git a/nni/retiarii/nn/pytorch/utils.py b/nni/retiarii/nn/pytorch/utils.py index 352348b997..d577a47cfc 100644 --- a/nni/retiarii/nn/pytorch/utils.py +++ b/nni/retiarii/nn/pytorch/utils.py @@ -1,4 +1,4 @@ -from typing import Optional +from typing import Any, Optional, Tuple from ...utils import uid, get_current_context @@ -9,9 +9,21 @@ def generate_new_label(label: Optional[str]): return label -def get_fixed_value(label: str): +def get_fixed_value(label: str) -> Any: ret = get_current_context('fixed') try: return ret[generate_new_label(label)] except KeyError: raise KeyError(f'Fixed context with {label} not found. Existing values are: {ret}') + + +def get_fixed_dict(label_prefix: str) -> Tuple[str, Any]: + ret = get_current_context('fixed') + try: + label_prefix = generate_new_label(label_prefix) + ret = {k: v for k, v in ret.items() if k.startswith(label_prefix + '/')} + if not ret: + raise KeyError + return label_prefix, ret + except KeyError: + raise KeyError(f'Fixed context with prefix {label_prefix} not found. Existing values are: {ret}') diff --git a/nni/retiarii/strategy/bruteforce.py b/nni/retiarii/strategy/bruteforce.py index 971711f0d9..4046551ea7 100644 --- a/nni/retiarii/strategy/bruteforce.py +++ b/nni/retiarii/strategy/bruteforce.py @@ -8,7 +8,7 @@ import time from typing import Any, Dict, List -from .. import Sampler, submit_models, query_available_resources, budget_exhausted +from .. import InvalidMutation, Sampler, submit_models, query_available_resources, budget_exhausted from .base import BaseStrategy from .utils import dry_run_for_search_space, get_targeted_model @@ -121,4 +121,7 @@ def run(self, base_model, applied_mutators): if budget_exhausted(): return time.sleep(self._polling_interval) - submit_models(get_targeted_model(base_model, applied_mutators, sample)) + try: + submit_models(get_targeted_model(base_model, applied_mutators, sample)) + except InvalidMutation as e: + _logger.warning(f'Invalid mutation: {e}. Skip.') diff --git a/nni/retiarii/utils.py b/nni/retiarii/utils.py index c8b02dfba4..066dd993e3 100644 --- a/nni/retiarii/utils.py +++ b/nni/retiarii/utils.py @@ -67,6 +67,10 @@ def get_importable_name(cls, relocate_module=False): return module_name + '.' + cls.__name__ +class NoContextError(Exception): + pass + + class ContextStack: """ This is to maintain a globally-accessible context envinronment that is visible to everywhere. @@ -98,7 +102,8 @@ def pop(cls, key: str) -> None: @classmethod def top(cls, key: str) -> Any: - assert cls._stack[key], 'Context is empty.' + if not cls._stack[key]: + raise NoContextError('Context is empty.') return cls._stack[key][-1] diff --git a/test/.gitignore b/test/.gitignore index 1065b0ee85..35133a8063 100644 --- a/test/.gitignore +++ b/test/.gitignore @@ -10,3 +10,4 @@ _generated_model data generated lightning_logs +model.onnx diff --git a/test/ut/retiarii/test_highlevel_apis.py b/test/ut/retiarii/test_highlevel_apis.py index e246e748f2..db24f7ec3c 100644 --- a/test/ut/retiarii/test_highlevel_apis.py +++ b/test/ut/retiarii/test_highlevel_apis.py @@ -5,7 +5,7 @@ import nni.retiarii.nn.pytorch as nn import torch import torch.nn.functional as F -from nni.retiarii import Sampler, basic_unit +from nni.retiarii import InvalidMutation, Sampler, basic_unit from nni.retiarii.converter import convert_to_graph from nni.retiarii.codegen import model_to_pytorch_script from nni.retiarii.execution.python import _unpack_if_only_one @@ -518,3 +518,29 @@ def test_valuechoice_access_functional(self): ... @unittest.skip def test_valuechoice_access_functional_expression(self): ... + + def test_nasbench101_cell(self): + # this is only supported in python engine for now. + @self.get_serializer() + class Net(nn.Module): + def __init__(self): + super().__init__() + self.cell = nn.NasBench101Cell([lambda x: nn.Linear(x, x), lambda x: nn.Linear(x, x, bias=False)], + 10, 16, lambda x, y: nn.Linear(x, y), max_num_nodes=5, max_num_edges=7) + + def forward(self, x): + return self.cell(x) + + raw_model, mutators = self._get_model_with_mutators(Net()) + + succeeded = 0 + sampler = RandomSampler() + while succeeded <= 10: + try: + model = raw_model + for mutator in mutators: + model = mutator.bind_sampler(sampler).apply(model) + succeeded += 1 + except InvalidMutation: + continue + self.assertTrue(self._get_converted_pytorch_model(model)(torch.randn(2, 10)).size() == torch.Size([2, 16]))