22import numpy as np
33import jax
44import jax .numpy as jnp
5+ import logging
6+ import copy
7+ import copy
8+ from jax .tree_util import tree_map
59"""
610Jax default parameter structure:
711dict_keys(['Dense_0', 'Dense_1', 'Dense_2', 'Dense_3', 'Dense_4', 'Dense_5', 'Dense_6', 'Dense_7', 'embedding_table'])
1620The function assumes that the Jax model parameters are already initialized
1721and that the PyTorch weights are in the correct format.
1822"""
19- def use_pytorch_weights (jax_params , file_name = None ):
23+ def use_pytorch_weights_inplace (jax_params , file_name = None , replicate = False ):
24+
2025 # Load PyTorch state_dict
2126 state_dict = torch .load (file_name )
2227 print (state_dict .keys ())
@@ -28,41 +33,125 @@ def use_pytorch_weights(jax_params, file_name=None):
2833 numpy_weights [f'embedding_chunk_{ i } ' ] for i in range (4 )
2934 ], axis = 0 ) # adjust axis depending on chunking direction
3035
31- jax_params ['embedding_table' ] = embedding_table
36+ jax_params ['embedding_table' ] = jnp . array ( embedding_table )
3237
3338 # --- Bot MLP: Dense_0 to Dense_2 ---
3439 for i , j in zip ([0 , 2 , 4 ], range (3 )):
35- jax_params [f'Dense_{ j } ' ]['kernel' ] = numpy_weights [f'bot_mlp.{ i } .weight' ].T
36- jax_params [f'Dense_{ j } ' ]['bias' ] = numpy_weights [f'bot_mlp.{ i } .bias' ]
40+ jax_params [f'Dense_{ j } ' ]['kernel' ] = jnp . array ( numpy_weights [f'bot_mlp.{ i } .weight' ].T )
41+ jax_params [f'Dense_{ j } ' ]['bias' ] = jnp . array ( numpy_weights [f'bot_mlp.{ i } .bias' ])
3742
3843 # --- Top MLP: Dense_3 to Dense_7 ---
3944 for i , j in zip ([0 , 2 , 4 , 6 , 8 ], range (3 , 8 )):
40- jax_params [f'Dense_{ j } ' ]['kernel' ] = numpy_weights [f'top_mlp.{ i } .weight' ].T
41- jax_params [f'Dense_{ j } ' ]['bias' ] = numpy_weights [f'top_mlp.{ i } .bias' ]
42-
45+ jax_params [f'Dense_{ j } ' ]['kernel' ] = jnp .array (numpy_weights [f'top_mlp.{ i } .weight' ].T )
46+ jax_params [f'Dense_{ j } ' ]['bias' ] = jnp .array (numpy_weights [f'top_mlp.{ i } .bias' ])
47+ #jax_params = tree_map(lambda x: jnp.array(x), jax_params)
48+ del state_dict
4349 return jax_params
4450
4551
52+ # def are_weights_equal(params1, params2, atol=1e-6, rtol=1e-6):
53+ # """Compares two JAX PyTrees of weights and prints where they differ."""
54+ # all_equal = True
55+
56+ # def compare_fn(p1, p2):
57+ # nonlocal all_equal
58+ # #if not jnp.allclose(p1, p2):
59+ # if not jnp.allclose(p1, p2, atol=atol, rtol=rtol):
60+ # logging.info("❌ Mismatch found:")
61+ # logging.info(f"Shape 1: {p1.shape}, Shape 2: {p2.shape}")
62+ # logging.info(f"Max diff: {jnp.max(jnp.abs(p1 - p2))}")
63+ # all_equal = False
64+ # return jnp.allclose(p1, p2, atol=atol, rtol=rtol)
65+
66+ # try:
67+ # _ = jax.tree_util.tree_map(compare_fn, params1, params2)
68+ # except Exception as e:
69+ # logging.info("❌ Structure mismatch or error during comparison:", e)
70+ # return False
71+
72+ # if all_equal:
73+ # logging.info("✅ All weights are equal (within tolerance)")
74+ # return all_equal
75+
76+ import jax
77+ import jax .numpy as jnp
78+ import logging
79+
80+ def maybe_unreplicate (pytree ):
81+ """If leading axis matches device count, strip it assuming it's pmap replication."""
82+ num_devices = jax .device_count ()
83+ return jax .tree_util .tree_map (
84+ lambda x : x [0 ] if isinstance (x , jax .Array ) and x .shape [0 ] == num_devices else x ,
85+ pytree
86+ )
87+
88+ def move_to_cpu (tree ):
89+ return jax .tree_util .tree_map (lambda x : jax .device_put (x , device = jax .devices ("cpu" )[0 ]), tree )
90+
91+
4692def are_weights_equal (params1 , params2 , atol = 1e-6 , rtol = 1e-6 ):
47- """Compares two JAX PyTrees of weights and prints where they differ."""
93+ """Compares two JAX PyTrees of weights and logs where they differ, safely handling PMAP replication."""
94+ # Attempt to unreplicate if needed
95+ params1 = maybe_unreplicate (params1 )
96+ params2 = maybe_unreplicate (params2 )
97+
98+ params1 = move_to_cpu (params1 )
99+ params2 = move_to_cpu (params2 )
100+
48101 all_equal = True
49102
50103 def compare_fn (p1 , p2 ):
51104 nonlocal all_equal
52- #if not jnp.allclose(p1, p2):
53105 if not jnp .allclose (p1 , p2 , atol = atol , rtol = rtol ):
54- print ("❌ Mismatch found:" )
55- print (f"Shape 1: { p1 .shape } , Shape 2: { p2 .shape } " )
56- print (f"Max diff: { jnp .max (jnp .abs (p1 - p2 ))} " )
106+ logging . info ("❌ Mismatch found:" )
107+ logging . info (f"Shape 1: { p1 .shape } , Shape 2: { p2 .shape } " )
108+ logging . info (f"Max diff: { jnp .max (jnp .abs (p1 - p2 ))} " )
57109 all_equal = False
58110 return jnp .allclose (p1 , p2 , atol = atol , rtol = rtol )
59111
60112 try :
61- _ = jax .tree_util .tree_map (compare_fn , params1 , params2 )
113+ jax .tree_util .tree_map (compare_fn , params1 , params2 )
62114 except Exception as e :
63- print ("❌ Structure mismatch or error during comparison:" , e )
115+ logging . info ("❌ Structure mismatch or error during comparison:" , exc_info = True )
64116 return False
65117
66118 if all_equal :
67- print ("✅ All weights are equal (within tolerance)" )
119+ logging . info ("✅ All weights are equal (within tolerance)" )
68120 return all_equal
121+
122+
123+
124+ def use_pytorch_weights2 (jax_params , file_name = None , replicate = False ):
125+
126+ def deep_copy_to_cpu (pytree ):
127+ return tree_map (lambda x : jax .device_put (jnp .array (copy .deepcopy (x )), device = jax .devices ("cpu" )[0 ]), pytree )
128+
129+ breakpoint ()
130+ jax_copy = deep_copy_to_cpu (jax_params )
131+ # Load PyTorch state_dict lazily to CPU
132+ state_dict = torch .load (file_name , map_location = 'cpu' )
133+ print (state_dict .keys ())
134+ # Convert PyTorch tensors to NumPy arrays
135+ numpy_weights = {k : v .cpu ().numpy () for k , v in state_dict .items ()}
136+
137+ # --- Embedding Table ---
138+ embedding_table = np .concatenate ([
139+ numpy_weights [f'embedding_chunk_{ i } ' ] for i in range (4 )
140+ ], axis = 0 ) # adjust axis depending on chunking direction
141+
142+ jax_copy ['embedding_table' ] = jnp .array (embedding_table )
143+
144+ # --- Bot MLP: Dense_0 to Dense_2 ---
145+ for i , j in zip ([0 , 2 , 4 ], range (3 )):
146+ jax_copy [f'Dense_{ j } ' ]['kernel' ] = jnp .array (numpy_weights [f'bot_mlp.{ i } .weight' ].T )
147+ jax_copy [f'Dense_{ j } ' ]['bias' ] = jnp .array (numpy_weights [f'bot_mlp.{ i } .bias' ])
148+
149+ # --- Top MLP: Dense_3 to Dense_7 ---
150+ for i , j in zip ([0 , 2 , 4 , 6 , 8 ], range (3 , 8 )):
151+ jax_copy [f'Dense_{ j } ' ]['kernel' ] = jnp .array (numpy_weights [f'top_mlp.{ i } .weight' ].T )
152+ jax_copy [f'Dense_{ j } ' ]['bias' ] = jnp .array (numpy_weights [f'top_mlp.{ i } .bias' ])
153+ #jax_copy = tree_map(lambda x: jnp.array(x), jax_copy)
154+ del state_dict
155+
156+ return jax_copy
157+
0 commit comments