Skip to content

Commit edb4933

Browse files
authored
Add files via upload
1 parent 4f8e9b3 commit edb4933

File tree

1 file changed

+252
-0
lines changed

1 file changed

+252
-0
lines changed

code/python/carot.py

+252
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,252 @@
1+
# Author: Remi Flamary <[email protected]>
2+
3+
#
4+
# License: MIT License
5+
6+
import numpy as np
7+
import matplotlib.pylab as pl
8+
import matplotlib.pyplot as plt
9+
from numpy import linalg as la
10+
import scipy.sparse
11+
import ot
12+
import ot.plot
13+
from ot.datasets import make_1D_gauss as gauss
14+
import pandas as pd
15+
from sklearn import preprocessing
16+
from sklearn.preprocessing import normalize
17+
from scipy.io import loadmat
18+
import os
19+
20+
from scipy.io import loadmat
21+
import ot
22+
import ot.plot
23+
from ot.datasets import make_1D_gauss as gauss
24+
import pandas as pd
25+
import random
26+
import scipy.io as sio
27+
from scipy.spatial.distance import cdist
28+
29+
import ot
30+
import ot.plot
31+
from ot.datasets import make_1D_gauss as gauss
32+
import pandas as pd
33+
import random
34+
from numpy import savetxt
35+
import sys
36+
from matplotlib import pyplot as plt
37+
from scipy.stats.stats import pearsonr
38+
#num_time_points = 1
39+
from tqdm import tqdm
40+
from sklearn.linear_model import Ridge
41+
from sklearn.model_selection import GridSearchCV
42+
from sklearn.model_selection import cross_val_predict
43+
from sklearn.model_selection import cross_val_score
44+
from sklearn.model_selection import KFold
45+
from sklearn.model_selection import RepeatedKFold
46+
from sklearn.model_selection import PredefinedSplit
47+
from sklearn.feature_selection import SelectPercentile
48+
from sklearn.feature_selection import f_regression
49+
from sklearn.model_selection import train_test_split
50+
from sklearn.pipeline import Pipeline
51+
from sklearn import metrics
52+
from sklearn.metrics import accuracy_score
53+
from sklearn import svm
54+
from time import sleep
55+
from sklearn import linear_model
56+
57+
from sklearn.feature_selection import mutual_info_classif
58+
from sklearn.feature_selection import SelectKBest
59+
from sklearn.preprocessing import normalize
60+
import numpy as np
61+
import scipy.io as sio
62+
import h5py
63+
64+
from numpy import genfromtxt
65+
import time
66+
import csv
67+
import os
68+
69+
import numpy as np
70+
import torch; torch.manual_seed(0)
71+
import torch.nn as nn
72+
import torch.nn.functional as F
73+
import torch.utils
74+
import torch.distributions
75+
import torchvision
76+
77+
import sys
78+
79+
import argparse
80+
81+
def dir_path(path):
82+
if os.path.isdir(path):
83+
return path
84+
else:
85+
raise argparse.ArgumentTypeError("{0} is not a valid path".format(path))
86+
87+
parser = argparse.ArgumentParser(description='Process some integers.')
88+
parser.add_argument('-s','--source',type=str,help="source atlas")
89+
parser.add_argument('-t','--target',type=str,help="target atlas")
90+
parser.add_argument('-task','--task',type=str,default="rest1",help="task")
91+
parser.add_argument('-m','--mapping',help="path to mapping")
92+
93+
args = parser.parse_args()
94+
95+
96+
#argv = sys.argv[1:]
97+
def shuffle_list(l):
98+
ids = np.arange(len(l))
99+
np.random.shuffle(ids)
100+
print(ids)
101+
return np.array(l)[ids]
102+
103+
104+
105+
106+
atlases = ['dosenbach','schaefer','brainnetome','power','craddock','shen','shen368','craddock400']
107+
atlases = ['dosenbach','schaefer','brainnetome','power','shen','craddock']
108+
atlases = shuffle_list(atlases)
109+
tasks = ["rest1","gambling","wm","motor","lang","social","relational","emotion"]
110+
tasks = shuffle_list(tasks)
111+
112+
113+
path = '/data_dustin/store4/Templates/HCP'
114+
115+
coord = {}
116+
all_data = {}
117+
coord['schaefer'] =pd.read_csv('/data_dustin/store4/Templates/schaefer_coords.csv', sep=',',header=None)
118+
coord['brainnetome'] =pd.read_csv('/data_dustin/store4/Templates/brainnetome_coords.csv', sep=',',header=None)
119+
coord['shen'] =pd.read_csv('/data_dustin/store4/Templates/shen_coords.csv', sep=',',header=None)
120+
coord['shen368'] =pd.read_csv('/data_dustin/store4/Templates/shen_368_coords.csv', sep=',',header=None)
121+
coord['power'] =pd.read_csv('/data_dustin/store4/Templates/power_coords.txt', sep=',',header=None)
122+
coord['dosenbach'] =pd.read_csv('/data_dustin/store4/Templates/dosenbach_coords.txt', sep=',',header=None)
123+
coord['craddock'] =pd.read_csv('/data_dustin/store4/Templates/craddock_coords.txt', sep=',',header=None)
124+
coord['craddock400'] =pd.read_csv('/data_dustin/store4/Templates/craddock_400_coords.txt', sep=',',header=None)
125+
126+
# Loading Atlas ...
127+
tasks = [args.task]
128+
atlases = [args.source,args.target]
129+
130+
for atlas in tqdm(atlases,desc = 'Loading Atlases ..'):
131+
zero_nodes = set()
132+
133+
for task in tasks:
134+
data = sio.loadmat(os.path.join(path,atlas,task+'.mat'))
135+
x = data['all_mats']
136+
idx = np.argwhere(np.all(x[..., :] == 0, axis=0))
137+
p = [p1 for (p1,p2) in idx]
138+
zero_nodes.update(p)
139+
140+
for task in tasks:
141+
data = sio.loadmat(os.path.join(path,atlas,task+'.mat'))
142+
x = data['all_mats']
143+
print(atlas,task,x.shape)
144+
np.delete(x,list(zero_nodes),1)
145+
all_data[(atlas,task)] = x
146+
147+
148+
def normalize(x):
149+
if all(v == 0 for v in x):
150+
return x
151+
else:
152+
if max(x) == min(x):
153+
return x
154+
else:
155+
return (x-min(x))/(max(x)-min(x))
156+
157+
from scipy import stats
158+
def corr2_coeff(A, B):
159+
# Rowwise mean of input arrays & subtract from input arrays themeselves
160+
A_mA = A - A.mean(1)[:, None]
161+
B_mB = B - B.mean(1)[:, None] # Sum of squares across rows
162+
ssA = (A_mA**2).sum(1)
163+
ssB = (B_mB**2).sum(1)
164+
return np.dot(A_mA, B_mB.T) / np.sqrt(np.dot(ssA[:, None],ssB[None]))
165+
166+
def generate_correlation_map(x, y):
167+
mu_x = x.mean(1)
168+
mu_y = y.mean(1)
169+
n = x.shape[1]
170+
if n != y.shape[1]:
171+
raise ValueError('x and y must ' +'have the same number of timepoints.')
172+
s_x = x.std(1, ddof=n - 1)+1e-6
173+
s_y = y.std(1, ddof=n - 1)+1e-6
174+
cov = np.dot(x,y.T) - n * np.dot(mu_x[:, np.newaxis],mu_y[np.newaxis, :])
175+
return cov / np.dot(s_x[:, np.newaxis], s_y[np.newaxis, :])
176+
177+
178+
179+
from collections import defaultdict
180+
def atlas_ot(source,target,task):
181+
print("Atlas OT .. ")
182+
n =all_data[(source,task)].shape[0]
183+
t =all_data[(source,task)].shape[2]
184+
p1 =all_data[(source,task)].shape[1]
185+
p2 =all_data[(target,task)].shape[1]
186+
n2 =all_data[(target,task)].shape[0]
187+
num_time_points = t
188+
189+
mapping_file = args.mapping#os.chdir(args.mapping)
190+
G = genfromtxt(mapping_file,delimiter=',',skip_header=1,usecols=range(1,1+coord[args.target].shape[0]))
191+
192+
test_time_series_pred = np.zeros((n2,p2,num_time_points))
193+
test_time_series = []
194+
step = 0
195+
id_count_ot = 0
196+
id_count_orig = 0
197+
id_conn = {}
198+
for i in tqdm(range(n2),desc="applying OT .."):
199+
for j in range(num_time_points):
200+
201+
a = all_data[(source,task)][i,:,j]
202+
a = a + 1e-6
203+
a = normalize(a)
204+
b= np.transpose(G).dot(a)
205+
test_time_series_pred[i,:,j] = b
206+
test_time_series.append(test_time_series_pred[i,:,:])
207+
208+
return G, test_time_series_pred
209+
210+
211+
def evaluate(c1,c2,y,test_size):
212+
213+
ids = np.arange(test_size)
214+
np.random.shuffle(ids)
215+
train_test_fraction = 0.9
216+
g3_train_size = int(train_test_fraction*test_size)
217+
g3_test_size = test_size - g3_train_size
218+
219+
g3_train_index = ids[0:g3_train_size]
220+
g3_test_index = ids[g3_train_size:]
221+
222+
clf = Ridge(alpha=1.0)
223+
clf.fit(c1[g3_train_index,:], y[g3_train_index])
224+
all_pred_1 = clf.predict(c1[g3_test_index])
225+
226+
clf.fit(c2[g3_train_index,:], y[g3_train_index])
227+
all_pred_2 = clf.predict(c2[g3_test_index])
228+
229+
#print(all_pred_1.shape,y[g3_test_size])
230+
s1 = np.corrcoef(all_pred_1, y[g3_test_index])[0, 1]
231+
s2 = np.corrcoef(all_pred_2, y[g3_test_index])[0, 1]
232+
return s1,s2
233+
234+
235+
236+
from scipy.stats.stats import spearmanr
237+
import sys
238+
239+
if __name__ == "__main__":
240+
source = args.source
241+
target= args.target
242+
task = args.task
243+
random.seed(3000)
244+
245+
n =all_data[(atlases[0],tasks[0])].shape[0]
246+
G, test_time_series_pred = atlas_ot(source,target,task)
247+
data = {"data":test_time_series_pred}
248+
sio.savemat("time_series_"+target+"_ot.mat",data)
249+
#df = pd.DataFrame(test_time_series_pred)
250+
#df = pd.Panel().to_frame(),stacj().reset_index()#DataFrame(test_time_series_pred)
251+
#df.to_csv("time_series_"+target+"_ot.csv")
252+

0 commit comments

Comments
 (0)