@@ -51,20 +51,70 @@ inline int String2MXNetPadType(const std::string& s) {
5151 return 0 ;
5252}
5353
54+ inline Tuple<Tuple<int >> BroadcastPadWidth (int ndim, runtime::ADT adt) {
55+ std::vector<mxnet::Tuple<int >> temp;
56+ int adt_size = adt.size ();
57+ if (const runtime::IntegerObj* pad = adt[0 ].as <runtime::IntegerObj>()) {
58+ if (adt_size == 1 ) {
59+ int pad_width = static_cast <int >(pad->value );
60+ if (ndim == 1 ) {
61+ temp.emplace_back (mxnet::Tuple<int >({pad_width}));
62+ temp.emplace_back (mxnet::Tuple<int >({pad_width}));
63+ } else {
64+ for (int dim = 0 ; dim < ndim; dim++) {
65+ temp.emplace_back (mxnet::Tuple<int >({pad_width, pad_width}));
66+ }
67+ }
68+ } else {
69+ CHECK_EQ (adt_size, 2 ) << " Invalid Input pad_width" ;
70+ int pad_before = static_cast <int >(pad->value );
71+ int pad_after = static_cast <int >(Downcast<runtime::Integer, ObjectRef>(adt[1 ])->value );
72+ if (ndim == 1 ) {
73+ temp.emplace_back (mxnet::Tuple<int >({pad_before}));
74+ temp.emplace_back (mxnet::Tuple<int >({pad_after}));
75+ } else {
76+ for (int dim = 0 ; dim < ndim; dim++) {
77+ temp.emplace_back (mxnet::Tuple<int >({pad_before, pad_after}));
78+ }
79+ }
80+ }
81+ } else {
82+ if (adt_size == 1 ) {
83+ if (ndim == 1 ) {
84+ runtime::ADT pad_adt = Downcast<runtime::ADT, ObjectRef>(adt[0 ]);
85+ int pad_before =
86+ static_cast <int >(Downcast<runtime::Integer, ObjectRef>(pad_adt[0 ])->value );
87+ int pad_after =
88+ static_cast <int >(Downcast<runtime::Integer, ObjectRef>(pad_adt[1 ])->value );
89+ temp.emplace_back (mxnet::Tuple<int >({pad_before}));
90+ temp.emplace_back (mxnet::Tuple<int >({pad_after}));
91+ } else {
92+ for (int dim = 0 ; dim < ndim; dim++) {
93+ temp.emplace_back (mxnet::Tuple<int >(adt[0 ]));
94+ }
95+ }
96+ } else {
97+ CHECK_EQ (adt_size, ndim) << " Invalid Input pad_width" ;
98+ for (int dim = 0 ; dim < ndim; dim++) {
99+ temp.emplace_back (mxnet::Tuple<int >(adt[dim]));
100+ }
101+ }
102+ }
103+ return Tuple<Tuple<int >>(temp.begin (), temp.end ());
104+ }
105+
54106MXNET_REGISTER_API (" _npi.pad" )
55107.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) {
56108 using namespace runtime ;
57109 const nnvm::Op* op = Op::Get (" _npi_pad" );
58110 nnvm::NodeAttrs attrs;
59111 op::NumpyPadParam param;
112+ NDArray* inputs[] = {args[0 ].operator mxnet::NDArray*()};
113+ mxnet::TShape ashape = inputs[0 ]->shape ();
114+ int ndim = ashape.ndim ();
60115 ADT adt = Downcast<ADT, ObjectRef>(args[1 ].operator ObjectRef ());
61- int ndim = adt.size ();
62- std::vector<mxnet::Tuple<int >> temp;
63- int counter = 0 ;
64- for (counter = 0 ; counter < ndim; counter++) {
65- temp.emplace_back (mxnet::Tuple<int >(adt[counter]));
66- }
67- param.pad_width = Tuple<Tuple<int >>(temp.begin (), temp.end ());
116+ // broadcast pad_width to (ndim, 2)
117+ param.pad_width = BroadcastPadWidth (ndim, adt);
68118 param.mode = String2MXNetPadType (args[2 ].operator std::string ());
69119 if (args[3 ].type_code () != kNull ) {
70120 param.constant_values = args[3 ].operator double ();
@@ -77,7 +127,6 @@ MXNET_REGISTER_API("_npi.pad")
77127 SetAttrDict<op::NumpyPadParam>(&attrs);
78128 int num_inputs = 1 ;
79129 int num_outputs = 0 ;
80- NDArray* inputs[] = {args[0 ].operator mxnet::NDArray*()};
81130 auto ndoutputs = Invoke (op, &attrs, num_inputs, inputs, &num_outputs, nullptr );
82131 *ret = reinterpret_cast <mxnet::NDArray*>(ndoutputs[0 ]);
83132});
0 commit comments