diff --git a/train.py b/train.py index 3c3513a3..a21b4adf 100644 --- a/train.py +++ b/train.py @@ -95,7 +95,7 @@ def weights_init(m): crnn.load_state_dict(torch.load(opt.pretrained)) print(crnn) -image = torch.FloatTensor(opt.batchSize, 3, opt.imgH, opt.imgH) +image = torch.FloatTensor(opt.batchSize, 3, opt.imgH, opt.imgW) text = torch.IntTensor(opt.batchSize * 5) length = torch.IntTensor(opt.batchSize)