@@ -47,15 +47,6 @@ inline int GetRnnGatesNum(int mode) {
4747 }
4848}
4949
50- // Bug in oneDNN <= 1.6 in memory descriptor comparision operators.
51- // for specific dims and strides in descriptors == operator can return `true`
52- // but get_size() function will return different size
53- // TODO(bgawrych): Remove with oneDNN 1.7 upgrade
54- static inline bool CheckMemDescEquality (const mkldnn::memory::desc &left,
55- const mkldnn::memory::desc &right) {
56- return left == right && left.get_size () == right.get_size ();
57- }
58-
5950void MKLDNNRnnLayerParam::SetDims () {
6051 const int ngates = GetRnnGatesNum (mode);
6152 // * NOTES: LBR-GRU's new gate formula needs two bias. So it has one more bias with LBR-GRU
@@ -599,13 +590,13 @@ void MKLDNNRnnForwardTraining::SetTrnMem(const MKLDNNRnnForward& fwd) {
599590 weights_iter_ = mkldnn_shared_mem_t (new memory (fwd_trn_.GetIterDesc (), cpu_engine));
600591
601592 // fill weights memory using the reordered weights of fwd_inference primitive
602- if (CheckMemDescEquality ( fwd.weights_layer_r_ ->get_desc (), fwd_trn_.GetLayerDesc () )) {
593+ if (fwd.weights_layer_r_ ->get_desc () == fwd_trn_.GetLayerDesc ()) {
603594 weights_layer_->set_data_handle (fwd.weights_layer_r_ ->get_data_handle ());
604595 } else {
605596 MKLDNNMemoryReorder (*fwd.weights_layer_r_ , *weights_layer_);
606597 }
607598
608- if (CheckMemDescEquality ( fwd.weights_iter_r_ ->get_desc (), fwd_trn_.GetIterDesc () )) {
599+ if (fwd.weights_iter_r_ ->get_desc () == fwd_trn_.GetIterDesc ()) {
609600 weights_iter_->set_data_handle (fwd.weights_iter_r_ ->get_data_handle ());
610601 } else {
611602 MKLDNNMemoryReorder (*fwd.weights_iter_r_ , *weights_iter_);
@@ -729,15 +720,15 @@ void MKLDNNRnnBackward::FetchDataWeightsMem(const MKLDNNRnnForwardTraining& fwd)
729720 const mkldnn::memory* valid_mem;
730721 switch (kv.first ) {
731722 case MKLDNN_ARG_WEIGHTS_LAYER: {
732- if (CheckMemDescEquality ( bwd_.weights_layer_desc_ , fwd.fwd_trn_ .GetLayerDesc () )) {
723+ if (bwd_.weights_layer_desc_ == fwd.fwd_trn_ .GetLayerDesc ()) {
733724 this ->weights_layer_ ->set_data_handle (kv.second .get_data_handle ());
734725 } else {
735726 MKLDNNMemoryReorder (*fwd.weights_layer_ , *this ->weights_layer_ );
736727 }
737728 valid_mem = this ->weights_layer_ .get ();
738729 } break ;
739730 case MKLDNN_ARG_WEIGHTS_ITER: {
740- if (CheckMemDescEquality ( bwd_.weights_iter_desc_ , fwd.fwd_trn_ .GetIterDesc () )) {
731+ if (bwd_.weights_iter_desc_ == fwd.fwd_trn_ .GetIterDesc ()) {
741732 this ->weights_iter_ ->set_data_handle (kv.second .get_data_handle ());
742733 } else {
743734 MKLDNNMemoryReorder (*fwd.weights_iter_ , *this ->weights_iter_ );
@@ -771,14 +762,14 @@ void MKLDNNRnnBackward::SetWeightsGradsMem() {
771762 this ->diff_weights_iter_r_ = std::make_shared<mkldnn::memory>(
772763 native_iter_desc, cpu_engine);
773764
774- if (CheckMemDescEquality ( native_layer_desc, bwd_.diff_weights_layer_desc_ ) ) {
765+ if (native_layer_desc == bwd_.diff_weights_layer_desc_ ) {
775766 this ->diff_weights_layer_ = std::make_shared<mkldnn::memory>(
776767 bwd_.diff_weights_layer_desc_ , cpu_engine, diff_weights_layer_r_->get_data_handle ());
777768 } else {
778769 this ->diff_weights_layer_ = std::make_shared<mkldnn::memory>(
779770 bwd_.diff_weights_layer_desc_ , cpu_engine);
780771 }
781- if (CheckMemDescEquality ( native_iter_desc, bwd_.diff_weights_iter_desc_ ) ) {
772+ if (native_iter_desc == bwd_.diff_weights_iter_desc_ ) {
782773 this ->diff_weights_iter_ = std::make_shared<mkldnn::memory>(
783774 bwd_.diff_weights_iter_desc_ , cpu_engine, diff_weights_iter_r_->get_data_handle ());
784775 } else {
@@ -830,12 +821,10 @@ void MKLDNNRnnBackward::SetDataGradsMem(
830821}
831822
832823void MKLDNNRnnBackward::SetNativeWeightsGrads () const {
833- if (!CheckMemDescEquality (this ->diff_weights_layer_ ->get_desc (),
834- this ->diff_weights_layer_r_ ->get_desc ())) {
824+ if (this ->diff_weights_layer_ ->get_desc () != this ->diff_weights_layer_r_ ->get_desc ()) {
835825 MKLDNNMemoryReorder (*this ->diff_weights_layer_ , *this ->diff_weights_layer_r_ );
836826 }
837- if (!CheckMemDescEquality (this ->diff_weights_iter_ ->get_desc (),
838- this ->diff_weights_iter_r_ ->get_desc ())) {
827+ if (this ->diff_weights_iter_ ->get_desc () != this ->diff_weights_iter_r_ ->get_desc ()) {
839828 MKLDNNMemoryReorder (*this ->diff_weights_iter_ , *this ->diff_weights_iter_r_ );
840829 }
841830}
@@ -854,11 +843,9 @@ void MKLDNNRnnBackward::CommitWeightsGrads(void* diff_weights, void* diff_bias,
854843
855844 void * diff_weights_layer_ptr = this ->diff_weights_layer_ ->get_data_handle ();
856845 void * diff_weights_iter_ptr = this ->diff_weights_iter_ ->get_data_handle ();
857- if (!CheckMemDescEquality (this ->diff_weights_layer_ ->get_desc (),
858- this ->diff_weights_layer_r_ ->get_desc ()))
846+ if (this ->diff_weights_layer_ ->get_desc () != this ->diff_weights_layer_r_ ->get_desc ())
859847 diff_weights_layer_ptr = this ->diff_weights_layer_r_ ->get_data_handle ();
860- if (!CheckMemDescEquality (this ->diff_weights_iter_ ->get_desc (),
861- this ->diff_weights_iter_r_ ->get_desc ()))
848+ if (this ->diff_weights_iter_ ->get_desc () != this ->diff_weights_iter_r_ ->get_desc ())
862849 diff_weights_iter_ptr = this ->diff_weights_iter_r_ ->get_data_handle ();
863850
864851 const int num_layer = param.num_layer ;
0 commit comments