Skip to content

Commit e4eacea

Browse files
committed
fix linting
1 parent ac76d4f commit e4eacea

1 file changed

Lines changed: 4 additions & 7 deletions

File tree

tests/test_jax_utils.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -22,13 +22,12 @@
2222

2323
def pytrees_are_equal(a, b, rtol=1e-5, atol=1e-8):
2424
"""
25-
A custom function to check if two PyTrees are equal, handling floats with a tolerance.
25+
A custom function to check if two PyTrees are equal, handling floats with
26+
a tolerance.
2627
"""
27-
# 1. Check if the structures are the same
2828
if tree_structure(a) != tree_structure(b):
2929
return False
3030

31-
# 2. Define a comparison function for leaves
3231
def leaf_comparator(x, y):
3332
# Use allclose for floating-point JAX arrays
3433
if isinstance(x, jnp.ndarray) and jnp.issubdtype(x.dtype, jnp.floating):
@@ -37,8 +36,6 @@ def leaf_comparator(x, y):
3736
else:
3837
return x == y
3938

40-
# 3. Map the comparison function over the trees and check if all results are True
41-
# We also need to flatten the results of the tree_map and check if all are True
4239
comparison_tree = tree_map(leaf_comparator, a, b)
4340
all_equal = all(tree_leaves(comparison_tree))
4441

@@ -80,7 +77,7 @@ def test_forward(self, dropout_rate, mode):
8077
"""
8178

8279
# initialize models
83-
rng, data_rng, dropout_rng = jax.random.split(jax.random.key(SEED), 3)
80+
rng, dropout_rng = jax.random.split(jax.random.key(SEED), 2)
8481
fake_batch = jnp.ones((10,))
8582
orig_model = LegacyDropoutModel(dropout_rate=dropout_rate)
8683
cust_model = DropoutModel()
@@ -130,7 +127,7 @@ def test_dropout_update(self, dropout_rate, mode):
130127
eval mode.
131128
"""
132129
# init model
133-
rng, data_rng, dropout_rng = jax.random.split(jax.random.key(SEED), 3)
130+
rng, dropout_rng = jax.random.split(jax.random.key(SEED), 2)
134131
fake_batch = jnp.ones((10,))
135132
orig_model = LegacyDropoutModel(dropout_rate=dropout_rate)
136133
cust_model = DropoutModel()

0 commit comments

Comments
 (0)