-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
038f332
commit 5900440
Showing
10 changed files
with
317 additions
and
23 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
import torch | ||
import numpy as np | ||
from torch.autograd import Variable | ||
|
||
|
||
class ModuleB2(torch.nn.Module): | ||
def __init__(self, D_in, D_out_1, D_out_2): | ||
super(ModuleB2, self).__init__() | ||
|
||
# H1 = Variable(torch.randn(num_words, 100)) | ||
# H2 = Variable(torch.randn(num_words, 100)) | ||
|
||
self.linear1 = torch.nn.Linear(D_in, 100).cuda() | ||
self.linear2 = torch.nn.Linear(100, 100).cuda() | ||
self.linear3 = torch.nn.Linear(100, D_out_1).cuda() | ||
|
||
# | ||
# H12 = Variable(torch.randn(num_words, 100)) | ||
# H22 = Variable(torch.randn(num_words, 100)) | ||
|
||
self.linear12 = torch.nn.Linear(D_in, 100).cuda() | ||
self.linear22 = torch.nn.Linear(100, 100).cuda() | ||
self.linear32 = torch.nn.Linear(100, D_out_2).cuda() | ||
|
||
|
||
self.linear1.weight.data.uniform_(-30, 30) | ||
self.linear1.bias.data.uniform_(0, 0) | ||
self.linear2.weight.data.uniform_(-30, 30) | ||
self.linear2.bias.data.uniform_(0, 0) | ||
self.linear3.weight.data.uniform_(-30, 30) | ||
self.linear3.bias.data.uniform_(0, 0) | ||
self.linear12.weight.data.uniform_(-30, 30) | ||
self.linear12.bias.data.uniform_(0, 0) | ||
self.linear22.weight.data.uniform_(-30, 30) | ||
self.linear22.bias.data.uniform_(0, 0) | ||
self.linear32.weight.data.uniform_(-30, 30) | ||
self.linear32.bias.data.uniform_(0, 0) | ||
|
||
self.activation = torch.nn.Tanh() | ||
|
||
def forward(self, x): | ||
o1 = self.linear1(x).clamp(min=0) | ||
o2 = self.linear2(o1).clamp(min=0) | ||
|
||
o12 = self.linear12(x).clamp(min=0) | ||
o22 = self.linear22(o12).clamp(min=0) | ||
|
||
return self.activation(self.linear3(o2)), self.activation(self.linear32(o22)) | ||
|
||
# return self.linear1(x).tanh(), self.linear12(x).tanh() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,51 @@ | ||
import numpy as np | ||
import torch | ||
from torch.autograd import Variable | ||
from network.ModuleA import ModuleA | ||
from network.ModuleB import ModuleB | ||
from network.ModuleB2 import ModuleB2 | ||
from network.ModuleC import ModuleC | ||
from network.ModuleD import ModuleD | ||
|
||
|
||
class ComputationGraphTableNeighborParse(torch.nn.Module): | ||
def __init__(self): | ||
super(ComputationGraphTableNeighborParse, self).__init__() | ||
self.k = 8 | ||
self.D_in = 300 + self.k | ||
|
||
self.A = ModuleA(self.D_in, 100) | ||
self.B = ModuleB() | ||
self.B2 = ModuleB2(100, 100, 100) | ||
self.C = ModuleC(100, 2) | ||
self.D = ModuleD(100, 100) | ||
# self.Cat = ModuleCollect(100, self.N) | ||
self.iterations = 1 | ||
|
||
def set_iterations(self, iterations): | ||
self.iterations = iterations | ||
|
||
def concat(self, x, indices, indices_not_found, num_words): | ||
y = Variable(torch.zeros(num_words, 100 * 5)).cuda() | ||
y[:, 000:100] = x[indices[:, 0]] | ||
y[:, 100:200] = x[indices[:, 1]] | ||
y[:, 200:300] = x[indices[:, 2]] | ||
y[:, 300:400] = x[indices[:, 3]] | ||
y[:, 400:500] = x[indices[:, 4]] | ||
y[indices_not_found] = 0 | ||
|
||
return y | ||
|
||
def forward(self, indices, indices_not_found, vv, num_words): | ||
|
||
uu = self.A.forward(vv) | ||
hh = Variable(torch.zeros(num_words,100)).cuda() | ||
|
||
for i in range(self.iterations): | ||
ww = self.concat(uu, indices, indices_not_found, num_words) | ||
bb = self.B.forward(ww, hh) | ||
oo, hh = self.B2.forward(bb) | ||
ll = self.C.forward(oo) | ||
uu = self.D.forward(hh) | ||
|
||
return ll |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
import torch | ||
import numpy as np | ||
from torch.autograd import Variable | ||
|
||
|
||
class Dense(torch.nn.Module): | ||
def __init__(self, D_in, config = [300,'S',100,'S',100,'T']): | ||
super(Dense, self).__init__() | ||
layers = [] | ||
last = D_in | ||
for v in config: | ||
if v=='R': | ||
layers += [torch.nn.ReLU()] | ||
elif v=='S': | ||
layers += [torch.nn.Sigmoid()] | ||
elif v=='T': | ||
layers += [torch.nn.Tanh()] | ||
else: | ||
num_next = int(v) | ||
layers += [torch.nn.Linear(last, num_next)] | ||
last = num_next | ||
|
||
self.output = torch.nn.Sequential(*layers) | ||
|
||
def forward(self, x): | ||
return self.output(x) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,12 +1,15 @@ | ||
|
||
class TableData: | ||
def __init__(self, tokens_embeddings, tokens_rects, neighbor_distance_matrix, tokens_neighbor_matrix, | ||
tokens_share_row_matrix, tokens_share_col_matrix, tokens_share_cell_matrix): | ||
tokens_share_row_matrix, tokens_share_col_matrix, tokens_share_cell_matrix, neighbors_same_row, neighbors_same_col, neighbors_same_cell): | ||
self.embeddings = tokens_embeddings | ||
self.rects = tokens_rects | ||
self.distances = neighbor_distance_matrix | ||
self.neighbor_graph = tokens_neighbor_matrix | ||
self.row_share = tokens_share_row_matrix | ||
self.col_share = tokens_share_col_matrix | ||
self.cell_share = tokens_share_cell_matrix | ||
self.neighbors_same_row = neighbors_same_row | ||
self.neighbors_same_col = neighbors_same_col | ||
self.neighbors_same_cell = neighbors_same_cell | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,129 @@ | ||
import numpy as np | ||
import configparser as cp | ||
from network.data_features_dumper import DataFeaturesDumper | ||
from network.silknet import LoadInterface | ||
from network.silknet.FolderDataReader import FolderDataReader | ||
from interface import implements | ||
import os | ||
import pickle | ||
from network.computation_graph_neighbor_parse import ComputationGraphTableNeighborParse | ||
import torch | ||
from torch.autograd import Variable | ||
import cv2 | ||
|
||
|
||
class DataLoader(implements(LoadInterface)): | ||
def load_datum(self, full_path): | ||
with open(os.path.join(full_path, '__dump__.pickle'), 'rb') as f: | ||
doc = pickle.load(f) | ||
return doc | ||
|
||
|
||
class Trainer: | ||
def __init__(self): | ||
config = cp.ConfigParser() | ||
config.read('config.ini') | ||
self.train_path = config['neighbor_parse']['train_data_path'] | ||
self.test_path = config['neighbor_parse']['test_data_path'] | ||
self.validation_data_path = config['neighbor_parse']['validation_data_path'] | ||
self.glove_path = config['neighbor_parse']['glove_path'] | ||
self.learning_rate = float(config['neighbor_parse']['learning_rate']) | ||
self.from_scratch = int(config['neighbor_parse']['from_scratch']) == 1 | ||
self.model_path = config['neighbor_parse']['model_path'] | ||
self.save_after = int(config['neighbor_parse']['save_after']) | ||
|
||
def init(self, dump_features_again): | ||
pass | ||
|
||
def do_plot(self, document, id): | ||
rects = document.rects | ||
row_share = document.row_share | ||
canvas = (np.ones((500,500, 3))*255).astype(np.uint8) | ||
for i in range(len(rects)): | ||
rect = rects[i] | ||
color = (255, 0, 0) if document.cell_share[0, i] == 0 else (0,0,255) | ||
cv2.rectangle(canvas, (int(rect[0] * 500), int(rect[1]*500)), (int((rect[0]+rect[2]) * 500), int((rect[1]+rect[3])*500)), color) | ||
cv2.imshow('test' + id, canvas) | ||
cv2.waitKey(0) | ||
|
||
def get_example_elements(self, document, id): | ||
num_words, _ = np.shape(document.rects) | ||
vv = np.concatenate([document.rects, document.distances, document.embeddings], axis=1).astype(np.float32) | ||
vv = Variable(torch.from_numpy(vv)).cuda() | ||
y = Variable(torch.from_numpy(document.neighbors_same_cell[:,0].astype(np.int64)), requires_grad=False).cuda() | ||
|
||
baseline_accuracy_1 = 100 * np.sum(document.neighbors_same_cell[:,0] == 0) / num_words | ||
baseline_accuracy_2 = 100 * np.sum(document.neighbors_same_cell[:,0] == 1) / num_words | ||
|
||
indices = torch.LongTensor(torch.from_numpy(np.concatenate( | ||
[np.expand_dims(np.arange(num_words, dtype=np.int64), axis=1), | ||
np.maximum(document.neighbor_graph.astype(np.int64), 0)], axis=1))).cuda() | ||
indices_not_found = torch.ByteTensor(torch.from_numpy(np.repeat(np.concatenate( | ||
[np.expand_dims(np.zeros(num_words, dtype=np.int64), axis=1), | ||
document.neighbor_graph.astype(np.int64)], axis=1) == -1, 100).reshape((-1, 500)).astype( | ||
np.uint8))).cuda() | ||
# indices_not_found = indices_not_found * 0 | ||
|
||
return num_words, vv, y, baseline_accuracy_1, baseline_accuracy_2, indices, indices_not_found | ||
|
||
def do_validation(self, model, dataset): | ||
|
||
sum_of_accuracies = 0 | ||
total = 0 | ||
while True: | ||
document, epoch, id = dataset.next_element() | ||
num_words, vv, y, baseline_accuracy_1, baseline_accuracy_2, indices, indices_not_found = self.get_example_elements(document, id) | ||
|
||
y_pred = model(indices, indices_not_found, vv, num_words) | ||
_, predicted = torch.max(y_pred.data, 1) | ||
|
||
accuracy = torch.sum(predicted == y.data) | ||
accuracy = 100 * accuracy / num_words | ||
|
||
print(accuracy) | ||
|
||
total += 1 | ||
sum_of_accuracies += accuracy | ||
|
||
if epoch == 1: | ||
break | ||
|
||
print("Average validation accuracy = ", sum_of_accuracies / total) | ||
|
||
def train(self): | ||
|
||
dataset = FolderDataReader(self.train_path, DataLoader()) | ||
validation_dataset = FolderDataReader(self.validation_data_path, DataLoader()) | ||
dataset.init() | ||
validation_dataset.init() | ||
model = ComputationGraphTableNeighborParse() | ||
model.set_iterations(4) | ||
criterion = torch.nn.CrossEntropyLoss(size_average=True) | ||
optimizer = torch.optim.Adam(model.parameters(), lr=self.learning_rate) | ||
for i in range(1000000): | ||
if i % 10000 == 0: | ||
self.do_validation(model, validation_dataset) | ||
|
||
document, epoch, id = dataset.next_element() | ||
num_words, vv, y, baseline_accuracy_1, baseline_accuracy_2, indices, indices_not_found = self.get_example_elements(document, id) | ||
|
||
|
||
for j in range(1): | ||
y_pred = model(indices, indices_not_found, vv, num_words) | ||
_, predicted = torch.max(y_pred.data, 1) | ||
accuracy = torch.sum(predicted == y.data) | ||
accuracy = 100 * accuracy / num_words | ||
|
||
yes_pred = torch.sum(predicted == 0) | ||
yes_pred = 100 * yes_pred / num_words | ||
|
||
no_pred = torch.sum(predicted == 1) | ||
no_pred = 100 * no_pred / num_words | ||
|
||
|
||
loss = criterion(y_pred, y) | ||
|
||
optimizer.zero_grad() | ||
loss.backward() | ||
optimizer.step() | ||
print("%3dx%3d Loss = %f" % (i, j, loss.data[0]), "Accuracy: %03.2f" % accuracy, "Yes: %03.2f" % yes_pred, "No: %03.2f" % no_pred, "Base Yes: %03.2f" % baseline_accuracy_1,"Base No: %03.2f" % baseline_accuracy_2, torch.sum(y_pred).data[0]) |
Oops, something went wrong.