Skip to content

Commit 23cb35d

Browse files
author
mei jieru
committed
import t7 model
1 parent f2d4864 commit 23cb35d

File tree

8 files changed

+368
-32
lines changed

8 files changed

+368
-32
lines changed

README.md

+18
Original file line numberDiff line numberDiff line change
@@ -3,3 +3,21 @@ Convolutional Recurrent Neural Network
33

44
This software implements the Convolutional Recurrent Neural Network (CRNN) in pytorch.
55
Origin software could be found in [crnn](https://github.com/bgshih/crnn)
6+
7+
Run demo
8+
--------
9+
A demo program can be found in ``src/demo.py``. Before running the demo, download a pretrained model
10+
from [Baidu Netdisk](https://pan.baidu.com/s/1pLbeCND) or [Dropbox](https://www.dropbox.com/s/dboqjk20qjkpta3/crnn.pth?dl=0).
11+
This pretrained model is converted from auther offered one by ``tool``.
12+
Put the downloaded model file ``crnn.pth`` into directory ``data/``. Then launch the demo by:
13+
14+
python demo.py
15+
16+
The demo reads an example image and recognizes its text content.
17+
18+
Example image:
19+
![Example Image](./data/demo.png)
20+
21+
Expected output:
22+
loading pretrained model from ./data/crnn.pth
23+
a-----v--a-i-l-a-bb-l-ee-- => available

crnn_main.py

+16-10
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
parser.add_argument('--cuda', action='store_true', help='enables cuda')
2828
parser.add_argument('--ngpu', type=int, default=1, help='number of GPUs to use')
2929
parser.add_argument('--crnn', default='', help="path to crnn (to continue training)")
30-
parser.add_argument('--alphabet', type=str, default='abcdefghijklmnopqrstuvwxyz0123456789')
30+
parser.add_argument('--alphabet', type=str, default='0123456789abcdefghijklmnopqrstuvwxyz')
3131
parser.add_argument('--Diters', type=int, default=5, help='number of D iters per each G iter')
3232
parser.add_argument('--experiment', default=None, help='Where to store samples and models')
3333
parser.add_argument('--displayInterval', type=int, default=500, help='Interval to be displayed')
@@ -66,9 +66,9 @@
6666
train_dataset, batch_size=opt.batchSize,
6767
shuffle=True, sampler=sampler,
6868
num_workers=int(opt.workers),
69-
collate_fn=dataset.alignCollate(imgH=opt.imgH,
70-
keep_ratio=opt.keep_ratio))
71-
test_dataset = dataset.lmdbDataset(root=opt.valroot, transform=dataset.resizeNormalize((128, 32)))
69+
collate_fn=dataset.alignCollate(imgH=opt.imgH, keep_ratio=opt.keep_ratio))
70+
test_dataset = dataset.lmdbDataset(
71+
root=opt.valroot, transform=dataset.resizeNormalize((128, 32)))
7272

7373
ngpu = int(opt.ngpu)
7474
nh = int(opt.nh)
@@ -114,11 +114,12 @@ def weights_init(m):
114114

115115
# setup optimizer
116116
if opt.adam:
117-
optimizer = optim.Adam(crnn.parameters(), lr=opt.lrD, betas=(opt.beta1, 0.999))
117+
optimizer = optim.Adam(crnn.parameters(), lr=opt.lr,
118+
betas=(opt.beta1, 0.999))
118119
elif opt.adadelta:
119-
optimizer = optim.Adadelta(crnn.parameters(), lr=opt.lrD)
120+
optimizer = optim.Adadelta(crnn.parameters(), lr=opt.lr)
120121
else:
121-
optimizer = optim.RMSprop(crnn.parameters(), lr=opt.lrD)
122+
optimizer = optim.RMSprop(crnn.parameters(), lr=opt.lr)
122123

123124

124125
def val(net, dataset, criterion, max_iter=100):
@@ -129,7 +130,7 @@ def val(net, dataset, criterion, max_iter=100):
129130

130131
net.eval()
131132
data_loader = torch.utils.data.DataLoader(
132-
dataset, shuffle=True, batch_size=opt.batchSize, num_workers=int(opt.workers))
133+
dataset, shuffle=True, batch_size=opt.batchSize, num_workers=int(opt.workers))
133134
val_iter = iter(data_loader)
134135

135136
i = 0
@@ -167,6 +168,9 @@ def val(net, dataset, criterion, max_iter=100):
167168
print('Test loss: %f, accuray: %f' % (loss_avg.val(), accuracy))
168169

169170

171+
# val(crnn, test_dataset, criterion)
172+
# exit(0)
173+
170174
def trainBatch(net, criterion, optimizer):
171175
data = train_iter.next()
172176
cpu_images, cpu_texts = data
@@ -198,12 +202,14 @@ def trainBatch(net, criterion, optimizer):
198202
i += 1
199203

