Skip to content

[TF2] add GAN #1716

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 5 commits into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
142 changes: 140 additions & 2 deletions chapter_generative-adversarial-networks/gan.md
Original file line number Diff line number Diff line change
Expand Up @@ -58,27 +58,64 @@ import torch
from torch import nn
```


```{.python .input}
#@tab tensorflow
%matplotlib inline
from d2l import tensorflow as d2l
import tensorflow as tf
```

## Generate some "real" data

Since this is going to be the world's lamest example, we simply generate data drawn from a Gaussian.

```{.python .input}
#@tab all
X = d2l.normal(0.0, 1, (1000, 2))
A = d2l.tensor([[1, 2], [-0.1, 0.5]])
b = d2l.tensor([1, 2])
data = d2l.matmul(X, A) + b
```

```{.python .input}
#@tab pytorch
X = d2l.normal(0.0, 1, (1000, 2))
A = d2l.tensor([[1, 2], [-0.1, 0.5]])
b = d2l.tensor([1, 2])
data = d2l.matmul(X, A) + b
```

```{.python .input}
#@tab tensorflow
X = d2l.normal([1000, 2], 0.0, 1, tf.float32)
A = d2l.tensor([[1, 2], [-0.1, 0.5]], tf.float32)
b = d2l.tensor([1, 2], tf.float32)
data = d2l.matmul(X, A) + b
```

Let us see what we got. This should be a Gaussian shifted in some rather arbitrary way with mean $b$ and covariance matrix $A^TA$.

```{.python .input}
#@tab all
d2l.set_figsize()
d2l.plt.scatter(d2l.numpy(data[:100, 0]), d2l.numpy(data[:100, 1]));
print(f'The covariance matrix is\n{d2l.matmul(A.T, A)}')
```

```{.python .input}
#@tab pytorch
d2l.set_figsize()
d2l.plt.scatter(d2l.numpy(data[:100, 0]), d2l.numpy(data[:100, 1]));
print(f'The covariance matrix is\n{d2l.matmul(A.T, A)}')
```

```{.python .input}
#@tab tensorflow
d2l.set_figsize()
d2l.plt.scatter(d2l.numpy(data[:100, 0]), d2l.numpy(data[:100, 1]));
print(f'The covariance matrix is\n{d2l.matmul(tf.transpose(A), A)}')
```


```{.python .input}
#@tab all
batch_size = 8
Expand All @@ -99,6 +136,13 @@ net_G.add(nn.Dense(2))
net_G = nn.Sequential(nn.Linear(2, 2))
```

```{.python .input}
#@tab tensorflow
net_G = tf.keras.Sequential()
net_G.add(tf.keras.layers.Dense(2))
```


## Discriminator

For the discriminator we will be a bit more discriminating: we will use an MLP with 3 layers to make things a bit more interesting.
Expand All @@ -118,6 +162,15 @@ net_D = nn.Sequential(
nn.Linear(3, 1))
```


```{.python .input}
#@tab tensorflow
net_D = tf.keras.Sequential()
net_D.add(tf.keras.layers.Dense(5, activation='tanh'))
net_D.add(tf.keras.layers.Dense(3, activation='tanh'))
net_D.add(tf.keras.layers.Dense(1))
```

## Training

