@@ -116,7 +116,8 @@ class SgDNNLSelfAttQKOp {
116116 void Forward (const OpContext& ctx,
117117 const std::vector<NDArray>& inputs,
118118 const std::vector<OpReqType>& req,
119- const std::vector<NDArray>& outputs);
119+ const std::vector<NDArray>& outputs,
120+ bool already_prepared);
120121
121122 void Backward (const OpContext& ctx,
122123 const std::vector<NDArray>& inputs,
@@ -163,10 +164,12 @@ static void SgDNNLSelfAttQKForward(const OpStatePtr& state_pointer,
163164 const std::vector<OpReqType>& req,
164165 const std::vector<NDArray>& outputs) {
165166 SgDNNLSelfAttQKOp& op = state_pointer.get_state <SgDNNLSelfAttQKOp>();
167+ bool already_prepared = false ;
166168 if (!op.IsInitialized ()) {
167169 op.Initialize (ctx, inputs, req, outputs);
170+ already_prepared = true ;
168171 }
169- op.Forward (ctx, inputs, req, outputs);
172+ op.Forward (ctx, inputs, req, outputs, already_prepared );
170173}
171174
172175static bool SgDNNLSelfAttStorageType (const nnvm::NodeAttrs& attrs,
@@ -264,21 +267,23 @@ void SgDNNLSelfAttQKOp::Initialize(const OpContext& ctx,
264267void SgDNNLSelfAttQKOp::Forward (const OpContext& ctx,
265268 const std::vector<NDArray>& inputs,
266269 const std::vector<OpReqType>& req,
267- const std::vector<NDArray>& outputs) {
268- const size_t output_lin_dim = inputs[0 ].shape ()[2 ];
269- const size_t embed_dim = output_lin_dim / QKV_NUM;
270-
271- MSHADOW_TYPE_SWITCH (inputs[0 ].dtype (), DType, {
272- DType* query_mem_ptr = inputs[0 ].data ().dptr <DType>();
273- DType* key_mem_ptr = query_mem_ptr + embed_dim;
274- cached_query_mem_->set_data_handle (query_mem_ptr);
275- cached_key_mem_->set_data_handle (key_mem_ptr);
276- });
277-
278- MSHADOW_TYPE_SWITCH (outputs[0 ].dtype (), DType, {
279- cached_out_mem_->set_data_handle (outputs[0 ].data ().dptr <DType>());
280- });
281-
270+ const std::vector<NDArray>& outputs,
271+ bool already_prepared) {
272+ if (!already_prepared) {
273+ const size_t output_lin_dim = inputs[0 ].shape ()[2 ];
274+ const size_t embed_dim = output_lin_dim / QKV_NUM;
275+
276+ MSHADOW_TYPE_SWITCH (inputs[0 ].dtype (), DType, {
277+ DType* query_mem_ptr = inputs[0 ].data ().dptr <DType>();
278+ DType* key_mem_ptr = query_mem_ptr + embed_dim;
279+ cached_query_mem_->set_data_handle (query_mem_ptr);
280+ cached_key_mem_->set_data_handle (key_mem_ptr);
281+ });
282+
283+ MSHADOW_TYPE_SWITCH (outputs[0 ].dtype (), DType, {
284+ cached_out_mem_->set_data_handle (outputs[0 ].data ().dptr <DType>());
285+ });
286+ }
282287 DNNLStream::Get ()->RegisterPrimArgs (*fwd_, args_);
283288 DNNLStream::Get ()->Submit ();
284289
@@ -484,7 +489,8 @@ class DNNLSelfAttValAttOp {
484489 void Forward (const OpContext& ctx,
485490 const std::vector<NDArray>& inputs,
486491 const std::vector<OpReqType>& req,
487- const std::vector<NDArray>& outputs);
492+ const std::vector<NDArray>& outputs,
493+ bool already_prepared);
488494
489495 void Backward (const OpContext& ctx,
490496 const std::vector<NDArray>& inputs,
@@ -538,10 +544,12 @@ static void DNNLSelfAttValAttForward(const OpStatePtr& state_pointer,
538544 const std::vector<OpReqType>& req,
539545 const std::vector<NDArray>& outputs) {
540546 DNNLSelfAttValAttOp& op = state_pointer.get_state <DNNLSelfAttValAttOp>();
547+ bool already_prepared = false ;
541548 if (!op.IsInitialized ()) {
542549 op.Initialize (ctx, inputs, req, outputs);
550+ already_prepared = true ;
543551 }
544- op.Forward (ctx, inputs, req, outputs);
552+ op.Forward (ctx, inputs, req, outputs, already_prepared );
545553}
546554
547555void DNNLSelfAttValAttOp::Initialize (const OpContext& ctx,
@@ -664,29 +672,31 @@ void DNNLSelfAttValAttOp::Initialize(const OpContext& ctx,
664672void DNNLSelfAttValAttOp::Forward (const OpContext& ctx,
665673 const std::vector<NDArray>& inputs,
666674 const std::vector<OpReqType>& req,
667- const std::vector<NDArray>& outputs) {
668- // multiply by 2 as we need to skip queries and keys
669- const size_t value_offset = inputs[1 ].shape ()[2 ] / QKV_NUM * 2 ;
670-
671- auto att_buffer = inputs[0 ];
672- if (att_buffer.IsDNNLData ())
673- att_buffer = att_buffer.Reorder2Default ();
674-
675- MSHADOW_TYPE_SWITCH (att_buffer.dtype (), DType, {
676- DType* attention_ptr = att_buffer.data ().dptr <DType>();
677- cached_att_mem_->set_data_handle (attention_ptr);
678- });
679-
680- MSHADOW_TYPE_SWITCH (inputs[1 ].dtype (), DType, {
681- DType* qkv_ptr = inputs[1 ].data ().dptr <DType>();
682- DType* value_mem_ptr = qkv_ptr + value_offset;
683- cached_value_mem_->set_data_handle (value_mem_ptr);
684- });
685-
686- MSHADOW_TYPE_SWITCH (outputs[0 ].dtype (), DType, {
687- cached_transposed_mem_->set_data_handle (outputs[0 ].data ().dptr <DType>());
688- });
689-
675+ const std::vector<NDArray>& outputs,
676+ bool already_prepared) {
677+ if (!already_prepared) {
678+ // multiply by 2 as we need to skip queries and keys
679+ const size_t value_offset = inputs[1 ].shape ()[2 ] / QKV_NUM * 2 ;
680+
681+ auto att_buffer = inputs[0 ];
682+ if (att_buffer.IsDNNLData ())
683+ att_buffer = att_buffer.Reorder2Default ();
684+
685+ MSHADOW_TYPE_SWITCH (att_buffer.dtype (), DType, {
686+ DType* attention_ptr = att_buffer.data ().dptr <DType>();
687+ cached_att_mem_->set_data_handle (attention_ptr);
688+ });
689+
690+ MSHADOW_TYPE_SWITCH (inputs[1 ].dtype (), DType, {
691+ DType* qkv_ptr = inputs[1 ].data ().dptr <DType>();
692+ DType* value_mem_ptr = qkv_ptr + value_offset;
693+ cached_value_mem_->set_data_handle (value_mem_ptr);
694+ });
695+
696+ MSHADOW_TYPE_SWITCH (outputs[0 ].dtype (), DType, {
697+ cached_transposed_mem_->set_data_handle (outputs[0 ].data ().dptr <DType>());
698+ });
699+ }
690700 DNNLStream::Get ()->RegisterPrimArgs (*fwd_, args_);
691701 DNNLStream::Get ()->RegisterPrimArgs (*reorder_, reorder_args);
692702 DNNLStream::Get ()->Submit ();
0 commit comments