@@ -208,26 +208,28 @@ class CustomSubgraphProperty: public SubgraphProperty {
208208 arg_dev_type.clear ();
209209 arg_dev_id.clear ();
210210 for (size_t i=0 ; i < in_arg_names.size (); i++) {
211- arg_names.push_back (in_arg_names[i].c_str ());
212- const NDArray &in_arg = *(in_args_ptr[i]);
211+ if (in_args_ptr[i] != nullptr ) {
212+ arg_names.push_back (in_arg_names[i].c_str ());
213+ const NDArray &in_arg = *(in_args_ptr[i]);
213214
214215#if MXNET_USE_MKLDNN == 1
215- // reorder data if in MKLDNN format
216- if (in_arg.IsMKLDNNData ()) {
217- in_arg.Reorder2DefaultAsync ();
218- in_arg.WaitToRead ();
219- }
216+ // reorder data if in MKLDNN format
217+ if (in_arg.IsMKLDNNData ()) {
218+ in_arg.Reorder2DefaultAsync ();
219+ in_arg.WaitToRead ();
220+ }
220221#endif
221222
222- // pull out parts of NDArray to send to backend
223- arg_data.push_back (in_arg.data ().dptr_ );
224- arg_shapes.push_back (in_arg.shape ().data ());
225- arg_dims.push_back (in_arg.shape ().ndim ());
226- arg_types.push_back (in_arg.dtype ());
227- arg_verIDs.push_back (in_arg.version ());
228- const char * arg_ctx_str = in_arg.ctx ().dev_mask () == Context::kCPU ? " cpu" : " gpu" ;
229- arg_dev_type.push_back (arg_ctx_str);
230- arg_dev_id.push_back (in_arg.ctx ().real_dev_id ());
223+ // pull out parts of NDArray to send to backend
224+ arg_data.push_back (in_arg.data ().dptr_ );
225+ arg_shapes.push_back (in_arg.shape ().data ());
226+ arg_dims.push_back (in_arg.shape ().ndim ());
227+ arg_types.push_back (in_arg.dtype ());
228+ arg_verIDs.push_back (in_arg.version ());
229+ const char * arg_ctx_str = in_arg.ctx ().dev_mask () == Context::kCPU ? " cpu" : " gpu" ;
230+ arg_dev_type.push_back (arg_ctx_str);
231+ arg_dev_id.push_back (in_arg.ctx ().real_dev_id ());
232+ }
231233 }
232234
233235 // convert input aux
@@ -240,26 +242,28 @@ class CustomSubgraphProperty: public SubgraphProperty {
240242 aux_dev_type.clear ();
241243 aux_dev_id.clear ();
242244 for (size_t i=0 ; i < in_aux_names.size (); i++) {
243- aux_names.push_back (in_aux_names[i].c_str ());
244- const auto &in_aux = *(in_aux_ptr[i]);
245+ if (in_aux_ptr[i] != nullptr ) {
246+ aux_names.push_back (in_aux_names[i].c_str ());
247+ const auto &in_aux = *(in_aux_ptr[i]);
245248
246249#if MXNET_USE_MKLDNN == 1
247- // reorder data if in MKLDNN format
248- if (in_aux.IsMKLDNNData ()) {
249- in_aux.Reorder2DefaultAsync ();
250- in_aux.WaitToRead ();
251- }
250+ // reorder data if in MKLDNN format
251+ if (in_aux.IsMKLDNNData ()) {
252+ in_aux.Reorder2DefaultAsync ();
253+ in_aux.WaitToRead ();
254+ }
252255#endif
253256
254- // pull out parts of NDArray to send to backend
255- aux_data.push_back (in_aux.data ().dptr_ );
256- aux_shapes.push_back (in_aux.shape ().data ());
257- aux_dims.push_back (in_aux.shape ().ndim ());
258- aux_types.push_back (in_aux.dtype ());
259- aux_verIDs.push_back (in_aux.version ());
260- const char * aux_ctx_str = in_aux.ctx ().dev_mask () == Context::kCPU ? " cpu" : " gpu" ;
261- aux_dev_type.push_back (aux_ctx_str);
262- aux_dev_id.push_back (in_aux.ctx ().real_dev_id ());
257+ // pull out parts of NDArray to send to backend
258+ aux_data.push_back (in_aux.data ().dptr_ );
259+ aux_shapes.push_back (in_aux.shape ().data ());
260+ aux_dims.push_back (in_aux.shape ().ndim ());
261+ aux_types.push_back (in_aux.dtype ());
262+ aux_verIDs.push_back (in_aux.version ());
263+ const char * aux_ctx_str = in_aux.ctx ().dev_mask () == Context::kCPU ? " cpu" : " gpu" ;
264+ aux_dev_type.push_back (aux_ctx_str);
265+ aux_dev_id.push_back (in_aux.ctx ().real_dev_id ());
266+ }
263267 }
264268
265269 // remove all graph attrs, some cannot be saved to json
@@ -285,13 +289,17 @@ class CustomSubgraphProperty: public SubgraphProperty {
285289 for (unsigned oid = 0 ; oid < node->num_outputs (); oid++) {
286290 const uint32_t out_entry_id = indexed_graph.entry_id (nid, oid);
287291 mxnet::TShape& shape = shapes[out_entry_id];
288- ss << shape;
292+ if (shape.ndim () == -1 )
293+ ss << " [None]" ;
294+ else
295+ ss << shape;
289296 if (oid < node->num_outputs ()-1 ) ss << " ," ;
290297 }
291298 ss << " ]" ;
292299 node->attrs .dict [MX_STR_SHAPE] = ss.str ();
293300 }
294301 }
302+
295303 // set dtype attrs for each node in the graph
296304 if (g.HasAttr (" dtype" )) {
297305 std::vector<int > dtypes = g.GetAttr <std::vector<int > >(" dtype" );
0 commit comments