Skip to content

Commit 5efdfad

Browse files
committed
change var name to enhance readibility
pyt --> pytorch k --> eval_metric_key
1 parent e5b7930 commit 5efdfad

32 files changed

Lines changed: 94 additions & 94 deletions

File tree

tests/modeldiffs/criteo1tb/compare.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -53,16 +53,16 @@ def sd_transform(sd):
5353
jax_workload = JaxWorkload()
5454
pytorch_workload = PyTorchWorkload()
5555

56-
pyt_batch = {
56+
pytorch_batch = {
5757
'inputs': torch.ones((2, 13 + 26)),
5858
'targets': torch.randint(low=0, high=1, size=(2,)),
5959
'weights': torch.ones(2),
6060
}
61-
jax_batch = {k: np.array(v) for k, v in pyt_batch.items()}
61+
jax_batch = {k: np.array(v) for k, v in pytorch_batch.items()}
6262

6363
# Test outputs for identical weights and inputs.
6464
pytorch_model_kwargs = dict(
65-
augmented_and_preprocessed_input_batch=pyt_batch,
65+
augmented_and_preprocessed_input_batch=pytorch_batch,
6666
model_state=None,
6767
mode=spec.ForwardPassMode.EVAL,
6868
rng=None,

tests/modeldiffs/criteo1tb_embed_init/compare.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -52,16 +52,16 @@ def sd_transform(sd):
5252
jax_workload = JaxWorkload()
5353
pytorch_workload = PyTorchWorkload()
5454

55-
pyt_batch = {
55+
pytorch_batch = {
5656
'inputs': torch.ones((2, 13 + 26)),
5757
'targets': torch.randint(low=0, high=1, size=(2,)),
5858
'weights': torch.ones(2),
5959
}
60-
jax_batch = {k: np.array(v) for k, v in pyt_batch.items()}
60+
jax_batch = {k: np.array(v) for k, v in pytorch_batch.items()}
6161

6262
# Test outputs for identical weights and inputs.
6363
pytorch_model_kwargs = dict(
64-
augmented_and_preprocessed_input_batch=pyt_batch,
64+
augmented_and_preprocessed_input_batch=pytorch_batch,
6565
model_state=None,
6666
mode=spec.ForwardPassMode.EVAL,
6767
rng=None,

tests/modeldiffs/criteo1tb_layernorm/compare.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -64,16 +64,16 @@ def sd_transform(sd):
6464
jax_workload = JaxWorkload()
6565
pytorch_workload = PyTorchWorkload()
6666

67-
pyt_batch = {
67+
pytorch_batch = {
6868
'inputs': torch.ones((2, 13 + 26)),
6969
'targets': torch.randint(low=0, high=1, size=(2,)),
7070
'weights': torch.ones(2),
7171
}
72-
jax_batch = {k: np.array(v) for k, v in pyt_batch.items()}
72+
jax_batch = {k: np.array(v) for k, v in pytorch_batch.items()}
7373

7474
# Test outputs for identical weights and inputs.
7575
pytorch_model_kwargs = dict(
76-
augmented_and_preprocessed_input_batch=pyt_batch,
76+
augmented_and_preprocessed_input_batch=pytorch_batch,
7777
model_state=None,
7878
mode=spec.ForwardPassMode.EVAL,
7979
rng=None,

tests/modeldiffs/criteo1tb_resnet/compare.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ def sd_transform(sd):
6464
jax_workload = JaxWorkload()
6565
pytorch_workload = PyTorchWorkload()
6666

67-
pyt_batch = {
67+
pytorch_batch = {
6868
'inputs': torch.ones((2, 13 + 26)),
6969
'targets': torch.randint(low=0, high=1, size=(2,)),
7070
'weights': torch.ones(2),
@@ -75,12 +75,12 @@ def sd_transform(sd):
7575
input_size = 13 + num_categorical_features
7676
input_shape = (init_fake_batch_size, input_size)
7777
fake_inputs = jnp.ones(input_shape, jnp.float32)
78-
jax_batch = {k: np.array(v) for k, v in pyt_batch.items()}
78+
jax_batch = {k: np.array(v) for k, v in pytorch_batch.items()}
7979
jax_batch['inputs'] = fake_inputs
8080

8181
# Test outputs for identical weights and inputs.
8282
pytorch_model_kwargs = dict(
83-
augmented_and_preprocessed_input_batch=pyt_batch,
83+
augmented_and_preprocessed_input_batch=pytorch_batch,
8484
model_state=None,
8585
mode=spec.ForwardPassMode.EVAL,
8686
rng=None,

tests/modeldiffs/fastmri/compare.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -61,10 +61,10 @@ def sort_key(k):
6161
image = torch.randn(2, 320, 320)
6262

6363
jax_batch = {'inputs': image.detach().numpy()}
64-
pyt_batch = {'inputs': image}
64+
pytorch_batch = {'inputs': image}
6565

6666
pytorch_model_kwargs = dict(
67-
augmented_and_preprocessed_input_batch=pyt_batch,
67+
augmented_and_preprocessed_input_batch=pytorch_batch,
6868
model_state=None,
6969
mode=spec.ForwardPassMode.EVAL,
7070
rng=None,

tests/modeldiffs/fastmri_layernorm/compare.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -68,10 +68,10 @@ def sort_key(k):
6868
image = torch.randn(2, 320, 320)
6969

7070
jax_batch = {'inputs': image.detach().numpy()}
71-
pyt_batch = {'inputs': image}
71+
pytorch_batch = {'inputs': image}
7272

7373
pytorch_model_kwargs = dict(
74-
augmented_and_preprocessed_input_batch=pyt_batch,
74+
augmented_and_preprocessed_input_batch=pytorch_batch,
7575
model_state=None,
7676
mode=spec.ForwardPassMode.EVAL,
7777
rng=None,

tests/modeldiffs/fastmri_model_size/compare.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -61,10 +61,10 @@ def sort_key(k):
6161
image = torch.randn(2, 320, 320)
6262

6363
jax_batch = {'inputs': image.detach().numpy()}
64-
pyt_batch = {'inputs': image}
64+
pytorch_batch = {'inputs': image}
6565

6666
pytorch_model_kwargs = dict(
67-
augmented_and_preprocessed_input_batch=pyt_batch,
67+
augmented_and_preprocessed_input_batch=pytorch_batch,
6868
model_state=None,
6969
mode=spec.ForwardPassMode.EVAL,
7070
rng=None,

tests/modeldiffs/fastmri_tanh/compare.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -61,10 +61,10 @@ def sort_key(k):
6161
image = torch.randn(2, 320, 320)
6262

6363
jax_batch = {'inputs': image.detach().numpy()}
64-
pyt_batch = {'inputs': image}
64+
pytorch_batch = {'inputs': image}
6565

6666
pytorch_model_kwargs = dict(
67-
augmented_and_preprocessed_input_batch=pyt_batch,
67+
augmented_and_preprocessed_input_batch=pytorch_batch,
6868
model_state=None,
6969
mode=spec.ForwardPassMode.EVAL,
7070
rng=None,

tests/modeldiffs/imagenet_resnet/compare.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -78,10 +78,10 @@ def sd_transform(sd):
7878
image = torch.randn(2, 3, 224, 224)
7979

8080
jax_batch = {'inputs': image.permute(0, 2, 3, 1).detach().numpy()}
81-
pyt_batch = {'inputs': image}
81+
pytorch_batch = {'inputs': image}
8282

8383
pytorch_model_kwargs = dict(
84-
augmented_and_preprocessed_input_batch=pyt_batch,
84+
augmented_and_preprocessed_input_batch=pytorch_batch,
8585
model_state=None,
8686
mode=spec.ForwardPassMode.EVAL,
8787
rng=None,

tests/modeldiffs/imagenet_resnet/gelu_compare.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,10 +25,10 @@
2525
image = torch.randn(2, 3, 224, 224)
2626

2727
jax_batch = {'inputs': image.permute(0, 2, 3, 1).detach().numpy()}
28-
pyt_batch = {'inputs': image}
28+
pytorch_batch = {'inputs': image}
2929

3030
pytorch_model_kwargs = dict(
31-
augmented_and_preprocessed_input_batch=pyt_batch,
31+
augmented_and_preprocessed_input_batch=pytorch_batch,
3232
model_state=None,
3333
mode=spec.ForwardPassMode.EVAL,
3434
rng=None,

0 commit comments

Comments
 (0)