Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
[R] CGAN R demo scripts. close #6040 (#6551)
Browse files Browse the repository at this point in the history
* add cgan R demo scripts
  • Loading branch information
jeremiedb authored and thirdwing committed Jun 3, 2017
1 parent 3959d1e commit 728d632
Show file tree
Hide file tree
Showing 3 changed files with 348 additions and 0 deletions.
104 changes: 104 additions & 0 deletions example/gan/CGAN_mnist_R/CGAN_mnist_setup.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
require("imager")
require("dplyr")
require("readr")
require("mxnet")

source("iterators.R")

######################################################
### Data import and preperation
### First download MNIST train data at Kaggle:
### https://www.kaggle.com/c/digit-recognizer/data
######################################################
train <- read_csv('data/train.csv')
train<- data.matrix(train)

train_data <- train[,-1]
train_data <- t(train_data/255*2-1)
train_label <- as.integer(train[,1])

dim(train_data) <- c(28, 28, 1, ncol(train_data))

##################################################
#### Model parameters
##################################################
random_dim<- 96
gen_features<- 96
dis_features<- 32
image_depth = 1
fix_gamma<- T
no_bias<- T
eps<- 1e-5 + 1e-12
batch_size<- 64


##################################################
#### Generator Symbol
##################################################
data = mx.symbol.Variable('data')

gen_rand<- mx.symbol.normal(loc=0, scale=1, shape=c(1, 1, random_dim, batch_size), name="gen_rand")
gen_concat<- mx.symbol.Concat(data = list(data, gen_rand), num.args = 2, name="gen_concat")

g1 = mx.symbol.Deconvolution(gen_concat, name='g1', kernel=c(4,4), num_filter=gen_features*4, no_bias=T)
gbn1 = mx.symbol.BatchNorm(g1, name='gbn1', fix_gamma=fix_gamma, eps=eps)
gact1 = mx.symbol.Activation(gbn1, name='gact1', act_type='relu')

g2 = mx.symbol.Deconvolution(gact1, name='g2', kernel=c(3,3), stride=c(2,2), pad=c(1,1), num_filter=gen_features*2, no_bias=no_bias)
gbn2 = mx.symbol.BatchNorm(g2, name='gbn2', fix_gamma=fix_gamma, eps=eps)
gact2 = mx.symbol.Activation(gbn2, name='gact2', act_type='relu')

g3 = mx.symbol.Deconvolution(gact2, name='g3', kernel=c(4,4), stride=c(2,2), pad=c(1,1), num_filter=gen_features, no_bias=no_bias)
gbn3 = mx.symbol.BatchNorm(g3, name='gbn3', fix_gamma=fix_gamma, eps=eps)
gact3 = mx.symbol.Activation(gbn3, name='gact3', act_type='relu')

g4 = mx.symbol.Deconvolution(gact3, name='g4', kernel=c(4,4), stride=c(2,2), pad=c(1,1), num_filter=image_depth, no_bias=no_bias)
G_sym = mx.symbol.Activation(g4, name='G_sym', act_type='tanh')


##################################################
#### Discriminator Symbol
##################################################
data = mx.symbol.Variable('data')
dis_digit = mx.symbol.Variable('digit')
label = mx.symbol.Variable('label')

dis_digit<- mx.symbol.Reshape(data=dis_digit, shape=c(1,1,10,batch_size), name="digit_reshape")
dis_digit<- mx.symbol.broadcast_to(data=dis_digit, shape=c(28,28,10, batch_size), name="digit_broadcast")

data_concat <- mx.symbol.Concat(list(data, dis_digit), num.args = 2, dim = 1, name='dflat_concat')

d1 = mx.symbol.Convolution(data=data_concat, name='d1', kernel=c(3,3), stride=c(1,1), pad=c(0,0), num_filter=24, no_bias=no_bias)
dbn1 = mx.symbol.BatchNorm(d1, name='dbn1', fix_gamma=fix_gamma, eps=eps)
dact1 = mx.symbol.LeakyReLU(dbn1, name='dact1', act_type='elu', slope=0.25)
pool1 <- mx.symbol.Pooling(data=dact1, name="pool1", pool_type="max", kernel=c(2,2), stride=c(2,2), pad=c(0,0))

