Skip to content
This repository has been archived by the owner on Sep 18, 2024. It is now read-only.

[Retiarii] Pure-python execution engine #3605

Merged
merged 13 commits into from
May 25, 2021
12 changes: 11 additions & 1 deletion docs/en_US/NAS/retiarii/Advanced.rst
Original file line number Diff line number Diff line change
@@ -1,7 +1,17 @@
Advanced Tutorial
=================

This document includes two parts. The first part explains the design decision of ``@basic_unit`` and ``serializer``. The second part is the tutorial of how to write a model space with mutators.
Pure-python execution engine (experimental)
-------------------------------------------

If you are experiencing issues with TorchScript, or the generated model code by Retiarii, there is another execution engine called Pure-python execution engine which doesn't need the code-graph conversion. This should generally not affect models and strategies in most cases, but customized mutation might not be supported.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

终于可以不用torchscript了,大佬牛逼


This will come as the default execution engine in future version of Retiarii.

Two steps are needed to enable this engine now.

1. Add ``@nni.retiarii.model_wrapper`` decorator outside the whole PyTorch model.
2. Add ``config.execution_engine = 'py'`` to ``RetiariiExeConfig``.

``@basic_unit`` and ``serializer``
----------------------------------
Expand Down
2 changes: 1 addition & 1 deletion nni/retiarii/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,4 @@
from .graph import *
from .execution import *
from .mutator import *
from .serializer import basic_unit, json_dump, json_dumps, json_load, json_loads, serialize, serialize_cls
from .serializer import basic_unit, json_dump, json_dumps, json_load, json_loads, serialize, serialize_cls, model_wrapper
9 changes: 4 additions & 5 deletions nni/retiarii/execution/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,19 +15,18 @@
'list_models', 'submit_models', 'wait_models', 'query_available_resources',
'set_execution_engine', 'is_stopped_exec', 'budget_exhausted']

def set_execution_engine(engine) -> None:

def set_execution_engine(engine: AbstractExecutionEngine) -> None:
global _execution_engine
if _execution_engine is None:
_execution_engine = engine
else:
raise RuntimeError('execution engine is already set')
raise RuntimeError('Execution engine is already set.')


def get_execution_engine() -> AbstractExecutionEngine:
"""
Currently we assume the default execution engine is BaseExecutionEngine.
"""
global _execution_engine
assert _execution_engine is not None, 'You need to set execution engine, before using it.'
return _execution_engine


Expand Down
8 changes: 6 additions & 2 deletions nni/retiarii/execution/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import os
import random
import string
from typing import Dict, Iterable, List
from typing import Any, Dict, Iterable, List

from .interface import AbstractExecutionEngine, AbstractGraphListener
from .. import codegen, utils
Expand Down Expand Up @@ -59,7 +59,7 @@ def __init__(self) -> None:

def submit_models(self, *models: Model) -> None:
for model in models:
data = BaseGraphData(codegen.model_to_pytorch_script(model), model.evaluator)
data = self.pack_model_data(model)
self._running_models[send_trial(data.dump())] = model
self._history.append(model)

Expand Down Expand Up @@ -108,6 +108,10 @@ def budget_exhausted(self) -> bool:
advisor = get_advisor()
return advisor.stopping

@classmethod
def pack_model_data(cls, model: Model) -> Any:
return BaseGraphData(codegen.model_to_pytorch_script(model), model.evaluator)

