diff --git a/docs/en_US/NAS/ApiReference.rst b/docs/en_US/NAS/ApiReference.rst index 359fc5c3aa..be7a773c49 100644 --- a/docs/en_US/NAS/ApiReference.rst +++ b/docs/en_US/NAS/ApiReference.rst @@ -105,4 +105,6 @@ Retiarii Experiments Utilities --------- -.. autofunction:: nni.retiarii.serialize \ No newline at end of file +.. autofunction:: nni.retiarii.serialize + +.. autofunction:: nni.retiarii.fixed_arch diff --git a/docs/en_US/NAS/OneshotTrainer.rst b/docs/en_US/NAS/OneshotTrainer.rst index 76baa36b92..e276a48125 100644 --- a/docs/en_US/NAS/OneshotTrainer.rst +++ b/docs/en_US/NAS/OneshotTrainer.rst @@ -34,4 +34,10 @@ See `API reference <./ApiReference.rst>`__ for detailed usages. Here, we show an trainer.fit() final_architecture = trainer.export() -**Format of the exported architecture.** TBD. +After the searching is done, we can use the exported architecture to instantiate the full network for retraining. Here is an example: + +.. code-block:: python + + from nni.retiarii import fixed_arch + with fixed_arch('/path/to/checkpoint.json'): + model = Model() diff --git a/docs/en_US/NAS/WriteOneshot.rst b/docs/en_US/NAS/WriteOneshot.rst index c190099f3e..19b546b3de 100644 --- a/docs/en_US/NAS/WriteOneshot.rst +++ b/docs/en_US/NAS/WriteOneshot.rst @@ -16,7 +16,7 @@ A typical example is DartsTrainer, where learnable-parameters are used to combin class DartsLayerChoice(nn.Module): def __init__(self, layer_choice): super(DartsLayerChoice, self).__init__() - self.name = layer_choice.key + self.name = layer_choice.label self.op_choices = nn.ModuleDict(layer_choice.named_children()) self.alpha = nn.Parameter(torch.randn(len(self.op_choices)) * 1e-3) diff --git a/examples/nas/oneshot/darts/model.py b/examples/nas/oneshot/darts/model.py index 53e567faee..c4135463ae 100644 --- a/examples/nas/oneshot/darts/model.py +++ b/examples/nas/oneshot/darts/model.py @@ -7,7 +7,7 @@ import torch.nn as nn import ops -from nni.nas.pytorch import mutables +from nni.retiarii.nn.pytorch import LayerChoice, InputChoice class AuxiliaryHead(nn.Module): @@ -45,7 +45,7 @@ def __init__(self, node_id, num_prev_nodes, channels, num_downsample_connect): stride = 2 if i < num_downsample_connect else 1 choice_keys.append("{}_p{}".format(node_id, i)) self.ops.append( - mutables.LayerChoice(OrderedDict([ + LayerChoice(OrderedDict([ ("maxpool", ops.PoolBN('max', channels, 3, stride, 1, affine=False)), ("avgpool", ops.PoolBN('avg', channels, 3, stride, 1, affine=False)), ("skipconnect", nn.Identity() if stride == 1 else ops.FactorizedReduce(channels, channels, affine=False)), @@ -53,9 +53,9 @@ def __init__(self, node_id, num_prev_nodes, channels, num_downsample_connect): ("sepconv5x5", ops.SepConv(channels, channels, 5, stride, 2, affine=False)), ("dilconv3x3", ops.DilConv(channels, channels, 3, stride, 2, 2, affine=False)), ("dilconv5x5", ops.DilConv(channels, channels, 5, stride, 4, 2, affine=False)) - ]), key=choice_keys[-1])) + ]), label=choice_keys[-1])) self.drop_path = ops.DropPath() - self.input_switch = mutables.InputChoice(choose_from=choice_keys, n_chosen=2, key="{}_switch".format(node_id)) + self.input_switch = InputChoice(n_candidates=len(choice_keys), n_chosen=2, label="{}_switch".format(node_id)) def forward(self, prev_nodes): assert len(self.ops) == len(prev_nodes) diff --git a/examples/nas/oneshot/darts/retrain.py b/examples/nas/oneshot/darts/retrain.py index 16f8b46129..30765fea57 100644 --- a/examples/nas/oneshot/darts/retrain.py +++ b/examples/nas/oneshot/darts/retrain.py @@ -12,8 +12,8 @@ import datasets import utils from model import CNN -from nni.nas.pytorch.fixed import apply_fixed_architecture from nni.nas.pytorch.utils import AverageMeter +from nni.retiarii import fixed_arch logger = logging.getLogger('nni') @@ -119,8 +119,8 @@ def validate(config, valid_loader, model, criterion, epoch, cur_step): args = parser.parse_args() dataset_train, dataset_valid = datasets.get_dataset("cifar10", cutout_length=16) - model = CNN(32, 3, 36, 10, args.layers, auxiliary=True) - apply_fixed_architecture(model, args.arc_checkpoint) + with fixed_arch(args.arc_checkpoint): + model = CNN(32, 3, 36, 10, args.layers, auxiliary=True) criterion = nn.CrossEntropyLoss() model.to(device) diff --git a/nni/retiarii/__init__.py b/nni/retiarii/__init__.py index f441367460..0b387e6c1b 100644 --- a/nni/retiarii/__init__.py +++ b/nni/retiarii/__init__.py @@ -4,5 +4,6 @@ from .operation import Operation from .graph import * from .execution import * +from .fixed import fixed_arch from .mutator import * from .serializer import basic_unit, json_dump, json_dumps, json_load, json_loads, serialize, serialize_cls, model_wrapper diff --git a/nni/retiarii/fixed.py b/nni/retiarii/fixed.py new file mode 100644 index 0000000000..e85cea582e --- /dev/null +++ b/nni/retiarii/fixed.py @@ -0,0 +1,40 @@ +import json +import logging +from pathlib import Path +from typing import Union, Dict, Any + +from .utils import ContextStack + +_logger = logging.getLogger(__name__) + + +def fixed_arch(fixed_arch: Union[str, Path, Dict[str, Any]], verbose=True): + """ + Load architecture from ``fixed_arch`` and apply to model. This should be used as a context manager. For example, + + .. code-block:: python + + with fixed_arch('/path/to/export.json'): + model = Model(3, 224, 224) + + Parameters + ---------- + fixed_arc : str, Path or dict + Path to the JSON that stores the architecture, or dict that stores the exported architecture. + verbose : bool + Print log messages if set to True + + Returns + ------- + ContextStack + Context manager that provides a fixed architecture when creates the model. + """ + + if isinstance(fixed_arch, (str, Path)): + with open(fixed_arch) as f: + fixed_arch = json.load(f) + + if verbose: + _logger.info(f'Fixed architecture: %s', fixed_arch) + + return ContextStack('fixed', fixed_arch) diff --git a/nni/retiarii/oneshot/pytorch/darts.py b/nni/retiarii/oneshot/pytorch/darts.py index edcb6d7b86..5b014d00a5 100644 --- a/nni/retiarii/oneshot/pytorch/darts.py +++ b/nni/retiarii/oneshot/pytorch/darts.py @@ -3,6 +3,7 @@ import copy import logging +from collections import OrderedDict import torch import torch.nn as nn @@ -18,8 +19,8 @@ class DartsLayerChoice(nn.Module): def __init__(self, layer_choice): super(DartsLayerChoice, self).__init__() - self.name = layer_choice.key - self.op_choices = nn.ModuleDict(layer_choice.named_children()) + self.name = layer_choice.label + self.op_choices = nn.ModuleDict(OrderedDict([(name, layer_choice[name]) for name in layer_choice.names])) self.alpha = nn.Parameter(torch.randn(len(self.op_choices)) * 1e-3) def forward(self, *args, **kwargs): @@ -38,13 +39,13 @@ def named_parameters(self): yield name, p def export(self): - return torch.argmax(self.alpha).item() + return list(self.op_choices.keys())[torch.argmax(self.alpha).item()] class DartsInputChoice(nn.Module): def __init__(self, input_choice): super(DartsInputChoice, self).__init__() - self.name = input_choice.key + self.name = input_choice.label self.alpha = nn.Parameter(torch.randn(input_choice.n_candidates) * 1e-3) self.n_chosen = input_choice.n_chosen or 1