Skip to content

Commit

Permalink
Merge pull request #1 from ThomasDelteil/patch-33
Browse files Browse the repository at this point in the history
Update info_gan.md
  • Loading branch information
NRauschmayr committed Nov 6, 2018
2 parents 8403706 + 7d7470a commit 64ae54f
Showing 1 changed file with 11 additions and 9 deletions.
20 changes: 11 additions & 9 deletions docs/tutorials/gluon/info_gan.md
Original file line number Diff line number Diff line change
Expand Up @@ -88,10 +88,10 @@ split = int(len(images)*0.8)
test_images = images[split:]
test_filenames = filenames[split:]
train_images = images[:split]
train_filenames = filenames[:split]
train_filenames = filenames[:split]

train_data = mx.gluon.data.ArrayDataset(nd.concatenate(train_images))
train_dataloader = mx.gluon.data.DataLoader(train_data, batch_size=batch_size, shuffle=True, last_batch='rollover', num_workers=4)
train_data = gluon.data.ArrayDataset(nd.concatenate(train_images))
train_dataloader = gluon.data.DataLoader(train_data, batch_size=batch_size, shuffle=True, last_batch='rollover', num_workers=4)
```

## Generator
Expand Down Expand Up @@ -239,13 +239,13 @@ This function samples `c`, `z`, and concatenates them to create the generator in
def create_generator_input():

#create random noise
z = mx.nd.random_normal(0, 1, shape=(batch_size, z_dim), ctx=ctx)
z = nd.random_normal(0, 1, shape=(batch_size, z_dim), ctx=ctx)
label = nd.array(np.random.randint(n_categories, size=batch_size)).as_in_context(ctx)
c1 = nd.one_hot(label, depth=n_categories).as_in_context(ctx)
c2 = nd.random.uniform(-1, 1, shape=(batch_size, n_continuous)).as_in_context(ctx)

# concatenate random noise with c which will be the input of the generator
return mx.nd.concat(z, c1, c2, dim=1)
return nd.concat(z, c1, c2, dim=1)
```

Define the training loop.
Expand All @@ -265,8 +265,8 @@ with SummaryWriter(logdir='./logs/') as sw:
print("Epoch", epoch)
starttime = time.time()

d_error_epoch = mx.nd.zeros((1,), ctx=ctx)
g_error_epoch = mx.nd.zeros((1,), ctx=ctx)
d_error_epoch = nd.zeros((1,), ctx=ctx)
g_error_epoch = nd.zeros((1,), ctx=ctx)

for idx, data in enumerate(train_dataloader):
i = i + 1
Expand Down Expand Up @@ -372,11 +372,11 @@ Take some images from the test data, obtain its feature vector from `discriminat
```python
feature_size = 8192

features = mx.nd.zeros((len(test_images), feature_size), ctx=ctx)
features = nd.zeros((len(test_images), feature_size), ctx=ctx)

for idx, image in enumerate(test_images):

feature = discriminator(mx.nd.array(image))
feature = discriminator(nd.array(image))
feature = feature.reshape(feature_size,)
features[idx,:] = feature.copyto(ctx)

Expand Down Expand Up @@ -434,3 +434,5 @@ with open("imagetsne.json", 'w') as outfile:
Load the file with TSNEViewer. You can now inspect whether similiar looking images are grouped nearby or not.

<img src="https://raw.githubusercontent.com/NRauschmayr/web-data/master/mxnet/doc/tutorials/info_gan/tsne.png" style="width:800px;height:600px;">

<!-- INSERT SOURCE DOWNLOAD BUTTONS -->

0 comments on commit 64ae54f

Please sign in to comment.