-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathsparsity_experiment.py
63 lines (52 loc) · 2.01 KB
/
sparsity_experiment.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
# Libraries
import time
import utils
import numpy as np
from sista import SISTA
# Parameters
N, d = 100, 200
n_experiments = 5
T, rho, gamma, beta0, v0 = 2.5, 1e-1, 1e-5, np.zeros(d), np.ones(N)
# Generate data for the experiment
X, Y = utils.generate_clouds(N, d)
p, q = utils.generate_measures(N, uniform = False)
D = utils.generate_D(X, Y)
# Test the algo for different level of sparsity
res = np.zeros((d,4))
for sparisity in range(1,d+1):
# Do a certain number of experiments for the each level
mean_correct_proportion = 0.
mean_relative_error = 0.
t_start = time.time()
for n in range(1,n_experiments):
# Create beta and D
beta = np.zeros(d)
indices = np.random.choice(d, sparisity, replace = False)
beta[indices] = 1.
C = np.sum(D * beta[:, np.newaxis, np.newaxis], axis = 0)
# Use Sinkhorn to find pi_hat
pi_hat = utils.sinkhorn(C, p, q, T)
pi_hat = np.round(pi_hat, 6)
# Use SISTA to solve the inverse problem
sista = SISTA(pi_hat, D, p, q)
beta_tilde, _ = sista.solve(beta0, v0, rho, gamma, T)
# Update results
mean_relative_error += np.linalg.norm(beta - beta_tilde)/np.linalg.norm(beta)
mean_correct_proportion += np.sum(np.abs(beta - beta_tilde) <= 1e-2)/len(beta)
# Update results
t_end = time.time()
mean_relative_error /= n
mean_correct_proportion /= n
execution_time = np.rint(t_end - t_start).astype(int)
res[sparisity-1,0] = (sparisity/d)*100
res[sparisity-1,1] = mean_relative_error
res[sparisity-1,2] = mean_correct_proportion
res[sparisity-1,3] = execution_time
# Print the results
print("Sparsity level: {:.1f}%".format((sparisity/d)*100))
print("Execution time: {}s".format(execution_time))
print("Mean relative error: {:.3f}".format(mean_relative_error))
print("Mean correct proportion: {:.2f}".format(mean_correct_proportion))
print("")
# Save the results for further analysis
np.save('./results/sparsity.npy', res)