Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ The full paper is available at: https://arxiv.org/pdf/1904.01941.pdf
1、PyTroch>=0.4.1
2、torchvision>=0.2.1
3、opencv-python>=3.4.2
4、check requiremtns.txt
4、check requirements.txt
5、4 nvidia GPUs(we use 4 nvidia titanX)


Expand Down
52 changes: 4 additions & 48 deletions trainic15data.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,41 +22,26 @@
from data_loader import ICDAR2015, Synth80k, ICDAR2013

###import file#######
from augmentation import random_rot, crop_img_bboxes
from gaussianmap import gaussion_transform, four_point_transform
from generateheatmap import add_character, generate_target, add_affinity, generate_affinity, sort_box, real_affinity, generate_affinity_box
from mseloss import Maploss



from collections import OrderedDict
from eval.script import getresult



from PIL import Image
from torchvision.transforms import transforms
from craft import CRAFT
from torch.autograd import Variable
from multiprocessing import Pool

#3.2768e-5
random.seed(42)

# class SynAnnotationTransform(object):
# def __init__(self):
# pass
# def __call__(self, gt):
# image_name = gt['imnames'][0]
parser = argparse.ArgumentParser(description='CRAFT reimplementation')


parser.add_argument('--resume', default=None, type=str,
help='Checkpoint state_dict file to resume training from')
parser.add_argument('--batch_size', default=128, type = int,
parser.add_argument('--batch_size', default=128, type=int,
help='batch size of training')
#parser.add_argument('--cdua', default=True, type=str2bool,
#help='Use CUDA to train model')
parser.add_argument('--lr', '--learning-rate', default=3.2768e-5, type=float,
help='initial learning rate')
parser.add_argument('--momentum', default=0.9, type=float,
Expand All @@ -72,10 +57,6 @@
args = parser.parse_args()






def copyStateDict(state_dict):
if list(state_dict.keys())[0].startswith("module"):
start_idx = 1
Expand All @@ -87,6 +68,7 @@ def copyStateDict(state_dict):
new_state_dict[name] = v
return new_state_dict


def adjust_learning_rate(optimizer, gamma, step):
"""Sets the learning rate to the initial LR decayed by 10 at every
specified step
Expand All @@ -100,7 +82,6 @@ def adjust_learning_rate(optimizer, gamma, step):


if __name__ == '__main__':

dataloader = Synth80k('/data/CRAFT-pytorch/syntext/SynthText/SynthText', target_size = 768)
train_loader = torch.utils.data.DataLoader(
dataloader,
Expand All @@ -110,14 +91,12 @@ def adjust_learning_rate(optimizer, gamma, step):
drop_last=True,
pin_memory=True)
batch_syn = iter(train_loader)

net = CRAFT()

net.load_state_dict(copyStateDict(torch.load('/data/CRAFT-pytorch/1-7.pth')))

net = net.cuda()


net = net.cuda()

net = torch.nn.DataParallel(net,device_ids=[0,1,2,3]).cuda()
cudnn.benchmark = True
Expand All @@ -134,13 +113,9 @@ def adjust_learning_rate(optimizer, gamma, step):

optimizer = optim.Adam(net.parameters(), lr=args.lr, weight_decay=args.weight_decay)
criterion = Maploss()
#criterion = torch.nn.MSELoss(reduce=True, size_average=True)



step_index = 0


loss_time = 0
loss_value = 0
compare_loss = 1
Expand All @@ -153,13 +128,11 @@ def adjust_learning_rate(optimizer, gamma, step):

st = time.time()
for index, (real_images, real_gh_label, real_gah_label, real_mask, _) in enumerate(real_data_loader):
#real_images, real_gh_label, real_gah_label, real_mask = next(batch_real)
syn_images, syn_gh_label, syn_gah_label, syn_mask, __ = next(batch_syn)
images = torch.cat((syn_images,real_images), 0)
gh_label = torch.cat((syn_gh_label, real_gh_label), 0)
gah_label = torch.cat((syn_gah_label, real_gah_label), 0)
mask = torch.cat((syn_mask, real_mask), 0)
#affinity_mask = torch.cat((syn_mask, real_affinity_mask), 0)


images = Variable(images.type(torch.FloatTensor)).cuda()
Expand All @@ -169,8 +142,6 @@ def adjust_learning_rate(optimizer, gamma, step):
gah_label = Variable(gah_label).cuda()
mask = mask.type(torch.FloatTensor)
mask = Variable(mask).cuda()
# affinity_mask = affinity_mask.type(torch.FloatTensor)
# affinity_mask = Variable(affinity_mask).cuda()

out, _ = net(images)

Expand All @@ -189,24 +160,9 @@ def adjust_learning_rate(optimizer, gamma, step):
loss_time = 0
loss_value = 0
st = time.time()
# if loss < compare_loss:
# print('save the lower loss iter, loss:',loss)
# compare_loss = loss
# torch.save(net.module.state_dict(),
# '/data/CRAFT-pytorch/real_weights/lower_loss.pth')

print('Saving state, iter:', epoch)
torch.save(net.module.state_dict(),
'/data/CRAFT-pytorch/real_weights/CRAFT_clr_' + repr(epoch) + '.pth')
test('/data/CRAFT-pytorch/real_weights/CRAFT_clr_' + repr(epoch) + '.pth')
#test('/data/CRAFT-pytorch/craft_mlt_25k.pth')
getresult()