Skip to content
This repository has been archived by the owner on Sep 18, 2024. It is now read-only.

Fix bug of FLOPs counter #3497

Merged
merged 6 commits into from
Apr 21, 2021
Merged
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
153 changes: 143 additions & 10 deletions nni/compression/pytorch/utils/counter.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import torch
import torch.nn as nn
from torch.nn.utils.rnn import PackedSequence
from nni.compression.pytorch.compressor import PrunerModuleWrapper


Expand Down Expand Up @@ -39,14 +40,14 @@ def __init__(self, custom_ops=None, mode='default'):
nn.Conv1d: self._count_convNd,
nn.Conv2d: self._count_convNd,
nn.Conv3d: self._count_convNd,
nn.ConvTranspose1d: self._count_convNd,
nn.ConvTranspose2d: self._count_convNd,
nn.ConvTranspose3d: self._count_convNd,
nn.Linear: self._count_linear
}
self._count_bias = False
if mode == 'full':
self.ops.update({
nn.ConvTranspose1d: self._count_convNd,
nn.ConvTranspose2d: self._count_convNd,
nn.ConvTranspose3d: self._count_convNd,
nn.BatchNorm1d: self._count_bn,
nn.BatchNorm2d: self._count_bn,
nn.BatchNorm3d: self._count_bn,
Expand All @@ -59,7 +60,13 @@ def __init__(self, custom_ops=None, mode='default'):
nn.AdaptiveAvgPool3d: self._count_adap_avgpool,
nn.Upsample: self._count_upsample,
nn.UpsamplingBilinear2d: self._count_upsample,
nn.UpsamplingNearest2d: self._count_upsample
nn.UpsamplingNearest2d: self._count_upsample,
nn.RNNCell: self._count_rnn_cell,
nn.GRUCell: self._count_gru_cell,
nn.LSTMCell: self._count_lstm_cell,
nn.RNN: self._count_rnn,
nn.GRU: self._count_gru,
nn.LSTM: self._count_lstm,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just curious, why only conv and linear are supported by default?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Most pruning papers target on the classification task and only report results of conv and linear. Move RNN modules to default would be better.

})
self._count_bias = True

Expand All @@ -86,7 +93,7 @@ def _get_result(self, m, flops):

def _count_convNd(self, m, x, y):
cin = m.in_channels
kernel_ops = m.weight.size()[2] * m.weight.size()[3]
kernel_ops = torch.zeros(m.weight.size()[2:]).numel()
output_size = torch.zeros(y.size()[2:]).numel()
cout = y.size()[1]

Expand Down Expand Up @@ -156,13 +163,143 @@ def _count_upsample(self, m, x, y):

return self._get_result(m, total_ops)

def _count_cell_flops(self, input_size, hidden_size, cell_type):
state_ops = hidden_size * (input_size + hidden_size) + hidden_size
if self._count_bias:
state_ops += hidden_size * 2

if cell_type == 'rnn':
return state_ops

total_ops = 0
if cell_type == 'gru':
total_ops += state_ops * 2
total_ops += (hidden_size + input_size) * hidden_size + hidden_size
if self._count_bias:
total_ops += hidden_size * 2

total_ops += hidden_size * 4

elif cell_type == 'lstm':
total_ops += state_ops * 4
total_ops += hidden_size * 4

return total_ops


def _count_rnn_cell(self, m, x, y):
total_ops = self._count_cell_flops(m.input_size, m.hidden_size, 'rnn')
batch_size = x[0].size(0)
total_ops *= batch_size

return self._get_result(m, total_ops)

def _count_gru_cell(self, m, x, y):
total_ops = self._count_cell_flops(m.input_size, m.hidden_size, 'gru')
batch_size = x[0].size(0)
total_ops *= batch_size

return self._get_result(m, total_ops)

def _count_lstm_cell(self, m, x, y):
total_ops = self._count_cell_flops(m.input_size, m.hidden_size, 'lstm')
batch_size = x[0].size(0)
total_ops *= batch_size

return self._get_result(m, total_ops)

def _get_bsize_nsteps(self, m, x):
if isinstance(x[0], PackedSequence):
batch_size = torch.max(x[0].batch_sizes)
num_steps = x[0].batch_sizes.size(0)
else:
if m.batch_first:
batch_size = x[0].size(0)
num_steps = x[0].size(1)
else:
batch_size = x[0].size(1)
num_steps = x[0].size(0)

return batch_size, num_steps

def _count_rnn(self, m, x, y):
input_size = m.input_size
hidden_size = m.hidden_size
num_layers = m.num_layers

batch_size, num_steps = self._get_bsize_nsteps(m, x)
total_ops = self._count_cell_flops(input_size, hidden_size, 'rnn')

if m.bidirectional:
total_ops *= 2

for _ in range(num_layers - 1):
if m.bidirectional:
cell_flops = self._count_cell_flops(hidden_size * 2, hidden_size, 'rnn') * 2
else:
cell_flops = self._count_cell_flops(hidden_size, hidden_size,'rnn')
total_ops += cell_flops

total_ops *= num_steps
total_ops *= batch_size

return self._get_result(m, total_ops)

def _count_gru(self, m, x, y):
hidden_size = m.hidden_size
num_layers = m.num_layers

batch_size, num_steps = self._get_bsize_nsteps(m, x)
total_ops = self._count_cell_flops(hidden_size * 2, hidden_size, 'gru')
if m.bidirectional:
total_ops *= 2

for _ in range(num_layers - 1):
if m.bidirectional:
cell_flops = self._count_cell_flops(hidden_size * 2, hidden_size,
'gru') * 2
else:
cell_flops = self._count_cell_flops(hidden_size, hidden_size, 'gru')
total_ops += cell_flops

total_ops *= num_steps
total_ops *= batch_size

return self._get_result(m, total_ops)

def _count_lstm(self, m, x, y):
hidden_size = m.hidden_size
num_layers = m.num_layers
batch_size, num_steps = self._get_bsize_nsteps(m, x)

total_ops = self._count_cell_flops(hidden_size * 2, hidden_size,
'lstm')
if m.bidirectional:
total_ops *= 2

for _ in range(num_layers - 1):
if m.bidirectional:
cell_flops = self._count_cell_flops(hidden_size * 2, hidden_size,
'lstm') * 2
else:
cell_flops = self._count_cell_flops(hidden_size, hidden_size, 'lstm')
total_ops += cell_flops

total_ops *= num_steps
total_ops *= batch_size

return self._get_result(m, total_ops)


def count_module(self, m, x, y, name):
# assume x is tuple of single tensor
result = self.ops[type(m)](m, x, y)
output_size = y[0].size() if isinstance(y, tuple) else y.size()

total_result = {
'name': name,
'input_size': tuple(x[0].size()),
'output_size': tuple(y.size()),
'output_size': tuple(output_size),
'module_type': type(m).__name__,
**result
}
Expand Down Expand Up @@ -279,10 +416,6 @@ def count_flops_params(model, x, custom_ops=None, verbose=True, mode='default'):
model(*x)

# restore origin status
for name, m in model.named_modules():
if hasattr(m, 'weight_mask'):
delattr(m, 'weight_mask')

model.train(training).to(original_device)
for handler in handler_collection:
handler.remove()
Expand Down