Skip to content
This repository has been archived by the owner on Sep 18, 2024. It is now read-only.

[Retiarii] Use ValueChoice inline in a serializable instance #3382

Merged
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions nni/retiarii/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,14 @@ def _dump(self) -> Any:
}
return ret

def get_nodes(self) -> List['Node']:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The return type is Iterable in this case.

"""
Traverse through all the nodes.
"""
for graph in self.graphs.values():
for node in graph.nodes:
yield node

def get_nodes_by_label(self, label: str) -> List['Node']:
"""
Traverse all the nodes to find the matched node(s) with the given name.
Expand Down
65 changes: 54 additions & 11 deletions nni/retiarii/nn/pytorch/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import torch
import torch.nn as nn

from ...utils import uid, add_record, del_record
from ...utils import uid, add_record, del_record, Translatable


__all__ = ['LayerChoice', 'InputChoice', 'ValueChoice', 'Placeholder', 'ChosenInputs']
Expand Down Expand Up @@ -130,6 +130,9 @@ def forward(self, x):
warnings.warn('You should not run forward of this module directly.')
return x

def __repr__(self):
return f'LayerChoice({self.candidates}, label={repr(self.label)})'


class InputChoice(nn.Module):
"""
Expand Down Expand Up @@ -188,33 +191,66 @@ def forward(self, candidate_inputs: List[torch.Tensor]) -> torch.Tensor:
warnings.warn('You should not run forward of this module directly.')
return candidate_inputs[0]

def __repr__(self):
return f'InputChoice(n_candidates={self.n_candidates}, n_chosen={self.n_chosen}, ' \
f'reduction={repr(self.reduction)}, label={repr(self.label)})'


