Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Gan and Customop fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
lanking520 committed Jul 24, 2018
1 parent 836e64f commit 7340653
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ package org.apache.mxnetexamples.customop
import org.apache.mxnet.{DataIter, IO, Shape}

object Data {

// return train and val iterators for mnist
def mnistIterator(dataPath: String, batchSize: Int, inputShape: Shape): (DataIter, DataIter) = {
val flat = if (inputShape.length == 3) "False" else "True"
Expand All @@ -29,15 +30,19 @@ object Data {
"input_shape" -> inputShape.toString(),
"batch_size" -> s"$batchSize",
"shuffle" -> "True",
"flat" -> flat
"flat" -> flat,
"dataLayout" -> "NT",
"labelLayout" -> "N"
)
val trainDataIter = IO.MNISTIter(trainParams)
val testParams = Map(
"image" -> s"$dataPath/t10k-images-idx3-ubyte",
"label" -> s"$dataPath/t10k-labels-idx1-ubyte",
"input_shape" -> inputShape.toString(),
"batch_size" -> s"$batchSize",
"flat" -> flat
"flat" -> flat,
"dataLayout" -> "NT",
"labelLayout" -> "N"
)
val testDataIter = IO.MNISTIter(testParams)
(trainDataIter, testDataIter)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,9 @@ object GanMnist {
"label" -> s"$dataPath/train-labels-idx1-ubyte",
"input_shape" -> s"(1, 28, 28)",
"batch_size" -> s"$batchSize",
"shuffle" -> "True"
"shuffle" -> "True",
"dataLayout" -> "NT",
"labelLayout" -> "N"
)

val mnistIter = IO.MNISTIter(params)
Expand Down

0 comments on commit 7340653

Please sign in to comment.