Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.

Commit 9fa75b4

Browse files
bgawrychBartlomiej Gawrych
andauthored
Fix identity fuse for oneDNN (#20767)
* Fix identity fuse * add new line * apply review comments * fix Co-authored-by: Bartlomiej Gawrych <barlomiej.gawrych@intel.com>
1 parent 0c3ef7a commit 9fa75b4

2 files changed

Lines changed: 24 additions & 10 deletions

File tree

src/operator/subgraph/dnnl/dnnl_identity_property.h

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ namespace op {
4040
class SgDNNLIdentitySelector : public SubgraphSelectorV2 {
4141
private:
4242
std::vector<const BiDirectedNode*> matched_list_;
43+
bool pattern_found = false;
4344

4445
public:
4546
bool Select(const BiDirectedNode& seed_node,
@@ -65,10 +66,11 @@ class SgDNNLIdentitySelector : public SubgraphSelectorV2 {
6566
}
6667

6768
bool SelectInput(const BiDirectedNode& n, const BiDirectedNode& input_node) override {
68-
if (input_node.node->is_variable()) {
69+
if (pattern_found || input_node.node->is_variable()) {
6970
return false;
7071
} else if (input_node.node->op()) {
7172
matched_list_.emplace_back(&input_node);
73+
pattern_found = true;
7274
return true;
7375
}
7476
return false;
@@ -80,7 +82,8 @@ class SgDNNLIdentitySelector : public SubgraphSelectorV2 {
8082

8183
std::vector<BiDirectedNode*> Filter(const std::vector<BiDirectedNode*>& candidates) override {
8284
// candidates should contain only two nodes - custom node and identity node
83-
if (candidates.size() == 2 && candidates.size() == matched_list_.size()) {
85+
if (pattern_found && candidates.size() == matched_list_.size()) {
86+
CHECK_EQ(candidates.size(), 2);
8487
return candidates;
8588
} else {
8689
return std::vector<BiDirectedNode*>(0);
@@ -134,8 +137,10 @@ class SgDNNLIdentityProperty : public SubgraphProperty {
134137
// Create copy of original node
135138
nnvm::ObjectPtr n = nnvm::Node::Create();
136139
n->attrs = org_node->attrs;
137-
CHECK(n->op());
138-
n->op()->attr_parser(&(n->attrs));
140+
CHECK(n->op()) << "WRTF";
141+
if (n->op()->attr_parser) {
142+
n->op()->attr_parser(&(n->attrs));
143+
}
139144
return n;
140145
}
141146

tests/python/dnnl/subgraphs/test_fc_subgraph.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -208,18 +208,27 @@ def test_fc_identity_eltwise(identity_node):
208208
class FCIdentityEltwise(nn.HybridBlock):
209209
def __init__(self, identity_node, **kwargs):
210210
super(FCIdentityEltwise, self).__init__(**kwargs)
211-
self.fc = nn.Dense(units=64, use_bias=False, weight_initializer=None, flatten=True)
211+
self.fc1 = nn.Dense(units=64, use_bias=False, weight_initializer=None, flatten=True)
212+
self.fc2 = nn.Dense(units=64, use_bias=False, weight_initializer=None, flatten=True)
212213
self.identity_node = identity_node
214+
213215
def forward(self, x):
214-
fc_out = self.fc(x)
216+
out = self.fc1(x)
215217
if self.identity_node == 'copy':
216-
fc_out = mx.np.copy(fc_out)
218+
out = mx.np.copy(out)
217219
else:
218-
fc_out = mx.npx.dropout(fc_out)
219-
out = mx.npx.activation(fc_out, act_type='relu')
220+
out = mx.npx.dropout(out)
221+
out = mx.npx.activation(out, act_type='relu')
222+
out = self.fc2(out)
223+
if self.identity_node == 'copy':
224+
out = mx.np.copy(out)
225+
else:
226+
out = mx.npx.dropout(out)
227+
out = mx.npx.activation(out, act_type='relu')
220228
return out
221229

222230
data_shape = (64, 4, 10, 10)
223-
attrs = {'fc': {'with_eltwise': 'true'}}
231+
attrs = {'sg_onednn_fully_connected_eltwise_0' : {'with_eltwise': 'true'},
232+
'sg_onednn_fully_connected_eltwise_1' : {'with_eltwise': 'true'}}
224233
net = FCIdentityEltwise(identity_node)
225234
check_fusion(net, data_shape, attrs, check_quantization=False)

0 commit comments

Comments
 (0)