-
Notifications
You must be signed in to change notification settings - Fork 0
/
monitor.py
53 lines (44 loc) · 1.76 KB
/
monitor.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
"""Utilities to log results to visdom"""
from collections import Iterable
import torch
import torchvision
import visdom
from PIL import ImageColor
class Scalar():
def __init__(self, name, title=None, xlabel=None, ylabel=None,
multi_trace=False, env=None):
self.vis = visdom.Visdom(env=env)
if title is None:
title = name
opts = dict(title=title, xlabel=xlabel, ylabel=ylabel,
showlegend=multi_trace)
self.win = name
self.opts = opts
self.multi_trace = multi_trace
def add(self, step, data, trace=None):
if not isinstance(step, torch.Tensor):
if not isinstance(step, Iterable):
step = [step]
step = torch.Tensor([*step])
if not self.vis.win_exists(self.win):
self.vis.line(X=step, Y=data, name=trace, win=self.win,
opts=self.opts)
if self.multi_trace and trace is None:
raise ValueError('Set trace when using multi-trace graph')
self.vis.line(X=step, Y=data, name=trace, win=self.win,
update='append', opts=self.opts)
class Image():
def __init__(self, title, env=None):
self.vis = visdom.Visdom(env=env)
self.title = title + '_{}'
def add(self, step, image):
if not isinstance(step, torch.Tensor):
if not isinstance(step, Iterable):
step = [step]
step = torch.Tensor([*step])
if image.shape[0] == 1:
self.vis.image(image.squeeze(0).detach(),
opts={'title': self.title.format(step.item())})
else:
self.vis.images(image.detach(),
opts={'title': self.title.format(step.item())})