d2 = mx.symbol.Convolution(pool1, name='d2', kernel=c(3,3), stride=c(2,2), pad=c(0,0), num_filter=32, no_bias=no_bias)
dbn2 = mx.symbol.BatchNorm(d2, name='dbn2', fix_gamma=fix_gamma, eps=eps)
dact2 = mx.symbol.LeakyReLU(dbn2, name='dact2', act_type='elu', slope=0.25)

d3 = mx.symbol.Convolution(dact2, name='d3', kernel=c(3,3), stride=c(1,1), pad=c(0,0), num_filter=64, no_bias=no_bias)
dbn3 = mx.symbol.BatchNorm(d3, name='dbn3', fix_gamma=fix_gamma, eps=eps)
dact3 = mx.symbol.LeakyReLU(dbn3, name='dact3', act_type='elu', slope=0.25)

d4 = mx.symbol.Convolution(dact2, name='d3', kernel=c(4,4), stride=c(1,1), pad=c(0,0), num_filter=64, no_bias=no_bias)
dbn4 = mx.symbol.BatchNorm(d4, name='dbn4', fix_gamma=fix_gamma, eps=eps)
dact4 = mx.symbol.LeakyReLU(dbn4, name='dact4', act_type='elu', slope=0.25)

# pool4 <- mx.symbol.Pooling(data=dact3, name="pool4", pool_type="avg", kernel=c(4,4), stride=c(1,1), pad=c(0,0))

dflat = mx.symbol.Flatten(dact4, name="dflat")

dfc <- mx.symbol.FullyConnected(data=dflat, name="dfc", num_hidden=1, no_bias=F)
D_sym = mx.symbol.LogisticRegressionOutput(data=dfc, label=label, name='D_sym')


########################
### Graph
########################
input_shape_G<- c(1, 1, 10, batch_size)
input_shape_D<- c(28, 28, 1, batch_size)

graph.viz(G_sym, type = "graph", direction = "LR")
graph.viz(D_sym, type = "graph", direction = "LR")

182 changes: 182 additions & 0 deletions example/gan/CGAN_mnist_R/CGAN_train.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,182 @@
#####################################################
### Training module for GAN
#####################################################

devices<- mx.cpu()

data_shape_G<- c(1, 1, 10, batch_size)
data_shape_D<- c(28, 28, 1, batch_size)
digit_shape_D<- c(10, batch_size)

mx.metric.binacc <- mx.metric.custom("binacc", function(label, pred) {
res <- mean(label==round(pred))
return(res)
})

mx.metric.logloss <- mx.metric.custom("logloss", function(label, pred) {
res <- mean(label*log(pred)+(1-label)*log(1-pred))
return(res)
})

##############################################
### Define iterators
iter_G<- G_iterator(batch_size = batch_size)
iter_D<- D_iterator(batch_size = batch_size)

exec_G<- mx.simple.bind(symbol = G_sym, data=data_shape_G, ctx = devices, grad.req = "write")
exec_D<- mx.simple.bind(symbol = D_sym, data=data_shape_D, digit=digit_shape_D, ctx = devices, grad.req = "write")

### initialize parameters - To Do - personalise each layer
initializer<- mx.init.Xavier(rnd_type = "gaussian", factor_type = "avg", magnitude = 3)

arg_param_ini_G<- mx.init.create(initializer = initializer, shape.array = mx.symbol.infer.shape(G_sym, data=data_shape_G)$arg.shapes, ctx = mx.cpu())
aux_param_ini_G<- mx.init.create(initializer = initializer, shape.array = mx.symbol.infer.shape(G_sym, data=data_shape_G)$aux.shapes, ctx = mx.cpu())

arg_param_ini_D<- mx.init.create(initializer = initializer, shape.array = mx.symbol.infer.shape(D_sym, data=data_shape_D, digit=digit_shape_D)$arg.shapes, ctx = mx.cpu())
aux_param_ini_D<- mx.init.create(initializer = initializer, shape.array = mx.symbol.infer.shape(D_sym, data=data_shape_D, digit=digit_shape_D)$aux.shapes, ctx = mx.cpu())

mx.exec.update.arg.arrays(exec_G, arg_param_ini_G, match.name=TRUE)
mx.exec.update.aux.arrays(exec_G, aux_param_ini_G, match.name=TRUE)

