Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[MXNET-836] RNN Example for Scala #11753

Merged
merged 7 commits into from
Aug 21, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ object BucketIo {
type ReadContent = String => String

def defaultReadContent(path: String): String = {
Source.fromFile(path).mkString.replaceAll("\\. |\n", " <eos> ")
Source.fromFile(path, "UTF-8").mkString.replaceAll("\\. |\n", " <eos> ")
}

def defaultBuildVocab(path: String): Map[String, Int] = {
Expand All @@ -56,7 +56,7 @@ object BucketIo {
val tmp = sentence.split(" ").filter(_.length() > 0)
for (w <- tmp) yield theVocab(w)
}
words.toArray
words
}

def defaultGenBuckets(sentences: Array[String], batchSize: Int,
Expand Down Expand Up @@ -162,8 +162,6 @@ object BucketIo {
labelBuffer.append(NDArray.zeros(_batchSize, buckets(iBucket)))
}

private val initStateArrays = initStates.map(x => NDArray.zeros(x._2._1, x._2._2))

private val _provideData = { val tmp = ListMap("data" -> Shape(_batchSize, _defaultBucketKey))
tmp ++ initStates.map(x => x._1 -> Shape(x._2._1, x._2._2))
}
Expand Down Expand Up @@ -208,12 +206,13 @@ object BucketIo {
tmp ++ initStates.map(x => x._1 -> Shape(x._2._1, x._2._2))
}
val batchProvideLabel = ListMap("softmax_label" -> labelBuf.shape)
new DataBatch(IndexedSeq(dataBuf) ++ initStateArrays,
IndexedSeq(labelBuf),
getIndex(),
getPad(),
this.buckets(bucketIdx).asInstanceOf[AnyRef],
batchProvideData, batchProvideLabel)
val initStateArrays = initStates.map(x => NDArray.zeros(x._2._1, x._2._2))
lanking520 marked this conversation as resolved.
Show resolved Hide resolved
new DataBatch(IndexedSeq(dataBuf.copy()) ++ initStateArrays,
IndexedSeq(labelBuf.copy()),
getIndex(),
getPad(),
this.buckets(bucketIdx).asInstanceOf[AnyRef],
batchProvideData, batchProvideLabel)
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,10 @@

package org.apache.mxnetexamples.rnn

import org.apache.mxnet.Symbol
import org.apache.mxnet.{Shape, Symbol}

import scala.collection.mutable.ArrayBuffer

/**
* @author Depeng Liang
*/
object Lstm {

final case class LSTMState(c: Symbol, h: Symbol)
Expand All @@ -35,27 +32,22 @@ object Lstm {
def lstm(numHidden: Int, inData: Symbol, prevState: LSTMState,
param: LSTMParam, seqIdx: Int, layerIdx: Int, dropout: Float = 0f): LSTMState = {
val inDataa = {
if (dropout > 0f) Symbol.Dropout()()(Map("data" -> inData, "p" -> dropout))
if (dropout > 0f) Symbol.api.Dropout(data = Some(inData), p = Some(dropout))
else inData
}
val i2h = Symbol.FullyConnected(s"t${seqIdx}_l${layerIdx}_i2h")()(Map("data" -> inDataa,
"weight" -> param.i2hWeight,
"bias" -> param.i2hBias,
"num_hidden" -> numHidden * 4))
val h2h = Symbol.FullyConnected(s"t${seqIdx}_l${layerIdx}_h2h")()(Map("data" -> prevState.h,
"weight" -> param.h2hWeight,
"bias" -> param.h2hBias,
"num_hidden" -> numHidden * 4))
val i2h = Symbol.api.FullyConnected(data = Some(inDataa), weight = Some(param.i2hWeight),
bias = Some(param.i2hBias), num_hidden = numHidden * 4, name = s"t${seqIdx}_l${layerIdx}_i2h")
val h2h = Symbol.api.FullyConnected(data = Some(prevState.h), weight = Some(param.h2hWeight),
bias = Some(param.h2hBias), num_hidden = numHidden * 4, name = s"t${seqIdx}_l${layerIdx}_h2h")
val gates = i2h + h2h
val sliceGates = Symbol.SliceChannel(s"t${seqIdx}_l${layerIdx}_slice")(
gates)(Map("num_outputs" -> 4))
val ingate = Symbol.Activation()()(Map("data" -> sliceGates.get(0), "act_type" -> "sigmoid"))
val inTransform = Symbol.Activation()()(Map("data" -> sliceGates.get(1), "act_type" -> "tanh"))
val forgetGate = Symbol.Activation()()(
Map("data" -> sliceGates.get(2), "act_type" -> "sigmoid"))
val outGate = Symbol.Activation()()(Map("data" -> sliceGates.get(3), "act_type" -> "sigmoid"))
val sliceGates = Symbol.api.SliceChannel(data = Some(gates), num_outputs = 4,
name = s"t${seqIdx}_l${layerIdx}_slice")
val ingate = Symbol.api.Activation(data = Some(sliceGates.get(0)), act_type = "sigmoid")
val inTransform = Symbol.api.Activation(data = Some(sliceGates.get(1)), act_type = "tanh")
val forgetGate = Symbol.api.Activation(data = Some(sliceGates.get(2)), act_type = "sigmoid")
val outGate = Symbol.api.Activation(data = Some(sliceGates.get(3)), act_type = "sigmoid")
val nextC = (forgetGate * prevState.c) + (ingate * inTransform)
val nextH = outGate * Symbol.Activation()()(Map("data" -> nextC, "act_type" -> "tanh"))
val nextH = outGate * Symbol.api.Activation(data = Some(nextC), "tanh")
LSTMState(c = nextC, h = nextH)
}

Expand All @@ -74,11 +66,11 @@ object Lstm {
val lastStatesBuf = ArrayBuffer[LSTMState]()
for (i <- 0 until numLstmLayer) {
paramCellsBuf.append(LSTMParam(i2hWeight = Symbol.Variable(s"l${i}_i2h_weight"),
i2hBias = Symbol.Variable(s"l${i}_i2h_bias"),
h2hWeight = Symbol.Variable(s"l${i}_h2h_weight"),
h2hBias = Symbol.Variable(s"l${i}_h2h_bias")))
i2hBias = Symbol.Variable(s"l${i}_i2h_bias"),
h2hWeight = Symbol.Variable(s"l${i}_h2h_weight"),
h2hBias = Symbol.Variable(s"l${i}_h2h_bias")))
lastStatesBuf.append(LSTMState(c = Symbol.Variable(s"l${i}_init_c_beta"),
h = Symbol.Variable(s"l${i}_init_h_beta")))
h = Symbol.Variable(s"l${i}_init_h_beta")))
}
val paramCells = paramCellsBuf.toArray
val lastStates = lastStatesBuf.toArray
Expand All @@ -87,10 +79,10 @@ object Lstm {
// embeding layer
val data = Symbol.Variable("data")
var label = Symbol.Variable("softmax_label")
val embed = Symbol.Embedding("embed")()(Map("data" -> data, "input_dim" -> inputSize,
"weight" -> embedWeight, "output_dim" -> numEmbed))
val wordvec = Symbol.SliceChannel()()(
Map("data" -> embed, "num_outputs" -> seqLen, "squeeze_axis" -> 1))
val embed = Symbol.api.Embedding(data = Some(data), input_dim = inputSize,
weight = Some(embedWeight), output_dim = numEmbed, name = "embed")
val wordvec = Symbol.api.SliceChannel(data = Some(embed),
num_outputs = seqLen, squeeze_axis = Some(true))

val hiddenAll = ArrayBuffer[Symbol]()
var dpRatio = 0f
Expand All @@ -101,22 +93,23 @@ object Lstm {
for (i <- 0 until numLstmLayer) {
if (i == 0) dpRatio = 0f else dpRatio = dropout
val nextState = lstm(numHidden, inData = hidden,
prevState = lastStates(i),
param = paramCells(i),
seqIdx = seqIdx, layerIdx = i, dropout = dpRatio)
prevState = lastStates(i),
param = paramCells(i),
seqIdx = seqIdx, layerIdx = i, dropout = dpRatio)
hidden = nextState.h
lastStates(i) = nextState
}
// decoder
if (dropout > 0f) hidden = Symbol.Dropout()()(Map("data" -> hidden, "p" -> dropout))
if (dropout > 0f) hidden = Symbol.api.Dropout(data = Some(hidden), p = Some(dropout))
hiddenAll.append(hidden)
}
val hiddenConcat = Symbol.Concat()(hiddenAll: _*)(Map("dim" -> 0))
val pred = Symbol.FullyConnected("pred")()(Map("data" -> hiddenConcat, "num_hidden" -> numLabel,
"weight" -> clsWeight, "bias" -> clsBias))
label = Symbol.transpose()(label)()
label = Symbol.Reshape()()(Map("data" -> label, "target_shape" -> "(0,)"))
val sm = Symbol.SoftmaxOutput("softmax")()(Map("data" -> pred, "label" -> label))
val hiddenConcat = Symbol.api.Concat(data = hiddenAll.toArray, num_args = hiddenAll.length,
dim = Some(0))
val pred = Symbol.api.FullyConnected(data = Some(hiddenConcat), num_hidden = numLabel,
weight = Some(clsWeight), bias = Some(clsBias))
label = Symbol.api.transpose(data = Some(label))
label = Symbol.api.Reshape(data = Some(label), target_shape = Some(Shape(0)))
val sm = Symbol.api.SoftmaxOutput(data = Some(pred), label = Some(label), name = "softmax")
sm
}

