Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Matmul Flops #108

Open
breuera opened this issue May 17, 2022 · 2 comments
Open

Matmul Flops #108

breuera opened this issue May 17, 2022 · 2 comments
Labels
question Further information is requested

Comments

@breuera
Copy link

breuera commented May 17, 2022

The matmul flop counts seem to be off by 2x.

I tested the code on a simple MLP which reads as:

import torch.nn

## @package eml.mlp.Module
#  Simple MultiLayer Perceptron (MLP) with fixed dimensions.
#
#  The MLP is assumes a 28^2 input-image and 10 output classes.
#  These are the dimensions of the Fashion MNIST dataset.
class Model( torch.nn.Module ):
  ## Initializes the class.
  #  @param self object pointer.
  def __init__( self ):
    super( Model, self ).__init__()
    ## flattens the input
    self.m_flatten = torch.nn.Flatten()
    ## layers of the MLP: 3x(linear + relu)
    self.m_layers = torch.nn.Sequential( torch.nn.Linear( 28*28, 512 ),
                                         torch.nn.ReLU(),
                                         torch.nn.Linear( 512, 512 ),
                                         torch.nn.ReLU(),
                                         torch.nn.Linear( 512, 10 ) )

  ## Forward pass with the given input.
  #  @param self object pointer.
  #  @param i_input input for the forward pass.
  #  @return output of the MLP.
  def forward( self,
               i_input ):
    l_flatten = self.m_flatten( i_input )
    l_result = self.m_layers( l_flatten )
    return l_result

Embedded this in some code with the crucial piece here:

l_model = eml.mlp.model.Model()
[...]
print( l_model )

#
# flop count code
# https://github.com/facebookresearch/fvcore/blob/main/docs/flop_count.md
#
import fvcore.nn

l_x, l_y = next(iter(l_data_loader_train))

print( l_x.size() )

l_flops = fvcore.nn.FlopCountAnalysis( l_model,
                                       l_x )

print( l_flops.by_module_and_operator() )

print( fvcore.nn.flop_count_table( l_flops ) )

This returns:

Model(
  (m_flatten): Flatten(start_dim=1, end_dim=-1)
  (m_layers): Sequential(
    (0): Linear(in_features=784, out_features=512, bias=True)
    (1): ReLU()
    (2): Linear(in_features=512, out_features=512, bias=True)
    (3): ReLU()
    (4): Linear(in_features=512, out_features=10, bias=True)
  )
)
torch.Size([64, 1, 28, 28])
{'': Counter({'addmm': 42795008}), 'm_flatten': Counter(), 'm_layers': Counter({'addmm': 42795008}), 'm_layers.0': Counter({'addmm': 25690112}), 'm_layers.1': Counter(), 'm_layers.2': Counter({'addmm': 16777216}), 'm_layers.3': Counter(), 'm_layers.4': Counter({'addmm': 327680})}
| module     | #parameters or shape   | #flops   |
|:-----------|:-----------------------|:---------|
| m_layers   | 0.67M                  | 42.795M  |
|  0         |  0.402M                |  25.69M  |
|   0.weight |   (512, 784)           |          |
|   0.bias   |   (512,)               |          |
|  2         |  0.263M                |  16.777M |
|   2.weight |   (512, 512)           |          |
|   2.bias   |   (512,)               |          |
|  4         |  5.13K                 |  0.328M  |
|   4.weight |   (10, 512)            |          |
|   4.bias   |   (10,)                |          |

Let's take the first linear layer as an example: Matrix A in https://pytorch.org/docs/stable/generated/torch.nn.Linear.html has shape (512, 784).
Matrix x (since the example batched) has shape (64, 784).
Computing the result, C=xA^T requires 2*64*512*784 - 64*512 floating point operations.
However, in the example a bias is used, i.e., 64*512 additions on top -> 2*64*512*784=513,80,224 flops total; the tool reports 25,690,112 for the first layer. btw: I am not sure why the bias doesn't show up separately.

I believe that the code below is off since the number of ops of the op C+=AB using BLAS identifiers is 2*M*N*K not M*N*K:

flop = prod(input_shapes[0]) * input_shapes[-1][-1]

@breuera
Copy link
Author

breuera commented May 17, 2022

I appears that #69 (comment) raises the same concern.

We count one fused multiply-add as one flop.

I'd consider this to be an unconventional definition.
Even knowing this, I don't understand how the ops of the bias fit in there.

@ppwwyyxx
Copy link
Contributor

Different groups adopt different conventions, unfortunately. We implemented the convention in computer vision, which is to use MACs and ignore the flops of bias.

#77 would improve this, but unfortunately no one is working on it.

@nikhilaravi nikhilaravi added the question Further information is requested label May 19, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question Further information is requested
Projects
None yet
Development

No branches or pull requests

3 participants