Skip to content

Commit a151382

Browse files
committed
pylint fix
1 parent 0f43049 commit a151382

1 file changed

Lines changed: 1 addition & 1 deletion

File tree

tests/test_jax_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -185,7 +185,7 @@ def test_jitted_updates(self, dropout_rate, mode):
185185
"""
186186

187187
# initialize models
188-
rng, data_rng, dropout_rng = jax.random.split(jax.random.key(SEED), 3)
188+
rng, dropout_rng = jax.random.split(jax.random.key(SEED), 2)
189189
fake_batch = jnp.ones((10,))
190190
orig_model = LegacyDropoutModel(dropout_rate=dropout_rate)
191191
cust_model = DropoutModel()

0 commit comments

Comments
 (0)