|
65 | 65 | import time
|
66 | 66 | import csv
|
67 | 67 | import os
|
68 |
| - |
69 | 68 | import numpy as np
|
70 |
| -import torch; torch.manual_seed(0) |
71 |
| -import torch.nn as nn |
72 |
| -import torch.nn.functional as F |
73 |
| -import torch.utils |
74 |
| -import torch.distributions |
75 |
| -import torchvision |
76 |
| - |
77 | 69 | import sys
|
78 |
| - |
79 | 70 | import argparse
|
| 71 | +import ConfigParser |
| 72 | + |
| 73 | +config = ConfigParser.RawConfigParser() |
| 74 | +config.read('config.properties') |
80 | 75 |
|
81 | 76 | parser = argparse.ArgumentParser(description='Process some integers.')
|
82 | 77 | parser.add_argument('-s','--source',type=str,help="source atlas")
|
@@ -111,148 +106,48 @@ def shuffle_list(l):
|
111 | 106 |
|
112 | 107 |
|
113 | 108 |
|
114 |
| -atlases = ['dosenbach','schaefer','brainnetome','power','craddock','shen','shen368','craddock400'] |
115 |
| -#atlases = ['dosenbach','schaefer','brainnetome','power','craddock','shen'] |
116 |
| -#atlases = ['dosenbach','power','shen'] |
117 |
| -#atlases = ['power','shen'] |
118 |
| -#atlases = ['power','shen'] |
119 |
| -#atlases = ['shen368','craddock400'] |
120 |
| -#atlases = ['dosenbach','schaefer','brainnetome','power','craddock','shen368','craddock400'] |
121 | 109 | atlases = ['dosenbach','schaefer','brainnetome','power','shen','craddock']
|
122 |
| -#atlases = ['shen','schaefer'] |
123 | 110 | atlases = shuffle_list(atlases)
|
124 |
| -#tasks = ["gambling","wm","motor","lang","social","relational","emotion","rest1"] |
125 |
| -#tasks = ["gambling","wm","motor","lang"] |
126 |
| -#tasks = ["social","relational","emotion","rest1"] |
127 |
| -#tasks = ["gambling","wm"]#,"motor","lang","social","relational","emotion","rest1"] |
128 |
| -#tasks = ["social","relational"] |
129 |
| -#tasks = ["gambling","wm","motor","lang","emotion","rest1"] |
130 |
| -tasks = ["rest1","gambling","wm","motor","lang","social","relational","emotion"] |
| 111 | +tasks = ["rest1"]#,"gambling","wm","motor","lang","social","relational","emotion"] |
131 | 112 | tasks = shuffle_list(tasks)
|
132 | 113 |
|
133 | 114 |
|
134 |
| -path = '/data_dustin/store4/Templates/HCP' |
135 |
| - |
136 | 115 | coord = {}
|
137 | 116 | all_data = {}
|
138 |
| -coord['schaefer'] =pd.read_csv('/data_dustin/store4/Templates/schaefer_coords.csv', sep=',',header=None) |
139 |
| -coord['brainnetome'] =pd.read_csv('/data_dustin/store4/Templates/brainnetome_coords.csv', sep=',',header=None) |
140 |
| -coord['shen'] =pd.read_csv('/data_dustin/store4/Templates/shen_coords.csv', sep=',',header=None) |
141 |
| -coord['shen368'] =pd.read_csv('/data_dustin/store4/Templates/shen_368_coords.csv', sep=',',header=None) |
142 |
| -coord['power'] =pd.read_csv('/data_dustin/store4/Templates/power_coords.txt', sep=',',header=None) |
143 |
| -coord['dosenbach'] =pd.read_csv('/data_dustin/store4/Templates/dosenbach_coords.txt', sep=',',header=None) |
144 |
| -coord['craddock'] =pd.read_csv('/data_dustin/store4/Templates/craddock_coords.txt', sep=',',header=None) |
145 |
| -coord['craddock400'] =pd.read_csv('/data_dustin/store4/Templates/craddock_400_coords.txt', sep=',',header=None) |
146 | 117 |
|
147 | 118 | # Loading Atlas ...
|
148 | 119 | tasks = [args.task]
|
149 | 120 | atlases = [args.source,args.target]
|
150 | 121 |
|
| 122 | +dataset_path = {} |
| 123 | +for atlas in tqdm(atlases,desc = 'Loading config file ..'): |
| 124 | + dataset_path[atlas]=config.get('path',atlas) |
| 125 | + coord[atlas]= pd.read_csv(config.get('coord',atlas), sep=',',header=None) |
| 126 | + |
151 | 127 | for atlas in tqdm(atlases,desc = 'Loading Atlases ..'):
|
152 | 128 | zero_nodes = set()
|
153 | 129 |
|
| 130 | + data = sio.loadmat(dataset_path[atlas]) |
154 | 131 | for task in tasks:
|
155 |
| - data = sio.loadmat(os.path.join(path,atlas,task+'.mat')) |
156 | 132 | x = data['all_mats']
|
157 | 133 | idx = np.argwhere(np.all(x[..., :] == 0, axis=0))
|
158 | 134 | p = [p1 for (p1,p2) in idx]
|
159 | 135 | zero_nodes.update(p)
|
160 | 136 |
|
161 | 137 | for task in tasks:
|
162 |
| - data = sio.loadmat(os.path.join(path,atlas,task+'.mat')) |
163 | 138 | x = data['all_mats']
|
164 | 139 | print(atlas,task,x.shape)
|
165 | 140 | np.delete(x,list(zero_nodes),1)
|
166 | 141 | all_data[(atlas,task)] = x
|
167 | 142 |
|
| 143 | + data = {"all_mats":x[0:10,:,:]} |
| 144 | + sio.savemat(atlas+".mat",data) |
| 145 | + |
| 146 | + |
| 147 | + |
| 148 | +all_behav = genfromtxt(config.get('behavior','iq'), delimiter=',') |
| 149 | +all_sex = genfromtxt(config.get('behavior','gender'), delimiter=',') |
168 | 150 |
|
169 |
| -all_behav = genfromtxt('data/268/all_behav.csv', delimiter=',') |
170 |
| -all_sex = genfromtxt('data/268/gender.csv', delimiter=',') |
171 |
| - |
172 |
| - |
173 |
| -class Decoder(nn.Module): |
174 |
| - def __init__(self,p1,p2, latent_dims): |
175 |
| - super(Decoder, self).__init__() |
176 |
| - self.linear1 = nn.Linear(latent_dims, 512) |
177 |
| - self.linear2 = nn.Linear(512, p1*p2) |
178 |
| - self.p1 = p1 |
179 |
| - self.p2 = p2 |
180 |
| - |
181 |
| - def forward(self, z): |
182 |
| - z = F.relu(self.linear1(z)) |
183 |
| - z = torch.sigmoid(self.linear2(z)) |
184 |
| - return z#.reshape((-1, 1, self.p1, self.p2)) |
185 |
| - |
186 |
| -class Encoder(nn.Module): |
187 |
| - def __init__(self, p1,p2,latent_dims): |
188 |
| - super(Encoder, self).__init__() |
189 |
| - self.linear1 = nn.Linear(p1*p2, 512) |
190 |
| - self.linear2 = nn.Linear(512, latent_dims) |
191 |
| - self.linear3 = nn.Linear(512, latent_dims) |
192 |
| - |
193 |
| - self.N = torch.distributions.Normal(0, 1) |
194 |
| - self.N.loc = self.N.loc#.cuda() # hack to get sampling on the GPU |
195 |
| - self.N.scale = self.N.scale#.cuda() |
196 |
| - self.kl = 0 |
197 |
| - |
198 |
| - def forward(self, x): |
199 |
| - x = torch.flatten(x, start_dim=1) |
200 |
| - x = F.relu(self.linear1(x)) |
201 |
| - mu = self.linear2(x) |
202 |
| - sigma = torch.exp(self.linear3(x)) |
203 |
| - z = mu + sigma*self.N.sample(mu.shape) |
204 |
| - self.kl = (sigma**2 + mu**2 - torch.log(sigma) - 1/2).sum() |
205 |
| - return z |
206 |
| - |
207 |
| -class VAE(nn.Module): |
208 |
| - def __init__(self,p1=200,p2=268,latent_dims=2): |
209 |
| - super(VAE, self).__init__() |
210 |
| - self.encoder = Encoder(p1,p2,latent_dims) |
211 |
| - self.decoder = Decoder(p1,p2,latent_dims) |
212 |
| - |
213 |
| - def forward(self, x): |
214 |
| - z = self.encoder(x) |
215 |
| - return self.decoder(z) |
216 |
| - |
217 |
| -def train(autoencoder, source,target, T, epochs=40): |
218 |
| - print(source,target) |
219 |
| - print(T.shape) |
220 |
| - opt = torch.optim.Adam(autoencoder.parameters()) |
221 |
| - p1 = T.shape[0] |
222 |
| - p2 = T.shape[1] |
223 |
| - TEST_FREQUENCY = 5 |
224 |
| - train_elbo = [] |
225 |
| - num_tasks = T.shape[0] |
226 |
| - for epoch in range(epochs): |
227 |
| - y = T |
228 |
| - y = y.reshape(1,p1*p2) # GPU |
229 |
| - opt.zero_grad() |
230 |
| - y[y<0.001] = 0 |
231 |
| - y_hat = autoencoder(y) |
232 |
| - loss = ((y - y_hat)**2).sum() |
233 |
| - #cos = nn.CosineSimilarity(dim=1, eps=1e-6) |
234 |
| - #loss = cos(y,y_hat) |
235 |
| - if torch.isnan(loss): |
236 |
| - return autoencoder,T |
237 |
| - train_elbo.append(loss) |
238 |
| - loss.backward() |
239 |
| - opt.step() |
240 |
| - |
241 |
| - if epoch % TEST_FREQUENCY == 0: |
242 |
| - print("[epoch %03d] average test loss: %.4f" % (epoch, float(sum(train_elbo)/len(train_elbo)))) |
243 |
| - train_elbo = [] |
244 |
| - return autoencoder,y_hat.reshape(p1,p2) |
245 |
| - |
246 |
| - |
247 |
| -def evaluate_vae(svi, test_loader, use_cuda=False): |
248 |
| - test_loss = 0. |
249 |
| - for x, _ in test_loader: |
250 |
| - if use_cuda: |
251 |
| - x = x.cuda() |
252 |
| - test_loss += svi.evaluate_loss(x) |
253 |
| - normalizer_test = len(test_loader.dataset) |
254 |
| - total_epoch_loss_test = test_loss / normalizer_test |
255 |
| - return total_epoch_loss_test |
256 | 151 |
|
257 | 152 | def normalize(x):
|
258 | 153 | if all(v == 0 for v in x):
|
|
0 commit comments