Expand All @@ -131,35 +124,35 @@ object Lstm {
var lastStates = Array[LSTMState]()
for (i <- 0 until numLstmLayer) {
paramCells = paramCells :+ LSTMParam(i2hWeight = Symbol.Variable(s"l${i}_i2h_weight"),
i2hBias = Symbol.Variable(s"l${i}_i2h_bias"),
h2hWeight = Symbol.Variable(s"l${i}_h2h_weight"),
h2hBias = Symbol.Variable(s"l${i}_h2h_bias"))
i2hBias = Symbol.Variable(s"l${i}_i2h_bias"),
h2hWeight = Symbol.Variable(s"l${i}_h2h_weight"),
h2hBias = Symbol.Variable(s"l${i}_h2h_bias"))
lastStates = lastStates :+ LSTMState(c = Symbol.Variable(s"l${i}_init_c_beta"),
h = Symbol.Variable(s"l${i}_init_h_beta"))
h = Symbol.Variable(s"l${i}_init_h_beta"))
}
assert(lastStates.length == numLstmLayer)

val data = Symbol.Variable("data")

var hidden = Symbol.Embedding("embed")()(Map("data" -> data, "input_dim" -> inputSize,
"weight" -> embedWeight, "output_dim" -> numEmbed))
var hidden = Symbol.api.Embedding(data = Some(data), input_dim = inputSize,
weight = Some(embedWeight), output_dim = numEmbed, name = "embed")

