-
Notifications
You must be signed in to change notification settings - Fork 31
/
Copy pathmodels.py
438 lines (360 loc) · 19.4 KB
/
models.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
"""
TensorFlow models for use in this project.
"""
from .utils import *
from .nn_utils import *
from .gp_kernel import *
from tensorflow_probability import distributions as tfd
import tensorflow as tf
# Encoders
class DiagonalEncoder(tf.keras.Model):
def __init__(self, z_size, hidden_sizes=(64, 64), **kwargs):
""" Encoder with factorized Normal posterior over temporal dimension
Used by disjoint VAE and HI-VAE with Standard Normal prior
:param z_size: latent space dimensionality
:param hidden_sizes: tuple of hidden layer sizes.
The tuple length sets the number of hidden layers.
"""
super(DiagonalEncoder, self).__init__()
self.z_size = int(z_size)
self.net = make_nn(2*z_size, hidden_sizes)
def __call__(self, x):
mapped = self.net(x)
return tfd.MultivariateNormalDiag(
loc=mapped[..., :self.z_size],
scale_diag=tf.nn.softplus(mapped[..., self.z_size:]))
class JointEncoder(tf.keras.Model):
def __init__(self, z_size, hidden_sizes=(64, 64), window_size=3, transpose=False, **kwargs):
""" Encoder with 1d-convolutional network and factorized Normal posterior
Used by joint VAE and HI-VAE with Standard Normal prior or GP-VAE with factorized Normal posterior
:param z_size: latent space dimensionality
:param hidden_sizes: tuple of hidden layer sizes.
The tuple length sets the number of hidden layers.
:param window_size: kernel size for Conv1D layer
:param transpose: True for GP prior | False for Standard Normal prior
"""
super(JointEncoder, self).__init__()
self.z_size = int(z_size)
self.net = make_cnn(2*z_size, hidden_sizes, window_size)
self.transpose = transpose
def __call__(self, x):
mapped = self.net(x)
if self.transpose:
num_dim = len(x.shape.as_list())
perm = list(range(num_dim - 2)) + [num_dim - 1, num_dim - 2]
mapped = tf.transpose(mapped, perm=perm)
return tfd.MultivariateNormalDiag(
loc=mapped[..., :self.z_size, :],
scale_diag=tf.nn.softplus(mapped[..., self.z_size:, :]))
return tfd.MultivariateNormalDiag(
loc=mapped[..., :self.z_size],
scale_diag=tf.nn.softplus(mapped[..., self.z_size:]))
class BandedJointEncoder(tf.keras.Model):
def __init__(self, z_size, hidden_sizes=(64, 64), window_size=3, data_type=None, **kwargs):
""" Encoder with 1d-convolutional network and multivariate Normal posterior
Used by GP-VAE with proposed banded covariance matrix
:param z_size: latent space dimensionality
:param hidden_sizes: tuple of hidden layer sizes.
The tuple length sets the number of hidden layers.
:param window_size: kernel size for Conv1D layer
:param data_type: needed for some data specific modifications, e.g:
tf.nn.softplus is a more common and correct choice, however
tf.nn.sigmoid provides more stable performance on Physionet dataset
"""
super(BandedJointEncoder, self).__init__()
self.z_size = int(z_size)
self.net = make_cnn(3*z_size, hidden_sizes, window_size)
self.data_type = data_type
def __call__(self, x):
mapped = self.net(x)
batch_size = mapped.shape.as_list()[0]
time_length = mapped.shape.as_list()[1]
# Obtain mean and precision matrix components
num_dim = len(mapped.shape.as_list())
perm = list(range(num_dim - 2)) + [num_dim - 1, num_dim - 2]
mapped_transposed = tf.transpose(mapped, perm=perm)
mapped_mean = mapped_transposed[:, :self.z_size]
mapped_covar = mapped_transposed[:, self.z_size:]
# tf.nn.sigmoid provides more stable performance on Physionet dataset
if self.data_type == 'physionet':
mapped_covar = tf.nn.sigmoid(mapped_covar)
else:
mapped_covar = tf.nn.softplus(mapped_covar)
mapped_reshaped = tf.reshape(mapped_covar, [batch_size, self.z_size, 2*time_length])
dense_shape = [batch_size, self.z_size, time_length, time_length]
idxs_1 = np.repeat(np.arange(batch_size), self.z_size*(2*time_length-1))
idxs_2 = np.tile(np.repeat(np.arange(self.z_size), (2*time_length-1)), batch_size)
idxs_3 = np.tile(np.concatenate([np.arange(time_length), np.arange(time_length-1)]), batch_size*self.z_size)
idxs_4 = np.tile(np.concatenate([np.arange(time_length), np.arange(1,time_length)]), batch_size*self.z_size)
idxs_all = np.stack([idxs_1, idxs_2, idxs_3, idxs_4], axis=1)
# ~10x times faster on CPU then on GPU
with tf.device('/cpu:0'):
# Obtain covariance matrix from precision one
mapped_values = tf.reshape(mapped_reshaped[:, :, :-1], [-1])
prec_sparse = tf.sparse.SparseTensor(indices=idxs_all, values=mapped_values, dense_shape=dense_shape)
prec_sparse = tf.sparse.reorder(prec_sparse)
prec_tril = tf.sparse_add(tf.zeros(prec_sparse.dense_shape, dtype=tf.float32), prec_sparse)
eye = tf.eye(num_rows=prec_tril.shape.as_list()[-1], batch_shape=prec_tril.shape.as_list()[:-2])
prec_tril = prec_tril + eye
cov_tril = tf.linalg.triangular_solve(matrix=prec_tril, rhs=eye, lower=False)
cov_tril = tf.where(tf.math.is_finite(cov_tril), cov_tril, tf.zeros_like(cov_tril))
num_dim = len(cov_tril.shape)
perm = list(range(num_dim - 2)) + [num_dim - 1, num_dim - 2]
cov_tril_lower = tf.transpose(cov_tril, perm=perm)
z_dist = tfd.MultivariateNormalTriL(loc=mapped_mean, scale_tril=cov_tril_lower)
return z_dist
# Decoders
class Decoder(tf.keras.Model):
def __init__(self, output_size, hidden_sizes=(64, 64)):
""" Decoder parent class with no specified output distribution
:param output_size: output dimensionality
:param hidden_sizes: tuple of hidden layer sizes.
The tuple length sets the number of hidden layers.
"""
super(Decoder, self).__init__()
self.net = make_nn(output_size, hidden_sizes)
def __call__(self, x):
pass
class BernoulliDecoder(Decoder):
""" Decoder with Bernoulli output distribution (used for HMNIST) """
def __call__(self, x):
mapped = self.net(x)
return tfd.Bernoulli(logits=mapped)
class GaussianDecoder(Decoder):
""" Decoder with Gaussian output distribution (used for SPRITES and Physionet) """
def __call__(self, x):
mean = self.net(x)
var = tf.ones(tf.shape(mean), dtype=tf.float32)
return tfd.Normal(loc=mean, scale=var)
# Image preprocessor
class ImagePreprocessor(tf.keras.Model):
def __init__(self, image_shape, hidden_sizes=(256, ), kernel_size=3.):
""" Decoder parent class without specified output distribution
:param image_shape: input image size
:param hidden_sizes: tuple of hidden layer sizes.
The tuple length sets the number of hidden layers.
:param kernel_size: kernel/filter width and height
"""
super(ImagePreprocessor, self).__init__()
self.image_shape = image_shape
self.net = make_2d_cnn(image_shape[-1], hidden_sizes, kernel_size)
def __call__(self, x):
return self.net(x)
# VAE models
class VAE(tf.keras.Model):
def __init__(self, latent_dim, data_dim, time_length,
encoder_sizes=(64, 64), encoder=DiagonalEncoder,
decoder_sizes=(64, 64), decoder=BernoulliDecoder,
image_preprocessor=None, beta=1.0, M=1, K=1, **kwargs):
""" Basic Variational Autoencoder with Standard Normal prior
:param latent_dim: latent space dimensionality
:param data_dim: original data dimensionality
:param time_length: time series duration
:param encoder_sizes: layer sizes for the encoder network
:param encoder: encoder model class {Diagonal, Joint, BandedJoint}Encoder
:param decoder_sizes: layer sizes for the decoder network
:param decoder: decoder model class {Bernoulli, Gaussian}Decoder
:param image_preprocessor: 2d-convolutional network used for image data preprocessing
:param beta: tradeoff coefficient between reconstruction and KL terms in ELBO
:param M: number of Monte Carlo samples for ELBO estimation
:param K: number of importance weights for IWAE model (see: https://arxiv.org/abs/1509.00519)
"""
super(VAE, self).__init__()
self.latent_dim = latent_dim
self.data_dim = data_dim
self.time_length = time_length
self.encoder = encoder(latent_dim, encoder_sizes, **kwargs)
self.decoder = decoder(data_dim, decoder_sizes)
self.preprocessor = image_preprocessor
self.beta = beta
self.K = K
self.M = M
def encode(self, x):
x = tf.identity(x) # in case x is not a Tensor already...
if self.preprocessor is not None:
x_shape = x.shape.as_list()
new_shape = [x_shape[0] * x_shape[1]] + list(self.preprocessor.image_shape)
x_reshaped = tf.reshape(x, new_shape)
x_preprocessed = self.preprocessor(x_reshaped)
x = tf.reshape(x_preprocessed, x_shape)
return self.encoder(x)
def decode(self, z):
z = tf.identity(z) # in case z is not a Tensor already...
return self.decoder(z)
def __call__(self, inputs):
return self.decode(self.encode(inputs).sample()).sample()
def generate(self, noise=None, num_samples=1):
if noise is None:
noise = tf.random_normal(shape=(num_samples, self.latent_dim))
return self.decode(noise)
def _get_prior(self):
if self.prior is None:
self.prior = tfd.MultivariateNormalDiag(loc=tf.zeros(self.latent_dim, dtype=tf.float32),
scale_diag=tf.ones(self.latent_dim, dtype=tf.float32))
return self.prior
def compute_nll(self, x, y=None, m_mask=None):
# Used only for evaluation
assert len(x.shape) == 3, "Input should have shape: [batch_size, time_length, data_dim]"
if y is None: y = x
z_sample = self.encode(x).sample()
x_hat_dist = self.decode(z_sample)
nll = -x_hat_dist.log_prob(y) # shape=(BS, TL, D)
nll = tf.where(tf.math.is_finite(nll), nll, tf.zeros_like(nll))
if m_mask is not None:
m_mask = tf.cast(m_mask, tf.bool)
nll = tf.where(m_mask, nll, tf.zeros_like(nll)) # !!! inverse mask, set zeros for observed
return tf.reduce_sum(nll)
def compute_mse(self, x, y=None, m_mask=None, binary=False):
# Used only for evaluation
assert len(x.shape) == 3, "Input should have shape: [batch_size, time_length, data_dim]"
if y is None: y = x
z_mean = self.encode(x).mean()
x_hat_mean = self.decode(z_mean).mean() # shape=(BS, TL, D)
if binary:
x_hat_mean = tf.round(x_hat_mean)
mse = tf.math.squared_difference(x_hat_mean, y)
if m_mask is not None:
m_mask = tf.cast(m_mask, tf.bool)
mse = tf.where(m_mask, mse, tf.zeros_like(mse)) # !!! inverse mask, set zeros for observed
return tf.reduce_sum(mse)
def _compute_loss(self, x, m_mask=None, return_parts=False):
assert len(x.shape) == 3, "Input should have shape: [batch_size, time_length, data_dim]"
x = tf.identity(x) # in case x is not a Tensor already...
x = tf.tile(x, [self.M * self.K, 1, 1]) # shape=(M*K*BS, TL, D)
if m_mask is not None:
m_mask = tf.identity(m_mask) # in case m_mask is not a Tensor already...
m_mask = tf.tile(m_mask, [self.M * self.K, 1, 1]) # shape=(M*K*BS, TL, D)
m_mask = tf.cast(m_mask, tf.bool)
pz = self._get_prior()
qz_x = self.encode(x)
z = qz_x.sample()
px_z = self.decode(z)
nll = -px_z.log_prob(x) # shape=(M*K*BS, TL, D)
nll = tf.where(tf.math.is_finite(nll), nll, tf.zeros_like(nll))
if m_mask is not None:
nll = tf.where(m_mask, tf.zeros_like(nll), nll) # if not HI-VAE, m_mask is always zeros
nll = tf.reduce_sum(nll, [1, 2]) # shape=(M*K*BS)
if self.K > 1:
kl = qz_x.log_prob(z) - pz.log_prob(z) # shape=(M*K*BS, TL or d)
kl = tf.where(tf.is_finite(kl), kl, tf.zeros_like(kl))
kl = tf.reduce_sum(kl, 1) # shape=(M*K*BS)
weights = -nll - kl # shape=(M*K*BS)
weights = tf.reshape(weights, [self.M, self.K, -1]) # shape=(M, K, BS)
elbo = reduce_logmeanexp(weights, axis=1) # shape=(M, 1, BS)
elbo = tf.reduce_mean(elbo) # scalar
else:
# if K==1, compute KL analytically
kl = self.kl_divergence(qz_x, pz) # shape=(M*K*BS, TL or d)
kl = tf.where(tf.math.is_finite(kl), kl, tf.zeros_like(kl))
kl = tf.reduce_sum(kl, 1) # shape=(M*K*BS)
elbo = -nll - self.beta * kl # shape=(M*K*BS) K=1
elbo = tf.reduce_mean(elbo) # scalar
if return_parts:
nll = tf.reduce_mean(nll) # scalar
kl = tf.reduce_mean(kl) # scalar
return -elbo, nll, kl
else:
return -elbo
def compute_loss(self, x, m_mask=None, return_parts=False):
del m_mask
return self._compute_loss(x, return_parts=return_parts)
def kl_divergence(self, a, b):
return tfd.kl_divergence(a, b)
def get_trainable_vars(self):
self.compute_loss(tf.random.normal(shape=(1, self.time_length, self.data_dim), dtype=tf.float32),
tf.zeros(shape=(1, self.time_length, self.data_dim), dtype=tf.float32))
return self.trainable_variables
class HI_VAE(VAE):
""" HI-VAE model, where the reconstruction term in ELBO is summed only over observed components """
def compute_loss(self, x, m_mask=None, return_parts=False):
return self._compute_loss(x, m_mask=m_mask, return_parts=return_parts)
class GP_VAE(HI_VAE):
def __init__(self, *args, kernel="cauchy", sigma=1., length_scale=1.0, kernel_scales=1, **kwargs):
""" Proposed GP-VAE model with Gaussian Process prior
:param kernel: Gaussial Process kernel ["cauchy", "diffusion", "rbf", "matern"]
:param sigma: scale parameter for a kernel function
:param length_scale: length scale parameter for a kernel function
:param kernel_scales: number of different length scales over latent space dimensions
"""
super(GP_VAE, self).__init__(*args, **kwargs)
self.kernel = kernel
self.sigma = sigma
self.length_scale = length_scale
self.kernel_scales = kernel_scales
if isinstance(self.encoder, JointEncoder):
self.encoder.transpose = True
# Precomputed KL components for efficiency
self.pz_scale_inv = None
self.pz_scale_log_abs_determinant = None
self.prior = None
def decode(self, z):
num_dim = len(z.shape)
assert num_dim > 2
perm = list(range(num_dim - 2)) + [num_dim - 1, num_dim - 2]
return self.decoder(tf.transpose(z, perm=perm))
def _get_prior(self):
if self.prior is None:
# Compute kernel matrices for each latent dimension
kernel_matrices = []
for i in range(self.kernel_scales):
if self.kernel == "rbf":
kernel_matrices.append(rbf_kernel(self.time_length, self.length_scale / 2**i))
elif self.kernel == "diffusion":
kernel_matrices.append(diffusion_kernel(self.time_length, self.length_scale / 2**i))
elif self.kernel == "matern":
kernel_matrices.append(matern_kernel(self.time_length, self.length_scale / 2**i))
elif self.kernel == "cauchy":
kernel_matrices.append(cauchy_kernel(self.time_length, self.sigma, self.length_scale / 2**i))
# Combine kernel matrices for each latent dimension
tiled_matrices = []
total = 0
for i in range(self.kernel_scales):
if i == self.kernel_scales-1:
multiplier = self.latent_dim - total
else:
multiplier = int(np.ceil(self.latent_dim / self.kernel_scales))
total += multiplier
tiled_matrices.append(tf.tile(tf.expand_dims(kernel_matrices[i], 0), [multiplier, 1, 1]))
kernel_matrix_tiled = np.concatenate(tiled_matrices)
assert len(kernel_matrix_tiled) == self.latent_dim
self.prior = tfd.MultivariateNormalFullCovariance(
loc=tf.zeros([self.latent_dim, self.time_length], dtype=tf.float32),
covariance_matrix=kernel_matrix_tiled)
return self.prior
def kl_divergence(self, a, b):
""" Batched KL divergence `KL(a || b)` for multivariate Normals.
See https://github.com/tensorflow/probability/blob/master/tensorflow_probability
/python/distributions/mvn_linear_operator.py
It's used instead of default KL class in order to exploit precomputed components for efficiency
"""
def squared_frobenius_norm(x):
"""Helper to make KL calculation slightly more readable."""
return tf.reduce_sum(tf.square(x), axis=[-2, -1])
def is_diagonal(x):
"""Helper to identify if `LinearOperator` has only a diagonal component."""
return (isinstance(x, tf.linalg.LinearOperatorIdentity) or
isinstance(x, tf.linalg.LinearOperatorScaledIdentity) or
isinstance(x, tf.linalg.LinearOperatorDiag))
if is_diagonal(a.scale) and is_diagonal(b.scale):
# Using `stddev` because it handles expansion of Identity cases.
b_inv_a = (a.stddev() / b.stddev())[..., tf.newaxis]
else:
if self.pz_scale_inv is None:
self.pz_scale_inv = tf.linalg.inv(b.scale.to_dense())
self.pz_scale_inv = tf.where(tf.math.is_finite(self.pz_scale_inv),
self.pz_scale_inv, tf.zeros_like(self.pz_scale_inv))
if self.pz_scale_log_abs_determinant is None:
self.pz_scale_log_abs_determinant = b.scale.log_abs_determinant()
a_shape = a.scale.shape
if len(b.scale.shape) == 3:
_b_scale_inv = tf.tile(self.pz_scale_inv[tf.newaxis], [a_shape[0]] + [1] * (len(a_shape) - 1))
else:
_b_scale_inv = tf.tile(self.pz_scale_inv, [a_shape[0]] + [1] * (len(a_shape) - 1))
b_inv_a = _b_scale_inv @ a.scale.to_dense()
# ~10x times faster on CPU then on GPU
with tf.device('/cpu:0'):
kl_div = (self.pz_scale_log_abs_determinant - a.scale.log_abs_determinant() +
0.5 * (-tf.cast(a.scale.domain_dimension_tensor(), a.dtype) +
squared_frobenius_norm(b_inv_a) + squared_frobenius_norm(
b.scale.solve((b.mean() - a.mean())[..., tf.newaxis]))))
return kl_div