Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[BUGFIX] Fix numpy pad operator #19787

Merged
merged 3 commits into from
Jan 29, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 57 additions & 8 deletions src/api/operator/numpy/np_pad_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -51,20 +51,70 @@ inline int String2MXNetPadType(const std::string& s) {
return 0;
}

inline Tuple<Tuple<int>> BroadcastPadWidth(int ndim, runtime::ADT adt) {
std::vector<mxnet::Tuple<int>> temp;
int adt_size = adt.size();
if (const runtime::IntegerObj* pad = adt[0].as<runtime::IntegerObj>()) {
if (adt_size == 1) {
int pad_width = static_cast<int>(pad->value);
if (ndim == 1) {
temp.emplace_back(mxnet::Tuple<int>({pad_width}));
temp.emplace_back(mxnet::Tuple<int>({pad_width}));
} else {
for (int dim = 0; dim < ndim; dim++) {
temp.emplace_back(mxnet::Tuple<int>({pad_width, pad_width}));
}
}
} else {
CHECK_EQ(adt_size, 2) << "Invalid Input pad_width";
int pad_before = static_cast<int>(pad->value);
int pad_after = static_cast<int>(Downcast<runtime::Integer, ObjectRef>(adt[1])->value);
if (ndim == 1) {
temp.emplace_back(mxnet::Tuple<int>({pad_before}));
temp.emplace_back(mxnet::Tuple<int>({pad_after}));
} else {
for (int dim = 0; dim < ndim; dim++) {
temp.emplace_back(mxnet::Tuple<int>({pad_before, pad_after}));
}
}
}
} else {
if (adt_size == 1) {
if (ndim == 1) {
runtime::ADT pad_adt = Downcast<runtime::ADT, ObjectRef>(adt[0]);
int pad_before =
static_cast<int>(Downcast<runtime::Integer, ObjectRef>(pad_adt[0])->value);
int pad_after =
static_cast<int>(Downcast<runtime::Integer, ObjectRef>(pad_adt[1])->value);
temp.emplace_back(mxnet::Tuple<int>({pad_before}));
temp.emplace_back(mxnet::Tuple<int>({pad_after}));
} else {
for (int dim = 0; dim < ndim; dim++) {
temp.emplace_back(mxnet::Tuple<int>(adt[0]));
}
}
} else {
CHECK_EQ(adt_size, ndim) << "Invalid Input pad_width";
for (int dim = 0; dim < ndim; dim++) {
temp.emplace_back(mxnet::Tuple<int>(adt[dim]));
}
}
}
return Tuple<Tuple<int>>(temp.begin(), temp.end());
}

MXNET_REGISTER_API("_npi.pad")
.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) {
using namespace runtime;
const nnvm::Op* op = Op::Get("_npi_pad");
nnvm::NodeAttrs attrs;
op::NumpyPadParam param;
NDArray* inputs[] = {args[0].operator mxnet::NDArray*()};
mxnet::TShape ashape = inputs[0]->shape();
int ndim = ashape.ndim();
ADT adt = Downcast<ADT, ObjectRef>(args[1].operator ObjectRef());
int ndim = adt.size();
std::vector<mxnet::Tuple<int>> temp;
int counter = 0;
for (counter = 0; counter < ndim; counter++) {
temp.emplace_back(mxnet::Tuple<int>(adt[counter]));
}
param.pad_width = Tuple<Tuple<int>>(temp.begin(), temp.end());
// broadcast pad_width to (ndim, 2)
param.pad_width = BroadcastPadWidth(ndim, adt);
param.mode = String2MXNetPadType(args[2].operator std::string());
if (args[3].type_code() != kNull) {
param.constant_values = args[3].operator double();
Expand All @@ -77,7 +127,6 @@ MXNET_REGISTER_API("_npi.pad")
SetAttrDict<op::NumpyPadParam>(&attrs);
int num_inputs = 1;
int num_outputs = 0;
NDArray* inputs[] = {args[0].operator mxnet::NDArray*()};
auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, nullptr);
*ret = reinterpret_cast<mxnet::NDArray*>(ndoutputs[0]);
});
Expand Down
2 changes: 1 addition & 1 deletion tests/python/unittest/test_numpy_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -8325,7 +8325,7 @@ def __init__(self, pad_width, mode='constant'):
def hybrid_forward(self,F,A,**kwargs):
return F.np.pad(A, self._pad_width, mode=self._mode, **kwargs)

shapes = [(1,5), (2,2), (2,2), (3,3), (2,3), (3,4,5)]
shapes = [6, (1,5), (2,2), (2,2), (3,3), (2,3), (3,4,5)]
dtypes = [np.int8, np.uint8, np.int32, np.int64, np.float16, np.float32, np.float64]
mode = ['constant', 'reflect', 'symmetric', 'edge', 'minimum', 'maximum']
for hybridize, shape, dtype, in itertools.product([False,True], shapes, dtypes):
Expand Down