200204
if i % opt.displayInterval == 0:
201-
print('[%d/%d][%d/%d] Loss: %f' % (epoch, opt.niter, i, len(train_loader), loss_avg.val()))
205+
print('[%d/%d][%d/%d] Loss: %f' %
206+
(epoch, opt.niter, i, len(train_loader), loss_avg.val()))
202207
loss_avg.reset()
203208

204209
if i % opt.valInterval == 0:
205210
val(crnn, test_dataset, criterion)
206211

207212
# do checkpointing
208213
if i % opt.saveInterval == 0:
209-
torch.save(crnn.state_dict(), '{0}/netCRNN_{1}_{2}.pth'.format(opt.experiment, epoch, i))
214+
torch.save(
215+
crnn.state_dict(), '{0}/netCRNN_{1}_{2}.pth'.format(opt.experiment, epoch, i))

data/demo.png

18.7 KB
Loading

demo.py

+36
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
import torch
2+
from torch.autograd import Variable
3+
import utils
4+
import dataset
5+
from PIL import Image
6+
7+
import models.crnn as crnn
8+
9+
10+
model_path = './data/crnn.pth'
11+
img_path = './data/demo.png'
12+
alphabet = '0123456789abcdefghijklmnopqrstuvwxyz'
13+
14+
model = crnn.CRNN(32, 1, 37, 256, 1).cuda()
15+
print('loading pretrained model from %s' % model_path)
16+
model.load_state_dict(torch.load(model_path))
17+
18+
converter = utils.strLabelConverter(alphabet)
19+
20+
transformer = dataset.resizeNormalize((100, 32))
21+
image = Image.open(img_path).convert('L')
22+
image = transformer(image).cuda()
23+
image = image.view(1, *image.size())
24+
image = Variable(image)
25+
26+
model.eval()
27+
preds = model(image)
28+
29+
_, preds = preds.max(2)
30+
preds = preds.squeeze(2)
31+
preds = preds.transpose(1, 0).contiguous().view(-1)
32+
33+
preds_size = Variable(torch.IntTensor([preds.size(0)]))
34+
raw_pred = converter.decode(preds.data, preds_size.data, raw=True)
35+
sim_pred = converter.decode(preds.data, preds_size.data, raw=False)
36+
print('%-20s => %-20s' % (raw_pred, sim_pred))

models/crnn.py

+39-19
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,31 @@
22
import torch.nn.parallel
33

44

5+
class BidirectionalLSTM(nn.Module):
6+
7+
def __init__(self, nIn, nHidden, nOut, ngpu):
8+
super(BidirectionalLSTM, self).__init__()
9+
self.ngpu = ngpu
10+
11+
self.rnn = nn.LSTM(nIn, nHidden, bidirectional=True)
12+
self.embedding = nn.Linear(nHidden * 2, nOut)
13+
14+
def forward(self, input):
15+
gpu_ids = None
16+
if isinstance(input.data, torch.cuda.FloatTensor) and self.ngpu > 1:
17+
gpu_ids = range(self.ngpu)
18+
recurrent, _ = nn.parallel.data_parallel(
19+
self.rnn, input, gpu_ids) # [T, b, h * 2]
20+
21+
T, b, h = recurrent.size()
22+
t_rec = recurrent.view(T * b, h)
23+
output = nn.parallel.data_parallel(
24+
self.embedding, t_rec, gpu_ids) # [T * b, nOut]
25+
output = output.view(T, b, -1)
26+
27+
return output
28+
29+
530
class CRNN(nn.Module):
631

732
def __init__(self, imgH, nc, nclass, nh, ngpu, n_rnn=2, leakyRelu=False):
@@ -30,45 +55,40 @@ def convRelu(i, batchNormalization=False):
3055
cnn.add_module('relu{0}'.format(i), nn.ReLU(True))
3156

3257
convRelu(0)
33-
cnn.add_module('pooling{0}'.format(0), nn.MaxPool2d((2, 2),
34-
(2, 2))) # 64x16x64
58+
cnn.add_module('pooling{0}'.format(0), nn.MaxPool2d(2, 2)) # 64x16x64
3559
convRelu(1)
36-
cnn.add_module('pooling{0}'.format(1), nn.MaxPool2d((2, 2),
37-
(2, 2))) # 128x8x32
60+
cnn.add_module('pooling{0}'.format(1), nn.MaxPool2d(2, 2)) # 128x8x32
3861
convRelu(2, True)
3962
convRelu(3)
4063
cnn.add_module('pooling{0}'.format(2), nn.MaxPool2d((2, 2),
41-
(2, 2))) # 256x4x16
64+
(2, 1),
65+
(0, 1))) # 256x4x16
4266
convRelu(4, True)
4367
convRelu(5)
4468
cnn.add_module('pooling{0}'.format(3), nn.MaxPool2d((2, 2),
45-
(2, 1))) # 512x2x16
69+
(2, 1),
70+
(0, 1))) # 512x2x16
4671
convRelu(6, True) # 512x1x16
4772

