Skip to content

Commit

Permalink
vae notebook size shrink
Browse files Browse the repository at this point in the history
  • Loading branch information
denizyuret committed Dec 15, 2017
2 parents 747e54c + 0a5e75b commit 7d14a44
Show file tree
Hide file tree
Showing 3 changed files with 141 additions and 106 deletions.
29 changes: 29 additions & 0 deletions data/mnist.jl
Original file line number Diff line number Diff line change
Expand Up @@ -72,4 +72,33 @@ function _mnist_gzload(file)
return(a)
end

function mnistgrid(y; gridsize=(4,4), scale=2.0, shape=(28,28))
y = reshape(y, shape..., size(y)[end])
y = map(x->y[:,:,x]', [1:size(y,3)...])
shp = map(x->Int(round(x*scale)), shape)
y = map(x->Images.imresize(x,shp), y)
gridx, gridy = gridsize
outdims = (gridx*shp[1]+gridx+1,gridy*shp[2]+gridy+1)
out = zeros(outdims...)
for k = 1:gridx+1; out[(k-1)*(shp[1]+1)+1,:] = 1.0; end
for k = 1:gridy+1; out[:,(k-1)*(shp[2]+1)+1] = 1.0; end

x0 = y0 = 2
for k = 1:length(y)
x1 = x0+shp[1]-1
y1 = y0+shp[2]-1
out[x0:x1,y0:y1] = y[k]

y0 = y1+2
if k % gridy == 0
x0 = x1+2
y0 = 2
else
y0 = y1+2
end
end

return convert(Array{Float64,2}, map(x->isnan(x)?0:x, out))
end

nothing
69 changes: 12 additions & 57 deletions examples/variational-autoencoder/vae_mnist.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
for p in ("Knet","ArgParse","PyPlot")
for p in ("Knet","ArgParse","Images")
Pkg.installed(p) == nothing && Pkg.add(p)
end

Expand All @@ -7,8 +7,8 @@ Train a Variational Autoencoder on the MNIST dataset.
"""
module VAE
using Knet
using PyPlot # comment out if not plotting
using ArgParse
using Images
include(Pkg.dir("Knet","data","mnist.jl"))

const F = Float32
Expand Down Expand Up @@ -91,55 +91,15 @@ function weights(nz, nh; atype=Array{F})
return θ, ϕ
end


function plot_reconstruction(θ, ϕ, data, nimg=10)
x, _ = rand(data)
x = mat(x)
x = x[:, rand(1:size(x,2), nimg)]

μ, logσ² = encode(ϕ, x)
z = μ .+ randn!(similar(μ)) .* exp.(logσ²./2)
= decode(θ, z)

x = Array(reshape(x, 28, 28, length(x) ÷ 28^2))
= Array(reshape(x̂, 28, 28, length(x̂) ÷ 28^2))

fig = figure("reconstruction", figsize=(10,3))
clf()
for i=1:nimg
subplot(2, nimg, i)
imshow(x[:,:,i]', cmap="gray") #notice the transpose
ax = gca()
ax[:xaxis][:set_visible](false)
ax[:yaxis][:set_visible](false)

subplot(2, nimg, nimg+i)
imshow(x̂[:,:,i]', cmap="gray") #notice the transpose
ax = gca()
ax[:xaxis][:set_visible](false)
ax[:yaxis][:set_visible](false)
end
# tight_layout()
end

function plot_dream(θ, nimg=20)
function plot_dream(θ; gridsize=(5,5), scale=1.0)
nh, nz = size(θ[1])
atype = θ[1] isa KnetArray ? KnetArray : Array

m, n = gridsize
nimg = m*n
z = convert(atype, randn(F, nz, nimg))
= decode(θ, z)

= Array(reshape(x̂, 28, 28, length(x̂) ÷ 28^2))

fig = figure("dream",figsize=(6,5))
clf()
for i=1:nimg
subplot(4, nimg÷4, i)
imshow(x̂[:,:,i]', cmap="gray") #notice the transpose
ax = gca()
ax[:xaxis][:set_visible](false)
ax[:yaxis][:set_visible](false)
end
= Array(decode(θ, z))
grid = mnistgrid(x̂; gridsize=gridsize, scale=scale)
display(colorview(Gray, grid))
end

function main(args="")
Expand All @@ -154,7 +114,6 @@ function main(args="")
("--nz"; arg_type=Int; default=40; help="encoding dimention")
("--lr"; arg_type=Float64; default=1e-3; help="learning rate")
("--atype"; default=(gpu()>=0 ? "KnetArray{F}" : "Array{F}"); help="array type: Array for cpu, KnetArray for gpu")
("--verb"; arg_type=Int; default=1; help="plot dream and reconstruction if verb > 1")
("--infotime"; arg_type=Int; default=2; help="report every infotime epochs")
end
isa(args, String) && (args=split(args))
Expand All @@ -163,29 +122,25 @@ function main(args="")
return
end
o = parse_args(args, s; as_symbols=true)

atype = eval(parse(o[:atype]))
info("using ", atype)
o[:seed] > 0 && srand(o[:seed])
atype <: KnetArray && rand!(KnetArray(ones(10))) # bug #181 of Knet
atype <: KnetArray && rand!(KnetArray(ones(10))) # bug #181 of Knet

θ, ϕ = weights(o[:nz], o[:nh], atype=atype)
w = [θ; ϕ]
opt = optimizers(w, Adam, lr=o[:lr])

xtrn, ytrn, xtst, ytst = mnist()


report(epoch) = begin
dtrn = minibatch(xtrn, ytrn, o[:batchsize]; xtype=atype)
dtst = minibatch(xtst, ytst, o[:batchsize]; xtype=atype)
println((:epoch, epoch,
:trn, aveloss(θ, ϕ, dtrn),
:tst, aveloss(θ, ϕ, dtst)))
if o[:verb] > 1
plot_reconstruction(θ, ϕ, dtrn)
plot_dream(θ)
end
end

report(0); tic()
Expand Down
149 changes: 100 additions & 49 deletions examples/variational-autoencoder/vae_mnist_demo.ipynb

Large diffs are not rendered by default.

0 comments on commit 7d14a44

Please sign in to comment.