From 912627627d1dd8c16b9553085fd8d60a47d91562 Mon Sep 17 00:00:00 2001 From: Ching-Feng Yeh Date: Tue, 12 Apr 2022 15:09:50 -0700 Subject: [PATCH] Add _multi_head_attention_forward_flops_compute(). --- .../profiling/flops_profiler/profiler.py | 120 ++++++++++++++++++ .../profiling/testing/unit_profiling_test.py | 2 +- 2 files changed, 121 insertions(+), 1 deletion(-) diff --git a/deepspeed/profiling/flops_profiler/profiler.py b/deepspeed/profiling/flops_profiler/profiler.py index caa6fc1a6b91..5ed5b85d52e6 100644 --- a/deepspeed/profiling/flops_profiler/profiler.py +++ b/deepspeed/profiling/flops_profiler/profiler.py @@ -828,6 +828,120 @@ def _dropout_flops_compute(input, p=0.5, training=True, inplace=False): return 0, 0 +def _multi_head_attention_forward_flops_compute( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + embed_dim_to_check: int, + num_heads: int, + in_proj_weight: torch.Tensor, + in_proj_bias: typing.Optional[torch.Tensor], + bias_k: typing.Optional[torch.Tensor], + bias_v: typing.Optional[torch.Tensor], + add_zero_attn: bool, + dropout_p: float, + out_proj_weight: torch.Tensor, + out_proj_bias: typing.Optional[torch.Tensor], + training: bool = True, + key_padding_mask: typing.Optional[torch.Tensor] = None, + need_weights: bool = True, + attn_mask: typing.Optional[torch.Tensor] = None, + use_separate_proj_weight: bool = False, + q_proj_weight: typing.Optional[torch.Tensor] = None, + k_proj_weight: typing.Optional[torch.Tensor] = None, + v_proj_weight: typing.Optional[torch.Tensor] = None, + static_k: typing.Optional[torch.Tensor] = None, + static_v: typing.Optional[torch.Tensor] = None, + average_attn_weights: bool = True, + out: typing.Tuple[torch.Tensor, typing.Optional[torch.Tensor]] = None, + ): + """ + Count flops for torch.nn.functional.multi_head_attention_forward + """ + assert query.dim() == 2 or query.dim() == 3 + assert query.dim() == key.dim() + assert query.dim() == value.dim() + + if query.dim() == 3: + batch_size = query.shape[1] + assert key.shape[1] == batch_size + assert value.shape[1] == batch_size + else: + batch_size = 1 + + embed_dim = query.shape[-1] + assert embed_dim == embed_dim_to_check + assert embed_dim % num_heads == 0 + + head_dim = embed_dim // num_heads + tgt_len, src_len = query.shape[0], key.shape[0] + + if use_separate_proj_weight: + assert key.shape[:-1] == value.shape[:-1] + else: + assert key.shape == value.shape + + flops, macs = 0, 0 + + # flops and macs for in-projection. + if not use_separate_proj_weight: + # using in_proj_weight, which is of shape (3E, E), where E = embed_dim. + n = query.numel() * embed_dim + 2 * key.numel() * embed_dim + flops += 2 * n + macs += n + else: + n = ( + query.numel() * q_proj_weight.shape[0] + + 2 * key.numel() * k_proj_weight.shape[0] + ) + flops += 2 * n + macs += n + + if in_proj_bias is not None: + n = query.numel() + key.numel() + value.numel() + flops += n + macs += n + + # q = q / sqrt(head_dim) + flops += query.numel() + macs += query.numel() + + # q * k^T (bmm) + n = batch_size * num_heads * src_len * tgt_len * head_dim + flops += 2 * n + macs += n + + # attn_mask + n = batch_size * num_heads * src_len * tgt_len + flops += n + macs += n + + # softmax + n = batch_size * num_heads * src_len * tgt_len + flops += 3 * n + macs += 2 * n + + # dropout + if dropout_p > 0.0: + n = batch_size * num_heads * src_len * tgt_len + flops += n + macs += n + + # attn * v + n = batch_size * num_heads * src_len * tgt_len * head_dim + flops += 2 * n + macs += n + + # out-projection + n = batch_size * tgt_len * embed_dim * out_proj_weight.shape[0] + flops += 2 * n + macs += n + + import pdb; pdb.set_trace() + + return flops, macs + + def _matmul_flops_compute(input, other, *, out=None): """ Count flops for the matmul operation. @@ -1114,6 +1228,12 @@ def _patch_nn_functionals(): # embedding F.embedding = wrapFunc(F.embedding, _embedding_flops_compute) + + # multi_head_attention_forward + F.multi_head_attention_forward = wrapFunc( + F.multi_head_attention_forward, + _multi_head_attention_forward_flops_compute, + ) # not implemented _check_function_level_patch(F) diff --git a/deepspeed/profiling/testing/unit_profiling_test.py b/deepspeed/profiling/testing/unit_profiling_test.py index e5acff625e59..2777c89ac9b1 100644 --- a/deepspeed/profiling/testing/unit_profiling_test.py +++ b/deepspeed/profiling/testing/unit_profiling_test.py @@ -66,4 +66,4 @@ def get_args(): output_file=os.path.join(os.path.dirname(os.path.realpath(__file__)), "log/{}.txt".format(args.test_name)), ignore_modules=None, show_untracked=args.show_untracked, - ) \ No newline at end of file + )