-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: Build dynamic RNN based network
- Loading branch information
1 parent
ee1c944
commit 175023b
Showing
21 changed files
with
666 additions
and
5 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
import os | ||
import sys | ||
from network import data_features_dumper as dr | ||
from network import computation_graph | ||
from network.trainer import Trainer | ||
|
||
|
||
# path = '/home/srq/Datasets/tables/unlv-for-nlp/train' | ||
# glove_path = '/media/srq/Seagate Expansion Drive/Models/GloVe/glove.840B.300d.txt' | ||
# | ||
# data_reader = dr.DataReader(path, glove_path, 'train') | ||
# data_reader.load() | ||
|
||
|
||
trainer = Trainer() | ||
trainer.init(dump_features_again=False) | ||
trainer.train() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
import torch | ||
import numpy as np | ||
from torch.autograd import Variable | ||
|
||
|
||
class ModuleA(torch.nn.Module): | ||
def __init__(self, D_in, D_out): | ||
super(ModuleA, self).__init__() | ||
|
||
# H1 = Variable(torch.randn(num_words, 100)) | ||
# H2 = Variable(torch.randn(num_words, 100)) | ||
|
||
self.linear1 = torch.nn.Linear(D_in, 100) | ||
self.linear2 = torch.nn.Linear(100, 100) | ||
self.linear3 = torch.nn.Linear(100, D_out) | ||
|
||
def forward(self, x): | ||
o1 = self.linear1(x).clamp(min=0) | ||
o2 = self.linear2(o1).clamp(min=0) | ||
return self.linear3(o2).clamp(min=0) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
import torch | ||
import numpy as np | ||
from torch.autograd import Variable | ||
|
||
|
||
class ModuleB(torch.nn.Module): | ||
def __init__(self): | ||
super(ModuleB, self).__init__() | ||
self.gru = torch.nn.GRUCell(500, 100) | ||
|
||
def forward(self, x, hx): | ||
return self.gru.forward(x, hx) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
import torch | ||
import numpy as np | ||
from torch.autograd import Variable | ||
|
||
|
||
class ModuleB2(torch.nn.Module): | ||
def __init__(self, D_in, D_out_1, D_out_2): | ||
super(ModuleB2, self).__init__() | ||
|
||
# H1 = Variable(torch.randn(num_words, 100)) | ||
# H2 = Variable(torch.randn(num_words, 100)) | ||
|
||
self.linear1 = torch.nn.Linear(D_in, 100) | ||
self.linear2 = torch.nn.Linear(100, 100) | ||
self.linear3 = torch.nn.Linear(100, D_out_1) | ||
|
||
# | ||
# H12 = Variable(torch.randn(num_words, 100)) | ||
# H22 = Variable(torch.randn(num_words, 100)) | ||
|
||
self.linear12 = torch.nn.Linear(D_in, 100) | ||
self.linear22 = torch.nn.Linear(100, 100) | ||
self.linear32 = torch.nn.Linear(100, D_out_2) | ||
|
||
def forward(self, x): | ||
o1 = self.linear1(x).clamp(min=0) | ||
o2 = self.linear2(o1).clamp(min=0) | ||
|
||
o12 = self.linear12(x).clamp(min=0) | ||
o22 = self.linear22(o12).clamp(min=0) | ||
|
||
return self.linear3(o2).clamp(min=0), self.linear32(o22).clamp(min=0) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
import torch | ||
import numpy as np | ||
from torch.autograd import Variable | ||
|
||
|
||
class ModuleC(torch.nn.Module): | ||
def __init__(self, D_in, D_out): | ||
super(ModuleC, self).__init__() | ||
|
||
# H1 = Variable(torch.randn(num_words, 100)) | ||
# H2 = Variable(torch.randn(num_words, 100)) | ||
|
||
self.linear1 = torch.nn.Linear(D_in, 100) | ||
self.linear2 = torch.nn.Linear(100, 100) | ||
self.linear3 = torch.nn.Linear(100, D_out) | ||
|
||
def forward(self, x): | ||
o1 = self.linear1(x).clamp(min=0) | ||
o2 = self.linear2(o1).clamp(min=0) | ||
return self.linear3(o2).clamp(min=0) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
import torch | ||
import numpy as np | ||
from torch.autograd import Variable | ||
|
||
|
||
class ModuleCollect(torch.nn.Module): | ||
def __init__(self, D_in, num_words): | ||
super(ModuleCollect, self).__init__() | ||
self.D_in = D_in | ||
self.num_words = num_words | ||
|
||
def forward(self, x, indices): | ||
y = Variable(torch.zeros(self.num_words, self.D_in * 5)) | ||
|
||
for i in range(self.num_words): | ||
y[i, 0:self.D_in] = x[i] | ||
y[i, self.D_in * 1:self.D_in * 2] = x[np.maximum(indices[i, 0], 0)] * int(indices[i, 0] != -1) | ||
y[i, self.D_in * 2:self.D_in * 3] = x[np.maximum(indices[i, 1], 0)] * int(indices[i, 1] != -1) | ||
y[i, self.D_in * 3:self.D_in * 4] = x[np.maximum(indices[i, 2], 0)] * int(indices[i, 2] != -1) | ||
y[i, self.D_in * 4:self.D_in * 5] = x[np.maximum(indices[i, 3], 0)] * int(indices[i, 3] != -1) | ||
|
||
return y |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
import torch | ||
import numpy as np | ||
from torch.autograd import Variable | ||
|
||
|
||
class ModuleD(torch.nn.Module): | ||
def __init__(self, D_in, D_out): | ||
super(ModuleD, self).__init__() | ||
|
||
self.linear1 = torch.nn.Linear(D_in, 100) | ||
self.linear2 = torch.nn.Linear(100, 100) | ||
self.linear3 = torch.nn.Linear(100, D_out) | ||
|
||
def forward(self, x): | ||
o1 = self.linear1(x).clamp(min=0) | ||
o2 = self.linear2(o1).clamp(min=0) | ||
return self.linear3(o2).clamp(min=0) |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,52 @@ | ||
import numpy as np | ||
import torch | ||
from torch.autograd import Variable | ||
from network.ModuleA import ModuleA | ||
from network.ModuleB import ModuleB | ||
from network.ModuleB2 import ModuleB2 | ||
from network.ModuleC import ModuleC | ||
from network.ModuleD import ModuleD | ||
from network.ModuleCollect import ModuleCollect | ||
|
||
|
||
class SimpleDocProcModel(torch.nn.Module): | ||
def __init__(self): | ||
super(SimpleDocProcModel, self).__init__() | ||
self.k = 10 | ||
self.D_in = 300 + self.k | ||
|
||
self.A = ModuleA(self.D_in, 100) | ||
self.B = ModuleB() | ||
self.B2 = ModuleB2(100, 100, 100) | ||
self.C = ModuleC(100, 2) | ||
self.D = ModuleD(100, 100) | ||
# self.Cat = ModuleCollect(100, self.N) | ||
self.iterations = 1 | ||
|
||
def set_iterations(self, iterations): | ||
self.iterations = iterations | ||
|
||
def concat(self, x, indices, num_words): | ||
y = Variable(torch.zeros(num_words, self.D_in * 5)) | ||
|
||
for i in range(num_words): | ||
y[i, 0:self.D_in] = x[i] | ||
y[i, self.D_in * 1:self.D_in * 2] = x[np.maximum(indices[i, 0], 0)] * int(indices[i, 0] != -1) | ||
y[i, self.D_in * 2:self.D_in * 3] = x[np.maximum(indices[i, 1], 0)] * int(indices[i, 1] != -1) | ||
y[i, self.D_in * 3:self.D_in * 4] = x[np.maximum(indices[i, 2], 0)] * int(indices[i, 2] != -1) | ||
y[i, self.D_in * 4:self.D_in * 5] = x[np.maximum(indices[i, 3], 0)] * int(indices[i, 3] != -1) | ||
|
||
return y | ||
|
||
def forward(self, indices, vv, num_words): | ||
uu = self.A.forward(vv) | ||
hh = Variable(torch.zeros(num_words,100)) | ||
|
||
for i in range(self.iterations): | ||
ww = self.concat(uu, indices) | ||
bb = self.B.forward(ww, hh, num_words) | ||
oo, hh = self.B2.forward(bb) | ||
ll = self.C.forward(oo) | ||
uu = self.D.forward(hh) | ||
|
||
return ll |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,132 @@ | ||
import os | ||
import nltk | ||
import ssl | ||
import numpy as np | ||
import random | ||
from .glove_reader import GLoVe | ||
import cv2 | ||
import json | ||
from network.neighbor_graph_builder import NeighborGraphBuilder | ||
from random import shuffle | ||
from network.document_features import DocumentFeatures | ||
import pickle | ||
|
||
|
||
# TODO: Tackle - Won't work for very large dataset because images are loaded into memory | ||
class DataFeaturesDumper: | ||
path = '' | ||
tokens_set = set() | ||
docs = [] | ||
queue = [] | ||
|
||
def __init__(self, path, glove_path, cache_name): | ||
self.path = path | ||
self.glove_path = glove_path | ||
self.cache_name = cache_name | ||
|
||
def dump_doc(self, all_tokens, all_tokens_rects, image, file_name): | ||
N = len(all_tokens) | ||
height, width = np.shape(image) | ||
class_one_hot = np.zeros((N, 2)) | ||
rect_matrix = np.zeros((N, 4)) | ||
embeddings_matrix = np.zeros((N, 300)) | ||
for i in range(N): | ||
token_rect = all_tokens_rects[i] | ||
index = 0 if image[int(token_rect['y'] + token_rect['height'] / 2), int( | ||
token_rect['x'] + token_rect['width'] / 2)] == 0 else 1 | ||
class_one_hot[i, index] = 1 | ||
rect_matrix[i, 0] = token_rect['x'] / width | ||
rect_matrix[i, 1] = token_rect['y'] / height | ||
rect_matrix[i, 2] = token_rect['width'] / width | ||
rect_matrix[i, 3] = token_rect['height'] / height | ||
embedding = self.glove_reader.get_vector(all_tokens[i]) | ||
if embedding is None: | ||
embedding = np.ones((300)) * (-1) | ||
embeddings_matrix[i] = embedding | ||
|
||
|
||
graph_builder = NeighborGraphBuilder(all_tokens_rects, image) | ||
neighbor_graph, neighbor_distance_matrix = graph_builder.get_neighbor_matrix() | ||
neighbor_distance_matrix[:, 0] = neighbor_distance_matrix[:, 0] / width | ||
neighbor_distance_matrix[:, 1] = neighbor_distance_matrix[:, 1] / height | ||
neighbor_distance_matrix[:, 2] = neighbor_distance_matrix[:, 2] / width | ||
neighbor_distance_matrix[:, 3] = neighbor_distance_matrix[:, 3] / height | ||
document = DocumentFeatures(embeddings_matrix, rect_matrix, neighbor_distance_matrix, neighbor_graph, class_one_hot) | ||
with open(file_name, 'wb') as f: | ||
pickle.dump(document, f, pickle.HIGHEST_PROTOCOL) | ||
|
||
def get_glove_reader(self): | ||
return self.glove_reader | ||
|
||
def build_glove(self): | ||
# Find all the unique words first | ||
ii = 0 | ||
for i in os.listdir(self.path): | ||
full_example_path = os.path.join(self.path, i) | ||
|
||
json_path = os.path.join(full_example_path, 'ocr_gt.json') | ||
|
||
with open(json_path) as data_file: | ||
ocr_data = json.load(data_file) | ||
|
||
for i in range(len(ocr_data)): | ||
word_data = ocr_data[i] | ||
tokens = nltk.word_tokenize(word_data['word']) | ||
for j in tokens: | ||
self.tokens_set.add(j) | ||
print("Unique words are %d" % len(self.tokens_set)) | ||
self.glove_reader = GLoVe(self.glove_path, self.tokens_set) | ||
self.glove_reader.load(self.cache_name) | ||
|
||
def load(self): | ||
print("Loading data") | ||
self.build_glove() | ||
|
||
ii = 0 | ||
for i in os.listdir(self.path): | ||
full_example_path = os.path.join(self.path, i) | ||
|
||
image_path = os.path.join(full_example_path, 'image.png') | ||
json_path = os.path.join(full_example_path, 'ocr_gt.json') | ||
document_dump_path = os.path.join(full_example_path, '__dump__.pickle') | ||
image = cv2.imread(image_path, 0) | ||
height, width = np.shape(image) | ||
|
||
all_words = [] | ||
all_words_rects = [] | ||
|
||
with open(json_path) as data_file: | ||
ocr_data = json.load(data_file) | ||
|
||
for i in range(len(ocr_data)): | ||
word_data = ocr_data[i] | ||
all_words.append(word_data['word']) | ||
all_words_rects.append(word_data['rect']) | ||
|
||
all_tokens = [] | ||
all_tokens_rects = [] | ||
class_indices = [] | ||
|
||
for i in range(len(all_words)): | ||
tokens = nltk.word_tokenize(all_words[i]) | ||
all_tokens.extend(tokens) | ||
word_rect = all_words_rects[i] | ||
divided_width = word_rect['width'] / len(tokens) | ||
# If a word contains more than one token, just | ||
# divide along width | ||
for j in range(len(tokens)): | ||
token_rect = dict(word_rect) | ||
token_rect['x'] += int(j*divided_width) | ||
token_rect['width'] = int(divided_width) | ||
all_tokens_rects.append(token_rect) | ||
assert(len(all_tokens) == len(all_tokens_rects)) | ||
|
||
for i in range(len(all_tokens)): | ||
token_rect = all_tokens_rects[i] | ||
class_indices.append(0 if image[int(token_rect['y'] + token_rect['height']/2), int(token_rect['x'] + token_rect['width']/2)] == 0 else 1) | ||
|
||
self.dump_doc(all_tokens, all_tokens_rects, image, document_dump_path) | ||
ii += 1 | ||
print("Loaded %d" % ii) | ||
|
||
print("Data loaded") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
|
||
class DocumentFeatures: | ||
def __init__(self, tokens_embeddings, tokens_rects, neighbor_distance_matrix, tokens_neighbor_matrix, tokens_classes): | ||
self.embeddings = tokens_embeddings | ||
self.rects = tokens_rects | ||
self.distances = neighbor_distance_matrix | ||
self.neighbor_graph = tokens_neighbor_matrix | ||
self.classes = tokens_classes | ||
|
Oops, something went wrong.