Skip to content

Commit

Permalink
Merge pull request apache#7 from terrytangyuan/terry
Browse files Browse the repository at this point in the history
Learning rate scheduler
  • Loading branch information
yzhliu committed Dec 21, 2015
2 parents 5197470 + a2edc89 commit a6c9ead
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 1 deletion.
54 changes: 54 additions & 0 deletions scala-package/core/src/main/scala/ml/dmlc/mxnet/LRScheduler.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
package ml.dmlc.mxnet
import org.slf4j.LoggerFactory

/**
* Learning rate scheduler, which adaptively changes the learning rate
* based on the training progress.
* @author Yuan Tang
*/

abstract class LRScheduler(protected var baseLR: Float = 0.01f) {
/**
* Base class of a learning rate scheduler
*
* The training progress is presented by `num_update`, which can be roughly
* viewed as the number of minibatches executed so far. Its value is
* non-decreasing, and increases at most by one.
*
* The exact value is the upper bound of the number of updates applied to
* a weight/index.
*
* @param numUpdate Int, the maximal number of updates applied to a weight.
*/
def apply(numUpdate: Int): Float
}

/**
* Class for reducing learning rate in factor
*
* Assume the weight has been updated by n times, then the learning rate will
* be base_lr * factor^^(floor(n/step))
*
* @param step Int, schedule learning rate after n updates
* @param factor Float, the factor for reducing the learning rate
*
*/
class FactorScheduler(protected var step: Int, protected var factor: Float) extends LRScheduler {

protected var count: Int = 0
private val logger = LoggerFactory.getLogger(classOf[FactorScheduler])

require(step < 1, "Schedule step must be greater or equal than 1 round")
require(factor >= 1.0, "Factor must be less than 1 to make lr reduce")

def apply(numUpdate: Int): Float = {

if (numUpdate > this.count + this.step) {
this.count += this.step
this.baseLR *= this.factor
this.logger.info(s"Update$numUpdate: Change learning rate to ${this.baseLR}")
}
this.baseLR
}
}

Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ import scala.collection.mutable.{ArrayBuffer, ListBuffer}

/**
* NDArray API of mxnet
* @author Yizhi Liu, Terry Tang
* @author Yizhi Liu, Yuan Tang
*/
object NDArray {
private val logger = LoggerFactory.getLogger(classOf[NDArray])
Expand Down

0 comments on commit a6c9ead

Please sign in to comment.