-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathTrain.py
62 lines (37 loc) · 1.87 KB
/
Train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
# Graph-Tensor Neural Networks for Network Traffic Data Imputation
#
import LoadData as LD
import numpy as np
import BuildModel as BM
import TrainModel as TM
import DefineParam as DP
import os
import scipy.io as sio
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
# 1. Input: Parameters
pixel_w, pixel_h, batchSize, nPhase, nTrainData, nValData, learningRate, nEpoch, nOfModel, ncpkt, trainFile, valFile, saveDir, modelDir = DP.get_param()
# 2. Data Loading
print('-------------------------------------\nLoading Data...\n-------------------------------------\n')
trainLabel, valLabel = LD.load_train_data(mat73=False)#
trainPhi = np.ones([batchSize, 12, 24, 144])#
L= trainLabel.shape[3]
num_missing = round(0.7*L) # select 70% OD pairs for taking measurements.
index = np.arange(L, dtype=int)
np.random.seed(1)
np.random.shuffle(index)
# sio.savemat("index.mat", {'index': index})
# idx = sio.loadmat('.\index_144.mat')
# index = idx['randindex1']
# index=index.astype(np.int32)
missing_index = (index[:num_missing])
print(missing_index)
for index_x in missing_index:
trainPhi[:, :, :, index_x] = 0
trainPhi = trainPhi.astype('float32')
# 3. Model Building
print('-------------------------------------\nBuilding Model...\n-------------------------------------\n')
sess, saver, Xinput, Xoutput, Epoch_num, costMean, costSymmetric, costSparsity, optmAll, Yinput, prediction, lambdaStep, softThr, transField = BM.build_model(trainPhi, missing_index)
# 4. Model Training
print('-------------------------------------\nTraining Model...\n-------------------------------------\n')
TM.train_model(sess, saver, costMean, costSymmetric, costSparsity, optmAll, Yinput, prediction, trainLabel, valLabel, trainPhi, Xinput, Xoutput, Epoch_num, lambdaStep, softThr, missing_index, transField)
print('-------------------------------------\nTraining Accomplished.\n-------------------------------------\n')