Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.

Commit ad1ff3a

Browse files
xziyaTaoLv
authored andcommitted
[v1.6.x] Cherry-pick MKL-DNN Rnn operator enhancements to v1.6.x (#17225)
* [MKLDNN] mkldnn RNN operator enhancement (#17075) * mkldnn rnn operator enhancement `add` operation support Rename AddTo Add MXNET_USE_MKLDNN_RNN env Add Env var for switching to naive RNN impl and naive add/copy impl * Re-run CI, op:test_reduce failed on Unix-CPU * Rerun CI, Python2 CPU on Unix-CPU timeout * MKL-DNN RNN backward path enhancement (#17183) * Flush memory before RNN backward primitive * Add gluon rnn unit test for gradients check * Cache reorder * Re-write rnn supporting check * Update OpSignature.AddSign to avoid potential hash collision for rnn-packed memory Get the data type from mkldnn memory descriptor when setting grad handle
1 parent 0015fc3 commit ad1ff3a

9 files changed

Lines changed: 466 additions & 219 deletions

File tree

docs/static_site/src/pages/api/faq/env_var.md

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -283,11 +283,11 @@ If ctypes is used, it must be `mxnet._ctypes.ndarray.NDArrayBase`.
283283
If no such algorithm exists given other constraints, MXNet will error out. This variable affects the choice
284284
of CUDNN convolution algorithms. Please see [CUDNN developer guide](https://docs.nvidia.com/deeplearning/sdk/cudnn-developer-guide/index.html) for more details.
285285

286-
* MXNET_CPU_PARALLEL_COPY_SIZE
286+
* MXNET_CPU_PARALLEL_SIZE
287287
- Values: Int ```(default=200000)```
288-
- The minimum size to call parallel copy by OpenMP in CPU2CPU mode.
289-
- When the array size is bigger than or equal to this threshold, NDArray::Copy(from, to) is implemented by OpenMP with the Recommended OMP Thread Count.
290-
- When the array size is less than this threshold, NDArray::Copy(from , to)) is implemented by memcpy in single thread.
288+
- The minimum size to call parallel operations by OpenMP for CPU context.
289+
- When the array size is bigger than or equal to this threshold, the operation implemented by OpenMP is executed with the Recommended OMP Thread Count.
290+
- When the array size is less than this threshold, the operation is implemented naively in single thread.
291291

292292
* MXNET_OPTIMIZER_AGGREGATION_SIZE
293293
- Values: Int ```(default=4)```
@@ -343,6 +343,10 @@ If ctypes is used, it must be `mxnet._ctypes.ndarray.NDArrayBase`.
343343
- Values: 0(false) or 1(true) ```(default=1)```
344344
- If this variable is set, MXNet will simplify the computation graph, eliminating duplicated operations on the same inputs.
345345

346+
* MXNET_USE_MKLDNN_RNN
347+
- Values: 0(false) or 1(true) ```(default=1)```
348+
- This variable controls whether to use the MKL-DNN backend in fused RNN operator for CPU context. There are two fusion implementations of RNN operator in MXNet. The MKL-DNN implementation has a better performance than the naive one, but the latter is more stable in the backward operation currently.
349+
346350
Settings for Minimum Memory Usage
347351
---------------------------------
348352
- Make sure ```min(MXNET_EXEC_NUM_TEMP, MXNET_GPU_WORKER_NTHREADS) = 1```

src/common/utils.h

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -760,7 +760,7 @@ inline void EmplaceBackZeros(const NDArrayStorageType stype, const mxnet::TShape
760760
*/
761761
template<typename DType>
762762
inline void ParallelCopy(DType* dst, const DType* src, index_t size) {
763-
static index_t copy_block_size = dmlc::GetEnv("MXNET_CPU_PARALLEL_COPY_SIZE", 200000);
763+
static index_t copy_block_size = dmlc::GetEnv("MXNET_CPU_PARALLEL_SIZE", 200000);
764764
if (size >= copy_block_size) {
765765
#pragma omp parallel for num_threads(engine::OpenMP::Get()->GetRecommendedOMPThreadCount())
766766
for (index_t i = 0; i < size; ++i) {
@@ -771,6 +771,24 @@ inline void ParallelCopy(DType* dst, const DType* src, index_t size) {
771771
}
772772
}
773773

774+
/*!
775+
* \breif parallelize add by OpenMP
776+
*/
777+
template<typename DType>
778+
inline void ParallelAdd(DType* dst, const DType* src, index_t size) {
779+
static index_t add_block_size = dmlc::GetEnv("MXNET_CPU_PARALLEL_SIZE", 200000);
780+
if (size >= add_block_size) {
781+
#pragma omp parallel for num_threads(engine::OpenMP::Get()->GetRecommendedOMPThreadCount())
782+
for (index_t i = 0; i < size; ++i) {
783+
dst[i] += src[i];
784+
}
785+
} else {
786+
for (index_t i = 0; i < size; ++i) {
787+
dst[i] += src[i];
788+
}
789+
}
790+
}
791+
774792
/*!
775793
* \brief If numpy compatibility is turned off (default), the shapes passed in
776794
* by users follow the legacy shape definition:

src/operator/nn/mkldnn/mkldnn_base-inl.h

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -132,9 +132,12 @@ static inline bool SupportMKLDNN(int dtype, const mxnet::TShape &shape) {
132132
return dtype == mshadow::kFloat32 && (ndim == 1 || ndim == 2 || ndim == 4);
133133
}
134134

135-
static inline bool SupportMKLDNNRNN(const NDArray &input) {
136-
int ndim = input.shape().ndim();
137-
return (input.dtype() == mshadow::kFloat32) && (ndim == 3);
135+
static inline bool SupportMKLDNNRnn(const NDArray &input) {
136+
if (input.dtype() == mshadow::kFloat32 && input.shape().ndim() == 3
137+
&& dmlc::GetEnv("MXNET_USE_MKLDNN_RNN", 1)) {
138+
return true;
139+
}
140+
return false;
138141
}
139142

140143
static inline bool SupportMKLDNNQuantize(int dtype) {

src/operator/nn/mkldnn/mkldnn_rnn-inl.h

Lines changed: 21 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -120,33 +120,32 @@ class RnnPrimitive {
120120
template<typename rnn_fwd, typename... Args>
121121
static RnnPrimitive Create(Args&&... args) {
122122
RnnPrimitive rnn_fwd_prim;
123-
rnn_fwd_prim.pd_.reset(
124-
new typename rnn_fwd::desc(std::forward<Args>(args)...),
125-
[](typename rnn_fwd::desc* pd) {
126-
delete reinterpret_cast<typename rnn_fwd::desc*>(pd);
123+
auto fwd_desc = typename rnn_fwd::desc(std::forward<Args>(args)...);
124+
rnn_fwd_prim.fwd_pd_.reset(
125+
new typename rnn_fwd::primitive_desc(fwd_desc, CpuEngine::Get()->get_engine()),
126+
[](typename rnn_fwd::primitive_desc* pd) {
127+
delete reinterpret_cast<typename rnn_fwd::primitive_desc*>(pd);
127128
});
128-
const typename rnn_fwd::desc& fwd_desc =
129-
*(reinterpret_cast<typename rnn_fwd::desc*>(rnn_fwd_prim.pd_.get()));
130-
typename rnn_fwd::primitive_desc fwd_pd(fwd_desc, CpuEngine::Get()->get_engine());
131-
rnn_fwd_prim.weights_layer_desc_ = fwd_pd.weights_layer_desc();
132-
rnn_fwd_prim.weights_iter_desc_ = fwd_pd.weights_iter_desc();
133-
rnn_fwd_prim.workspace_desc_ = fwd_pd.workspace_desc();
129+
auto fwd_pd = reinterpret_cast<typename rnn_fwd::primitive_desc*>(rnn_fwd_prim.fwd_pd_.get());
130+
rnn_fwd_prim.weights_layer_desc_ = fwd_pd->weights_layer_desc();
131+
rnn_fwd_prim.weights_iter_desc_ = fwd_pd->weights_iter_desc();
132+
rnn_fwd_prim.workspace_desc_ = fwd_pd->workspace_desc();
134133

135-
rnn_fwd_prim.primitive_ = std::shared_ptr<mkldnn::primitive>(new rnn_fwd(fwd_pd));
134+
rnn_fwd_prim.primitive_ = std::shared_ptr<mkldnn::primitive>(new rnn_fwd(*fwd_pd));
136135

137136
return rnn_fwd_prim;
138137
}
139138

140139
RnnPrimitive() {
141-
this->pd_ = nullptr;
140+
this->fwd_pd_ = nullptr;
142141
this->primitive_ = nullptr;
143142
this->weights_layer_desc_ = mkldnn::memory::desc();
144143
this->weights_iter_desc_ = mkldnn::memory::desc();
145144
this->workspace_desc_ = mkldnn::memory::desc();
146145
}
147146

148147
RnnPrimitive(const RnnPrimitive& rnn_fwd_prim) {
149-
this->pd_ = rnn_fwd_prim.pd_;
148+
this->fwd_pd_ = rnn_fwd_prim.fwd_pd_;
150149
this->primitive_ = rnn_fwd_prim.primitive_;
151150
this->weights_layer_desc_ = rnn_fwd_prim.weights_layer_desc_;
152151
this->weights_iter_desc_ = rnn_fwd_prim.weights_iter_desc_;
@@ -155,7 +154,7 @@ class RnnPrimitive {
155154

156155
RnnPrimitive& operator=(const RnnPrimitive& rnn_fwd_prim) {
157156
if (this != &rnn_fwd_prim) {
158-
this->pd_ = rnn_fwd_prim.pd_;
157+
this->fwd_pd_ = rnn_fwd_prim.fwd_pd_;
159158
this->primitive_ = rnn_fwd_prim.primitive_;
160159
this->weights_layer_desc_ = rnn_fwd_prim.weights_layer_desc_;
161160
this->weights_iter_desc_ = rnn_fwd_prim.weights_iter_desc_;
@@ -165,7 +164,7 @@ class RnnPrimitive {
165164
return *this;
166165
}
167166

168-
const void* GetPrimDesc() const { return pd_.get(); }
167+
const void* GetPrimDesc() const { return fwd_pd_.get(); }
169168
const mkldnn::primitive& GetPrim() const { return *primitive_; }
170169

171170
const mkldnn::memory::desc& GetLayerDesc() const {
@@ -181,7 +180,7 @@ class RnnPrimitive {
181180
}
182181

183182
private:
184-
std::shared_ptr<void> pd_;
183+
std::shared_ptr<void> fwd_pd_;
185184
std::shared_ptr<mkldnn::primitive> primitive_;
186185
mkldnn::memory::desc weights_layer_desc_;
187186
mkldnn::memory::desc weights_iter_desc_;
@@ -370,7 +369,10 @@ class MKLDNNRnnBackward {
370369
void SetDataGradsMem(void* diff_src, void* diff_state, void* diff_statecell,
371370
void* diff_out, void* diff_state_out, void* diff_statecell_out,
372371
const int dtype = mshadow::kFloat32);
373-
void CommitWeightsDiff(void* diff_weights, void* diff_bias, const int dtype = mshadow::kFloat32);
372+
void SetNativeWeightsGrads() const;
373+
void CommitWeightsGrads(void* diff_weights, void* diff_bias,
374+
const OpReqType req,
375+
const int dtype = mshadow::kFloat32);
374376

375377
const mkldnn::primitive& GetBwd() const { return *bwd_.primitive_; }
376378
const mkldnn_args_map_t& GetArgsMap() const { return net_args_; }
@@ -385,6 +387,8 @@ class MKLDNNRnnBackward {
385387

386388
mkldnn_shared_mem_t diff_weights_layer_;
387389
mkldnn_shared_mem_t diff_weights_iter_;
390+
mkldnn_shared_mem_t diff_weights_layer_r_;
391+
mkldnn_shared_mem_t diff_weights_iter_r_;
388392
mkldnn_shared_mem_t diff_bias_;
389393

390394
mkldnn_args_map_t net_args_;

0 commit comments

Comments
 (0)