Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Channel Wise Pruning leading very poor results #53

Closed
s36srini opened this issue Aug 1, 2019 · 1 comment
Closed

Channel Wise Pruning leading very poor results #53

s36srini opened this issue Aug 1, 2019 · 1 comment
Labels
technique:pruning Regarding tfmot.sparsity.keras APIs and docs

Comments

@s36srini
Copy link

s36srini commented Aug 1, 2019

Hello, I'm updating the channel_mask method to introduce channel wise sparsity for point-wise convolution layers (MobileNet V1). I have yet to check the final weights of the model as it's still undergoing training, however the training accuracy is starting to saturate at around 35%. Can anyone let me know if I made a mistake somewhere along in my tensor manipulation?

The reason for channel-wise pruning is to be able to use network surgery to eliminate channels that contribute very little to the inferential ability of the model, allowing for faster inference time with embedded devices.

 def _update_block_channel_mask(self, weights):
      """
      Performs channel wise masking of the weights (function below doesn't work).
      Args:
        weights: The weight tensor that needs to be masked.

      Returns:
        new_threshold: The new value of the threshold based on weights, and sparsity at the current global_step
        new_mask: A numpy array of the same size and shape as the weights containing 0 or 1 to indicate which of the values in weights falls below
        the threshold

      Raises:
        ValueError: if block pooling function is not AVG or MAX

      """
      # The weights should be of shape (1, 1, j, k), representing the shape of a pointwise convolutional layer.

      sparsity = self._pruning_schedule(self._step_fn())[1]
      with ops.name_scope('pruning_ops'):
        abs_weights = math_ops.abs(weights)

        k = math_ops.cast(sparsity * math_ops.cast(abs_weights.shape[-1], dtypes.float32), dtypes.int32)

        # Tranpose to have rows as values per output channel tensor
        squeezed_weights = tf.transpose(array_ops.squeeze(abs_weights))

        # Calculating the sum per output channel
        channel_sums = tf.sort(tf.reduce_sum(squeezed_weights, 1))

        # Grab the smallest K magnitude channels
        min_sums = tf.slice(channel_sums, [0], [k])

        current_threshold = array_ops.gather(min_sums, k - 1)

        # If any row sums matches the min K sums, set it to zero (prune it), otherwise set it to 1
        new_mask = tf.map_fn(lambda x: tf.cond(tf.reduce_any(tf.equal(tf.reduce_sum(x), min_sums)), lambda: tf.zeros_like(x),  lambda: tf.ones_like(x)), squeezed_weights)

        # Tranpose back to I x O and reshape into 1x1xIxO
        new_mask = tf.transpose(new_mask)
        new_mask = tf.reshape(new_mask, abs_weights.shape)

      return current_threshold, new_mask
@s36srini
Copy link
Author

s36srini commented Aug 9, 2019

For some reason, adding regularizers to the convolution layers when re-constructing the model with pruning layers causes the masking functionality to fail. The resultant model had no zeroes. After removing the regularizer, the pruning began to work channel-wise. Closing this issue for now.

@s36srini s36srini closed this as completed Aug 9, 2019
@alanchiao alanchiao added the technique:pruning Regarding tfmot.sparsity.keras APIs and docs label Feb 6, 2020
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
technique:pruning Regarding tfmot.sparsity.keras APIs and docs
Projects
None yet
Development

No branches or pull requests

2 participants