From b6c971f4533d79c189b9f6149cf6c59b33a68d3e Mon Sep 17 00:00:00 2001 From: terrytangyuan Date: Sun, 20 Dec 2015 21:06:54 -0600 Subject: [PATCH 1/5] Added class for learning rate scheduler --- .../scala/ml/dmlc/mxnet/LRScheduler.scala | 47 +++++++++++++++++++ .../main/scala/ml/dmlc/mxnet/NDArray.scala | 2 +- 2 files changed, 48 insertions(+), 1 deletion(-) create mode 100644 scala-package/core/src/main/scala/ml/dmlc/mxnet/LRScheduler.scala diff --git a/scala-package/core/src/main/scala/ml/dmlc/mxnet/LRScheduler.scala b/scala-package/core/src/main/scala/ml/dmlc/mxnet/LRScheduler.scala new file mode 100644 index 000000000000..d63bccf1564a --- /dev/null +++ b/scala-package/core/src/main/scala/ml/dmlc/mxnet/LRScheduler.scala @@ -0,0 +1,47 @@ +package ml.dmlc.mxnet +import org.slf4j.LoggerFactory + +/** + * Class for learning rate scheduler, which adaptively changes the learning rate + * based on the training progress. + * @author Yuan Tang + */ + + +abstract class LRScheduler(var baseLR: Float = 0.01) { + /** + * Base class of a learning rate scheduler + */ + + def apply(numUpdate: Int): Unit +} + +class FactorScheduler(var step: Int, var factor: Float) extends LRScheduler { + /** + * Class for reducing learning rate in factor + */ + + var count: Int = 0 + private val logger = LoggerFactory.getLogger(classOf[FactorScheduler]) + + if (step < 1) { + throw new IllegalArgumentException("Schedule step must be greater or equal than 1 round") + } + if (factor >= 1.0) { + throw new IllegalArgumentException("Factor must be less than 1 to make lr reduce") + } + + def apply(numUpdate: Int): Float = { + + if (numUpdate > this.count + this.step) { + this.count += this.step + this.baseLR *= this.factor + this.logger.info(s"""Update$numUpdate: Change learning rate to ${this.baseLR}%.5f""") + } + this.baseLR + } +} + + + + diff --git a/scala-package/core/src/main/scala/ml/dmlc/mxnet/NDArray.scala b/scala-package/core/src/main/scala/ml/dmlc/mxnet/NDArray.scala index a36c87292d64..9b8a55b99c77 100644 --- a/scala-package/core/src/main/scala/ml/dmlc/mxnet/NDArray.scala +++ b/scala-package/core/src/main/scala/ml/dmlc/mxnet/NDArray.scala @@ -7,7 +7,7 @@ import scala.collection.mutable.{ArrayBuffer, ListBuffer} /** * NDArray API of mxnet - * @author Yizhi Liu, Terry Tang + * @author Yizhi Liu, Yuan Tang */ object NDArray { private val logger = LoggerFactory.getLogger(classOf[NDArray]) From 15790ba0ad3f3ead8f0f6aa68a1226ca9772538f Mon Sep 17 00:00:00 2001 From: terrytangyuan Date: Sun, 20 Dec 2015 21:16:41 -0600 Subject: [PATCH 2/5] Documentation for LRScheduler --- .../scala/ml/dmlc/mxnet/LRScheduler.scala | 26 ++++++++++++++----- 1 file changed, 20 insertions(+), 6 deletions(-) diff --git a/scala-package/core/src/main/scala/ml/dmlc/mxnet/LRScheduler.scala b/scala-package/core/src/main/scala/ml/dmlc/mxnet/LRScheduler.scala index d63bccf1564a..10bc6bd77817 100644 --- a/scala-package/core/src/main/scala/ml/dmlc/mxnet/LRScheduler.scala +++ b/scala-package/core/src/main/scala/ml/dmlc/mxnet/LRScheduler.scala @@ -2,24 +2,38 @@ package ml.dmlc.mxnet import org.slf4j.LoggerFactory /** - * Class for learning rate scheduler, which adaptively changes the learning rate + * Learning rate scheduler, which adaptively changes the learning rate * based on the training progress. * @author Yuan Tang */ - abstract class LRScheduler(var baseLR: Float = 0.01) { /** * Base class of a learning rate scheduler + * + * The training progress is presented by `num_update`, which can be roughly + * viewed as the number of minibatches executed so far. Its value is + * non-decreasing, and increases at most by one. + * + * The exact value is the upper bound of the number of updates applied to + * a weight/index. + * + * @param numUpdate Int, the maximal number of updates applied to a weight. */ - def apply(numUpdate: Int): Unit } +/** + * Class for reducing learning rate in factor + * + * Assume the weight has been updated by n times, then the learning rate will + * be base_lr * factor^^(floor(n/step)) + * + * @param step Int, schedule learning rate after n updates + * @param factor Float, the factor for reducing the learning rate + * + */ class FactorScheduler(var step: Int, var factor: Float) extends LRScheduler { - /** - * Class for reducing learning rate in factor - */ var count: Int = 0 private val logger = LoggerFactory.getLogger(classOf[FactorScheduler]) From f980f05eba49bb93fb33d96adb53ce8f5274a850 Mon Sep 17 00:00:00 2001 From: terrytangyuan Date: Sun, 20 Dec 2015 21:23:45 -0600 Subject: [PATCH 3/5] Fixed type mismatch --- .../src/main/scala/ml/dmlc/mxnet/LRScheduler.scala | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/scala-package/core/src/main/scala/ml/dmlc/mxnet/LRScheduler.scala b/scala-package/core/src/main/scala/ml/dmlc/mxnet/LRScheduler.scala index 10bc6bd77817..adf8070e39e3 100644 --- a/scala-package/core/src/main/scala/ml/dmlc/mxnet/LRScheduler.scala +++ b/scala-package/core/src/main/scala/ml/dmlc/mxnet/LRScheduler.scala @@ -7,7 +7,7 @@ import org.slf4j.LoggerFactory * @author Yuan Tang */ -abstract class LRScheduler(var baseLR: Float = 0.01) { +abstract class LRScheduler(var baseLR: Double = 0.01) { /** * Base class of a learning rate scheduler * @@ -20,7 +20,7 @@ abstract class LRScheduler(var baseLR: Float = 0.01) { * * @param numUpdate Int, the maximal number of updates applied to a weight. */ - def apply(numUpdate: Int): Unit + def apply(numUpdate: Int): Double } /** @@ -45,17 +45,14 @@ class FactorScheduler(var step: Int, var factor: Float) extends LRScheduler { throw new IllegalArgumentException("Factor must be less than 1 to make lr reduce") } - def apply(numUpdate: Int): Float = { + def apply(numUpdate: Int): Double = { if (numUpdate > this.count + this.step) { this.count += this.step this.baseLR *= this.factor - this.logger.info(s"""Update$numUpdate: Change learning rate to ${this.baseLR}%.5f""") + this.logger.info(s"""Update$numUpdate: Change learning rate to ${this.baseLR}""") } this.baseLR } } - - - From 68dbd8eba02741dc2e72b5670d7c6bdf630482a4 Mon Sep 17 00:00:00 2001 From: terrytangyuan Date: Sun, 20 Dec 2015 22:06:06 -0600 Subject: [PATCH 4/5] Change error throw to require --- .../core/src/main/scala/ml/dmlc/mxnet/LRScheduler.scala | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/scala-package/core/src/main/scala/ml/dmlc/mxnet/LRScheduler.scala b/scala-package/core/src/main/scala/ml/dmlc/mxnet/LRScheduler.scala index adf8070e39e3..bd99b49e7b1f 100644 --- a/scala-package/core/src/main/scala/ml/dmlc/mxnet/LRScheduler.scala +++ b/scala-package/core/src/main/scala/ml/dmlc/mxnet/LRScheduler.scala @@ -38,12 +38,8 @@ class FactorScheduler(var step: Int, var factor: Float) extends LRScheduler { var count: Int = 0 private val logger = LoggerFactory.getLogger(classOf[FactorScheduler]) - if (step < 1) { - throw new IllegalArgumentException("Schedule step must be greater or equal than 1 round") - } - if (factor >= 1.0) { - throw new IllegalArgumentException("Factor must be less than 1 to make lr reduce") - } + require(step < 1, "Schedule step must be greater or equal than 1 round") + require(factor >= 1.0, "Factor must be less than 1 to make lr reduce") def apply(numUpdate: Int): Double = { From a2edc89900a90fc7a0ba90f5239e34c388b59c4d Mon Sep 17 00:00:00 2001 From: terrytangyuan Date: Sun, 20 Dec 2015 23:02:16 -0600 Subject: [PATCH 5/5] Change to Float and make member vars protected --- .../src/main/scala/ml/dmlc/mxnet/LRScheduler.scala | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/scala-package/core/src/main/scala/ml/dmlc/mxnet/LRScheduler.scala b/scala-package/core/src/main/scala/ml/dmlc/mxnet/LRScheduler.scala index bd99b49e7b1f..35e05786fe74 100644 --- a/scala-package/core/src/main/scala/ml/dmlc/mxnet/LRScheduler.scala +++ b/scala-package/core/src/main/scala/ml/dmlc/mxnet/LRScheduler.scala @@ -7,7 +7,7 @@ import org.slf4j.LoggerFactory * @author Yuan Tang */ -abstract class LRScheduler(var baseLR: Double = 0.01) { +abstract class LRScheduler(protected var baseLR: Float = 0.01f) { /** * Base class of a learning rate scheduler * @@ -20,7 +20,7 @@ abstract class LRScheduler(var baseLR: Double = 0.01) { * * @param numUpdate Int, the maximal number of updates applied to a weight. */ - def apply(numUpdate: Int): Double + def apply(numUpdate: Int): Float } /** @@ -33,20 +33,20 @@ abstract class LRScheduler(var baseLR: Double = 0.01) { * @param factor Float, the factor for reducing the learning rate * */ -class FactorScheduler(var step: Int, var factor: Float) extends LRScheduler { +class FactorScheduler(protected var step: Int, protected var factor: Float) extends LRScheduler { - var count: Int = 0 + protected var count: Int = 0 private val logger = LoggerFactory.getLogger(classOf[FactorScheduler]) require(step < 1, "Schedule step must be greater or equal than 1 round") require(factor >= 1.0, "Factor must be less than 1 to make lr reduce") - def apply(numUpdate: Int): Double = { + def apply(numUpdate: Int): Float = { if (numUpdate > this.count + this.step) { this.count += this.step this.baseLR *= this.factor - this.logger.info(s"""Update$numUpdate: Change learning rate to ${this.baseLR}""") + this.logger.info(s"Update$numUpdate: Change learning rate to ${this.baseLR}") } this.baseLR }