Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.

Commit 7d84b59

Browse files
authored
[FEATURE] Integrate oneDNN support for add, subtract, multiply, divide. (#20713)
* Integrate oneDNN support for binary elementwise operators. * Delete template xpu for BinaryOperatorComputeExCPU function * Fix binary operators StorageType functio. * Fix SupportDNNLBinary function. * Fix test_operator, DNNLAlgorithm structure, DNNLData and DNNLBinaryOpForward condition. * Fix test cases, add oneDNN runtime flag to dispatch, remove node attrs, rename pointers * Fix sanity
1 parent 5e08b7f commit 7d84b59

11 files changed

Lines changed: 370 additions & 17 deletions

src/operator/nn/dnnl/dnnl_base-inl.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -199,6 +199,7 @@ bool SupportDNNLLayerNorm(const LayerNormParam& param, const std::vector<NDArray
199199
bool SupportDNNLReshape(const NDArray& input, const NDArray& output);
200200
bool SupportDNNLSplit(const NDArray& input);
201201
bool SupportDNNLStack(const std::vector<NDArray>& inputs);
202+
bool SupportDNNLBinary(const std::vector<NDArray>& inputs);
202203
} // namespace op
203204

204205
static int GetTypeSize(int dtype) {
Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
/*!
21+
* \file dnnl_binary-inl.h
22+
* \author: Adam Grabowski, adam.grabowski@intel.com
23+
*/
24+
25+
#ifndef MXNET_OPERATOR_NN_DNNL_DNNL_BINARY_INL_H_
26+
#define MXNET_OPERATOR_NN_DNNL_DNNL_BINARY_INL_H_
27+
28+
#if MXNET_USE_ONEDNN == 1
29+
#include "./dnnl_base-inl.h"
30+
#include "./dnnl_ops-inl.h"
31+
#include <vector>
32+
33+
#include "../../tensor/elemwise_binary_broadcast_op.h"
34+
35+
namespace mxnet {
36+
namespace op {
37+
38+
using binary_fwd_t = dnnl::binary;
39+
using binary_fwd_pd_t = dnnl::binary::primitive_desc;
40+
41+
class DNNLBinaryOpFwd {
42+
public:
43+
template <dnnl::algorithm alg>
44+
static DNNLBinaryOpFwd& GetBinaryOpForward(const std::vector<NDArray>& inputs,
45+
const std::vector<NDArray>& outputs);
46+
DNNLBinaryOpFwd(const dnnl::algorithm alg,
47+
const std::vector<NDArray>& inputs,
48+
const std::vector<NDArray>& outputs);
49+
50+
void Execute(const std::vector<NDArray>& inputs,
51+
const std::vector<OpReqType>& req,
52+
const std::vector<NDArray>& outputs);
53+
54+
private:
55+
std::shared_ptr<binary_fwd_t> fwd;
56+
std::shared_ptr<binary_fwd_pd_t> fwd_pd;
57+
};
58+
59+
template <dnnl::algorithm alg>
60+
DNNLBinaryOpFwd& DNNLBinaryOpFwd::GetBinaryOpForward(const std::vector<NDArray>& inputs,
61+
const std::vector<NDArray>& outputs) {
62+
using binary_op_fwd_map = std::unordered_map<OpSignature, DNNLBinaryOpFwd, OpHash>;
63+
#if DMLC_CXX11_THREAD_LOCAL
64+
static thread_local binary_op_fwd_map fwds;
65+
#else
66+
static MX_THREAD_LOCAL binary_op_fwd_map fwds;
67+
#endif
68+
OpSignature key;
69+
key.AddSign(static_cast<int>(alg));
70+
key.AddSign(inputs[0]);
71+
key.AddSign(inputs[1]);
72+
key.AddSign(outputs[0]);
73+
74+
auto it = fwds.find(key);
75+
if (it == fwds.end()) {
76+
const DNNLBinaryOpFwd fwd(alg, inputs, outputs);
77+
it = AddToCache(&fwds, key, fwd);
78+
}
79+
return it->second;
80+
}
81+
82+
} // namespace op
83+
} // namespace mxnet
84+
85+
#endif // MXNET_USE_ONEDNN == 1
86+
#endif // MXNET_OPERATOR_NN_DNNL_DNNL_BINARY_INL_H_
Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
/*!
21+
* \file dnnl_binary.cc
22+
* \author: Adam Grabowski, adam.grabowski@intel.com
23+
*/
24+
25+
#if MXNET_USE_ONEDNN == 1
26+
#include "./dnnl_binary-inl.h"
27+
28+
namespace mxnet {
29+
namespace op {
30+
31+
DNNLBinaryOpFwd::DNNLBinaryOpFwd(const dnnl::algorithm alg,
32+
const std::vector<NDArray>& inputs,
33+
const std::vector<NDArray>& outputs) {
34+
auto src0_desc = inputs[0].GetDNNLData()->get_desc();
35+
auto src1_desc = inputs[1].GetDNNLData()->get_desc();
36+
auto dst_desc = outputs[0].GetDNNLData()->get_desc();
37+
38+
dnnl::binary::desc fwd_desc(alg, src0_desc, src1_desc, dst_desc);
39+
fwd_pd = std::make_shared<binary_fwd_pd_t>(fwd_desc, mxnet::CpuEngine::Get()->get_engine());
40+
fwd = std::make_shared<binary_fwd_t>(*fwd_pd);
41+
}
42+
43+
void DNNLBinaryOpFwd::Execute(const std::vector<NDArray>& inputs,
44+
const std::vector<OpReqType>& req,
45+
const std::vector<NDArray>& outputs) {
46+
auto engine = mxnet::CpuEngine::Get()->get_engine();
47+
auto src0 = inputs[0].GetDNNLData();
48+
auto src1 = inputs[1].GetDNNLData();
49+
dnnl_output_t out_mem;
50+
if (outputs[0].GetDNNLData()->get_data_handle() == inputs[1].GetDNNLData()->get_data_handle())
51+
out_mem = CreateDNNLMem(outputs[0], fwd_pd->dst_desc(), req[0], &inputs[1]);
52+
else
53+
out_mem = CreateDNNLMem(outputs[0], fwd_pd->dst_desc(), req[0], &inputs[0]);
54+
55+
dnnl_args_map_t args = {
56+
{DNNL_ARG_SRC_0, *src0},
57+
{DNNL_ARG_SRC_1, *src1},
58+
{DNNL_ARG_DST, *out_mem.second},
59+
};
60+
61+
DNNLStream::Get()->RegisterPrimArgs(*fwd, args);
62+
CommitOutput(outputs[0], out_mem);
63+
DNNLStream::Get()->Submit();
64+
}
65+
66+
bool SupportDNNLBinary(const std::vector<NDArray>& inputs) {
67+
auto dtype = inputs[0].dtype();
68+
auto ndim_0 = inputs[0].shape().ndim();
69+
auto ndim_1 = inputs[1].shape().ndim();
70+
return ndim_0 >= 1 && ndim_0 <= 6 && ndim_1 >= 1 && ndim_1 <= 6 &&
71+
inputs[0].shape().Size() != 0 && inputs[1].shape().Size() != 0 &&
72+
dtype == mshadow::kFloat32 && dtype == inputs[1].dtype();
73+
}
74+
75+
} // namespace op
76+
} // namespace mxnet
77+
78+
#endif // MXNET_USE_ONEDNN == 1

src/operator/numpy/np_elemwise_broadcast_op.h

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -851,6 +851,53 @@ void NumpyBinaryBackwardUseIn(const nnvm::NodeAttrs& attrs,
851851
}
852852
}
853853

854+
#if MXNET_USE_ONEDNN == 1
855+
inline bool NumpyBinaryBroadcastStorageType(const nnvm::NodeAttrs& attrs,
856+
const int dev_mask,
857+
DispatchMode* dispatch_mode,
858+
std::vector<int>* in_attrs,
859+
std::vector<int>* out_attrs) {
860+
CHECK_EQ(in_attrs->size(), 2);
861+
CHECK_EQ(out_attrs->size(), 1);
862+
863+
return DNNLStorageType(attrs, dev_mask, true, dispatch_mode, in_attrs, out_attrs);
864+
}
865+
866+
void NumpyDivideBroadcastComputeCPU(const nnvm::NodeAttrs& attrs,
867+
const OpContext& ctx,
868+
const std::vector<TBlob>& inputs,
869+
const std::vector<OpReqType>& req,
870+
const std::vector<TBlob>& outputs);
871+
872+
template <typename OP>
873+
void NumpyBinaryOperatorComputeExCPU(const nnvm::NodeAttrs& attrs,
874+
const OpContext& ctx,
875+
const std::vector<mxnet::NDArray>& inputs,
876+
const std::vector<OpReqType>& req,
877+
const std::vector<mxnet::NDArray>& outputs) {
878+
if (SupportDNNLBinary(inputs)) {
879+
const dnnl::algorithm alg = DNNLAlgorithm<OP>::value;
880+
DNNLRun(DNNLBinaryOpForward<alg>, attrs, ctx, inputs, req, outputs);
881+
return;
882+
}
883+
using namespace op::mshadow_op;
884+
std::vector<mxnet::TBlob> in_data = {inputs[0].data(), inputs[1].data()};
885+
std::vector<mxnet::TBlob> out_data = {outputs[0].data()};
886+
if (std::is_same<OP, plus>::value) {
887+
NumpyBinaryBroadcastComputeWithBool<cpu, OP, mixed_plus, mixed_plus>(
888+
attrs, ctx, in_data, req, out_data);
889+
} else if (std::is_same<OP, minus>::value) {
890+
NumpyBinaryBroadcastCompute<cpu, OP, mixed_minus, mixed_rminus>(
891+
attrs, ctx, in_data, req, out_data);
892+
} else if (std::is_same<OP, mul>::value) {
893+
NumpyBinaryBroadcastComputeWithBool<cpu, OP, mixed_mul, mixed_mul>(
894+
attrs, ctx, in_data, req, out_data);
895+
} else if (std::is_same<OP, div>::value) {
896+
NumpyDivideBroadcastComputeCPU(attrs, ctx, in_data, req, out_data);
897+
}
898+
}
899+
#endif // MXNET_USE_ONEDNN
900+
854901
#define MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR(name) \
855902
NNVM_REGISTER_OP(name) \
856903
.set_num_inputs(1) \

src/operator/numpy/np_elemwise_broadcast_op_add.cc

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,10 @@ MXNET_OPERATOR_REGISTER_NP_BINARY_MIXED_PRECISION(_npi_add)
3333
op::mshadow_op::plus,
3434
op::mshadow_op::mixed_plus,
3535
op::mshadow_op::mixed_plus>)
36+
#if MXNET_USE_ONEDNN == 1
37+
.set_attr<FComputeEx>("FComputeEx<cpu>", NumpyBinaryOperatorComputeExCPU<op::mshadow_op::plus>)
38+
.set_attr<FInferStorageType>("FInferStorageType", NumpyBinaryBroadcastStorageType)
39+
#endif // MXNET_USE_ONEDNN
3640
.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseIn{"_backward_npi_broadcast_add"});
3741

