@@ -208,18 +208,27 @@ def test_fc_identity_eltwise(identity_node):
208208 class FCIdentityEltwise (nn .HybridBlock ):
209209 def __init__ (self , identity_node , ** kwargs ):
210210 super (FCIdentityEltwise , self ).__init__ (** kwargs )
211- self .fc = nn .Dense (units = 64 , use_bias = False , weight_initializer = None , flatten = True )
211+ self .fc1 = nn .Dense (units = 64 , use_bias = False , weight_initializer = None , flatten = True )
212+ self .fc2 = nn .Dense (units = 64 , use_bias = False , weight_initializer = None , flatten = True )
212213 self .identity_node = identity_node
214+
213215 def forward (self , x ):
214- fc_out = self .fc (x )
216+ out = self .fc1 (x )
215217 if self .identity_node == 'copy' :
216- fc_out = mx .np .copy (fc_out )
218+ out = mx .np .copy (out )
217219 else :
218- fc_out = mx .npx .dropout (fc_out )
219- out = mx .npx .activation (fc_out , act_type = 'relu' )
220+ out = mx .npx .dropout (out )
221+ out = mx .npx .activation (out , act_type = 'relu' )
222+ out = self .fc2 (out )
223+ if self .identity_node == 'copy' :
224+ out = mx .np .copy (out )
225+ else :
226+ out = mx .npx .dropout (out )
227+ out = mx .npx .activation (out , act_type = 'relu' )
220228 return out
221229
222230 data_shape = (64 , 4 , 10 , 10 )
223- attrs = {'fc' : {'with_eltwise' : 'true' }}
231+ attrs = {'sg_onednn_fully_connected_eltwise_0' : {'with_eltwise' : 'true' },
232+ 'sg_onednn_fully_connected_eltwise_1' : {'with_eltwise' : 'true' }}
224233 net = FCIdentityEltwise (identity_node )
225234 check_fusion (net , data_shape , attrs , check_quantization = False )
0 commit comments