Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.

Commit 2e33e96

Browse files
author
bgawrych
authored
Remove temporary fix for RNN (#19451)
1 parent 6d5d8b9 commit 2e33e96

1 file changed

Lines changed: 10 additions & 23 deletions

File tree

src/operator/nn/mkldnn/mkldnn_rnn.cc

Lines changed: 10 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -47,15 +47,6 @@ inline int GetRnnGatesNum(int mode) {
4747
}
4848
}
4949

50-
// Bug in oneDNN <= 1.6 in memory descriptor comparision operators.
51-
// for specific dims and strides in descriptors == operator can return `true`
52-
// but get_size() function will return different size
53-
// TODO(bgawrych): Remove with oneDNN 1.7 upgrade
54-
static inline bool CheckMemDescEquality(const mkldnn::memory::desc &left,
55-
const mkldnn::memory::desc &right) {
56-
return left == right && left.get_size() == right.get_size();
57-
}
58-
5950
void MKLDNNRnnLayerParam::SetDims() {
6051
const int ngates = GetRnnGatesNum(mode);
6152
//* NOTES: LBR-GRU's new gate formula needs two bias. So it has one more bias with LBR-GRU
@@ -599,13 +590,13 @@ void MKLDNNRnnForwardTraining::SetTrnMem(const MKLDNNRnnForward& fwd) {
599590
weights_iter_ = mkldnn_shared_mem_t(new memory(fwd_trn_.GetIterDesc(), cpu_engine));
600591

601592
// fill weights memory using the reordered weights of fwd_inference primitive
602-
if (CheckMemDescEquality(fwd.weights_layer_r_->get_desc(), fwd_trn_.GetLayerDesc())) {
593+
if (fwd.weights_layer_r_->get_desc() == fwd_trn_.GetLayerDesc()) {
603594
weights_layer_->set_data_handle(fwd.weights_layer_r_->get_data_handle());
604595
} else {
605596
MKLDNNMemoryReorder(*fwd.weights_layer_r_, *weights_layer_);
606597
}
607598

608-
if (CheckMemDescEquality(fwd.weights_iter_r_->get_desc(), fwd_trn_.GetIterDesc())) {
599+
if (fwd.weights_iter_r_->get_desc() == fwd_trn_.GetIterDesc()) {
609600
weights_iter_->set_data_handle(fwd.weights_iter_r_->get_data_handle());
610601
} else {
611602
MKLDNNMemoryReorder(*fwd.weights_iter_r_, *weights_iter_);
@@ -729,15 +720,15 @@ void MKLDNNRnnBackward::FetchDataWeightsMem(const MKLDNNRnnForwardTraining& fwd)
729720
const mkldnn::memory* valid_mem;
730721
switch (kv.first) {
731722
case MKLDNN_ARG_WEIGHTS_LAYER: {
732-
if (CheckMemDescEquality(bwd_.weights_layer_desc_, fwd.fwd_trn_.GetLayerDesc())) {
723+
if (bwd_.weights_layer_desc_ == fwd.fwd_trn_.GetLayerDesc()) {
733724
this->weights_layer_->set_data_handle(kv.second.get_data_handle());
734725
} else {
735726
MKLDNNMemoryReorder(*fwd.weights_layer_, *this->weights_layer_);
736727
}
737728
valid_mem = this->weights_layer_.get();
738729
} break;
739730
case MKLDNN_ARG_WEIGHTS_ITER: {
740-
if (CheckMemDescEquality(bwd_.weights_iter_desc_, fwd.fwd_trn_.GetIterDesc())) {
731+
if (bwd_.weights_iter_desc_ == fwd.fwd_trn_.GetIterDesc()) {
741732
this->weights_iter_->set_data_handle(kv.second.get_data_handle());
742733
} else {
743734
MKLDNNMemoryReorder(*fwd.weights_iter_, *this->weights_iter_);
@@ -771,14 +762,14 @@ void MKLDNNRnnBackward::SetWeightsGradsMem() {
771762
this->diff_weights_iter_r_ = std::make_shared<mkldnn::memory>(
772763
native_iter_desc, cpu_engine);
773764

774-
if (CheckMemDescEquality(native_layer_desc, bwd_.diff_weights_layer_desc_)) {
765+
if (native_layer_desc == bwd_.diff_weights_layer_desc_) {
775766
this->diff_weights_layer_ = std::make_shared<mkldnn::memory>(
776767
bwd_.diff_weights_layer_desc_, cpu_engine, diff_weights_layer_r_->get_data_handle());
777768
} else {
778769
this->diff_weights_layer_ = std::make_shared<mkldnn::memory>(
779770
bwd_.diff_weights_layer_desc_, cpu_engine);
780771
}
781-
if (CheckMemDescEquality(native_iter_desc, bwd_.diff_weights_iter_desc_)) {
772+
if (native_iter_desc == bwd_.diff_weights_iter_desc_) {
782773
this->diff_weights_iter_ = std::make_shared<mkldnn::memory>(
783774
bwd_.diff_weights_iter_desc_, cpu_engine, diff_weights_iter_r_->get_data_handle());
784775
} else {
@@ -830,12 +821,10 @@ void MKLDNNRnnBackward::SetDataGradsMem(
830821
}
831822

832823
void MKLDNNRnnBackward::SetNativeWeightsGrads() const {
833-
if (!CheckMemDescEquality(this->diff_weights_layer_->get_desc(),
834-
this->diff_weights_layer_r_->get_desc())) {
824+
if (this->diff_weights_layer_->get_desc() != this->diff_weights_layer_r_->get_desc()) {
835825
MKLDNNMemoryReorder(*this->diff_weights_layer_, *this->diff_weights_layer_r_);
836826
}
837-
if (!CheckMemDescEquality(this->diff_weights_iter_->get_desc(),
838-
this->diff_weights_iter_r_->get_desc())) {
827+
if (this->diff_weights_iter_->get_desc() != this->diff_weights_iter_r_->get_desc()) {
839828
MKLDNNMemoryReorder(*this->diff_weights_iter_, *this->diff_weights_iter_r_);
840829
}
841830
}
@@ -854,11 +843,9 @@ void MKLDNNRnnBackward::CommitWeightsGrads(void* diff_weights, void* diff_bias,
854843

855844
void* diff_weights_layer_ptr = this->diff_weights_layer_->get_data_handle();
856845
void* diff_weights_iter_ptr = this->diff_weights_iter_->get_data_handle();
857-
if (!CheckMemDescEquality(this->diff_weights_layer_->get_desc(),
858-
this->diff_weights_layer_r_->get_desc()))
846+
if (this->diff_weights_layer_->get_desc() != this->diff_weights_layer_r_->get_desc())
859847
diff_weights_layer_ptr = this->diff_weights_layer_r_->get_data_handle();
860-
if (!CheckMemDescEquality(this->diff_weights_iter_->get_desc(),
861-
this->diff_weights_iter_r_->get_desc()))
848+
if (this->diff_weights_iter_->get_desc() != this->diff_weights_iter_r_->get_desc())
862849
diff_weights_iter_ptr = this->diff_weights_iter_r_->get_data_handle();
863850

864851
const int num_layer = param.num_layer;

0 commit comments

Comments
 (0)