From c647750430d0135632e049154a2404ec8f9408a1 Mon Sep 17 00:00:00 2001 From: Yuge Zhang Date: Mon, 15 Feb 2021 12:06:38 +0800 Subject: [PATCH 1/4] Serializer end --- nni/retiarii/nn/pytorch/api.py | 8 ++++++-- nni/retiarii/utils.py | 21 +++++++++++++++++++++ 2 files changed, 27 insertions(+), 2 deletions(-) diff --git a/nni/retiarii/nn/pytorch/api.py b/nni/retiarii/nn/pytorch/api.py index 46a1715469..fa774f20dd 100644 --- a/nni/retiarii/nn/pytorch/api.py +++ b/nni/retiarii/nn/pytorch/api.py @@ -5,7 +5,7 @@ import torch import torch.nn as nn -from ...utils import uid, add_record, del_record +from ...utils import uid, add_record, del_record, Translatable __all__ = ['LayerChoice', 'InputChoice', 'ValueChoice', 'Placeholder', 'ChosenInputs'] @@ -189,7 +189,7 @@ def forward(self, candidate_inputs: List[torch.Tensor]) -> torch.Tensor: return candidate_inputs[0] -class ValueChoice(nn.Module): +class ValueChoice(Translatable, nn.Module): """ ValueChoice is to choose one from ``candidates``. @@ -237,6 +237,10 @@ def forward(self): warnings.warn('You should not run forward of this module directly.') return self.candidates[0] + def __translate__(self): + # Will function as a value when used in serializer. + return self.candidates[0] + class Placeholder(nn.Module): # TODO: docstring diff --git a/nni/retiarii/utils.py b/nni/retiarii/utils.py index 27c1f22424..9776a4a866 100644 --- a/nni/retiarii/utils.py +++ b/nni/retiarii/utils.py @@ -1,3 +1,4 @@ +import abc import functools import inspect from collections import defaultdict @@ -89,6 +90,17 @@ def del_record(key): _records.pop(key, None) +class Translatable(abc.ABC): + """ + Inherit this class and implement ``translate`` when the inner class needs a different + parameter from the wrapper class in its init function. + """ + + @abc.abstractmethod + def __translate__(self) -> Any: + pass + + def _blackbox_cls(cls): class wrapper(cls): def __init__(self, *args, **kwargs): @@ -100,6 +112,15 @@ def __init__(self, *args, **kwargs): for argname, value in zip(argname_list, args): full_args[argname] = value + # translate parameters + args = list(args) + for i, value in enumerate(args): + if isinstance(value, Translatable): + args[i] = value.__translate__() + for i, value in kwargs.items(): + if isinstance(value, Translatable): + kwargs[i] = value.__translate__() + add_record(id(self), full_args) # for compatibility. Will remove soon. self.__init_parameters__ = full_args From cf8993cf132a956935cd7464cd7c51cea78f27c0 Mon Sep 17 00:00:00 2001 From: Yuge Zhang Date: Mon, 22 Feb 2021 12:06:13 +0800 Subject: [PATCH 2/4] Full support and tests --- nni/retiarii/graph.py | 8 +++ nni/retiarii/nn/pytorch/api.py | 57 ++++++++++++--- nni/retiarii/nn/pytorch/mutator.py | 38 +++++++++- test/ut/retiarii/test_highlevel_apis.py | 93 +++++++++++++++++++++++-- 4 files changed, 180 insertions(+), 16 deletions(-) diff --git a/nni/retiarii/graph.py b/nni/retiarii/graph.py index f8a99b7eb9..49056d65be 100644 --- a/nni/retiarii/graph.py +++ b/nni/retiarii/graph.py @@ -152,6 +152,14 @@ def _dump(self) -> Any: } return ret + def get_nodes(self) -> List['Node']: + """ + Traverse through all the nodes. + """ + for graph in self.graphs.values(): + for node in graph.nodes: + yield node + def get_nodes_by_label(self, label: str) -> List['Node']: """ Traverse all the nodes to find the matched node(s) with the given name. diff --git a/nni/retiarii/nn/pytorch/api.py b/nni/retiarii/nn/pytorch/api.py index fa774f20dd..6da0eaff1c 100644 --- a/nni/retiarii/nn/pytorch/api.py +++ b/nni/retiarii/nn/pytorch/api.py @@ -130,6 +130,9 @@ def forward(self, x): warnings.warn('You should not run forward of this module directly.') return x + def __repr__(self): + return f'LayerChoice({self.candidates}, label={repr(self.label)})' + class InputChoice(nn.Module): """ @@ -188,33 +191,66 @@ def forward(self, candidate_inputs: List[torch.Tensor]) -> torch.Tensor: warnings.warn('You should not run forward of this module directly.') return candidate_inputs[0] + def __repr__(self): + return f'InputChoice(n_candidates={self.n_candidates}, n_chosen={self.n_chosen}, ' \ + f'reduction={repr(self.reduction)}, label={repr(self.label)})' + class ValueChoice(Translatable, nn.Module): """ ValueChoice is to choose one from ``candidates``. - Should initialize the values to choose from in init and call the module in forward to get the chosen value. + In most use scenarios, ValueChoice should be passed to the init parameters of a serializable module. For example, - A common use is to pass a mutable value to a functional API like ``torch.xxx`` or ``nn.functional.xxx```. For example, + .. code-block:: python + + class Net(nn.Module): + def __init__(self): + super().__init__() + self.conv = nn.Conv2d(3, nn.ValueChoice([32, 64]), kernel_size=nn.ValueChoice([3, 5, 7])) + + def forward(self, x): + return self.conv(x) + + In case, you want to search a parameter that is used repeatedly, this is also possible by sharing the same value choice instance. + (Sharing the label should have the same effect.) For example, .. code-block:: python class Net(nn.Module): def __init__(self): super().__init__() - self.dropout_rate = nn.ValueChoice([0., 1.]) + hidden_dim = nn.ValueChoice([128, 512]) + self.fc = nn.Sequential( + nn.Linear(64, hidden_dim), + nn.Linear(hidden_dim, 10) + ) + + # the following code has the same effect. + # self.fc = nn.Sequential( + # nn.Linear(64, nn.ValueChoice([128, 512], label='dim')), + # nn.Linear(nn.ValueChoice([128, 512], label='dim'), 10) + # ) def forward(self, x): - return F.dropout(x, self.dropout_rate()) + return self.fc(x) + + Note that ValueChoice should be used directly. Transformations like ``nn.Linear(32, nn.ValueChoice([64, 128]) * 2)`` + are not supported. - The following use case is currently not supported because ValueChoice cannot be called in ``__init__``. - Please use LayerChoice as a workaround. + Another common use case is to initialize the values to choose from in init and call the module in forward to get the chosen value. + Usually, this is used to pass a mutable value to a functional API like ``torch.xxx`` or ``nn.functional.xxx```. + For example, .. code-block:: python - # in __init__ code - self.kernel_size = nn.ValueChoice([3, 5]) - self.conv = nn.Conv2d(3, self.out_channels, kernel_size=self.kernel_size()) + class Net(nn.Module): + def __init__(self): + super().__init__() + self.dropout_rate = nn.ValueChoice([0., 1.]) + + def forward(self, x): + return F.dropout(x, self.dropout_rate()) Parameters ---------- @@ -241,6 +277,9 @@ def __translate__(self): # Will function as a value when used in serializer. return self.candidates[0] + def __repr__(self): + return f'ValueChoice({self.candidates}, label={repr(self.label)})' + class Placeholder(nn.Module): # TODO: docstring diff --git a/nni/retiarii/nn/pytorch/mutator.py b/nni/retiarii/nn/pytorch/mutator.py index 1d1da4d5c7..85d70dc14f 100644 --- a/nni/retiarii/nn/pytorch/mutator.py +++ b/nni/retiarii/nn/pytorch/mutator.py @@ -1,7 +1,8 @@ -from typing import Any, List, Optional +from typing import Any, List, Optional, Tuple from ...mutator import Mutator from ...graph import Model, Node +from .api import ValueChoice class LayerChoiceMutator(Mutator): @@ -48,6 +49,19 @@ def mutate(self, model): target.update_operation('prim::Constant', {'value': chosen}) +class ParameterChoiceMutator(Mutator): + def __init__(self, nodes: List[Tuple[Node, str]], candidates: List[Any]): + super().__init__() + self.nodes = nodes + self.candidates = candidates + + def mutate(self, model): + chosen = self.choice(self.candidates) + for node, argname in self.nodes: + target = model.get_node_by_name(node.name) + target.update_operation(target.operation.type, {**target.operation.parameters, argname: chosen}) + + def process_inline_mutation(model: Model) -> Optional[List[Mutator]]: applied_mutators = [] @@ -73,6 +87,18 @@ def process_inline_mutation(model: Model) -> Optional[List[Mutator]]: mutator = ValueChoiceMutator(node_list, node_list[0].operation.parameters['candidates']) applied_mutators.append(mutator) + pc_nodes = [] + for node in model.get_nodes(): + for name, choice in node.operation.parameters.items(): + if isinstance(choice, ValueChoice): + pc_nodes.append((node, name)) + pc_nodes = _group_parameters_by_label(pc_nodes) + for node_list in pc_nodes: + assert _is_all_equal([node.operation.parameters[name].candidates for node, name in node_list]), \ + 'Value choice with the same label must have the same candidates.' + mutator = ParameterChoiceMutator(node_list, node_list[0][0].operation.parameters[node_list[0][1]].candidates) + applied_mutators.append(mutator) + if applied_mutators: return applied_mutators return None @@ -95,3 +121,13 @@ def _group_by_label(nodes: List[Node]) -> List[List[Node]]: result[label] = [] result[label].append(node) return list(result.values()) + + +def _group_parameters_by_label(nodes: List[Tuple[Node, str]]) -> List[List[Tuple[Node, str]]]: + result = {} + for node, argname in nodes: + label = node.operation.parameters[argname].label + if label not in result: + result[label] = [] + result[label].append((node, argname)) + return list(result.values()) diff --git a/test/ut/retiarii/test_highlevel_apis.py b/test/ut/retiarii/test_highlevel_apis.py index 54b01d8909..b2b74db626 100644 --- a/test/ut/retiarii/test_highlevel_apis.py +++ b/test/ut/retiarii/test_highlevel_apis.py @@ -10,7 +10,7 @@ from nni.retiarii.nn.pytorch.mutator import process_inline_mutation -class EnuemrateSampler(Sampler): +class EnumerateSampler(Sampler): def __init__(self): self.index = 0 @@ -70,7 +70,7 @@ def forward(self, x): model = self._convert_to_ir(Net()) mutators = process_inline_mutation(model) self.assertEqual(len(mutators), 1) - mutator = mutators[0].bind_sampler(EnuemrateSampler()) + mutator = mutators[0].bind_sampler(EnumerateSampler()) model1 = mutator.apply(model) model2 = mutator.apply(model) self.assertEqual(self._get_converted_pytorch_model(model1)(torch.randn(1, 3, 3, 3)).size(), @@ -94,7 +94,7 @@ def forward(self, x): model = self._convert_to_ir(Net()) mutators = process_inline_mutation(model) self.assertEqual(len(mutators), 1) - mutator = mutators[0].bind_sampler(EnuemrateSampler()) + mutator = mutators[0].bind_sampler(EnumerateSampler()) model1 = mutator.apply(model) model2 = mutator.apply(model) self.assertEqual(self._get_converted_pytorch_model(model1)(torch.randn(1, 3, 3, 3)).size(), @@ -119,7 +119,7 @@ def forward(self, x): model = self._convert_to_ir(Net(reduction)) mutators = process_inline_mutation(model) self.assertEqual(len(mutators), 1) - mutator = mutators[0].bind_sampler(EnuemrateSampler()) + mutator = mutators[0].bind_sampler(EnumerateSampler()) model = mutator.apply(model) result = self._get_converted_pytorch_model(model)(torch.randn(1, 3, 3, 3)) if reduction == 'none': @@ -144,7 +144,7 @@ def forward(self, x): model = self._convert_to_ir(Net()) mutators = process_inline_mutation(model) self.assertEqual(len(mutators), 1) - mutator = mutators[0].bind_sampler(EnuemrateSampler()) + mutator = mutators[0].bind_sampler(EnumerateSampler()) model1 = mutator.apply(model) model2 = mutator.apply(model) self.assertEqual(self._get_converted_pytorch_model(model1)(torch.randn(1, 3, 3, 3)).size(), @@ -152,6 +152,87 @@ def forward(self, x): self.assertEqual(self._get_converted_pytorch_model(model2)(torch.randn(1, 3, 3, 3)).size(), torch.Size([1, 5, 3, 3])) + def test_value_choice_as_parameter(self): + class Net(nn.Module): + def __init__(self): + super().__init__() + self.conv = nn.Conv2d(3, 5, kernel_size=nn.ValueChoice([3, 5])) + + def forward(self, x): + return self.conv(x) + + model = self._convert_to_ir(Net()) + mutators = process_inline_mutation(model) + self.assertEqual(len(mutators), 1) + mutator = mutators[0].bind_sampler(EnumerateSampler()) + model1 = mutator.apply(model) + model2 = mutator.apply(model) + self.assertEqual(self._get_converted_pytorch_model(model1)(torch.randn(1, 3, 5, 5)).size(), + torch.Size([1, 5, 3, 3])) + self.assertEqual(self._get_converted_pytorch_model(model2)(torch.randn(1, 3, 5, 5)).size(), + torch.Size([1, 5, 1, 1])) + + def test_value_choice_as_parameter(self): + class Net(nn.Module): + def __init__(self): + super().__init__() + self.conv = nn.Conv2d(3, 5, kernel_size=nn.ValueChoice([3, 5])) + + def forward(self, x): + return self.conv(x) + + model = self._convert_to_ir(Net()) + mutators = process_inline_mutation(model) + self.assertEqual(len(mutators), 1) + mutator = mutators[0].bind_sampler(EnumerateSampler()) + model1 = mutator.apply(model) + model2 = mutator.apply(model) + self.assertEqual(self._get_converted_pytorch_model(model1)(torch.randn(1, 3, 5, 5)).size(), + torch.Size([1, 5, 3, 3])) + self.assertEqual(self._get_converted_pytorch_model(model2)(torch.randn(1, 3, 5, 5)).size(), + torch.Size([1, 5, 1, 1])) + + def test_value_choice_as_parameter(self): + class Net(nn.Module): + def __init__(self): + super().__init__() + self.conv = nn.Conv2d(3, nn.ValueChoice([6, 8]), kernel_size=nn.ValueChoice([3, 5])) + + def forward(self, x): + return self.conv(x) + + model = self._convert_to_ir(Net()) + mutators = process_inline_mutation(model) + self.assertEqual(len(mutators), 2) + mutators[0].bind_sampler(EnumerateSampler()) + mutators[1].bind_sampler(EnumerateSampler()) + input = torch.randn(1, 3, 5, 5) + self.assertEqual(self._get_converted_pytorch_model(mutators[1].apply(mutators[0].apply(model)))(input).size(), + torch.Size([1, 6, 3, 3])) + self.assertEqual(self._get_converted_pytorch_model(mutators[1].apply(mutators[0].apply(model)))(input).size(), + torch.Size([1, 8, 1, 1])) + + def test_value_choice_as_parameter_shared(self): + class Net(nn.Module): + def __init__(self): + super().__init__() + self.conv1 = nn.Conv2d(3, nn.ValueChoice([6, 8], label='shared'), 1) + self.conv2 = nn.Conv2d(3, nn.ValueChoice([6, 8], label='shared'), 1) + + def forward(self, x): + return self.conv1(x) + self.conv2(x) + + model = self._convert_to_ir(Net()) + mutators = process_inline_mutation(model) + self.assertEqual(len(mutators), 1) + mutator = mutators[0].bind_sampler(EnumerateSampler()) + model1 = mutator.apply(model) + model2 = mutator.apply(model) + self.assertEqual(self._get_converted_pytorch_model(model1)(torch.randn(1, 3, 5, 5)).size(), + torch.Size([1, 6, 5, 5])) + self.assertEqual(self._get_converted_pytorch_model(model2)(torch.randn(1, 3, 5, 5)).size(), + torch.Size([1, 8, 5, 5])) + def test_value_choice_in_functional(self): class Net(nn.Module): def __init__(self): @@ -164,7 +245,7 @@ def forward(self, x): model = self._convert_to_ir(Net()) mutators = process_inline_mutation(model) self.assertEqual(len(mutators), 1) - mutator = mutators[0].bind_sampler(EnuemrateSampler()) + mutator = mutators[0].bind_sampler(EnumerateSampler()) model1 = mutator.apply(model) model2 = mutator.apply(model) self.assertEqual(self._get_converted_pytorch_model(model1)(torch.randn(1, 3, 3, 3)).size(), torch.Size([1, 3, 3, 3])) From 9fb9bd512cccde2aebedee81228633f80604f1b8 Mon Sep 17 00:00:00 2001 From: Yuge Zhang Date: Mon, 22 Feb 2021 18:23:47 +0800 Subject: [PATCH 3/4] Update docs --- docs/en_US/NAS/retiarii/Tutorial.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/en_US/NAS/retiarii/Tutorial.rst b/docs/en_US/NAS/retiarii/Tutorial.rst index 925827adeb..309a623007 100644 --- a/docs/en_US/NAS/retiarii/Tutorial.rst +++ b/docs/en_US/NAS/retiarii/Tutorial.rst @@ -83,7 +83,7 @@ For easy usability and also backward compatibility, we provide some APIs for use # invoked in `forward` function, choose one from the three out = self.input_switch([tensor1, tensor2, tensor3]) -* ``nn.ValueChoice``. It is for choosing one value from some candidate values. It can only be used as input argument of the modules in ``nn.modules`` and ``@blackbox_module`` decorated user-defined modules. *Note that it has not been officially supported.* +* ``nn.ValueChoice``. It is for choosing one value from some candidate values. It can only be used as input argument of the modules in ``nn.modules`` and ``@blackbox_module`` decorated user-defined modules. .. code-block:: python From 8d480055c3332a6242af716f9b18ef64625e66ba Mon Sep 17 00:00:00 2001 From: Yuge Zhang Date: Wed, 24 Feb 2021 15:37:07 +0800 Subject: [PATCH 4/4] Resolve comments --- nni/retiarii/graph.py | 4 ++-- nni/retiarii/nn/pytorch/api.py | 2 +- nni/retiarii/utils.py | 6 +++--- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/nni/retiarii/graph.py b/nni/retiarii/graph.py index 49056d65be..022094ed90 100644 --- a/nni/retiarii/graph.py +++ b/nni/retiarii/graph.py @@ -6,7 +6,7 @@ import copy import json from enum import Enum -from typing import (Any, Dict, List, Optional, Tuple, Union, overload) +from typing import (Any, Dict, Iterable, List, Optional, Tuple, Union, overload) from .operation import Cell, Operation, _IOPseudoOperation from .utils import get_full_class_name, import_, uid @@ -152,7 +152,7 @@ def _dump(self) -> Any: } return ret - def get_nodes(self) -> List['Node']: + def get_nodes(self) -> Iterable['Node']: """ Traverse through all the nodes. """ diff --git a/nni/retiarii/nn/pytorch/api.py b/nni/retiarii/nn/pytorch/api.py index 6da0eaff1c..9a12257f9d 100644 --- a/nni/retiarii/nn/pytorch/api.py +++ b/nni/retiarii/nn/pytorch/api.py @@ -273,7 +273,7 @@ def forward(self): warnings.warn('You should not run forward of this module directly.') return self.candidates[0] - def __translate__(self): + def _translate(self): # Will function as a value when used in serializer. return self.candidates[0] diff --git a/nni/retiarii/utils.py b/nni/retiarii/utils.py index 9776a4a866..bca0881d47 100644 --- a/nni/retiarii/utils.py +++ b/nni/retiarii/utils.py @@ -97,7 +97,7 @@ class Translatable(abc.ABC): """ @abc.abstractmethod - def __translate__(self) -> Any: + def _translate(self) -> Any: pass @@ -116,10 +116,10 @@ def __init__(self, *args, **kwargs): args = list(args) for i, value in enumerate(args): if isinstance(value, Translatable): - args[i] = value.__translate__() + args[i] = value._translate() for i, value in kwargs.items(): if isinstance(value, Translatable): - kwargs[i] = value.__translate__() + kwargs[i] = value._translate() add_record(id(self), full_args) # for compatibility. Will remove soon.