Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
1st phase scala/java warning killers
Browse files Browse the repository at this point in the history
  • Loading branch information
lanking520 committed Apr 9, 2019
1 parent 733e54c commit c5daa53
Show file tree
Hide file tree
Showing 12 changed files with 330 additions and 63 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ object Image {
* @param flag Convert decoded image to grayscale (0) or color (1).
* @param to_rgb Whether to convert decoded image
* to mxnet's default RGB format (instead of opencv's default BGR).
* @param out NDArray to store the output
* @return NDArray in HWC format with DType [[DType.UInt8]]
*/
def imDecode(buf: Array[Byte], flag: Int,
Expand All @@ -57,6 +58,10 @@ object Image {
/**
* Same imageDecode with InputStream
* @param inputStream the inputStream of the image
* @param flag Convert decoded image to grayscale (0) or color (1).
* @param to_rgb Whether to convert decoded image
* to mxnet's default RGB format (instead of opencv's default BGR).
* @param out NDArray to store the output
* @return NDArray in HWC format with DType [[DType.UInt8]]
*/
def imDecode(inputStream: InputStream, flag: Int = 1,
Expand All @@ -79,6 +84,7 @@ object Image {
* @param flag Convert decoded image to grayscale (0) or color (1).
* @param to_rgb Whether to convert decoded image to mxnet's default RGB format
* (instead of opencv's default BGR).
* @param out NDArray to store the output
* @return org.apache.mxnet.NDArray in HWC format with DType [[DType.UInt8]]
*/
def imRead(filename: String, flag: Option[Int] = None,
Expand All @@ -99,6 +105,7 @@ object Image {
* @param w Width of resized image.
* @param h Height of resized image.
* @param interp Interpolation method (default=cv2.INTER_LINEAR).
* @param out NDArray to store the output
* @return org.apache.mxnet.NDArray
*/
def imResize(src: org.apache.mxnet.NDArray, w: Int, h: Int,
Expand All @@ -124,6 +131,7 @@ object Image {
* @param typeOf Filling type (default=cv2.BORDER_CONSTANT).
* @param value (Deprecated! Use ``values`` instead.) Fill with single value.
* @param values Fill with value(RGB[A] or gray), up to 4 channels.
* @param out NDArray to store the output
* @return org.apache.mxnet.NDArray
*/
def copyMakeBorder(src: org.apache.mxnet.NDArray, top: Int, bot: Int,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,11 @@ import scala.util.Try
*/
@AddNDArrayFunctions(false)
object NDArray extends NDArrayBase {
/**
* method to convert NDArrayFunctionReturn to NDArray
* @param ret the returned NDArray list
* @return NDArray result
*/
implicit def getFirstResult(ret: NDArrayFuncReturn): NDArray = ret(0)
private val logger = LoggerFactory.getLogger(classOf[NDArray])

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,12 @@ object Image {
org.apache.mxnet.Image.imDecode(buf, flag, toRGB, None)
}

/**
* Decode image with OpenCV.
* Note: return image in RGB by default, instead of OpenCV's default BGR.
* @param buf Buffer containing binary encoded image
* @return NDArray in HWC format with DType [[DType.UInt8]]
*/
def imDecode(buf: Array[Byte]): NDArray = {
imDecode(buf, 1, true)
}
Expand All @@ -52,6 +58,12 @@ object Image {
org.apache.mxnet.Image.imDecode(inputStream, flag, toRGB, None)
}

/**
* Same imageDecode with InputStream
*
* @param inputStream the inputStream of the image
* @return NDArray in HWC format with DType [[DType.UInt8]]
*/
def imDecode(inputStream: InputStream): NDArray = {
imDecode(inputStream, 1, true)
}
Expand All @@ -69,6 +81,12 @@ object Image {
org.apache.mxnet.Image.imRead(filename, Some(flag), Some(toRGB), None)
}

/**
* Read and decode image with OpenCV.
* Note: return image in RGB by default, instead of OpenCV's default BGR.
* @param filename Name of the image file to be loaded.
* @return org.apache.mxnet.NDArray in HWC format with DType [[DType.UInt8]]
*/
def imRead(filename: String): NDArray = {
imRead(filename, 1, true)
}
Expand All @@ -86,6 +104,13 @@ object Image {
org.apache.mxnet.Image.imResize(src, w, h, interpVal, None)
}

/**
* Resize image with OpenCV.
* @param src source image in NDArray
* @param w Width of resized image.
* @param h Height of resized image.
* @return org.apache.mxnet.NDArray
*/
def imResize(src: NDArray, w: Int, h: Int): NDArray = {
imResize(src, w, h, null)
}
Expand Down
27 changes: 25 additions & 2 deletions scala-package/init/src/main/scala/org/apache/mxnet/init/Base.scala
Original file line number Diff line number Diff line change
Expand Up @@ -29,17 +29,41 @@ object Base {
class RefFloat(val value: Float = 0)
class RefString(val value: String = null)

/**
* This C Pointer Address point to the
* actual memory piece created in MXNet Engine
*/
type CPtrAddress = Long

/**
* NDArrayHandle is the C pointer to
* the NDArray
*/
type NDArrayHandle = CPtrAddress
/**
* FunctionHandle is the C pointer to
* the ids of the operators
*/
type FunctionHandle = CPtrAddress
/**
* KVStorHandle is the C pointer to
* the KVStore
*/
type KVStoreHandle = CPtrAddress
/**
* ExecutorHandle is the C pointer to
* the Executor
*/
type ExecutorHandle = CPtrAddress
/**
* SymbolHandle is the C pointer to
* the Symbol
*/
type SymbolHandle = CPtrAddress

@throws(classOf[UnsatisfiedLinkError])
private def tryLoadInitLibrary(): Unit = {
var userDir : File = new File(System.getProperty("user.dir"))
val userDir : File = new File(System.getProperty("user.dir"))
var nativeDir : File = new File(userDir, "init-native")
if (!nativeDir.exists()) {
nativeDir = new File(userDir.getParent, "init-native")
Expand All @@ -50,7 +74,6 @@ object Base {
val baseDir = nativeDir.getAbsolutePath

val os = System.getProperty("os.name")
// ref: http://lopica.sourceforge.net/os.html
if (os.startsWith("Linux")) {
System.load(s"$baseDir/target/libmxnet-init-scala.so")
} else if (os.startsWith("Mac")) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,25 @@ import org.apache.mxnet.init.Base._
import scala.collection.mutable.ListBuffer

class LibInfo {
/**
* Get the list of the symbol ids
* @param symbolList pass in an empty ListBuffer and obtain a list of operator ids
* @return callback result
*/
@native def mxSymbolListAtomicSymbolCreators(symbolList: ListBuffer[SymbolHandle]): Int

/**
* Get the detailed information of an operator
* @param handle The id of the operator
* @param name name of the operator
* @param desc description of the operator
* @param numArgs number of arguments
* @param argNames argument names
* @param argTypes argument types
* @param argDescs argument descriptions
* @param keyVarNumArgs Kwargs number
* @return callback result
*/
@native def mxSymbolGetAtomicSymbolInfo(handle: SymbolHandle,
name: RefString,
desc: RefString,
Expand All @@ -31,6 +49,18 @@ class LibInfo {
argTypes: ListBuffer[String],
argDescs: ListBuffer[String],
keyVarNumArgs: RefString): Int
/**
* Get the name list of all operators
* @param names names of all operators
* @return callback result
*/
@native def mxListAllOpNames(names: ListBuffer[String]): Int

/**
* get operator id from its name
* @param opName Operator name
* @param opHandle Operator id
* @return callback result
*/
@native def nnGetOpHandle(opName: String, opHandle: RefLong): Int
}
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,16 @@ import java.security.MessageDigest
import scala.collection.mutable.ListBuffer

/**
* This object will generate the Scala documentation of the new Scala API
* Two file namely: SymbolAPIBase.scala and NDArrayAPIBase.scala
* This object will generate the Scala documentation of the Scala/Java APIs
* The code will be executed during Macros stage and file live in Core stage
*/
private[mxnet] object APIDocGenerator extends GeneratorBase with RandomHelpers {

/**
* main method used to generate code and write to files
* a hash check placed at the end to verify changes
* @param args input args
*/
def main(args: Array[String]): Unit = {
val FILE_PATH = args(0)
val hashCollector = ListBuffer[String]()
Expand All @@ -42,13 +46,25 @@ private[mxnet] object APIDocGenerator extends GeneratorBase with RandomHelpers {
val finalHash = hashCollector.mkString("\n")
}

/**
* generate MD5 result from an input string
* encoded in UTF-8
* @param input the input string
* @return a MD5 value from the string
*/
def MD5Generator(input: String): String = {
val md = MessageDigest.getInstance("MD5")
md.update(input.getBytes("UTF-8"))
val digest = md.digest()
org.apache.commons.codec.binary.Base64.encodeBase64URLSafeString(digest)
}

/**
* Type-safe class body generation for NDArray/Symbol
* @param FILE_PATH file path write the file to
* @param isSymbol check if write the Symbol API, NDArray otherwise
* @return MD5 String
*/
def typeSafeClassGen(FILE_PATH: String, isSymbol: Boolean): String = {
val generated = typeSafeFunctionsToGenerate(isSymbol, isContrib = false)
.map { func =>
Expand All @@ -65,6 +81,12 @@ private[mxnet] object APIDocGenerator extends GeneratorBase with RandomHelpers {
generated)
}

/**
* Generate the Random classes for Symbol/NDArray
* @param FILE_PATH file path write the file to
* @param isSymbol check if write the Symbol API, NDArray otherwise
* @return MD5 String
*/
def typeSafeRandomClassGen(FILE_PATH: String, isSymbol: Boolean): String = {
val generated = typeSafeRandomFunctionsToGenerate(isSymbol)
.map { func =>
Expand All @@ -83,6 +105,12 @@ private[mxnet] object APIDocGenerator extends GeneratorBase with RandomHelpers {
generated)
}

/**
* Non Type-safe interface of Scala Symbol/NDArray
* @param FILE_PATH file path write the file to
* @param isSymbol check if write the Symbol API, NDArray otherwise
* @return MD5 String
*/
def nonTypeSafeClassGen(FILE_PATH: String, isSymbol: Boolean): String = {
val absFuncs = functionsToGenerate(isSymbol, isContrib = false)
.map { func =>
Expand Down Expand Up @@ -112,7 +140,12 @@ private[mxnet] object APIDocGenerator extends GeneratorBase with RandomHelpers {
absFuncs)
}

def javaClassGen(filePath : String) : String = {
/**
* Type-safe interface of Java NDArray
* @param FILE_PATH file path write the file to
* @return MD5 String
*/
def javaClassGen(FILE_PATH : String) : String = {
val notGenerated = Set("Custom")
val absClassFunctions = functionsToGenerate(false, false, true)
val absFuncs = absClassFunctions.filterNot(ele => notGenerated.contains(ele.name))
Expand All @@ -133,13 +166,19 @@ private[mxnet] object APIDocGenerator extends GeneratorBase with RandomHelpers {
val packageName = "NDArrayBase"
val packageDef = "package org.apache.mxnet.javaapi"
writeFile(
filePath + "javaapi/",
FILE_PATH + "javaapi/",
packageDef,
packageName,
"import org.apache.mxnet.annotation.Experimental",
absFuncs)
}

/**
* Generate Scala docs from the function description
* @param func the Function case class
* @param withParam whether to generate param field
* @return a formatted string for the function description
*/
def generateAPIDocFromBackend(func: Func, withParam: Boolean = true): String = {
def fixDesc(desc: String): String = {
var curDesc = desc
Expand Down Expand Up @@ -173,6 +212,14 @@ private[mxnet] object APIDocGenerator extends GeneratorBase with RandomHelpers {
}
}

/**
* generate the function interface
* e.g: def softmax(data: NDArray, name ...): NDArrayFunctionReturn
* @param func the Function case class
* @param isSymbol check if generate Symbol function, NDArray otherwise
* @param typeParameter type param specifically used in Random Module
* @return formatted string for the function
*/
def generateAPISignature(func: Func, isSymbol: Boolean, typeParameter: String = ""): String = {
val argDef = ListBuffer[String]()

Expand All @@ -192,6 +239,11 @@ private[mxnet] object APIDocGenerator extends GeneratorBase with RandomHelpers {
|def ${func.name}$typeParameter (${argDef.mkString(", ")}): $returnType""".stripMargin
}

/**
* Generate Java function interface
* @param func the Function case class
* @return a formatted string for the function
*/
def generateJavaAPISignature(func : Func) : String = {
val useParamObject = func.listOfArgs.count(arg => arg.isOptional) >= 2
var argDef = ListBuffer[String]()
Expand Down Expand Up @@ -250,6 +302,15 @@ private[mxnet] object APIDocGenerator extends GeneratorBase with RandomHelpers {
}
}

/**
* Write the formatted string to file
* @param FILE_PATH location of the file writes to
* @param packageDef package definition
* @param className class name
* @param imports packages need to import
* @param absFuncs all formatted functions
* @return a MD5 string
*/
def writeFile(FILE_PATH: String, packageDef: String, className: String,
imports: String, absFuncs: Seq[String]): String = {

Expand Down
Loading

0 comments on commit c5daa53

Please sign in to comment.