-
Notifications
You must be signed in to change notification settings - Fork 1
/
test.py
36 lines (28 loc) · 1.12 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
import numpy as np
import numpy.ma as ma
from torch.utils.data import DataLoader
from torchvision import transforms, utils
from dataset import FATDataset, RandomCrop, ToTensor
import matplotlib.pyplot as plt
import matplotlib.colors as cls
def show_batch(batched_samples):
batch_depth = batched_samples['depth']
grid = utils.make_grid(batch_depth).numpy().astype(np.int)
norm = cls.Normalize(vmin=0, vmax=255)
grid = ma.getdata(norm(grid)) # Convert the masked array to ndarray
plt.imshow(grid.astype(np.int).transpose((1, 2, 0)))
fat_dataset = FATDataset("/home/dong/dgn-pytorch/dataset/fat",
"train", trans=transforms.Compose([RandomCrop(480), ToTensor()]))
dataloader = DataLoader(fat_dataset, batch_size=4,
shuffle=True, num_workers=4)
for i_batch, sample_batched in enumerate(dataloader):
print(i_batch, sample_batched['depth'].size(),
sample_batched['left'].size())
# observe 4th batch and stop.
if i_batch == 3:
plt.figure()
show_batch(sample_batched)
plt.axis('off')
plt.ioff()
plt.show()
break