forked from PkuDavidGuan/CurvedSynthText
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathload_dataset.py
69 lines (57 loc) · 2.08 KB
/
load_dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
from __future__ import print_function
from __future__ import division
import numpy as np
import h5py
import os, sys, traceback
import cv2
import pickle
def group(iterator, count):
itr = iter(iterator)
while True:
yield tuple([next(itr) for i in range(count)])
class DatasetLoader():
def __init__(self, data_path):
# generate data path
self.data_path = data_path
self.image_path = os.path.join(data_path, 'bg_img')
image_name_file = os.path.join(data_path, 'imnames.cp')
seg_file = os.path.join(data_path, 'seg.h5')
depth_file = os.path.join(data_path, 'depth.h5')
# open h5 files
self.seg = h5py.File(seg_file, 'r')['mask']
self.depth = h5py.File(depth_file, 'r')
# load file names
self.filenames = []
with open(image_name_file) as namefile:
for i, file in group(namefile, 2):
i = int(i[i.find('p')+1:-1])
vindex = file.find('V')
if -1 == vindex:
break
file = file[vindex+1:-1]
if file in self.seg:
self.filenames.append(file)
def load(self, filename):
# load depth
depth = self.depth[filename][:].T
depth = depth[:,:,1]
sz = depth.shape[:2][::-1]
# load image
img = cv2.imread(os.path.join(self.image_path, filename))
img = cv2.resize(img, sz)
# load segmentation
seg = self.seg[filename][:]
seg_max = np.max(seg)
label = np.array(range(1, seg_max+1), dtype=np.int)
area = np.array(range(1, seg_max+1), dtype=np.int)
for i in label:
area[i-1] = int(np.sum(seg == i))
seg = seg.astype('float32')
seg = cv2.resize(seg, sz, interpolation=cv2.INTER_NEAREST)
return img, depth, seg, area, label
if __name__ == '__main__':
loader = DatasetLoader('/home/zwj/SynthTextData')
filename = loader.filenames[333]
data = loader.load(filename)
with open('dump.pkl', 'wbd') as f:
pickle.dump(data, f)