From c114930df84ee448e796077af353d8faf23e8ca5 Mon Sep 17 00:00:00 2001 From: Semar Augusto Date: Fri, 17 Jan 2020 10:02:55 -0300 Subject: [PATCH] fix typos and clean up code --- README.md | 2 +- trainic15data.py | 52 ++++-------------------------------------------- 2 files changed, 5 insertions(+), 49 deletions(-) diff --git a/README.md b/README.md index 9b32b95..c7dbff5 100644 --- a/README.md +++ b/README.md @@ -13,7 +13,7 @@ The full paper is available at: https://arxiv.org/pdf/1904.01941.pdf 1、PyTroch>=0.4.1 2、torchvision>=0.2.1 3、opencv-python>=3.4.2 -4、check requiremtns.txt +4、check requirements.txt 5、4 nvidia GPUs(we use 4 nvidia titanX) diff --git a/trainic15data.py b/trainic15data.py index 9763d12..fa3a580 100644 --- a/trainic15data.py +++ b/trainic15data.py @@ -22,41 +22,26 @@ from data_loader import ICDAR2015, Synth80k, ICDAR2013 ###import file####### -from augmentation import random_rot, crop_img_bboxes -from gaussianmap import gaussion_transform, four_point_transform -from generateheatmap import add_character, generate_target, add_affinity, generate_affinity, sort_box, real_affinity, generate_affinity_box from mseloss import Maploss - - from collections import OrderedDict from eval.script import getresult - - from PIL import Image from torchvision.transforms import transforms from craft import CRAFT from torch.autograd import Variable from multiprocessing import Pool -#3.2768e-5 random.seed(42) -# class SynAnnotationTransform(object): -# def __init__(self): -# pass -# def __call__(self, gt): -# image_name = gt['imnames'][0] parser = argparse.ArgumentParser(description='CRAFT reimplementation') parser.add_argument('--resume', default=None, type=str, help='Checkpoint state_dict file to resume training from') -parser.add_argument('--batch_size', default=128, type = int, +parser.add_argument('--batch_size', default=128, type=int, help='batch size of training') -#parser.add_argument('--cdua', default=True, type=str2bool, - #help='Use CUDA to train model') parser.add_argument('--lr', '--learning-rate', default=3.2768e-5, type=float, help='initial learning rate') parser.add_argument('--momentum', default=0.9, type=float, @@ -72,10 +57,6 @@ args = parser.parse_args() - - - - def copyStateDict(state_dict): if list(state_dict.keys())[0].startswith("module"): start_idx = 1 @@ -87,6 +68,7 @@ def copyStateDict(state_dict): new_state_dict[name] = v return new_state_dict + def adjust_learning_rate(optimizer, gamma, step): """Sets the learning rate to the initial LR decayed by 10 at every specified step @@ -100,7 +82,6 @@ def adjust_learning_rate(optimizer, gamma, step): if __name__ == '__main__': - dataloader = Synth80k('/data/CRAFT-pytorch/syntext/SynthText/SynthText', target_size = 768) train_loader = torch.utils.data.DataLoader( dataloader, @@ -110,14 +91,12 @@ def adjust_learning_rate(optimizer, gamma, step): drop_last=True, pin_memory=True) batch_syn = iter(train_loader) - + net = CRAFT() net.load_state_dict(copyStateDict(torch.load('/data/CRAFT-pytorch/1-7.pth'))) - - net = net.cuda() - + net = net.cuda() net = torch.nn.DataParallel(net,device_ids=[0,1,2,3]).cuda() cudnn.benchmark = True @@ -134,13 +113,9 @@ def adjust_learning_rate(optimizer, gamma, step): optimizer = optim.Adam(net.parameters(), lr=args.lr, weight_decay=args.weight_decay) criterion = Maploss() - #criterion = torch.nn.MSELoss(reduce=True, size_average=True) - - step_index = 0 - loss_time = 0 loss_value = 0 compare_loss = 1 @@ -153,13 +128,11 @@ def adjust_learning_rate(optimizer, gamma, step): st = time.time() for index, (real_images, real_gh_label, real_gah_label, real_mask, _) in enumerate(real_data_loader): - #real_images, real_gh_label, real_gah_label, real_mask = next(batch_real) syn_images, syn_gh_label, syn_gah_label, syn_mask, __ = next(batch_syn) images = torch.cat((syn_images,real_images), 0) gh_label = torch.cat((syn_gh_label, real_gh_label), 0) gah_label = torch.cat((syn_gah_label, real_gah_label), 0) mask = torch.cat((syn_mask, real_mask), 0) - #affinity_mask = torch.cat((syn_mask, real_affinity_mask), 0) images = Variable(images.type(torch.FloatTensor)).cuda() @@ -169,8 +142,6 @@ def adjust_learning_rate(optimizer, gamma, step): gah_label = Variable(gah_label).cuda() mask = mask.type(torch.FloatTensor) mask = Variable(mask).cuda() - # affinity_mask = affinity_mask.type(torch.FloatTensor) - # affinity_mask = Variable(affinity_mask).cuda() out, _ = net(images) @@ -189,24 +160,9 @@ def adjust_learning_rate(optimizer, gamma, step): loss_time = 0 loss_value = 0 st = time.time() - # if loss < compare_loss: - # print('save the lower loss iter, loss:',loss) - # compare_loss = loss - # torch.save(net.module.state_dict(), - # '/data/CRAFT-pytorch/real_weights/lower_loss.pth') print('Saving state, iter:', epoch) torch.save(net.module.state_dict(), '/data/CRAFT-pytorch/real_weights/CRAFT_clr_' + repr(epoch) + '.pth') test('/data/CRAFT-pytorch/real_weights/CRAFT_clr_' + repr(epoch) + '.pth') - #test('/data/CRAFT-pytorch/craft_mlt_25k.pth') getresult() - - - - - - - - -