Skip to content

Commit

Permalink
feat: Amazing table detection performance
Browse files Browse the repository at this point in the history
  • Loading branch information
shahrukhqasim committed Nov 12, 2017
1 parent f2e3982 commit 5b85b5e
Show file tree
Hide file tree
Showing 9 changed files with 127 additions and 58 deletions.
19 changes: 14 additions & 5 deletions python/network/ModuleA.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,22 @@ def __init__(self, D_in, D_out):
# H1 = Variable(torch.randn(num_words, 100))
# H2 = Variable(torch.randn(num_words, 100))

self.linear1 = torch.nn.Linear(D_in, 100).cuda()
self.linear2 = torch.nn.Linear(100, 100).cuda()
self.linear1 = torch.nn.Linear(D_in, 200).cuda()
self.linear2 = torch.nn.Linear(200, 100).cuda()
self.linear3 = torch.nn.Linear(100, D_out).cuda()

self.linear1.weight.data.uniform_(-30, 30)
self.linear1.bias.data.uniform_(0, 0)
self.linear2.weight.data.uniform_(-30, 30)
self.linear2.bias.data.uniform_(0, 0)
self.linear3.weight.data.uniform_(-30, 30)
self.linear3.bias.data.uniform_(0, 0)

self.activation = torch.nn.Sigmoid()

def forward(self, x):
o1 = self.linear1(x).clamp(min=0)
o2 = self.linear2(o1).clamp(min=0)
return self.linear3(o2).tanh()
o1 = self.activation(self.linear1(x))
o2 = self.activation(self.linear2(o1))
return self.activation(self.linear3(o2))

# return self.linear1(x).tanh()
2 changes: 1 addition & 1 deletion python/network/ModuleB.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
class ModuleB(torch.nn.Module):
def __init__(self):
super(ModuleB, self).__init__()
self.gru = torch.nn.GRUCell(100, 100).cuda()
self.gru = torch.nn.GRUCell(500, 100).cuda()

def forward(self, x, hx):
return self.gru.forward(x, hx)
4 changes: 3 additions & 1 deletion python/network/ModuleB2.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,15 @@ def __init__(self, D_in, D_out_1, D_out_2):
self.linear22 = torch.nn.Linear(100, 100).cuda()
self.linear32 = torch.nn.Linear(100, D_out_2).cuda()

self.activation = torch.nn.Sigmoid()

def forward(self, x):
o1 = self.linear1(x).clamp(min=0)
o2 = self.linear2(o1).clamp(min=0)

o12 = self.linear12(x).clamp(min=0)
o22 = self.linear22(o12).clamp(min=0)

return self.linear3(o2).tanh(), self.linear32(o22).tanh()
return self.activation(self.linear3(o2)), self.activation(self.linear32(o22))

# return self.linear1(x).tanh(), self.linear12(x).tanh()
2 changes: 1 addition & 1 deletion python/network/ModuleC.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,6 @@ def __init__(self, D_in, D_out):
def forward(self, x):
o1 = self.linear1(x).clamp(min=0)
o2 = self.linear2(o1).clamp(min=0)
return self.linear3(o2)
return self.linear3(x)

# return self.linear3(x)
8 changes: 5 additions & 3 deletions python/network/ModuleD.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,11 @@ def __init__(self, D_in, D_out):
self.linear2 = torch.nn.Linear(100, 100).cuda()
self.linear3 = torch.nn.Linear(100, D_out).cuda()

self.activation = torch.nn.Sigmoid()

def forward(self, x):
o1 = self.linear1(x).clamp(min=0)
o2 = self.linear2(o1).clamp(min=0)
return self.linear3(o2).tanh()
# o1 = self.linear1(x).clamp(min=0)
# o2 = self.linear2(o1).clamp(min=0)
return self.activation(self.linear3(x))

# return self.linear3(x).tanh()
26 changes: 13 additions & 13 deletions python/network/computation_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,25 +28,25 @@ def set_iterations(self, iterations):

def concat(self, x, indices, indices_not_found, num_words):
y = Variable(torch.zeros(num_words, 100 * 5)).cuda()
y[:, 000:100] = x#[indices[:, 0]]
# y[:, 100:200] = x[indices[:, 1]]
# y[:, 200:300] = x[indices[:, 2]]
# y[:, 300:400] = x[indices[:, 3]]
# y[:, 400:500] = x[indices[:, 4]]

