Skip to content

Commit

Permalink
Support pruning nested model recursively.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 367125045
  • Loading branch information
liyunlu0618 authored and tensorflower-gardener committed Apr 7, 2021
1 parent f08d37a commit ecf3dc5
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 2 deletions.
12 changes: 10 additions & 2 deletions tensorflow_model_optimization/python/core/sparsity/keras/prune.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,14 @@ def _prune_list(layers, **params):
return wrapped_layers

def _add_pruning_wrapper(layer):
if isinstance(layer, keras.Model):
# Check whether the model is a subclass model.
if (not layer._is_graph_network and
not isinstance(layer, keras.models.Sequential)):
raise ValueError('Subclassed models are not supported currently.')

return keras.models.clone_model(
layer, input_tensors=None, clone_function=_add_pruning_wrapper)
if isinstance(layer, pruning_wrapper.PruneLowMagnitude):
return layer
return pruning_wrapper.PruneLowMagnitude(layer, **params)
Expand All @@ -172,6 +180,7 @@ def _add_pruning_wrapper(layer):
'block_size': block_size,
'block_pooling_type': block_pooling_type
}

is_sequential_or_functional = isinstance(
to_prune, keras.Model) and (isinstance(to_prune, keras.Sequential) or
to_prune._is_graph_network)
Expand All @@ -183,8 +192,7 @@ def _add_pruning_wrapper(layer):
if isinstance(to_prune, list):
return _prune_list(to_prune, **params)
elif is_sequential_or_functional:
return keras.models.clone_model(
to_prune, input_tensors=None, clone_function=_add_pruning_wrapper)
return _add_pruning_wrapper(to_prune)
elif is_keras_layer:
params.update(kwargs)
return pruning_wrapper.PruneLowMagnitude(to_prune, **params)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -397,6 +397,27 @@ def testPrunesEmbedding_ReachesTargetSparsity(self):
input_data = np.random.randint(10, size=(32, 5))
self._check_strip_pruning_matches_original(model, 0.5, input_data)

def testPruneRecursivelyReachesTargetSparsity(self):
internal_model = keras.Sequential(
[keras.layers.Dense(10, input_shape=(10,))])
model = keras.Sequential([
internal_model,
layers.Flatten(),
layers.Dense(1),
])
model.compile(
loss='binary_crossentropy', optimizer='sgd', metrics=['accuracy'])
test_utils.assert_model_sparsity(self, 0.0, model)
model.fit(
np.random.randint(10, size=(32, 10)),
np.random.randint(2, size=(32, 1)),
callbacks=[pruning_callbacks.UpdatePruningStep()])

test_utils.assert_model_sparsity(self, 0.5, model)

input_data = np.random.randint(10, size=(32, 10))
self._check_strip_pruning_matches_original(model, 0.5, input_data)

@parameterized.parameters(test_utils.model_type_keys())
def testPrunesMnist_ReachesTargetSparsity(self, model_type):
model = test_utils.build_mnist_model(model_type, self.params)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -304,6 +304,16 @@ def testPruneFunctionalModelPreservesBuiltState(self):
json.loads(pruned_model.to_json()))
self.assertEqual(loaded_model.built, True)

def testPruneModelRecursively(self):
internal_model = keras.Sequential(
[keras.layers.Dense(10, input_shape=(10,))])
original_model = keras.Sequential([
internal_model,
layers.Dense(10),
])
pruned_model = prune.prune_low_magnitude(original_model, **self.params)
self.assertEqual(self._count_pruned_layers(pruned_model), 2)

def testPruneSubclassModel(self):
model = TestSubclassedModel()
with self.assertRaises(ValueError) as e:
Expand Down

0 comments on commit ecf3dc5

Please sign in to comment.