-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathestimate_d.py
75 lines (63 loc) · 2.67 KB
/
estimate_d.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
# -*- coding: utf-8 -*-
"""
Created on Thu Mar 17 18:48:13 2022
Dummy code for estimate ID once singular values are calculated
@author: Horvat
"""
from matplotlib import pyplot as plt
import numpy as np
import os
from utils import ID_NF_estimator
#hyperparameters
n_sigmas = 20 # number N of trained NFs
datadim = 3 # data dimension D
batch_size = 10 # number K of samples where sing. values where estimated on
save_path = r'outputs' # path where to save output files
# Define the sigma range used during training
# In the paper, we used an equidistant sigma range (in log-scale) from
# sig2_0 to sig2_1
sig2_0 = 1e-09
sig2_1 = 2.0
delta = np.log( (sig2_1 / sig2_0)**(1/(n_sigmas-1)) )
sigmas = np.zeros(n_sigmas) + sig2_0
for k in range(n_sigmas-1):
sigmas[k+1] = sigmas[k] * np.exp(delta)
# load here your sing values of shapoe NxKxD
sing_values_batch = np.abs(np.random.randn(n_sigmas,batch_size,datadim))
# in the paper, we find that averaging the sing. values across all batches
# reduces noise in the estimate; however, if you are interested in a local estimator
# set "local_estimator = True"
local_estimator = False
data_type = 'image'
# option for plotting sing. values as a functions of sig2
plot = True
if local_estimator:
d_hat = np.zeros(batch_size)
if plot:
fig = plt.figure(figsize=(20,10))
ax = fig.add_subplot(111)
for d in range(datadim):
sing_values_d = sing_values_batch[:,0,d]
ax.plot(sigmas,sing_values_d) #,c=colors[n],label=labels[n]
plt.yscale('log')#, nonposy='clip')
plt.xscale('log')#, nonposx='clip')
plt.savefig(os.path.join(save_path, 'sing_values_vs_sig2'+'.pdf'))
for k in range(batch_size):
sing_values = sing_values_batch[:,k,:]
d_hat[k] = ID_NF_estimator(sing_values,sigmas,datadim,mode=data_type,latent_dim=2,plot=True,tag=str(k),save_path=save_path)
print('--estimate mean ', d_hat.mean())
else:
sing_values_mean = sing_values_batch.mean(axis=1)
if plot:
fig = plt.figure(figsize=(20,10))
ax = fig.add_subplot(111)
for d in range(datadim):
sing_values_d = sing_values_mean[:,d]
ax.plot(sigmas,sing_values_d) #,c=colors[n],label=labels[n]
plt.yscale('log')#, nonposy='clip')
plt.xscale('log')#, nonposx='clip')
plt.savefig(os.path.join(save_path, 'sing_values_vs_sig2'+'.pdf'))
# estimate ID based on ID_NF, see documentation of this function for details
d_hat = ID_NF_estimator(sing_values_mean,sigmas,datadim,mode=data_type,latent_dim=2,plot=True,tag=str(datadim),save_path=save_path)
print('--estimate ', d_hat)
np.save(os.path.join(save_path, 'd_hat.npy'),d_hat)