3842
NNVM_REGISTER_OP(_backward_npi_broadcast_add)

src/operator/numpy/np_elemwise_broadcast_op_mul.cc

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,10 @@ MXNET_OPERATOR_REGISTER_NP_BINARY_MIXED_PRECISION(_npi_multiply)
3333
op::mshadow_op::mul,
3434
op::mshadow_op::mixed_mul,
3535
op::mshadow_op::mixed_mul>)
36+
#if MXNET_USE_ONEDNN == 1
37+
.set_attr<FComputeEx>("FComputeEx<cpu>", NumpyBinaryOperatorComputeExCPU<op::mshadow_op::mul>)
38+
.set_attr<FInferStorageType>("FInferStorageType", NumpyBinaryBroadcastStorageType)
39+
#endif // MXNET_USE_ONEDNN
3640
.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseIn{"_backward_npi_broadcast_mul"});
3741

3842
NNVM_REGISTER_OP(_backward_npi_broadcast_mul)

src/operator/numpy/np_elemwise_broadcast_op_sub.cc

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,10 @@ MXNET_OPERATOR_REGISTER_NP_BINARY_MIXED_PRECISION(_npi_subtract)
3333
op::mshadow_op::minus,
3434
op::mshadow_op::mixed_minus,
3535
op::mshadow_op::mixed_rminus>)
36+
#if MXNET_USE_ONEDNN == 1
37+
.set_attr<FComputeEx>("FComputeEx<cpu>", NumpyBinaryOperatorComputeExCPU<op::mshadow_op::minus>)
38+
.set_attr<FInferStorageType>("FInferStorageType", NumpyBinaryBroadcastStorageType)
39+
#endif // MXNET_USE_ONEDNN
3640
.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseIn{"_backward_npi_broadcast_sub"});
3741