class ValueChoice(nn.Module):
class ValueChoice(Translatable, nn.Module):
"""
ValueChoice is to choose one from ``candidates``.

Should initialize the values to choose from in init and call the module in forward to get the chosen value.
In most use scenarios, ValueChoice should be passed to the init parameters of a serializable module. For example,

.. code-block:: python

class Net(nn.Module):
def __init__(self):
super().__init__()
self.conv = nn.Conv2d(3, nn.ValueChoice([32, 64]), kernel_size=nn.ValueChoice([3, 5, 7]))

def forward(self, x):
return self.conv(x)

A common use is to pass a mutable value to a functional API like ``torch.xxx`` or ``nn.functional.xxx```. For example,
In case, you want to search a parameter that is used repeatedly, this is also possible by sharing the same value choice instance.
(Sharing the label should have the same effect.) For example,

.. code-block:: python

class Net(nn.Module):
def __init__(self):
super().__init__()
self.dropout_rate = nn.ValueChoice([0., 1.])
hidden_dim = nn.ValueChoice([128, 512])
self.fc = nn.Sequential(
nn.Linear(64, hidden_dim),
nn.Linear(hidden_dim, 10)
)

# the following code has the same effect.
# self.fc = nn.Sequential(
# nn.Linear(64, nn.ValueChoice([128, 512], label='dim')),
# nn.Linear(nn.ValueChoice([128, 512], label='dim'), 10)
# )

def forward(self, x):
return F.dropout(x, self.dropout_rate())
return self.fc(x)

The following use case is currently not supported because ValueChoice cannot be called in ``__init__``.
Please use LayerChoice as a workaround.
Note that ValueChoice should be used directly. Transformations like ``nn.Linear(32, nn.ValueChoice([64, 128]) * 2)``
are not supported.

Another common use case is to initialize the values to choose from in init and call the module in forward to get the chosen value.
Usually, this is used to pass a mutable value to a functional API like ``torch.xxx`` or ``nn.functional.xxx```.
For example,

.. code-block:: python

# in __init__ code
self.kernel_size = nn.ValueChoice([3, 5])
self.conv = nn.Conv2d(3, self.out_channels, kernel_size=self.kernel_size())
class Net(nn.Module):
def __init__(self):
super().__init__()
self.dropout_rate = nn.ValueChoice([0., 1.])

def forward(self, x):
return F.dropout(x, self.dropout_rate())

Parameters
----------
Expand All @@ -237,6 +273,13 @@ def forward(self):
warnings.warn('You should not run forward of this module directly.')
return self.candidates[0]

def __translate__(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it proper to use __xxx__ for non-builtin method name? I'm not sure.

# Will function as a value when used in serializer.
return self.candidates[0]

def __repr__(self):
return f'ValueChoice({self.candidates}, label={repr(self.label)})'


class Placeholder(nn.Module):
# TODO: docstring
Expand Down
38 changes: 37 additions & 1 deletion nni/retiarii/nn/pytorch/mutator.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from typing import Any, List, Optional
from typing import Any, List, Optional, Tuple

from ...mutator import Mutator
from ...graph import Model, Node
from .api import ValueChoice


class LayerChoiceMutator(Mutator):
Expand Down Expand Up @@ -48,6 +49,19 @@ def mutate(self, model):
target.update_operation('prim::Constant', {'value': chosen})


class ParameterChoiceMutator(Mutator):
def __init__(self, nodes: List[Tuple[Node, str]], candidates: List[Any]):
super().__init__()
self.nodes = nodes
self.candidates = candidates

def mutate(self, model):
chosen = self.choice(self.candidates)
for node, argname in self.nodes:
target = model.get_node_by_name(node.name)
target.update_operation(target.operation.type, {**target.operation.parameters, argname: chosen})


def process_inline_mutation(model: Model) -> Optional[List[Mutator]]:
applied_mutators = []

Expand All @@ -73,6 +87,18 @@ def process_inline_mutation(model: Model) -> Optional[List[Mutator]]:
mutator = ValueChoiceMutator(node_list, node_list[0].operation.parameters['candidates'])
applied_mutators.append(mutator)

pc_nodes = []
for node in model.get_nodes():
for name, choice in node.operation.parameters.items():
if isinstance(choice, ValueChoice):
pc_nodes.append((node, name))
pc_nodes = _group_parameters_by_label(pc_nodes)
for node_list in pc_nodes:
assert _is_all_equal([node.operation.parameters[name].candidates for node, name in node_list]), \
'Value choice with the same label must have the same candidates.'
mutator = ParameterChoiceMutator(node_list, node_list[0][0].operation.parameters[node_list[0][1]].candidates)
applied_mutators.append(mutator)

if applied_mutators:
return applied_mutators
return None
Expand All @@ -95,3 +121,13 @@ def _group_by_label(nodes: List[Node]) -> List[List[Node]]:
result[label] = []
result[label].append(node)
return list(result.values())


def _group_parameters_by_label(nodes: List[Tuple[Node, str]]) -> List[List[Tuple[Node, str]]]:
result = {}
for node, argname in nodes:
label = node.operation.parameters[argname].label
if label not in result:
result[label] = []
result[label].append((node, argname))
return list(result.values())
21 changes: 21 additions & 0 deletions nni/retiarii/utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import abc
import functools
import inspect
from collections import defaultdict
Expand Down Expand Up @@ -89,6 +90,17 @@ def del_record(key):
_records.pop(key, None)


class Translatable(abc.ABC):
"""
Inherit this class and implement ``translate`` when the inner class needs a different
parameter from the wrapper class in its init function.
"""

@abc.abstractmethod
def __translate__(self) -> Any:
pass


def _blackbox_cls(cls):
class wrapper(cls):
def __init__(self, *args, **kwargs):
Expand All @@ -100,6 +112,15 @@ def __init__(self, *args, **kwargs):
for argname, value in zip(argname_list, args):
full_args[argname] = value

# translate parameters
args = list(args)
for i, value in enumerate(args):
if isinstance(value, Translatable):
args[i] = value.__translate__()
for i, value in kwargs.items():
if isinstance(value, Translatable):
kwargs[i] = value.__translate__()

add_record(id(self), full_args) # for compatibility. Will remove soon.

self.__init_parameters__ = full_args
Expand Down
93 changes: 87 additions & 6 deletions test/ut/retiarii/test_highlevel_apis.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from nni.retiarii.nn.pytorch.mutator import process_inline_mutation


class EnuemrateSampler(Sampler):
class EnumerateSampler(Sampler):
def __init__(self):
self.index = 0

Expand Down Expand Up @@ -70,7 +70,7 @@ def forward(self, x):
model = self._convert_to_ir(Net())
mutators = process_inline_mutation(model)
self.assertEqual(len(mutators), 1)
mutator = mutators[0].bind_sampler(EnuemrateSampler())
mutator = mutators[0].bind_sampler(EnumerateSampler())
model1 = mutator.apply(model)
model2 = mutator.apply(model)
self.assertEqual(self._get_converted_pytorch_model(model1)(torch.randn(1, 3, 3, 3)).size(),
Expand All @@ -94,7 +94,7 @@ def forward(self, x):
model = self._convert_to_ir(Net())
mutators = process_inline_mutation(model)
self.assertEqual(len(mutators), 1)
mutator = mutators[0].bind_sampler(EnuemrateSampler())
mutator = mutators[0].bind_sampler(EnumerateSampler())
model1 = mutator.apply(model)
model2 = mutator.apply(model)
self.assertEqual(self._get_converted_pytorch_model(model1)(torch.randn(1, 3, 3, 3)).size(),
Expand All @@ -119,7 +119,7 @@ def forward(self, x):
model = self._convert_to_ir(Net(reduction))
mutators = process_inline_mutation(model)
self.assertEqual(len(mutators), 1)
mutator = mutators[0].bind_sampler(EnuemrateSampler())
mutator = mutators[0].bind_sampler(EnumerateSampler())
model = mutator.apply(model)
result = self._get_converted_pytorch_model(model)(torch.randn(1, 3, 3, 3))
if reduction == 'none':
Expand All @@ -144,14 +144,95 @@ def forward(self, x):
model = self._convert_to_ir(Net())
mutators = process_inline_mutation(model)
self.assertEqual(len(mutators), 1)
mutator = mutators[0].bind_sampler(EnuemrateSampler())
mutator = mutators[0].bind_sampler(EnumerateSampler())
model1 = mutator.apply(model)
model2 = mutator.apply(model)
self.assertEqual(self._get_converted_pytorch_model(model1)(torch.randn(1, 3, 3, 3)).size(),
torch.Size([1, 3, 3, 3]))
self.assertEqual(self._get_converted_pytorch_model(model2)(torch.randn(1, 3, 3, 3)).size(),
torch.Size([1, 5, 3, 3]))

def test_value_choice_as_parameter(self):
class Net(nn.Module):
def __init__(self):
super().__init__()
self.conv = nn.Conv2d(3, 5, kernel_size=nn.ValueChoice([3, 5]))

def forward(self, x):
return self.conv(x)

model = self._convert_to_ir(Net())
mutators = process_inline_mutation(model)
self.assertEqual(len(mutators), 1)
mutator = mutators[0].bind_sampler(EnumerateSampler())
model1 = mutator.apply(model)
model2 = mutator.apply(model)
self.assertEqual(self._get_converted_pytorch_model(model1)(torch.randn(1, 3, 5, 5)).size(),
torch.Size([1, 5, 3, 3]))
self.assertEqual(self._get_converted_pytorch_model(model2)(torch.randn(1, 3, 5, 5)).size(),
torch.Size([1, 5, 1, 1]))

def test_value_choice_as_parameter(self):
class Net(nn.Module):
def __init__(self):
super().__init__()
self.conv = nn.Conv2d(3, 5, kernel_size=nn.ValueChoice([3, 5]))

def forward(self, x):
return self.conv(x)

model = self._convert_to_ir(Net())
mutators = process_inline_mutation(model)
self.assertEqual(len(mutators), 1)
mutator = mutators[0].bind_sampler(EnumerateSampler())
model1 = mutator.apply(model)
model2 = mutator.apply(model)
self.assertEqual(self._get_converted_pytorch_model(model1)(torch.randn(1, 3, 5, 5)).size(),
torch.Size([1, 5, 3, 3]))
self.assertEqual(self._get_converted_pytorch_model(model2)(torch.randn(1, 3, 5, 5)).size(),
torch.Size([1, 5, 1, 1]))

def test_value_choice_as_parameter(self):
class Net(nn.Module):
def __init__(self):
super().__init__()
self.conv = nn.Conv2d(3, nn.ValueChoice([6, 8]), kernel_size=nn.ValueChoice([3, 5]))

def forward(self, x):
return self.conv(x)

model = self._convert_to_ir(Net())
mutators = process_inline_mutation(model)
self.assertEqual(len(mutators), 2)
mutators[0].bind_sampler(EnumerateSampler())
mutators[1].bind_sampler(EnumerateSampler())
input = torch.randn(1, 3, 5, 5)
self.assertEqual(self._get_converted_pytorch_model(mutators[1].apply(mutators[0].apply(model)))(input).size(),
torch.Size([1, 6, 3, 3]))
self.assertEqual(self._get_converted_pytorch_model(mutators[1].apply(mutators[0].apply(model)))(input).size(),
torch.Size([1, 8, 1, 1]))

def test_value_choice_as_parameter_shared(self):
class Net(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(3, nn.ValueChoice([6, 8], label='shared'), 1)
self.conv2 = nn.Conv2d(3, nn.ValueChoice([6, 8], label='shared'), 1)

def forward(self, x):
return self.conv1(x) + self.conv2(x)

model = self._convert_to_ir(Net())
mutators = process_inline_mutation(model)
self.assertEqual(len(mutators), 1)
mutator = mutators[0].bind_sampler(EnumerateSampler())
model1 = mutator.apply(model)
model2 = mutator.apply(model)
self.assertEqual(self._get_converted_pytorch_model(model1)(torch.randn(1, 3, 5, 5)).size(),
torch.Size([1, 6, 5, 5]))
self.assertEqual(self._get_converted_pytorch_model(model2)(torch.randn(1, 3, 5, 5)).size(),
torch.Size([1, 8, 5, 5]))

def test_value_choice_in_functional(self):
class Net(nn.Module):
def __init__(self):
Expand All @@ -164,7 +245,7 @@ def forward(self, x):
model = self._convert_to_ir(Net())
mutators = process_inline_mutation(model)
self.assertEqual(len(mutators), 1)
mutator = mutators[0].bind_sampler(EnuemrateSampler())
mutator = mutators[0].bind_sampler(EnumerateSampler())
model1 = mutator.apply(model)
model2 = mutator.apply(model)
self.assertEqual(self._get_converted_pytorch_model(model1)(torch.randn(1, 3, 3, 3)).size(), torch.Size([1, 3, 3, 3]))
Expand Down