From be0e57c5ae491050ed597095f557b1e0be7983a1 Mon Sep 17 00:00:00 2001 From: Ke Han Date: Wed, 24 Jun 2020 18:08:57 +0800 Subject: [PATCH] * Fix ndarray situation --- src/api/operator/numpy/np_matrix_op.cc | 2 +- src/api/operator/numpy/random/np_multinomial_op.cc | 12 +++++++++--- 2 files changed, 10 insertions(+), 4 deletions(-) diff --git a/src/api/operator/numpy/np_matrix_op.cc b/src/api/operator/numpy/np_matrix_op.cc index 98841cf971a9..098fd622850b 100644 --- a/src/api/operator/numpy/np_matrix_op.cc +++ b/src/api/operator/numpy/np_matrix_op.cc @@ -631,7 +631,7 @@ MXNET_REGISTER_API("_npi.vstack") inputs_vec[i] = args[i].operator mxnet::NDArray*(); } NDArray** inputs = inputs_vec.data(); - auto ndoutputs = Invoke(op, &attrs, param.num_args, &inputs[0], &num_outputs, nullptr); + auto ndoutputs = Invoke(op, &attrs, param.num_args, inputs, &num_outputs, nullptr); *ret = ndoutputs[0]; }); diff --git a/src/api/operator/numpy/random/np_multinomial_op.cc b/src/api/operator/numpy/random/np_multinomial_op.cc index cbd7c5977fd4..5f4a3f2a1a44 100644 --- a/src/api/operator/numpy/random/np_multinomial_op.cc +++ b/src/api/operator/numpy/random/np_multinomial_op.cc @@ -35,14 +35,19 @@ MXNET_REGISTER_API("_npi.multinomial") const nnvm::Op* op = Op::Get("_npi_multinomial"); nnvm::NodeAttrs attrs; op::NumpyMultinomialParam param; - std::vector inputs; + NDArray** inputs = new NDArray*[1](); + int num_inputs = 0; - //parse int + // parse int param.n = arg[0].operator int(); // parse pvals if (args[1].type_code() == kNull) { param.pvals = dmlc::nullopt; + } else if (args[1].type_code() == kNDArrayHandle) { + param.pvals = dmlc::nullopt; + inputs[0] = args[1].operator mxnet::NDArray*(); + num_inputs = 1; } else { param.pvals = Tuple(args[1].operator ObjectRef()); } @@ -61,7 +66,8 @@ MXNET_REGISTER_API("_npi.multinomial") attrs.parsed = std::move(param); attrs.op = op; SetAttrDict(&attrs); - auto ndoutputs = Invoke(op, &attrs, 0, nullptr, 0, nullptr); + inputs = num_inputs == 0 ? nullptr : inputs; + auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, 0, nullptr); *ret = ndoutputs[0]; });