Skip to content

Commit

Permalink
[Enhance] Use fvcore to calculate FLOPS. (#1000)
Browse files Browse the repository at this point in the history
* [Feature] Use fvcore for flops count

* update requirements

* update
  • Loading branch information
tonysy authored Aug 31, 2022
1 parent ccece7e commit 4367d05
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 4 deletions.
1 change: 1 addition & 0 deletions requirements/optional.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
albumentations>=0.3.2 --no-binary qudida,albumentations
colorama
fvcore
requests
35 changes: 31 additions & 4 deletions tools/analysis_tools/get_flops.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,17 @@
# Copyright (c) OpenMMLab. All rights reserved.
import argparse

from mmcv import Config
from mmcv.cnn.utils import get_model_complexity_info
import torch

try:
from fvcore.nn import (ActivationCountAnalysis, FlopCountAnalysis,
flop_count_str, flop_count_table, parameter_count)
except ImportError:
print('You may need to install fvcore for flops computation, '
'and you can use `pip install -r requirements/optional.txt` '
'to set up the environment')
from fvcore.nn.print_model_statistics import _format_size
from mmengine import Config

from mmcls.models import build_classifier

Expand Down Expand Up @@ -42,10 +51,28 @@ def main():
'FLOPs counter is currently not currently supported with {}'.
format(model.__class__.__name__))

flops, params = get_model_complexity_info(model, input_shape)
inputs = (torch.randn((1, *input_shape)), )
flops_ = FlopCountAnalysis(model, inputs)
activations_ = ActivationCountAnalysis(model, inputs)

flops = _format_size(flops_.total())
activations = _format_size(activations_.total())
params = _format_size(parameter_count(model)[''])

flop_table = flop_count_table(
flops=flops_,
activations=activations_,
show_param_shapes=True,
)
flop_str = flop_count_str(flops=flops_, activations=activations_)

print('\n' + flop_str)
print('\n' + flop_table)

split_line = '=' * 30
print(f'{split_line}\nInput shape: {input_shape}\n'
f'Flops: {flops}\nParams: {params}\n{split_line}')
f'Flops: {flops}\nParams: {params}\n'
f'Activation: {activations}\n{split_line}')
print('!!!Please be cautious if you use the results in papers. '
'You may need to check if all ops are supported and verify that the '
'flops computation is correct.')
Expand Down

0 comments on commit 4367d05

Please sign in to comment.