@@ -227,19 +227,6 @@ def forward(self, inpt):
227227
228228
229229def check_layer_bidirectional_varseqlen (size , in_size ):
230- class RefBiLSTMVarSeqLen (gluon .Block ):
231- def __init__ (self , size , ** kwargs ):
232- super (RefBiLSTMVarSeqLen , self ).__init__ (** kwargs )
233- with self .name_scope ():
234- self ._lstm_fwd = gluon .rnn .LSTM (size , bidirectional = False , prefix = 'l0' )
235- self ._lstm_bwd = gluon .rnn .LSTM (size , bidirectional = False , prefix = 'r0' )
236-
237- def forward (self , inpt , sequence_length ):
238- fwd = self ._lstm_fwd (inpt )
239- bwd_inpt = nd .SequenceReverse (inpt , sequence_length = sequence_length , use_sequence_length = True )
240- bwd = self ._lstm_bwd (bwd_inpt )
241- bwd = nd .SequenceReverse (bwd , sequence_length = sequence_length , use_sequence_length = True )
242- return nd .concat (fwd , bwd , dim = 2 )
243230 weights = {}
244231 for d in ['l' , 'r' ]:
245232 weights ['lstm_{}0_i2h_weight' .format (d )] = mx .random .uniform (shape = (size * 4 , in_size ))
@@ -248,31 +235,58 @@ def forward(self, inpt, sequence_length):
248235 weights ['lstm_{}0_h2h_bias' .format (d )] = mx .random .uniform (shape = (size * 4 ,))
249236
250237 net = gluon .rnn .LSTM (size , bidirectional = True , use_sequence_length = True , prefix = 'lstm_' )
251- ref_net = RefBiLSTMVarSeqLen (size , prefix = 'lstm_ ' )
238+ ref_net = gluon . rnn . LSTM (size , bidirectional = True , use_sequence_length = False , prefix = 'lstm_ref_ ' )
252239 net .initialize ()
253240 ref_net .initialize ()
254241 net_params = net .collect_params ()
255242 ref_net_params = ref_net .collect_params ()
256243 for k in weights :
257244 net_params [k ].set_data (weights [k ])
258- ref_net_params [k .replace ('l0' , 'l0l0' ).replace ('r0' , 'r0l0' )].set_data (weights [k ])
259-
245+ ref_net_params [k .replace ("lstm_" , "lstm_ref_" )].set_data (weights [k ])
260246
261247 batch_size = 10
262248 num_timesteps = 11
263249 data = mx .random .uniform (shape = (num_timesteps , batch_size , in_size ))
250+ data_np = data .asnumpy ()
264251
265- # TODO: figure out why int32 doesn't work here
266- sequence_length = nd .random .randint (1 , num_timesteps + 1 , shape = (batch_size )).astype ("float" )
267-
268- net_output = net (data , sequence_length = sequence_length ).asnumpy ()
269- ref_net_output = ref_net (data , sequence_length ).asnumpy ()
252+ sequence_length = nd .random .randint (1 , num_timesteps + 1 , shape = (batch_size )).astype ("int32" )
270253 sequence_length_np = sequence_length .asnumpy ().astype ("int32" )
271254
255+ # Reference net is processing batch elements one at a time, so that it is "perfectly sized"
256+ # Because of that, we need to accumulate gradients in reference net.
257+ for p in ref_net .collect_params ().values ():
258+ p .grad_req = 'add'
259+
260+ ref_net_output = []
261+ with autograd .record ():
262+ net_output = net (data .copy (), sequence_length = sequence_length .copy ())
263+
264+ for b in range (batch_size ):
265+ data_slice = mx .nd .array (data_np [:sequence_length_np [b ], b , :]).reshape (sequence_length_np [b ], 1 , in_size )
266+ ref_output_slice = ref_net (data_slice )
267+ ref_net_output .append (ref_output_slice )
268+
269+ net_output_np = net_output .asnumpy ()
270+
272271 # TODO: test state return value as well output
273272 # Only compare the valid sections for each batch entry
274273 for b in range (batch_size ):
275- assert_allclose (net_output [:sequence_length_np [b ], b ], ref_net_output [:sequence_length_np [b ], b ])
274+ assert_allclose (net_output_np [:sequence_length_np [b ], b ], ref_net_output [b ].asnumpy ().squeeze (1 ),
275+ rtol = 1e-2 , atol = 1e-6 )
276+
277+ # Now test backward
278+ net_output .backward ()
279+
280+ for ref_output_slice in ref_net_output :
281+ ref_output_slice .backward ()
282+
283+ ref_net_params = ref_net .collect_params ()
284+
285+ for k in weights :
286+ net_grad = net_params [k ].grad ()
287+ ref_net_grad = ref_net_params [k .replace ('lstm_' , 'lstm_ref_' )].grad ()
288+ assert_almost_equal (net_grad .asnumpy (), ref_net_grad .asnumpy (),
289+ rtol = 1e-2 , atol = 1e-6 )
276290
277291
278292@with_seed ()
0 commit comments