var dpRatio = 0f
// stack LSTM
for (i <- 0 until numLstmLayer) {
if (i == 0) dpRatio = 0f else dpRatio = dropout
val nextState = lstm(numHidden, inData = hidden,
prevState = lastStates(i),
param = paramCells(i),
seqIdx = seqIdx, layerIdx = i, dropout = dpRatio)
prevState = lastStates(i),
param = paramCells(i),
seqIdx = seqIdx, layerIdx = i, dropout = dpRatio)
hidden = nextState.h
lastStates(i) = nextState
}
// decoder
if (dropout > 0f) hidden = Symbol.Dropout()()(Map("data" -> hidden, "p" -> dropout))
val fc = Symbol.FullyConnected("pred")()(Map("data" -> hidden, "num_hidden" -> numLabel,
"weight" -> clsWeight, "bias" -> clsBias))
val sm = Symbol.SoftmaxOutput("softmax")()(Map("data" -> fc))
if (dropout > 0f) hidden = Symbol.api.Dropout(data = Some(hidden), p = Some(dropout))
val fc = Symbol.api.FullyConnected(data = Some(hidden),
num_hidden = numLabel, weight = Some(clsWeight), bias = Some(clsBias))
val sm = Symbol.api.SoftmaxOutput(data = Some(fc), name = "softmax")
var output = Array(sm)
for (state <- lastStates) {
output = output :+ state.c
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,8 @@ import org.apache.mxnet.module.BucketingModule
import org.apache.mxnet.module.FitParams

/**
* Bucketing LSTM examples
* @author Yizhi Liu
*/
* Bucketing LSTM examples
*/
class LstmBucketing {
@Option(name = "--data-train", usage = "training set")
private val dataTrain: String = "example/rnn/sherlockholmes.train.txt"
Expand Down Expand Up @@ -61,6 +60,60 @@ object LstmBucketing {
Math.exp(loss / labelArr.length).toFloat
}

def runTraining(trainData : String, validationData : String,
lanking520 marked this conversation as resolved.
Show resolved Hide resolved
ctx : Array[Context], numEpoch : Int): Unit = {
val batchSize = 32
val buckets = Array(10, 20, 30, 40, 50, 60)
val numHidden = 200
val numEmbed = 200
val numLstmLayer = 2

logger.info("Building vocab ...")
val vocab = BucketIo.defaultBuildVocab(trainData)

def BucketSymGen(key: AnyRef):
(Symbol, IndexedSeq[String], IndexedSeq[String]) = {
val seqLen = key.asInstanceOf[Int]
val sym = Lstm.lstmUnroll(numLstmLayer, seqLen, vocab.size,
numHidden = numHidden, numEmbed = numEmbed, numLabel = vocab.size)
(sym, IndexedSeq("data"), IndexedSeq("softmax_label"))
}

val initC = (0 until numLstmLayer).map(l =>
(s"l${l}_init_c_beta", (batchSize, numHidden))
)
val initH = (0 until numLstmLayer).map(l =>
(s"l${l}_init_h_beta", (batchSize, numHidden))
)
val initStates = initC ++ initH

val dataTrain = new BucketSentenceIter(trainData, vocab,
buckets, batchSize, initStates)
val dataVal = new BucketSentenceIter(validationData, vocab,
buckets, batchSize, initStates)

val model = new BucketingModule(
symGen = BucketSymGen,
defaultBucketKey = dataTrain.defaultBucketKey,
contexts = ctx)

val fitParams = new FitParams()
fitParams.setEvalMetric(
new CustomMetric(perplexity, name = "perplexity"))
fitParams.setKVStore("device")
fitParams.setOptimizer(
new SGD(learningRate = 0.01f, momentum = 0f, wd = 0.00001f))
fitParams.setInitializer(new Xavier(factorType = "in", magnitude = 2.34f))
fitParams.setBatchEndCallback(new Speedometer(batchSize, 50))

logger.info("Start training ...")
model.fit(
trainData = dataTrain,
evalData = Some(dataVal),
numEpoch = numEpoch, fitParams)
logger.info("Finished training...")
}

def main(args: Array[String]): Unit = {
val inst = new LstmBucketing
val parser: CmdLineParser = new CmdLineParser(inst)
Expand All @@ -71,56 +124,7 @@ object LstmBucketing {
else if (inst.cpus != null) inst.cpus.split(',').map(id => Context.cpu(id.trim.toInt))
else Array(Context.cpu(0))

val batchSize = 32
val buckets = Array(10, 20, 30, 40, 50, 60)
val numHidden = 200
val numEmbed = 200
val numLstmLayer = 2

logger.info("Building vocab ...")
val vocab = BucketIo.defaultBuildVocab(inst.dataTrain)

def BucketSymGen(key: AnyRef):
(Symbol, IndexedSeq[String], IndexedSeq[String]) = {
val seqLen = key.asInstanceOf[Int]
val sym = Lstm.lstmUnroll(numLstmLayer, seqLen, vocab.size,
numHidden = numHidden, numEmbed = numEmbed, numLabel = vocab.size)
(sym, IndexedSeq("data"), IndexedSeq("softmax_label"))
}

val initC = (0 until numLstmLayer).map(l =>
(s"l${l}_init_c_beta", (batchSize, numHidden))
)
val initH = (0 until numLstmLayer).map(l =>
(s"l${l}_init_h_beta", (batchSize, numHidden))
)
val initStates = initC ++ initH

val dataTrain = new BucketSentenceIter(inst.dataTrain, vocab,
buckets, batchSize, initStates)
val dataVal = new BucketSentenceIter(inst.dataVal, vocab,
buckets, batchSize, initStates)

val model = new BucketingModule(
symGen = BucketSymGen,
defaultBucketKey = dataTrain.defaultBucketKey,
contexts = contexts)

val fitParams = new FitParams()
fitParams.setEvalMetric(
new CustomMetric(perplexity, name = "perplexity"))
fitParams.setKVStore("device")
fitParams.setOptimizer(
new SGD(learningRate = 0.01f, momentum = 0f, wd = 0.00001f))
fitParams.setInitializer(new Xavier(factorType = "in", magnitude = 2.34f))
fitParams.setBatchEndCallback(new Speedometer(batchSize, 50))

logger.info("Start training ...")
model.fit(
trainData = dataTrain,
evalData = Some(dataVal),
numEpoch = inst.numEpoch, fitParams)
logger.info("Finished training...")
runTraining(inst.dataTrain, inst.dataVal, contexts, 5)
} catch {
case ex: Exception =>
logger.error(ex.getMessage, ex)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
# RNN Example for MXNet Scala
This folder contains the following examples writing in new Scala type-safe API:
- [x] LSTM Bucketing
- [x] CharRNN Inference : Generate similar text based on the model
- [x] CharRNN Training: Training the language model using RNN

These example is only for Illustration and not modeled to achieve the best accuracy.

## Setup
### Download the Network Definition, Weights and Training Data
`obama.zip` contains the training inputs (Obama's speech) for CharCNN examples and `sherlockholmes` contains the data for LSTM Bucketing
```bash
https://s3.us-east-2.amazonaws.com/mxnet-scala/scala-example-ci/RNN/obama.zip
https://s3.us-east-2.amazonaws.com/mxnet-scala/scala-example-ci/RNN/sherlockholmes.train.txt
https://s3.us-east-2.amazonaws.com/mxnet-scala/scala-example-ci/RNN/sherlockholmes.valid.txt
```
### Unzip the file
```bash
unzip obama.zip
```
### Arguement Configuration
Then you need to define the arguments that you would like to pass in the model:

#### LSTM Bucketing
```bash
--data-train
<path>/sherlockholmes.train.txt
--data-val
<path>/sherlockholmes.valid.txt
--cpus
<num_cpus>
--gpus
<num_gpu>
```
#### TrainCharRnn
```bash
--data-path
<path>/obama.txt
--save-model-path
<path>/
```
#### TestCharRnn
```bash
--data-path
<path>/obama.txt
--model-prefix
<path>/obama
```
Loading