From 73406530f6847374f6e4c10d368e3b10046fd0e1 Mon Sep 17 00:00:00 2001 From: Qing Date: Tue, 24 Jul 2018 14:38:53 -0700 Subject: [PATCH] Gan and Customop fixes --- .../scala/org/apache/mxnetexamples/customop/Data.scala | 9 +++++++-- .../scala/org/apache/mxnetexamples/gan/GanMnist.scala | 4 +++- 2 files changed, 10 insertions(+), 3 deletions(-) diff --git a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/customop/Data.scala b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/customop/Data.scala index d61269c131ff..230c56e38678 100644 --- a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/customop/Data.scala +++ b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/customop/Data.scala @@ -20,6 +20,7 @@ package org.apache.mxnetexamples.customop import org.apache.mxnet.{DataIter, IO, Shape} object Data { + // return train and val iterators for mnist def mnistIterator(dataPath: String, batchSize: Int, inputShape: Shape): (DataIter, DataIter) = { val flat = if (inputShape.length == 3) "False" else "True" @@ -29,7 +30,9 @@ object Data { "input_shape" -> inputShape.toString(), "batch_size" -> s"$batchSize", "shuffle" -> "True", - "flat" -> flat + "flat" -> flat, + "dataLayout" -> "NT", + "labelLayout" -> "N" ) val trainDataIter = IO.MNISTIter(trainParams) val testParams = Map( @@ -37,7 +40,9 @@ object Data { "label" -> s"$dataPath/t10k-labels-idx1-ubyte", "input_shape" -> inputShape.toString(), "batch_size" -> s"$batchSize", - "flat" -> flat + "flat" -> flat, + "dataLayout" -> "NT", + "labelLayout" -> "N" ) val testDataIter = IO.MNISTIter(testParams) (trainDataIter, testDataIter) diff --git a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/gan/GanMnist.scala b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/gan/GanMnist.scala index 6186989b74f6..de2be8df3c75 100644 --- a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/gan/GanMnist.scala +++ b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/gan/GanMnist.scala @@ -130,7 +130,9 @@ object GanMnist { "label" -> s"$dataPath/train-labels-idx1-ubyte", "input_shape" -> s"(1, 28, 28)", "batch_size" -> s"$batchSize", - "shuffle" -> "True" + "shuffle" -> "True", + "dataLayout" -> "NT", + "labelLayout" -> "N" ) val mnistIter = IO.MNISTIter(params)