File tree Expand file tree Collapse file tree
Expand file tree Collapse file tree Original file line number Diff line number Diff line change 1212 Criteo1TbDlrmSmallWorkload as JaxWorkload
1313from algoperf .workloads .criteo1tb .criteo1tb_pytorch .workload import \
1414 Criteo1TbDlrmSmallWorkload as PyTorchWorkload
15- from tests .modeldiffs .diff import out_diff
15+ from tests .modeldiffs .diff import ModelDiffRunner
1616
1717
1818def key_transform (k ):
@@ -74,11 +74,11 @@ def sd_transform(sd):
7474 rng = jax .random .PRNGKey (0 ),
7575 update_batch_norm = False )
7676
77- out_diff (
77+ ModelDiffRunner (
7878 jax_workload = jax_workload ,
7979 pytorch_workload = pytorch_workload ,
8080 jax_model_kwargs = jax_model_kwargs ,
8181 pytorch_model_kwargs = pytorch_model_kwargs ,
8282 key_transform = key_transform ,
8383 sd_transform = sd_transform ,
84- out_transform = None )
84+ out_transform = None ). run ()
Original file line number Diff line number Diff line change @@ -64,3 +64,40 @@ def out_diff(jax_workload,
6464
6565 print (f'Max fprop difference between jax and pytorch: { max_diff } ' )
6666 print (f'Min fprop difference between jax and pytorch: { min_diff } ' )
67+
68+
69+ class ModelDiffRunner :
70+ def __init__ (self , jax_workload ,
71+ pytorch_workload ,
72+ jax_model_kwargs ,
73+ pytorch_model_kwargs ,
74+ key_transform = None ,
75+ sd_transform = None ,
76+ out_transform = None ) -> None :
77+ """Initializes the instance based on diffing logic.
78+ Args:
79+ jax_workload: Workload implementation using JAX
80+ pytorch_workload: Workload implementation using PyTorch
81+ jax_model_kwargs: Arguments to be used for model_fn in jax workload
82+ pytorch_model_kwargs: Arguments to be used for model_fn in PyTorch workload
83+ key_transform: Transformation function for keys.
84+ sd_transform: Transformation function for State Dictionary.
85+ out_transform: Transformation function for the output.
86+ """
87+
88+ self .jax_workload = jax_workload
89+ self .pytorch_workload = pytorch_workload
90+ self .jax_model_kwargs = jax_model_kwargs
91+ self .pytorch_model_kwargs = pytorch_model_kwargs
92+ self .key_transform = key_transform
93+ self .sd_transform = sd_transform
94+ self .out_transform = out_transform
95+
96+ def run (self ):
97+ out_diff (self .jax_workload ,
98+ self .pytorch_workload ,
99+ self .jax_model_kwargs ,
100+ self .pytorch_model_kwargs ,
101+ self .key_transform ,
102+ self .sd_transform ,
103+ self .out_transform )
You can’t perform that action at this time.
0 commit comments