Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
* Fix ndarray situation
Browse files Browse the repository at this point in the history
  • Loading branch information
hanke580 committed Jun 24, 2020
1 parent 9171f09 commit e30a842
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 4 deletions.
2 changes: 1 addition & 1 deletion src/api/operator/numpy/np_matrix_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -527,7 +527,7 @@ MXNET_REGISTER_API("_npi.vstack")
inputs_vec[i] = args[i].operator mxnet::NDArray*();
}
NDArray** inputs = inputs_vec.data();
auto ndoutputs = Invoke(op, &attrs, param.num_args, &inputs[0], &num_outputs, nullptr);
auto ndoutputs = Invoke(op, &attrs, param.num_args, inputs, &num_outputs, nullptr);
*ret = ndoutputs[0];
});

Expand Down
12 changes: 9 additions & 3 deletions src/api/operator/numpy/random/np_multinomial_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -35,14 +35,19 @@ MXNET_REGISTER_API("_npi.multinomial")
const nnvm::Op* op = Op::Get("_npi_multinomial");
nnvm::NodeAttrs attrs;
op::NumpyMultinomialParam param;
std::vector<NDArray*> inputs;
NDArray** inputs = new NDArray*[1]();
int num_inputs = 0;

//parse int
// parse int
param.n = arg[0].operator int();

// parse pvals
if (args[1].type_code() == kNull) {
param.pvals = dmlc::nullopt;
} else if (args[1].type_code() == kNDArrayHandle) {
param.pvals = dmlc::nullopt;
inputs[0] = args[1].operator mxnet::NDArray*();
num_inputs = 1;
} else {
param.pvals = Tuple<double>(args[1].operator ObjectRef());
}
Expand All @@ -61,7 +66,8 @@ MXNET_REGISTER_API("_npi.multinomial")
attrs.parsed = std::move(param);
attrs.op = op;
SetAttrDict<op::NumpyMultinomialParam>(&attrs);
auto ndoutputs = Invoke(op, &attrs, 0, nullptr, 0, nullptr);
inputs = num_inputs == 0 ? nullptr : inputs;
auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, 0, nullptr);
*ret = ndoutputs[0];

});
Expand Down

0 comments on commit e30a842

Please sign in to comment.