Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[MXNET-1385] Improved Scala Init and Macros warning messages #14656

Merged
merged 4 commits into from
Apr 19, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ object Image {
* @param flag Convert decoded image to grayscale (0) or color (1).
* @param to_rgb Whether to convert decoded image
* to mxnet's default RGB format (instead of opencv's default BGR).
* @param out NDArray to store the output
* @return NDArray in HWC format with DType [[DType.UInt8]]
*/
def imDecode(buf: Array[Byte], flag: Int,
Expand All @@ -57,6 +58,10 @@ object Image {
/**
* Same imageDecode with InputStream
* @param inputStream the inputStream of the image
* @param flag Convert decoded image to grayscale (0) or color (1).
* @param to_rgb Whether to convert decoded image
* to mxnet's default RGB format (instead of opencv's default BGR).
* @param out NDArray to store the output
* @return NDArray in HWC format with DType [[DType.UInt8]]
*/
def imDecode(inputStream: InputStream, flag: Int = 1,
Expand All @@ -79,6 +84,7 @@ object Image {
* @param flag Convert decoded image to grayscale (0) or color (1).
* @param to_rgb Whether to convert decoded image to mxnet's default RGB format
* (instead of opencv's default BGR).
* @param out NDArray to store the output
* @return org.apache.mxnet.NDArray in HWC format with DType [[DType.UInt8]]
*/
def imRead(filename: String, flag: Option[Int] = None,
Expand All @@ -99,6 +105,7 @@ object Image {
* @param w Width of resized image.
* @param h Height of resized image.
* @param interp Interpolation method (default=cv2.INTER_LINEAR).
* @param out NDArray to store the output
* @return org.apache.mxnet.NDArray
*/
def imResize(src: org.apache.mxnet.NDArray, w: Int, h: Int,
Expand All @@ -124,6 +131,7 @@ object Image {
* @param typeOf Filling type (default=cv2.BORDER_CONSTANT).
* @param value (Deprecated! Use ``values`` instead.) Fill with single value.
* @param values Fill with value(RGB[A] or gray), up to 4 channels.
* @param out NDArray to store the output
* @return org.apache.mxnet.NDArray
*/
def copyMakeBorder(src: org.apache.mxnet.NDArray, top: Int, bot: Int,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,11 @@ import scala.util.Try
*/
@AddNDArrayFunctions(false)
object NDArray extends NDArrayBase {
/**
* method to convert NDArrayFunctionReturn to NDArray
* @param ret the returned NDArray list
* @return NDArray result
*/
implicit def getFirstResult(ret: NDArrayFuncReturn): NDArray = ret(0)
private val logger = LoggerFactory.getLogger(classOf[NDArray])

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,12 @@ object Image {
org.apache.mxnet.Image.imDecode(buf, flag, toRGB, None)
}

/**
* Decode image with OpenCV.
* Note: return image in RGB by default, instead of OpenCV's default BGR.
* @param buf Buffer containing binary encoded image
* @return NDArray in HWC format with DType [[DType.UInt8]]
*/
def imDecode(buf: Array[Byte]): NDArray = {
imDecode(buf, 1, true)
}
Expand All @@ -52,6 +58,12 @@ object Image {
org.apache.mxnet.Image.imDecode(inputStream, flag, toRGB, None)
}

/**
* Same imageDecode with InputStream
*
* @param inputStream the inputStream of the image
* @return NDArray in HWC format with DType [[DType.UInt8]]
*/
def imDecode(inputStream: InputStream): NDArray = {
imDecode(inputStream, 1, true)
}
Expand All @@ -69,6 +81,12 @@ object Image {
org.apache.mxnet.Image.imRead(filename, Some(flag), Some(toRGB), None)
}

/**
* Read and decode image with OpenCV.
* Note: return image in RGB by default, instead of OpenCV's default BGR.
* @param filename Name of the image file to be loaded.
* @return org.apache.mxnet.NDArray in HWC format with DType [[DType.UInt8]]
*/
def imRead(filename: String): NDArray = {
imRead(filename, 1, true)
}
Expand All @@ -86,6 +104,13 @@ object Image {
org.apache.mxnet.Image.imResize(src, w, h, interpVal, None)
}

/**
* Resize image with OpenCV.
* @param src source image in NDArray
* @param w Width of resized image.
* @param h Height of resized image.
* @return org.apache.mxnet.NDArray
*/
def imResize(src: NDArray, w: Int, h: Int): NDArray = {
imResize(src, w, h, null)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,17 +29,41 @@ object Base {
class RefFloat(val value: Float = 0)
class RefString(val value: String = null)

/**
* This C Pointer Address point to the
* actual memory piece created in MXNet Engine
*/
type CPtrAddress = Long

/**
* NDArrayHandle is the C pointer to
* the NDArray
*/
type NDArrayHandle = CPtrAddress
/**
* FunctionHandle is the C pointer to
* the ids of the operators
*/
type FunctionHandle = CPtrAddress
/**
* KVStoreHandle is the C pointer to
* the KVStore
*/
type KVStoreHandle = CPtrAddress
/**
* ExecutorHandle is the C pointer to
* the Executor
*/
type ExecutorHandle = CPtrAddress
/**
* SymbolHandle is the C pointer to
* the Symbol
*/
type SymbolHandle = CPtrAddress

@throws(classOf[UnsatisfiedLinkError])
private def tryLoadInitLibrary(): Unit = {
var userDir : File = new File(System.getProperty("user.dir"))
val userDir : File = new File(System.getProperty("user.dir"))
var nativeDir : File = new File(userDir, "init-native")
if (!nativeDir.exists()) {
nativeDir = new File(userDir.getParent, "init-native")
Expand All @@ -50,7 +74,6 @@ object Base {
val baseDir = nativeDir.getAbsolutePath

val os = System.getProperty("os.name")
// ref: http://lopica.sourceforge.net/os.html
if (os.startsWith("Linux")) {
System.load(s"$baseDir/target/libmxnet-init-scala.so")
} else if (os.startsWith("Mac")) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,25 @@ import org.apache.mxnet.init.Base._
import scala.collection.mutable.ListBuffer

class LibInfo {
/**
* Get the list of the symbol ids
* @param symbolList Pass in an empty ListBuffer and obtain a list of operator IDs
* @return Callback result
*/
@native def mxSymbolListAtomicSymbolCreators(symbolList: ListBuffer[SymbolHandle]): Int

/**
* Get the detailed information of an operator
* @param handle The ID of the operator
* @param name Name of the operator
* @param desc Description of the operator
* @param numArgs Number of arguments
* @param argNames Argument names
* @param argTypes Argument types
* @param argDescs Argument descriptions
* @param keyVarNumArgs Kwargs number
* @return Callback result
*/
@native def mxSymbolGetAtomicSymbolInfo(handle: SymbolHandle,
name: RefString,
desc: RefString,
Expand All @@ -31,6 +49,18 @@ class LibInfo {
argTypes: ListBuffer[String],
argDescs: ListBuffer[String],
keyVarNumArgs: RefString): Int
/**
* Get the name list of all operators
* @param names Names of all operator
* @return Callback result
*/
@native def mxListAllOpNames(names: ListBuffer[String]): Int

/**
* Get operator ID from its name
* @param opName Operator name
* @param opHandle Operator ID
* @return Callback result
*/
@native def nnGetOpHandle(opName: String, opHandle: RefLong): Int
}
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,16 @@ import java.security.MessageDigest
import scala.collection.mutable.ListBuffer

/**
* This object will generate the Scala documentation of the new Scala API
* Two file namely: SymbolAPIBase.scala and NDArrayAPIBase.scala
* This object will generate the Scala documentation of the Scala/Java APIs
* The code will be executed during Macros stage and file live in Core stage
*/
private[mxnet] object APIDocGenerator extends GeneratorBase with RandomHelpers {

/**
* Main method used to generate code and write to files
* A hash check placed at the end to verify changes
* @param args Input args
*/
def main(args: Array[String]): Unit = {
val FILE_PATH = args(0)
val hashCollector = ListBuffer[String]()
Expand All @@ -42,13 +46,25 @@ private[mxnet] object APIDocGenerator extends GeneratorBase with RandomHelpers {
val finalHash = hashCollector.mkString("\n")
}

/**
* Generate MD5 result from an input string
* Encoded in UTF-8
* @param input The input string
* @return A MD5 value from the string
*/
def MD5Generator(input: String): String = {
val md = MessageDigest.getInstance("MD5")
md.update(input.getBytes("UTF-8"))
val digest = md.digest()
org.apache.commons.codec.binary.Base64.encodeBase64URLSafeString(digest)
}

/**
* Type-safe class body generation for NDArray/Symbol
* @param FILE_PATH File path write the file to
* @param isSymbol Check if write the Symbol API, NDArray otherwise
* @return MD5 String
*/
def typeSafeClassGen(FILE_PATH: String, isSymbol: Boolean): String = {
val generated = typeSafeFunctionsToGenerate(isSymbol, isContrib = false)
.map { func =>
Expand All @@ -65,6 +81,12 @@ private[mxnet] object APIDocGenerator extends GeneratorBase with RandomHelpers {
generated)
}

/**
* Generate the Random classes for Symbol/NDArray
* @param FILE_PATH File path write the file to
* @param isSymbol Check if write the Symbol API, NDArray otherwise
* @return MD5 String
*/
def typeSafeRandomClassGen(FILE_PATH: String, isSymbol: Boolean): String = {
val generated = typeSafeRandomFunctionsToGenerate(isSymbol)
.map { func =>
Expand All @@ -83,6 +105,16 @@ private[mxnet] object APIDocGenerator extends GeneratorBase with RandomHelpers {
generated)
}

/**
* Non Type-safe interface of Scala Symbol/NDArray
lanking520 marked this conversation as resolved.
Show resolved Hide resolved
* It includes class definition : e.g class SymbolBase
* and function definitions : e.g def softmax(...)(...)(...) : NDArray
* Users can directly use the api by calling NDArray.<function_name>
* It support both positional input or Map input
* @param FILE_PATH File path write the file to
* @param isSymbol Check if write the Symbol API, NDArray otherwise
* @return MD5 String
*/
def nonTypeSafeClassGen(FILE_PATH: String, isSymbol: Boolean): String = {
val absFuncs = functionsToGenerate(isSymbol, isContrib = false)
.map { func =>
Expand Down Expand Up @@ -112,7 +144,12 @@ private[mxnet] object APIDocGenerator extends GeneratorBase with RandomHelpers {
absFuncs)
}

def javaClassGen(filePath : String) : String = {
/**
* Type-safe interface of Java NDArray
* @param FILE_PATH File path write the file to
* @return MD5 String
*/
def javaClassGen(FILE_PATH : String) : String = {
val notGenerated = Set("Custom")
val absClassFunctions = functionsToGenerate(false, false, true)
val absFuncs = absClassFunctions.filterNot(ele => notGenerated.contains(ele.name))
Expand All @@ -133,13 +170,19 @@ private[mxnet] object APIDocGenerator extends GeneratorBase with RandomHelpers {
val packageName = "NDArrayBase"
val packageDef = "package org.apache.mxnet.javaapi"
writeFile(
filePath + "javaapi/",
FILE_PATH + "javaapi/",
packageDef,
packageName,
"import org.apache.mxnet.annotation.Experimental",
absFuncs)
}

/**
* Generate Scala docs from the function description
* @param func The function case class
* @param withParam Whether to generate param field
* @return A formatted string for the function description
*/
def generateAPIDocFromBackend(func: Func, withParam: Boolean = true): String = {
def fixDesc(desc: String): String = {
var curDesc = desc
Expand Down Expand Up @@ -173,6 +216,14 @@ private[mxnet] object APIDocGenerator extends GeneratorBase with RandomHelpers {
}
}

/**
* Generate the function interface
* e.g: def softmax(data: NDArray, name ...): NDArrayFunctionReturn
* @param func The function case class
* @param isSymbol Check if generate Symbol function, NDArray otherwise
* @param typeParameter Type param specifically used in Random Module
* @return Formatted string for the function
*/
def generateAPISignature(func: Func, isSymbol: Boolean, typeParameter: String = ""): String = {
val argDef = ListBuffer[String]()

Expand All @@ -192,6 +243,11 @@ private[mxnet] object APIDocGenerator extends GeneratorBase with RandomHelpers {
|def ${func.name}$typeParameter (${argDef.mkString(", ")}): $returnType""".stripMargin
}

/**
* Generate Java function interface
* @param func The function case class
* @return A formatted string for the function
*/
def generateJavaAPISignature(func : Func) : String = {
val useParamObject = func.listOfArgs.count(arg => arg.isOptional) >= 2
var argDef = ListBuffer[String]()
Expand Down Expand Up @@ -250,6 +306,15 @@ private[mxnet] object APIDocGenerator extends GeneratorBase with RandomHelpers {
}
}

/**
* Write the formatted string to file
* @param FILE_PATH Location of the file writes to
* @param packageDef Package definition
* @param className Class name
* @param imports Packages need to import
* @param absFuncs All formatted functions
* @return A MD5 string
*/
def writeFile(FILE_PATH: String, packageDef: String, className: String,
imports: String, absFuncs: Seq[String]): String = {

Expand Down
Loading