-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
57d93e2
commit 7104e93
Showing
10 changed files
with
310 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
import torch | ||
import numpy as np | ||
from torch.autograd import Variable | ||
from torch import nn | ||
from conv_autoencoder.ModuleEncode import ModuleEncode | ||
from conv_autoencoder.ModuleDecode import ModuleDecode | ||
|
||
|
||
class Autoencoder(torch.nn.Module): | ||
def __init__(self): | ||
super(Autoencoder, self).__init__() | ||
self.encoder = ModuleEncode() | ||
self.decoder = ModuleDecode() | ||
|
||
def forward(self, x): | ||
return self.decoder(self.encoder(x)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
import torch | ||
import numpy as np | ||
from torch.autograd import Variable | ||
from torch import nn | ||
|
||
class ModuleDecode(torch.nn.Module): | ||
|
||
def __init__(self): | ||
super(ModuleDecode, self).__init__() | ||
|
||
#config = [3, 48, 48, 'M', 64, 64, 'M', 128, 128, 128, 'M', 192, 192] | ||
config = [3, 48, 48, 'M', 48, 48, 'M', 48, 48, 48, 'M', 48, 48] | ||
config.reverse() | ||
|
||
layers = [] | ||
# in_channels = 192 | ||
in_channels = 48 | ||
for v in config: | ||
if v == 'M': | ||
layers += [nn.Upsample(scale_factor=2, mode='nearest')] | ||
else: | ||
conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1) | ||
layers += [conv2d, nn.ReLU(inplace=True)] | ||
in_channels = v | ||
self.image = nn.Sequential(*layers) | ||
|
||
|
||
def forward(self, x): | ||
# o1 = self.linear1(x).clamp(min=0) | ||
# o2 = self.linear2(o1).clamp(min=0) | ||
return self.image(x) | ||
|
||
# return self.linear3(x).tanh() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
import torch | ||
import numpy as np | ||
from torch.autograd import Variable | ||
from torch import nn | ||
|
||
class ModuleEncode(torch.nn.Module): | ||
|
||
def __init__(self): | ||
super(ModuleEncode, self).__init__() | ||
|
||
#config = [48, 48, 'M', 64, 64, 'M', 128, 128, 128, 'M', 192, 192, 192] | ||
config = [48, 48, 'M', 48, 48, 'M', 48, 48, 48, 'M', 48, 48, 48] | ||
|
||
layers = [] | ||
in_channels = 3 | ||
for v in config: | ||
if v == 'M': | ||
layers += [nn.MaxPool2d(kernel_size=2, stride=2)] | ||
else: | ||
conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1) | ||
layers += [conv2d, nn.ReLU(inplace=True)] | ||
in_channels = v | ||
self.features = nn.Sequential(*layers) | ||
|
||
|
||
def forward(self, x): | ||
return self.features(x) |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
import torch | ||
from torch import nn | ||
from torch.autograd import Variable | ||
|
||
image = Variable(torch.FloatTensor(1, 3, 300, 300)) | ||
|
||
C = nn.Conv2d(3,5,(5,5)) | ||
|
||
image2 = C(image) | ||
|
||
print(image2.size()) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,147 @@ | ||
import numpy as np | ||
import configparser as cp | ||
from network.silknet import LoadInterface | ||
from network.silknet.FolderDataReader import FolderDataReader | ||
from interface import implements | ||
import os | ||
import torch | ||
from torch.autograd import Variable | ||
import cv2 | ||
import random | ||
from conv_autoencoder.Autoencoder import Autoencoder | ||
import sys | ||
|
||
|
||
class DataLoader(implements(LoadInterface)): | ||
def __init__(self, num_crops): | ||
self.num_crops = num_crops | ||
self.K = cv2.getGaussianKernel(5, 5) | ||
|
||
def load_datum(self, full_path): | ||
image = cv2.imread(os.path.join(full_path, 'image.png'), 1) | ||
height, width, _ = np.shape(image) | ||
|
||
patches_x = np.zeros((self.num_crops, 3, 512, 512)).astype(np.float32) | ||
patches_y = np.zeros((self.num_crops, 3, 512, 512)).astype(np.float32) | ||
|
||
for j in range(self.num_crops): | ||
crop_width = random.randint(100, width) | ||
crop_height = random.randint(100, height) | ||
x = random.randint(0, width - crop_width) | ||
y = random.randint(0, height - crop_height) | ||
patch = image[y:y + crop_height, x:x + crop_width, :] | ||
patch_x = cv2.resize(patch, dsize=(512, 512)) | ||
patch_y = cv2.filter2D(patch_x, -1, self.K) | ||
|
||
patches_x[j] = np.swapaxes(patch_x, 0, 2).astype(np.float32) | ||
patches_y[j] = np.swapaxes(patch_y, 0, 2).astype(np.float32) | ||
|
||
datum = dict() | ||
datum['X'] = patches_x | ||
datum['Y'] = patches_x | ||
|
||
return datum | ||
|
||
|
||
class ConvolutionalAutoencoder: | ||
def __init__(self): | ||
config = cp.ConfigParser() | ||
config.read('config.ini') | ||
self.train_path = config['conv_auto_encoder']['train_data_path'] | ||
self.test_path = config['conv_auto_encoder']['test_data_path'] | ||
self.validation_data_path = config['conv_auto_encoder']['validation_data_path'] | ||
self.learning_rate = float(config['conv_auto_encoder']['learning_rate']) | ||
self.from_scratch = int(config['conv_auto_encoder']['from_scratch']) == 1 | ||
self.model_path = config['conv_auto_encoder']['model_path'] | ||
self.save_after = int(config['conv_auto_encoder']['save_after']) | ||
self.batch_size = int(config['conv_auto_encoder']['batch_size']) | ||
self.manual_mode_loaded = False | ||
|
||
def train(self): | ||
assert not self.manual_mode_loaded | ||
dataset = FolderDataReader(self.train_path, DataLoader(self.batch_size)) | ||
dataset.init() | ||
model = Autoencoder().cuda() | ||
|
||
if not self.from_scratch: | ||
print("Loaded") | ||
model.load_state_dict(torch.load(self.model_path)) | ||
|
||
criterion = torch.nn.MSELoss(size_average=True) | ||
optimizer = torch.optim.Adam(model.parameters(), lr=self.learning_rate) | ||
for i in range(1000000): | ||
# Save model | ||
if (i % self.save_after) == 0: | ||
print("Saving model") | ||
torch.save(model.state_dict(), self.model_path) | ||
|
||
document, epoch, id = dataset.next_element() | ||
input_image = Variable(torch.from_numpy(document['X'])).cuda() | ||
expected_output = Variable(torch.from_numpy(document['Y'])).cuda() | ||
|
||
given_output = model(input_image) | ||
loss = criterion(given_output, expected_output) | ||
|
||
optimizer.zero_grad() | ||
loss.backward() | ||
optimizer.step() | ||
print("%5d Loss = %f" % (i, loss.data[0])) | ||
|
||
def prepare_for_manual_testing(self): | ||
self.manual_model = Autoencoder().cuda() | ||
self.manual_model .load_state_dict(torch.load(self.model_path)) | ||
self.manual_mode_loaded = True | ||
|
||
|
||
# image: Numpy array (N,N,3) | ||
def get_feature_map(self, image): | ||
assert self.manual_mode_loaded | ||
patches_x = np.zeros((1, 3, 512, 512)).astype(np.float32) | ||
|
||
patch_x = cv2.resize(image, dsize=(512, 512)) | ||
|
||
patches_x[0] = np.swapaxes(patch_x, 0, 2).astype(np.float32) | ||
|
||
input_image = Variable(torch.from_numpy(patches_x)).cuda() | ||
|
||
given_output = self.manual_model.encoder(input_image) | ||
|
||
return np.swapaxes((given_output.cpu().data.numpy())[0], 0, 2) | ||
|
||
def test(self): | ||
assert not self.manual_mode_loaded | ||
dataset = FolderDataReader(self.validation_data_path, DataLoader(1)) | ||
dataset.init() | ||
model = Autoencoder().cuda() | ||
model.load_state_dict(torch.load(self.model_path)) | ||
for i in range(1000000): | ||
document, epoch, id = dataset.next_element() | ||
input_image = Variable(torch.from_numpy(document['X'])).cuda() | ||
given_output = model(input_image) | ||
|
||
original = np.swapaxes(document['Y'][0], 0, 2).astype(np.uint8) | ||
result = np.swapaxes((given_output.cpu().data.numpy())[0], 0, 2) | ||
result = result / np.max(result) | ||
cv2.imshow('a', original) | ||
cv2.imshow('b', result) | ||
cv2.waitKey(0) | ||
|
||
|
||
|
||
if __name__ == '__main__': | ||
if len(sys.argv) != 2: | ||
print("Error") | ||
sys.exit() | ||
|
||
arg = sys.argv[1] | ||
|
||
if arg == 'train': | ||
train = True | ||
else: | ||
train = False | ||
|
||
trainer = ConvolutionalAutoencoder() | ||
if train: | ||
trainer.train() | ||
else: | ||
trainer.test() |
20 changes: 20 additions & 0 deletions
20
python/conv_autoencoder/prepare_dataset_for_conv_autoencoders.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
import torch | ||
import numpy as np | ||
import cv2 | ||
import os | ||
import shutil | ||
|
||
images_path = '/home/srq/Datasets/tables/uw3/correctgt-zones/' | ||
output_path = '/home/srq/Datasets/tables/uw3-for-auto-encoders' | ||
|
||
for i in os.listdir(images_path): | ||
if not i.endswith('.png'): | ||
continue | ||
id = os.path.splitext(i)[0] | ||
print(id) | ||
output_create_dir = os.path.join(output_path,id) | ||
if not os.path.exists(output_create_dir): | ||
os.mkdir(output_create_dir) | ||
image_out = os.path.join(output_create_dir, 'image.png') | ||
shutil.copy(os.path.join(images_path, i), image_out) | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
import cv2 | ||
import numpy as np | ||
|
||
path = '/home/srq/Datasets/tables/uw3/correctgt-zones/D05DBIN.png' | ||
|
||
image = cv2.imread(path,0) | ||
|
||
height, width = np.shape(image) | ||
|
||
scale_factor = 500 / max(height,width) | ||
|
||
image_resized = cv2.resize(image, None, fx=scale_factor, fy=scale_factor) | ||
|
||
K = cv2.getGaussianKernel(3,3) | ||
image_smoothed = cv2.filter2D(image_resized, -1, K) | ||
|
||
cv2.imshow('t', image_smoothed) | ||
cv2.waitKey(0) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
import unittest | ||
import torch | ||
from conv_autoencoder.ModuleEncode import ModuleEncode | ||
from conv_autoencoder.ModuleDecode import ModuleDecode | ||
from torch.autograd import Variable | ||
|
||
|
||
class NetsTests(unittest.TestCase): | ||
def setUp(self): | ||
super().setUp() | ||
self.encoder = ModuleEncode() | ||
self.decoder = ModuleDecode() | ||
|
||
def test_encode_decode(self): | ||
test_random_image = torch.FloatTensor(1, 3, 512, 512) | ||
image_variable = Variable(test_random_image) | ||
output_features = self.encoder(image_variable) | ||
print("Encoded size", output_features.size()) | ||
image = self.decoder(output_features) | ||
print("Decoded size", image.size()) | ||
|
||
|
||
if __name__ == '__main__': | ||
unittest.main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
import random | ||
|
||
|
||
def x(width, height): | ||
crop_width = random.randint(100, width) | ||
crop_height = random.randint(100, height) | ||
x = random.randint(0, width - crop_width -1) | ||
y = random.randint(0, height - crop_height -1) | ||
|
||
print(x, y, crop_width, crop_height) | ||
|
||
|
||
for _ in range(300): | ||
x(3000, 2500) |