forked from Westlake-AI/Markov-Lipschitz-Deep-Learning
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdataset.py
182 lines (141 loc) · 6.05 KB
/
dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
import torch
import numpy as np
from sklearn.datasets import make_s_curve
from samples_generator_new import make_swiss_roll
from torchvision import transforms
from torchvision import datasets as torchvisiondatasets
# shuffling the data to get an index of the individual batches
class SampleIndexGenerator():
def __init__(self, data, batch_size):
self.num_train_sample = data.shape[0]
self.batch_size = batch_size
self.Reset()
def Reset(self):
self.unuse_index = torch.randperm(self.num_train_sample).tolist()
def CalSampleIndex(self, batch_idx):
use_index = self.unuse_index[:self.batch_size]
self.unuse_index = self.unuse_index[self.batch_size:]
return use_index
def dsphere(n=100, d=2, r=1, noise=None, ambient=None):
"""
Sample `n` data points on a d-sphere.
Arguments:
n {int} -- number of data points in shape
r {float} -- radius of sphere
ambient {int, default=None} -- Embed the sphere into a space with ambient dimension equal to `ambient`. The sphere is randomly rotated in this high dimensional space.
"""
data = np.random.randn(n, d+1)
# normalize points to the sphere
data = r * data / np.sqrt(np.sum(data**2, 1)[:, None])
if noise:
data += noise * np.random.randn(*data.shape)
if ambient:
assert ambient > d, "Must embed in higher dimensions"
data = embed(data, ambient)
return data
def create_sphere_dataset10000(n_samples=500, d=100, n_spheres=11, r=5, seed=42):
np.random.seed(seed)
# it seemed that rescaling the shift variance by sqrt of d lets big sphere stay around the inner spheres
variance = 10/np.sqrt(d)
shift_matrix = np.random.normal(0, variance, [n_spheres, d+1])
spheres = []
n_datapoints = 0
for i in np.arange(n_spheres-1):
sphere = dsphere(n=n_samples, d=d, r=r)
spheres.append(sphere + shift_matrix[i, :])
n_datapoints += n_samples
# additional big surrounding sphere:
n_samples_big = 10*n_samples # int(n_samples/2)
big = dsphere(n=n_samples_big, d=d, r=r*5)
spheres.append(big)
n_datapoints += n_samples_big
# create Dataset
dataset = np.concatenate(spheres, axis=0)
labels = np.zeros(n_datapoints)
label_index = 0
for index, data in enumerate(spheres):
n_sphere_samples = data.shape[0]
labels[label_index:label_index + n_sphere_samples] = index
label_index += n_sphere_samples
index_seed = np.linspace(0, 10000, num=20, dtype='int16', endpoint=False)
arr = np.array([], dtype='int16')
for i in range(500):
arr = np.concatenate((arr, index_seed+int(i)))
dataset = dataset[arr]
labels = labels[arr]
return dataset/22 + 0.5, labels
def create_sphere_dataset5500(n_samples=1500, d=100, bigR=25, n_spheres=11, r=5, seed=42):
np.random.seed(42)
# it seemed that rescaling the shift variance by sqrt of d lets big sphere stay around the inner spheres
variance = 10/np.sqrt(d)
shift_matrix = np.random.normal(0, variance, [n_spheres, d+1])
spheres = []
n_datapoints = 0
for i in np.arange(n_spheres-1):
sphere = dsphere(n=n_samples, d=d, r=r)
spheres.append(sphere + shift_matrix[i, :])
n_datapoints += n_samples
# additional big surrounding sphere
n_samples_big = 1*n_samples
big = dsphere(n=n_samples_big, d=d, r=bigR)
spheres.append(big)
n_datapoints += n_samples_big
# create Dataset
dataset = np.concatenate(spheres, axis=0)
labels = np.zeros(n_datapoints)
label_index = 0
for index, data in enumerate(spheres):
n_sphere_samples = data.shape[0]
labels[label_index:label_index + n_sphere_samples] = index
label_index += n_sphere_samples
arr = np.arange(dataset.shape[0])
np.random.shuffle(arr)
dataset = dataset[arr]
labels = labels[arr]
return dataset/22 + 0.5, labels
def LoadData(data_name='SwissRoll', data_num=1500, seed=0, noise=0.0, device=torch.device('cuda' if torch.cuda.is_available() else 'cpu'), remove=None, test=False):
"""
function used to load data
Arguments:
data_name {str} -- the dataset to be loaded
data_num {int} -- the data number to be loaded
seed {int} -- the seed for data generation
noise {float} -- the noise for data generation
device {torch} -- the device to store data
remove {str} -- Shape of the points removed from the generated manifold
"""
# Load SwissRoll Dataset
if data_name == 'SwissRoll':
if remove is None:
train_data, train_label = make_swiss_roll(n_samples=data_num, noise=noise, random_state=seed)
else:
train_data, train_label = make_swiss_roll(n_samples=data_num, noise=noise, random_state=seed+1, remove=remove, center=[10, 10], r=8)
train_data = train_data / 20
# Load SCurve Dataset
if data_name == 'SCurve':
train_data, train_label = make_s_curve(n_samples=data_num, noise=noise, random_state=seed)
train_data = train_data / 2
# Load Mnist Dataset
if data_name == 'MNIST':
train_data = torchvisiondatasets.MNIST(
'~/data', train=True, download=True,
transform=transforms.ToTensor()
).data.float().view(-1, 28*28)/255
train_label = torchvisiondatasets.MNIST(
'~/data', train=True, download=True,
transform=transforms.ToTensor()
).targets
if not test:
train_data = train_data[:data_num]
train_label = train_label[:data_num]
else:
train_data = train_data[data_num:data_num*2]
train_label = train_label[data_num:data_num*2]
if data_name == 'Spheres5500':
train_data, train_label = create_sphere_dataset5500(seed=seed)
if data_name == 'Spheres10000':
train_data, train_label = create_sphere_dataset10000(seed=seed)
# Put the data to device
train_data = torch.tensor(train_data).to(device)[:data_num]
train_label = torch.tensor(train_label).to(device)[:data_num]
return train_data, train_label