Skip to content

Commit a775004

Browse files
committed
shuffling the functions for better understanding
1 parent 7ec211f commit a775004

2 files changed

Lines changed: 58 additions & 61 deletions

File tree

custom_pytorch_jax_converter.py

Lines changed: 56 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import copy
77
import copy
88
from jax.tree_util import tree_map
9+
910
"""
1011
Jax default parameter structure:
1112
dict_keys(['Dense_0', 'Dense_1', 'Dense_2', 'Dense_3', 'Dense_4', 'Dense_5', 'Dense_6', 'Dense_7', 'embedding_table'])
@@ -44,11 +45,45 @@ def use_pytorch_weights_inplace(jax_params, file_name=None, replicate=False):
4445
for i, j in zip([0, 2, 4, 6, 8], range(3, 8)):
4546
jax_params[f'Dense_{j}']['kernel'] = jnp.array(numpy_weights[f'top_mlp.{i}.weight'].T)
4647
jax_params[f'Dense_{j}']['bias'] = jnp.array(numpy_weights[f'top_mlp.{i}.bias'])
47-
#jax_params = tree_map(lambda x: jnp.array(x), jax_params)
48+
4849
del state_dict
4950
return jax_params
5051

5152

53+
def use_pytorch_weights_cpu_copy(jax_params, file_name=None, replicate=False):
54+
55+
def deep_copy_to_cpu(pytree):
56+
return tree_map(lambda x: jax.device_put(jnp.array(copy.deepcopy(x)), device=jax.devices("cpu")[0]), pytree)
57+
58+
jax_copy = deep_copy_to_cpu(jax_params)
59+
# Load PyTorch state_dict lazily to CPU
60+
state_dict = torch.load(file_name, map_location='cpu')
61+
print(state_dict.keys())
62+
# Convert PyTorch tensors to NumPy arrays
63+
numpy_weights = {k: v.cpu().numpy() for k, v in state_dict.items()}
64+
65+
# --- Embedding Table ---
66+
embedding_table = np.concatenate([
67+
numpy_weights[f'embedding_chunk_{i}'] for i in range(4)
68+
], axis=0) # adjust axis depending on chunking direction
69+
70+
jax_copy['embedding_table'] = jnp.array(embedding_table)
71+
72+
# --- Bot MLP: Dense_0 to Dense_2 ---
73+
for i, j in zip([0, 2, 4], range(3)):
74+
jax_copy[f'Dense_{j}']['kernel'] = jnp.array(numpy_weights[f'bot_mlp.{i}.weight'].T)
75+
jax_copy[f'Dense_{j}']['bias'] = jnp.array(numpy_weights[f'bot_mlp.{i}.bias'])
76+
77+
# --- Top MLP: Dense_3 to Dense_7 ---
78+
for i, j in zip([0, 2, 4, 6, 8], range(3, 8)):
79+
jax_copy[f'Dense_{j}']['kernel'] = jnp.array(numpy_weights[f'top_mlp.{i}.weight'].T)
80+
jax_copy[f'Dense_{j}']['bias'] = jnp.array(numpy_weights[f'top_mlp.{i}.bias'])
81+
#jax_copy = tree_map(lambda x: jnp.array(x), jax_copy)
82+
del state_dict
83+
84+
return jax_copy
85+
86+
5287
def use_pytorch_weights_inplace_mnist(jax_params, file_name=None, replicate=False):
5388
# Load the PyTorch checkpoint
5489
ckpt = torch.load(file_name)
@@ -78,34 +113,6 @@ def use_pytorch_weights_inplace_mnist(jax_params, file_name=None, replicate=Fals
78113
return jax_params
79114

80115

81-
# def are_weights_equal(params1, params2, atol=1e-6, rtol=1e-6):
82-
# """Compares two JAX PyTrees of weights and prints where they differ."""
83-
# all_equal = True
84-
85-
# def compare_fn(p1, p2):
86-
# nonlocal all_equal
87-
# #if not jnp.allclose(p1, p2):
88-
# if not jnp.allclose(p1, p2, atol=atol, rtol=rtol):
89-
# logging.info("❌ Mismatch found:")
90-
# logging.info(f"Shape 1: {p1.shape}, Shape 2: {p2.shape}")
91-
# logging.info(f"Max diff: {jnp.max(jnp.abs(p1 - p2))}")
92-
# all_equal = False
93-
# return jnp.allclose(p1, p2, atol=atol, rtol=rtol)
94-
95-
# try:
96-
# _ = jax.tree_util.tree_map(compare_fn, params1, params2)
97-
# except Exception as e:
98-
# logging.info("❌ Structure mismatch or error during comparison:", e)
99-
# return False
100-
101-
# if all_equal:
102-
# logging.info("✅ All weights are equal (within tolerance)")
103-
# return all_equal
104-
105-
import jax
106-
import jax.numpy as jnp
107-
import logging
108-
109116
def maybe_unreplicate(pytree):
110117
"""If leading axis matches device count, strip it assuming it's pmap replication."""
111118
num_devices = jax.device_count()
@@ -150,37 +157,27 @@ def compare_fn(p1, p2):
150157