3842
NNVM_REGISTER_OP(_backward_npi_broadcast_sub)

src/operator/numpy/np_true_divide.cc

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,16 @@ bool TrueDivideType(const nnvm::NodeAttrs& attrs,
6161
return true;
6262
}
6363

64+
#if MXNET_USE_ONEDNN == 1
65+
void NumpyDivideBroadcastComputeCPU(const nnvm::NodeAttrs& attrs,
66+
const OpContext& ctx,
67+
const std::vector<TBlob>& inputs,
68+
const std::vector<OpReqType>& req,
69+
const std::vector<TBlob>& outputs) {
70+
TrueDivideBroadcastCompute<cpu>(attrs, ctx, inputs, req, outputs);
71+
}
72+
#endif // MXNET_USE_ONEDNN
73+
6474
NNVM_REGISTER_OP(_npi_true_divide)
6575
.set_num_inputs(2)
6676
.set_num_outputs(1)
@@ -79,6 +89,10 @@ NNVM_REGISTER_OP(_npi_true_divide)
7989
return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};
8090
})
8191
.set_attr<FCompute>("FCompute<cpu>", TrueDivideBroadcastCompute<cpu>)
92+
#if MXNET_USE_ONEDNN == 1
93+
.set_attr<FComputeEx>("FComputeEx<cpu>", NumpyBinaryOperatorComputeExCPU<op::mshadow_op::div>)
94+
.set_attr<FInferStorageType>("FInferStorageType", NumpyBinaryBroadcastStorageType)
95+
#endif // MXNET_USE_ONEDNN
8296
.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseIn{"_backward_npi_broadcast_div"})
8397
.add_argument("lhs", "NDArray-or-Symbol", "Dividend array")
8498
.add_argument("rhs", "NDArray-or-Symbol", "Divisor array");

