-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdata.py
97 lines (74 loc) · 2.76 KB
/
data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
from pathlib import Path
from PIL import Image
from torch.utils.data import DataLoader,Dataset
import albumentations as A
from albumentations.pytorch import ToTensorV2
from config import Config
import numpy as np
from torchvision.utils import make_grid
from ignite.handlers.tensorboard_logger import TensorboardLogger
data_path = Path(__file__).parent / 'pix2pix_dataset' / 'maps' / 'maps'
train_path = data_path / 'train'
val_path = data_path / 'val'
def get_imgs_files(path:Path)->list[str]:
img_files = list(path.glob('*.jpg'))
img_files = [str(img_file) for img_file in img_files]
return img_files
class MapDataset(Dataset):
def __init__(self,path:Path,tranform,valid=False,device='cuda'):
super().__init__()
self.files_paths = get_imgs_files(path)
self.transform = tranform
self.device=device
if valid:
self.files_paths = self.files_paths[:16] # we dont need to much data for validation logging
def __len__(self):
return len(self.files_paths)
def __getitem__(self, index) :
img_path = self.files_paths[index]
full_img = Image.open(img_path)
full_img = np.array(full_img)
input_img , output_img = full_img[:,:600,:],full_img[:,600:,:]
transformed = self.transform(image=input_img, output_image=output_img)
input_img_t = transformed['image']
output_img_t = transformed['output_image']
return input_img_t.to(self.device), output_img_t.to(self.device)
mean ,std= [0.5,0.5,0.5],[0.5,0.5,0.5]
transform = A.Compose(
[
A.Resize(256,256),
A.HorizontalFlip(p=0.5),
A.Normalize(mean=mean,std=std,max_pixel_value=255.0),
ToTensorV2()
] ,
additional_targets= {
'output_image':'image'
}
)
train_ds = MapDataset(train_path,tranform=transform)
val_ds = MapDataset(val_path,valid=True,tranform=transform)
train_dl = DataLoader(
dataset=train_ds,
batch_size=Config.BATCH_SIZE,
drop_last=True,
shuffle=True
)
valid_dl = DataLoader(
dataset=val_ds,
batch_size=Config.BATCH_SIZE,
drop_last=True,
shuffle=False
)
if __name__ == "__main__":
# Get a batch from the validation dataloader
batch = next(iter(valid_dl))
input_images, real_output_images = batch
input_images = input_images.cpu().detach()
real_output_images = real_output_images.cpu().detach()
input_grid = make_grid(input_images, nrow=4, normalize=True, value_range=(-1, 1))
output_grid = make_grid(real_output_images, nrow=4, normalize=True, value_range=(-1, 1))
tb_logger = TensorboardLogger('./log_imgs_test')
# Log the images
tb_logger.writer.add_image("Inputs", input_grid, 0)
tb_logger.writer.add_image("Outputs", output_grid, 0)
tb_logger.close()