-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathvisualise.py
69 lines (56 loc) · 2.13 KB
/
visualise.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
import os
import time
# There are two ways to load the data from the PANDA dataset:
# Option 1: Load images using openslide
import openslide
# Option 2: Load images using skimage (requires that tifffile is installed)
import skimage.io
import random
import seaborn as sns
import cv2
# General packages
import pandas as pd
import numpy as np
from scipy import ndimage
import matplotlib
import matplotlib.pyplot as plt
plt.ion()
import PIL
from torch.utils.data import DataLoader
from dataset import PandaDataset
from albumentations import Compose, HorizontalFlip, VerticalFlip, Transpose, HueSaturationValue, RandomBrightness, RandomContrast, RandomGamma, ShiftScaleRotate
from tqdm import tqdm
root_path = f'/home/nvme/Kaggle/prostate-cancer-grade-assessment/'
df = pd.read_csv(root_path + 'train.csv')
mask_present = [] # Only about 100 images in the dataset have no mask so just ignore them for training
for idx in df['image_id']:
mask_present += [os.path.isfile(os.path.join(root_path, 'train_label_masks', idx + '_mask.tiff'))]
df = df[mask_present]
transforms = Compose([Transpose(p=0.5),
VerticalFlip(p=0.5),
HorizontalFlip(p=0.5),
#ShiftScaleRotate(p=0.5)
])
dataset = PandaDataset(root_path, df, level=1, patch_size=256, num_patches=25, use_mask=True, transforms=transforms)
dataloader = DataLoader(dataset, batch_size=1, shuffle=False, pin_memory=False, num_workers=16)
t0 = time.time()
x, y = dataset[2209]
t1 = time.time()
print(t1 - t0)
x, y = dataset[0]
t2 = time.time()
print(t2 - t1)
print(y[0].shape)
#for x, y, in tqdm(dataloader, total=len(dataloader)):
#pass
cmap = matplotlib.colors.ListedColormap(['black', 'gray', 'blue', 'green', 'yellow', 'red'])
for j in range(5):
t0 = time.time()
image, (mask, label) = dataset[j]
print('Dataloading time', time.time() - t0)
plt.figure(figsize=(32, 32))
for i in range(25):
plt.subplot(5, 10, 2 * i + 1)
plt.imshow(image[i].permute(1, 2, 0).numpy())
plt.subplot(5, 10, 2 * i + 2)
plt.imshow(mask[i].max(dim=0)[1], cmap=cmap, interpolation='nearest', vmin=0, vmax=5)