src/operator/tensor/elemwise_binary_broadcast_op.h

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,8 +91,14 @@ inline bool BinaryBroadcastMulStorageType(const nnvm::NodeAttrs& attrs,
9191
int& out_stype = out_attrs->at(0);
9292
bool dispatched = false;
9393
if (!dispatched && common::ContainsOnlyStorage(*in_attrs, kDefaultStorage)) {
94+
#if MXNET_USE_ONEDNN == 1
95+
if (dev_mask == mshadow::cpu::kDevMask && DNNLEnvSet())
96+
dispatched = storage_type_assign(
97+
&out_stype, kDefaultStorage, dispatch_mode, DispatchMode::kFComputeEx);
98+
#else
9499
dispatched =
95100
storage_type_assign(&out_stype, kDefaultStorage, dispatch_mode, DispatchMode::kFCompute);
101+
#endif // MXNET_USE_ONEDNN == 1
96102
}
97103
if (!dispatched && lhs_stype == kCSRStorage && rhs_stype == kDefaultStorage) {
98104
dispatched =
@@ -116,8 +122,14 @@ inline bool BinaryBroadcastAddStorageType(const nnvm::NodeAttrs& attrs,
116122
int& out_stype = out_attrs->at(0);
117123
bool dispatched = false;
118124
if (!dispatched && common::ContainsOnlyStorage(*in_attrs, kDefaultStorage)) {
125+
#if MXNET_USE_ONEDNN == 1
126+
if (dev_mask == mshadow::cpu::kDevMask && DNNLEnvSet())
127+
dispatched = storage_type_assign(
128+
&out_stype, kDefaultStorage, dispatch_mode, DispatchMode::kFComputeEx);
129+
#else
119130
dispatched =
120131
storage_type_assign(&out_stype, kDefaultStorage, dispatch_mode, DispatchMode::kFCompute);
132+
#endif // MXNET_USE_ONEDNN == 1
121133
}
122134
if (!dispatched && ((lhs_stype == kCSRStorage && rhs_stype == kDefaultStorage) ||
123135
(lhs_stype == kDefaultStorage && rhs_stype == kCSRStorage))) {
@@ -788,6 +800,35 @@ void BinaryBroadcastBackwardUseIn(const nnvm::NodeAttrs& attrs,
788800
}
789801
}
790802

803+
#if MXNET_USE_ONEDNN == 1
804+
template <dnnl::algorithm alg>
805+
void DNNLBinaryOpForward(const nnvm::NodeAttrs& attrs,
806+
const OpContext& ctx,
807+
const std::vector<NDArray>& inputs,
808+
const std::vector<OpReqType>& req,
809+
const std::vector<NDArray>& outputs);
810+
811+
// template struct converting op::mshadow_op to dnnl::algorithm
812+
template <typename OP>
813+
struct DNNLAlgorithm {};
814+
template <>
815+
struct DNNLAlgorithm<op::mshadow_op::plus> {
816+
static const dnnl::algorithm value = dnnl::algorithm::binary_add;
817+
};
818+
template <>
819+
struct DNNLAlgorithm<op::mshadow_op::minus> {
820+
static const dnnl::algorithm value = dnnl::algorithm::binary_sub;
821+
};
822+
template <>
823+
struct DNNLAlgorithm<op::mshadow_op::mul> {
824+
static const dnnl::algorithm value = dnnl::algorithm::binary_mul;
825+
};
826+
template <>
827+
struct DNNLAlgorithm<op::mshadow_op::div> {
828+
static const dnnl::algorithm value = dnnl::algorithm::binary_div;
829+
};
830+
#endif // MXNET_USE_ONEDNN == 1
831+
791832
#define MXNET_OPERATOR_REGISTER_BINARY_BROADCAST(name) \
792833
NNVM_REGISTER_OP(name) \
793834
.set_num_inputs(2) \

0 commit comments

Comments
 (0)