4873
self.cnn = cnn
49-
self.rnn = nn.Sequential(nn.LSTM(512, nh, n_rnn, bidirectional=True))
50-
self.text = nn.Sequential(nn.Linear(nh * 2, nclass)) # [T, b, nclass]
74+
self.rnn = nn.Sequential(
75+
BidirectionalLSTM(512, nh, nh, ngpu),
76+
BidirectionalLSTM(nh, nh, nclass, ngpu)
77+
)
5178

5279
def forward(self, input):
5380
gpu_ids = None
5481
if isinstance(input.data, torch.cuda.FloatTensor) and self.ngpu > 1:
5582
gpu_ids = range(self.ngpu)
83+
5684
# conv features
5785
conv = nn.parallel.data_parallel(self.cnn, input, gpu_ids)
58-
59-
# rnn features
6086
b, c, h, w = conv.size()
6187
assert h == 1, "the height of conv must be 1"
6288
conv = conv.squeeze(2)
6389
conv = conv.permute(2, 0, 1) # [w, b, c]
64-
recurrent, _ = nn.parallel.data_parallel(self.rnn, conv,
65-
gpu_ids) # [T, b, h * 2]
6690

67-
# text classifier
68-
T, b, h = recurrent.size()
69-
t_rec = recurrent.view(T * b, h)
70-
text = nn.parallel.data_parallel(self.text, t_rec,
71-
gpu_ids) # [T * b, nclass]
72-
text = text.view(T, b, -1)
91+
# rnn features
92+
output = nn.parallel.data_parallel(self.rnn, conv, gpu_ids)
7393

74-
return text
94+
return output

tool/convert_t7.lua

+89
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
require('table')
2+
require('torch')
3+
require('os')
4+
5+
function clone(t)
6+
-- deep-copy a table
7+
if type(t) ~= "table" then return t end
8+
local meta = getmetatable(t)
9+
local target = {}
10+
for k, v in pairs(t) do
11+
if type(v) == "table" then
12+
target[k] = clone(v)
13+
else
14+
target[k] = v
15+
end
16+
end
17+
setmetatable(target, meta)
18+
return target
19+
end
20+
21+
22+
function tableMerge(lhs, rhs)
23+
output = clone(lhs)
24+
for _, v in pairs(rhs) do
25+
table.insert(output, v)
26+
end
27+
return output
28+
end
29+
30+
31+
function isInTable(val, val_list)
32+
for _, item in pairs(val_list) do
33+
if val == item then
34+
return true
35+
end
36+
end
37+
return false
38+
end
39+
40+
41+
function modelToList(model)
42+
local ignoreList = {
43+
'nn.Copy',
44+
'nn.AddConstant',
45+
'nn.MulConstant',
46+
'nn.View',
47+
'nn.Transpose',
48+
'nn.SplitTable',
49+
'nn.SharedParallelTable',
50+
'nn.JoinTable',
51+
}
52+
local state = {}
53+
local param
54+
for i, layer in pairs(model.modules) do
55+
local typeName = torch.type(layer)
56+
if not isInTable(typeName, ignoreList) then
57+
if typeName == 'nn.Sequential' or typeName == 'nn.ConcatTable' then
58+
param = modelToList(layer)
59+
elseif typeName == 'cudnn.SpatialConvolution' or typeName == 'nn.SpatialConvolution' then
60+
param = layer:parameters()
61+
elseif typeName == 'cudnn.SpatialBatchNormalization' or typeName == 'nn.SpatialBatchNormalization' then
62+
param = layer:parameters()
63+
bn_vars = {layer.running_mean, layer.running_var}
64+
param = tableMerge(param, bn_vars)
65+
elseif typeName == 'nn.LstmLayer' then
66+
param = layer:parameters()
67+
elseif typeName == 'nn.BiRnnJoin' then
68+
param = layer:parameters()
69+
elseif typeName == 'cudnn.SpatialMaxPooling' or typeName == 'nn.SpatialMaxPooling' then
70+
param = {}
71+
elseif typeName == 'cudnn.ReLU' or typeName == 'nn.ReLU' then
72+
param = {}
73+
else
74+
print(string.format('Unknown class %s', typeName))
75+
os.exit(0)
76+
end
77+
table.insert(state, {typeName, param})
78+
else
79+
print(string.format('pass %s', typeName))
80+
end
81+
end
82+
return state
83+
end
84+
85+
86+
function saveModel(model, output_path)
87+
local state = modelToList(model)
88+
torch.save(output_path, state)
89+
end

0 commit comments

Comments
 (0)