Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.

Commit 84b1626

Browse files
author
bgawrych
authored
oneDNN FullyConnected weight caching & refactor (#21047)
* FC weight and bias caching * prepare output for sum * check initialization conditions * create output mem desc * PrepareQuantization * remove unused variables * cleanup * Enable BRGEMM * Reorder functions * make minmax enum anonymous * node identificator & env flag * fix sanity * fix sanity * apply review * rename variable
1 parent b713dc5 commit 84b1626

4 files changed

Lines changed: 401 additions & 308 deletions

File tree

src/operator/nn/dnnl/dnnl_base-inl.h

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -351,13 +351,6 @@ inline static dnnl::memory::desc GetMemDesc(const NDArray& arr, int dtype = -1)
351351
return dnnl::memory::desc{dims, get_dnnl_type(dtype), dnnl::memory::format_tag::any};
352352
}
353353

354-
inline static bool ChooseBRGEMMImpl(const dnnl::memory::dims& weight_dims, size_t batch_size) {
355-
// Conditions based on measurement results done on CLX8280
356-
// https://github.com/apache/incubator-mxnet/pull/20533
357-
return weight_dims[0] >= 1024 && weight_dims[1] >= 1024 && batch_size >= 16384 &&
358-
weight_dims[0] % 64 == 0 && weight_dims[1] % 64 == 0;
359-
}
360-
361354
inline static dnnl::memory::desc GetFCWeightDesc(const NDArray& arr,
362355
size_t batch_size,
363356
int dtype = -1) {
@@ -370,7 +363,7 @@ inline static dnnl::memory::desc GetFCWeightDesc(const NDArray& arr,
370363
// for batch 256 alexnet benchmark test
371364
const bool force_fc_ab_format = dmlc::GetEnv("MXNET_ONEDNN_FORCE_FC_AB_FORMAT", false);
372365
if (dims.size() == 2) {
373-
if (force_fc_ab_format || !ChooseBRGEMMImpl(dims, batch_size)) {
366+
if (force_fc_ab_format || dtype != mshadow::kInt8) {
374367
format = dnnl::memory::format_tag::ab;
375368
}
376369
}

src/operator/operator_common.h

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -547,7 +547,7 @@ class OpSignature {
547547

548548
#if MXNET_USE_ONEDNN == 1
549549
void AddSign(const dnnl::memory::desc& desc) {
550-
hash = hash * 2 + desc.data.format_kind;
550+
hash = hash * 2 + desc.data.format_kind;
551551
eles.push_back(desc.data.format_kind);
552552
hash = hash * 2 + desc.data.data_type;
553553
eles.push_back(desc.data.data_type);
@@ -617,6 +617,11 @@ class OpSignature {
617617

618618
#endif
619619

620+
void AddSign(const std::string& s) {
621+
uint64_t key = static_cast<uint64_t>(std::hash<std::string>{}(s));
622+
eles.push_back(key);
623+
}
624+
620625
void AddSign(const std::vector<NDArray>& arrs) {
621626
for (auto& arr : arrs) {
622627
AddSign(arr);

0 commit comments

Comments
 (0)