2525
2626#if MXNET_USE_ONEDNN == 1
2727
28- #include < dnnl.hpp>
29-
3028#include " ../../tensor/matrix_op-inl.h"
3129
30+ #include " ./dnnl_transpose-inl.h"
31+
3232namespace mxnet {
3333namespace op {
3434
35- bool SupportDNNLTranspose (const TransposeParam& param, const NDArray& data) {
35+ bool SupportDNNLTranspose (const NDArray& data) {
3636 auto data_ndim = data.shape ().ndim ();
3737
3838 if (data_ndim > 4 || data_ndim == 0 || data.shape ().Size () == 0 ||
@@ -42,107 +42,104 @@ bool SupportDNNLTranspose(const TransposeParam& param, const NDArray& data) {
4242 return true ;
4343}
4444
45- typedef ParamOpSign<TransposeParam> DNNLTransposeSignature;
46-
47- class DNNLTransposeForward {
48- public:
49- std::shared_ptr<dnnl::memory> data_;
50- std::shared_ptr<dnnl::memory> out_;
51- std::shared_ptr<dnnl::memory::desc> dst_md_;
52- std::shared_ptr<dnnl::reorder> transpose_;
53-
54- public:
55- DNNLTransposeForward (const TransposeParam& param, const NDArray& data) {
56- auto shape = data.shape ();
57- auto data_ndim = shape.ndim ();
58- auto axes_ndim = param.axes .ndim ();
59- auto axes = mxnet::TShape (data_ndim, -1 );
60- if (axes_ndim == 0 ) {
61- for (int i = 0 ; i < data_ndim; i++) {
62- axes[i] = data_ndim - i - 1 ;
63- }
64- } else {
65- axes = param.axes ;
66- }
45+ typedef ParamOpSign<NumpyTransposeParam> DNNLTransposeSignature;
6746
68- auto engine = CpuEngine::Get ()->get_engine ();
69- auto in_mem = data.GetDNNLData ();
70- auto src_md = in_mem->get_desc ();
71- data_ = std::make_shared<dnnl::memory>(src_md, engine, nullptr );
72-
73- dnnl_dims_t strides;
74- dnnl_dims_t sh;
75- dim_t total_stride = 1 ;
76- for (int i = data_ndim - 1 ; i >= 0 ; i--) {
77- sh[i] = shape[i];
78- strides[axes[i]] = total_stride;
79- total_stride *= shape[axes[i]];
47+ DNNLTransposeFwd::DNNLTransposeFwd (const NumpyTransposeParam& param, const NDArray& data) {
48+ auto shape = data.shape ();
49+ auto data_ndim = shape.ndim ();
50+ auto axes_ndim = param.axes .ndim ();
51+ auto axes = mxnet::TShape (data_ndim, -1 );
52+ if (!ndim_is_known (axes_ndim)) {
53+ for (int i = 0 ; i < data_ndim; i++) {
54+ axes[i] = data_ndim - i - 1 ;
8055 }
56+ } else {
57+ axes = param.axes ;
58+ }
8159
82- dnnl_memory_desc_t dst_fmt;
83- dnnl_memory_desc_init_by_strides (&dst_fmt, data_ndim, sh, dnnl_f32, strides);
60+ auto engine = CpuEngine::Get ()->get_engine ();
61+ auto in_mem = data.GetDNNLData ();
62+ auto src_md = in_mem->get_desc ();
63+ data_ = std::make_shared<dnnl::memory>(src_md, engine, nullptr );
64+
65+ dnnl_dims_t strides;
66+ dnnl_dims_t sh;
67+ dim_t total_stride = 1 ;
68+ for (int i = data_ndim - 1 ; i >= 0 ; i--) {
69+ sh[i] = shape[i];
70+ strides[axes[i]] = total_stride;
71+ total_stride *= shape[axes[i]];
72+ }
8473
85- dst_md_ = std::make_shared<dnnl::memory::desc>( dst_fmt) ;
86- out_ = std::make_shared<dnnl::memory>(*dst_md_, engine, nullptr );
74+ dnnl_memory_desc_t dst_fmt;
75+ dnnl_memory_desc_init_by_strides (&dst_fmt, data_ndim, sh, dnnl_f32, strides );
8776
88- transpose_ = std::make_shared<dnnl::reorder>(*data_, *out_ );
89- }
77+ dst_md_ = std::make_shared<dnnl::memory::desc>(dst_fmt );
78+ out_ = std::make_shared<dnnl::memory>(*dst_md_, engine, nullptr );
9079
91- void SetNewMem (const NDArray& data, const NDArray& output) {
92- if (data.IsDNNLData ()) {
93- this ->data_ ->set_data_handle (data.GetDNNLData ()->get_data_handle ());
94- } else {
95- MSHADOW_TYPE_SWITCH (
96- data.dtype (), DTYPE, { this ->data_ ->set_data_handle (data.data ().dptr <DTYPE>()); });
97- }
80+ transpose_ = std::make_shared<dnnl::reorder>(*data_, *out_);
81+ }
9882
99- CHECK (!output.IsDNNLData ());
83+ void DNNLTransposeFwd::SetNewMem (const NDArray& data, const NDArray& output) {
84+ if (data.IsDNNLData ()) {
85+ this ->data_ ->set_data_handle (data.GetDNNLData ()->get_data_handle ());
86+ } else {
10087 MSHADOW_TYPE_SWITCH (
101- output .dtype (), DTYPE, { this ->out_ ->set_data_handle (output .data ().dptr <DTYPE>()); });
88+ data .dtype (), DTYPE, { this ->data_ ->set_data_handle (data .data ().dptr <DTYPE>()); });
10289 }
10390
104- const dnnl::reorder& GetFwd () const {
105- return *transpose_;
106- }
91+ CHECK (!output.IsDNNLData ());
92+ MSHADOW_TYPE_SWITCH (
93+ output.dtype (), DTYPE, { this ->out_ ->set_data_handle (output.data ().dptr <DTYPE>()); });
94+ }
10795
108- void Execute () const {
109- auto stream = DNNLStream::Get ();
110- dnnl_args_map_t net_args;
111- net_args.insert ({{DNNL_ARG_FROM, *(data_)}, {DNNL_ARG_TO, *(out_)}});
112- stream->RegisterPrimArgs (*transpose_, net_args);
113- stream->Submit ();
114- }
115- };
96+ const dnnl::reorder& DNNLTransposeFwd::GetFwd () const {
97+ return *transpose_;
98+ }
99+
100+ void DNNLTransposeFwd::Execute () const {
101+ auto stream = DNNLStream::Get ();
102+ dnnl_args_map_t net_args;
103+ net_args.insert ({{DNNL_ARG_FROM, *(data_)}, {DNNL_ARG_TO, *(out_)}});
104+ stream->RegisterPrimArgs (*transpose_, net_args);
105+ stream->Submit ();
106+ }
116107
117- static DNNLTransposeForward & GetTransposeForward (const TransposeParam & param, const NDArray& data) {
108+ DNNLTransposeFwd & GetTransposeForward (const NumpyTransposeParam & param, const NDArray& data) {
118109#if DMLC_CXX11_THREAD_LOCAL
119- static thread_local std::unordered_map<DNNLTransposeSignature, DNNLTransposeForward , OpHash> fwds;
110+ static thread_local std::unordered_map<DNNLTransposeSignature, DNNLTransposeFwd , OpHash> fwds;
120111#else
121- static MX_THREAD_LOCAL std::unordered_map<DNNLTransposeSignature, DNNLTransposeForward, OpHash>
122- fwds;
112+ static MX_THREAD_LOCAL std::unordered_map<DNNLTransposeSignature, DNNLTransposeFwd, OpHash> fwds;
123113#endif
124114 DNNLTransposeSignature key (param);
125115 key.AddSign (data);
126116
127117 auto it = fwds.find (key);
128118 if (it == fwds.end ()) {
129- DNNLTransposeForward fwd (param, data);
119+ DNNLTransposeFwd fwd (param, data);
130120 it = AddToCache (&fwds, key, fwd);
131121 }
132122 return it->second ;
133123}
134124
135- void DNNLTransposeForward ( const nnvm::NodeAttrs& attrs,
136- const OpContext& ctx,
137- const NDArray& data,
138- const OpReqType& req,
139- const NDArray& output) {
140- const TransposeParam& param = nnvm::get<TransposeParam>(attrs. parsed );
125+ template <>
126+ NumpyTransposeParam ConvertParamsToNumpy<NumpyTransposeParam>( const NumpyTransposeParam& param) {
127+ NumpyTransposeParam numpy_param;
128+ numpy_param. axes = common::CanonicalizeAxes (param. axes );
129+ return numpy_param;
130+ }
141131
142- auto fwd = GetTransposeForward (param, data);
143- fwd.SetNewMem (data, output);
144- fwd.Execute ();
132+ template <>
133+ NumpyTransposeParam ConvertParamsToNumpy<TransposeParam>(const TransposeParam& param) {
134+ NumpyTransposeParam numpy_param;
135+ if (param.axes .ndim () == 0 ) {
136+ numpy_param.axes = mxnet::TShape (-1 , 0 );
137+ } else {
138+ numpy_param.axes = param.axes ;
139+ }
140+ return numpy_param;
145141}
142+
146143} // namespace op
147144} // namespace mxnet
148- #endif
145+ #endif // MXNET_USE_ONEDNN == 1
0 commit comments