|
6 | 6 | import copy |
7 | 7 | import copy |
8 | 8 | from jax.tree_util import tree_map |
| 9 | + |
9 | 10 | """ |
10 | 11 | Jax default parameter structure: |
11 | 12 | dict_keys(['Dense_0', 'Dense_1', 'Dense_2', 'Dense_3', 'Dense_4', 'Dense_5', 'Dense_6', 'Dense_7', 'embedding_table']) |
@@ -44,11 +45,45 @@ def use_pytorch_weights_inplace(jax_params, file_name=None, replicate=False): |
44 | 45 | for i, j in zip([0, 2, 4, 6, 8], range(3, 8)): |
45 | 46 | jax_params[f'Dense_{j}']['kernel'] = jnp.array(numpy_weights[f'top_mlp.{i}.weight'].T) |
46 | 47 | jax_params[f'Dense_{j}']['bias'] = jnp.array(numpy_weights[f'top_mlp.{i}.bias']) |
47 | | - #jax_params = tree_map(lambda x: jnp.array(x), jax_params) |
| 48 | + |
48 | 49 | del state_dict |
49 | 50 | return jax_params |
50 | 51 |
|
51 | 52 |
|
| 53 | +def use_pytorch_weights_cpu_copy(jax_params, file_name=None, replicate=False): |
| 54 | + |
| 55 | + def deep_copy_to_cpu(pytree): |
| 56 | + return tree_map(lambda x: jax.device_put(jnp.array(copy.deepcopy(x)), device=jax.devices("cpu")[0]), pytree) |
| 57 | + |
| 58 | + jax_copy = deep_copy_to_cpu(jax_params) |
| 59 | + # Load PyTorch state_dict lazily to CPU |
| 60 | + state_dict = torch.load(file_name, map_location='cpu') |
| 61 | + print(state_dict.keys()) |
| 62 | + # Convert PyTorch tensors to NumPy arrays |
| 63 | + numpy_weights = {k: v.cpu().numpy() for k, v in state_dict.items()} |
| 64 | + |
| 65 | + # --- Embedding Table --- |
| 66 | + embedding_table = np.concatenate([ |
| 67 | + numpy_weights[f'embedding_chunk_{i}'] for i in range(4) |
| 68 | + ], axis=0) # adjust axis depending on chunking direction |
| 69 | + |
| 70 | + jax_copy['embedding_table'] = jnp.array(embedding_table) |
| 71 | + |
| 72 | + # --- Bot MLP: Dense_0 to Dense_2 --- |
| 73 | + for i, j in zip([0, 2, 4], range(3)): |
| 74 | + jax_copy[f'Dense_{j}']['kernel'] = jnp.array(numpy_weights[f'bot_mlp.{i}.weight'].T) |
| 75 | + jax_copy[f'Dense_{j}']['bias'] = jnp.array(numpy_weights[f'bot_mlp.{i}.bias']) |
| 76 | + |
| 77 | + # --- Top MLP: Dense_3 to Dense_7 --- |
| 78 | + for i, j in zip([0, 2, 4, 6, 8], range(3, 8)): |
| 79 | + jax_copy[f'Dense_{j}']['kernel'] = jnp.array(numpy_weights[f'top_mlp.{i}.weight'].T) |
| 80 | + jax_copy[f'Dense_{j}']['bias'] = jnp.array(numpy_weights[f'top_mlp.{i}.bias']) |
| 81 | + #jax_copy = tree_map(lambda x: jnp.array(x), jax_copy) |
| 82 | + del state_dict |
| 83 | + |
| 84 | + return jax_copy |
| 85 | + |
| 86 | + |
52 | 87 | def use_pytorch_weights_inplace_mnist(jax_params, file_name=None, replicate=False): |
53 | 88 | # Load the PyTorch checkpoint |
54 | 89 | ckpt = torch.load(file_name) |
@@ -78,34 +113,6 @@ def use_pytorch_weights_inplace_mnist(jax_params, file_name=None, replicate=Fals |
78 | 113 | return jax_params |
79 | 114 |
|
80 | 115 |
|
81 | | -# def are_weights_equal(params1, params2, atol=1e-6, rtol=1e-6): |
82 | | -# """Compares two JAX PyTrees of weights and prints where they differ.""" |
83 | | -# all_equal = True |
84 | | - |
85 | | -# def compare_fn(p1, p2): |
86 | | -# nonlocal all_equal |
87 | | -# #if not jnp.allclose(p1, p2): |
88 | | -# if not jnp.allclose(p1, p2, atol=atol, rtol=rtol): |
89 | | -# logging.info("❌ Mismatch found:") |
90 | | -# logging.info(f"Shape 1: {p1.shape}, Shape 2: {p2.shape}") |
91 | | -# logging.info(f"Max diff: {jnp.max(jnp.abs(p1 - p2))}") |
92 | | -# all_equal = False |
93 | | -# return jnp.allclose(p1, p2, atol=atol, rtol=rtol) |
94 | | - |
95 | | -# try: |
96 | | -# _ = jax.tree_util.tree_map(compare_fn, params1, params2) |
97 | | -# except Exception as e: |
98 | | -# logging.info("❌ Structure mismatch or error during comparison:", e) |
99 | | -# return False |
100 | | - |
101 | | -# if all_equal: |
102 | | -# logging.info("✅ All weights are equal (within tolerance)") |
103 | | -# return all_equal |
104 | | - |
105 | | -import jax |
106 | | -import jax.numpy as jnp |
107 | | -import logging |
108 | | - |
109 | 116 | def maybe_unreplicate(pytree): |
110 | 117 | """If leading axis matches device count, strip it assuming it's pmap replication.""" |
111 | 118 | num_devices = jax.device_count() |
@@ -150,37 +157,27 @@ def compare_fn(p1, p2): |
150 | 157 |
|
151 | 158 |
|
152 | 159 |
|
153 | | -def use_pytorch_weights2(jax_params, file_name=None, replicate=False): |
154 | | - |
155 | | - def deep_copy_to_cpu(pytree): |
156 | | - return tree_map(lambda x: jax.device_put(jnp.array(copy.deepcopy(x)), device=jax.devices("cpu")[0]), pytree) |
157 | | - |
158 | | - breakpoint() |
159 | | - jax_copy = deep_copy_to_cpu(jax_params) |
160 | | - # Load PyTorch state_dict lazily to CPU |
161 | | - state_dict = torch.load(file_name, map_location='cpu') |
162 | | - print(state_dict.keys()) |
163 | | - # Convert PyTorch tensors to NumPy arrays |
164 | | - numpy_weights = {k: v.cpu().numpy() for k, v in state_dict.items()} |
165 | | - |
166 | | - # --- Embedding Table --- |
167 | | - embedding_table = np.concatenate([ |
168 | | - numpy_weights[f'embedding_chunk_{i}'] for i in range(4) |
169 | | - ], axis=0) # adjust axis depending on chunking direction |
170 | | - |
171 | | - jax_copy['embedding_table'] = jnp.array(embedding_table) |
| 160 | +# def are_weights_equal(params1, params2, atol=1e-6, rtol=1e-6): |
| 161 | +# """Compares two JAX PyTrees of weights and prints where they differ.""" |
| 162 | +# all_equal = True |
172 | 163 |
|
173 | | - # --- Bot MLP: Dense_0 to Dense_2 --- |
174 | | - for i, j in zip([0, 2, 4], range(3)): |
175 | | - jax_copy[f'Dense_{j}']['kernel'] = jnp.array(numpy_weights[f'bot_mlp.{i}.weight'].T) |
176 | | - jax_copy[f'Dense_{j}']['bias'] = jnp.array(numpy_weights[f'bot_mlp.{i}.bias']) |
| 164 | +# def compare_fn(p1, p2): |
| 165 | +# nonlocal all_equal |
| 166 | +# #if not jnp.allclose(p1, p2): |
| 167 | +# if not jnp.allclose(p1, p2, atol=atol, rtol=rtol): |
| 168 | +# logging.info("❌ Mismatch found:") |
| 169 | +# logging.info(f"Shape 1: {p1.shape}, Shape 2: {p2.shape}") |
| 170 | +# logging.info(f"Max diff: {jnp.max(jnp.abs(p1 - p2))}") |
| 171 | +# all_equal = False |
| 172 | +# return jnp.allclose(p1, p2, atol=atol, rtol=rtol) |
177 | 173 |
|
178 | | - # --- Top MLP: Dense_3 to Dense_7 --- |
179 | | - for i, j in zip([0, 2, 4, 6, 8], range(3, 8)): |
180 | | - jax_copy[f'Dense_{j}']['kernel'] = jnp.array(numpy_weights[f'top_mlp.{i}.weight'].T) |
181 | | - jax_copy[f'Dense_{j}']['bias'] = jnp.array(numpy_weights[f'top_mlp.{i}.bias']) |
182 | | - #jax_copy = tree_map(lambda x: jnp.array(x), jax_copy) |
183 | | - del state_dict |
| 174 | +# try: |
| 175 | +# _ = jax.tree_util.tree_map(compare_fn, params1, params2) |
| 176 | +# except Exception as e: |
| 177 | +# logging.info("❌ Structure mismatch or error during comparison:", e) |
| 178 | +# return False |
184 | 179 |
|
185 | | - return jax_copy |
| 180 | +# if all_equal: |
| 181 | +# logging.info("✅ All weights are equal (within tolerance)") |
| 182 | +# return all_equal |
186 | 183 |
|
0 commit comments