Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.

Commit 5abdc77

Browse files
authored
[FEATURE] Add _npi_power_scalar and _npi_multiply_scalar fuse (#20976)
* [FEATURE] Add _npi_power_scalar and _npi_multiply_scalar fuse * Merge _npi_power_scalar implementation with implementation of this fuse * Fix clang * Fix CI * Fix review and simplify the implementation * Add checks for the amount of inputs and outputs * Fix CI * Add Reset() function * Fix DNNLPowMulScalarShape and Type functions * Fix DNNLPowMulScalarType * Fix DNNLPowMulScalarType * Add generic implementation for sq_pow_mul_scalar operator * Fix sanity * Fix req * Add Filter method to property * Add new line * Fix gpu CI * Add '_sg_pow_mul_scalar' to symbol_fp16.py * Fix CI on MacOS * Fix SupportDNNL* * Make PowMulScalarCompute more readable * Fix PowMulScalarCompute * Fix memory usage * Fix build
1 parent 9745d36 commit 5abdc77

23 files changed

Lines changed: 547 additions & 152 deletions

include/mxnet/op_attr_types.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ enum OpReqType {
5858
};
5959

6060
/*!
61-
* \brief All the possible information needed by Operator.Forward and Backward
61+
* \brief All the possible information needed by Operator.
6262
* This is the superset of RunContext.
6363
* We use this data structure to bookkeep everything needed by Forward and Backward.
6464
* \sa Resource

include/mxnet/tensor_blob.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -210,7 +210,7 @@ class TBlob {
210210
CHECK(Device::kDevMask == this->dev_mask())
211211
<< "TBlob.get: device type do not match specified type";
212212
CHECK(mshadow::DataType<DType>::kFlag == type_flag_)
213-
<< "TBlob.get_with_shape: data type do not match specified type."
213+
<< "TBlob.get_with_shape: data type do not match specified type. "
214214
<< "Expected: " << mshadow::dtype_string(type_flag_) << " v.s. given "
215215
<< mshadow::dtype_string(mshadow::DataType<DType>::kFlag);
216216
return mshadow::Tensor<Device, 2, DType>(static_cast<DType*>(dptr_), shape_.FlatTo2D(), stream);
@@ -248,7 +248,7 @@ class TBlob {
248248
template <typename DType>
249249
inline DType* dptr() const {
250250
CHECK(mshadow::DataType<DType>::kFlag == type_flag_)
251-
<< "TBlob.get_with_shape: data type do not match specified type."
251+
<< "TBlob.get_with_shape: data type do not match specified type. "
252252
<< "Expected: " << mshadow::dtype_string(type_flag_) << " v.s. given "
253253
<< mshadow::dtype_string(mshadow::DataType<DType>::kFlag);
254254
return static_cast<DType*>(dptr_);

python/mxnet/amp/lists/symbol_fp16.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -636,7 +636,8 @@
636636
'_sg_onednn_fully_connected',
637637
'_sg_onednn_selfatt_qk',
638638
'_sg_onednn_selfatt_valatt',
639-
'_sg_onednn_batch_dot'
639+
'_sg_onednn_batch_dot',
640+
'_sg_pow_mul_scalar'
640641
])
641642

642643
# Functions that have to be cast to FP32 only for

src/operator/nn/dnnl/dnnl_base-inl.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -254,7 +254,6 @@ bool SupportDNNLLeakyRelu(const LeakyReLUParam& param);
254254
bool SupportDNNLLeakyRelu(const LeakyReLUParam& param, const NDArray& input);
255255
bool SupportDNNLLogSoftmax(const SoftmaxParam& param, const NDArray& input);
256256
bool SupportDNNLMaskedSoftmax(const MaskedSoftmaxParam& param, const std::vector<NDArray>& input);
257-
bool SupportDNNLPower(const NDArray& input);
258257
bool SupportDNNLQuantizedAct(const ActivationParam& param);
259258
bool SupportDNNLReshape(const NDArray& input);
260259
bool SupportDNNLSlice(const SliceParam& param, const NDArray& input, const NDArray& output);
Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
/*!
21+
* \file dnnl_pow_mul_scalar-inl.h
22+
*/
23+
24+
#ifndef MXNET_OPERATOR_NN_DNNL_DNNL_POW_MUL_SCALAR_INL_H_
25+
#define MXNET_OPERATOR_NN_DNNL_DNNL_POW_MUL_SCALAR_INL_H_
26+
27+
#if MXNET_USE_ONEDNN == 1
28+
29+
#include <vector>
30+
31+
#include "operator/tensor/elemwise_binary_scalar_op.h"
32+
33+
namespace mxnet {
34+
namespace op {
35+
36+
struct DNNLPowMulScalarParam : public dmlc::Parameter<DNNLPowMulScalarParam> {
37+
float exponent;
38+
float multiplier;
39+
bool exp_is_int;
40+
bool mul_is_int;
41+
42+
DMLC_DECLARE_PARAMETER(DNNLPowMulScalarParam) {
43+
DMLC_DECLARE_FIELD(exponent).describe("Exponent for power operation.").set_default(1);
44+
DMLC_DECLARE_FIELD(multiplier).describe("Multiplier for multiply operation.").set_default(1);
45+
DMLC_DECLARE_FIELD(exp_is_int)
46+
.describe("Indicate whether exponent is int type.")
47+
.set_default(true);
48+
DMLC_DECLARE_FIELD(mul_is_int)
49+
.describe("Indicate whether multiplier is int type.")
50+
.set_default(true);
51+
}
52+
53+
bool operator==(const DNNLPowMulScalarParam& other) const {
54+
return this->exponent == other.exponent && this->multiplier == other.multiplier &&
55+
this->exp_is_int == other.exp_is_int && this->mul_is_int == other.mul_is_int;
56+
}
57+
};
58+
59+
using eltwise_fwd_t = dnnl::eltwise_forward;
60+
using eltwise_fwd_pd_t = dnnl::eltwise_forward::primitive_desc;
61+
62+
typedef ParamOpSign<DNNLPowMulScalarParam> DNNLPowMulScalarSignature;
63+
64+
class DNNLPowMulScalarFwd {
65+
public:
66+
static DNNLPowMulScalarFwd& GetCached(const DNNLPowMulScalarParam& param,
67+
const NDArray& input,
68+
const NDArray& output);
69+
70+
DNNLPowMulScalarFwd(const DNNLPowMulScalarParam& param, const NDArray& input);
71+
72+
void Execute(const NDArray& input, const OpReqType& req, const NDArray& output);
73+
74+
private:
75+
std::shared_ptr<eltwise_fwd_t> fwd;
76+
std::shared_ptr<eltwise_fwd_pd_t> fwd_pd;
77+
};
78+
79+
template <bool subgraph>
80+
inline void DNNLPowMulScalarForward(const nnvm::NodeAttrs& attrs,
81+
const OpContext& ctx,
82+
const std::vector<NDArray>& inputs,
83+
const std::vector<OpReqType>& req,
84+
const std::vector<NDArray>& outputs) {
85+
DNNLPowMulScalarParam param;
86+
if (subgraph) {
87+
param = nnvm::get<DNNLPowMulScalarParam>(attrs.parsed);
88+
} else {
89+
param.multiplier = 1;
90+
param.exponent = nnvm::get<NumpyBinaryScalarParam>(attrs.parsed).scalar;
91+
}
92+
DNNLPowMulScalarFwd& fwd = DNNLPowMulScalarFwd::GetCached(param, inputs[0], outputs[0]);
93+
fwd.Execute(inputs[0], req[0], outputs[0]);
94+
}
95+
96+
} // namespace op
97+
} // namespace mxnet
98+
99+
#endif // MXNET_USE_ONEDNN == 1
100+
#endif // MXNET_OPERATOR_NN_DNNL_DNNL_POW_MUL_SCALAR_INL_H_

src/operator/nn/dnnl/dnnl_power_scalar.cc renamed to src/operator/nn/dnnl/dnnl_pow_mul_scalar.cc

Lines changed: 33 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -18,49 +18,54 @@
1818
*/
1919

2020
/*!
21-
* \file dnnl_power_scalar.cc
22-
* \author: Adam Grabowski, adam.grabowski@intel.com
21+
* \file dnnl_pow_mul_scalar.cc
2322
*/
2423

2524
#if MXNET_USE_ONEDNN == 1
2625

27-
#include "dnnl_power_scalar-inl.h"
26+
#include "dnnl_pow_mul_scalar-inl.h"
2827

2928
namespace mxnet {
3029
namespace op {
3130

32-
DNNLPowerFwd& DNNLPowerFwd::GetPowerForward(const nnvm::NodeAttrs& attrs,
33-
const NDArray& input,
34-
const NDArray& output) {
35-
const NumpyBinaryScalarParam& param = nnvm::get<NumpyBinaryScalarParam>(attrs.parsed);
31+
DMLC_REGISTER_PARAMETER(DNNLPowMulScalarParam);
32+
33+
DNNLPowMulScalarFwd& DNNLPowMulScalarFwd::GetCached(const DNNLPowMulScalarParam& param,
34+
const NDArray& input,
35+
const NDArray& output) {
3636
#if DMLC_CXX11_THREAD_LOCAL
37-
static thread_local std::unordered_map<DNNLPowerSignature, DNNLPowerFwd, OpHash> fwds;
37+
static thread_local std::unordered_map<DNNLPowMulScalarSignature, DNNLPowMulScalarFwd, OpHash>
38+
fwds;
3839
#else
39-
static MX_THREAD_LOCAL std::unordered_map<DNNLPowerSignature, DNNLPowerFwd, OpHash> fwds;
40+
static MX_THREAD_LOCAL std::unordered_map<DNNLPowMulScalarSignature, DNNLPowMulScalarFwd, OpHash>
41+
fwds;
4042
#endif
41-
DNNLPowerSignature key;
42-
key.AddSign(static_cast<float>(param.scalar));
43+
DNNLPowMulScalarSignature key(param);
4344
key.AddSign(input);
4445
key.AddSign(output);
4546

4647
auto it = fwds.find(key);
4748
if (it == fwds.end()) {
48-
const DNNLPowerFwd fwd(input, static_cast<float>(param.scalar));
49+
const DNNLPowMulScalarFwd fwd(param, input);
4950
it = AddToCache(&fwds, key, fwd);
5051
}
5152
return it->second;
5253
}
5354

54-
DNNLPowerFwd::DNNLPowerFwd(const NDArray& input, const float exponent) {
55+
DNNLPowMulScalarFwd::DNNLPowMulScalarFwd(const DNNLPowMulScalarParam& param, const NDArray& input) {
5556
auto src_desc = input.GetDNNLData()->get_desc();
56-
dnnl::eltwise_forward::desc fwd_desc(
57-
dnnl::prop_kind::forward_scoring, dnnl::algorithm::eltwise_pow, src_desc, 1, exponent);
57+
dnnl::eltwise_forward::desc fwd_desc(dnnl::prop_kind::forward_scoring,
58+
dnnl::algorithm::eltwise_pow,
59+
src_desc,
60+
param.multiplier,
61+
param.exponent);
5862
fwd_pd = std::make_shared<eltwise_fwd_pd_t>(fwd_desc, mxnet::CpuEngine::Get()->get_engine());
5963
fwd = std::make_shared<eltwise_fwd_t>(*fwd_pd);
6064
}
6165

62-
void DNNLPowerFwd::Execute(const NDArray& input, const OpReqType& req, const NDArray& output) {
63-
auto engine = mxnet::CpuEngine::Get()->get_engine();
66+
void DNNLPowMulScalarFwd::Execute(const NDArray& input,
67+
const OpReqType& req,
68+
const NDArray& output) {
6469
auto src = input.GetDNNLData();
6570
dnnl_output_t out_mem = CreateDNNLMem(output, fwd_pd->dst_desc(), req, &input);
6671

@@ -73,22 +78,18 @@ void DNNLPowerFwd::Execute(const NDArray& input, const OpReqType& req, const NDA
7378
CommitOutput(output, out_mem);
7479
DNNLStream::Get()->Submit();
7580
}
76-
77-
void DNNLPowerForward(const nnvm::NodeAttrs& attrs,
78-
const OpContext& ctx,
79-
const NDArray& input,
80-
const OpReqType& req,
81-
const NDArray& output) {
82-
DNNLPowerFwd& fwd = DNNLPowerFwd::GetPowerForward(attrs, input, output);
83-
fwd.Execute(input, req, output);
84-
}
85-
86-
bool SupportDNNLPower(const NDArray& input) {
87-
return input.shape().Size() != 0 && input.shape().ndim() > 0 && input.shape().ndim() <= 6 &&
88-
input.dtype() == mshadow::kFloat32;
89-
}
90-
9181
} // namespace op
9282
} // namespace mxnet
9383

84+
namespace std {
85+
template <>
86+
struct hash<mxnet::op::DNNLPowMulScalarParam> {
87+
size_t operator()(const mxnet::op::DNNLPowMulScalarParam& val) {
88+
size_t ret = 0;
89+
ret = dmlc::HashCombine(ret, val.exponent);
90+
ret = dmlc::HashCombine(ret, val.multiplier);
91+
return ret;
92+
}
93+
};
94+
} // namespace std
9495
#endif // MXNET_USE_ONEDNN == 1

src/operator/nn/dnnl/dnnl_power_scalar-inl.h

Lines changed: 0 additions & 66 deletions
This file was deleted.

src/operator/subgraph/common.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,8 @@
2929
namespace mxnet {
3030
namespace op {
3131

32+
enum SelectStatus { kFail = 0, kStart, kSuccess };
33+
3234
inline uint32_t DefaultSubgraphOpNumInputs(const nnvm::NodeAttrs& attrs) {
3335
const nnvm::Symbol& sym = *attrs.subgraphs[0];
3436
return sym.ListInputNames(nnvm::Symbol::kAll).size();

src/operator/subgraph/dnnl/dnnl_bn_relu_property.h

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,6 @@ namespace op {
3535

3636
class SgDNNLBNReLUSelector : public SubgraphSelector {
3737
public:
38-
enum SelectStatus { kStart, kSuccess, kFail };
39-
4038
explicit SgDNNLBNReLUSelector(const bool disable_bn_relu)
4139
: disable_bn_relu_(disable_bn_relu), status_(kStart) {}
4240

src/operator/subgraph/dnnl/dnnl_conv_property.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ namespace op {
3737
class SgDNNLConvSelector : public SubgraphSelector {
3838
public:
3939
/*! \brief pattern match status_ */
40-
enum SelectStatus {
40+
enum SelectStatusConv {
4141
kFail = 0,
4242
kStart,
4343
kBN,
@@ -51,7 +51,7 @@ class SgDNNLConvSelector : public SubgraphSelector {
5151
bool disable_conv_act_;
5252
bool disable_conv_sum_;
5353
bool quantize_;
54-
SelectStatus status_;
54+
SelectStatusConv status_;
5555
std::vector<const nnvm::Node*> matched_list_;
5656

5757
public:

0 commit comments

Comments
 (0)