-
Notifications
You must be signed in to change notification settings - Fork 6
/
Copy pathutils.py
77 lines (59 loc) · 2.21 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
"""
Some utility functions
"""
import torch
import torch.nn as nn
import numpy as np
from omegaconf import DictConfig
from omegaconf.listconfig import ListConfig
from typing import Tuple, Any, List, Union, Dict
def get_device() -> Tuple[torch.device, bool]:
"Get GPU if available"
use_gpu = torch.cuda.is_available()
return torch.device("cuda" if use_gpu else "cpu"), use_gpu
def load_model(model_path: str) -> Dict[str, Any]:
"Load a model checkpoint"
if torch.cuda.is_available():
return torch.load(model_path) # type: ignore
else:
return torch.load(model_path, map_location=torch.device("cpu")) # type: ignore
def count_params(model: nn.Module) -> int:
"The number of parameters in a PyTorch model"
return sum([p.numel() for p in model.parameters()])
def count_actions(
pred: Union[List[Any], List[List[Any]]], gt: Union[List[Any], List[List[Any]]]
) -> Tuple[int, int]:
"Count the number of correct actions and the number of total actions"
if isinstance(pred[0], list):
num_correct = np.sum(
np.sum(x == y for x, y in zip(pred_seq, gt_seq)) # type: ignore
for pred_seq, gt_seq in zip(pred, gt)
)
num_total = np.sum(len(pred_seq) for pred_seq in pred) # type: ignore
else:
num_correct = np.sum([x == y for x, y in zip(pred, gt)])
num_total = len(pred)
return num_correct, num_total
def conf2list(cfg: ListConfig) -> List[Any]:
cfg_list: List[Any] = []
for v in cfg:
if isinstance(v, ListConfig):
cfg_list.append(conf2list(v))
elif isinstance(v, DictConfig):
cfg_list.append(conf2dict(v))
else:
assert v is None or isinstance(v, (str, int, float, bool))
cfg_list.append(v)
return cfg_list
def conf2dict(cfg: DictConfig) -> Dict[str, Any]:
cfg_dict: Dict[str, Any] = {}
for k, v in cfg.items():
assert isinstance(k, str)
if isinstance(v, ListConfig):
cfg_dict[k] = conf2list(v)
elif isinstance(v, DictConfig):
cfg_dict[k] = conf2dict(v)
else:
assert v is None or isinstance(v, (str, int, float, bool))
cfg_dict[k] = v
return cfg_dict