Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.

Commit 3122423

Browse files
authored
[BUGFIX] Fix numpy pad operator (#19787)
* fix numpy pad operator * fix sanity
1 parent 7810961 commit 3122423

2 files changed

Lines changed: 58 additions & 9 deletions

File tree

src/api/operator/numpy/np_pad_op.cc

Lines changed: 57 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -51,20 +51,70 @@ inline int String2MXNetPadType(const std::string& s) {
5151
return 0;
5252
}
5353

54+
inline Tuple<Tuple<int>> BroadcastPadWidth(int ndim, runtime::ADT adt) {
55+
std::vector<mxnet::Tuple<int>> temp;
56+
int adt_size = adt.size();
57+
if (const runtime::IntegerObj* pad = adt[0].as<runtime::IntegerObj>()) {
58+
if (adt_size == 1) {
59+
int pad_width = static_cast<int>(pad->value);
60+
if (ndim == 1) {
61+
temp.emplace_back(mxnet::Tuple<int>({pad_width}));
62+
temp.emplace_back(mxnet::Tuple<int>({pad_width}));
63+
} else {
64+
for (int dim = 0; dim < ndim; dim++) {
65+
temp.emplace_back(mxnet::Tuple<int>({pad_width, pad_width}));
66+
}
67+
}
68+
} else {
69+
CHECK_EQ(adt_size, 2) << "Invalid Input pad_width";
70+
int pad_before = static_cast<int>(pad->value);
71+
int pad_after = static_cast<int>(Downcast<runtime::Integer, ObjectRef>(adt[1])->value);
72+
if (ndim == 1) {
73+
temp.emplace_back(mxnet::Tuple<int>({pad_before}));
74+
temp.emplace_back(mxnet::Tuple<int>({pad_after}));
75+
} else {
76+
for (int dim = 0; dim < ndim; dim++) {
77+
temp.emplace_back(mxnet::Tuple<int>({pad_before, pad_after}));
78+
}
79+
}
80+
}
81+
} else {
82+
if (adt_size == 1) {
83+
if (ndim == 1) {
84+
runtime::ADT pad_adt = Downcast<runtime::ADT, ObjectRef>(adt[0]);
85+
int pad_before =
86+
static_cast<int>(Downcast<runtime::Integer, ObjectRef>(pad_adt[0])->value);
87+
int pad_after =
88+
static_cast<int>(Downcast<runtime::Integer, ObjectRef>(pad_adt[1])->value);
89+
temp.emplace_back(mxnet::Tuple<int>({pad_before}));
90+
temp.emplace_back(mxnet::Tuple<int>({pad_after}));
91+
} else {
92+
for (int dim = 0; dim < ndim; dim++) {
93+
temp.emplace_back(mxnet::Tuple<int>(adt[0]));
94+
}
95+
}
96+
} else {
97+
CHECK_EQ(adt_size, ndim) << "Invalid Input pad_width";
98+
for (int dim = 0; dim < ndim; dim++) {
99+
temp.emplace_back(mxnet::Tuple<int>(adt[dim]));
100+
}
101+
}
102+
}
103+
return Tuple<Tuple<int>>(temp.begin(), temp.end());
104+
}
105+
54106
MXNET_REGISTER_API("_npi.pad")
55107
.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) {
56108
using namespace runtime;
57109
const nnvm::Op* op = Op::Get("_npi_pad");
58110
nnvm::NodeAttrs attrs;
59111
op::NumpyPadParam param;
112+
NDArray* inputs[] = {args[0].operator mxnet::NDArray*()};
113+
mxnet::TShape ashape = inputs[0]->shape();
114+
int ndim = ashape.ndim();
60115
ADT adt = Downcast<ADT, ObjectRef>(args[1].operator ObjectRef());
61-
int ndim = adt.size();
62-
std::vector<mxnet::Tuple<int>> temp;
63-
int counter = 0;
64-
for (counter = 0; counter < ndim; counter++) {
65-
temp.emplace_back(mxnet::Tuple<int>(adt[counter]));
66-
}
67-
param.pad_width = Tuple<Tuple<int>>(temp.begin(), temp.end());
116+
// broadcast pad_width to (ndim, 2)
117+
param.pad_width = BroadcastPadWidth(ndim, adt);
68118
param.mode = String2MXNetPadType(args[2].operator std::string());
69119
if (args[3].type_code() != kNull) {
70120
param.constant_values = args[3].operator double();
@@ -77,7 +127,6 @@ MXNET_REGISTER_API("_npi.pad")
77127
SetAttrDict<op::NumpyPadParam>(&attrs);
78128
int num_inputs = 1;
79129
int num_outputs = 0;
80-
NDArray* inputs[] = {args[0].operator mxnet::NDArray*()};
81130
auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, nullptr);
82131
*ret = reinterpret_cast<mxnet::NDArray*>(ndoutputs[0]);
83132
});

tests/python/unittest/test_numpy_op.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8325,7 +8325,7 @@ def __init__(self, pad_width, mode='constant'):
83258325
def hybrid_forward(self,F,A,**kwargs):
83268326
return F.np.pad(A, self._pad_width, mode=self._mode, **kwargs)
83278327

8328-
shapes = [(1,5), (2,2), (2,2), (3,3), (2,3), (3,4,5)]
8328+
shapes = [6, (1,5), (2,2), (2,2), (3,3), (2,3), (3,4,5)]
83298329
dtypes = [np.int8, np.uint8, np.int32, np.int64, np.float16, np.float32, np.float64]
83308330
mode = ['constant', 'reflect', 'symmetric', 'edge', 'minimum', 'maximum']
83318331
for hybridize, shape, dtype, in itertools.product([False,True], shapes, dtypes):

0 commit comments

Comments
 (0)