|
| 1 | +# Author: Remi Flamary <[email protected]> |
| 2 | + |
| 3 | +# |
| 4 | +# License: MIT License |
| 5 | + |
| 6 | +import numpy as np |
| 7 | +import matplotlib.pylab as pl |
| 8 | +import matplotlib.pyplot as plt |
| 9 | +from numpy import linalg as la |
| 10 | +import scipy.sparse |
| 11 | +import ot |
| 12 | +import ot.plot |
| 13 | +from ot.datasets import make_1D_gauss as gauss |
| 14 | +import pandas as pd |
| 15 | +from sklearn import preprocessing |
| 16 | +from sklearn.preprocessing import normalize |
| 17 | +from scipy.io import loadmat |
| 18 | +import os |
| 19 | + |
| 20 | +from scipy.io import loadmat |
| 21 | +import ot |
| 22 | +import ot.plot |
| 23 | +from ot.datasets import make_1D_gauss as gauss |
| 24 | +import pandas as pd |
| 25 | +import random |
| 26 | +import scipy.io as sio |
| 27 | +from scipy.spatial.distance import cdist |
| 28 | + |
| 29 | +import ot |
| 30 | +import ot.plot |
| 31 | +from ot.datasets import make_1D_gauss as gauss |
| 32 | +import pandas as pd |
| 33 | +import random |
| 34 | +from numpy import savetxt |
| 35 | +import sys |
| 36 | +from matplotlib import pyplot as plt |
| 37 | +from scipy.stats.stats import pearsonr |
| 38 | +#num_time_points = 1 |
| 39 | +from tqdm import tqdm |
| 40 | +from sklearn.linear_model import Ridge |
| 41 | +from sklearn.model_selection import GridSearchCV |
| 42 | +from sklearn.model_selection import cross_val_predict |
| 43 | +from sklearn.model_selection import cross_val_score |
| 44 | +from sklearn.model_selection import KFold |
| 45 | +from sklearn.model_selection import RepeatedKFold |
| 46 | +from sklearn.model_selection import PredefinedSplit |
| 47 | +from sklearn.feature_selection import SelectPercentile |
| 48 | +from sklearn.feature_selection import f_regression |
| 49 | +from sklearn.model_selection import train_test_split |
| 50 | +from sklearn.pipeline import Pipeline |
| 51 | +from sklearn import metrics |
| 52 | +from sklearn.metrics import accuracy_score |
| 53 | +from sklearn import svm |
| 54 | +from time import sleep |
| 55 | +from sklearn import linear_model |
| 56 | + |
| 57 | +from sklearn.feature_selection import mutual_info_classif |
| 58 | +from sklearn.feature_selection import SelectKBest |
| 59 | +from sklearn.preprocessing import normalize |
| 60 | +import numpy as np |
| 61 | +import scipy.io as sio |
| 62 | +import h5py |
| 63 | + |
| 64 | +from numpy import genfromtxt |
| 65 | +import time |
| 66 | +import csv |
| 67 | +import os |
| 68 | + |
| 69 | +import numpy as np |
| 70 | +import torch; torch.manual_seed(0) |
| 71 | +import torch.nn as nn |
| 72 | +import torch.nn.functional as F |
| 73 | +import torch.utils |
| 74 | +import torch.distributions |
| 75 | +import torchvision |
| 76 | + |
| 77 | +import sys |
| 78 | + |
| 79 | +import argparse |
| 80 | + |
| 81 | +def dir_path(path): |
| 82 | + if os.path.isdir(path): |
| 83 | + return path |
| 84 | + else: |
| 85 | + raise argparse.ArgumentTypeError("{0} is not a valid path".format(path)) |
| 86 | + |
| 87 | +parser = argparse.ArgumentParser(description='Process some integers.') |
| 88 | +parser.add_argument('-s','--source',type=str,help="source atlas") |
| 89 | +parser.add_argument('-t','--target',type=str,help="target atlas") |
| 90 | +parser.add_argument('-task','--task',type=str,default="rest1",help="task") |
| 91 | +parser.add_argument('-m','--mapping',help="path to mapping") |
| 92 | + |
| 93 | +args = parser.parse_args() |
| 94 | + |
| 95 | + |
| 96 | +#argv = sys.argv[1:] |
| 97 | +def shuffle_list(l): |
| 98 | + ids = np.arange(len(l)) |
| 99 | + np.random.shuffle(ids) |
| 100 | + print(ids) |
| 101 | + return np.array(l)[ids] |
| 102 | + |
| 103 | + |
| 104 | + |
| 105 | + |
| 106 | +atlases = ['dosenbach','schaefer','brainnetome','power','craddock','shen','shen368','craddock400'] |
| 107 | +atlases = ['dosenbach','schaefer','brainnetome','power','shen','craddock'] |
| 108 | +atlases = shuffle_list(atlases) |
| 109 | +tasks = ["rest1","gambling","wm","motor","lang","social","relational","emotion"] |
| 110 | +tasks = shuffle_list(tasks) |
| 111 | + |
| 112 | + |
| 113 | +path = '/data_dustin/store4/Templates/HCP' |
| 114 | + |
| 115 | +coord = {} |
| 116 | +all_data = {} |
| 117 | +coord['schaefer'] =pd.read_csv('/data_dustin/store4/Templates/schaefer_coords.csv', sep=',',header=None) |
| 118 | +coord['brainnetome'] =pd.read_csv('/data_dustin/store4/Templates/brainnetome_coords.csv', sep=',',header=None) |
| 119 | +coord['shen'] =pd.read_csv('/data_dustin/store4/Templates/shen_coords.csv', sep=',',header=None) |
| 120 | +coord['shen368'] =pd.read_csv('/data_dustin/store4/Templates/shen_368_coords.csv', sep=',',header=None) |
| 121 | +coord['power'] =pd.read_csv('/data_dustin/store4/Templates/power_coords.txt', sep=',',header=None) |
| 122 | +coord['dosenbach'] =pd.read_csv('/data_dustin/store4/Templates/dosenbach_coords.txt', sep=',',header=None) |
| 123 | +coord['craddock'] =pd.read_csv('/data_dustin/store4/Templates/craddock_coords.txt', sep=',',header=None) |
| 124 | +coord['craddock400'] =pd.read_csv('/data_dustin/store4/Templates/craddock_400_coords.txt', sep=',',header=None) |
| 125 | + |
| 126 | +# Loading Atlas ... |
| 127 | +tasks = [args.task] |
| 128 | +atlases = [args.source,args.target] |
| 129 | + |
| 130 | +for atlas in tqdm(atlases,desc = 'Loading Atlases ..'): |
| 131 | + zero_nodes = set() |
| 132 | + |
| 133 | + for task in tasks: |
| 134 | + data = sio.loadmat(os.path.join(path,atlas,task+'.mat')) |
| 135 | + x = data['all_mats'] |
| 136 | + idx = np.argwhere(np.all(x[..., :] == 0, axis=0)) |
| 137 | + p = [p1 for (p1,p2) in idx] |
| 138 | + zero_nodes.update(p) |
| 139 | + |
| 140 | + for task in tasks: |
| 141 | + data = sio.loadmat(os.path.join(path,atlas,task+'.mat')) |
| 142 | + x = data['all_mats'] |
| 143 | + print(atlas,task,x.shape) |
| 144 | + np.delete(x,list(zero_nodes),1) |
| 145 | + all_data[(atlas,task)] = x |
| 146 | + |
| 147 | + |
| 148 | +def normalize(x): |
| 149 | + if all(v == 0 for v in x): |
| 150 | + return x |
| 151 | + else: |
| 152 | + if max(x) == min(x): |
| 153 | + return x |
| 154 | + else: |
| 155 | + return (x-min(x))/(max(x)-min(x)) |
| 156 | + |
| 157 | +from scipy import stats |
| 158 | +def corr2_coeff(A, B): |
| 159 | + # Rowwise mean of input arrays & subtract from input arrays themeselves |
| 160 | + A_mA = A - A.mean(1)[:, None] |
| 161 | + B_mB = B - B.mean(1)[:, None] # Sum of squares across rows |
| 162 | + ssA = (A_mA**2).sum(1) |
| 163 | + ssB = (B_mB**2).sum(1) |
| 164 | + return np.dot(A_mA, B_mB.T) / np.sqrt(np.dot(ssA[:, None],ssB[None])) |
| 165 | + |
| 166 | +def generate_correlation_map(x, y): |
| 167 | + mu_x = x.mean(1) |
| 168 | + mu_y = y.mean(1) |
| 169 | + n = x.shape[1] |
| 170 | + if n != y.shape[1]: |
| 171 | + raise ValueError('x and y must ' +'have the same number of timepoints.') |
| 172 | + s_x = x.std(1, ddof=n - 1)+1e-6 |
| 173 | + s_y = y.std(1, ddof=n - 1)+1e-6 |
| 174 | + cov = np.dot(x,y.T) - n * np.dot(mu_x[:, np.newaxis],mu_y[np.newaxis, :]) |
| 175 | + return cov / np.dot(s_x[:, np.newaxis], s_y[np.newaxis, :]) |
| 176 | + |
| 177 | + |
| 178 | + |
| 179 | +from collections import defaultdict |
| 180 | +def atlas_ot(source,target,task): |
| 181 | + print("Atlas OT .. ") |
| 182 | + n =all_data[(source,task)].shape[0] |
| 183 | + t =all_data[(source,task)].shape[2] |
| 184 | + p1 =all_data[(source,task)].shape[1] |
| 185 | + p2 =all_data[(target,task)].shape[1] |
| 186 | + n2 =all_data[(target,task)].shape[0] |
| 187 | + num_time_points = t |
| 188 | + |
| 189 | + mapping_file = args.mapping#os.chdir(args.mapping) |
| 190 | + G = genfromtxt(mapping_file,delimiter=',',skip_header=1,usecols=range(1,1+coord[args.target].shape[0])) |
| 191 | + |
| 192 | + test_time_series_pred = np.zeros((n2,p2,num_time_points)) |
| 193 | + test_time_series = [] |
| 194 | + step = 0 |
| 195 | + id_count_ot = 0 |
| 196 | + id_count_orig = 0 |
| 197 | + id_conn = {} |
| 198 | + for i in tqdm(range(n2),desc="applying OT .."): |
| 199 | + for j in range(num_time_points): |
| 200 | + |
| 201 | + a = all_data[(source,task)][i,:,j] |
| 202 | + a = a + 1e-6 |
| 203 | + a = normalize(a) |
| 204 | + b= np.transpose(G).dot(a) |
| 205 | + test_time_series_pred[i,:,j] = b |
| 206 | + test_time_series.append(test_time_series_pred[i,:,:]) |
| 207 | + |
| 208 | + return G, test_time_series_pred |
| 209 | + |
| 210 | + |
| 211 | +def evaluate(c1,c2,y,test_size): |
| 212 | + |
| 213 | + ids = np.arange(test_size) |
| 214 | + np.random.shuffle(ids) |
| 215 | + train_test_fraction = 0.9 |
| 216 | + g3_train_size = int(train_test_fraction*test_size) |
| 217 | + g3_test_size = test_size - g3_train_size |
| 218 | + |
| 219 | + g3_train_index = ids[0:g3_train_size] |
| 220 | + g3_test_index = ids[g3_train_size:] |
| 221 | + |
| 222 | + clf = Ridge(alpha=1.0) |
| 223 | + clf.fit(c1[g3_train_index,:], y[g3_train_index]) |
| 224 | + all_pred_1 = clf.predict(c1[g3_test_index]) |
| 225 | + |
| 226 | + clf.fit(c2[g3_train_index,:], y[g3_train_index]) |
| 227 | + all_pred_2 = clf.predict(c2[g3_test_index]) |
| 228 | + |
| 229 | + #print(all_pred_1.shape,y[g3_test_size]) |
| 230 | + s1 = np.corrcoef(all_pred_1, y[g3_test_index])[0, 1] |
| 231 | + s2 = np.corrcoef(all_pred_2, y[g3_test_index])[0, 1] |
| 232 | + return s1,s2 |
| 233 | + |
| 234 | + |
| 235 | + |
| 236 | +from scipy.stats.stats import spearmanr |
| 237 | +import sys |
| 238 | + |
| 239 | +if __name__ == "__main__": |
| 240 | + source = args.source |
| 241 | + target= args.target |
| 242 | + task = args.task |
| 243 | + random.seed(3000) |
| 244 | + |
| 245 | + n =all_data[(atlases[0],tasks[0])].shape[0] |
| 246 | + G, test_time_series_pred = atlas_ot(source,target,task) |
| 247 | + data = {"data":test_time_series_pred} |
| 248 | + sio.savemat("time_series_"+target+"_ot.mat",data) |
| 249 | + #df = pd.DataFrame(test_time_series_pred) |
| 250 | + #df = pd.Panel().to_frame(),stacj().reset_index()#DataFrame(test_time_series_pred) |
| 251 | + #df.to_csv("time_series_"+target+"_ot.csv") |
| 252 | + |
0 commit comments