Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
fix some glitches
Browse files Browse the repository at this point in the history
  • Loading branch information
mdespriee committed Sep 14, 2018
1 parent 48b8160 commit 82dfeb8
Show file tree
Hide file tree
Showing 5 changed files with 5 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ class DCASGD(val learningRate: Float = 0.01f, momentum: Float = 0.0f,
mon *= this.momentum
mon += monUpdated
} else {
require(this.momentum == 0f, s"momentum should be 0, not ${this.momentum}")
require(this.momentum == 0, "momentum should be zero when state is provided")
mon = monUpdated
}
previousWeight.set(weight)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ class NAG(val learningRate: Float = 0.01f, momentum: Float = 0.0f,
resdGrad += momentum * mom
weight += -lr * resdGrad
} else {
require(momentum == 0f, s"momentum should be 0, not ${this.momentum}")
require(momentum == 0f, "momentum should be zero when state is provided")
// adder = -lr * (resdGrad + this.wd * weight)
// we write in this way to get rid of memory leak
val adder = this.wd * weight
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ class SGD(val learningRate: Float = 0.01f, momentum: Float = 0.0f,
weight += mom
adder.dispose()
} else {
require(momentum == 0f, s"momentum should be 0, not ${this.momentum}")
require(momentum == 0f, "momentum should be zero when state is provided")
// adder = -lr * (resdGrad + this.wd * weight)
// we write in this way to get rid of memory leak
val adder = this.wd * weight
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ private[infer] object MXNetHandlerType extends Enumeration {
private[infer] class MXNetThreadPoolHandler(numThreads: Int = 1)
extends MXNetHandler {

require(numThreads > 0, s"numThreads should be a positive number, you passed $numThreads")
require(numThreads > 0, s"Invalid numThreads $numThreads")

private val logger = LoggerFactory.getLogger(classOf[MXNetThreadPoolHandler])
private var threadCount: Int = 0
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ import org.apache.mxnet.NDArray
* @author Yizhi Liu
*/
class MXNDArray(@transient private var ndArray: NDArray) extends Serializable {
require(ndArray != null, "ndArray must be non-null")
require(ndArray != null, "Undefined ndArray")
private val arrayBytes: Array[Byte] = ndArray.serialize()

def get: NDArray = {
Expand Down

0 comments on commit 82dfeb8

Please sign in to comment.