First we define a function to update the discriminator.
Expand Down Expand Up @@ -162,6 +215,29 @@ def update_D(X, Z, net_D, net_G, loss, trainer_D):
return loss_D
```


```{.python .input}
#@tab tensorflow
#@save
@tf.function
def update_D(X, Z, net_D, net_G, loss, trainer_D):
"""Update discriminator."""
batch_size = X.shape[0]
ones = tf.ones_like(batch_size, tf.float32)# , device=X.device)
zeros = tf.zeros_like(batch_size, tf.float32)# , device=X.device)
with tf.GradientTape() as D_tape:
real_Y = net_D(X, training=True)
fake_X = net_G(X, training=True)
# Do not need to compute gradient for `net_G`, detach it from
# computing gradients.
fake_Y = net_D(tf.stop_gradient(fake_X), training=False)
loss_D = (loss(ones, real_Y) + loss(zeros, fake_Y)) / 2
grads = D_tape.gradient(loss_D, net_D.trainable_variables)
trainer_D.apply_gradients(zip(grads, net_D.trainable_weights))
return float(loss_D)
```


The generator is updated similarly. Here we reuse the cross-entropy loss but change the label of the fake data from $0$ to $1$.

```{.python .input}
Expand Down Expand Up @@ -199,6 +275,25 @@ def update_G(Z, net_D, net_G, loss, trainer_G):
return loss_G
```

```{.python .input}
#@tab tensorflow
#@save
@tf.function
def update_G(Z, net_D, net_G, loss, trainer_G):
"""Update generator."""
batch_size = Z.shape[0]
ones = tf.ones_like(batch_size, tf.float32)
with tf.GradientTape() as G_tape:
# We could reuse `fake_X` from `update_D` to save computation
fake_X = net_G(Z, training=True)
# Recomputing `fake_Y` is needed since `net_D` is changed
fake_Y = net_D(fake_X, training=True)
loss_G = loss(ones, fake_Y)
grads = G_tape.gradient(loss_G, net_G.trainable_variables)
trainer_G.apply_gradients(zip(grads, net_G.trainable_weights))
return float(loss_G)
```

Both the discriminator and the generator performs a binary logistic regression with the cross-entropy loss. We use Adam to smooth the training process. In each iteration, we first update the discriminator and then the generator. We visualize both losses and generated examples.

```{.python .input}
Expand Down Expand Up @@ -276,6 +371,45 @@ def train(net_D, net_G, data_iter, num_epochs, lr_D, lr_G, latent_dim, data):
f'{metric[2] / timer.stop():.1f} examples/sec')
```

```{.python .input}
#@tab tensorflow
def train(net_D, net_G, data_iter, num_epochs, lr_D, lr_G, latent_dim, data):
loss = tf.keras.losses.BinaryCrossentropy(reduction=
tf.keras.losses.Reduction.SUM)
tf.keras.initializers.RandomNormal(net_D, stddev=0.02)
tf.keras.initializers.RandomNormal(net_G, stddev=0.02)
trainer_D = tf.keras.optimizers.Adam(learning_rate=lr_D)
trainer_G = tf.keras.optimizers.Adam(learning_rate=lr_G)
net_D.compile(optimizer=trainer_D)
net_G.compile(optimizer=trainer_G)
animator = d2l.Animator(xlabel='epoch', ylabel='loss',
xlim=[1, num_epochs], nrows=2, figsize=(5, 5),
legend=['discriminator', 'generator'])
animator.fig.subplots_adjust(hspace=0.3)
for epoch in range(num_epochs):
# Train one epoch
timer = d2l.Timer()
metric = d2l.Accumulator(3) # loss_D, loss_G, num_examples
for (X,) in data_iter:
batch_size = X.shape[0]
Z = d2l.normal([batch_size, latent_dim], 0, 1, tf.float32)
metric.add(update_D(X, Z, net_D, net_G, loss, trainer_D),
update_G(Z, net_D, net_G, loss, trainer_G),
batch_size)
# Visualize generated examples
Z = d2l.normal([100, latent_dim], 0, 1, tf.float32)
fake_X = net_G(tf.stop_gradient(Z), training=False).numpy()
animator.axes[1].cla()
animator.axes[1].scatter(data[:, 0], data[:, 1])
animator.axes[1].scatter(fake_X[:, 0], fake_X[:, 1])
animator.axes[1].legend(['real', 'generated'])
# Show the losses
loss_D, loss_G = metric[0]/metric[2], metric[1]/metric[2]
animator.add(epoch + 1, (loss_D, loss_G))
print(f'loss_D {loss_D:.3f}, loss_G {loss_G:.3f}, '
f'{metric[2] / timer.stop():.1f} examples/sec')
```

Now we specify the hyperparameters to fit the Gaussian distribution.

```{.python .input}
Expand All @@ -302,3 +436,7 @@ train(net_D, net_G, data_iter, num_epochs, lr_D, lr_G,
:begin_tab:`pytorch`
[Discussions](https://discuss.d2l.ai/t/1082)
:end_tab:

:begin_tab:`tensorflow`
[Discussions](https://discuss.d2l.ai/t/2534)
:end_tab: