-
Notifications
You must be signed in to change notification settings - Fork 188
/
utils.py
115 lines (83 loc) · 3.09 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
from typing import Dict, Optional, Tuple
import torch
from torch.utils._python_dispatch import TorchDispatchMode
from packaging import version
from functools import reduce
from math import gcd
__all__ = [
"find_multiple",
"compute_error",
"_apply_logging_hook",
"get_model_size_in_bytes",
"TORCH_VERSION_AFTER_2_3",
]
def find_multiple(n: int, *args: Tuple[int]) -> int:
k: int = reduce(lambda x, y: x * y // gcd(x, y), args + (1,)) # type: ignore[9]
if n % k == 0:
return n
return n + k - (n % k)
# basic SQNR
def compute_error(x, y):
Ps = torch.linalg.norm(x)
Pn = torch.linalg.norm(x - y)
return 20 * torch.log10(Ps / Pn)
# logger for fqn + op + shape
# note: not safe for any kind of multithreading
_cur_fqn: Optional[str] = None
def _get_logging_hook(fqn):
def forward_hook(module, input):
global _cur_fqn
_cur_fqn = fqn
return forward_hook
def _apply_logging_hook(model):
for name, mod in model.named_modules():
mod.register_forward_pre_hook(_get_logging_hook(name))
# collections.defaultdict printing is weird with lambdas, so hand writing for now
_fqn_to_op_to_shape_to_count: Dict[
Optional[str], Dict[Optional[str], Dict[Optional[str], int]]
] = {}
class LoggingTensorMode(TorchDispatchMode):
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
if kwargs is None:
kwargs = {}
rs = func(*args, **kwargs)
global _cur_fqn
op_name: str = f"{func.__module__}.{func.__name__}"
shape_str = ""
for arg in args:
if isinstance(arg, torch.Tensor):
shape_str += str(list(arg.shape)) + ", "
if shape_str != "":
shape_str = shape_str[:-2]
if _cur_fqn not in _fqn_to_op_to_shape_to_count:
_fqn_to_op_to_shape_to_count[_cur_fqn] = {}
if op_name not in _fqn_to_op_to_shape_to_count[_cur_fqn]:
_fqn_to_op_to_shape_to_count[_cur_fqn][op_name] = {}
if shape_str not in _fqn_to_op_to_shape_to_count[_cur_fqn][op_name]:
_fqn_to_op_to_shape_to_count[_cur_fqn][op_name][shape_str] = 0
_fqn_to_op_to_shape_to_count[_cur_fqn][op_name][shape_str] += 1
return rs
# https://discuss.pytorch.org/t/finding-model-size/130275
def get_model_size_in_bytes(model):
s = 0
for p in model.parameters():
s += p.nelement() * p.element_size()
for b in model.buffers():
s += b.nelement() * b.element_size()
return s
if version.parse(torch.__version__) >= version.parse("2.4.0.dev"):
TORCH_VERSION_AFTER_2_4 = True
else:
TORCH_VERSION_AFTER_2_4 = False
if version.parse(torch.__version__) >= version.parse("2.3.0.dev"):
TORCH_VERSION_AFTER_2_3 = True
else:
TORCH_VERSION_AFTER_2_3 = False
if version.parse(torch.__version__) >= version.parse("2.2.0.dev"):
TORCH_VERSION_AFTER_2_2 = True
else:
TORCH_VERSION_AFTER_2_2 = False