Skip to content

Commit ac76d4f

Browse files
committed
formatting fixes
1 parent 161c264 commit ac76d4f

1 file changed

Lines changed: 21 additions & 22 deletions

File tree

tests/test_jax_utils.py

Lines changed: 21 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
"""
55

66
from functools import partial
7-
import os
87

98
from absl.testing import absltest
109
from absl.testing import parameterized
@@ -63,7 +62,7 @@ def __call__(self, x, train, dropout_rate=DEFAULT_DROPOUT):
6362
x, rate=dropout_rate)
6463

6564

66-
class ModelEquivalenceTest(parameterized.TestCase):
65+
class DropoutTest(parameterized.TestCase):
6766

6867
@parameterized.named_parameters(
6968
dict(
@@ -185,8 +184,7 @@ def test_dropout_update(self, dropout_rate, mode):
185184
dict(testcase_name="Dropout, p=0.1, eval", dropout_rate=0.1, mode="eval"),
186185
)
187186
def test_jitted_updates(self, dropout_rate, mode):
188-
""" Compare forward pass of Dropout layer to flax.linen.Dropout in train and
189-
eval mode.
187+
""" Compare jitted updates with dropout.
190188
"""
191189

192190
# initialize models
@@ -214,24 +212,25 @@ def test_jitted_updates(self, dropout_rate, mode):
214212
jitted_custom_apply = jax.jit(
215213
partial(cust_model.apply), static_argnames=['train'])
216214

217-
def multiple_fwd_passes_custom_layer():
218-
for d in [i * 0.1 * dropout_rate for i in range(0, 11)]:
219-
y2 = jitted_custom_apply(
220-
initial_variables_custom,
221-
x,
222-
train=train,
223-
dropout_rate=d,
224-
rngs={"dropout": dropout_rng},
225-
)
226-
return y2
227-
228-
def multiple_fwd_passes_original_layer():
229-
for d in [i * 0.1 * dropout_rate for i in range(0, 11)]:
230-
y1 = jitted_original_apply(
231-
initial_variables_original,
232-
x,
233-
train=train,
234-
rngs={"dropout": dropout_rng})
215+
for d in [i * 0.1 * dropout_rate for i in range(0, 11)]:
216+
y1 = jitted_original_apply(
217+
initial_variables_original,
218+
x,
219+
train=train,
220+
rngs={"dropout": dropout_rng})
221+
return y1
222+
223+
for d in [i * 0.1 * dropout_rate for i in range(0, 11)]:
224+
y2 = jitted_custom_apply(
225+
initial_variables_custom,
226+
x,
227+
train=train,
228+
dropout_rate=d,
229+
rngs={"dropout": dropout_rng},
230+
)
231+
return y2
232+
233+
assert jnp.allclose(y1, y2, atol=1e-3, rtol=1e-3)
235234

236235

237236
if __name__ == "__main__":

0 commit comments

Comments
 (0)