-
Notifications
You must be signed in to change notification settings - Fork 101
/
albumentations.py
64 lines (54 loc) · 1.9 KB
/
albumentations.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
#coding=utf-8
import glob
import cv2
from albumentations import (
PadIfNeeded,
HorizontalFlip, # 随机水平翻转
VerticalFlip, # 随机垂直翻转
CenterCrop,
Crop,
Compose,
Transpose,
RandomRotate90, # 随机90度旋转
ElasticTransform,
GridDistortion,
OpticalDistortion,
RandomSizedCrop, # 随机尺寸裁剪并缩放回原始大小
OneOf,
CLAHE,
RandomBrightnessContrast,
RandomGamma
)
def data_num(train_path, mask_path):
train_img = glob.glob(train_path)
masks = glob.glob(mask_path)
return train_img, masks
def mask_aug():
aug = Compose([VerticalFlip(p=0.5),
RandomRotate90(p=0.5),
HorizontalFlip(p=0.5),
RandomSizedCrop(min_max_height=(128, 512), height=384, width=384, p=0.5)])
return aug
def main():
train_path = (r"./data/data-2/train_image/*.jpg") # 输入 img 地址
mask_path = (r"./data/data-2/train_label/*.png") # 输入 mask 地址
augtrain_path = (r"./data/data-2/new_image") # 输入增强img存放地址
augmask_path = (r"./data/data-2/new_label") # 输入增强mask存放地址
num = 3 # 输入增强图像增强的张数。
aug = mask_aug()
train_img, masks = data_num(train_path, mask_path)
for data in range(len(train_img)):
for i in range(num):
image = cv2.imread(train_img[data])
mask = cv2.imread(masks[data])
augmented = aug(image=image, mask=mask)
aug_image = augmented['image']
aug_mask = augmented['mask']
cv2.imwrite("./data/data-2/new_image/aug{}_{}.jpg".format(data, i), aug_image)
cv2.imwrite("./data/data-2/new_label/aug{}_{}.png".format(data, i), aug_mask)
print(data)
# cv2.imshow("x",aug_image)
# cv2.imshow("y",aug_mask)
# cv2.waitKey(0)
if __name__ == "__main__":
main()