Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[MXNET-1287] Fix scalastyle #14669

Merged
merged 1 commit into from
Apr 11, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -180,21 +180,24 @@ class FeedForward private(

// Initialize the predictor module for running prediction.
private def initPredictor(inputShapes: Map[String, Shape]): Unit = {
var shouldInit = true
if (this.predExec != null) {
val (argShapes, _, _) = symbol.inferShape(inputShapes)
require(argShapes != null, "Shape inference failed." +
s"Known shapes are $inputShapes for symbol arguments ${symbol.listArguments()} " +
s"and aux states ${symbol.listAuxiliaryStates()}")
val predShapes = this.predExec.argArrays.map(_.shape)
if (argShapes.sameElements(predShapes)) {
return
shouldInit = false
}
}
// for now only use the first device
val predExec = symbol.simpleBind(ctx(0), gradReq = "null", shapeDict = inputShapes)
predExec.copyParamsFrom(_argParams, _auxParams)
ExecutorManager.checkArguments(symbol)
this.predExec = predExec
if(shouldInit) {
// for now only use the first device
val predExec = symbol.simpleBind(ctx(0), gradReq = "null", shapeDict = inputShapes)
predExec.copyParamsFrom(_argParams, _auxParams)
ExecutorManager.checkArguments(symbol)
this.predExec = predExec
}
}

// Initialize the iterator given input.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -173,14 +173,13 @@ class BucketingModule(symGen: AnyRef => (Symbol, IndexedSeq[String], IndexedSeq[
allowMissing: Boolean = false,
forceInit: Boolean = false,
allowExtra: Boolean = false): Unit = {
if (paramsInitialized && !forceInit) {
return
if (!paramsInitialized || forceInit) {
require(binded, "call bind before initializing the parameters")
this._currModule.initParams(initializer, argParams, auxParams,
allowMissing, forceInit, allowExtra)
this.paramsDirty = false
this.paramsInitialized = true
}
require(binded, "call bind before initializing the parameters")
this._currModule.initParams(initializer, argParams, auxParams,
allowMissing, forceInit, allowExtra)
this.paramsDirty = false
this.paramsInitialized = true
}

/**
Expand Down Expand Up @@ -218,28 +217,27 @@ class BucketingModule(symGen: AnyRef => (Symbol, IndexedSeq[String], IndexedSeq[

if (this.binded) {
logger.warn("Already bound, ignoring bind()")
return
}
} else {
require(sharedModule.isEmpty,
"sharedModule for BucketingModule is not supported")

require(sharedModule.isEmpty,
"sharedModule for BucketingModule is not supported")

this.forTraining = forTraining
this.inputsNeedGrad = inputsNeedGrad
this.binded = true

val (sym, dNames, lNames) = this.symGen(this.defaultBucketKey)
val module = new Module(sym, dNames, lNames, this.contexts,
this.workLoadList, this.fixedParamNames)
module.bind(dataShapes, labelShapes, forTraining, inputsNeedGrad,
forceRebind = false, sharedModule = None, gradReq)
this._currModule = module
this._currBucketKey = this.defaultBucketKey
this._buckets(this.defaultBucketKey) = module

// copy back saved params, if already initialized
if (this.paramsInitialized) {
this.setParams(argParams, auxParams)
this.forTraining = forTraining
this.inputsNeedGrad = inputsNeedGrad
this.binded = true

val (sym, dNames, lNames) = this.symGen(this.defaultBucketKey)
val module = new Module(sym, dNames, lNames, this.contexts,
this.workLoadList, this.fixedParamNames)
module.bind(dataShapes, labelShapes, forTraining, inputsNeedGrad,
forceRebind = false, sharedModule = None, gradReq)
this._currModule = module
this._currBucketKey = this.defaultBucketKey
this._buckets(this.defaultBucketKey) = module

// copy back saved params, if already initialized
if (this.paramsInitialized) {
this.setParams(argParams, auxParams)
}
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -121,36 +121,35 @@ class Module(symbolVar: Symbol,
allowMissing: Boolean = false,
forceInit: Boolean = false,
allowExtra: Boolean = false): Unit = {
if (paramsInitialized && !forceInit) {
return
}
require(binded, "call bind before initializing the parameters")
if (!paramsInitialized || forceInit) {
require(binded, "call bind before initializing the parameters")

if (this.argParams == null) {
val paramArrays =
execGroup.paramArrays.map(nds => NDArray.zeros(nds(0).shape, dtype = nds(0).dtype))
this.argParams = this.paramNames.zip(paramArrays).toMap
}
if (this.argParams == null) {
val paramArrays =
execGroup.paramArrays.map(nds => NDArray.zeros(nds(0).shape, dtype = nds(0).dtype))
this.argParams = this.paramNames.zip(paramArrays).toMap
}

if (this.auxParams == null) {
val auxArrays =
execGroup.auxArrays.map(nds => NDArray.zeros(nds(0).shape, dtype = nds(0).dtype))
this.auxParams = this.auxNames.zip(auxArrays).toMap
}
if (this.auxParams == null) {
val auxArrays =
execGroup.auxArrays.map(nds => NDArray.zeros(nds(0).shape, dtype = nds(0).dtype))
this.auxParams = this.auxNames.zip(auxArrays).toMap
}

this.argParams.foreach { case (name, arr) =>
impl(name, arr, allowMissing, Option(initializer), argParams)
}
this.argParams.foreach { case (name, arr) =>
impl(name, arr, allowMissing, Option(initializer), argParams)
}

this.auxParams.foreach { case (name, arr) =>
impl(name, arr, allowMissing, Option(initializer), auxParams)
}
this.auxParams.foreach { case (name, arr) =>
impl(name, arr, allowMissing, Option(initializer), auxParams)
}

this.paramsInitialized = true
this.paramsDirty = false
this.paramsInitialized = true
this.paramsDirty = false

// copy the initialized parameters to devices
this.execGroup.setParams(this.argParams, this.auxParams, allowExtra = allowExtra)
// copy the initialized parameters to devices
this.execGroup.setParams(this.argParams, this.auxParams, allowExtra = allowExtra)
}
}

// Internal helper for parameter initialization
Expand Down Expand Up @@ -246,64 +245,64 @@ class Module(symbolVar: Symbol,

if (binded) {
logger.warn("Already binded, ignoring bind()")
return
}
} else {
this.forTraining = forTraining
this.inputsNeedGrad = inputsNeedGrad
this.binded = true

this.forTraining = forTraining
this.inputsNeedGrad = inputsNeedGrad
this.binded = true
if (!forTraining) {
require(!inputsNeedGrad, "Invalid inputsNeedGrad (cannot be true if not forTraining)")
} else {
// this is not True, as some module might not contains a loss function
// that consumes the labels
// require(labelShapes != None)
}

if (!forTraining) {
require(!inputsNeedGrad, "Invalid inputsNeedGrad (cannot be true if not forTraining)")
} else {
// this is not True, as some module might not contains a loss function
// that consumes the labels
// require(labelShapes != None)
}
this.dataShapesVar = dataShapes
this.labelShapesVar = labelShapes

this.dataShapesVar = dataShapes
this.labelShapesVar = labelShapes

val sharedGroup =
sharedModule.map(sharedModuleInst => {
require(sharedModuleInst.binded && sharedModuleInst.paramsInitialized,
s"bind() and initParams() must be called first on shared module.")
sharedModuleInst.execGroup
})

val inputTypes = this.dataShapesVar.map(dataDesc => (dataDesc.name, dataDesc.dtype)).toMap ++
labelShapes.map(shapes => shapes.map(dataDesc => (dataDesc.name, dataDesc.dtype)).toMap)
.getOrElse(Map.empty[String, DType])

execGroup = new Builder(symbol, contexts, paramNames)
.setWorkLoadList(workLoads)
.setDataShapes(dataShapes)
.setLabelShapes(labelShapes.orNull)
.setForTraining(forTraining)
.setInputsNeedGrad(inputsNeedGrad)
.setSharedGroup(sharedGroup.orNull)
.setFixedParamNames(fixedParamNames.orNull)
.setGradReq(gradReq)
.setInputTypes(inputTypes)
.build()

if (sharedModule.isDefined) {
paramsInitialized = true
argParams = sharedModule.get.argParams
auxParams = sharedModule.get.auxParams
} else if (paramsInitialized) {
// if the parameters are already initialized, we are re-binding
// so automatically copy the already initialized params
execGroup.setParams(argParams, auxParams)
}
val sharedGroup =
sharedModule.map(sharedModuleInst => {
require(sharedModuleInst.binded && sharedModuleInst.paramsInitialized,
s"bind() and initParams() must be called first on shared module.")
sharedModuleInst.execGroup
})

sharedModule.foreach {
case sharedModuleInst: Module =>
if (sharedModuleInst.optimizerInitialized) {
borrowOptimizer(sharedModuleInst)
}
case _ =>
val inputTypes = this.dataShapesVar.map(dataDesc => (dataDesc.name, dataDesc.dtype)).toMap ++
labelShapes.map(shapes => shapes.map(dataDesc => (dataDesc.name, dataDesc.dtype)).toMap)
.getOrElse(Map.empty[String, DType])

execGroup = new Builder(symbol, contexts, paramNames)
.setWorkLoadList(workLoads)
.setDataShapes(dataShapes)
.setLabelShapes(labelShapes.orNull)
.setForTraining(forTraining)
.setInputsNeedGrad(inputsNeedGrad)
.setSharedGroup(sharedGroup.orNull)
.setFixedParamNames(fixedParamNames.orNull)
.setGradReq(gradReq)
.setInputTypes(inputTypes)
.build()

if (sharedModule.isDefined) {
paramsInitialized = true
argParams = sharedModule.get.argParams
auxParams = sharedModule.get.auxParams
} else if (paramsInitialized) {
// if the parameters are already initialized, we are re-binding
// so automatically copy the already initialized params
execGroup.setParams(argParams, auxParams)
}

sharedModule.foreach {
case sharedModuleInst: Module =>
if (sharedModuleInst.optimizerInitialized) {
borrowOptimizer(sharedModuleInst)
}
case _ =>
}
}

}

/**
Expand Down
Loading