Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.

Commit 9266a91

Browse files
author
bgawrych
authored
Optimize preparation of selfattn operators (#20682)
1 parent 30734fb commit 9266a91

1 file changed

Lines changed: 52 additions & 42 deletions

File tree

src/operator/subgraph/dnnl/dnnl_transformer.cc

Lines changed: 52 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,8 @@ class SgDNNLSelfAttQKOp {
116116
void Forward(const OpContext& ctx,
117117
const std::vector<NDArray>& inputs,
118118
const std::vector<OpReqType>& req,
119-
const std::vector<NDArray>& outputs);
119+
const std::vector<NDArray>& outputs,
120+
bool already_prepared);
120121

121122
void Backward(const OpContext& ctx,
122123
const std::vector<NDArray>& inputs,
@@ -163,10 +164,12 @@ static void SgDNNLSelfAttQKForward(const OpStatePtr& state_pointer,
163164
const std::vector<OpReqType>& req,
164165
const std::vector<NDArray>& outputs) {
165166
SgDNNLSelfAttQKOp& op = state_pointer.get_state<SgDNNLSelfAttQKOp>();
167+
bool already_prepared = false;
166168
if (!op.IsInitialized()) {
167169
op.Initialize(ctx, inputs, req, outputs);
170+
already_prepared = true;
168171
}
169-
op.Forward(ctx, inputs, req, outputs);
172+
op.Forward(ctx, inputs, req, outputs, already_prepared);
170173
}
171174

172175
static bool SgDNNLSelfAttStorageType(const nnvm::NodeAttrs& attrs,
@@ -264,21 +267,23 @@ void SgDNNLSelfAttQKOp::Initialize(const OpContext& ctx,
264267
void SgDNNLSelfAttQKOp::Forward(const OpContext& ctx,
265268
const std::vector<NDArray>& inputs,
266269
const std::vector<OpReqType>& req,
267-
const std::vector<NDArray>& outputs) {
268-
const size_t output_lin_dim = inputs[0].shape()[2];
269-
const size_t embed_dim = output_lin_dim / QKV_NUM;
270-
271-
MSHADOW_TYPE_SWITCH(inputs[0].dtype(), DType, {
272-
DType* query_mem_ptr = inputs[0].data().dptr<DType>();
273-
DType* key_mem_ptr = query_mem_ptr + embed_dim;
274-
cached_query_mem_->set_data_handle(query_mem_ptr);
275-
cached_key_mem_->set_data_handle(key_mem_ptr);
276-
});
277-
278-
MSHADOW_TYPE_SWITCH(outputs[0].dtype(), DType, {
279-
cached_out_mem_->set_data_handle(outputs[0].data().dptr<DType>());
280-
});
281-
270+
const std::vector<NDArray>& outputs,
271+
bool already_prepared) {
272+
if (!already_prepared) {
273+
const size_t output_lin_dim = inputs[0].shape()[2];
274+
const size_t embed_dim = output_lin_dim / QKV_NUM;
275+
276+
MSHADOW_TYPE_SWITCH(inputs[0].dtype(), DType, {
277+
DType* query_mem_ptr = inputs[0].data().dptr<DType>();
278+
DType* key_mem_ptr = query_mem_ptr + embed_dim;
279+
cached_query_mem_->set_data_handle(query_mem_ptr);
280+
cached_key_mem_->set_data_handle(key_mem_ptr);
281+
});
282+
283+
MSHADOW_TYPE_SWITCH(outputs[0].dtype(), DType, {
284+
cached_out_mem_->set_data_handle(outputs[0].data().dptr<DType>());
285+
});
286+
}
282287
DNNLStream::Get()->RegisterPrimArgs(*fwd_, args_);
283288
DNNLStream::Get()->Submit();
284289

@@ -484,7 +489,8 @@ class DNNLSelfAttValAttOp {
484489
void Forward(const OpContext& ctx,
485490
const std::vector<NDArray>& inputs,
486491
const std::vector<OpReqType>& req,
487-
const std::vector<NDArray>& outputs);
492+
const std::vector<NDArray>& outputs,
493+
bool already_prepared);
488494

489495
void Backward(const OpContext& ctx,
490496
const std::vector<NDArray>& inputs,
@@ -538,10 +544,12 @@ static void DNNLSelfAttValAttForward(const OpStatePtr& state_pointer,
538544
const std::vector<OpReqType>& req,
539545
const std::vector<NDArray>& outputs) {
540546
DNNLSelfAttValAttOp& op = state_pointer.get_state<DNNLSelfAttValAttOp>();
547+
bool already_prepared = false;
541548
if (!op.IsInitialized()) {
542549
op.Initialize(ctx, inputs, req, outputs);
550+
already_prepared = true;
543551
}
544-
op.Forward(ctx, inputs, req, outputs);
552+
op.Forward(ctx, inputs, req, outputs, already_prepared);
545553
}
546554

547555
void DNNLSelfAttValAttOp::Initialize(const OpContext& ctx,
@@ -664,29 +672,31 @@ void DNNLSelfAttValAttOp::Initialize(const OpContext& ctx,
664672
void DNNLSelfAttValAttOp::Forward(const OpContext& ctx,
665673
const std::vector<NDArray>& inputs,
666674
const std::vector<OpReqType>& req,
667-
const std::vector<NDArray>& outputs) {
668-
// multiply by 2 as we need to skip queries and keys
669-
const size_t value_offset = inputs[1].shape()[2] / QKV_NUM * 2;
670-
671-
auto att_buffer = inputs[0];
672-
if (att_buffer.IsDNNLData())
673-
att_buffer = att_buffer.Reorder2Default();
674-
675-
MSHADOW_TYPE_SWITCH(att_buffer.dtype(), DType, {
676-
DType* attention_ptr = att_buffer.data().dptr<DType>();
677-
cached_att_mem_->set_data_handle(attention_ptr);
678-
});
679-
680-
MSHADOW_TYPE_SWITCH(inputs[1].dtype(), DType, {
681-
DType* qkv_ptr = inputs[1].data().dptr<DType>();
682-
DType* value_mem_ptr = qkv_ptr + value_offset;
683-
cached_value_mem_->set_data_handle(value_mem_ptr);
684-
});
685-
686-
MSHADOW_TYPE_SWITCH(outputs[0].dtype(), DType, {
687-
cached_transposed_mem_->set_data_handle(outputs[0].data().dptr<DType>());
688-
});
689-
675+
const std::vector<NDArray>& outputs,
676+
bool already_prepared) {
677+
if (!already_prepared) {
678+
// multiply by 2 as we need to skip queries and keys
679+
const size_t value_offset = inputs[1].shape()[2] / QKV_NUM * 2;
680+
681+
auto att_buffer = inputs[0];
682+
if (att_buffer.IsDNNLData())
683+
att_buffer = att_buffer.Reorder2Default();
684+
685+
MSHADOW_TYPE_SWITCH(att_buffer.dtype(), DType, {
686+
DType* attention_ptr = att_buffer.data().dptr<DType>();
687+
cached_att_mem_->set_data_handle(attention_ptr);
688+
});
689+
690+
MSHADOW_TYPE_SWITCH(inputs[1].dtype(), DType, {
691+
DType* qkv_ptr = inputs[1].data().dptr<DType>();
692+
DType* value_mem_ptr = qkv_ptr + value_offset;
693+
cached_value_mem_->set_data_handle(value_mem_ptr);
694+
});
695+
696+
MSHADOW_TYPE_SWITCH(outputs[0].dtype(), DType, {
697+
cached_transposed_mem_->set_data_handle(outputs[0].data().dptr<DType>());
698+
});
699+
}
690700
DNNLStream::Get()->RegisterPrimArgs(*fwd_, args_);
691701
DNNLStream::Get()->RegisterPrimArgs(*reorder_, reorder_args);
692702
DNNLStream::Get()->Submit();

0 commit comments

Comments
 (0)