# y[indices_not_found] = 0
y[:, 000:100] = x[indices[:, 0]]
y[:, 100:200] = x[indices[:, 1]]
y[:, 200:300] = x[indices[:, 2]]
y[:, 300:400] = x[indices[:, 3]]
y[:, 400:500] = x[indices[:, 4]]
y[indices_not_found] = 0

return y

def forward(self, indices, indices_not_found, vv, num_words):

uu = self.A.forward(vv)
hh = Variable(torch.zeros(num_words,100)).cuda()

# for i in range(self.iterations):
# # ww = self.concat(uu, indices, indices_not_found, num_words)
# bb = self.B.forward(uu, hh)
# oo, hh = self.B2.forward(bb)
# ll = self.C.forward(oo)
# uu = self.D.forward(hh)
for i in range(self.iterations):
ww = self.concat(uu, indices, indices_not_found, num_words)
bb = self.B.forward(ww, hh)
oo, hh = self.B2.forward(bb)
ll = self.C.forward(oo)
uu = self.D.forward(hh)

return ll
9 changes: 7 additions & 2 deletions python/network/data_features_dumper.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,9 +84,10 @@ def load(self):

ii = 0
for i in os.listdir(self.path):
print("On", i)
full_example_path = os.path.join(self.path, i)

image_path = os.path.join(full_example_path, 'image.png')
image_path = os.path.join(full_example_path, 'tables.png')
json_path = os.path.join(full_example_path, 'ocr_gt.json')
document_dump_path = os.path.join(full_example_path, '__dump__.pickle')
image = cv2.imread(image_path, 0)
Expand Down Expand Up @@ -123,7 +124,11 @@ def load(self):

for i in range(len(all_tokens)):
token_rect = all_tokens_rects[i]
class_indices.append(0 if image[int(token_rect['y'] + token_rect['height']/2), int(token_rect['x'] + token_rect['width']/2)] == 0 else 1)
try:
class_indices.append(0 if image[int(token_rect['y'] + token_rect['height']/2), int(token_rect['x'] + token_rect['width']/2)] == 0 else 1)
except:
print(i, all_tokens[i], all_tokens_rects[i])
pass

self.dump_doc(all_tokens, all_tokens_rects, image, document_dump_path)
ii += 1
Expand Down
18 changes: 9 additions & 9 deletions python/network/neighbor_graph_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,23 +39,23 @@ def get_neighbor_matrix(self):
new_distance = A['y'] - B['y']
new_distance_abs = abs(new_distance)
# B is above A
if new_distance > 0 and new_distance < min_top:
min_top = new_distance
if new_distance > 0 and new_distance_abs < min_top:
min_top = new_distance_abs
min_index_top = j
# B is below A
elif new_distance < min_bottom:
min_bottom = new_distance
elif new_distance_abs < min_bottom:
min_bottom = new_distance_abs
min_index_bottom = j
if self.vertical_overlap(A, B):
new_distance = A['x'] - B['x']
new_distance_abs = abs(new_distance)
# B is left of A
if new_distance > 0 and new_distance < min_left:
min_left = new_distance
if new_distance > 0 and new_distance_abs < min_left:
min_left = new_distance_abs
min_index_left = j
# B is below A
elif new_distance < min_right:
min_right = new_distance
# B is right of A
elif new_distance_abs < min_right:
min_right = new_distance_abs
min_index_right = j
m[i,0] = min_index_left
m[i,1] = min_index_top
Expand Down
97 changes: 74 additions & 23 deletions python/network/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from network.computation_graph import SimpleDocProcModel
import torch
from torch.autograd import Variable
import cv2


class DataLoader(implements(LoadInterface)):
Expand All @@ -24,42 +25,92 @@ def __init__(self):
config.read('config.ini')
self.train_path = config['quad']['train_data_path']
self.test_path = config['quad']['test_data_path']
self.validation_data_path = config['quad']['validation_data_path']
self.glove_path = config['quad']['glove_path']
self.learning_rate = float(config['quad']['learning_rate'])

def init(self, dump_features_again):
if dump_features_again:
self.reader = DataFeaturesDumper(self.train_path, self.glove_path, 'train')
self.reader.load()
self.glove_reader = self.reader.get_glove_reader()
# self.reader = DataFeaturesDumper(self.train_path, self.glove_path, 'train')
# self.reader.load()
self.validation_reader = DataFeaturesDumper(self.validation_data_path, self.glove_path, 'validate')
self.validation_reader.load()

