-
Notifications
You must be signed in to change notification settings - Fork 45
/
generate_pseudo_label.py
122 lines (102 loc) · 5.34 KB
/
generate_pseudo_label.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import os
import logging
import random
import argparse
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from PIL import Image
from parser_train import parser_, relative_path_to_absolute_path
from tqdm import tqdm
from data import create_dataset
from models import adaptation_modelv2
from utils import fliplr
def test(opt, logger):
torch.manual_seed(opt.seed)
torch.cuda.manual_seed(opt.seed)
np.random.seed(opt.seed)
random.seed(opt.seed)
## create dataset
device = torch.device("cuda:0" if torch.cuda.is_available() else 'cpu')
datasets = create_dataset(opt, logger)
if opt.model_name == 'deeplabv2':
checkpoint = torch.load(opt.resume_path)['ResNet101']["model_state"]
model = adaptation_modelv2.CustomModel(opt, logger)
model.BaseNet.load_state_dict(checkpoint)
validation(model, logger, datasets, device, opt)
def validation(model, logger, datasets, device, opt):
_k = -1
model.eval(logger=logger)
torch.cuda.empty_cache()
with torch.no_grad():
validate(datasets.target_train_loader, device, model, opt)
#validate(datasets.target_valid_loader, device, model, opt)
def label2rgb(func, label):
rgbs = []
for k in range(label.shape[0]):
rgb = func(label[k, 0].cpu().numpy())
rgbs.append(torch.from_numpy(rgb).permute(2, 0, 1))
rgbs = torch.stack(rgbs, dim=0).float()
return rgbs
def validate(valid_loader, device, model, opt):
ori_LP = os.path.join(opt.root, 'Code/ProDA', opt.save_path, opt.name)
if not os.path.exists(ori_LP):
os.makedirs(ori_LP)
sm = torch.nn.Softmax(dim=1)
for data_i in tqdm(valid_loader):
images_val = data_i['img'].to(device)
labels_val = data_i['label'].to(device)
filename = data_i['img_path']
out = model.BaseNet_DP(images_val)
if opt.soft:
threshold_arg = F.softmax(out['out'], dim=1)
for k in range(labels_val.shape[0]):
name = os.path.basename(filename[k])
np.save(os.path.join(ori_LP, name.replace('.png', '.npy')), threshold_arg[k].cpu().numpy())
else:
if opt.flip:
flip_out = model.BaseNet_DP(fliplr(images_val))
flip_out['out'] = F.interpolate(sm(flip_out['out']), size=images_val.size()[2:], mode='bilinear', align_corners=True)
out['out'] = F.interpolate(sm(out['out']), size=images_val.size()[2:], mode='bilinear', align_corners=True)
out['out'] = (out['out'] + fliplr(flip_out['out'])) / 2
confidence, pseudo = out['out'].max(1, keepdim=True)
#entropy = -(out['out']*torch.log(out['out']+1e-6)).sum(1, keepdim=True)
pseudo_rgb = label2rgb(valid_loader.dataset.decode_segmap, pseudo).float() * 255
for k in range(labels_val.shape[0]):
name = os.path.basename(filename[k])
Image.fromarray(pseudo[k,0].cpu().numpy().astype(np.uint8)).save(os.path.join(ori_LP, name))
Image.fromarray(pseudo_rgb[k].permute(1,2,0).cpu().numpy().astype(np.uint8)).save(os.path.join(ori_LP, name[:-4] + '_color.png'))
np.save(os.path.join(ori_LP, name.replace('.png', '_conf.npy')), confidence[k, 0].cpu().numpy().astype(np.float16))
#np.save(os.path.join(ori_LP, name.replace('.png', '_entropy.npy')), entropy[k, 0].cpu().numpy().astype(np.float16))
def get_logger(logdir):
logger = logging.getLogger('ptsemseg')
file_path = os.path.join(logdir, 'run.log')
hdlr = logging.FileHandler(file_path)
formatter = logging.Formatter('%(asctime)s %(levelname)s %(message)s')
hdlr.setFormatter(formatter)
logger.addHandler(hdlr)
logger.setLevel(logging.INFO)
return logger
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="config")
parser.add_argument('--save_path', type=str, default='./Pseudo', help='pseudo label update thred')
parser.add_argument('--soft', action='store_true', help='save soft pseudo label')
parser.add_argument('--flip', action='store_true')
parser = parser_(parser)
opt = parser.parse_args()
opt = relative_path_to_absolute_path(opt)
opt.logdir = opt.logdir.replace(opt.name, 'debug')
opt.noaug = True
opt.noshuffle = True
print('RUNDIR: {}'.format(opt.logdir))
if not os.path.exists(opt.logdir):
os.makedirs(opt.logdir)
logger = get_logger(opt.logdir)
test(opt, logger)
#python generate_pseudo_label.py --name gta2citylabv2_warmup_soft --soft --resume_path ./logs/gta2citylabv2_warmup/from_gta5_to_cityscapes_on_deeplabv2_best_model.pkl --no_droplast
#python generate_pseudo_label.py --name gta2citylabv2_stage1Denoise --flip --resume_path ./logs/gta2citylabv2_stage1Denoisev2/from_gta5_to_cityscapes_on_deeplabv2_best_model.pkl --no_droplast
#python generate_pseudo_label.py --name gta2citylabv2_stage2 --flip --resume_path ./logs/gta2citylabv2_stage2/from_gta5_to_cityscapes_on_deeplabv2_best_model.pkl --no_droplast --bn_clr --student_init simclr
#python generate_pseudo_label.py --name syn2citylabv2_warmup_soft --soft --src_dataset synthia --n_class 16 --src_rootpath Dataset/SYNTHIA-RAND-CITYSCAPES --resume_path ./logs/syn2citylabv2_warmup/from_synthia_to_cityscapes_on_deeplabv2_best_model.pkl --no_droplast