Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[MXNET-1385] Fix scaladoc scalastyle violations in Infer package #14671

Merged
merged 1 commit into from
Apr 11, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ trait ClassifierBase {

/**
* Takes an array of floats and returns corresponding (Label, Score) tuples
* @tparam T The Scala equivalent of the DType used for the input array and return value
* @param input Indexed sequence one-dimensional array of floats/doubles
* @param topK (Optional) How many result (sorting based on the last axis)
* elements to return. Default returns unsorted output.
Expand Down Expand Up @@ -167,6 +168,12 @@ class Classifier(modelPathPrefix: String,
result.toIndexedSeq
}

/**
* Gives the path to the standard location of the synset.txt file
* @throws IllegalArgumentException Thrown when the file does not exist
* @param modelPathPrefix The path to the model directory
* @return The path to the synset.txt file
*/
private[infer] def getSynsetFilePath(modelPathPrefix: String): String = {
val dirPath = modelPathPrefix.substring(0, 1 + modelPathPrefix.lastIndexOf(File.separator))
val d = new File(dirPath)
Expand All @@ -179,6 +186,11 @@ class Classifier(modelPathPrefix: String,
s.getCanonicalPath
}

/**
* Parses the labels from a synset file
* @param synsetFilePath The path to the synset file. Can be gotten from getSynsetFilePath
* @return A IndexedSeq of each element in the file
*/
private[infer] def readSynsetFile(synsetFilePath: String): IndexedSeq[String] = {
val f = io.Source.fromFile(synsetFilePath)
try {
Expand All @@ -188,6 +200,11 @@ class Classifier(modelPathPrefix: String,
}
}

/**
* Creates a predictor with the same modelPath, inputDescriptors, contexts,
* and epoch as the classifier
* @return The new Predictor
*/
private[infer] def getPredictor(): PredictBase = {
new Predictor(modelPathPrefix, inputDescriptors, contexts, epoch)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,10 @@ class ImageClassifier(modelPathPrefix: String,
protected[infer] val height = inputShape(inputLayout.indexOf('H'))
protected[infer] val width = inputShape(inputLayout.indexOf('W'))

/**
* Get the names and shapes that would be returns by a classify call
* @return a list of (name, shape) tuples
*/
def outputShapes: IndexedSeq[(String, Shape)] = predictor.outputShapes

/**
Expand Down Expand Up @@ -127,6 +131,19 @@ class ImageClassifier(modelPathPrefix: String,
result
}

/**
* Creates a Classifier
*
* @param modelPathPrefix Path prefix from where to load the model artifacts.
* These include the symbol, parameters, and synset.txt.
* Example: file://model-dir/resnet-152 (containing
* resnet-152-symbol.json, resnet-152-0000.params, and synset.txt).
* @param inputDescriptors Descriptors defining the input node names, shape,
* layout and type parameters
* @param contexts Device contexts on which you want to run inference; defaults to CPU
* @param epoch Model epoch to load; defaults to 0
* @return A Classifier to perform inference with
*/
private[infer] def getClassifier(modelPathPrefix: String,
inputDescriptors: IndexedSeq[DataDesc],
contexts: Array[Context] = Context.cpu(),
Expand Down Expand Up @@ -156,19 +173,16 @@ object ImageClassifier {

/**
* Convert input BufferedImage to NDArray of input shape
*
* <p>
* Note: Caller is responsible to dispose the NDArray
* returned by this method after the use.
* </p>
* @param resizedImage BufferedImage to get pixels from
*
* @param inputImageShape Input shape; for example for resnet it is (3,224,224).
Should be same as inputDescriptor shape.
* @param dType The DataType of the NDArray created from the image
* that should be returned.
* Currently it defaults to Dtype.Float32
* @return NDArray pixels array with shape (3, 224, 224) in CHW format
* @param resizedImage BufferedImage to get pixels from
* @param inputImageShape Input shape; for example for resnet it is (3,224,224).
* Should be same as inputDescriptor shape.
* @param dType The DataType of the NDArray created from the image
* that should be returned.
* Currently it defaults to Dtype.Float32
* @return NDArray pixels array with shape (3, 224, 224) in CHW format
*/
def bufferedImageToPixels(resizedImage: BufferedImage, inputImageShape: Shape,
dType : DType = DType.Float32): NDArray = {
Expand Down Expand Up @@ -235,4 +249,4 @@ object ImageClassifier {
def loadInputBatch(inputImagePaths: List[String]): Traversable[BufferedImage] = {
inputImagePaths.map(path => ImageIO.read(new File(path)))
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,12 @@ import org.slf4j.LoggerFactory

private[infer] trait MXNetHandler {

/**
* Executes a function within a thread-safe executor
* @param f The function to execute
* @tparam T The return type of the function
* @return Returns the result of the function f
*/
def execute[T](f: => T): T

val executor: ExecutorService
Expand All @@ -31,7 +37,11 @@ private[infer] trait MXNetHandler {

private[infer] object MXNetHandlerType extends Enumeration {

/**
* The internal type of the MXNetHandlerType enumeration
*/
type MXNetHandlerType = Value

val SingleThreadHandler = Value("MXNetSingleThreadHandler")
val OneThreadPerModelHandler = Value("MXNetOneThreadPerModelHandler")
}
Expand Down Expand Up @@ -93,6 +103,10 @@ private[infer] object MXNetSingleThreadHandler extends MXNetThreadPoolHandler(1)

private[infer] object MXNetHandler {

/**
* Creates a handler based on the handlerType
* @return A ThreadPool or Thread Handler
*/
def apply(): MXNetHandler = {
if (handlerType == MXNetHandlerType.OneThreadPerModelHandler) {
new MXNetThreadPoolHandler(1)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,13 @@ class ObjectDetector(modelPathPrefix: String,
batchResult.toIndexedSeq
}

/**
* Formats detection results by sorting in descending order of accuracy (topK only)
* and combining with synset labels
* @param predictResultND The results from the objectDetect call
* @param topK The number of top results to return or None for all
* @return The top predicted results as (className, [Accuracy, Xmin, Ymin, Xmax, Ymax])
*/
private[infer] def sortAndReformat(predictResultND: NDArray, topK: Option[Int])
: IndexedSeq[(String, Array[Float])] = {
// iterating over the all the predictions
Expand Down Expand Up @@ -170,6 +177,18 @@ class ObjectDetector(modelPathPrefix: String,
result
}

/**
* Creates an image classifier from the object detector model
* @param modelPathPrefix Path prefix from where to load the model artifacts.
* These include the symbol, parameters, and synset.txt.
* Example: file://model-dir/resnet-152 (containing
* resnet-152-symbol.json, resnet-152-0000.params, and synset.txt).
* @param inputDescriptors Descriptors defining the input node names, shape,
* layout and type parameters
* @param contexts Device contexts on which you want to run inference; defaults to CPU
* @param epoch Model epoch to load; defaults to 0
* @return The corresponding image classifier
*/
private[infer] def getImageClassifier(modelPathPrefix: String,
inputDescriptors: IndexedSeq[DataDesc],
contexts: Array[Context] = Context.cpu(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,27 +33,27 @@ import org.slf4j.LoggerFactory
private[infer] trait PredictBase {

/**
* Converts indexed sequences of 1-D array to NDArrays.
* <p>
* This method will take input as IndexedSeq one dimensional arrays and creates the
* NDArray needed for inference. The array will be reshaped based on the input descriptors.
* @param input: An Indexed Sequence of a one-dimensional array of datatype
* Float or Double
An IndexedSequence is needed when the model has more than one input.
* @return Indexed sequence array of outputs
*/
* Converts indexed sequences of 1-D array to NDArrays.
* This method will take input as IndexedSeq one dimensional arrays and creates the
* NDArray needed for inference. The array will be reshaped based on the input descriptors.
* @tparam T The Scala equivalent of the DType used for the input array and return value
* @param input An Indexed Sequence of a one-dimensional array of datatype
* Float or Double
* An IndexedSequence is needed when the model has more than one input.
* @return Indexed sequence array of outputs
*/
def predict[@specialized (Base.MX_PRIMITIVES) T](input: IndexedSeq[Array[T]])
: IndexedSeq[Array[T]]

/**
* Predict using NDArray as input.
* <p>
* This method is useful when the input is a batch of data
* or when multiple operations on the input have to performed.
* Note: User is responsible for managing allocation/deallocation of NDArrays.
* @param input IndexedSequence NDArrays.
* @return Output of predictions as NDArrays.
*/
* Predict using NDArray as input.
* <p>
* This method is useful when the input is a batch of data
* or when multiple operations on the input have to performed.
* Note: User is responsible for managing allocation/deallocation of NDArrays.
* @param input IndexedSequence NDArrays.
* @return Output of predictions as NDArrays.
*/
def predictWithNDArray(input: IndexedSeq[NDArray]): IndexedSeq[NDArray]

/**
Expand Down Expand Up @@ -248,6 +248,10 @@ class Predictor(modelPathPrefix: String,
resultND
}

/**
* Creates the module backing the Predictor with the same path, epoch, contexts, and inputs
* @return The Module
*/
private[infer] def loadModule(): Module = {
val mod = mxNetHandler.execute(Module.loadCheckpoint(modelPathPrefix, epoch.get,
contexts = contexts, dataNames = inputDescriptors.map(desc => desc.name)))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,19 +31,23 @@ import scala.language.implicitConversions
* The ObjectDetector class helps to run ObjectDetection tasks where the goal
* is to find bounding boxes and corresponding labels for objects in a image.
*
* @param modelPathPrefix Path prefix from where to load the model artifacts.
* These include the symbol, parameters, and synset.txt.
* Example: file://model-dir/ssd_resnet50_512 (containing
* ssd_resnet50_512-symbol.json, ssd_resnet50_512-0000.params,
* and synset.txt)
* @param inputDescriptors Descriptors defining the input node names, shape,
* layout and type parameters
* @param contexts Device contexts on which you want to run inference.
* Defaults to CPU.
* @param epoch Model epoch to load; defaults to 0
* @param objDetector A source Scala Object detector
*/
class ObjectDetector private[mxnet] (val objDetector: org.apache.mxnet.infer.ObjectDetector){

/**
*
* @param modelPathPrefix Path prefix from where to load the model artifacts.
* These include the symbol, parameters, and synset.txt.
* Example: file://model-dir/ssd_resnet50_512 (containing
* ssd_resnet50_512-symbol.json, ssd_resnet50_512-0000.params,
* and synset.txt)
* @param inputDescriptors Descriptors defining the input node names, shape,
* layout and type parameters
* @param contexts Device contexts on which you want to run inference.
* Defaults to CPU.
* @param epoch Model epoch to load; defaults to 0
*/
def this(modelPathPrefix: String, inputDescriptors: java.lang.Iterable[DataDesc], contexts:
java.lang.Iterable[Context], epoch: Int)
= this {
Expand Down Expand Up @@ -98,32 +102,78 @@ class ObjectDetector private[mxnet] (val objDetector: org.apache.mxnet.infer.Obj
(ret map {a => (a map {e => new ObjectDetectorOutput(e._1, e._2)}).asJava}).asJava
}

/**
* Helper to map an implicit conversion
* @param l The value to convert
* @tparam B The desired type
* @tparam A The input type
* @return The converted result
*/
def convert[B, A <% B](l: IndexedSeq[A]): IndexedSeq[B] = l map { a => a: B }

}


object ObjectDetector {
implicit def fromObjectDetector(OD: org.apache.mxnet.infer.ObjectDetector):
ObjectDetector = new ObjectDetector(OD)

implicit def toObjectDetector(jOD: ObjectDetector):
org.apache.mxnet.infer.ObjectDetector = jOD.objDetector

/**
* Loads an input images from file
* @param inputImagePath Path of single input image
* @return BufferedImage Buffered image
*/
def loadImageFromFile(inputImagePath: String): BufferedImage = {
org.apache.mxnet.infer.ImageClassifier.loadImageFromFile(inputImagePath)
}

/**
* Reshape the input image to a new shape
*
* @param img Input image
* @param newWidth New width for rescaling
* @param newHeight New height for rescaling
* @return Rescaled BufferedImage
*/
def reshapeImage(img : BufferedImage, newWidth: Int, newHeight: Int): BufferedImage = {
org.apache.mxnet.infer.ImageClassifier.reshapeImage(img, newWidth, newHeight)
}

/**
* Convert input BufferedImage to NDArray of input shape
* Note: Caller is responsible to dispose the NDArray
* returned by this method after the use.
*
* @param resizedImage BufferedImage to get pixels from
* @param inputImageShape Input shape; for example for resnet it is (3,224,224).
* Should be same as inputDescriptor shape.
* @return NDArray pixels array with shape (3, 224, 224) in CHW format
*/
def bufferedImageToPixels(resizedImage: BufferedImage, inputImageShape: Shape): NDArray = {
org.apache.mxnet.infer.ImageClassifier.bufferedImageToPixels(resizedImage, inputImageShape)
}

/**
* Loads a batch of images from a folder
* @param inputImagePaths Path to a folder of images
* @return List of buffered images
*/
def loadInputBatch(inputImagePaths: java.lang.Iterable[String]): java.util.List[BufferedImage] = {
org.apache.mxnet.infer.ImageClassifier
.loadInputBatch(inputImagePaths.asScala.toList).toList.asJava
}

/**
* Implicitly convert a Scala ObjectDetector to a Java ObjectDetector
* @param OD The Scala ObjectDetector
* @return The Java ObjectDetector
*/
implicit def fromObjectDetector(OD: org.apache.mxnet.infer.ObjectDetector):
ObjectDetector = new ObjectDetector(OD)

/**
* Implicitly converts a Java ObjectDetector to a Scala ObjectDetector
* @param jOD The Java ObjectDetector
* @return The Scala ObjectDetector
*/
implicit def toObjectDetector(jOD: ObjectDetector):
org.apache.mxnet.infer.ObjectDetector = jOD.objDetector
}
Loading