Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.

Commit 50a8ee8

Browse files
author
bgawrych
authored
Improve split operator by oneDNN reorder primitive (#20757)
* Add oneDNN support for array_split operator * benchmark.py * refactor * update * review fixes * fix sanity * fix * review * Apply review comments
1 parent 9fa75b4 commit 50a8ee8

6 files changed

Lines changed: 272 additions & 0 deletions

File tree

src/operator/nn/dnnl/dnnl_base-inl.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -197,6 +197,7 @@ bool SupportDNNLTranspose(const NDArray& data);
197197
bool SupportDNNLBatchDot(const std::vector<NDArray>& inputs, const NDArray& output);
198198
bool SupportDNNLLayerNorm(const LayerNormParam& param, const std::vector<NDArray>& inputs);
199199
bool SupportDNNLReshape(const NDArray& input, const NDArray& output);
200+
bool SupportDNNLSplit(const NDArray& input);
200201
bool SupportDNNLStack(const std::vector<NDArray>& inputs);
201202
} // namespace op
202203

src/operator/nn/dnnl/dnnl_ops-inl.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,12 @@ void DNNLSoftmaxOutputForward(const nnvm::NodeAttrs& attrs,
132132
const std::vector<OpReqType>& req,
133133
const std::vector<NDArray>& out_data);
134134

