diff --git a/src/api/operator/numpy/np_matrix_op.cc b/src/api/operator/numpy/np_matrix_op.cc index e2fb1230c2d9..9075e153161c 100644 --- a/src/api/operator/numpy/np_matrix_op.cc +++ b/src/api/operator/numpy/np_matrix_op.cc @@ -527,7 +527,7 @@ MXNET_REGISTER_API("_npi.vstack") inputs_vec[i] = args[i].operator mxnet::NDArray*(); } NDArray** inputs = inputs_vec.data(); - auto ndoutputs = Invoke(op, &attrs, param.num_args, &inputs[0], &num_outputs, nullptr); + auto ndoutputs = Invoke(op, &attrs, param.num_args, inputs, &num_outputs, nullptr); *ret = ndoutputs[0]; }); diff --git a/src/api/operator/numpy/random/np_multinomial_op.cc b/src/api/operator/numpy/random/np_multinomial_op.cc index cbd7c5977fd4..5f4a3f2a1a44 100644 --- a/src/api/operator/numpy/random/np_multinomial_op.cc +++ b/src/api/operator/numpy/random/np_multinomial_op.cc @@ -35,14 +35,19 @@ MXNET_REGISTER_API("_npi.multinomial") const nnvm::Op* op = Op::Get("_npi_multinomial"); nnvm::NodeAttrs attrs; op::NumpyMultinomialParam param; - std::vector inputs; + NDArray** inputs = new NDArray*[1](); + int num_inputs = 0; - //parse int + // parse int param.n = arg[0].operator int(); // parse pvals if (args[1].type_code() == kNull) { param.pvals = dmlc::nullopt; + } else if (args[1].type_code() == kNDArrayHandle) { + param.pvals = dmlc::nullopt; + inputs[0] = args[1].operator mxnet::NDArray*(); + num_inputs = 1; } else { param.pvals = Tuple(args[1].operator ObjectRef()); } @@ -61,7 +66,8 @@ MXNET_REGISTER_API("_npi.multinomial") attrs.parsed = std::move(param); attrs.op = op; SetAttrDict(&attrs); - auto ndoutputs = Invoke(op, &attrs, 0, nullptr, 0, nullptr); + inputs = num_inputs == 0 ? nullptr : inputs; + auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, 0, nullptr); *ret = ndoutputs[0]; });