Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Predict extend (#13473)
Browse files Browse the repository at this point in the history
[1.4.x] Predict extend
  • Loading branch information
lanking520 authored Nov 30, 2018
1 parent e434251 commit eb82da8
Show file tree
Hide file tree
Showing 7 changed files with 209 additions and 17 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -106,9 +106,9 @@ class Executor private[mxnet](private[mxnet] val handle: ExecutorHandle,
"is more efficient than the reverse." +
"If you really want to up size, set allowUpSizing = true " +
"to enable allocation of new arrays.")
newArgDict = newArgDict + (name -> NDArray.empty(newShape, arr.context))
newArgDict = newArgDict + (name -> NDArray.empty(newShape, arr.context, arr.dtype))
if (dArr != null) {
newGradDict = newGradDict + (name -> NDArray.empty(newShape, dArr.context))
newGradDict = newGradDict + (name -> NDArray.empty(newShape, dArr.context, dArr.dtype))
}
} else {
newArgDict = newArgDict + (name -> arr.reshape(newShape.toArray))
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.mxnetexamples.infer.predictor

import java.io.File

import scala.io
import org.apache.mxnet._
import org.apache.mxnet.infer.Predictor
import org.apache.mxnetexamples.benchmark.CLIParserBase
import org.kohsuke.args4j.{CmdLineParser, Option}

import scala.collection.JavaConverters._

object PredictorExample {

def loadModel(modelPathPrefix : String, inputDesc : IndexedSeq[DataDesc],
context : Context, epoch : Int): Predictor = {
new Predictor(modelPathPrefix, inputDesc, context, Some(epoch))
}

def doInference(predictor : Predictor, imageND : NDArray): IndexedSeq[NDArray] = {
predictor.predictWithNDArray(IndexedSeq(imageND))
}

def preProcess(imagePath: String, h: Int, w: Int) : NDArray = {
var img = Image.imRead(imagePath)
img = Image.imResize(img, h, w)
// HWC -> CHW
img = NDArray.api.transpose(img, Some(Shape(2, 0, 1)))
img = NDArray.api.expand_dims(img, 0)
img.asType(DType.Float32)
}

def postProcess(modelPathPrefix : String, result : Array[Float]) : String = {
val dirPath = modelPathPrefix.substring(0, 1 + modelPathPrefix.lastIndexOf(File.separator))
val d = new File(dirPath)
require(d.exists && d.isDirectory, s"directory: $dirPath not found")
val f = io.Source.fromFile(dirPath + "synset.txt")
val s = f.getLines().toIndexedSeq
val maxIdx = result.zipWithIndex.maxBy(_._1)._2
printf(s"Predict Result ${s(maxIdx)} with prob ${result(maxIdx)}\n")
s(maxIdx)
}

def main(args : Array[String]): Unit = {
val inst = new CLIParser
val parser: CmdLineParser = new CmdLineParser(inst)

parser.parseArgument(args.toList.asJava)

var context = Context.cpu()
if (System.getenv().containsKey("SCALA_TEST_ON_GPU") &&
System.getenv("SCALA_TEST_ON_GPU").toInt == 1) {
context = Context.gpu()
}

val imgWidth = 224
val imgHeight = 224

val inputDesc = IndexedSeq(new DataDesc("data", Shape(1, 3, imgHeight, imgWidth),
DType.Float32, Layout.NCHW))

val predictor = loadModel(inst.modelPathPrefix, inputDesc, context, 0)
val img = preProcess(inst.inputImagePath, imgHeight, imgWidth)
val result = doInference(predictor, img)(0).toArray
postProcess(inst.modelPathPrefix, result)
}

}

class CLIParser extends CLIParserBase{
@Option(name = "--model-path-prefix", usage = "the input model directory")
val modelPathPrefix: String = "/resnet-152/resnet-152"
@Option(name = "--input-image", usage = "the input image")
val inputImagePath: String = "/images/kitten.jpg"
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,7 @@ package org.apache.mxnetexamples.infer.imageclassifier
import org.scalatest.{BeforeAndAfterAll, FunSuite}
import org.slf4j.LoggerFactory
import java.io.File
import java.net.URL

import org.apache.commons.io.FileUtils
import org.apache.mxnet.{Context, NDArrayCollector}
import org.apache.mxnet.Context
import org.apache.mxnetexamples.Util

import sys.process.Process
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,7 @@
package org.apache.mxnetexamples.infer.objectdetector

import java.io.File
import java.net.URL

import org.apache.commons.io.FileUtils
import org.apache.mxnet.{Context, NDArrayCollector}
import org.apache.mxnet.Context
import org.apache.mxnetexamples.Util
import org.scalatest.{BeforeAndAfterAll, FunSuite}
import org.slf4j.LoggerFactory
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.mxnetexamples.infer.predictor

import java.io.File

import org.apache.mxnet._
import org.apache.mxnetexamples.Util
import org.scalatest.{BeforeAndAfterAll, FunSuite}
import org.slf4j.LoggerFactory

class PredictorExampleSuite extends FunSuite with BeforeAndAfterAll {
private val logger = LoggerFactory.getLogger(classOf[PredictorExampleSuite])
private var modelDirPrefix = ""
private var inputImagePath = ""
private var context = Context.cpu()

override def beforeAll(): Unit = {
logger.info("Downloading resnet-18 model")

val tempDirPath = System.getProperty("java.io.tmpdir")
logger.info("tempDirPath: %s".format(tempDirPath))

val baseUrl = "https://s3.us-east-2.amazonaws.com/scala-infer-models"

Util.downloadUrl(baseUrl + "/resnet-18/resnet-18-symbol.json",
tempDirPath + "/resnet18/resnet-18-symbol.json")
Util.downloadUrl(baseUrl + "/resnet-18/resnet-18-0000.params",
tempDirPath + "/resnet18/resnet-18-0000.params")
Util.downloadUrl(baseUrl + "/resnet-18/synset.txt",
tempDirPath + "/resnet18/synset.txt")
Util.downloadUrl("https://s3.amazonaws.com/model-server/inputs/Pug-Cookie.jpg",
tempDirPath + "/inputImages/resnet18/Pug-Cookie.jpg")

modelDirPrefix = tempDirPath + File.separator + "resnet18/resnet-18"
inputImagePath = tempDirPath + File.separator +
"inputImages/resnet18/Pug-Cookie.jpg"

if (System.getenv().containsKey("SCALA_TEST_ON_GPU") &&
System.getenv("SCALA_TEST_ON_GPU").toInt == 1) {
context = Context.gpu()
}
val props = System.getProperties
props.setProperty("mxnet.disableShapeCheck", "true")
}

override def afterAll(): Unit = {
val props = System.getProperties
props.setProperty("mxnet.disableShapeCheck", "false")
}

test("test Predictor With Fixed Shape and random shape") {
val inputDesc = IndexedSeq(new DataDesc("data", Shape(1, 3, 224, 224),
DType.Float32, Layout.NCHW))
val predictor = PredictorExample.loadModel(modelDirPrefix, inputDesc, context, 0)
// fix size
var img = PredictorExample.preProcess(inputImagePath, 224, 224)
var result = PredictorExample.doInference(predictor, img)(0)
var top1 = PredictorExample.postProcess(modelDirPrefix, result.toArray)
assert(top1 === "n02110958 pug, pug-dog")
// random size 512
img = PredictorExample.preProcess(inputImagePath, 512, 512)
result = PredictorExample.doInference(predictor, img)(0)
top1 = PredictorExample.postProcess(modelDirPrefix, result.toArray)
assert(top1 === "n02110958 pug, pug-dog")
// original size
img = PredictorExample.preProcess(inputImagePath, 1024, 576)
result = PredictorExample.doInference(predictor, img)(0)
top1 = PredictorExample.postProcess(modelDirPrefix, result.toArray)
assert(top1 === "n02110958 pug, pug-dog")
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,10 @@ import org.apache.mxnet.{Context, DataDesc, NDArray, Shape}
import org.apache.mxnet.module.Module

import scala.collection.mutable.ListBuffer
import scala.util.Try
import org.slf4j.LoggerFactory


/**
* Base Trait for MXNet Predictor classes.
*/
Expand Down Expand Up @@ -76,6 +78,21 @@ class Predictor(modelPathPrefix: String,

private val logger = LoggerFactory.getLogger(classOf[Predictor])

/*
By setting -Dmxnet.disableShapeCheck=true would disable the data Shape
Check of the predictor. Some model may allow different lens of the data
such as Seq2Seq, however there maybe risk of crashes if the lens beyond
the acceptable range of the model
*/
private val traceProperty = "mxnet.disableShapeCheck"
private lazy val shapeCheckDisabled = {
val value = Try(System.getProperty(traceProperty).toBoolean).getOrElse(false)
if (value) {
logger.warn("Shape check is disabled (property {} is set)", traceProperty)
}
value
}

require(inputDescriptors.head.layout.size != 0, "layout size should not be zero")

protected[infer] var batchIndex = inputDescriptors(0).layout.indexOf('N')
Expand Down Expand Up @@ -172,18 +189,20 @@ class Predictor(modelPathPrefix: String,
for((i, d) <- inputBatch.zip(iDescriptors)) {
require(inputBatch(0).shape(batchIndex) == i.shape(batchIndex),
"All inputs should be of same batch size")
require(i.shape.drop(batchIndex + 1) == d.shape.drop(batchIndex + 1),
s"Input Data Shape: ${i.shape} should match the inputDescriptor " +
s"shape: ${d.shape} except batchSize")
if (!shapeCheckDisabled) {
require(i.shape.drop(batchIndex + 1) == d.shape.drop(batchIndex + 1),
s"Input Data Shape: ${i.shape} should match the inputDescriptor " +
s"shape: ${d.shape} except batchSize")
}
}

val inputBatchSize = inputBatch(0).shape(batchIndex)

// rebind with the new batchSize
if (batchSize != inputBatchSize) {
logger.info(s"Latency increased due to batchSize mismatch $batchSize vs $inputBatchSize")
val desc = iDescriptors.map((f : DataDesc) => new DataDesc(f.name,
Shape(f.shape.toVector.patch(batchIndex, Vector(inputBatchSize), 1)), f.dtype, f.layout) )
val desc = inputBatch.zip(iDescriptors).map(f => new DataDesc(f._2.name,
f._1.shape, f._2.dtype, f._2.layout))
mxNetHandler.execute(mod.bind(desc, forceRebind = true,
forTraining = false))
}
Expand All @@ -200,7 +219,7 @@ class Predictor(modelPathPrefix: String,

private[infer] def loadModule(): Module = {
val mod = mxNetHandler.execute(Module.loadCheckpoint(modelPathPrefix, epoch.get,
contexts = contexts))
contexts = contexts, dataNames = inputDescriptors.map(desc => desc.name)))
mxNetHandler.execute(mod.bind(inputDescriptors, forTraining = false))
mod
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ class Predictor private[mxnet] (val predictor: org.apache.mxnet.infer.Predictor)
* This method is useful when the input is a batch of data
* Note: User is responsible for managing allocation/deallocation of input/output NDArrays.
*
* @param input List of NDArrays
* @param input List of NDArrays
* @return Output of predictions as NDArrays
*/
def predictWithNDArray(input: java.util.List[NDArray]):
Expand Down

0 comments on commit eb82da8

Please sign in to comment.