-
Notifications
You must be signed in to change notification settings - Fork 13
/
wavegan.py
302 lines (244 loc) · 11 KB
/
wavegan.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.parallel
import torch.utils.data
class PhaseShuffle(nn.Module):
"""
Performs phase shuffling, i.e. shifting feature axis of a 3D tensor
by a random integer in {-n, n} and performing reflection padding where
necessary
If batch shuffle is enabled, only a single shuffle is applied to the entire
batch, rather than each sample in the batch.
"""
def __init__(self, shift_factor, batch_shuffle=False):
super(PhaseShuffle, self).__init__()
self.shift_factor = shift_factor
self.batch_shuffle = batch_shuffle
def forward(self, x):
# Return x if phase shift is disabled
if self.shift_factor == 0:
return x
if self.batch_shuffle:
# Make sure to use PyTorcTrueh to generate number RNG state is all shared
k = int(torch.Tensor(1).random_(0, 2*self.shift_factor + 1)) - self.shift_factor
# Return if no phase shift
if k == 0:
return x
# Slice feature dimension
if k > 0:
x_trunc = x[:, :, :-k]
pad = (k, 0)
else:
x_trunc = x[:, :, -k:]
pad = (0, -k)
# Reflection padding
x_shuffle = F.pad(x_trunc, pad, mode='reflect')
else:
# Generate shifts for each sample in the batch
k_list = torch.Tensor(x.shape[0]).random_(0, 2*self.shift_factor+1)\
- self.shift_factor
k_list = k_list.numpy().astype(int)
# Combine sample indices into lists so that less shuffle operations
# need to be performed
k_map = {}
for idx, k in enumerate(k_list):
k = int(k)
if k not in k_map:
k_map[k] = []
k_map[k].append(idx)
# Make a copy of x for our output
x_shuffle = x.clone()
# Apply shuffle to each sample
for k, idxs in k_map.items():
if k > 0:
x_shuffle[idxs] = F.pad(x[idxs][..., :-k], (k,0), mode='reflect')
else:
x_shuffle[idxs] = F.pad(x[idxs][..., -k:], (0,-k), mode='reflect')
assert x_shuffle.shape == x.shape, "{}, {}".format(x_shuffle.shape,
x.shape)
return x_shuffle
class UpsampleConvLayer(torch.nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, stride, upsample=None):
super(UpsampleConvLayer, self).__init__()
self.upsample = upsample
if upsample:
self.upsample_layer = torch.nn.Upsample(scale_factor=upsample)
reflection_padding = kernel_size // 2
self.reflection_pad = torch.nn.ConstantPad1d(reflection_padding, value = 0)
# self.reflection_pad = torch.nn.ReflectionPad1d(reflection_padding)
self.conv1d = torch.nn.Conv1d(in_channels, out_channels, kernel_size, stride)
def forward(self, x):
x_in = x
if self.upsample:
x_in = self.upsample_layer(x_in)
out = self.reflection_pad(x_in)
out = self.conv1d(out)
return out
class WaveGANGenerator(nn.Module):
def __init__(self, model_size=64, ngpus=1, num_channels=1, latent_dim=100,
post_proc_filt_len=512, verbose=False, upsample=True):
super(WaveGANGenerator, self).__init__()
self.ngpus = ngpus
self.model_size = model_size # d
self.num_channels = num_channels # c
self.latent_dim = latent_dim
self.post_proc_filt_len = post_proc_filt_len
self.verbose = verbose
self.fc1 = nn.DataParallel(nn.Linear(latent_dim, 256 * model_size))
self.tconv1 = None
self.tconv2 = None
self.tconv3 = None
self.tconv4 = None
self.tconv5 = None
self.upSampConv1 = None
self.upSampConv2 = None
self.upSampConv3 = None
self.upSampConv4 = None
self.upSampConv5 = None
self.upsample = upsample
if self.upsample:
self.upSampConv1 = nn.DataParallel(
UpsampleConvLayer(16 * model_size, 8 * model_size, 25, stride=1, upsample=4))
self.upSampConv2 = nn.DataParallel(
UpsampleConvLayer(8 * model_size, 4 * model_size, 25, stride=1, upsample=4))
self.upSampConv3 = nn.DataParallel(
UpsampleConvLayer(4 * model_size, 2 * model_size, 25, stride=1, upsample=4))
self.upSampConv4 = nn.DataParallel(
UpsampleConvLayer(2 * model_size, model_size, 25, stride=1, upsample=4))
self.upSampConv5 = nn.DataParallel(
UpsampleConvLayer(model_size, num_channels, 25, stride=1, upsample=4))
else:
self.tconv1 = nn.DataParallel(
nn.ConvTranspose1d(16 * model_size, 8 * model_size, 25, stride=4, padding=11,
output_padding=1))
self.tconv2 = nn.DataParallel(
nn.ConvTranspose1d(8 * model_size, 4 * model_size, 25, stride=4, padding=11,
output_padding=1))
self.tconv3 = nn.DataParallel(
nn.ConvTranspose1d(4 * model_size, 2 * model_size, 25, stride=4, padding=11,
output_padding=1))
self.tconv4 = nn.DataParallel(
nn.ConvTranspose1d(2 * model_size, model_size, 25, stride=4, padding=11,
output_padding=1))
self.tconv5 = nn.DataParallel(
nn.ConvTranspose1d(model_size, num_channels, 25, stride=4, padding=11,
output_padding=1))
if post_proc_filt_len:
self.ppfilter1 = nn.DataParallel(nn.Conv1d(num_channels, num_channels, post_proc_filt_len))
for m in self.modules():
if isinstance(m, nn.ConvTranspose1d) or isinstance(m, nn.Linear):
nn.init.kaiming_normal(m.weight.data)
def forward(self, x):
x = self.fc1(x).view(-1, 16 * self.model_size, 16)
x = F.relu(x)
output = None
if self.verbose:
print(x.shape)
if self.upsample:
x = F.relu(self.upSampConv1(x))
if self.verbose:
print(x.shape)
x = F.relu(self.upSampConv2(x))
if self.verbose:
print(x.shape)
x = F.relu(self.upSampConv3(x))
if self.verbose:
print(x.shape)
x = F.relu(self.upSampConv4(x))
if self.verbose:
print(x.shape)
output = F.tanh(self.upSampConv5(x))
else:
x = F.relu(self.tconv1(x))
if self.verbose:
print(x.shape)
x = F.relu(self.tconv2(x))
if self.verbose:
print(x.shape)
x = F.relu(self.tconv3(x))
if self.verbose:
print(x.shape)
x = F.relu(self.tconv4(x))
if self.verbose:
print(x.shape)
output = F.tanh(self.tconv5(x))
if self.verbose:
print(output.shape)
if self.post_proc_filt_len:
# Pad for "same" filtering
if (self.post_proc_filt_len % 2) == 0:
pad_left = self.post_proc_filt_len // 2
pad_right = pad_left - 1
else:
pad_left = (self.post_proc_filt_len - 1) // 2
pad_right = pad_left
output = self.ppfilter1(F.pad(output, (pad_left, pad_right)))
if self.verbose:
print(output.shape)
return output
class WaveGANDiscriminator(nn.Module):
def __init__(self, model_size=64, ngpus=1, num_channels=1, shift_factor=2, alpha=0.2, batch_shuffle=False, verbose=False):
super(WaveGANDiscriminator, self).__init__()
self.model_size = model_size # d
self.ngpus = ngpus
self.num_channels = num_channels # c
self.shift_factor = shift_factor # n
self.alpha = alpha
self.verbose = verbose
# Conv2d(in_channels, out_channels, kernel_size, stride=1, etc.)
self.conv1 = nn.DataParallel(nn.Conv1d(num_channels, model_size, 25, stride=4, padding=11))
self.conv2 = nn.DataParallel(
nn.Conv1d(model_size, 2 * model_size, 25, stride=4, padding=11))
self.conv3 = nn.DataParallel(
nn.Conv1d(2 * model_size, 4 * model_size, 25, stride=4, padding=11))
self.conv4 = nn.DataParallel(
nn.Conv1d(4 * model_size, 8 * model_size, 25, stride=4, padding=11))
self.conv5 = nn.DataParallel(
nn.Conv1d(8 * model_size, 16 * model_size, 25, stride=4, padding=11))
self.ps1 = PhaseShuffle(shift_factor, batch_shuffle=batch_shuffle)
self.ps2 = PhaseShuffle(shift_factor, batch_shuffle=batch_shuffle)
self.ps3 = PhaseShuffle(shift_factor, batch_shuffle=batch_shuffle)
self.ps4 = PhaseShuffle(shift_factor, batch_shuffle=batch_shuffle)
self.fc1 = nn.DataParallel(nn.Linear(256 * model_size, 1))
for m in self.modules():
if isinstance(m, nn.Conv1d) or isinstance(m, nn.Linear):
nn.init.kaiming_normal(m.weight.data)
def forward(self, x):
x = F.leaky_relu(self.conv1(x), negative_slope=self.alpha)
if self.verbose:
print(x.shape)
x = self.ps1(x)
x = F.leaky_relu(self.conv2(x), negative_slope=self.alpha)
if self.verbose:
print(x.shape)
x = self.ps2(x)
x = F.leaky_relu(self.conv3(x), negative_slope=self.alpha)
if self.verbose:
print(x.shape)
x = self.ps3(x)
x = F.leaky_relu(self.conv4(x), negative_slope=self.alpha)
if self.verbose:
print(x.shape)
x = self.ps4(x)
x = F.leaky_relu(self.conv5(x), negative_slope=self.alpha)
if self.verbose:
print(x.shape)
x = x.view(-1, 256 * self.model_size)
if self.verbose:
print(x.shape)
return F.sigmoid(self.fc1(x))
def load_wavegan_generator(filepath, model_size=64, ngpus=1, num_channels=1,
latent_dim=100, post_proc_filt_len=512, **kwargs):
model = WaveGANGenerator(model_size=model_size, ngpus=ngpus,
num_channels=num_channels, latent_dim=latent_dim,
post_proc_filt_len=post_proc_filt_len)
model.load_state_dict(torch.load(filepath))
return model
def load_wavegan_discriminator(filepath, model_size=64, ngpus=1, num_channels=1,
shift_factor=2, alpha=0.2, **kwargs):
model = WaveGANDiscriminator(model_size=model_size, ngpus=ngpus,
num_channels=num_channels,
shift_factor=shift_factor, alpha=alpha)
model.load_state_dict(torch.load(filepath))
return model