diff --git a/example/gan/CGAN_mnist_R/CGAN_mnist_setup.R b/example/gan/CGAN_mnist_R/CGAN_mnist_setup.R new file mode 100644 index 000000000000..f3ebf0fe1718 --- /dev/null +++ b/example/gan/CGAN_mnist_R/CGAN_mnist_setup.R @@ -0,0 +1,104 @@ +require("imager") +require("dplyr") +require("readr") +require("mxnet") + +source("iterators.R") + +###################################################### +### Data import and preperation +### First download MNIST train data at Kaggle: +### https://www.kaggle.com/c/digit-recognizer/data +###################################################### +train <- read_csv('data/train.csv') +train<- data.matrix(train) + +train_data <- train[,-1] +train_data <- t(train_data/255*2-1) +train_label <- as.integer(train[,1]) + +dim(train_data) <- c(28, 28, 1, ncol(train_data)) + +################################################## +#### Model parameters +################################################## +random_dim<- 96 +gen_features<- 96 +dis_features<- 32 +image_depth = 1 +fix_gamma<- T +no_bias<- T +eps<- 1e-5 + 1e-12 +batch_size<- 64 + + +################################################## +#### Generator Symbol +################################################## +data = mx.symbol.Variable('data') + +gen_rand<- mx.symbol.normal(loc=0, scale=1, shape=c(1, 1, random_dim, batch_size), name="gen_rand") +gen_concat<- mx.symbol.Concat(data = list(data, gen_rand), num.args = 2, name="gen_concat") + +g1 = mx.symbol.Deconvolution(gen_concat, name='g1', kernel=c(4,4), num_filter=gen_features*4, no_bias=T) +gbn1 = mx.symbol.BatchNorm(g1, name='gbn1', fix_gamma=fix_gamma, eps=eps) +gact1 = mx.symbol.Activation(gbn1, name='gact1', act_type='relu') + +g2 = mx.symbol.Deconvolution(gact1, name='g2', kernel=c(3,3), stride=c(2,2), pad=c(1,1), num_filter=gen_features*2, no_bias=no_bias) +gbn2 = mx.symbol.BatchNorm(g2, name='gbn2', fix_gamma=fix_gamma, eps=eps) +gact2 = mx.symbol.Activation(gbn2, name='gact2', act_type='relu') + +g3 = mx.symbol.Deconvolution(gact2, name='g3', kernel=c(4,4), stride=c(2,2), pad=c(1,1), num_filter=gen_features, no_bias=no_bias) +gbn3 = mx.symbol.BatchNorm(g3, name='gbn3', fix_gamma=fix_gamma, eps=eps) +gact3 = mx.symbol.Activation(gbn3, name='gact3', act_type='relu') + +g4 = mx.symbol.Deconvolution(gact3, name='g4', kernel=c(4,4), stride=c(2,2), pad=c(1,1), num_filter=image_depth, no_bias=no_bias) +G_sym = mx.symbol.Activation(g4, name='G_sym', act_type='tanh') + + +################################################## +#### Discriminator Symbol +################################################## +data = mx.symbol.Variable('data') +dis_digit = mx.symbol.Variable('digit') +label = mx.symbol.Variable('label') + +dis_digit<- mx.symbol.Reshape(data=dis_digit, shape=c(1,1,10,batch_size), name="digit_reshape") +dis_digit<- mx.symbol.broadcast_to(data=dis_digit, shape=c(28,28,10, batch_size), name="digit_broadcast") + +data_concat <- mx.symbol.Concat(list(data, dis_digit), num.args = 2, dim = 1, name='dflat_concat') + +d1 = mx.symbol.Convolution(data=data_concat, name='d1', kernel=c(3,3), stride=c(1,1), pad=c(0,0), num_filter=24, no_bias=no_bias) +dbn1 = mx.symbol.BatchNorm(d1, name='dbn1', fix_gamma=fix_gamma, eps=eps) +dact1 = mx.symbol.LeakyReLU(dbn1, name='dact1', act_type='elu', slope=0.25) +pool1 <- mx.symbol.Pooling(data=dact1, name="pool1", pool_type="max", kernel=c(2,2), stride=c(2,2), pad=c(0,0)) + +d2 = mx.symbol.Convolution(pool1, name='d2', kernel=c(3,3), stride=c(2,2), pad=c(0,0), num_filter=32, no_bias=no_bias) +dbn2 = mx.symbol.BatchNorm(d2, name='dbn2', fix_gamma=fix_gamma, eps=eps) +dact2 = mx.symbol.LeakyReLU(dbn2, name='dact2', act_type='elu', slope=0.25) + +d3 = mx.symbol.Convolution(dact2, name='d3', kernel=c(3,3), stride=c(1,1), pad=c(0,0), num_filter=64, no_bias=no_bias) +dbn3 = mx.symbol.BatchNorm(d3, name='dbn3', fix_gamma=fix_gamma, eps=eps) +dact3 = mx.symbol.LeakyReLU(dbn3, name='dact3', act_type='elu', slope=0.25) + +d4 = mx.symbol.Convolution(dact2, name='d3', kernel=c(4,4), stride=c(1,1), pad=c(0,0), num_filter=64, no_bias=no_bias) +dbn4 = mx.symbol.BatchNorm(d4, name='dbn4', fix_gamma=fix_gamma, eps=eps) +dact4 = mx.symbol.LeakyReLU(dbn4, name='dact4', act_type='elu', slope=0.25) + +# pool4 <- mx.symbol.Pooling(data=dact3, name="pool4", pool_type="avg", kernel=c(4,4), stride=c(1,1), pad=c(0,0)) + +dflat = mx.symbol.Flatten(dact4, name="dflat") + +dfc <- mx.symbol.FullyConnected(data=dflat, name="dfc", num_hidden=1, no_bias=F) +D_sym = mx.symbol.LogisticRegressionOutput(data=dfc, label=label, name='D_sym') + + +######################## +### Graph +######################## +input_shape_G<- c(1, 1, 10, batch_size) +input_shape_D<- c(28, 28, 1, batch_size) + +graph.viz(G_sym, type = "graph", direction = "LR") +graph.viz(D_sym, type = "graph", direction = "LR") + diff --git a/example/gan/CGAN_mnist_R/CGAN_train.R b/example/gan/CGAN_mnist_R/CGAN_train.R new file mode 100644 index 000000000000..6778d6b9c2b1 --- /dev/null +++ b/example/gan/CGAN_mnist_R/CGAN_train.R @@ -0,0 +1,182 @@ +##################################################### +### Training module for GAN +##################################################### + +devices<- mx.cpu() + +data_shape_G<- c(1, 1, 10, batch_size) +data_shape_D<- c(28, 28, 1, batch_size) +digit_shape_D<- c(10, batch_size) + +mx.metric.binacc <- mx.metric.custom("binacc", function(label, pred) { + res <- mean(label==round(pred)) + return(res) +}) + +mx.metric.logloss <- mx.metric.custom("logloss", function(label, pred) { + res <- mean(label*log(pred)+(1-label)*log(1-pred)) + return(res) +}) + +############################################## +### Define iterators +iter_G<- G_iterator(batch_size = batch_size) +iter_D<- D_iterator(batch_size = batch_size) + +exec_G<- mx.simple.bind(symbol = G_sym, data=data_shape_G, ctx = devices, grad.req = "write") +exec_D<- mx.simple.bind(symbol = D_sym, data=data_shape_D, digit=digit_shape_D, ctx = devices, grad.req = "write") + +### initialize parameters - To Do - personalise each layer +initializer<- mx.init.Xavier(rnd_type = "gaussian", factor_type = "avg", magnitude = 3) + +arg_param_ini_G<- mx.init.create(initializer = initializer, shape.array = mx.symbol.infer.shape(G_sym, data=data_shape_G)$arg.shapes, ctx = mx.cpu()) +aux_param_ini_G<- mx.init.create(initializer = initializer, shape.array = mx.symbol.infer.shape(G_sym, data=data_shape_G)$aux.shapes, ctx = mx.cpu()) + +arg_param_ini_D<- mx.init.create(initializer = initializer, shape.array = mx.symbol.infer.shape(D_sym, data=data_shape_D, digit=digit_shape_D)$arg.shapes, ctx = mx.cpu()) +aux_param_ini_D<- mx.init.create(initializer = initializer, shape.array = mx.symbol.infer.shape(D_sym, data=data_shape_D, digit=digit_shape_D)$aux.shapes, ctx = mx.cpu()) + +mx.exec.update.arg.arrays(exec_G, arg_param_ini_G, match.name=TRUE) +mx.exec.update.aux.arrays(exec_G, aux_param_ini_G, match.name=TRUE) + +mx.exec.update.arg.arrays(exec_D, arg_param_ini_D, match.name=TRUE) +mx.exec.update.aux.arrays(exec_D, aux_param_ini_D, match.name=TRUE) + +input_names_G <- mxnet:::mx.model.check.arguments(G_sym) +input_names_D <- mxnet:::mx.model.check.arguments(D_sym) + + +################################################### +#initialize optimizers +optimizer_G<-mx.opt.create(name = "adadelta", + rho=0.92, + epsilon = 1e-6, + wd=0, + rescale.grad=1/batch_size, + clip_gradient=1) + +updater_G<- mx.opt.get.updater(optimizer = optimizer_G, weights = exec_G$ref.arg.arrays) + +optimizer_D<-mx.opt.create(name = "adadelta", + rho=0.92, + epsilon = 1e-6, + wd=0, + rescale.grad=1/batch_size, + clip_gradient=1) +updater_D<- mx.opt.get.updater(optimizer = optimizer_D, weights = exec_D$ref.arg.arrays) + +#################################### +#initialize metric +metric_G<- mx.metric.binacc +metric_G_value<- metric_G$init() + +metric_D<- mx.metric.binacc +metric_D_value<- metric_D$init() + +iteration<- 1 +iter_G$reset() +iter_D$reset() + + +for (iteration in 1:2400) { + + iter_G$iter.next() + iter_D$iter.next() + + ### Random input to Generator to produce fake sample + G_values <- iter_G$value() + G_data <- G_values[input_names_G] + mx.exec.update.arg.arrays(exec_G, arg.arrays = G_data, match.name=TRUE) + mx.exec.forward(exec_G, is.train=T) + + ### Feed Discriminator with Concatenated Generator images and real images + ### Random input to Generator + D_data_fake <- exec_G$ref.outputs$G_sym_output + D_digit_fake <- G_values$data %>% mx.nd.Reshape(shape=c(-1, batch_size)) + + D_values <- iter_D$value() + D_data_real <- D_values$data + D_digit_real <- D_values$digit + + ### Train loop on fake + mx.exec.update.arg.arrays(exec_D, arg.arrays = list(data=D_data_fake, digit=D_digit_fake, label=mx.nd.array(rep(0, batch_size))), match.name=TRUE) + mx.exec.forward(exec_D, is.train=T) + mx.exec.backward(exec_D) + update_args_D<- updater_D(weight = exec_D$ref.arg.arrays, grad = exec_D$ref.grad.arrays) + mx.exec.update.arg.arrays(exec_D, update_args_D, skip.null=TRUE) + + metric_D_value <- metric_D$update(label = mx.nd.array(rep(0, batch_size)), exec_D$ref.outputs[["D_sym_output"]], metric_D_value) + + ### Train loop on real + mx.exec.update.arg.arrays(exec_D, arg.arrays = list(data=D_data_real, digit=D_digit_real, label=mx.nd.array(rep(1, batch_size))), match.name=TRUE) + mx.exec.forward(exec_D, is.train=T) + mx.exec.backward(exec_D) + update_args_D<- updater_D(weight = exec_D$ref.arg.arrays, grad = exec_D$ref.grad.arrays) + mx.exec.update.arg.arrays(exec_D, update_args_D, skip.null=TRUE) + + metric_D_value <- metric_D$update(mx.nd.array(rep(1, batch_size)), exec_D$ref.outputs[["D_sym_output"]], metric_D_value) + + ### Update Generator weights - use a seperate executor for writing data gradients + exec_D_back<- mxnet:::mx.symbol.bind(symbol = D_sym, arg.arrays = exec_D$arg.arrays, aux.arrays = exec_D$aux.arrays, grad.reqs = rep("write", length(exec_D$arg.arrays)), ctx = devices) + mx.exec.update.arg.arrays(exec_D_back, arg.arrays = list(data=D_data_fake, digit=D_digit_fake, label=mx.nd.array(rep(1, batch_size))), match.name=TRUE) + mx.exec.forward(exec_D_back, is.train=T) + mx.exec.backward(exec_D_back) + D_grads<- exec_D_back$ref.grad.arrays$data + mx.exec.backward(exec_G, out_grads=D_grads) + + update_args_G<- updater_G(weight = exec_G$ref.arg.arrays, grad = exec_G$ref.grad.arrays) + mx.exec.update.arg.arrays(exec_G, update_args_G, skip.null=TRUE) + + ### Update metrics + #metric_G_value <- metric_G$update(values[[label_name]], exec_G$ref.outputs[[output_name]], metric_G_value) + + if (iteration %% 25==0){ + D_metric_result <- metric_D$get(metric_D_value) + cat(paste0("[", iteration, "] ", D_metric_result$name, ": ", D_metric_result$value, "\n")) + } + + if (iteration==1 | iteration %% 100==0){ + + metric_D_value<- metric_D$init() + + par(mfrow=c(3,3), mar=c(0.1,0.1,0.1,0.1)) + for (i in 1:9) { + img <- as.array(exec_G$ref.outputs$G_sym_output)[,,,i] + plot(as.cimg(img), axes=F) + } + + print(as.numeric(as.array(G_values$digit))) + print(as.numeric(as.array(D_values$label))) + + } +} + +mx.symbol.save(D_sym, filename = "models/D_sym_model_v1.json") +mx.nd.save(exec_D$arg.arrays, filename = "models/D_aux_params_v1.params") +mx.nd.save(exec_D$aux.arrays, filename = "models/D_aux_params_v1.params") + +mx.symbol.save(G_sym, filename = "models/G_sym_model_v1.json") +mx.nd.save(exec_G$arg.arrays, filename = "models/G_arg_params_v1.params") +mx.nd.save(exec_G$aux.arrays, filename = "models/G_aux_params_v1.params") + + +### Inference +G_sym<- mx.symbol.load("models/G_sym_model_v1.json") +G_arg_params<- mx.nd.load("models/G_arg_params_v1.params") +G_aux_params<- mx.nd.load("models/G_aux_params_v1.params") + +digit<- mx.nd.array(rep(9, times=batch_size)) +data<- mx.nd.one.hot(indices = digit, depth = 10) +data<- mx.nd.reshape(data = data, shape = c(1,1,-1, batch_size)) + +exec_G<- mx.simple.bind(symbol = G_sym, data=data_shape_G, ctx = devices, grad.req = "null") +mx.exec.update.arg.arrays(exec_G, G_arg_params, match.name=TRUE) +mx.exec.update.arg.arrays(exec_G, list(data=data), match.name=TRUE) +mx.exec.update.aux.arrays(exec_G, G_aux_params, match.name=TRUE) + +mx.exec.forward(exec_G, is.train=F) + +par(mfrow=c(3,3), mar=c(0.1,0.1,0.1,0.1)) +for (i in 1:9) { + img <- as.array(exec_G$ref.outputs$G_sym_output)[,,,i] + plot(as.cimg(img), axes=F) +} diff --git a/example/gan/CGAN_mnist_R/iterators.R b/example/gan/CGAN_mnist_R/iterators.R new file mode 100644 index 000000000000..fa113c554b75 --- /dev/null +++ b/example/gan/CGAN_mnist_R/iterators.R @@ -0,0 +1,62 @@ + +G_iterator<- function(batch_size){ + + batch<- 0 + batch_per_epoch<-5 + + reset<- function(){ + batch<<- 0 + } + + iter.next<- function(){ + batch<<- batch+1 + if (batch>batch_per_epoch) { + return(FALSE) + } else { + return(TRUE) + } + } + + value<- function(){ + set.seed(123+batch) + digit<- mx.nd.array(sample(0:9, size = batch_size, replace = T)) + data<- mx.nd.one.hot(indices = digit, depth = 10) + data<- mx.nd.reshape(data = data, shape = c(1,1,-1, batch_size)) + return(list(data=data, digit=digit)) + } + + return(list(reset=reset, iter.next=iter.next, value=value, batch_size=batch_size, batch=batch)) +} + +D_iterator<- function(batch_size){ + + batch<- 0 + batch_per_epoch<-5 + + reset<- function(){ + batch<<- 0 + } + + iter.next<- function(){ + batch<<- batch+1 + if (batch>batch_per_epoch) { + return(FALSE) + } else { + return(TRUE) + } + } + + value<- function(){ + set.seed(123+batch) + idx<- sample(length(train_label), size = batch_size, replace = T) + data<- train_data[,,,idx, drop=F] + label<- mx.nd.array(train_label[idx]) + digit<- mx.nd.one.hot(indices = label, depth = 10) + + return(list(data=mx.nd.array(data), digit=digit, label=label)) + } + + return(list(reset=reset, iter.next=iter.next, value=value, batch_size=batch_size, batch=batch)) +} + +