151158

152159

153-
def use_pytorch_weights2(jax_params, file_name=None, replicate=False):
154-
155-
def deep_copy_to_cpu(pytree):
156-
return tree_map(lambda x: jax.device_put(jnp.array(copy.deepcopy(x)), device=jax.devices("cpu")[0]), pytree)
157-
158-
breakpoint()
159-
jax_copy = deep_copy_to_cpu(jax_params)
160-
# Load PyTorch state_dict lazily to CPU
161-
state_dict = torch.load(file_name, map_location='cpu')
162-
print(state_dict.keys())
163-
# Convert PyTorch tensors to NumPy arrays
164-
numpy_weights = {k: v.cpu().numpy() for k, v in state_dict.items()}
165-
166-
# --- Embedding Table ---
167-
embedding_table = np.concatenate([
168-
numpy_weights[f'embedding_chunk_{i}'] for i in range(4)
169-
], axis=0) # adjust axis depending on chunking direction
170-
171-
jax_copy['embedding_table'] = jnp.array(embedding_table)
160+
# def are_weights_equal(params1, params2, atol=1e-6, rtol=1e-6):
161+
# """Compares two JAX PyTrees of weights and prints where they differ."""
162+
# all_equal = True
172163

173-
# --- Bot MLP: Dense_0 to Dense_2 ---
174-
for i, j in zip([0, 2, 4], range(3)):
175-
jax_copy[f'Dense_{j}']['kernel'] = jnp.array(numpy_weights[f'bot_mlp.{i}.weight'].T)
176-
jax_copy[f'Dense_{j}']['bias'] = jnp.array(numpy_weights[f'bot_mlp.{i}.bias'])
164+
# def compare_fn(p1, p2):
165+
# nonlocal all_equal
166+
# #if not jnp.allclose(p1, p2):
167+
# if not jnp.allclose(p1, p2, atol=atol, rtol=rtol):
168+
# logging.info("❌ Mismatch found:")
169+
# logging.info(f"Shape 1: {p1.shape}, Shape 2: {p2.shape}")
170+
# logging.info(f"Max diff: {jnp.max(jnp.abs(p1 - p2))}")
171+
# all_equal = False
172+
# return jnp.allclose(p1, p2, atol=atol, rtol=rtol)
177173

178-
# --- Top MLP: Dense_3 to Dense_7 ---
179-
for i, j in zip([0, 2, 4, 6, 8], range(3, 8)):
180-
jax_copy[f'Dense_{j}']['kernel'] = jnp.array(numpy_weights[f'top_mlp.{i}.weight'].T)
181-
jax_copy[f'Dense_{j}']['bias'] = jnp.array(numpy_weights[f'top_mlp.{i}.bias'])
182-
#jax_copy = tree_map(lambda x: jnp.array(x), jax_copy)
183-
del state_dict
174+
# try:
175+
# _ = jax.tree_util.tree_map(compare_fn, params1, params2)
176+
# except Exception as e:
177+
# logging.info("❌ Structure mismatch or error during comparison:", e)
178+
# return False
184179

185-
return jax_copy
180+
# if all_equal:
181+
# logging.info("✅ All weights are equal (within tolerance)")
182+
# return all_equal
186183

reference_algorithms/schedule_free/jax/submission.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
import jax.numpy as jnp
1111
from optax.contrib import schedule_free_adamw
1212
from algoperf import spec
13-
from custom_pytorch_jax_converter import use_pytorch_weights2, are_weights_equal
13+
from custom_pytorch_jax_converter import use_pytorch_weights_cpu_copy, are_weights_equal
1414

1515
_GRAD_CLIP_EPS = 1e-6
1616

@@ -174,7 +174,7 @@ def update_params(workload: spec.Workload,
174174
if global_step % 100 == 0 and workload.metrics_logger is not None:
175175
date_ = "2025-06-14"
176176
file_name = f"/results/schedule_free_test_pytorch_weights/criteo1tb_{date_}_after_{global_step}_steps.pth"
177-
params = use_pytorch_weights2(new_params, file_name=file_name, replicate=True)
177+
params = use_pytorch_weights_cpu_copy(new_params, file_name=file_name, replicate=True)
178178
are_weights_equal(new_params, params)
179179
del params
180180

0 commit comments

Comments
 (0)