@classmethod
def trial_execute_graph(cls) -> None:
"""
Expand Down
53 changes: 53 additions & 0 deletions nni/retiarii/execution/python.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
from typing import Dict, Any, List

from ..graph import Evaluator, Model
from ..integration_api import receive_trial_parameters
from ..utils import ContextStack, import_, get_importable_name
from .base import BaseExecutionEngine


class PythonGraphData:
def __init__(self, class_name: str, init_parameters: Dict[str, Any],
mutation: Dict[str, Any], evaluator: Evaluator) -> None:
self.class_name = class_name
self.init_parameters = init_parameters
self.mutation = mutation
self.evaluator = evaluator

def dump(self) -> dict:
return {
'class_name': self.class_name,
'init_parameters': self.init_parameters,
'mutation': self.mutation,
'evaluator': self.evaluator
}

@staticmethod
def load(data) -> 'PythonGraphData':
return PythonGraphData(data['class_name'], data['init_parameters'], data['mutation'], data['evaluator'])


class PurePythonExecutionEngine(BaseExecutionEngine):
@classmethod
def pack_model_data(cls, model: Model) -> Any:
mutation = {mut.mutator.label: _unpack_if_only_one(mut.samples) for mut in model.history}
graph_data = PythonGraphData(get_importable_name(model.python_class, relocate_module=True),
model.python_init_params, mutation, model.evaluator)
return graph_data

@classmethod
def trial_execute_graph(cls) -> None:
graph_data = PythonGraphData.load(receive_trial_parameters())

class _model(import_(graph_data.class_name)):
def __init__(self):
super().__init__(**graph_data.init_parameters)

with ContextStack('fixed', graph_data.mutation):
graph_data.evaluator._execute(_model)


def _unpack_if_only_one(ele: List[Any]):
if len(ele) == 1:
return ele[0]
return ele
50 changes: 36 additions & 14 deletions nni/retiarii/experiment/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,11 @@

from ..codegen import model_to_pytorch_script
from ..converter import convert_to_graph
from ..execution import list_models
from ..execution import list_models, set_execution_engine
from ..graph import Model, Evaluator
from ..integration import RetiariiAdvisor
from ..mutator import Mutator
from ..nn.pytorch.mutator import process_inline_mutation
from ..nn.pytorch.mutator import process_inline_mutation, extract_mutation_from_pt_module
from ..strategy import BaseStrategy
from ..oneshot.interface import BaseOneShotTrainer

Expand All @@ -43,7 +43,7 @@
class RetiariiExeConfig(ConfigBase):
experiment_name: Optional[str] = None
search_space: Any = '' # TODO: remove
trial_command: str = 'python3 -m nni.retiarii.trial_entry'
trial_command: str = '_reserved'
trial_code_directory: PathLike = '.'
trial_concurrency: int
trial_gpu_number: int = 0
Expand All @@ -55,21 +55,26 @@ class RetiariiExeConfig(ConfigBase):
experiment_working_directory: Optional[PathLike] = None
# remove configuration of tuner/assessor/advisor
training_service: TrainingServiceConfig
execution_engine: str = 'base'

def __init__(self, training_service_platform: Optional[str] = None, **kwargs):
super().__init__(**kwargs)
if training_service_platform is not None:
assert 'training_service' not in kwargs
self.training_service = util.training_service_config_factory(platform = training_service_platform)
self.__dict__['trial_command'] = 'python3 -m nni.retiarii.trial_entry base'

def __setattr__(self, key, value):
fixed_attrs = {'search_space': '',
'trial_command': 'python3 -m nni.retiarii.trial_entry'}
'trial_command': '_reserved'}
if key in fixed_attrs and fixed_attrs[key] != value:
raise AttributeError(f'{key} is not supposed to be set in Retiarii mode by users!')
# 'trial_code_directory' is handled differently because the path will be converted to absolute path by us
if key == 'trial_code_directory' and not (value == Path('.') or os.path.isabs(value)):
raise AttributeError(f'{key} is not supposed to be set in Retiarii mode by users!')
if key == 'execution_engine':
assert value in ['base', 'py', 'cgo'], f'The specified execution engine "{value}" is not supported.'
self.__dict__['trial_command'] = 'python3 -m nni.retiarii.trial_entry ' + value
self.__dict__[key] = value

def validate(self, initialized_tuner: bool = False) -> None:
Expand Down Expand Up @@ -100,23 +105,27 @@ def _validation_rules(self):
'training_service': lambda value: (type(value) is not TrainingServiceConfig, 'cannot be abstract base class')
}

def preprocess_model(base_model, trainer, applied_mutators):
def preprocess_model(base_model, trainer, applied_mutators, full_ir=True):
# TODO: this logic might need to be refactored into execution engine
if full_ir:
try:
script_module = torch.jit.script(base_model)
except Exception as e:
_logger.error('Your base model cannot be parsed by torch.jit.script, please fix the following error:')
raise e
base_model_ir = convert_to_graph(script_module, base_model)
base_model_ir.evaluator = trainer

# handle inline mutations
mutators = process_inline_mutation(base_model_ir)
if mutators is not None and applied_mutators:
raise RuntimeError('Have not supported mixed usage of LayerChoice/InputChoice and mutators, '
'do not use mutators when you use LayerChoice/InputChoice')
if mutators is not None:
applied_mutators = mutators
return base_model_ir, applied_mutators
else:
base_model_ir, mutators = extract_mutation_from_pt_module(base_model)
base_model_ir.evaluator = trainer

if mutators is not None and applied_mutators:
raise RuntimeError('Have not supported mixed usage of LayerChoice/InputChoice and mutators, '
'do not use mutators when you use LayerChoice/InputChoice')
if mutators is not None:
applied_mutators = mutators
return base_model_ir, applied_mutators

def debug_mutated_model(base_model, trainer, applied_mutators):
"""
Expand Down Expand Up @@ -160,7 +169,8 @@ def __init__(self, base_model: nn.Module, trainer: Union[Evaluator, BaseOneShotT
self._pipe: Optional[Pipe] = None

def _start_strategy(self):
base_model_ir, self.applied_mutators = preprocess_model(self.base_model, self.trainer, self.applied_mutators)
base_model_ir, self.applied_mutators = preprocess_model(
self.base_model, self.trainer, self.applied_mutators, full_ir=self.config.execution_engine != 'py')

_logger.info('Start strategy...')
self.strategy.run(base_model_ir, self.applied_mutators)
Expand All @@ -182,6 +192,18 @@ def start(self, port: int = 8080, debug: bool = False) -> None:
"""
atexit.register(self.stop)

# we will probably need a execution engine factory to make this clean and elegant
if self.config.execution_engine == 'base':
from ..execution.base import BaseExecutionEngine
engine = BaseExecutionEngine()
elif self.config.execution_engine == 'cgo':
from ..execution.cgo_engine import CGOExecutionEngine
engine = CGOExecutionEngine()
elif self.config.execution_engine == 'py':
from ..execution.python import PurePythonExecutionEngine
engine = PurePythonExecutionEngine()
set_execution_engine(engine)

self.id = management.generate_experiment_id()

if self.config.experiment_working_directory is not None:
Expand Down
57 changes: 50 additions & 7 deletions nni/retiarii/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,12 @@
import copy
import json
from enum import Enum
from typing import (Any, Dict, Iterable, List, Optional, Tuple, Union, overload)
from typing import (Any, Dict, Iterable, List, Optional, Tuple, Type, Union, overload)

from .operation import Cell, Operation, _IOPseudoOperation
from .utils import get_importable_name, import_, uid

__all__ = ['Model', 'ModelStatus', 'Graph', 'Node', 'Edge', 'IllegalGraphError', 'MetricData']
__all__ = ['Model', 'ModelStatus', 'Graph', 'Node', 'Edge', 'Mutation', 'IllegalGraphError', 'MetricData']


MetricData = Any
Expand Down Expand Up @@ -80,6 +80,10 @@ class Model:

Attributes
----------
python_class
Python class that base model is converted from.
python_init_params
Initialization parameters of python class.
status
See `ModelStatus`.
root_graph
Expand All @@ -102,6 +106,8 @@ class Model:
def __init__(self, _internal=False):
assert _internal, '`Model()` is private, use `model.fork()` instead'
self.model_id: int = uid('model')
self.python_class: Optional[Type] = None
self.python_init_params: Optional[Dict[str, Any]] = None

self.status: ModelStatus = ModelStatus.Mutating

Expand All @@ -116,7 +122,8 @@ def __init__(self, _internal=False):

def __repr__(self):
return f'Model(model_id={self.model_id}, status={self.status}, graphs={list(self.graphs.keys())}, ' + \
f'evaluator={self.evaluator}, metric={self.metric}, intermediate_metrics={self.intermediate_metrics})'
f'evaluator={self.evaluator}, metric={self.metric}, intermediate_metrics={self.intermediate_metrics}, ' + \
f'python_class={self.python_class})'

@property
def root_graph(self) -> 'Graph':
Expand All @@ -133,9 +140,12 @@ def fork(self) -> 'Model':
"""
new_model = Model(_internal=True)
new_model._root_graph_name = self._root_graph_name
new_model.python_class = self.python_class
new_model.python_init_params = self.python_init_params
new_model.graphs = {name: graph._fork_to(new_model) for name, graph in self.graphs.items()}
new_model.evaluator = copy.deepcopy(self.evaluator) # TODO this may be a problem when evaluator is large
new_model.history = self.history + [self]
new_model.history = [*self.history]
# Note: the history is not updated. It will be updated when the model is changed, that is in mutator.
return new_model

@staticmethod
Expand Down Expand Up @@ -167,8 +177,8 @@ def get_nodes(self) -> Iterable['Node']:

def get_nodes_by_label(self, label: str) -> List['Node']:
"""
Traverse all the nodes to find the matched node(s) with the given name.
There could be multiple nodes with the same name. Name space name can uniquely
Traverse all the nodes to find the matched node(s) with the given label.
There could be multiple nodes with the same label. Name space name can uniquely
identify a graph or node.

NOTE: the implementation does not support the class abstration
Expand Down Expand Up @@ -493,6 +503,8 @@ class Node:
If two models have nodes with same ID, they are semantically the same node.
name
Mnemonic name. It should have an one-to-one mapping with ID.
label
Optional. If two nodes have the same label, they are considered same by the mutator.
operation
...
cell
Expand All @@ -515,7 +527,7 @@ def __init__(self, graph, node_id, name, operation, _internal=False):
# TODO: the operation is likely to be considered editable by end-user and it will be hard to debug
# maybe we should copy it here or make Operation class immutable, in next release
self.operation: Operation = operation
self.label: str = None
self.label: Optional[str] = None

def __repr__(self):
return f'Node(id={self.id}, name={self.name}, label={self.label}, operation={self.operation})'
Expand Down Expand Up @@ -673,6 +685,37 @@ def _dump(self) -> Any:
}


class Mutation:
"""
An execution of mutation, which consists of four parts: a mutator, a list of decisions (choices),
the model that it comes from, and the model that it becomes.

In general cases, the mutation logs are not reliable and should not be replayed as the mutators can
be arbitrarily complex. However, for inline mutations, the labels correspond to mutator labels here,
this can be useful for metadata visualization and python execution mode.

Attributes
----------
mutator
Mutator.
samples
Decisions/choices.
from_
Model that is comes from.
to
Model that it becomes.
"""

def __init__(self, mutator: 'Mutator', samples: List[Any], from_: Model, to: Model): # noqa: F821
self.mutator: 'Mutator' = mutator # noqa: F821
self.samples: List[Any] = samples
self.from_: Model = from_
self.to: Model = to

def __repr__(self):
return f'Edge(mutator={self.mutator}, samples={self.samples}, from={self.from_}, to={self.to})'


class IllegalGraphError(ValueError):
def __init__(self, graph, *args):
self._debug_dump_graph(graph)
Expand Down
Loading