diff --git a/examples/model_compress/model_prune_tf.py b/examples/model_compress/model_prune_tf.py index 8b0950aac5..1d39c44d75 100644 --- a/examples/model_compress/model_prune_tf.py +++ b/examples/model_compress/model_prune_tf.py @@ -28,31 +28,21 @@ def get_dataset(dataset_name='mnist'): def create_model(model_name='naive'): assert model_name == 'naive' - return NaiveModel() - -class NaiveModel(tf.keras.Model): - def __init__(self): - super().__init__() - self.seq_layers = [ - tf.keras.layers.Conv2D(filters=20, kernel_size=5), - tf.keras.layers.BatchNormalization(), - tf.keras.layers.ReLU(), - tf.keras.layers.MaxPool2D(pool_size=2), - tf.keras.layers.Conv2D(filters=20, kernel_size=5), - tf.keras.layers.BatchNormalization(), - tf.keras.layers.ReLU(), - tf.keras.layers.MaxPool2D(pool_size=2), - tf.keras.layers.Flatten(), - tf.keras.layers.Dense(units=500), - tf.keras.layers.ReLU(), - tf.keras.layers.Dense(units=10), - tf.keras.layers.Softmax() - ] - - def call(self, x): - for layer in self.seq_layers: - x = layer(x) - return x + return tf.keras.Sequential([ + tf.keras.layers.Conv2D(filters=20, kernel_size=5), + tf.keras.layers.BatchNormalization(), + tf.keras.layers.ReLU(), + tf.keras.layers.MaxPool2D(pool_size=2), + tf.keras.layers.Conv2D(filters=20, kernel_size=5), + tf.keras.layers.BatchNormalization(), + tf.keras.layers.ReLU(), + tf.keras.layers.MaxPool2D(pool_size=2), + tf.keras.layers.Flatten(), + tf.keras.layers.Dense(units=500), + tf.keras.layers.ReLU(), + tf.keras.layers.Dense(units=10), + tf.keras.layers.Softmax() + ]) def create_pruner(model, pruner_name): diff --git a/src/sdk/pynni/nni/compression/tensorflow/compressor.py b/src/sdk/pynni/nni/compression/tensorflow/compressor.py index 91fc6da3d9..7f9b1bc6ae 100644 --- a/src/sdk/pynni/nni/compression/tensorflow/compressor.py +++ b/src/sdk/pynni/nni/compression/tensorflow/compressor.py @@ -15,41 +15,6 @@ _logger = logging.getLogger(__name__) -class LayerInfo: - """ - This structure contains all infomation needed to compress a TensorFlow ``Layer``. - - - Attributes - ---------- - layer : tf.keras.layers.Layer - The layer. - name : str - The layer's name. Note that it's local to sub-model and may differ from its attribute name. - type : str - Name of the layer's class. - path : list of str or tuple of (str, int) - The layer object's and its parents' attribute name / list index. - For example, if the path is `[('cells', 2), 'conv']`, then the layer can be accessed as `model.cells[2].conv`. - config : JSON object - Selected configuration for this layer. The format is detailed in tutorial. - - Parameters - ---------- - layer : tf.keras.layers.Layer - See attributes section. - path : list of str or tuple of (str, int) - See attributes section. - """ - - def __init__(self, layer, path=None): - self.layer = layer - self.name = layer.name - self.type = type(layer).__name__ - self.path = path - self.config = None - - class Compressor: """ Common base class for all compressors. @@ -57,41 +22,38 @@ class Compressor: This class is designed for other base classes. Algorithms should inherit ``Pruner`` or ``Quantizer`` instead. - Attributes ---------- - bound_model : tf.keras.Model + compressed_model : tf.keras.Model Compressed user model. wrappers : list of tf.keras.Model A wrapper is an instrumented TF ``Layer``, in ``Model`` format. - The list is ordered by preorder traversal. Parameters ---------- - LayerWrapperClass : a class derive from Model - The class used to instrument layers. model : tf.keras.Model The user model to be compressed. config_list : list of JSON object User configuration. The format is detailed in tutorial. + LayerWrapperClass : a class derive from Model + The class used to instrument layers. """ - def __init__(self, LayerWrapperClass, model, config_list): + def __init__(self, model, config_list, LayerWrapperClass): assert isinstance(model, tf.keras.Model) - if isinstance(model, tf.keras.Sequential): - raise ValueError('NNI model compression does not support `Sequential` model for now') self.validate_config(model, config_list) - self.bound_model = model - self.wrappers = [] + self._original_model = model + self._config_list = config_list + self._wrapper_class = LayerWrapperClass + self._wrappers = {} # key: id(layer) , value: Wrapper(layer) + + self.compressed_model = self._instrument(model) + self.wrappers = list(self._wrappers.values()) - for layer_info in _detect_layers_to_compress(model, config_list): - self.wrappers.append(LayerWrapperClass(layer_info, self)) if not self.wrappers: _logger.warning('Nothing is configured to compress, please check your model and config list') - _instrument_model(model, self.wrappers) - def set_wrappers_attribute(self, name, value): """ Call ``setattr`` on all wrappers. @@ -99,6 +61,73 @@ def set_wrappers_attribute(self, name, value): for wrapper in self.wrappers: setattr(wrapper, name, value) + def validate_config(self, model, config_list): + """ + Compression algorithm should overload this function to validate configuration. + """ + pass + + + def _instrument(self, layer): + if isinstance(layer, tf.keras.Sequential): + return self._instrument_sequential(layer) + if isinstance(layer, tf.keras.Model): + return self._instrument_model(layer) + + # a layer can be referenced in multiple attributes of a model, + # but should only be instrumented once + if id(layer) in self._wrappers: + return self._wrappers[id(layer)] + + config = self._select_config(layer) + if config is not None: + wrapper = self._wrapper_class(layer, config, self) + self._wrappers[id(layer)] = wrapper + return wrapper + + return layer + + def _instrument_sequential(self, seq): + layers = list(seq.layers) # seq.layers is read-only property + need_rebuild = False + for i, layer in enumerate(layers): + new_layer = self._instrument(layer) + if new_layer is not layer: + layers[i] = new_layer + need_rebuild = True + return tf.keras.Sequential(layers) if need_rebuild else seq + + def _instrument_model(self, model): + for key, value in list(model.__dict__.items()): # avoid "dictionary keys changed during iteration" + if isinstance(value, tf.keras.layers.Layer): + new_layer = self._instrument(value) + if new_layer is not value: + setattr(model, key, new_layer) + elif isinstance(value, list): + for i, item in enumerate(value): + if isinstance(item, tf.keras.layers.Layer): + value[i] = self._instrument(item) + return model + + + def _select_config(self, layer): + # Find the last matching config block for given layer. + # Returns None if the layer should not be compressed. + layer_type = type(layer).__name__ + last_match = None + for config in self._config_list: + if 'op_types' in config: + match = layer_type in config['op_types'] + match_default = 'default' in config['op_types'] and layer_type in default_layers.weighted_modules + if not match and not match_default: + continue + if 'op_names' in config and layer.name not in config['op_names']: + continue + last_match = config + if last_match is None or 'exclude' in last_match: + return None + return last_match + class Pruner(Compressor): """ @@ -121,7 +150,7 @@ class Pruner(Compressor): User configuration. The format is detailed in tutorial. """ def __init__(self, model, config_list): - super().__init__(PrunerLayerWrapper, model, config_list) + super().__init__(model, config_list, PrunerLayerWrapper) #self.callback = PrunerCallback(self) def compress(self): @@ -133,10 +162,10 @@ def compress(self): Returns ------- tf.keras.Model - The compressed model, for convenience. This is exactly the same object to constructor argument. + The compressed model. """ self._update_mask() - return self.bound_model + return self.compressed_model def calc_masks(self, wrapper, **kwargs): """ @@ -195,11 +224,10 @@ class PrunerLayerWrapper(tf.keras.Model): Afterwards, `masks` is the last return value of ``Pruner.calc_masks``. See ``Pruner.calc_masks`` for details. """ - def __init__(self, layer_info, pruner): + def __init__(self, layer, config, pruner): super().__init__() - self.layer_info = layer_info - self.layer = layer_info.layer - self.config = layer_info.config + self.layer = layer + self.config = config self.pruner = pruner self.masks = {} _logger.info('Layer detected to compress: %s', self.layer.name) @@ -226,82 +254,3 @@ def call(self, *inputs): # # def on_train_batch_end(self, batch, logs=None): # self._pruner.update_mask() - - -def _detect_layers_to_compress(model, config_list): - # Returns list of LayerInfo. - located_layers = _locate_layers(model) - ret = [] - for layer in model.layers: - config = _select_config(LayerInfo(layer), config_list) - if config is not None: - if id(layer) not in located_layers: - _logger.error('Failed to locate layer %s in model. The layer will not be compressed. ' - 'This is a bug in NNI, feel free to fire an issue.', layer.name) - continue - layer_info = located_layers[id(layer)] - layer_info.config = config - ret.append(layer_info) - return ret - -def _locate_layers(model, cur_path=[]): - # Find out how to access layers from model object. - # Returns dict of (layer's object ID, LayerInfo). - # This function is required because TF framework does not track layer's attribute name, - # and to my knowledge `Layer.name` is only useful for read-only access. - # `cur_path`s format is documented in `LayerInfo.path`. - # TODO: it can only find layers in `Model` and `list` for now. - assert isinstance(model, tf.keras.Model) - if isinstance(model, tf.keras.Sequential): - _logger.warning('`Sequential` model is not supported yet, ignored.') - ret = {} - for key, value in model.__dict__.items(): - if isinstance(value, tf.keras.Model): - ret.update(_locate_layers(value, cur_path + [key])) - elif isinstance(value, tf.keras.layers.Layer): - ret[id(value)] = LayerInfo(value, cur_path + [key]) - elif isinstance(value, list): - for i, item in enumerate(value): - if isinstance(item, tf.keras.Model): - ret.update(_locate_layers(item, cur_path + [(key, i)])) - elif isinstance(item, tf.keras.layers.Layer): - ret[id(item)] = LayerInfo(item, cur_path + [(key, i)]) - return ret - -def _select_config(layer_info, config_list): - # Find the last matching config block for given layer. - # Returns None if the layer should not be compressed. - ret = None - for config in config_list: - if 'op_types' in config: - match = layer_info.type in config['op_types'] - match_default = 'default' in config['op_types'] and layer_info.type in default_layers.weighted_modules - if not match and not match_default: - continue - if 'op_names' in config and layer_info.name not in config['op_names']: - continue - ret = config - if ret is None or 'exclude' in ret: - return None - return ret - - -def _instrument_model(model, wrappers): - # Replace layers to wrappers - for wrapper in reversed(wrappers): - cur = model - for key in wrapper.layer_info.path[:-1]: - if isinstance(key, str): - cur = getattr(cur, key) - else: - name, index = key - cur = getattr(cur, name)[index] - key = wrapper.layer_info.path[-1] - if isinstance(key, str): - setattr(cur, key, wrapper) - else: - name, index = key - getattr(cur, name)[index] = wrapper - #if isinstance(cur, tf.keras.Sequential): - # cur._graph_initialized = False - # cur._layer_call_argspecs[wrapper] = cur._layer_call_argspecs[wrapper.layer] diff --git a/src/sdk/pynni/tests/test_compressor_tf.py b/src/sdk/pynni/tests/test_compressor_tf.py index f52bd34f2c..6e148aa776 100644 --- a/src/sdk/pynni/tests/test_compressor_tf.py +++ b/src/sdk/pynni/tests/test_compressor_tf.py @@ -9,9 +9,12 @@ #### # -# This file tests pruners on 2 models: a classic CNN model, and a naive model with one linear layer +# This file tests pruners on 3 models: +# A classic CNN model built by inheriting `Model`; +# The same CNN model built with `Sequential`; +# A naive model with only one linear layer. # -# The CNN model is used to test layer detecting and instrumenting. +# The CNN models are used to test layer detecting and instrumenting. # # The naive model is used to test mask calculation. # It has a single 10x10 linear layer without bias, and `reduce_sum` its result. @@ -31,11 +34,12 @@ def test_layer_detection(self): # Conv and dense layers should be compressed, pool and flatten should not. # This also tests instrumenting functionality. self._test_layer_detection_on_model(CnnModel()) + self._test_layer_detection_on_model(build_sequential_model()) def _test_layer_detection_on_model(self, model): pruner = pruners['level'](model) pruner.compress() - layer_types = sorted(wrapper.layer_info.type for wrapper in pruner.wrappers) + layer_types = sorted(type(wrapper.layer).__name__ for wrapper in pruner.wrappers) assert layer_types == ['Conv2D', 'Dense', 'Dense'], layer_types def test_level_pruner(self): @@ -73,6 +77,15 @@ def call(self, x): x = self.fc2(x) return x + def build_sequential_model(): + return Sequential([ + Conv2D(filters=10, kernel_size=3, activation='relu'), + MaxPool2D(pool_size=2), + Flatten(), + Dense(units=10, activation='relu'), + Dense(units=5, activation='softmax'), + ]) + class NaiveModel(Model): def __init__(self): super().__init__()