-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
9 changed files
with
275 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,77 @@ | ||
# coding=utf-8 | ||
import torch | ||
from torch import nn | ||
import numpy as np | ||
from torch.autograd import Variable | ||
from torch.utils.data import DataLoader,TensorDataset | ||
|
||
|
||
def train_predict(x_3d, y_3d, model, un_idx_train, un_idx_valid, | ||
epoch, batch_sz, learn_rate, w_decay,use_gpu): | ||
img_channel, img_height, img_width = x_3d.shape | ||
# x_3d : [img_channel,img_height,img_width] | ||
x_2d = np.transpose(x_3d.reshape(img_channel, img_height * img_width)) # [num,band] | ||
y_2d = np.transpose(y_3d.reshape(img_channel, img_height * img_width)) | ||
train_dataset = TensorDataset(torch.tensor(x_2d[un_idx_train, :], dtype=torch.float32), | ||
torch.tensor(y_2d[un_idx_train, :], dtype=torch.float32)) | ||
valid_label_x = torch.tensor(x_2d[un_idx_valid, :], dtype=torch.float32) | ||
valid_label_y = torch.tensor(y_2d[un_idx_valid, :], dtype=torch.float32) | ||
data_loader = DataLoader(train_dataset, batch_size=batch_sz, shuffle=True) | ||
iter_num = un_idx_train.size // batch_sz | ||
loss_fc = nn.MSELoss() | ||
# While constructing the network, transfer the model to GPU (pytorch) | ||
if (use_gpu): | ||
model = model.cuda() | ||
loss_fc = loss_fc.cuda() | ||
|
||
optimizer = torch.optim.Adam \ | ||
(model.parameters(), lr=learn_rate, betas=(0.9, 0.99), weight_decay=w_decay) | ||
|
||
# Training loss & Valid loss | ||
Tra_ls, Val_ls = [], [] | ||
for _epoch in range(0, epoch): | ||
model.train() | ||
tra_ave_ls = 0 | ||
for i, data in enumerate(data_loader): | ||
train_x, train_y = data | ||
# While traning, transfer the data to GPU | ||
if (use_gpu): | ||
train_x, train_y = train_x.cuda(), train_y.cuda() | ||
predict_y = model(Variable(train_x)) | ||
loss = loss_fc(model(Variable(train_x)), train_y) | ||
optimizer.zero_grad() | ||
loss.backward() | ||
optimizer.step() | ||
tra_ave_ls += loss.item() | ||
tra_ave_ls /= iter_num | ||
Tra_ls.append(tra_ave_ls) | ||
model.eval() | ||
if (use_gpu): | ||
valid_label_x, valid_label_y = valid_label_x.cuda(), valid_label_y.cuda() | ||
val_ls = loss_fc(model(valid_label_x), valid_label_y).item() | ||
Val_ls.append(val_ls) | ||
# print('epoch [{}/{}],train:{:.4f}, valid:{:.4f}'. | ||
# format(_epoch + 1, epoch, tra_ave_ls, val_ls)) | ||
# # if _epoch % 5 == 0 : print('epoch [{}/{}],train:{:.4f}, valid:{:.4f}'. | ||
# format(_epoch + 1, epoch, tra_ave_ls,val_ls)) | ||
|
||
# Prediction | ||
model.eval() | ||
x_2d = torch.tensor(x_2d, dtype=torch.float32) | ||
if (use_gpu): | ||
x_2d = x_2d.cuda() | ||
prediction_y = model(x_2d) # [num, band] | ||
loss_fn = torch.nn.MSELoss(reduce=False, size_average=False) | ||
if (use_gpu): | ||
loss_fn = loss_fn.cuda() | ||
input_y = torch.autograd.Variable(torch.from_numpy(y_2d)).float() # [num, band] | ||
if (use_gpu): | ||
input_y = input_y.cuda() | ||
loss = loss_fn(input_y, prediction_y) | ||
if (use_gpu): | ||
loss,prediction_y= loss.cpu(),prediction_y.cpu() | ||
loss_m1 = np.sum(loss.detach().numpy(), axis=1).reshape(img_height, img_width) # axis=1,[num, 1] | ||
prediction_y = prediction_y.detach().numpy().transpose(). \ | ||
reshape([img_channel, img_height, img_width, ]) | ||
|
||
return model, loss_m1, prediction_y, Tra_ls, Val_ls |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,84 @@ | ||
# coding=utf-8 | ||
import numpy as np | ||
from torch import nn | ||
from torch.nn import init | ||
|
||
|
||
def initNetParams(net): | ||
'''Init net parameters.''' | ||
for m in net.modules(): | ||
if isinstance(m, nn.Conv2d): | ||
m.weight.data.normal_(0, 0.001) | ||
#init.xavier_uniform(m.weight) | ||
init.constant(m.bias, 0) | ||
#if m.bias: | ||
|
||
elif isinstance(m, nn.BatchNorm2d): | ||
init.constant(m.weight, 1) | ||
init.constant(m.bias, 0) | ||
elif isinstance(m, nn.Linear): | ||
nn.init.kaiming_normal_(m.weight.data) | ||
m.bias.data.fill_(0) | ||
|
||
|
||
# for evaluating the performance of the anomaly change detection result | ||
def plot_roc(predict, ground_truth): | ||
""" | ||
INPUTS: | ||
predict - anomalous change intensity map | ||
ground_truth - 0or1 | ||
OUTPUTS: | ||
X, Y for ROC plotting | ||
auc | ||
""" | ||
max_value = np.max(ground_truth) | ||
if max_value != 1: | ||
ground_truth = ground_truth / max_value | ||
|
||
# initial point(1.0, 1.0) | ||
x = 1.0 | ||
y = 1.0 | ||
hight_g, width_g = ground_truth.shape | ||
hight_p, width_p = predict.shape | ||
if hight_p != hight_g: | ||
predict = np.transpose(predict) | ||
|
||
ground_truth = ground_truth.reshape(-1) | ||
equals_one1 = np.where(ground_truth == 1) | ||
predict = predict.reshape(-1) | ||
# compuate the number of positive and negagtive pixels of the ground_truth | ||
pos_num = np.sum(ground_truth == 1) | ||
neg_num = np.sum(ground_truth == 0) | ||
# step in axis of X and Y | ||
x_step = 1.0 / neg_num | ||
y_step = 1.0 / pos_num | ||
# ranking the result map | ||
index = np.argsort(list(predict)) | ||
# predict = sorted(predict) | ||
ground_truth = ground_truth[index] | ||
equals_one2 = np.where(ground_truth == 1) | ||
""" | ||
for i in ground_truth: | ||
when ground_truth[i] = 1, TP minus 1,one y_step in the y axis, go down | ||
when ground_truth[i] = 0, FP minus 1,one x_step in the x axis, go left | ||
""" | ||
X = np.zeros(ground_truth.shape) | ||
Y = np.zeros(ground_truth.shape) | ||
for idx in range(0, hight_g * width_g): | ||
if ground_truth[idx] == 1: | ||
y = y - y_step | ||
else: | ||
x = x - x_step | ||
X[idx] = x | ||
Y[idx] = y | ||
|
||
auc = -np.trapz(Y, X) | ||
if auc < 0.5: | ||
auc = -np.trapz(X, Y) | ||
t = X | ||
X = Y | ||
Y = t | ||
|
||
return X, Y, auc | ||
|
||
|
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,77 @@ | ||
# coding=utf-8 | ||
import numpy as np | ||
import torch | ||
import time | ||
import scipy.io as sio | ||
from PIL import Image | ||
import os | ||
import models | ||
import common_func | ||
import Train_func | ||
|
||
|
||
def train(x, y, idx_tra, idx_vld,h1, h2, lrn_rt,w_decay,ground_truth): | ||
model_1 = models.AutoEncoder(in_dim=127, hid_dim1=h1, hid_dim2=h2) | ||
model_2 = models.AutoEncoder(in_dim=127, hid_dim1=h1, hid_dim2=h2) | ||
model_1.apply(common_func.initNetParams) | ||
model_2.apply(common_func.initNetParams) # function:initNetParams | ||
epoch, bth_sz, = 200, 256 | ||
use_gpu = torch.cuda.is_available() # False | ||
model_1, ls_m1, prdt_y, T_ls_1, V_ls_1 = Train_func.train_predict( | ||
x, y, model_1, idx_tra, idx_vld, epoch, bth_sz, lrn_rt, w_decay,use_gpu) | ||
model_2, ls_m2, prdt_x, T_ls_2, V_ls_2 = Train_func.train_predict( | ||
y, x, model_2, idx_tra, idx_vld, epoch, bth_sz, lrn_rt, w_decay,use_gpu) | ||
loss_result = np.minimum(ls_m1, ls_m2) | ||
X, Y, auc = common_func.plot_roc(loss_result.transpose(), ground_truth) | ||
print("auc is ", auc,'\n') | ||
return loss_result,ls_m1,ls_m2,prdt_y,prdt_x | ||
|
||
|
||
if __name__ == '__main__': | ||
start = time.time() | ||
os.environ['CUDA_VISIBLE_DEVICES'] = "2" | ||
use_gpu = torch.cuda.is_available() | ||
|
||
# Step1 : Read Data | ||
path_name = '/data/meiqi.hu/PycharmProjects/HyperspectralACD/AE/ACDA/' | ||
EX, img_2, train_smp, valid_smp = 'EX1', 'img_2', 'un_idx_train1', 'un_idx_valid1' | ||
ground_truth = Image.open(path_name + 'ref_EX1.bmp') | ||
if EX == 'EX2': | ||
img_2, train_smp, valid_smp = 'img_3', 'un_idx_train2', 'un_idx_valid2' | ||
ground_truth = Image.open(path_name + 'ref_EX2.bmp') | ||
# read image data | ||
# img_data : img_1,img_2,img_3(de-striping, noise-whitening and spectrally binning) | ||
data_filename = 'img_data.mat' | ||
data = sio.loadmat(path_name + data_filename) | ||
img_x0 = data['img_1'] | ||
img_y0 = data[img_2] | ||
input_x = img_x0.transpose((2, 1, 0)) | ||
input_y = img_y0.transpose((2, 1, 0)) | ||
# read pre-train samples from pretraining result of USFA | ||
# for different training strategy(only replace the training samples) | ||
TrainSmp_filename = 'groundtruth_samples.mat' # groundtruth_samples random_samples pretrain_samples | ||
TrainSmp = sio.loadmat(path_name + TrainSmp_filename) | ||
un_idx_train = TrainSmp[train_smp].squeeze() | ||
un_idx_valid = TrainSmp[valid_smp].squeeze() | ||
img_channel, img_height, img_width = input_x.shape | ||
|
||
# Step2 : for experiemntal result | ||
Loss_result = np.zeros([img_height, img_width], dtype=float) | ||
h1, h2 = 60, 40 # 127, 127 | ||
learn_rate, w_decay = 0.001, 0.001 | ||
iter = 1 | ||
Loss_result = np.zeros([img_height, img_width], dtype=float) | ||
for i in np.arange(1, 1 + iter): | ||
print('epoch i =', i) | ||
loss_result, ls_m1, ls_m2, prdt_y, prdt_x = train(input_x, input_y, un_idx_train, un_idx_valid, h1, h2, | ||
learn_rate, w_decay,ground_truth) | ||
Loss_result = Loss_result + loss_result | ||
Loss_result = Loss_result / iter | ||
X, Y, auc = common_func.plot_roc(Loss_result.transpose(), ground_truth) | ||
|
||
print("auc is ", auc, '\n') | ||
print("-------------Ending---------------") | ||
print(" ") | ||
print(EX) | ||
end = time.time() | ||
print("共用时", (end - start), "秒") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
#coding=utf-8 | ||
|
||
from torch import nn | ||
import torch.nn.functional as F | ||
import torch | ||
import numpy as np | ||
from torch.autograd import Variable | ||
import sys | ||
sys.path.append('/data/meiqi.hu/PycharmProjects/HyperspectralACD/MyFunction/') | ||
import common_func | ||
from PIL import Image | ||
import random | ||
|
||
|
||
|
||
class AutoEncoder(nn.Module): | ||
def __init__(self, in_dim, hid_dim1,hid_dim2): | ||
super(AutoEncoder, self).__init__() | ||
self.out_dim = in_dim | ||
|
||
self.encoder = nn.Sequential( | ||
nn.Linear(in_dim, hid_dim1, bias=True), | ||
nn.ReLU(), | ||
nn.Linear(hid_dim1, hid_dim2, bias=True), | ||
nn.ReLU(), | ||
) | ||
self.decoder = nn.Sequential( | ||
nn.Linear(hid_dim2, hid_dim1, bias=True), | ||
nn.ReLU(), | ||
nn.Linear(hid_dim1, in_dim, bias=True), | ||
nn.ReLU(), | ||
) | ||
|
||
def forward(self, x): | ||
feature = self.encoder(x) | ||
cons_x = self.decoder(feature) | ||
return cons_x |
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.