Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.

Commit 16fed6e

Browse files
authored
[FEATURE] add oneDNN support for numpy transpose (#20419)
* numpy transpose onednn usage * remove unnecessary whitespace * remove unnecessary whitespace * remove unnecessary param * formatting changes, cleanup * remove unnecessary lines * template convert param * newline at end * remove unused declarations * whitespace, guard comments * sanity fix * formatting * separate error tests transpose * formatting * separate transpose error tests * transpose header dnnl * format files sanity * move include transpose * unify param templates * format, rename funcs * switch include order * dont sort includes for transpose * remove clang off section * delete unnecessary newline * add newlines * remove whitespace * remove whitespace
1 parent af1622e commit 16fed6e

7 files changed

Lines changed: 211 additions & 84 deletions

File tree

src/operator/nn/dnnl/dnnl_base-inl.h

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -181,7 +181,6 @@ struct ConvolutionParam;
181181
struct DeconvolutionParam;
182182
struct SoftmaxParam;
183183
struct SoftmaxOutputParam;
184-
struct TransposeParam;
185184
struct ReshapeParam;
186185
struct LayerNormParam;
187186
bool SupportDNNLAct(const ActivationParam& param);
@@ -194,7 +193,7 @@ bool SupportDNNLDeconv(const DeconvolutionParam& params, const NDArray& input);
194193
bool SupportDNNLSoftmax(const SoftmaxParam& param, const NDArray& input, const NDArray& output);
195194
bool SupportDNNLLogSoftmax(const SoftmaxParam& param, const NDArray& input, const NDArray& output);
196195
bool SupportDNNLSoftmaxOutput(const SoftmaxOutputParam& param);
197-
bool SupportDNNLTranspose(const TransposeParam& param, const NDArray& data);
196+
bool SupportDNNLTranspose(const NDArray& data);
198197
bool SupportDNNLBatchDot(const std::vector<NDArray>& inputs, const NDArray& output);
199198
bool SupportDNNLLayerNorm(const LayerNormParam& param, const std::vector<NDArray>& inputs);
200199
bool SupportDNNLReshape(const NDArray& input, const NDArray& output);

src/operator/nn/dnnl/dnnl_ops-inl.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -180,6 +180,7 @@ void DNNLLayerNormBackward(const nnvm::NodeAttrs& attrs,
180180

181181
void DNNLSum(const dnnl::memory& arr1, const dnnl::memory& arr2, const dnnl::memory& out);
182182

183+
template <class ParamType>
183184
void DNNLTransposeForward(const nnvm::NodeAttrs& attrs,
184185
const OpContext& ctx,
185186
const NDArray& data,
Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
/*!
21+
* \file dnnl_transpose-inl.h
22+
* \author Rafal Litka
23+
*/
24+
25+
#ifndef MXNET_OPERATOR_NN_DNNL_DNNL_TRANSPOSE_INL_H_
26+
#define MXNET_OPERATOR_NN_DNNL_DNNL_TRANSPOSE_INL_H_
27+
#if MXNET_USE_ONEDNN == 1
28+
29+
#include "./dnnl_base-inl.h"
30+
#include "./dnnl_ops-inl.h"
31+
32+
#include "../../numpy/np_matrix_op-inl.h"
33+
34+
namespace mxnet {
35+
namespace op {
36+
37+
bool SupportDNNLTranspose(const NDArray& data);
38+
39+
class DNNLTransposeFwd {
40+
public:
41+
std::shared_ptr<dnnl::memory> data_;
42+
std::shared_ptr<dnnl::memory> out_;
43+
std::shared_ptr<dnnl::memory::desc> dst_md_;
44+
std::shared_ptr<dnnl::reorder> transpose_;
45+
DNNLTransposeFwd(const NumpyTransposeParam& param, const NDArray& data);
46+
void SetNewMem(const NDArray& data, const NDArray& output);
47+
const dnnl::reorder& GetFwd() const;
48+
void Execute() const;
49+
};
50+
51+
DNNLTransposeFwd& GetTransposeForward(const NumpyTransposeParam& param, const NDArray& data);
52+
53+
template <class ParamType>
54+
NumpyTransposeParam ConvertParamsToNumpy(const ParamType& param);
55+
56+
template <class ParamType>
57+
void DNNLTransposeForward(const nnvm::NodeAttrs& attrs,
58+
const OpContext& ctx,
59+
const NDArray& data,
60+
const OpReqType& req,
61+
const NDArray& output) {
62+
const ParamType& org_param = nnvm::get<ParamType>(attrs.parsed);
63+
auto param = ConvertParamsToNumpy<ParamType>(org_param);
64+
auto fwd = GetTransposeForward(param, data);
65+
fwd.SetNewMem(data, output);
66+
fwd.Execute();
67+
}
68+
69+
} // namespace op
70+
} // namespace mxnet
71+
72+
#endif // MXNET_USE_ONEDNN == 1
73+
#endif // MXNET_OPERATOR_NN_DNNL_DNNL_TRANSPOSE_INL_H_

src/operator/nn/dnnl/dnnl_transpose.cc

Lines changed: 75 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -25,14 +25,14 @@
2525

2626
#if MXNET_USE_ONEDNN == 1
2727

28-
#include <dnnl.hpp>
29-
3028
#include "../../tensor/matrix_op-inl.h"
3129

30+
#include "./dnnl_transpose-inl.h"
31+
3232
namespace mxnet {
3333
namespace op {
3434

35-
bool SupportDNNLTranspose(const TransposeParam& param, const NDArray& data) {
35+
bool SupportDNNLTranspose(const NDArray& data) {
3636
auto data_ndim = data.shape().ndim();
3737

3838
if (data_ndim > 4 || data_ndim == 0 || data.shape().Size() == 0 ||
@@ -42,107 +42,104 @@ bool SupportDNNLTranspose(const TransposeParam& param, const NDArray& data) {
4242
return true;
4343
}
4444

45-
typedef ParamOpSign<TransposeParam> DNNLTransposeSignature;
46-
47-
class DNNLTransposeForward {
48-
public:
49-
std::shared_ptr<dnnl::memory> data_;
50-
std::shared_ptr<dnnl::memory> out_;
51-
std::shared_ptr<dnnl::memory::desc> dst_md_;
52-
std::shared_ptr<dnnl::reorder> transpose_;
53-
54-
public:
55-
DNNLTransposeForward(const TransposeParam& param, const NDArray& data) {
56-
auto shape = data.shape();
57-
auto data_ndim = shape.ndim();
58-
auto axes_ndim = param.axes.ndim();
59-
auto axes = mxnet::TShape(data_ndim, -1);
60-
if (axes_ndim == 0) {
61-
for (int i = 0; i < data_ndim; i++) {
62-
axes[i] = data_ndim - i - 1;
63-
}
64-
} else {
65-
axes = param.axes;
66-
}
45+
typedef ParamOpSign<NumpyTransposeParam> DNNLTransposeSignature;
6746

68-
auto engine = CpuEngine::Get()->get_engine();
69-
auto in_mem = data.GetDNNLData();
70-
auto src_md = in_mem->get_desc();
71-
data_ = std::make_shared<dnnl::memory>(src_md, engine, nullptr);
72-
73-
dnnl_dims_t strides;
74-
dnnl_dims_t sh;
75-
dim_t total_stride = 1;
76-
for (int i = data_ndim - 1; i >= 0; i--) {
77-
sh[i] = shape[i];
78-
strides[axes[i]] = total_stride;
79-
total_stride *= shape[axes[i]];
47+
DNNLTransposeFwd::DNNLTransposeFwd(const NumpyTransposeParam& param, const NDArray& data) {
48+
auto shape = data.shape();
49+
auto data_ndim = shape.ndim();
50+
auto axes_ndim = param.axes.ndim();
51+
auto axes = mxnet::TShape(data_ndim, -1);
52+
if (!ndim_is_known(axes_ndim)) {
53+
for (int i = 0; i < data_ndim; i++) {
54+
axes[i] = data_ndim - i - 1;
8055
}
56+
} else {
57+
axes = param.axes;
58+
}
8159

82-
dnnl_memory_desc_t dst_fmt;
83-
dnnl_memory_desc_init_by_strides(&dst_fmt, data_ndim, sh, dnnl_f32, strides);
60+
auto engine = CpuEngine::Get()->get_engine();
61+
auto in_mem = data.GetDNNLData();
62+
auto src_md = in_mem->get_desc();
63+
data_ = std::make_shared<dnnl::memory>(src_md, engine, nullptr);
64+
65+
dnnl_dims_t strides;
66+
dnnl_dims_t sh;
67+
dim_t total_stride = 1;
68+
for (int i = data_ndim - 1; i >= 0; i--) {
69+
sh[i] = shape[i];
70+
strides[axes[i]] = total_stride;
71+
total_stride *= shape[axes[i]];
72+
}
8473

85-
dst_md_ = std::make_shared<dnnl::memory::desc>(dst_fmt);
86-
out_ = std::make_shared<dnnl::memory>(*dst_md_, engine, nullptr);
74+
dnnl_memory_desc_t dst_fmt;
75+
dnnl_memory_desc_init_by_strides(&dst_fmt, data_ndim, sh, dnnl_f32, strides);
8776

88-
transpose_ = std::make_shared<dnnl::reorder>(*data_, *out_);
89-
}
77+
dst_md_ = std::make_shared<dnnl::memory::desc>(dst_fmt);
78+
out_ = std::make_shared<dnnl::memory>(*dst_md_, engine, nullptr);
9079

91-
void SetNewMem(const NDArray& data, const NDArray& output) {
92-
if (data.IsDNNLData()) {
93-
this->data_->set_data_handle(data.GetDNNLData()->get_data_handle());
94-
} else {
95-
MSHADOW_TYPE_SWITCH(
96-
data.dtype(), DTYPE, { this->data_->set_data_handle(data.data().dptr<DTYPE>()); });
97-
}
80+
transpose_ = std::make_shared<dnnl::reorder>(*data_, *out_);
81+
}
9882

99-
CHECK(!output.IsDNNLData());
83+
void DNNLTransposeFwd::SetNewMem(const NDArray& data, const NDArray& output) {
84+
if (data.IsDNNLData()) {
85+
this->data_->set_data_handle(data.GetDNNLData()->get_data_handle());
86+
} else {
10087
MSHADOW_TYPE_SWITCH(
101-
output.dtype(), DTYPE, { this->out_->set_data_handle(output.data().dptr<DTYPE>()); });
88+
data.dtype(), DTYPE, { this->data_->set_data_handle(data.data().dptr<DTYPE>()); });
10289
}
10390

104-
const dnnl::reorder& GetFwd() const {
105-
return *transpose_;
106-
}
91+
CHECK(!output.IsDNNLData());
92+
MSHADOW_TYPE_SWITCH(
93+
output.dtype(), DTYPE, { this->out_->set_data_handle(output.data().dptr<DTYPE>()); });
94+
}
10795

108-
void Execute() const {
109-
auto stream = DNNLStream::Get();
110-
dnnl_args_map_t net_args;
111-
net_args.insert({{DNNL_ARG_FROM, *(data_)}, {DNNL_ARG_TO, *(out_)}});
112-
stream->RegisterPrimArgs(*transpose_, net_args);
113-
stream->Submit();
114-
}
115-
};
96+
const dnnl::reorder& DNNLTransposeFwd::GetFwd() const {
97+
return *transpose_;
98+
}
99+
100+
void DNNLTransposeFwd::Execute() const {
101+
auto stream = DNNLStream::Get();
102+
dnnl_args_map_t net_args;
103+
net_args.insert({{DNNL_ARG_FROM, *(data_)}, {DNNL_ARG_TO, *(out_)}});
104+
stream->RegisterPrimArgs(*transpose_, net_args);
105+
stream->Submit();
106+
}
116107

117-
static DNNLTransposeForward& GetTransposeForward(const TransposeParam& param, const NDArray& data) {
108+
DNNLTransposeFwd& GetTransposeForward(const NumpyTransposeParam& param, const NDArray& data) {
118109
#if DMLC_CXX11_THREAD_LOCAL
119-
static thread_local std::unordered_map<DNNLTransposeSignature, DNNLTransposeForward, OpHash> fwds;
110+
static thread_local std::unordered_map<DNNLTransposeSignature, DNNLTransposeFwd, OpHash> fwds;
120111
#else
121-
static MX_THREAD_LOCAL std::unordered_map<DNNLTransposeSignature, DNNLTransposeForward, OpHash>
122-
fwds;
112+
static MX_THREAD_LOCAL std::unordered_map<DNNLTransposeSignature, DNNLTransposeFwd, OpHash> fwds;
123113
#endif
124114
DNNLTransposeSignature key(param);
125115
key.AddSign(data);
126116

127117
auto it = fwds.find(key);
128118
if (it == fwds.end()) {
129-
DNNLTransposeForward fwd(param, data);
119+
DNNLTransposeFwd fwd(param, data);
130120
it = AddToCache(&fwds, key, fwd);
131121
}
132122
return it->second;
133123
}
134124

135-
void DNNLTransposeForward(const nnvm::NodeAttrs& attrs,
136-
const OpContext& ctx,
137-
const NDArray& data,
138-
const OpReqType& req,
139-
const NDArray& output) {
140-
const TransposeParam& param = nnvm::get<TransposeParam>(attrs.parsed);
125+
template <>
126+
NumpyTransposeParam ConvertParamsToNumpy<NumpyTransposeParam>(const NumpyTransposeParam& param) {
127+
NumpyTransposeParam numpy_param;
128+
numpy_param.axes = common::CanonicalizeAxes(param.axes);
129+
return numpy_param;
130+
}
141131

142-
auto fwd = GetTransposeForward(param, data);
143-
fwd.SetNewMem(data, output);
144-
fwd.Execute();
132+
template <>
133+
NumpyTransposeParam ConvertParamsToNumpy<TransposeParam>(const TransposeParam& param) {
134+
NumpyTransposeParam numpy_param;
135+
if (param.axes.ndim() == 0) {
136+
numpy_param.axes = mxnet::TShape(-1, 0);
137+
} else {
138+
numpy_param.axes = param.axes;
139+
}
140+
return numpy_param;
145141
}
142+
146143
} // namespace op
147144
} // namespace mxnet
148-
#endif
145+
#endif // MXNET_USE_ONEDNN == 1

src/operator/numpy/np_matrix_op-inl.h

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,11 @@ struct NumpyTransposeParam : public dmlc::Parameter<NumpyTransposeParam> {
4949
"By default, reverse the dimensions, otherwise permute "
5050
"the axes according to the values given.");
5151
}
52+
53+
bool operator==(const NumpyTransposeParam& other) const {
54+
return this->axes == other.axes;
55+
}
56+
5257
void SetAttrDict(std::unordered_map<std::string, std::string>* dict) {
5358
std::ostringstream axes_s;
5459
axes_s << axes;
@@ -1868,4 +1873,15 @@ void NumpyDiagIndicesFromForward(const nnvm::NodeAttrs& attrs,
18681873
} // namespace op
18691874
} // namespace mxnet
18701875

1876+
namespace std {
1877+
template <>
1878+
struct hash<mxnet::op::NumpyTransposeParam> {
1879+
size_t operator()(const mxnet::op::NumpyTransposeParam& val) {
1880+
size_t ret = 0;
1881+
ret = dmlc::HashCombine(ret, val.axes);
1882+
return ret;
1883+
}
1884+
};
1885+
} // namespace std
1886+
18711887
#endif // MXNET_OPERATOR_NUMPY_NP_MATRIX_OP_INL_H_

src/operator/numpy/np_matrix_op.cc

Lines changed: 42 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,11 @@
2626
#include <set>
2727
#include "./np_matrix_op-inl.h"
2828
#include "../nn/concat-inl.h"
29-
29+
#if MXNET_USE_ONEDNN == 1
30+
#include "../nn/dnnl/dnnl_ops-inl.h"
31+
#include "../nn/dnnl/dnnl_base-inl.h"
32+
#include "../nn/dnnl/dnnl_transpose-inl.h"
33+
#endif
3034
namespace mxnet {
3135
namespace op {
3236

@@ -100,6 +104,38 @@ bool NumpyTransposeShape(const nnvm::NodeAttrs& attrs,
100104
SHAPE_ASSIGN_CHECK(*out_attrs, 0, ret);
101105
return shape_is_known(*in_attrs) && shape_is_known(*out_attrs);
102106
}
107+
#if MXNET_USE_ONEDNN == 1
108+
109+
static void NumpyTransposeComputeExCPU(const nnvm::NodeAttrs& attrs,
110+
const OpContext& ctx,
111+
const std::vector<NDArray>& inputs,
112+
const std::vector<OpReqType>& req,
113+
const std::vector<NDArray>& outputs) {
114+
if (req[0] == kNullOp) {
115+
return;
116+
}
117+
CHECK(req[0] == kWriteTo || req[0] == kAddTo)
118+
<< "Transpose only supports kNullOp, kWriteTo and kAddTo";
119+
CHECK_EQ(inputs.size(), 1U);
120+
CHECK_EQ(outputs.size(), 1U);
121+
122+
if (SupportDNNLTranspose(inputs[0]) && req[0] == kWriteTo) {
123+
DNNLRun(DNNLTransposeForward<NumpyTransposeParam>, attrs, ctx, inputs[0], req[0], outputs[0]);
124+
return;
125+
}
126+
FallBackCompute(NumpyTranspose<cpu>, attrs, ctx, inputs, req, outputs);
127+
}
128+
129+
inline static bool NumpyTransposeStorageType(const nnvm::NodeAttrs& attrs,
130+
const int dev_mask,
131+
DispatchMode* dispatch_mode,
132+
std::vector<int>* in_attrs,
133+
std::vector<int>* out_attrs) {
134+
CHECK_EQ(in_attrs->size(), 1U);
135+
CHECK_EQ(out_attrs->size(), 1U);
136+
return DNNLStorageType(attrs, dev_mask, true, dispatch_mode, in_attrs, out_attrs);
137+
}
138+
#endif
103139

104140
NNVM_REGISTER_OP(_npi_transpose)
105141
.set_num_inputs(1)
@@ -134,6 +170,11 @@ NNVM_REGISTER_OP(_npi_transpose)
134170
[](const NodeAttrs& attrs) {
135171
return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};
136172
})
173+
#if MXNET_USE_ONEDNN == 1
174+
.set_attr<bool>("TIsDNNL", true)
175+
.set_attr<FComputeEx>("FComputeEx<cpu>", NumpyTransposeComputeExCPU)
176+
.set_attr<FInferStorageType>("FInferStorageType", NumpyTransposeStorageType)
177+
#endif
137178
.set_attr<nnvm::FListInputNames>("FListInputNames",
138179
[](const NodeAttrs& attrs) {
139180
return std::vector<std::string>{"a"};

0 commit comments

Comments
 (0)