Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.

Commit 4d96671

Browse files
stephenrawlsszha
authored andcommitted
fixing var-seq-len rnn backward() operator (#15278)
* fixing var-seq-len rnn backward() operator * updating var-length lstm to test backward pass * removing bit of dbg print to stderr i forgot to remove earlier * resolving TODO about using int32 for sequence_length * setting rtol and atol similar to other tests in this file
1 parent 145f82d commit 4d96671

2 files changed

Lines changed: 51 additions & 25 deletions

File tree

src/operator/rnn-inl.h

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1583,8 +1583,11 @@ static OpStatePtr CreateRNNState(const nnvm::NodeAttrs &attrs,
15831583
int dtype = in_types[rnn_enum::kData];
15841584
int itype = dtype;
15851585
if (param.use_sequence_length) {
1586-
itype = in_types[rnn_enum::kSequenceLength];
1587-
if (param.mode == rnn_enum::kLstm) itype -= 1;
1586+
size_t seq_len_input_idx = rnn_enum::kSequenceLength;
1587+
if (param.mode != rnn_enum::kLstm) {
1588+
seq_len_input_idx -= 1;
1589+
}
1590+
itype = in_types[seq_len_input_idx];
15881591
}
15891592

15901593
MSHADOW_REAL_TYPE_SWITCH(dtype, DType, {
@@ -1649,7 +1652,7 @@ void RNNStatefulGradCompute(const OpStatePtr& state,
16491652
// Hacky. This relies on fact that seq-len type is either the last input,
16501653
// or we aren't using seq-len input and this type should be same as dtype.
16511654
// Would prefer direct access to RNNParam object here but not sure how to get.
1652-
int itype = inputs[inputs.size()-1].type_flag_;
1655+
int itype = outputs[outputs.size()-1].type_flag_;
16531656

16541657
MSHADOW_REAL_TYPE_SWITCH(dtype, DType, {
16551658
MSHADOW_TYPE_SWITCH(itype, IType, {
@@ -1669,6 +1672,15 @@ void RNNStatefulGradCompute(const OpStatePtr& state,
16691672
}
16701673
}
16711674

1675+
1676+
if (param.use_sequence_length) {
1677+
size_t seq_len_input_idx = rnn_enum::kSequenceLength;
1678+
if (param.mode != rnn_enum::kLstm) {
1679+
seq_len_input_idx -= 1;
1680+
}
1681+
in_data.push_back(outputs[seq_len_input_idx]);
1682+
}
1683+
16721684
op.Backward(ctx, out_grad, in_data, out_data, req, in_grad);
16731685
});
16741686
});

tests/python/gpu/test_gluon_gpu.py

Lines changed: 36 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -227,19 +227,6 @@ def forward(self, inpt):
227227

228228

229229
def check_layer_bidirectional_varseqlen(size, in_size):
230-
class RefBiLSTMVarSeqLen(gluon.Block):
231-
def __init__(self, size, **kwargs):
232-
super(RefBiLSTMVarSeqLen, self).__init__(**kwargs)
233-
with self.name_scope():
234-
self._lstm_fwd = gluon.rnn.LSTM(size, bidirectional=False, prefix='l0')
235-
self._lstm_bwd = gluon.rnn.LSTM(size, bidirectional=False, prefix='r0')
236-
237-
def forward(self, inpt, sequence_length):
238-
fwd = self._lstm_fwd(inpt)
239-
bwd_inpt = nd.SequenceReverse(inpt, sequence_length=sequence_length, use_sequence_length=True)
240-
bwd = self._lstm_bwd(bwd_inpt)
241-
bwd = nd.SequenceReverse(bwd, sequence_length=sequence_length, use_sequence_length=True)
242-
return nd.concat(fwd, bwd, dim=2)
243230
weights = {}
244231
for d in ['l', 'r']:
245232
weights['lstm_{}0_i2h_weight'.format(d)] = mx.random.uniform(shape=(size*4, in_size))
@@ -248,31 +235,58 @@ def forward(self, inpt, sequence_length):
248235
weights['lstm_{}0_h2h_bias'.format(d)] = mx.random.uniform(shape=(size*4,))
249236

250237
net = gluon.rnn.LSTM(size, bidirectional=True, use_sequence_length=True, prefix='lstm_')
251-
ref_net = RefBiLSTMVarSeqLen(size, prefix='lstm_')
238+
ref_net = gluon.rnn.LSTM(size, bidirectional=True, use_sequence_length=False, prefix='lstm_ref_')
252239
net.initialize()
253240
ref_net.initialize()
254241
net_params = net.collect_params()
255242
ref_net_params = ref_net.collect_params()
256243
for k in weights:
257244
net_params[k].set_data(weights[k])
258-
ref_net_params[k.replace('l0', 'l0l0').replace('r0', 'r0l0')].set_data(weights[k])
259-
245+
ref_net_params[k.replace("lstm_", "lstm_ref_")].set_data(weights[k])
260246

261247
batch_size = 10
262248
num_timesteps = 11
263249
data = mx.random.uniform(shape=(num_timesteps, batch_size, in_size))
250+
data_np = data.asnumpy()
264251

265-
# TODO: figure out why int32 doesn't work here
266-
sequence_length = nd.random.randint(1, num_timesteps+1, shape=(batch_size)).astype("float")
267-
268-
net_output = net(data, sequence_length=sequence_length).asnumpy()
269-
ref_net_output = ref_net(data, sequence_length).asnumpy()
252+
sequence_length = nd.random.randint(1, num_timesteps+1, shape=(batch_size)).astype("int32")
270253
sequence_length_np = sequence_length.asnumpy().astype("int32")
271254

255+
# Reference net is processing batch elements one at a time, so that it is "perfectly sized"
256+
# Because of that, we need to accumulate gradients in reference net.
257+
for p in ref_net.collect_params().values():
258+
p.grad_req = 'add'
259+
260+
ref_net_output = []
261+
with autograd.record():
262+
net_output = net(data.copy(), sequence_length=sequence_length.copy())
263+
264+
for b in range(batch_size):
265+
data_slice = mx.nd.array(data_np[:sequence_length_np[b], b, :]).reshape(sequence_length_np[b], 1, in_size)
266+
ref_output_slice = ref_net(data_slice)
267+
ref_net_output.append(ref_output_slice)
268+
269+
net_output_np = net_output.asnumpy()
270+
272271
# TODO: test state return value as well output
273272
# Only compare the valid sections for each batch entry
274273
for b in range(batch_size):
275-
assert_allclose(net_output[:sequence_length_np[b], b], ref_net_output[:sequence_length_np[b], b])
274+
assert_allclose(net_output_np[:sequence_length_np[b], b], ref_net_output[b].asnumpy().squeeze(1),
275+
rtol=1e-2, atol=1e-6)
276+
277+
# Now test backward
278+
net_output.backward()
279+
280+
for ref_output_slice in ref_net_output:
281+
ref_output_slice.backward()
282+
283+
ref_net_params = ref_net.collect_params()
284+
285+
for k in weights:
286+
net_grad = net_params[k].grad()
287+
ref_net_grad = ref_net_params[k.replace('lstm_', 'lstm_ref_')].grad()
288+
assert_almost_equal(net_grad.asnumpy(), ref_net_grad.asnumpy(),
289+
rtol=1e-2, atol=1e-6)
276290

277291

278292
@with_seed()

0 commit comments

Comments
 (0)