2222
2323def pytrees_are_equal (a , b , rtol = 1e-5 , atol = 1e-8 ):
2424 """
25- A custom function to check if two PyTrees are equal, handling floats with a tolerance.
25+ A custom function to check if two PyTrees are equal, handling floats with
26+ a tolerance.
2627 """
27- # 1. Check if the structures are the same
2828 if tree_structure (a ) != tree_structure (b ):
2929 return False
3030
31- # 2. Define a comparison function for leaves
3231 def leaf_comparator (x , y ):
3332 # Use allclose for floating-point JAX arrays
3433 if isinstance (x , jnp .ndarray ) and jnp .issubdtype (x .dtype , jnp .floating ):
@@ -37,8 +36,6 @@ def leaf_comparator(x, y):
3736 else :
3837 return x == y
3938
40- # 3. Map the comparison function over the trees and check if all results are True
41- # We also need to flatten the results of the tree_map and check if all are True
4239 comparison_tree = tree_map (leaf_comparator , a , b )
4340 all_equal = all (tree_leaves (comparison_tree ))
4441
@@ -80,7 +77,7 @@ def test_forward(self, dropout_rate, mode):
8077 """
8178
8279 # initialize models
83- rng , data_rng , dropout_rng = jax .random .split (jax .random .key (SEED ), 3 )
80+ rng , dropout_rng = jax .random .split (jax .random .key (SEED ), 2 )
8481 fake_batch = jnp .ones ((10 ,))
8582 orig_model = LegacyDropoutModel (dropout_rate = dropout_rate )
8683 cust_model = DropoutModel ()
@@ -130,7 +127,7 @@ def test_dropout_update(self, dropout_rate, mode):
130127 eval mode.
131128 """
132129 # init model
133- rng , data_rng , dropout_rng = jax .random .split (jax .random .key (SEED ), 3 )
130+ rng , dropout_rng = jax .random .split (jax .random .key (SEED ), 2 )
134131 fake_batch = jnp .ones ((10 ,))
135132 orig_model = LegacyDropoutModel (dropout_rate = dropout_rate )
136133 cust_model = DropoutModel ()
0 commit comments