-
Notifications
You must be signed in to change notification settings - Fork 12
/
sagan_models.py
77 lines (58 loc) · 2.71 KB
/
sagan_models.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
import numpy as np
import tensorflow as tf
from spectral import SpectralConv2D, SpectralConv2DTranspose
from attention import SelfAttnModel
def create_generator(image_size=64, z_dim=100, filters=64, kernel_size=4):
input_layers = tf.keras.layers.Input((z_dim,))
x = tf.keras.layers.Reshape((1, 1, z_dim))(input_layers)
repeat_num = int(np.log2(image_size)) - 1
mult = 2 ** (repeat_num - 1)
curr_filters = filters * mult
for i in range(3):
curr_filters = curr_filters // 2
strides = 4 if i == 0 else 2
x = SpectralConv2DTranspose(filters=curr_filters,
kernel_size=kernel_size,
strides=strides,
padding='same')(x)
x = tf.keras.layers.BatchNormalization()(x)
x = tf.keras.layers.ReLU()(x)
x, attn1 = SelfAttnModel(curr_filters)(x)
for i in range(repeat_num - 4):
curr_filters = curr_filters // 2
x = SpectralConv2DTranspose(filters=curr_filters,
kernel_size=kernel_size,
strides=2,
padding='same')(x)
x = tf.keras.layers.BatchNormalization()(x)
x = tf.keras.layers.ReLU()(x)
x, attn2 = SelfAttnModel(curr_filters)(x)
x = SpectralConv2DTranspose(filters=3,
kernel_size=kernel_size,
strides=2,
padding='same')(x)
x = tf.keras.layers.Activation('tanh')(x)
return tf.keras.models.Model(input_layers, [x, attn1, attn2])
def create_discriminator(image_size=64, filters=64, kernel_size=4):
input_layers = tf.keras.layers.Input((image_size, image_size, 3))
curr_filters = filters
x = input_layers
for i in range(3):
curr_filters = curr_filters * 2
x = SpectralConv2D(filters=curr_filters,
kernel_size=kernel_size,
strides=2,
padding='same')(x)
x = tf.keras.layers.LeakyReLU(alpha=0.1)(x)
x, attn1 = SelfAttnModel(curr_filters)(x)
for i in range(int(np.log2(image_size)) - 5):
curr_filters = curr_filters * 2
x = SpectralConv2D(filters=curr_filters,
kernel_size=kernel_size,
strides=2,
padding='same')(x)
x = tf.keras.layers.LeakyReLU(alpha=0.1)(x)
x, attn2 = SelfAttnModel(curr_filters)(x)
x = SpectralConv2D(filters=1, kernel_size=4)(x)
x = tf.keras.layers.Flatten()(x)
return tf.keras.models.Model(input_layers, [x, attn1, attn2])