135+
void DNNLSplitForward(const nnvm::NodeAttrs& attrs,
136+
const OpContext& ctx,
137+
const std::vector<NDArray>& inputs,
138+
const std::vector<OpReqType>& req,
139+
const std::vector<NDArray>& outputs);
140+
135141
/* For sum */
136142
void DNNLSumForward(const nnvm::NodeAttrs& attrs,
137143
const OpContext& ctx,
Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
/*!
21+
* \file dnnl_split-inl.h
22+
*/
23+
24+
#ifndef MXNET_OPERATOR_NN_DNNL_DNNL_SPLIT_INL_H_
25+
#define MXNET_OPERATOR_NN_DNNL_DNNL_SPLIT_INL_H_
26+
27+
#if MXNET_USE_ONEDNN == 1
28+
#include <vector>
29+
30+
#include "./dnnl_base-inl.h"
31+
#include "./dnnl_ops-inl.h"
32+
33+
namespace mxnet {
34+
namespace op {
35+
36+
using split_fwd_t = dnnl::reorder;
37+
using split_fwd_pd_t = dnnl::reorder::primitive_desc;
38+
39+
class DNNLSplitFwd {
40+
public:
41+
struct Tensors {
42+
Tensors(const NDArray& input, const std::vector<NDArray>& outputs);
43+
44+
const NDArray& input;
45+
const std::vector<NDArray>& outputs;
46+
};
47+
48+
static DNNLSplitFwd& GetCached(const SplitParam& param,
49+
const Tensors& tensors,
50+
const TShape& split_pts,
51+
const int split_axis);
52+
53+
DNNLSplitFwd(const Tensors& tensors, const TShape& split_pts, const int split_axis);
54+
55+
void Execute(const Tensors& tensors,
56+
const TShape& split_pts,
57+
const int split_axis,
58+
const std::vector<OpReqType>& req) const;
59+
60+
private:
61+
std::vector<split_fwd_t> split_fwds;
62+
std::vector<split_fwd_pd_t> split_pds;
63+
dnnl::memory::dims strides;
64+
};
65+
66+
} // namespace op
67+
} // namespace mxnet
68+
#endif
69+
#endif // MXNET_OPERATOR_NN_DNNL_DNNL_SPLIT_INL_H_

src/operator/nn/dnnl/dnnl_split.cc

Lines changed: 148 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,148 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
/*!
21+
* \file dnnl_split.cc
22+
*/
23+
24+
#if MXNET_USE_ONEDNN == 1
25+
26+
#include "../../tensor/matrix_op-inl.h"
27+
#include "./dnnl_split-inl.h"
28+
29+
namespace mxnet {
30+
namespace op {
31+
32+
bool SupportDNNLSplit(const NDArray& input) {
33+
static const std::set<int> supported_dtypes = {
34+
mshadow::kFloat32, mshadow::kBfloat16, mshadow::kInt32, mshadow::kInt8, mshadow::kUint8};
35+
return supported_dtypes.count(input.dtype());
36+
}
37+
38+
void DNNLSplitForward(const nnvm::NodeAttrs& attrs,
39+
const OpContext& ctx,
40+
const std::vector<NDArray>& inputs,
41+
const std::vector<OpReqType>& req,
42+
const std::vector<NDArray>& outputs) {
43+
const SplitParam& param = dmlc::get<SplitParam>(attrs.parsed);
44+
const auto tensors = DNNLSplitFwd::Tensors(inputs[0], outputs);
45+
46+
const auto& ishape = tensors.input.shape();
47+
const int split_axis = param.axis >= 0 ? param.axis : param.axis + ishape.ndim();
48+
const mxnet::TShape split_pts =
49+
(param.sections > 0) ? GetSplitIndices(tensors.input.shape(), split_axis, param.sections) :
50+
param.indices;
51+
52+
const auto& fwd = DNNLSplitFwd::GetCached(param, tensors, split_pts, split_axis);
53+
fwd.Execute(tensors, split_pts, split_axis, req);
54+
}
55+
56+
DNNLSplitFwd::Tensors::Tensors(const NDArray& input, const std::vector<NDArray>& outputs)
57+
: input(input), outputs(outputs) {}
58+
59+
typedef ParamOpSign<SplitParam> DNNLSplitSignature;
60+
61+
DNNLSplitFwd& DNNLSplitFwd::GetCached(const SplitParam& param,
62+
const Tensors& tensors,
63+
const TShape& split_pts,
64+
const int split_axis) {
65+
#if DMLC_CXX11_THREAD_LOCAL
66+
static thread_local std::unordered_map<DNNLSplitSignature, DNNLSplitFwd, OpHash> fwds;
67+
#else
68+
static MX_THREAD_LOCAL std::unordered_map<DNNLSplitSignature, DNNLSplitFwd, OpHash> fwds;
69+
#endif
70+
71+
DNNLSplitSignature key(param);
72+
key.AddSign(tensors.input);
73+
key.AddSign(tensors.outputs);
74+
key.AddSign(split_pts);
75+
key.AddSign(split_axis);
76+
auto it = fwds.find(key);
77+
if (it == fwds.end()) {
78+
DNNLSplitFwd fwd(tensors, split_pts, split_axis);
79+
it = AddToCache(&fwds, key, fwd);
80+
}
81+
return it->second;
82+
}
83+
84+
DNNLSplitFwd::DNNLSplitFwd(const Tensors& tensors, const TShape& split_pts, const int split_axis) {
85+
const auto cpu_engine = CpuEngine::Get()->get_engine();
86+
const auto input = tensors.input.Reorder2Default();
87+
const auto& ishape = input.shape();
88+
const auto& dtype = get_dnnl_type(input.dtype());
89+
const auto format_tag = static_cast<dnnl::memory::format_tag>(GetDefaultFormat(ishape.ndim()));
90+
91+
strides = dnnl::memory::dims(ishape.ndim(), 1);
92+
// last dim stride = 1, start loop from the penultimate
93+
for (int i = ishape.ndim() - 2; i >= 0; --i) {
94+
strides[i] = strides[i + 1] * ishape[i + 1];
95+
}
96+
97+
for (int i = 0; i < tensors.outputs.size(); ++i) {
98+
const auto& out = tensors.outputs[i];
99+
if (out.shape().Size() == 0) {
100+
continue;
101+
}
102+
dnnl::memory::dims dnnl_dims(ishape.begin(), ishape.end());
103+
// ending split point is always last dimension
104+
int end_split_pt = (i + 1 >= split_pts.ndim()) ? ishape[split_axis] : split_pts[i + 1];
105+
dnnl_dims[split_axis] = end_split_pt - split_pts[i];
106+
107+
auto in_mem_desc = dnnl::memory::desc(dnnl_dims, dtype, strides);
108+
auto out_mem_desc = dnnl::memory::desc(dnnl_dims, dtype, format_tag);
109+
110+
const auto split_pd = split_fwd_pd_t(cpu_engine, in_mem_desc, cpu_engine, out_mem_desc);
111+
split_pds.emplace_back(split_pd);
112+
split_fwds.emplace_back(split_fwd_t(split_pd));
113+
}
114+
}
115+
116+
void DNNLSplitFwd::Execute(const Tensors& tensors,
117+
const TShape& split_pts,
118+
const int split_axis,
119+
const std::vector<OpReqType>& req) const {
120+
const auto& cpu_engine = CpuEngine::Get()->get_engine();
121+
122+
const auto& input_tensor = tensors.input.Reorder2Default();
123+
int out_idx = 0, primitive_idx = 0;
124+
int axis_offset = strides[split_axis] * GetTypeSize(input_tensor.dtype());
125+
std::byte* input_ptr = reinterpret_cast<std::byte*>(input_tensor.data().dptr_);
126+
127+
for (const auto& out : tensors.outputs) {
128+
if (out.shape().Size() == 0) {
129+
out_idx++;
130+
continue;
131+
}
132+
int offset = split_pts[out_idx] * axis_offset;
133+
auto in_mem = dnnl::memory(split_pds[primitive_idx].src_desc(), cpu_engine, input_ptr + offset);
134+
135+
auto out_mem = CreateDNNLMem(out, split_pds[primitive_idx].dst_desc(), req[out_idx]);
136+
DNNLStream::Get()->RegisterPrimArgs(split_fwds[primitive_idx],
137+
{{DNNL_ARG_SRC, in_mem}, {DNNL_ARG_DST, *out_mem.second}});
138+
139+
CommitOutput(out, out_mem);
140+
++out_idx;
141+
++primitive_idx;
142+
}
143+
DNNLStream::Get()->Submit();
144+
}
145+
146+
} // namespace op
147+
} // namespace mxnet
148+
#endif

src/operator/tensor/matrix_op-inl.h

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3062,6 +3062,11 @@ struct SplitParam : public dmlc::Parameter<SplitParam> {
30623062
(*dict)["squeeze_axis"] = squeeze_axis_s.str();
30633063
(*dict)["sections"] = sections_s.str();
30643064
}
3065+
3066+
bool operator==(const SplitParam& other) const {
3067+
return this->indices == other.indices && this->axis == other.axis &&
3068+
this->squeeze_axis == other.squeeze_axis && this->sections == other.sections;
3069+
}
30653070
}; // struct SplitParam
30663071