def do_plot(self, document, id):
rects = document.rects
classes = document.classes
canvas = (np.ones((500,500, 3))*255).astype(np.uint8)
for i in range(len(rects)):
rect = rects[i]
color = (255, 0, 0) if classes[i] == 0 else (0,0,255)
cv2.rectangle(canvas, (int(rect[0] * 500), int(rect[1]*500)), (int((rect[0]+rect[2]) * 500), int((rect[1]+rect[3])*500)), color)
cv2.imshow('test' + id, canvas)
cv2.waitKey(0)

def get_example_elements(self, document):
num_words, _ = np.shape(document.rects)
vv = np.concatenate([document.rects, document.distances, document.embeddings], axis=1).astype(np.float32)
vv = Variable(torch.from_numpy(vv)).cuda()
y = Variable(torch.from_numpy(document.classes.astype(np.int64)), requires_grad=False).cuda()

baseline_accuracy_1 = 100 * np.sum(document.classes == 0) / num_words
baseline_accuracy_2 = 100 * np.sum(document.classes == 1) / num_words

indices = torch.LongTensor(torch.from_numpy(np.concatenate(
[np.expand_dims(np.arange(num_words, dtype=np.int64), axis=1),
np.maximum(document.neighbor_graph.astype(np.int64), 0)], axis=1))).cuda()
indices_not_found = torch.ByteTensor(torch.from_numpy(np.repeat(np.concatenate(
[np.expand_dims(np.zeros(num_words, dtype=np.int64), axis=1),
document.neighbor_graph.astype(np.int64)], axis=1) == -1, 100).reshape((-1, 500)).astype(
np.uint8))).cuda()
indices_not_found = indices_not_found * 0

return num_words, vv, y, baseline_accuracy_1, baseline_accuracy_2, indices, indices_not_found

def do_validation(self, model, dataset):

sum_of_accuracies = 0
total = 0
while True:
document, epoch, id = dataset.next_element()
num_words, vv, y, baseline_accuracy_1, baseline_accuracy_2, indices, indices_not_found = self.get_example_elements(document)

y_pred = model(indices, indices_not_found, vv, num_words)
_, predicted = torch.max(y_pred.data, 1)

accuracy = torch.sum(predicted == y.data)
accuracy = 100 * accuracy / num_words

print(accuracy)

total += 1
sum_of_accuracies += accuracy

if epoch == 1:
break

print("Average validation accuracy = ", sum_of_accuracies / total)

temp = input()

def train(self):
dataset = FolderDataReader(self.train_path, DataLoader())
validation_dataset = FolderDataReader(self.validation_data_path, DataLoader())
dataset.init()
validation_dataset.init()
model = SimpleDocProcModel()
model.set_iterations(1)
model.set_iterations(2)
criterion = torch.nn.CrossEntropyLoss(size_average=True)
optimizer = torch.optim.Adam(model.parameters(), lr=self.learning_rate)
for i in range(10000):
for i in range(1000000):
if i % 10000 == 0:
self.do_validation(model, validation_dataset)

document, epoch, id = dataset.next_element()
num_words, _ = np.shape(document.rects)
vv = np.concatenate([document.rects, document.distances, document.embeddings * 0], axis=1).astype(np.float32)
vv = Variable(torch.from_numpy(vv)).cuda()
y = Variable(torch.from_numpy(document.classes.astype(np.int64)), requires_grad=False).cuda()

baseline_accuracy_1 = 100 * np.sum(document.classes==0) / num_words
baseline_accuracy_2 = 100 * np.sum(document.classes==1) / num_words

indices = torch.LongTensor(torch.from_numpy(np.concatenate(
[np.expand_dims(np.arange(num_words, dtype=np.int64), axis=1),
np.maximum(document.neighbor_graph.astype(np.int64), 0)], axis=1))).cuda()
indices_not_found = torch.ByteTensor(torch.from_numpy(np.repeat(np.concatenate(
[np.expand_dims(np.zeros(num_words, dtype=np.int64), axis=1),
document.neighbor_graph.astype(np.int64)], axis=1) == -1, 100).reshape((-1, 500)).astype(
np.uint8))).cuda()
indices_not_found = indices_not_found*0

for j in range(1):
num_words, vv, y, baseline_accuracy_1, baseline_accuracy_2, indices, indices_not_found = self.get_example_elements(document)


for j in range(2):
y_pred = model(indices, indices_not_found, vv, num_words)
_, predicted = torch.max(y_pred.data, 1)
accuracy = torch.sum(predicted == y.data)
Expand Down

0 comments on commit 5b85b5e

Please sign in to comment.