mx.exec.update.arg.arrays(exec_D, arg_param_ini_D, match.name=TRUE)
mx.exec.update.aux.arrays(exec_D, aux_param_ini_D, match.name=TRUE)

input_names_G <- mxnet:::mx.model.check.arguments(G_sym)
input_names_D <- mxnet:::mx.model.check.arguments(D_sym)


###################################################
#initialize optimizers
optimizer_G<-mx.opt.create(name = "adadelta",
rho=0.92,
epsilon = 1e-6,
wd=0,
rescale.grad=1/batch_size,
clip_gradient=1)

updater_G<- mx.opt.get.updater(optimizer = optimizer_G, weights = exec_G$ref.arg.arrays)

optimizer_D<-mx.opt.create(name = "adadelta",
rho=0.92,
epsilon = 1e-6,
wd=0,
rescale.grad=1/batch_size,
clip_gradient=1)
updater_D<- mx.opt.get.updater(optimizer = optimizer_D, weights = exec_D$ref.arg.arrays)

####################################
#initialize metric
metric_G<- mx.metric.binacc
metric_G_value<- metric_G$init()

metric_D<- mx.metric.binacc
metric_D_value<- metric_D$init()

iteration<- 1
iter_G$reset()
iter_D$reset()


for (iteration in 1:2400) {

iter_G$iter.next()
iter_D$iter.next()

### Random input to Generator to produce fake sample
G_values <- iter_G$value()
G_data <- G_values[input_names_G]
mx.exec.update.arg.arrays(exec_G, arg.arrays = G_data, match.name=TRUE)
mx.exec.forward(exec_G, is.train=T)

### Feed Discriminator with Concatenated Generator images and real images
### Random input to Generator
D_data_fake <- exec_G$ref.outputs$G_sym_output
D_digit_fake <- G_values$data %>% mx.nd.Reshape(shape=c(-1, batch_size))

D_values <- iter_D$value()
D_data_real <- D_values$data
D_digit_real <- D_values$digit

### Train loop on fake
mx.exec.update.arg.arrays(exec_D, arg.arrays = list(data=D_data_fake, digit=D_digit_fake, label=mx.nd.array(rep(0, batch_size))), match.name=TRUE)
mx.exec.forward(exec_D, is.train=T)
mx.exec.backward(exec_D)
update_args_D<- updater_D(weight = exec_D$ref.arg.arrays, grad = exec_D$ref.grad.arrays)
mx.exec.update.arg.arrays(exec_D, update_args_D, skip.null=TRUE)

metric_D_value <- metric_D$update(label = mx.nd.array(rep(0, batch_size)), exec_D$ref.outputs[["D_sym_output"]], metric_D_value)

### Train loop on real
mx.exec.update.arg.arrays(exec_D, arg.arrays = list(data=D_data_real, digit=D_digit_real, label=mx.nd.array(rep(1, batch_size))), match.name=TRUE)
mx.exec.forward(exec_D, is.train=T)
mx.exec.backward(exec_D)
update_args_D<- updater_D(weight = exec_D$ref.arg.arrays, grad = exec_D$ref.grad.arrays)
mx.exec.update.arg.arrays(exec_D, update_args_D, skip.null=TRUE)

metric_D_value <- metric_D$update(mx.nd.array(rep(1, batch_size)), exec_D$ref.outputs[["D_sym_output"]], metric_D_value)

### Update Generator weights - use a seperate executor for writing data gradients
exec_D_back<- mxnet:::mx.symbol.bind(symbol = D_sym, arg.arrays = exec_D$arg.arrays, aux.arrays = exec_D$aux.arrays, grad.reqs = rep("write", length(exec_D$arg.arrays)), ctx = devices)
mx.exec.update.arg.arrays(exec_D_back, arg.arrays = list(data=D_data_fake, digit=D_digit_fake, label=mx.nd.array(rep(1, batch_size))), match.name=TRUE)
mx.exec.forward(exec_D_back, is.train=T)
mx.exec.backward(exec_D_back)
D_grads<- exec_D_back$ref.grad.arrays$data
mx.exec.backward(exec_G, out_grads=D_grads)

update_args_G<- updater_G(weight = exec_G$ref.arg.arrays, grad = exec_G$ref.grad.arrays)
mx.exec.update.arg.arrays(exec_G, update_args_G, skip.null=TRUE)

### Update metrics
#metric_G_value <- metric_G$update(values[[label_name]], exec_G$ref.outputs[[output_name]], metric_G_value)

if (iteration %% 25==0){
D_metric_result <- metric_D$get(metric_D_value)
cat(paste0("[", iteration, "] ", D_metric_result$name, ": ", D_metric_result$value, "\n"))
}

if (iteration==1 | iteration %% 100==0){

metric_D_value<- metric_D$init()

par(mfrow=c(3,3), mar=c(0.1,0.1,0.1,0.1))
for (i in 1:9) {
img <- as.array(exec_G$ref.outputs$G_sym_output)[,,,i]
plot(as.cimg(img), axes=F)
}

print(as.numeric(as.array(G_values$digit)))
print(as.numeric(as.array(D_values$label)))

}
}

