|
4 | 4 | """ |
5 | 5 |
|
6 | 6 | from functools import partial |
7 | | -import os |
8 | 7 |
|
9 | 8 | from absl.testing import absltest |
10 | 9 | from absl.testing import parameterized |
@@ -63,7 +62,7 @@ def __call__(self, x, train, dropout_rate=DEFAULT_DROPOUT): |
63 | 62 | x, rate=dropout_rate) |
64 | 63 |
|
65 | 64 |
|
66 | | -class ModelEquivalenceTest(parameterized.TestCase): |
| 65 | +class DropoutTest(parameterized.TestCase): |
67 | 66 |
|
68 | 67 | @parameterized.named_parameters( |
69 | 68 | dict( |
@@ -185,8 +184,7 @@ def test_dropout_update(self, dropout_rate, mode): |
185 | 184 | dict(testcase_name="Dropout, p=0.1, eval", dropout_rate=0.1, mode="eval"), |
186 | 185 | ) |
187 | 186 | def test_jitted_updates(self, dropout_rate, mode): |
188 | | - """ Compare forward pass of Dropout layer to flax.linen.Dropout in train and |
189 | | - eval mode. |
| 187 | + """ Compare jitted updates with dropout. |
190 | 188 | """ |
191 | 189 |
|
192 | 190 | # initialize models |
@@ -214,24 +212,25 @@ def test_jitted_updates(self, dropout_rate, mode): |
214 | 212 | jitted_custom_apply = jax.jit( |
215 | 213 | partial(cust_model.apply), static_argnames=['train']) |
216 | 214 |
|
217 | | - def multiple_fwd_passes_custom_layer(): |
218 | | - for d in [i * 0.1 * dropout_rate for i in range(0, 11)]: |
219 | | - y2 = jitted_custom_apply( |
220 | | - initial_variables_custom, |
221 | | - x, |
222 | | - train=train, |
223 | | - dropout_rate=d, |
224 | | - rngs={"dropout": dropout_rng}, |
225 | | - ) |
226 | | - return y2 |
227 | | - |
228 | | - def multiple_fwd_passes_original_layer(): |
229 | | - for d in [i * 0.1 * dropout_rate for i in range(0, 11)]: |
230 | | - y1 = jitted_original_apply( |
231 | | - initial_variables_original, |
232 | | - x, |
233 | | - train=train, |
234 | | - rngs={"dropout": dropout_rng}) |
| 215 | + for d in [i * 0.1 * dropout_rate for i in range(0, 11)]: |
| 216 | + y1 = jitted_original_apply( |
| 217 | + initial_variables_original, |
| 218 | + x, |
| 219 | + train=train, |
| 220 | + rngs={"dropout": dropout_rng}) |
| 221 | + return y1 |
| 222 | + |
| 223 | + for d in [i * 0.1 * dropout_rate for i in range(0, 11)]: |
| 224 | + y2 = jitted_custom_apply( |
| 225 | + initial_variables_custom, |
| 226 | + x, |
| 227 | + train=train, |
| 228 | + dropout_rate=d, |
| 229 | + rngs={"dropout": dropout_rng}, |
| 230 | + ) |
| 231 | + return y2 |
| 232 | + |
| 233 | + assert jnp.allclose(y1, y2, atol=1e-3, rtol=1e-3) |
235 | 234 |
|
236 | 235 |
|
237 | 236 | if __name__ == "__main__": |
|
0 commit comments