diff --git a/scala-package/core/src/main/scala/ml/dmlc/mxnet/Symbol.scala b/scala-package/core/src/main/scala/ml/dmlc/mxnet/Symbol.scala index 43c382823955..7199c0527e82 100644 --- a/scala-package/core/src/main/scala/ml/dmlc/mxnet/Symbol.scala +++ b/scala-package/core/src/main/scala/ml/dmlc/mxnet/Symbol.scala @@ -10,7 +10,25 @@ import scala.collection.mutable.{ArrayBuffer, ListBuffer} * @author Yizhi Liu */ class Symbol(private[mxnet] val handle: SymbolHandle) { - def +(other: Symbol): Symbol = Symbol.create("_Plus", this, other) + def +(other: Symbol): Symbol = Symbol.createFromListedSymbols("_Plus")(Array(this, other)) + def +[@specialized(Int, Float, Double) V](other: V): Symbol = { + Symbol.createFromListedSymbols("_PlusScalar")(Array(this), Map("scalar" -> other.toString)) + } + + def -(other: Symbol): Symbol = Symbol.createFromListedSymbols("_Minus")(Array(this, other)) + def -[@specialized(Int, Float, Double) V](other: V): Symbol = { + Symbol.createFromListedSymbols("_MinusScalar")(Array(this), Map("scalar" -> other.toString)) + } + + def *(other: Symbol): Symbol = Symbol.createFromListedSymbols("_Mul")(Array(this, other)) + def *[@specialized(Int, Float, Double) V](other: V): Symbol = { + Symbol.createFromListedSymbols("_MulScalar")(Array(this), Map("scalar" -> other.toString)) + } + + def /(other: Symbol): Symbol = Symbol.createFromListedSymbols("_Div")(Array(this, other)) + def /[@specialized(Int, Float, Double) V](other: V): Symbol = { + Symbol.createFromListedSymbols("_DivScalar")(Array(this), Map("scalar" -> other.toString)) + } override def clone(): Symbol = { val clonedHandle = new SymbolHandleRef @@ -40,7 +58,7 @@ class Symbol(private[mxnet] val handle: SymbolHandle) { * Get a new grouped symbol whose output contains all the internal outputs of this symbol. * @return The internal of the symbol. */ - def getInternals: Symbol = { + def getInternals(): Symbol = { val newHandle = new SymbolHandleRef checkCall(_LIB.mxSymbolGetInternals(handle, newHandle)) new Symbol(handle = newHandle.value) @@ -555,6 +573,10 @@ class Symbol(private[mxnet] val handle: SymbolHandle) { bind(ctx, args, argsGrad, "write", Nil, null) } + def bind(ctx: Context, args: Map[String, NDArray], argsGrad: Seq[NDArray]): Executor = { + bind(ctx, args, argsGrad, "write", Nil, null) + } + def bind(ctx: Context, args: Seq[NDArray], argsGrad: Map[String, NDArray]): Executor = { bind(ctx, args, argsGrad, "write", Nil, null) } @@ -674,11 +696,148 @@ class Symbol(private[mxnet] val handle: SymbolHandle) { } object Symbol { - private type SymbolCreateFunc = Map[String, Any] => Symbol + private type SymbolCreateNamedFunc = Map[String, Any] => Symbol private val logger = LoggerFactory.getLogger(classOf[Symbol]) private val functions: Map[String, SymbolFunction] = initSymbolModule() private val bindReqMap = Map("null" -> 0, "write" -> 1, "add" -> 3) + // TODO: _CrossDeviceCopy + + def pow(sym1: Symbol, sym2: Symbol): Symbol = { + Symbol.createFromListedSymbols("_Power")(Array(sym1, sym2)) + } + + def pow[@specialized(Int, Float, Double) V](sym: Symbol, number: V): Symbol = { + Symbol.createFromListedSymbols("_PowerScalar")(Array(sym), Map("scalar" -> number.toString)) + } + + def pow[@specialized(Int, Float, Double) V](number: V, sym: Symbol): Symbol = { + Symbol.createFromListedSymbols("_PowerScalar")(Array(sym), + Map("scalar" -> number.toString, "scalar_on_left" -> "True")) + } + + /** + * Take absolute value of the src + * @param src Source symbolic input to the function + */ + def abs(src: Symbol): Symbol = { + createFromListedSymbols("abs")(Array(src)) + } + + /** + * Take sign value of the src + * @param src Source symbolic input to the function + */ + def sign(src: Symbol): Symbol = { + createFromListedSymbols("sign")(Array(src)) + } + + /** + * Take round value of the src + * @param src Source input to the function + */ + def round(src: Symbol): Symbol = { + createFromListedSymbols("round")(Array(src)) + } + + /** + * Take ceil value of the src + * src Source input to the function + */ + def ceil(src: Symbol): Symbol = { + createFromListedSymbols("ceil")(Array(src)) + } + + /** + * Take floor value of the src + * @param src Source input to the function + */ + def floor(src: Symbol): Symbol = { + createFromListedSymbols("floor")(Array(src)) + } + + /** + * Take square of the src + * @param src Source symbolic input to the function + */ + def square(src: Symbol): Symbol = { + createFromListedSymbols("square")(Array(src)) + } + + /** + * Take sqrt of the src + * src Source symbolic input to the function + */ + def sqrt(src: Symbol): Symbol = { + createFromListedSymbols("sqrt")(Array(src)) + } + + /** + * Take rsqrt of the src + * @param src Source symbolic input to the function + */ + def rsqrt(src: Symbol): Symbol = { + createFromListedSymbols("rsqrt")(Array(src)) + } + + /** + * Take exp of the src + * @param src Source symbolic input to the function + */ + def exp(src: Symbol): Symbol = { + createFromListedSymbols("exp")(Array(src)) + } + + /** + * Take log of the src + * @param src Source symbolic input to the function + */ + def log(src: Symbol): Symbol = { + createFromListedSymbols("log")(Array(src)) + } + + /** + * Take cos of the src + * @param src Source symbolic input to the function + */ + def cos(src: Symbol): Symbol = { + createFromListedSymbols("cos")(Array(src)) + } + + /** + * Take sin of the src + * @param src Source symbolic input to the function + */ + def sin(src: Symbol): Symbol = { + createFromListedSymbols("sin")(Array(src)) + } + + def max(left: Symbol, right: Symbol): Symbol = { + createFromListedSymbols("_Maximum")(Array(left, right)) + } + + def max[@specialized(Int, Float, Double) V](left: Symbol, right: V): Symbol = { + createFromListedSymbols("_MaximumScalar")(Array(left), Map("scalar" -> right.toString)) + } + + def max[@specialized(Int, Float, Double) V](left: V, right: Symbol): Symbol = { + createFromListedSymbols("_MaximumScalar")(Array(right), + Map("scalar" -> left.toString, "scalar_on_left" -> "True")) + } + + def min(left: Symbol, right: Symbol): Symbol = { + createFromListedSymbols("_Minimum")(Array(left, right)) + } + + def min[@specialized(Int, Float, Double) V](left: Symbol, right: V): Symbol = { + createFromListedSymbols("_MinimumScalar")(Array(left), Map("scalar" -> right.toString)) + } + + def min[@specialized(Int, Float, Double) V](left: V, right: Symbol): Symbol = { + createFromListedSymbols("_MinimumScalar")(Array(right), + Map("scalar" -> left.toString, "scalar_on_left" -> "True")) + } + /** * Create a symbolic variable with specified name. * @param name Name of the variable. @@ -693,85 +852,420 @@ object Symbol { sym } - def FullyConnected: SymbolCreateFunc = { - FullyConnected(null) + /** + * Get output from a symbol and pass 0 gradient back + * + * Parameters + * ---------- + * data : Symbol. Input data. + */ + def BlockGrad(name: String = null, attr: Map[String, String] = null): SymbolCreateNamedFunc = { + createFromNamedSymbolsNoCheck("BlockGrad", name, attr) + } + + /** + * Crop the 2th and 3th dim of input data, with the corresponding size of w_h or with width + * and height of the second input symbol + * + * Parameters + * ---------- + * num_args : int, required. + * Number of inputs for crop, + * if equals one, then we will use the h_w for crop height and width, + * else if equals two, + * then we will use the height and width of the second input symbol, + * we name crop_like here + * offset : Shape(tuple), optional, default=(0, 0), corp offset coordinate: (y, x) + * h_w : Shape(tuple), optional, default=(0, 0), corp height and weight: (h, w) + * center_crop : boolean, optional, default=False. + * If set to true, then it will use be the center_crop, + * or it will crop using the shape of crop_like + */ + def Crop(name: String = null, attr: Map[String, String] = null)( + inputs: Array[Symbol], params: Map[String, Any] = null): Symbol = { + createFromListedSymbolsNoCheck("Crop", name, attr)(inputs, params) + } + + /** + * Apply dropout to input + * + * Parameters + * ---------- + * data : Symbol. Input data to dropout. + * p : float, optional, default=0.5. Fraction of the input that gets dropped out at training time + */ + def Dropout(name: String = null, attr: Map[String, String] = null): SymbolCreateNamedFunc = { + createFromNamedSymbolsNoCheck("Dropout", name, attr) + } + + /** + * Apply a sparse regularization to the output a sigmoid activation function. + * + * Parameters + * ---------- + * data : Symbol. Input data. + * sparseness_target : float, optional, default=0.1. The sparseness target + * penalty : float, optional, default=0.001. The tradeoff parameter for the sparseness penalty + * momentum : float, optional, default=0.9. The momentum for running average + */ + def IdentityAttachKLSparseReg(name: String = null, + attr: Map[String, String] = null): SymbolCreateNamedFunc = { + createFromNamedSymbolsNoCheck("IdentityAttachKLSparseReg", name, attr) + } + + /** + * Apply activation function to input. + * + * Parameters + * ---------- + * data : Symbol. Input data to activation function. + * act_type : {'elu', 'leaky', 'prelu', 'rrelu'},optional, default='leaky' + * Activation function to be applied. + * slope : float, optional, default=0.25. Init slope for the activation. (For leaky and elu only) + * lower_bound : float, optional, default=0.125. Lower bound of random slope. (For rrelu only) + * upper_bound : float, optional, default=0.334. Upper bound of random slope. (For rrelu only) + */ + def LeakyReLU(name: String = null, attr: Map[String, String] = null): SymbolCreateNamedFunc = { + createFromNamedSymbolsNoCheck("LeakyReLU", name, attr) } - def FullyConnected(attr: Map[String, String]): SymbolCreateFunc = { - createNoCheck("FullyConnected", attr) + /** + * Apply convolution to input then add a bias. + * + * Parameters + * ---------- + * data : Symbol. Input data to the ConvolutionOp. + * alpha : float, optional, default=0.0001, + * value of the alpha variance scaling parameter in the normalization formula + * beta : float, optional, default=0.75, + * value of the beta power parameter in the normalization formula + * knorm : float, optional, default=2, value of the k parameter in normalization formula + * nsize : int (non-negative), required, normalization window width in elements. + */ + def LRN(name: String = null, attr: Map[String, String] = null): SymbolCreateNamedFunc = { + createFromNamedSymbolsNoCheck("LRN", name, attr) } - def Activation: SymbolCreateFunc = { - Activation(null) + /** + * Use mean absolute error regression for final output, this is used on final output of a net. + * + * Parameters + * ---------- + * data : Symbol. Input data to function. + * label : Symbol. Input label to function. + * grad_scale : float, optional, default=1. Scale the gradient by a float factor + */ + def MAERegressionOutput(name: String = null, + attr: Map[String, String] = null): SymbolCreateNamedFunc = { + createFromNamedSymbolsNoCheck("MAERegressionOutput", name, attr) } - def Activation(attr: Map[String, String]): SymbolCreateFunc = { - createNoCheck("Activation", attr) + /** + * Reshape input to target shape + * + * Parameters + * ---------- + * data : Symbol. Input data to reshape. + * target_shape : Shape(tuple), required. Target new shape. One and only one dim can be 0, + * in which case it will be infered from the rest of dims + */ + def Reshape(name: String = null, attr: Map[String, String] = null): SymbolCreateNamedFunc = { + createFromNamedSymbolsNoCheck("Reshape", name, attr) } - def Convolution(attr: Map[String, String]): SymbolCreateFunc = { - createNoCheck("Convolution", attr) + /** + * Slice channel into many outputs with equally divided channel + * + * Parameters + * ---------- + * num_outputs : int, required. Number of outputs to be sliced. + */ + def SliceChannel(name: String = null, attr: Map[String, String] = null)( + inputs: Array[Symbol], params: Map[String, Any] = null): Symbol = { + createFromListedSymbolsNoCheck("SliceChannel", name, attr)(inputs, params) } - def Convolution: Map[String, Any] => Symbol = { - Convolution(null) + /** + * Apply softmax activation to input. + * This is intended for internal layers. For output (loss layer) please use SoftmaxOutput. + * If type=instance, + * this operator will compute a softmax for each instance in the batch; this is the default mode. + * If type=channel, + * this operator will compute a num_channel-class softmax at each position of each instance; + * this can be used for fully convolutional network, image segmentation, etc. + * + * Parameters + * ---------- + * data : Symbol. Input data to activation function. + * type : {'channel', 'instance'},optional, default='instance'. Softmax Mode. + * If set to instance, + * this operator will compute a softmax for each instance in the batch; + * this is the default mode. + * If set to channel, + * this operator will compute a num_channel-class softmax + * at each position of each instance; + * this can be used for fully convolutional network, image segmentation, etc. + */ + def SoftmaxActivation(name: String = null, + attr: Map[String, String] = null): SymbolCreateNamedFunc = { + createFromNamedSymbolsNoCheck("SoftmaxActivation", name, attr) } - def BatchNorm: Map[String, Any] => Symbol = { - createNoCheck("BatchNorm") + /** + * Apply matrix multiplication to input then add a bias. + * + * Parameters + * ---------- + * data : Symbol. Input data to the FullyConnectedOp. + * weight : Symbol. Weight matrix. + * bias : Symbol. Bias parameter. + * num_hidden : int, required. Number of hidden nodes of the output. + * no_bias : boolean, optional, default=False. Whether to disable bias parameter. + */ + def FullyConnected(name: String = null, + attr: Map[String, String] = null): SymbolCreateNamedFunc = { + createFromNamedSymbolsNoCheck("FullyConnected", name, attr) } - def Pooling: Map[String, Any] => Symbol = { - createNoCheck("Pooling") + /** + * Apply activation function to input. + * Softmax Activation is only available with CUDNN on GPUand will be computed + * at each location across channel if input is 4D. + * + * Parameters + * ---------- + * data : Symbol. Input data to activation function. + * act_type : {'relu', 'sigmoid', 'softrelu', 'tanh'}, required. + * Activation function to be applied. + */ + def Activation(name: String = null, attr: Map[String, String] = null): SymbolCreateNamedFunc = { + createFromNamedSymbolsNoCheck("Activation", name, attr) + } + + /** + * Apply convolution to input then add a bias. + * + * Parameters + * ---------- + * data : Symbol. Input data to the ConvolutionOp. + * weight : Symbol. Weight matrix. + * bias : Symbol. Bias parameter. + * kernel : Shape(tuple), required. Convolution kernel size: (y, x) + * stride : Shape(tuple), optional, default=(1, 1). Convolution stride: (y, x) + * dilate : Shape(tuple), optional, default=(1, 1). Convolution dilate: (y, x) + * pad : Shape(tuple), optional, default=(0, 0). Pad for convolution: (y, x) + * num_filter : int (non-negative), required. Convolution filter(channel) number + * num_group : int (non-negative), optional, default=1 + * Number of groups partition. + * This option is not supported by CuDNN, + * you can use SliceChannel to num_group, + * apply convolution and concat instead to achieve the same need. + * workspace : long (non-negative), optional, default=512. Tmp workspace for convolution (MB). + * no_bias : boolean, optional, default=False. Whether to disable bias parameter. + */ + def Convolution(name: String = null, attr: Map[String, String] = null): SymbolCreateNamedFunc = { + createFromNamedSymbolsNoCheck("Convolution", name, attr) } - def Flatten: Map[String, Any] => Symbol = { - createNoCheck("Flatten") + /** + * Apply deconvolution to input then add a bias. + * + * Parameters + * ---------- + * data : Symbol. Input data to the DeconvolutionOp. + * weight : Symbol. Weight matrix. + * bias : Symbol. Bias parameter. + * kernel : Shape(tuple), required, deconvolution kernel size: (y, x) + * stride : Shape(tuple), optional, default=(1, 1), deconvolution stride: (y, x) + * pad : Shape(tuple), optional, default=(0, 0), pad for deconvolution: (y, x) + * num_filter : int (non-negative), required, deconvolution filter(channel) number + * num_group : int (non-negative), optional, default=1, number of groups partition + * workspace : long (non-negative), optional, default=512. Tmp workspace for deconvolution (MB) + * no_bias : boolean, optional, default=True. Whether to disable bias parameter. + */ + def Deconvolution(name: String = null, + attr: Map[String, String] = null): SymbolCreateNamedFunc = { + createFromNamedSymbolsNoCheck("Deconvolution", name, attr) + } + + /** + * Perform spatial pooling on inputs. + * + * Parameters + * ---------- + * data : Symbol. Input data to the pooling operator. + * kernel : Shape(tuple), required, pooling kernel size: (y, x) + * pool_type : {'avg', 'max', 'sum'}, required. Pooling type to be applied. + * stride : Shape(tuple), optional, default=(1, 1), stride for pooling (y, x) + * pad : Shape(tuple), optional, default=(0, 0), pad for pooling: (y, x) + */ + def Pooling(name: String = null, attr: Map[String, String] = null): SymbolCreateNamedFunc = { + createFromNamedSymbolsNoCheck("Pooling", name, attr) + } + + /** + * Flatten input + * Parameters + * ---------- + * data : Symbol. Input data to flatten. + */ + def Flatten(name: String = null, attr: Map[String, String] = null): SymbolCreateNamedFunc = { + createFromNamedSymbolsNoCheck("Flatten", name, attr) + } + + /** + * Perform a softmax transformation on input, backprop with logloss. + * + * Parameters + * ---------- + * data : Symbol. Input data to softmax. + * label : Symbol. Label data. + * grad_scale : float, optional, default=1. Scale the gradient by a float factor + * ignore_label : float, optional, default=-1. + * the ignore_label will not work in backward, + * and this onlybe used when multi_output=true + * multi_output : boolean, optional, default=False. + * If set to true, for a (n,k,x_1,..,x_n) dimensionalinput tensor, + * softmax will generate n*x_1*...*x_n output, eachhas k classes + * use_ignore : boolean, optional, default=False. + * If set to true, + * the ignore_label value will not contributorto the backward gradient + */ + def SoftmaxOutput(name: String = null, + attr: Map[String, String] = null): SymbolCreateNamedFunc = { + createFromNamedSymbolsNoCheck("SoftmaxOutput", name, attr) } - def SoftmaxOutput: Map[String, Any] => Symbol = { - createNoCheck("SoftmaxOutput") + /** + * Cast array to a different data type. + * Parameters + * ---------- + * data : Symbol, Input data to cast function. + * dtype : {Int, Double, Short, Float}, required, Target data type. + */ + def Cast(name: String = null, attr: Map[String, String] = null): SymbolCreateNamedFunc = { + createFromNamedSymbolsNoCheck("Cast", name, attr) } - def Cast: Map[String, Any] => Symbol = { - createNoCheck("Cast") + /** + * Perform an elementwise sum over all the inputs. + * + * Parameters + * ---------- + * num_args : int, required. Number of inputs to be sum. + */ + def ElementWiseSum(name: String = null, + attr: Map[String, String] = null)( + symbols: Array[Symbol], params: Map[String, Any] = null): Symbol = { + createFromListedSymbolsNoCheck("ElementWiseSum", name, attr)(symbols, params) } - def ElementWiseSum(name: String, inputs: Symbol *): Symbol = { - create("ElementWiseSum", inputs.toArray, Map("name" -> name), null) + /** + * Apply batch normalization to input. + * + * Parameters + * ---------- + * data : Symbol, Input data to batch normalization + * eps : float, optional, default=0.001, Epsilon to prevent div 0 + * momentum : float, optional, default=0.9, Momentum for moving average + * fix_gamma : boolean, optional, default=True, Fix gamma while training + */ + def BatchNorm(name: String = null, attr: Map[String, String] = null): SymbolCreateNamedFunc = { + createFromNamedSymbolsNoCheck("BatchNorm", name, attr) } - def ElementWiseSum(inputs: Seq[Symbol], name: String): Symbol = { - create("ElementWiseSum", inputs.toArray, Map("name" -> name), null) + /** + * Perform nearest neighbor/bilinear up sampling to inputs + * + * Parameters + * ---------- + * data : Symbol[]. Array of tensors to upsample + * scale : int (non-negative), required. Up sampling scale + * num_filter : int (non-negative), optional, default=0. + * Input filter. Only used by nearest sample_type. + * sample_type : {'bilinear', 'nearest'}, required, upsampling method + * multi_input_mode : {'concat', 'sum'},optional, default='concat' + * How to handle multiple input. + * concat means concatenate upsampled images along the channel dimension. + * sum means add all images together, + * only available for nearest neighbor upsampling. + * num_args : int, required. Number of inputs to be upsampled. + * For nearest neighbor upsampling, this can be 1-N; + * the size of output will be(scale*h_0,scale*w_0) + * and all other inputs will be upsampled to thesame size. + * For bilinear upsampling this must be 2; 1 input and 1 weight. + */ + def UpSampling(name: String = null, attr: Map[String, String] = null)( + inputs: Array[Symbol], params: Map[String, Any] = null): Symbol = { + createFromListedSymbolsNoCheck("UpSampling", name, attr)(inputs, params) } - def Concat(inputs: Seq[Symbol], - paramKwargs: Map[String, Any], - attr: Map[String, String] = null): Symbol = { - create("Concat", inputs.toArray, - paramKwargs.map { case (k, v) => (k, v.toString) }, attr) + /** + * Perform an feature concat on channel dim (dim 1) over all the inputs. + * + * Parameters + * ---------- + * data : Symbol[]. List of tensors to concatenate + * num_args : int, required. Number of inputs to be concated. + * dim : int, optional, default='1'. the dimension to be concated. + */ + def Concat(name: String = null, attr: Map[String, String] = null)( + inputs: Array[Symbol], params: Map[String, Any] = null): Symbol = { + createFromListedSymbolsNoCheck("Concat", name, attr)(inputs, params) } - // Use Logistic regression for final output, this is used on final output of a net. - // Logistic regression is suitable for binary classification or probability prediction tasks. - def LogisticRegressionOutput(inputs: Seq[Symbol], attr: Map[String, String] = null): Symbol = { - create("LogisticRegressionOutput", inputs.toArray, null, attr) + /** + * Use Logistic regression for final output, this is used on final output of a net. + * Logistic regression is suitable for binary classification or probability prediction tasks. + * Parameters + * ---------- + * data : Symbol. Input data to function. + * label : Symbol. Input label to function. + * grad_scale : float, optional, default=1. Scale the gradient by a float factor + */ + def LogisticRegressionOutput(name: String = null, + attr: Map[String, String] = null): SymbolCreateNamedFunc = { + createFromNamedSymbolsNoCheck("LogisticRegressionOutput", name, attr) } - // Use linear regression for final output, this is used on final output of a net. - def LinearRegressionOutput(inputs: Seq[Symbol], attr: Map[String, String] = null): Symbol = { - create("LinearRegressionOutput", inputs.toArray, null, attr) + /** + * Use linear regression for final output, this is used on final output of a net. + * Parameters + * ---------- + * data : Symbol. Input data to function. + * label : Symbol. Input label to function. + * grad_scale : float, optional, default=1. Scale the gradient by a float factor + */ + def LinearRegressionOutput(name: String = null, + attr: Map[String, String] = null): SymbolCreateNamedFunc = { + createFromNamedSymbolsNoCheck("LinearRegressionOutput", name, attr) } /** * Apply swapaxis to input. - * @param data Input data to the SwapAxisOp. - * @param dim1 (non-negative), default=0, the first axis to be swapped. - * @param dim2 (non-negative), default=0, the second axis to be swapped. + * + * Parameters + * ---------- + * data : Symbol. Input data to the SwapAxisOp. + * dim1 : int (non-negative), default=0, the first axis to be swapped. + * dim2 : int (non-negative), default=0, the second axis to be swapped. */ - def SwapAxis(data: Symbol, dim1: Int = 0, dim2: Int = 0, - attr: Map[String, String] = null): Symbol = { - createNoCheck("SwapAxis")(Map("data" -> data, "dim1" -> dim1, "dim2" -> dim2)) + def SwapAxis(name: String = null, attr: Map[String, String] = null): SymbolCreateNamedFunc = { + createFromNamedSymbolsNoCheck("SwapAxis", name, attr) + } + + /** + * Get embedding for one-hot input + * + * Parameters + * ---------- + * data : Symbol, Input data to the EmbeddingOp. + * weight : Symbol, Embedding weight matrix. + * input_dim : int, input dim of one-hot encoding + * output_dim : int, output dim of embedding + */ + def Embedding(name: String = null, attr: Map[String, String] = null): SymbolCreateNamedFunc = { + createFromNamedSymbolsNoCheck("Embedding", name, attr) } /** @@ -819,10 +1313,9 @@ object Symbol { * @param attr Attributes set to the resulting symbol * @return the resulting symbol */ - def create(operator: String, - symbols: Array[Symbol], - paramKwargs: Map[String, String], - attr: Map[String, String]): Symbol = { + def createFromListedSymbols( + operator: String, name: String = null, attr: Map[String, String] = null)( + symbols: Array[Symbol], paramKwargs: Map[String, String] = null): Symbol = { val function = functions(operator) require(function != null, s"invalid operator name $operator") @@ -834,11 +1327,11 @@ object Symbol { val paramKeys: Array[String] = ( if (addkeyVarNumArgs) Array[String](function.keyVarNumArgs) else Array.empty[String] - ) ++ (params - "name").keys + ) ++ params.keys val paramVals: Array[String] = ( if (addkeyVarNumArgs) Array[String](symbols.length.toString) else Array.empty[String] - ) ++ (params - "name").values + ) ++ params.values // create atomic symbol val symHandle = new SymbolHandleRef @@ -849,15 +1342,11 @@ object Symbol { val attrAll = AttrScope.current.get(Option(attr)) s.setAttr(attrAll) val hint = operator.toLowerCase - val managedName = NameManager.current.get(params.get("name"), hint) + val managedName = NameManager.current.get(Option(name), hint) s.compose(managedName, symbols) s } - def create(operator: String, symbols: Symbol*): Symbol = { - create(operator, symbols.toArray, null, null) - } - /** * Activation Operator of Neural Net. * The parameters listed below can be passed in as keyword arguments. @@ -866,10 +1355,9 @@ object Symbol { * @param attr Attributes set to the resulting symbol * @return the resulting symbol */ - private def create(operator: String, - symbols: Map[String, Symbol], - paramKwargs: Map[String, String], - attr: Map[String, String]): Symbol = { + def createFromNamedSymbols( + operator: String, name: String = null, attr: Map[String, String] = null)( + symbols: Map[String, Symbol], paramKwargs: Map[String, String] = null): Symbol = { val function = functions(operator) require(function != null, s"invalid operator name $operator") require(function.keyVarNumArgs == null || function.keyVarNumArgs.isEmpty, @@ -878,10 +1366,10 @@ object Symbol { val paramKeys = if (paramKwargs == null) Array.empty[String] - else (paramKwargs - "name").keys.toArray + else paramKwargs.keys.toArray val paramVals = if (paramKwargs == null) Array.empty[String] - else (paramKwargs - "name").values.toArray + else paramKwargs.values.toArray val symHandle = new SymbolHandleRef checkCall(_LIB.mxSymbolCreateAtomicSymbol( function.handle, paramKeys, paramVals, symHandle)) @@ -890,25 +1378,16 @@ object Symbol { val attrAll = AttrScope.current.get(Option(attr)) s.setAttr(attrAll) val hint = operator.toLowerCase - val managedName = NameManager.current.get(paramKwargs.get("name"), hint) + val managedName = NameManager.current.get(Option(name), hint) s.compose(managedName, symbols) s } - def create(operator: String, symbols: Map[String, Symbol]): Symbol = { - create(operator, symbols, null, null) - } - - def create(operator: String, - symbols: Map[String, Symbol], - paramKwargs: Map[String, String]): Symbol = { - create(operator, symbols, paramKwargs, null) - } - // a more friendly interface for creating symbols // all values except symbols in kwargs will be cast to String using its toString() method - def createNoCheck(operator: String, attr: Map[String, String] = null)( - kwargs: Map[String, Any]): Symbol = { + def createFromNamedSymbolsNoCheck( + operator: String, name: String = null, attr: Map[String, String] = null)( + kwargs: Map[String, Any]): Symbol = { val symbolArgs = kwargs.filter { case (key, value) => value.isInstanceOf[Symbol] }.map { case (key, value) => @@ -919,7 +1398,18 @@ object Symbol { }.map { case (key, value) => (key, value.toString) } - create(operator, symbolArgs, strArgs, attr) + createFromNamedSymbols(operator, name, attr)(symbolArgs, strArgs) + } + + // a more friendly interface for creating symbols + // all values except symbols in kwargs will be cast to String using its toString() method + def createFromListedSymbolsNoCheck( + operator: String, name: String = null, attr: Map[String, String] = null)( + symbols: Array[Symbol], kwargs: Map[String, Any] = null): Symbol = { + val args = + if (kwargs == null) null + else kwargs.map { case (key, value) => (key, value.toString) } + createFromListedSymbols(operator, name, attr)(symbols, args) } /** @@ -962,3 +1452,29 @@ object Symbol { } private case class SymbolFunction(handle: SymbolHandle, keyVarNumArgs: String) + +object SymbolConversions { + implicit def int2Scalar(x: Int): SymbolConversions[Int] = new SymbolConversions(x) + implicit def double2Scalar(x: Double): SymbolConversions[Double] = new SymbolConversions(x) + implicit def float2Scalar(x: Float): SymbolConversions[Float] = new SymbolConversions(x) +} + +class SymbolConversions[@specialized(Int, Float, Double) V](val value: V) { + def +(other: Symbol): Symbol = { + other + value + } + + def -(other: Symbol): Symbol = { + Symbol.createFromListedSymbols("_MinusScalar")(Array(other), + Map("scalar" -> value.toString, "scalar_on_left" -> "True")) + } + + def *(other: Symbol): Symbol = { + other + value + } + + def /(other: Symbol): Symbol = { + Symbol.createFromListedSymbols("_DivScalar")(Array(other), + Map("scalar" -> value.toString, "scalar_on_left" -> "True")) + } +} diff --git a/scala-package/core/src/test/scala/ml/dmlc/mxnet/OperatorSuite.scala b/scala-package/core/src/test/scala/ml/dmlc/mxnet/OperatorSuite.scala index 8e99553af3b8..74d935bc6a98 100644 --- a/scala-package/core/src/test/scala/ml/dmlc/mxnet/OperatorSuite.scala +++ b/scala-package/core/src/test/scala/ml/dmlc/mxnet/OperatorSuite.scala @@ -2,9 +2,11 @@ package ml.dmlc.mxnet import ml.dmlc.mxnet.Base.Shape import ml.dmlc.mxnet.CheckUtils._ + import org.scalatest.prop.GeneratorDrivenPropertyChecks import org.scalatest.{Matchers, BeforeAndAfterAll, FunSuite} import org.scalacheck.Gen + import scala.collection.mutable class OperatorSuite extends FunSuite with BeforeAndAfterAll @@ -12,7 +14,7 @@ class OperatorSuite extends FunSuite with BeforeAndAfterAll private def checkElementwiseSumWithShape(shape: Shape, n: Int) = { // forward val inputs = (0 until n).map(i => Symbol.Variable(s"arg $i")) - val out = Symbol.ElementWiseSum("esum", inputs: _*) + val out = Symbol.ElementWiseSum(name = "esum")(inputs.toArray) val arr = (0 until n).map(_ => Random.uniform(-10, 10, shape)) val arrGrad = (0 until n).map(_ => NDArray.empty(shape)) val exec = out.bind(Context.cpu(), args = arr, argsGrad = arrGrad) @@ -45,7 +47,7 @@ class OperatorSuite extends FunSuite with BeforeAndAfterAll val targetDim = shapes.map(_(dimension)).sum val inputs = (0 until shapes.size).map(i => Symbol.Variable(s"arg$i")) - val out = Symbol.Concat(inputs, Map("name" -> "conc", "dim" -> dimension)) + val out = Symbol.Concat(name = "conc")(inputs.toArray, Map("dim" -> dimension)) val arr = shapes.map { shape => val nd = NDArray.empty(shape) nd.set(shape(dimension)) @@ -119,12 +121,12 @@ class OperatorSuite extends FunSuite with BeforeAndAfterAll } test("regression") { - checkRegression(Symbol.LogisticRegressionOutput( - Array(Symbol.Variable("data"), Symbol.Variable("label"))), + checkRegression(Symbol.LogisticRegressionOutput()( + Map("data" -> Symbol.Variable("data"), "label" -> Symbol.Variable("label"))), (x: Float) => 1.0f / (1.0f + Math.exp(-x).toFloat), (x: Float, y: Float) => x - y) - checkRegression(Symbol.LinearRegressionOutput( - Array(Symbol.Variable("data"), Symbol.Variable("label"))), + checkRegression(Symbol.LinearRegressionOutput()( + Map("data" -> Symbol.Variable("data"), "label" -> Symbol.Variable("label"))), (x: Float) => x, (x: Float, y: Float) => x - y) } @@ -146,8 +148,8 @@ class OperatorSuite extends FunSuite with BeforeAndAfterAll // [[ 2., 2., 2., 2.], // [ 2., 2., 2., 2.], // [ 2., 2., 2., 2.]]] - val swap0 = Symbol.SwapAxis(data = data, dim1 = 0, dim2 = 2) - val swap = Symbol.SwapAxis(data = swap0, dim1 = 1, dim2 = 2) + val swap0 = Symbol.SwapAxis()(Map("data" -> data, "dim1" -> 0, "dim2" -> 2)) + val swap = Symbol.SwapAxis()(Map("data" -> swap0, "dim1" -> 1, "dim2" -> 2)) val exec = swap.bind(Context.cpu(), args = Array(arrData)) exec.forward() val out = exec.outputs(0) @@ -171,4 +173,477 @@ class OperatorSuite extends FunSuite with BeforeAndAfterAll assert(CheckUtils.reldiff(axis0.toArray, Array(1f, 1f, 1f, 2f, 2f, 2f)) < 1e-6f) } } + + test("scalar op") { + val data = Symbol.Variable("data") + val shape = Vector(3, 4) + val dataTmp = NDArray.ones(shape) * 5 + + val test = { + import ml.dmlc.mxnet.SymbolConversions._ + 2 / (4 - ((1 + data + 1) * 2 / 5) - 0.2) + } + + val (npout1, npout) = { + import ml.dmlc.mxnet.NDArrayConversions._ + val npout1 = 4 - ((1 + dataTmp + 1) * 2 / 5) - 0.2f + val npout = 2 / npout1 + (npout1, npout) + } + + checkSymbolicForward(test, Array(dataTmp), Array(npout)) + + val npoutGrad = new NDArrayConversions(2f * (2f * 2f / 5f)) / (npout1 * npout1) + + checkSymbolicBackward(test, Array(dataTmp), Array(NDArray.ones(shape) * 2), Array(npoutGrad)) + } + + test("scalar pow") { + val data = Symbol.Variable("data") + val shape = Vector(1, 1) + val dataTmp = NDArray.ones(shape) * 3 + val dataTmpPowered = NDArray.ones(shape) * 9 + val test = Symbol.pow(data, 2) + // TODO: check numeric gradient + checkSymbolicForward(test, Array(dataTmp), Array(dataTmpPowered)) + checkSymbolicBackward(test, Array(dataTmp), Array(NDArray.ones(shape)), Array(dataTmp * 2)) + } + + test("symbol pow") { + val shape = Vector(1, 1) + + val data = Symbol.Variable("data") + val dataTmp = NDArray.ones(shape) * 2 + + val exp = Symbol.Variable("exp") + val expTmp = NDArray.ones(shape) * 3 + + val test = Symbol.pow(data, exp) + + // TODO: check numeric gradient + checkSymbolicForward(test, Seq(dataTmp, expTmp), Seq(NDArray.ones(shape) * 8)) + + val dataDir = NDArray.ones(shape) * 4 * expTmp // dataTmp**(expTmp - 1) * expTmp + // expDir = dataTmp**(expTmp) * log(dataTmp) + val expDir = NDArray.ones(shape) * 8 * (NDArray.ones(shape) * Math.log(2).toFloat) + checkSymbolicBackward(test, Seq(dataTmp, expTmp), + Seq(NDArray.ones(shape)), Seq(dataDir, expDir)) + } + + test("pow fn") { + val shape = Vector(3, 4) + val exp = Symbol.Variable("exp") + val y = Symbol.pow(2, exp) + val x = NDArray.ones(shape) * 3 + // TODO: check numeric gradient + checkSymbolicForward(y, Seq(x), Seq(NDArray.ones(shape) * 8)) // 2**x + checkSymbolicBackward(y, Seq(x), Seq(NDArray.ones(shape)), + // log(2) * 2**x + Seq(NDArray.ones(shape) * 8 * Math.log(2).toFloat)) + } + + test("embedding") { + val inDim = 10 + val outDim = 4 + val batch = 24 + + val data = Symbol.Variable("data") + val embed = Symbol.Embedding(name = "embed")( + Map("data" -> data, "input_dim" -> inDim, "output_dim" -> outDim)) + // TODO + // scalastyle:off println + println(s"Embeded symbol: ${embed.toJson}") + // scalastyle:on println + } + + // check ops handle duplicate input correctly. + test("binary op duplicate input") { + val data = Symbol.Variable("data") + val shape = Vector(3, 4) + val dataTmp = NDArray.ones(shape) * 5 + val arrData = dataTmp.copy() + val arrGrad = NDArray.ones(shape) * 3 + val outGrad = NDArray.ones(shape) + val square = data * data + val exeSquare = square.bind(Context.cpu(), args = Array(arrData), argsGrad = Array(arrGrad)) + exeSquare.forward() + assert(reldiff(exeSquare.outputs.head, dataTmp * dataTmp) < 1e-6f) + exeSquare.backward(outGrad) + assert(reldiff(arrGrad, dataTmp * 2f) < 1e-6f) + } + + test("sign") { + val data = Symbol.Variable("data") + val shape = Vector(3, 4) + val dataTmp = NDArray.ones(shape) * 5 + val arrData = dataTmp.copy() + val arrGrad = NDArray.ones(shape) * 3 + + val test = Symbol.sign(data) + val exeTest = test.bind(Context.cpu(), args = Array(arrData), argsGrad = Array(arrGrad)) + exeTest.forward() + val out = exeTest.outputs.head + val npout = NDArray.sign(dataTmp) + assert(reldiff(out, npout) < 1e-6) + + val outGrad = NDArray.ones(shape) * 2 + exeTest.backward(outGrad) + arrGrad.toArray.foreach(elem => assert(elem === 0f +- 1e-3f)) + } + + test("round, ceil, floor") { + val data = Symbol.Variable("data") + val shape = Vector(3, 4) + val dataTmp = NDArray.ones(shape) * 5.543f + val arrData = dataTmp.copy() + val arrGrad = NDArray.ones(shape) * 2 + + val test = Symbol.round(data) + Symbol.ceil(data) + Symbol.floor(data) + val exeTest = test.bind(Context.cpu(), args = Array(arrData)) + exeTest.forward() + val out = exeTest.outputs.head + val npout = NDArray.round(dataTmp) + NDArray.ceil(dataTmp) + NDArray.floor(dataTmp) + assert(reldiff(out, npout) < 1e-6) + } + + test("rsqrt, cos, sin") { + val data = Symbol.Variable("data") + val shape = Vector(3, 4) + val dataTmp = NDArray.ones(shape) * 5 + val arrData = dataTmp.copy() + val arrGrad = NDArray.ones(shape) * 3 + + val test = Symbol.rsqrt(data) + Symbol.cos(data) + Symbol.sin(data) + val exeTest = test.bind(Context.cpu(), args = Array(arrData), argsGrad = Array(arrGrad)) + exeTest.forward() + val out = exeTest.outputs.head + val npout = { + import ml.dmlc.mxnet.NDArrayConversions._ + 1 / NDArray.sqrt(dataTmp) + NDArray.cos(dataTmp) + NDArray.sin(dataTmp) + } + assert(reldiff(out, npout) < 1e-6) + + val outGrad = NDArray.ones(shape) * 2 + val npoutGrad = { + import ml.dmlc.mxnet.NDArrayConversions._ + outGrad * -(1 / (2 * dataTmp * NDArray.sqrt(dataTmp))) + + outGrad * -1 * NDArray.sin(dataTmp) + outGrad * NDArray.cos(dataTmp) + } + exeTest.backward(outGrad) + assert(reldiff(arrGrad, npoutGrad) < 1e-6) + } + + test("maximum") { + val data1 = Symbol.Variable("data") + val data2 = Symbol.Variable("data") + val shape = Vector(3, 4) + val dataTmp1 = Random.uniform(0, 100, shape) + val dataTmp2 = Random.uniform(0, 100, shape) + + val arrData1 = dataTmp1.copy() + val arrData2 = dataTmp2.copy() + + val test = Symbol.max(data1, data2) + val exeTest = test.bind(Context.cpu(), args = Array(arrData1, arrData2)) + exeTest.forward() + val out = exeTest.outputs.head + val expected = (dataTmp1.toArray zip dataTmp2.toArray).map { case (a, b) => Math.max(a, b) } + assert(reldiff(out.toArray, expected) < 1e-6) + } + + test("minimum") { + val data1 = Symbol.Variable("data") + val data2 = Symbol.Variable("data") + val shape = Vector(3, 4) + val dataTmp1 = Random.uniform(0, 100, shape) + val dataTmp2 = Random.uniform(0, 100, shape) + + val arrData1 = dataTmp1.copy() + val arrData2 = dataTmp2.copy() + + val test = Symbol.min(data1, data2) + val exeTest = test.bind(Context.cpu(), args = Array(arrData1, arrData2)) + exeTest.forward() + val out = exeTest.outputs.head + val expected = (dataTmp1.toArray zip dataTmp2.toArray).map { case (a, b) => Math.min(a, b) } + assert(reldiff(out.toArray, expected) < 1e-6) + } + + test("maximum minimum scalar") { + val data = Symbol.Variable("data") + val shape = Vector(3, 4) + val dataTmp = NDArray.ones(shape) * 2 + + val arrData = dataTmp.copy() + + val test = Symbol.max(data, 3) + Symbol.max(9, data) + Symbol.min(5, data) + Symbol.min(data, 4) + val exeTest = test.bind(Context.cpu(), args = Array(arrData)) + exeTest.forward() + val out = exeTest.outputs.head + // 3 + 9 + 2 + 2 + assert(reldiff(out, NDArray.ones(shape) * 16) < 1e-6) + } + + test("abs") { + val data = Symbol.Variable("data") + val shape = Vector(3, 4) + val dataTmp = NDArray.ones(shape) * 5 + val arrData = dataTmp.copy() + val arrGrad = NDArray.ones(shape) * 3 + + val test = Symbol.abs(data) + val exeTest = test.bind(Context.cpu(), args = Array(arrData), argsGrad = Array(arrGrad)) + exeTest.forward() + val out = exeTest.outputs.head + val npout = NDArray.abs(dataTmp) + assert(reldiff(out, npout) < 1e-6) + + val outGrad = NDArray.ones(shape) * 2 + val npoutGrad = outGrad * NDArray.sign(dataTmp) + exeTest.backward(outGrad) + assert(reldiff(arrGrad, npoutGrad) < 1e-6) + } + + // configure A: input --> conv --> deconv --> output. + // the convolution and deconvoluiton has similar parameter which ensure + // the input shape is the same as output, and the same weights between conv + // and deconv; + // If the input value of forward() and backwrad() is the same, then + // the output value of them should also the same; + private def checkDeconvolutionForwardBackward(inputShape: Shape, + numFilter: Int, + kernel: (Int, Int), + stride: (Int, Int), + pad: (Int, Int)): Unit = { + require(inputShape(1) == numFilter) + val data = Symbol.Variable(name = "data") + val conv = Symbol.Convolution(name = "conv")(Map( + "data" -> data, "kernel" -> kernel, "stride" -> stride, "pad" -> pad, + "num_filter" -> numFilter, "no_bias" -> "true")) + val deconv = Symbol.Deconvolution(name = "deconv")(Map( + "data" -> conv, "kernel" -> kernel, "stride" -> stride, "pad" -> pad, + "num_filter" -> numFilter, "no_bias" -> "true")) + + val argNames = deconv.listArguments() + val (argShapes, outShapes, _) = deconv.inferShape(Map("data" -> inputShape)) + val inputData = Random.uniform(-5, 5, inputShape) + val outGrad = inputData + val convWeight = Random.normal(0, 1, Vector(numFilter, inputShape(1), kernel._1, kernel._2)) + val args: Map[String, NDArray] = + Map("data" -> inputData, "conv_weight" -> convWeight, "deconv_weight" -> convWeight) + val argsGrad: Seq[NDArray] = argShapes.map(NDArray.empty(_)) + + val exe = deconv.bind(Context.cpu(), args = args, argsGrad = argsGrad) + exe.forward() + val out = exe.outputs.head + exe.backward(outGrad) + assert(reldiff(out, argsGrad.head) < 1e-6) + } + + test("deconvolution forward & backward") { + checkDeconvolutionForwardBackward( + inputShape = Vector(1, 1, 5, 5), + numFilter = 1, + kernel = (3, 3), + stride = (1, 1), + pad = (1, 1) + ) + checkDeconvolutionForwardBackward( + inputShape = Vector(32, 3, 28, 28), + numFilter = 3, + kernel = (3, 3), + stride = (1, 1), + pad = (1, 1) + ) + checkDeconvolutionForwardBackward( + inputShape = Vector(10, 3, 403, 403), + numFilter = 3, + kernel = (7, 7), + stride = (5, 5), + pad = (2, 2) + ) + } + + // configure A: input --> conv --> output. + // configure B: input --> deconv --> output + // the convolution and deconvoluiton has similar parameter which ensure + // the input shape is the same as output; + // During backward(), if the input of A equals output of B, and the output + // of A equals input of B, then the grad of weight should be the same; + private def checkDeconvolutionGradient(inputShape: Shape, + numFilter: Int, + pad: (Int, Int)): Unit = { + val stride = (1, 1) + val kernel = (2 * pad._1 + 1, 2 * pad._2 + 1) + val dataConv = Symbol.Variable(name = "data_conv") + val conv = Symbol.Convolution(name = "conv")(Map( + "data" -> dataConv, "kernel" -> kernel, "stride" -> stride, "pad" -> pad, + "num_filter" -> numFilter, "no_bias" -> "true")) + val dataDeconv = Symbol.Variable(name = "data_deconv") + val deconv = Symbol.Deconvolution(name = "deconv")(Map( + "data" -> dataDeconv, "kernel" -> kernel, "stride" -> stride, "pad" -> pad, + "num_filter" -> numFilter, "no_bias" -> "true")) + + val convData = Random.uniform(-5, 5, inputShape) + val convArgs = Map("data_conv" -> convData, + "conv_weight" -> Random.normal(0, 1, Vector(numFilter, inputShape(1), kernel._1, kernel._2))) + + val convArgsGrad = Seq(NDArray.zeros(convData.shape), + NDArray.zeros(Vector(numFilter, inputShape(1), kernel._1, kernel._2))) + val exeConv = conv.bind(Context.cpu(), args = convArgs, argsGrad = convArgsGrad) + val convOutGrad = Random.normal(0, 2, exeConv.outputs.head.shape) + exeConv.backward(convOutGrad) + + val deconvData = convOutGrad + val deconvArgs = Map("data_deconv" -> deconvData, "deconv_weight" -> convArgs("conv_weight")) + val deconvArgsGrad = Seq(NDArray.zeros(deconvData.shape), + NDArray.zeros(Vector(numFilter, inputShape(1), kernel._1, kernel._2))) + val exeDeconv = deconv.bind(Context.cpu(), args = deconvArgs, argsGrad = deconvArgsGrad) + val deconvOutGrad = convData + exeDeconv.backward(deconvOutGrad) + assert(reldiff(convArgsGrad(1), deconvArgsGrad(1)) < 1e-6) + } + + test("deconvolution gradient") { + checkDeconvolutionGradient( + inputShape = Vector(1, 3, 5, 5), + numFilter = 3, + pad = (1, 1) + ) + checkDeconvolutionGradient( + inputShape = Vector(5, 3, 100, 100), + numFilter = 3, + pad = (3, 3) + ) + } + + private def checkNearestUpSamplingWithShape(shapes: Seq[Shape], + scale: Int, + rootScale: Int): Unit = { + val arr = shapes.zipWithIndex.map { case (shape, i) => + (s"arg_$i", Random.uniform(-10, 10, shape)) + }.toMap + + val arrGrad = shapes.zipWithIndex.map { case (shape, i) => + (s"arg_$i", NDArray.zeros(shape)) + }.toMap + + val up = Symbol.UpSampling()((0 until shapes.size).map(i => Symbol.Variable(s"arg_$i")).toArray, + Map("sample_type" -> "nearest", "scale" -> rootScale)) + val exe = up.bind(Context.cpu(), args = arr, argsGrad = arrGrad) + exe.forward(isTrain = true) + exe.backward(exe.outputs) + for (k <- 0 until shapes.size) { + val name = s"arg_$k" + val expected = + arr(name).toArray.map(_ * Math.pow(rootScale, 2).toFloat * Math.pow(scale, 2 * k).toFloat) + val real = arrGrad(name).toArray + (expected zip real) foreach { case (e, r) => + assert(r === e +- 0.1f) + } + } + } + + test("nearest upsampling") { + for (rootScale <- 1 to 3) { + for (scale <- 1 to 3) { + for (numShape <- 1 to 3) { + for (base <- 1 to 3) { + val shapes = (0 until numShape).map(i => + Vector(1, 3, base * rootScale * Math.pow(scale, numShape - 1 - i).toInt, + base * rootScale * Math.pow(scale, numShape - 1 - i).toInt)) + checkNearestUpSamplingWithShape(shapes, scale, rootScale) + } + } + } + } + } + + test("batch norm") { + val data = Symbol.Variable("data") + val test = Symbol.BatchNorm(name = "bn")(Map("data" -> data, "fix_gamma" -> "False")) + // scalastyle:off println + println(s"BatchNorm: ${test.toJson}") + // scalastyle:on println + // TODO: check numeric gradient + } + + /** + * Compare forward call to expected value. + * @param sym output symbol + * @param location list of numpy arrays corresponding to sym.list_arguments + * @param expected list of arrays corresponding to sym.outputs + * @param checkEps relative error to check to + */ + private def checkSymbolicForward(sym: Symbol, + location: Seq[NDArray], + expected: Seq[NDArray], + checkEps: Float = 1e-5f): Unit = { + val arrData = location.map(_.copy()) + val arrGrad = location.map(array => NDArray.empty(array.shape)) + + val executor = sym.bind(Context.cpu(), args = arrData, argsGrad = arrGrad) + + val inps = executor.argArrays + assert(inps.size === location.size, + s"Executor argArrays and and location len do not match." + + s"Got ${inps.size} inputs and ${location.size} locations") + + for ((inp, source) <- location zip executor.argArrays) { + source.set(inp) + } + for (g <- executor.gradArrays) { + if (g != null) { + g.set(0f) + } + } + + assert(executor.outputs.length === 1) + + executor.forward() + + for ((expect, output) <- expected zip executor.outputs) { + assert(reldiff(expect, output) <= checkEps) + } + } + + /** + * Compare backwards call to expected value. + * @param sym output symbol + * @param location list of numpy arrays corresponding to sym.list_arguments + * @param grad list of numpy arrays corresponding to sym.outputs for incoming gradient + * @param expected list of arrays corresponding to sym.outputs + * @param checkEps relative error to check to + */ + private def checkSymbolicBackward(sym: Symbol, + location: Seq[NDArray], + grad: Seq[NDArray], + expected: Seq[NDArray], + checkEps: Float = 1e-5f): Unit = { + val arrData = location.map(_.copy()) + val arrGrad = location.map(array => NDArray.empty(array.shape)) + val outGrad = grad.map(_.copy()).toArray + + val executor = sym.bind(Context.cpu(), args = arrData, argsGrad = arrGrad) + + val inps = executor.argArrays + assert(inps.size === location.size, + s"Executor argArrays and and location len do not match." + + s"Got ${inps.size} inputs and ${location.size} locations") + for ((inp, source) <- location zip executor.argArrays) { + source.set(inp) + } + for (g <- executor.gradArrays) { + if (g != null) { + g.set(0f) + } + } + + executor.forward() + executor.backward(outGrad) + + for ((expect, grad) <- expected zip executor.gradArrays) { + assert(reldiff(expect, grad) <= checkEps) + } + } } diff --git a/scala-package/core/src/test/scala/ml/dmlc/mxnet/SymbolSuite.scala b/scala-package/core/src/test/scala/ml/dmlc/mxnet/SymbolSuite.scala index 3e70574f0914..71fb8f257191 100644 --- a/scala-package/core/src/test/scala/ml/dmlc/mxnet/SymbolSuite.scala +++ b/scala-package/core/src/test/scala/ml/dmlc/mxnet/SymbolSuite.scala @@ -6,14 +6,14 @@ class SymbolSuite extends FunSuite with BeforeAndAfterAll { test("symbol compose") { val data = Symbol.Variable("data") - var net1 = Symbol.FullyConnected(Map("data" -> data, "name" -> "fc1", "num_hidden" -> 10)) - net1 = Symbol.FullyConnected(Map("data" -> net1, "name" -> "fc2", "num_hidden" -> 100)) + var net1 = Symbol.FullyConnected(name = "fc1")(Map("data" -> data, "num_hidden" -> 10)) + net1 = Symbol.FullyConnected(name = "fc2")(Map("data" -> net1, "num_hidden" -> 100)) assert(net1.listArguments().toArray === Array("data", "fc1_weight", "fc1_bias", "fc2_weight", "fc2_bias")) - var net2 = Symbol.FullyConnected(Map("name" -> "fc3", "num_hidden" -> 10)) - net2 = Symbol.Activation(Map("data" -> net2, "act_type" -> "relu")) - net2 = Symbol.FullyConnected(Map("data" -> net2, "name" -> "fc4", "num_hidden" -> 20)) + var net2 = Symbol.FullyConnected(name = "fc3")(Map("num_hidden" -> 10)) + net2 = Symbol.Activation()(Map("data" -> net2, "act_type" -> "relu")) + net2 = Symbol.FullyConnected(name = "fc4")(Map("data" -> net2, "num_hidden" -> 20)) // scalastyle:off println println(s"net2 debug info:\n${net2.debugStr}") // scalastyle:on println @@ -28,20 +28,20 @@ class SymbolSuite extends FunSuite with BeforeAndAfterAll { test("symbol internal") { val data = Symbol.Variable("data") - val oldfc = Symbol.FullyConnected(Map("data" -> data, "name" -> "fc1", "num_hidden" -> 10)) - val net1 = Symbol.FullyConnected(Map("data" -> oldfc, "name" -> "fc2", "num_hidden" -> 100)) + val oldfc = Symbol.FullyConnected(name = "fc1")(Map("data" -> data, "num_hidden" -> 10)) + val net1 = Symbol.FullyConnected(name = "fc2")(Map("data" -> oldfc, "num_hidden" -> 100)) assert(net1.listArguments().toArray === Array("data", "fc1_weight", "fc1_bias", "fc2_weight", "fc2_bias")) - val internal = net1.getInternals + val internal = net1.getInternals() val fc1 = internal.get("fc1_output") assert(fc1.listArguments() === oldfc.listArguments()) } test("symbol infer type") { val data = Symbol.Variable("data") - val f32data = Symbol.Cast(Map("data" -> data, "dtype" -> "float32")) - val fc1 = Symbol.FullyConnected(Map("data" -> f32data, "name" -> "fc1", "num_hidden" -> 128)) - val mlp = Symbol.SoftmaxOutput(Map("data" -> fc1, "name" -> "softmax")) + val f32data = Symbol.Cast()(Map("data" -> data, "dtype" -> "float32")) + val fc1 = Symbol.FullyConnected(name = "fc1")(Map("data" -> f32data, "num_hidden" -> 128)) + val mlp = Symbol.SoftmaxOutput(name = "softmax")(Map("data" -> fc1)) val (arg, out, aux) = mlp.inferType(Map("data" -> classOf[Double])) assert(arg.toArray === Array(classOf[Double], classOf[Float], classOf[Float], classOf[Float])) diff --git a/scala-package/core/src/test/scala/ml/dmlc/mxnet/train/ConvSuite.scala b/scala-package/core/src/test/scala/ml/dmlc/mxnet/train/ConvSuite.scala index 98d0ec770738..4ef62a5f23f5 100644 --- a/scala-package/core/src/test/scala/ml/dmlc/mxnet/train/ConvSuite.scala +++ b/scala-package/core/src/test/scala/ml/dmlc/mxnet/train/ConvSuite.scala @@ -16,23 +16,23 @@ class ConvSuite extends FunSuite with BeforeAndAfterAll { val batchSize = 100 val data = Symbol.Variable("data") - val conv1 = Symbol.Convolution(Map("data" -> data, "name" -> "conv1", - "num_filter" -> 32, "kernel" -> (3, 3), "stride" -> (2, 2))) - val bn1 = Symbol.BatchNorm(Map("data" -> conv1, "name" -> "bn1")) - val act1 = Symbol.Activation(Map("data" -> bn1, "name" -> "relu1", "act_type" -> "relu")) - val mp1 = Symbol.Pooling(Map("data" -> act1, "name" -> "mp1", - "kernel" -> (2, 2), "stride" -> (2, 2), "pool_type" -> "max")) + val conv1 = Symbol.Convolution(name = "conv1")(Map("data" -> data, "num_filter" -> 32, + "kernel" -> (3, 3), "stride" -> (2, 2))) + val bn1 = Symbol.BatchNorm(name = "bn1")(Map("data" -> conv1)) + val act1 = Symbol.Activation(name = "relu1")(Map("data" -> bn1, "act_type" -> "relu")) + val mp1 = Symbol.Pooling(name = "mp1")(Map("data" -> act1, "kernel" -> (2, 2), + "stride" -> (2, 2), "pool_type" -> "max")) - val conv2 = Symbol.Convolution(Map("data" -> mp1, "name" -> "conv2", "num_filter" -> 32, - "kernel" -> (3, 3), "stride" -> (2, 2))) - val bn2 = Symbol.BatchNorm(Map("data" -> conv2, "name" -> "bn2")) - val act2 = Symbol.Activation(Map("data" -> bn2, "name" -> "relu2", "act_type" -> "relu")) - val mp2 = Symbol.Pooling(Map("data" -> act2, "name" -> "mp2", - "kernel" -> (2, 2), "stride" -> (2, 2), "pool_type" -> "max")) + val conv2 = Symbol.Convolution(name = "conv2")(Map("data" -> mp1, "num_filter" -> 32, + "kernel" -> (3, 3), "stride" -> (2, 2))) + val bn2 = Symbol.BatchNorm(name = "bn2")(Map("data" -> conv2)) + val act2 = Symbol.Activation(name = "relu2")(Map("data" -> bn2, "act_type" -> "relu")) + val mp2 = Symbol.Pooling(name = "mp2")(Map("data" -> act2, "kernel" -> (2, 2), + "stride" -> (2, 2), "pool_type" -> "max")) - val fl = Symbol.Flatten(Map("data" -> mp2, "name" -> "flatten")) - val fc2 = Symbol.FullyConnected(Map("data" -> fl, "name" -> "fc2", "num_hidden" -> 10)) - val softmax = Symbol.SoftmaxOutput(Map("data" -> fc2, "name" -> "sm")) + val fl = Symbol.Flatten(name = "flatten")(Map("data" -> mp2)) + val fc2 = Symbol.FullyConnected(name = "fc2")(Map("data" -> fl, "num_hidden" -> 10)) + val softmax = Symbol.SoftmaxOutput(name = "sm")(Map("data" -> fc2)) val numEpoch = 1 val model = new FeedForward(softmax, Context.cpu(), numEpoch = numEpoch,