mx.symbol.save(D_sym, filename = "models/D_sym_model_v1.json")
mx.nd.save(exec_D$arg.arrays, filename = "models/D_aux_params_v1.params")
mx.nd.save(exec_D$aux.arrays, filename = "models/D_aux_params_v1.params")

mx.symbol.save(G_sym, filename = "models/G_sym_model_v1.json")
mx.nd.save(exec_G$arg.arrays, filename = "models/G_arg_params_v1.params")
mx.nd.save(exec_G$aux.arrays, filename = "models/G_aux_params_v1.params")


### Inference
G_sym<- mx.symbol.load("models/G_sym_model_v1.json")
G_arg_params<- mx.nd.load("models/G_arg_params_v1.params")
G_aux_params<- mx.nd.load("models/G_aux_params_v1.params")

digit<- mx.nd.array(rep(9, times=batch_size))
data<- mx.nd.one.hot(indices = digit, depth = 10)
data<- mx.nd.reshape(data = data, shape = c(1,1,-1, batch_size))

exec_G<- mx.simple.bind(symbol = G_sym, data=data_shape_G, ctx = devices, grad.req = "null")
mx.exec.update.arg.arrays(exec_G, G_arg_params, match.name=TRUE)
mx.exec.update.arg.arrays(exec_G, list(data=data), match.name=TRUE)
mx.exec.update.aux.arrays(exec_G, G_aux_params, match.name=TRUE)

mx.exec.forward(exec_G, is.train=F)

par(mfrow=c(3,3), mar=c(0.1,0.1,0.1,0.1))
for (i in 1:9) {
img <- as.array(exec_G$ref.outputs$G_sym_output)[,,,i]
plot(as.cimg(img), axes=F)
}
62 changes: 62 additions & 0 deletions example/gan/CGAN_mnist_R/iterators.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@

G_iterator<- function(batch_size){

batch<- 0
batch_per_epoch<-5

reset<- function(){
batch<<- 0
}

iter.next<- function(){
batch<<- batch+1
if (batch>batch_per_epoch) {
return(FALSE)
} else {
return(TRUE)
}
}

value<- function(){
set.seed(123+batch)
digit<- mx.nd.array(sample(0:9, size = batch_size, replace = T))
data<- mx.nd.one.hot(indices = digit, depth = 10)
data<- mx.nd.reshape(data = data, shape = c(1,1,-1, batch_size))
return(list(data=data, digit=digit))
}

return(list(reset=reset, iter.next=iter.next, value=value, batch_size=batch_size, batch=batch))
}

D_iterator<- function(batch_size){

batch<- 0
batch_per_epoch<-5

reset<- function(){
batch<<- 0
}

iter.next<- function(){
batch<<- batch+1
if (batch>batch_per_epoch) {
return(FALSE)
} else {
return(TRUE)
}
}

value<- function(){
set.seed(123+batch)
idx<- sample(length(train_label), size = batch_size, replace = T)
data<- train_data[,,,idx, drop=F]
label<- mx.nd.array(train_label[idx])
digit<- mx.nd.one.hot(indices = label, depth = 10)

return(list(data=mx.nd.array(data), digit=digit, label=label))
}

return(list(reset=reset, iter.next=iter.next, value=value, batch_size=batch_size, batch=batch))
}


0 comments on commit 728d632

Please sign in to comment.