30673072
inline mxnet::TShape GetSplitIndices(const mxnet::TShape& ishape, int axis, int sections) {
@@ -3451,6 +3456,17 @@ struct hash<mxnet::op::ExpandDimParam> {
34513456
}
34523457
};
34533458

3459+
template <>
3460+
struct hash<mxnet::op::SplitParam> {
3461+
size_t operator()(const mxnet::op::SplitParam& val) {
3462+
size_t ret = 0;
3463+
ret = dmlc::HashCombine(ret, val.indices);
3464+
ret = dmlc::HashCombine(ret, val.axis);
3465+
ret = dmlc::HashCombine(ret, val.squeeze_axis);
3466+
ret = dmlc::HashCombine(ret, val.sections);
3467+
return ret;
3468+
}
3469+
};
34543470
} // namespace std
34553471

34563472
#endif // MXNET_OPERATOR_TENSOR_MATRIX_OP_INL_H_

src/operator/tensor/matrix_op.cc

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
#include "../nn/dnnl/dnnl_reshape-inl.h"
3131
#include "../nn/dnnl/dnnl_slice-inl.h"
3232
#include "../nn/dnnl/dnnl_transpose-inl.h"
33+
#include "../nn/dnnl/dnnl_split-inl.h"
3334
#endif
3435

3536
namespace mxnet {
@@ -1177,6 +1178,32 @@ Example::
11771178
.add_argument("data", "NDArray-or-Symbol", "Input ndarray")
11781179
.add_arguments(DepthToSpaceParam::__FIELDS__());
11791180

1181+
#if MXNET_USE_ONEDNN == 1
1182+
static void SplitForwardEx(const nnvm::NodeAttrs& attrs,
1183+
const OpContext& op_ctx,
1184+
const std::vector<NDArray>& inputs,
1185+
const std::vector<OpReqType>& req,
1186+
const std::vector<NDArray>& outputs) {
1187+
CHECK(!inputs.empty());
1188+
if (SupportDNNLSplit(inputs[0])) {
1189+
DNNL_OPCHECK_INIT(/*is backward*/ false, outputs.size(), inputs, outputs);
1190+
DNNLRun(DNNLSplitForward, attrs, op_ctx, inputs, req, outputs);
1191+
DNNL_OPCHECK_RUN(SplitOpForward<cpu>, attrs, op_ctx, inputs, req, outputs);
1192+
} else {
1193+
FallBackCompute(SplitOpForward<cpu>, attrs, op_ctx, inputs, req, outputs);
1194+
}
1195+
}
1196+
1197+
inline static bool SplitInferStorageType(const nnvm::NodeAttrs& attrs,
1198+
const int dev_mask,
1199+
DispatchMode* dispatch_mode,
1200+
std::vector<int>* in_attrs,
1201+
std::vector<int>* out_attrs) {
1202+
return DNNLStorageType(
1203+
attrs, dev_mask, /*support onednn*/ true, dispatch_mode, in_attrs, out_attrs);
1204+
}
1205+
#endif // MXNET_USE_ONEDNN == 1
1206+
11801207
NNVM_REGISTER_OP(_split_v2)
11811208
.add_alias("_npi_split")
11821209
.add_alias("_npi_array_split")
@@ -1246,6 +1273,11 @@ Example::
12461273
[](const NodeAttrs& n) {
12471274
return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};
12481275
})
1276+
#if MXNET_USE_ONEDNN == 1
1277+
.set_attr<FComputeEx>("FComputeEx<cpu>", SplitForwardEx)
1278+
.set_attr<bool>("TIsDNNL", true)
1279+
.set_attr<FInferStorageType>("FInferStorageType", SplitInferStorageType)
1280+
#endif
12491281
.set_attr<THasDeterministicOutput>("THasDeterministicOutput", true)
12501282
.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseNone{"_split_v2_backward"})
12511283
.add_argument("data", "NDArray-or-Symbol", "The input")

0 commit comments

Comments
 (0)