-
Notifications
You must be signed in to change notification settings - Fork 249
/
Copy pathutils.py
90 lines (74 loc) · 2.63 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
import pickle,os
from PIL import Image
import scipy.io
import time
from tqdm import tqdm
import pandas as pd
import shutil
from random import randint
import numpy as np
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
import torchvision.models as models
import torch.nn as nn
import torch
import torch.backends.cudnn as cudnn
from torch.autograd import Variable
# other util
def accuracy(output, target, topk=(1,)):
"""Computes the precision@k for the specified values of k"""
maxk = max(topk)
batch_size = target.size(0)
_, pred = output.topk(maxk, 1, True, True)
pred = pred.t()
correct = pred.eq(target.view(1, -1).expand_as(pred))
res = []
for k in topk:
correct_k = correct[:k].view(-1).float().sum(0)
res.append(correct_k.mul_(100.0 / batch_size))
return res
class AverageMeter(object):
"""Computes and stores the average and current value"""
def __init__(self):
self.reset()
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
def save_checkpoint(state, is_best, checkpoint, model_best):
torch.save(state, checkpoint)
if is_best:
shutil.copyfile(checkpoint, model_best)
def record_info(info,filename,mode):
if mode =='train':
result = (
'Time {batch_time} '
'Data {data_time} \n'
'Loss {loss} '
'Prec@1 {top1} '
'Prec@5 {top5}\n'
'LR {lr}\n'.format(batch_time=info['Batch Time'],
data_time=info['Data Time'], loss=info['Loss'], top1=info['Prec@1'], top5=info['Prec@5'],lr=info['lr']))
print result
df = pd.DataFrame.from_dict(info)
column_names = ['Epoch','Batch Time','Data Time','Loss','Prec@1','Prec@5','lr']
if mode =='test':
result = (
'Time {batch_time} \n'
'Loss {loss} '
'Prec@1 {top1} '
'Prec@5 {top5} \n'.format( batch_time=info['Batch Time'],
loss=info['Loss'], top1=info['Prec@1'], top5=info['Prec@5']))
print result
df = pd.DataFrame.from_dict(info)
column_names = ['Epoch','Batch Time','Loss','Prec@1','Prec@5']
if not os.path.isfile(filename):
df.to_csv(filename,index=False,columns=column_names)
else: # else it exists so append without writing the header
df.to_csv(filename,mode = 'a',header=False,index=False,columns=column_names)