Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Fix scaladoc scalastyle violations in Infer package (#14671)
Browse files Browse the repository at this point in the history
  • Loading branch information
zachgk authored and lanking520 committed Apr 11, 2019
1 parent 596ef3a commit 6a93bda
Show file tree
Hide file tree
Showing 7 changed files with 187 additions and 69 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ trait ClassifierBase {

/**
* Takes an array of floats and returns corresponding (Label, Score) tuples
* @tparam T The Scala equivalent of the DType used for the input array and return value
* @param input Indexed sequence one-dimensional array of floats/doubles
* @param topK (Optional) How many result (sorting based on the last axis)
* elements to return. Default returns unsorted output.
Expand Down Expand Up @@ -167,6 +168,12 @@ class Classifier(modelPathPrefix: String,
result.toIndexedSeq
}

/**
* Gives the path to the standard location of the synset.txt file
* @throws IllegalArgumentException Thrown when the file does not exist
* @param modelPathPrefix The path to the model directory
* @return The path to the synset.txt file
*/
private[infer] def getSynsetFilePath(modelPathPrefix: String): String = {
val dirPath = modelPathPrefix.substring(0, 1 + modelPathPrefix.lastIndexOf(File.separator))
val d = new File(dirPath)
Expand All @@ -179,6 +186,11 @@ class Classifier(modelPathPrefix: String,
s.getCanonicalPath
}

/**
* Parses the labels from a synset file
* @param synsetFilePath The path to the synset file. Can be gotten from getSynsetFilePath
* @return A IndexedSeq of each element in the file
*/
private[infer] def readSynsetFile(synsetFilePath: String): IndexedSeq[String] = {
val f = io.Source.fromFile(synsetFilePath)
try {
Expand All @@ -188,6 +200,11 @@ class Classifier(modelPathPrefix: String,
}
}

/**
* Creates a predictor with the same modelPath, inputDescriptors, contexts,
* and epoch as the classifier
* @return The new Predictor
*/
private[infer] def getPredictor(): PredictBase = {
new Predictor(modelPathPrefix, inputDescriptors, contexts, epoch)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,10 @@ class ImageClassifier(modelPathPrefix: String,
protected[infer] val height = inputShape(inputLayout.indexOf('H'))
protected[infer] val width = inputShape(inputLayout.indexOf('W'))

/**
* Get the names and shapes that would be returns by a classify call
* @return a list of (name, shape) tuples
*/
def outputShapes: IndexedSeq[(String, Shape)] = predictor.outputShapes

/**
Expand Down Expand Up @@ -127,6 +131,19 @@ class ImageClassifier(modelPathPrefix: String,
result
}

/**
* Creates a Classifier
*
* @param modelPathPrefix Path prefix from where to load the model artifacts.
* These include the symbol, parameters, and synset.txt.
* Example: file://model-dir/resnet-152 (containing
* resnet-152-symbol.json, resnet-152-0000.params, and synset.txt).
* @param inputDescriptors Descriptors defining the input node names, shape,
* layout and type parameters
* @param contexts Device contexts on which you want to run inference; defaults to CPU
* @param epoch Model epoch to load; defaults to 0
* @return A Classifier to perform inference with
*/
private[infer] def getClassifier(modelPathPrefix: String,
inputDescriptors: IndexedSeq[DataDesc],
contexts: Array[Context] = Context.cpu(),
Expand Down Expand Up @@ -156,19 +173,16 @@ object ImageClassifier {

/**
* Convert input BufferedImage to NDArray of input shape
*
* <p>
* Note: Caller is responsible to dispose the NDArray
* returned by this method after the use.
* </p>
* @param resizedImage BufferedImage to get pixels from
*
* @param inputImageShape Input shape; for example for resnet it is (3,224,224).
Should be same as inputDescriptor shape.
* @param dType The DataType of the NDArray created from the image
* that should be returned.
* Currently it defaults to Dtype.Float32
* @return NDArray pixels array with shape (3, 224, 224) in CHW format
* @param resizedImage BufferedImage to get pixels from
* @param inputImageShape Input shape; for example for resnet it is (3,224,224).
* Should be same as inputDescriptor shape.
* @param dType The DataType of the NDArray created from the image
* that should be returned.
* Currently it defaults to Dtype.Float32
* @return NDArray pixels array with shape (3, 224, 224) in CHW format
*/
def bufferedImageToPixels(resizedImage: BufferedImage, inputImageShape: Shape,
dType : DType = DType.Float32): NDArray = {
Expand Down Expand Up @@ -235,4 +249,4 @@ object ImageClassifier {
def loadInputBatch(inputImagePaths: List[String]): Traversable[BufferedImage] = {
inputImagePaths.map(path => ImageIO.read(new File(path)))
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,12 @@ import org.slf4j.LoggerFactory

private[infer] trait MXNetHandler {

/**
* Executes a function within a thread-safe executor
* @param f The function to execute
* @tparam T The return type of the function
* @return Returns the result of the function f
*/
def execute[T](f: => T): T

val executor: ExecutorService
Expand All @@ -31,7 +37,11 @@ private[infer] trait MXNetHandler {

private[infer] object MXNetHandlerType extends Enumeration {

/**
* The internal type of the MXNetHandlerType enumeration
*/
type MXNetHandlerType = Value

val SingleThreadHandler = Value("MXNetSingleThreadHandler")
val OneThreadPerModelHandler = Value("MXNetOneThreadPerModelHandler")
}
Expand Down Expand Up @@ -93,6 +103,10 @@ private[infer] object MXNetSingleThreadHandler extends MXNetThreadPoolHandler(1)

private[infer] object MXNetHandler {

/**
* Creates a handler based on the handlerType
* @return A ThreadPool or Thread Handler
*/
def apply(): MXNetHandler = {
if (handlerType == MXNetHandlerType.OneThreadPerModelHandler) {
new MXNetThreadPoolHandler(1)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,13 @@ class ObjectDetector(modelPathPrefix: String,
batchResult.toIndexedSeq
}

/**
* Formats detection results by sorting in descending order of accuracy (topK only)
* and combining with synset labels
* @param predictResultND The results from the objectDetect call
* @param topK The number of top results to return or None for all
* @return The top predicted results as (className, [Accuracy, Xmin, Ymin, Xmax, Ymax])
*/
private[infer] def sortAndReformat(predictResultND: NDArray, topK: Option[Int])
: IndexedSeq[(String, Array[Float])] = {
// iterating over the all the predictions
Expand Down Expand Up @@ -170,6 +177,18 @@ class ObjectDetector(modelPathPrefix: String,
result
}

/**
* Creates an image classifier from the object detector model
* @param modelPathPrefix Path prefix from where to load the model artifacts.
* These include the symbol, parameters, and synset.txt.
* Example: file://model-dir/resnet-152 (containing
* resnet-152-symbol.json, resnet-152-0000.params, and synset.txt).
* @param inputDescriptors Descriptors defining the input node names, shape,
* layout and type parameters
* @param contexts Device contexts on which you want to run inference; defaults to CPU
* @param epoch Model epoch to load; defaults to 0
* @return The corresponding image classifier
*/
private[infer] def getImageClassifier(modelPathPrefix: String,
inputDescriptors: IndexedSeq[DataDesc],
contexts: Array[Context] = Context.cpu(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,27 +33,27 @@ import org.slf4j.LoggerFactory
private[infer] trait PredictBase {

/**
* Converts indexed sequences of 1-D array to NDArrays.
* <p>
* This method will take input as IndexedSeq one dimensional arrays and creates the
* NDArray needed for inference. The array will be reshaped based on the input descriptors.
* @param input: An Indexed Sequence of a one-dimensional array of datatype
* Float or Double
An IndexedSequence is needed when the model has more than one input.
* @return Indexed sequence array of outputs
*/
* Converts indexed sequences of 1-D array to NDArrays.
* This method will take input as IndexedSeq one dimensional arrays and creates the
* NDArray needed for inference. The array will be reshaped based on the input descriptors.
* @tparam T The Scala equivalent of the DType used for the input array and return value
* @param input An Indexed Sequence of a one-dimensional array of datatype
* Float or Double
* An IndexedSequence is needed when the model has more than one input.
* @return Indexed sequence array of outputs
*/
def predict[@specialized (Base.MX_PRIMITIVES) T](input: IndexedSeq[Array[T]])
: IndexedSeq[Array[T]]

/**
* Predict using NDArray as input.
* <p>
* This method is useful when the input is a batch of data
* or when multiple operations on the input have to performed.
* Note: User is responsible for managing allocation/deallocation of NDArrays.
* @param input IndexedSequence NDArrays.
* @return Output of predictions as NDArrays.
*/
* Predict using NDArray as input.
* <p>
* This method is useful when the input is a batch of data
* or when multiple operations on the input have to performed.
* Note: User is responsible for managing allocation/deallocation of NDArrays.
* @param input IndexedSequence NDArrays.
* @return Output of predictions as NDArrays.
*/
def predictWithNDArray(input: IndexedSeq[NDArray]): IndexedSeq[NDArray]

/**
Expand Down Expand Up @@ -248,6 +248,10 @@ class Predictor(modelPathPrefix: String,
resultND
}

/**
* Creates the module backing the Predictor with the same path, epoch, contexts, and inputs
* @return The Module
*/
private[infer] def loadModule(): Module = {
val mod = mxNetHandler.execute(Module.loadCheckpoint(modelPathPrefix, epoch.get,
contexts = contexts, dataNames = inputDescriptors.map(desc => desc.name)))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,19 +31,23 @@ import scala.language.implicitConversions
* The ObjectDetector class helps to run ObjectDetection tasks where the goal
* is to find bounding boxes and corresponding labels for objects in a image.
*
* @param modelPathPrefix Path prefix from where to load the model artifacts.
* These include the symbol, parameters, and synset.txt.
* Example: file://model-dir/ssd_resnet50_512 (containing
* ssd_resnet50_512-symbol.json, ssd_resnet50_512-0000.params,
* and synset.txt)
* @param inputDescriptors Descriptors defining the input node names, shape,
* layout and type parameters
* @param contexts Device contexts on which you want to run inference.
* Defaults to CPU.
* @param epoch Model epoch to load; defaults to 0
* @param objDetector A source Scala Object detector
*/
class ObjectDetector private[mxnet] (val objDetector: org.apache.mxnet.infer.ObjectDetector){

/**
*
* @param modelPathPrefix Path prefix from where to load the model artifacts.
* These include the symbol, parameters, and synset.txt.
* Example: file://model-dir/ssd_resnet50_512 (containing
* ssd_resnet50_512-symbol.json, ssd_resnet50_512-0000.params,
* and synset.txt)
* @param inputDescriptors Descriptors defining the input node names, shape,
* layout and type parameters
* @param contexts Device contexts on which you want to run inference.
* Defaults to CPU.
* @param epoch Model epoch to load; defaults to 0
*/
def this(modelPathPrefix: String, inputDescriptors: java.lang.Iterable[DataDesc], contexts:
java.lang.Iterable[Context], epoch: Int)
= this {
Expand Down Expand Up @@ -98,32 +102,78 @@ class ObjectDetector private[mxnet] (val objDetector: org.apache.mxnet.infer.Obj
(ret map {a => (a map {e => new ObjectDetectorOutput(e._1, e._2)}).asJava}).asJava
}

/**
* Helper to map an implicit conversion
* @param l The value to convert
* @tparam B The desired type
* @tparam A The input type
* @return The converted result
*/
def convert[B, A <% B](l: IndexedSeq[A]): IndexedSeq[B] = l map { a => a: B }

}


object ObjectDetector {
implicit def fromObjectDetector(OD: org.apache.mxnet.infer.ObjectDetector):
ObjectDetector = new ObjectDetector(OD)

implicit def toObjectDetector(jOD: ObjectDetector):
org.apache.mxnet.infer.ObjectDetector = jOD.objDetector

/**
* Loads an input images from file
* @param inputImagePath Path of single input image
* @return BufferedImage Buffered image
*/
def loadImageFromFile(inputImagePath: String): BufferedImage = {
org.apache.mxnet.infer.ImageClassifier.loadImageFromFile(inputImagePath)
}

/**
* Reshape the input image to a new shape
*
* @param img Input image
* @param newWidth New width for rescaling
* @param newHeight New height for rescaling
* @return Rescaled BufferedImage
*/
def reshapeImage(img : BufferedImage, newWidth: Int, newHeight: Int): BufferedImage = {
org.apache.mxnet.infer.ImageClassifier.reshapeImage(img, newWidth, newHeight)
}

/**
* Convert input BufferedImage to NDArray of input shape
* Note: Caller is responsible to dispose the NDArray
* returned by this method after the use.
*
* @param resizedImage BufferedImage to get pixels from
* @param inputImageShape Input shape; for example for resnet it is (3,224,224).
* Should be same as inputDescriptor shape.
* @return NDArray pixels array with shape (3, 224, 224) in CHW format
*/
def bufferedImageToPixels(resizedImage: BufferedImage, inputImageShape: Shape): NDArray = {
org.apache.mxnet.infer.ImageClassifier.bufferedImageToPixels(resizedImage, inputImageShape)
}

/**
* Loads a batch of images from a folder
* @param inputImagePaths Path to a folder of images
* @return List of buffered images
*/
def loadInputBatch(inputImagePaths: java.lang.Iterable[String]): java.util.List[BufferedImage] = {
org.apache.mxnet.infer.ImageClassifier
.loadInputBatch(inputImagePaths.asScala.toList).toList.asJava
}

/**
* Implicitly convert a Scala ObjectDetector to a Java ObjectDetector
* @param OD The Scala ObjectDetector
* @return The Java ObjectDetector
*/
implicit def fromObjectDetector(OD: org.apache.mxnet.infer.ObjectDetector):
ObjectDetector = new ObjectDetector(OD)

/**
* Implicitly converts a Java ObjectDetector to a Scala ObjectDetector
* @param jOD The Java ObjectDetector
* @return The Scala ObjectDetector
*/
implicit def toObjectDetector(jOD: ObjectDetector):
org.apache.mxnet.infer.ObjectDetector = jOD.objDetector
}
Loading

0 comments on commit 6a93bda

Please sign in to comment.