Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Fix bugs in distributed example of TrainMnist.scala in MXNet-Scala #5648

Closed
wants to merge 13 commits into from
Closed
5 changes: 2 additions & 3 deletions scala-package/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -79,11 +79,10 @@ java -Xmx4G -cp \
--cpus=0,1,2,3
```

If you've compiled with `USE_DIST_KVSTORE` enabled, the python tools in `mxnet/tracker` can be used to launch distributed training.
The following command runs the above example using 2 worker nodes (and 2 server nodes) in local. Refer to [Distributed Training](http://mxnet.io/how_to/multi_devices.html) for more details.
If you've compiled with `USE_DIST_KVSTORE` enabled, the python tools in `tools` can be used to launch distributed training. Assuming you are in `mxnet` directory, the following command runs the above example using 2 worker nodes (and 2 server nodes) in local. Refer to [Distributed Training](http://mxnet.io/how_to/multi_devices.html) for more details.

```bash
tracker/dmlc_local.py -n 2 -s 2 \
tools/launch.py -n 2 --launcher local \
java -Xmx4G -cp \
scala-package/assembly/{your-architecture}/target/*:scala-package/examples/target/*:scala-package/examples/target/classes/lib/* \
ml.dmlc.mxnet.examples.imclassification.TrainMnist \
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -113,25 +113,14 @@ object TrainMnist {
if (inst.gpus != null) inst.gpus.split(',').map(id => Context.gpu(id.trim.toInt))
else if (inst.cpus != null) inst.cpus.split(',').map(id => Context.cpu(id.trim.toInt))
else Array(Context.cpu(0))

val envs: mutable.Map[String, String] = mutable.HashMap.empty[String, String]
envs.put("DMLC_ROLE", inst.role)
if (inst.schedulerHost != null) {
require(inst.schedulerPort > 0, "scheduler port not specified")
envs.put("DMLC_PS_ROOT_URI", inst.schedulerHost)
envs.put("DMLC_PS_ROOT_PORT", inst.schedulerPort.toString)
require(inst.numWorker > 0, "Num of workers must > 0")
envs.put("DMLC_NUM_WORKER", inst.numWorker.toString)
require(inst.numServer > 0, "Num of servers must > 0")
envs.put("DMLC_NUM_SERVER", inst.numServer.toString)
logger.info("Init PS environments")
KVStoreServer.init(envs.toMap)
}

if (inst.role != "worker") {
logger.info("Start KVStoreServer for scheduler & servers")
// Extract parameters from envs
val envs: Map[String, String] = sys.env
if (envs.contains("DMLC_ROLE") && envs("DMLC_ROLE") != "worker") {
logger.info("Start KVStoreServer for scheduler & servers" + envs("DMLC_ROLE"))
KVStoreServer.init(envs)
KVStoreServer.start()
} else {
// Run locally or as worker
ModelTrain.fit(dataDir = inst.dataDir,
batchSize = inst.batchSize, numExamples = inst.numExamples, devs = devs,
network = net, dataLoader = getIterator(dataShape),
Expand Down