-
Notifications
You must be signed in to change notification settings - Fork 0
/
vae2.py
123 lines (96 loc) · 4.85 KB
/
vae2.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
import tensorflow as tf
from keras.layers import Input, LSTM, Dense, Reshape, Conv2DTranspose, GRU, TimeDistributed, RepeatVector
from keras.models import Model, Sequential
from keras.losses import MeanSquaredError, binary_crossentropy
from keras import backend as K
from keras.metrics import Mean
import numpy as np
from sklearn.model_selection import train_test_split
import pandas as pd
# Constants
SEQUENCE_LENGTH = 300 # Number of timesteps in each sample
NUM_FEATURES = 2 # Number of features in movement data
FRAME_HEIGHT = 512 # Height of video frame
FRAME_WIDTH = 512 # Width of video frame
# Load data
movement_data = pd.read_csv("data/mTBI/eye_motion_trace.csv")
video_data = np.load("data/mTBI/frame_pixels.npy")
print("loaded data")
class VideoVAE:
def __init__(self, input_shape, latent_dim, frame_shape, sequence_length):
self.input_shape = input_shape
self.latent_dim = latent_dim
self.frame_shape = frame_shape
self.sequence_length = sequence_length
self.encoder = self.build_encoder()
self.decoder = self.build_decoder()
self.vae = self.build_vae()
def build_encoder(self):
inputs = Input(shape=self.input_shape)
x = LSTM(128, return_sequences=False)(inputs)
z_mean = Dense(self.latent_dim)(x)
z_log_var = Dense(self.latent_dim)(x)
return Model(inputs, [z_mean, z_log_var], name='encoder')
def build_decoder(self):
latent_inputs = Input(shape=(self.latent_dim,))
initial_depth = 64
initial_shape = (8, 8, initial_depth)
initial_dense_units = np.prod(initial_shape) * self.sequence_length
x = Dense(initial_dense_units, activation='relu')(latent_inputs)
x = Reshape((self.sequence_length, *initial_shape))(x)
# Simplified decoder layers
x = TimeDistributed(Conv2DTranspose(128, kernel_size=3, strides=(2, 2), padding='same', activation='relu'))(x)
x = TimeDistributed(Conv2DTranspose(1, kernel_size=3, strides=(2, 2), padding='same', activation='sigmoid'))(x)
return Model(latent_inputs, x, name='decoder')
def build_vae(self):
inputs = Input(shape=self.input_shape)
z_mean, z_log_var = self.encoder(inputs)
epsilon = K.random_normal(shape=(K.shape(inputs)[0], self.latent_dim))
z = z_mean + K.exp(0.5 * z_log_var) * epsilon
outputs = self.decoder(z)
return Model(inputs, outputs, name='vae')
def vae_loss(encoder, decoder, original_movement, original_images, reconstructed_images):
original_shape = tf.shape(original_images)
reconstruction_shape = tf.shape(reconstructed_images)
original_images_flattened = tf.reshape(original_images, [-1, original_shape[2] * original_shape[3] * original_shape[4]])
reconstructed_images_flattened = tf.reshape(reconstructed_images, [-1, reconstruction_shape[2] * reconstruction_shape[3] * reconstruction_shape[4]])
reconstruction_loss = binary_crossentropy(original_images_flattened, reconstructed_images_flattened)
z_mean, z_log_var = encoder(original_movement)
kl_loss = 1 + z_log_var - K.square(z_mean) - K.exp(z_log_var)
kl_loss = K.sum(kl_loss, axis=-1)
kl_loss *= -0.5
total_loss = K.mean(reconstruction_loss + kl_loss)
return total_loss
# Model parameters
input_shape = (300, 2)
latent_dim = 100
frame_shape = (512, 512, 1)
sequence_length = 300
video_vae = VideoVAE(input_shape, latent_dim, frame_shape, sequence_length)
# Optimizer
optimizer = tf.keras.optimizers.Adam()
# Compile model
video_vae.vae.compile(optimizer=optimizer, loss=lambda y_true, y_pred: vae_loss(video_vae.encoder, video_vae.decoder, y_true, y_pred))
# Data preprocessing function
def preprocess_data(movement, video):
movement = movement.reshape(-1, SEQUENCE_LENGTH, NUM_FEATURES)
video = video.reshape(-1, SEQUENCE_LENGTH, FRAME_HEIGHT, FRAME_WIDTH, 1)
video = video / 255.0
return movement, video
# Split data
indices = np.arange(movement_data.shape[0])
train_indices, test_indices = train_test_split(indices, test_size=0.2)
# Create tf.data.Dataset
batch_size = 64
train_dataset = tf.data.Dataset.from_tensor_slices((movement_data.iloc[train_indices].values, video_data[train_indices]))
train_dataset = train_dataset.map(preprocess_data).batch(batch_size).prefetch(tf.data.experimental.AUTOTUNE)
validation_dataset = tf.data.Dataset.from_tensor_slices((movement_data.iloc[test_indices].values, video_data[test_indices]))
validation_dataset = validation_dataset.map(preprocess_data).batch(batch_size).prefetch(tf.data.experimental.AUTOTUNE)
# Training parameters
num_epochs = 10
train_steps = len(train_indices) // batch_size
validation_steps = len(test_indices) // batch_size
# Train model
video_vae.vae.fit(train_dataset, epochs=num_epochs, validation_data=validation_dataset, steps_per_epoch=train_steps, validation_steps=validation_steps)
# Save model
video_vae.vae.save('vae2.h5')