Skip to content

Commit ac026a3

Browse files
committed
feat: schedule free changes
1 parent 546ad7d commit ac026a3

5 files changed

Lines changed: 597 additions & 0 deletions

File tree

algoperf/workloads/criteo1tb/criteo1tb_jax/workload.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from algoperf.workloads.criteo1tb.criteo1tb_jax import models
1414
from algoperf.workloads.criteo1tb.workload import \
1515
BaseCriteo1TbDlrmSmallWorkload
16+
from custom_pytorch_jax_converter import use_pytorch_weights
1617

1718

1819
class Criteo1TbDlrmSmallWorkload(BaseCriteo1TbDlrmSmallWorkload):
@@ -103,6 +104,7 @@ def init_model_fn(
103104
{'params': params_rng, 'dropout': dropout_rng},
104105
jnp.ones(input_shape, jnp.float32))
105106
initial_params = initial_variables['params']
107+
initial_params = use_pytorch_weights(initial_params)
106108
self._param_shapes = param_utils.jax_param_shapes(initial_params)
107109
self._param_types = param_utils.jax_param_types(self._param_shapes)
108110
return jax_utils.replicate(initial_params), None

algoperf/workloads/criteo1tb/criteo1tb_pytorch/workload.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,7 @@ def init_model_fn(
8888
dropout_rate=dropout_rate,
8989
use_layer_norm=self.use_layer_norm,
9090
embedding_init_multiplier=self.embedding_init_multiplier)
91+
torch.save(model.state_dict(), "/results/pytorch_base_model_criteo1tb_22_may.pth")
9192
self._param_shapes = param_utils.pytorch_param_shapes(model)
9293
self._param_types = param_utils.pytorch_param_types(self._param_shapes)
9394
model.to(DEVICE)

custom_pytorch_jax_converter.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
import torch
2+
import numpy as np
3+
from flax.core import freeze, unfreeze
4+
5+
# Load PyTorch state_dict
6+
state_dict = torch.load("/results/pytorch_base_model_criteo1tb_22_may.pth")
7+
8+
# Convert PyTorch tensors to NumPy arrays
9+
numpy_weights = {k: v.numpy() for k, v in state_dict.items()}
10+
11+
12+
"""
13+
Jax default parameter structure:
14+
dict_keys(['Dense_0', 'Dense_1', 'Dense_2', 'Dense_3', 'Dense_4', 'Dense_5', 'Dense_6', 'Dense_7', 'embedding_table'])
15+
16+
Pytorch stateduct structure:
17+
dict_keys(['embedding_chunk_0', 'embedding_chunk_1', 'embedding_chunk_2', 'embedding_chunk_3', 'bot_mlp.0.weight', 'bot_mlp.0.bias', 'bot_mlp.2.weight', 'bot_mlp.2.bias', 'bot_mlp.4.weight', 'bot_mlp.4.bias', 'top_mlp.0.weight', 'top_mlp.0.bias', 'top_mlp.2.weight', 'top_mlp.2.bias', 'top_mlp.4.weight', 'top_mlp.4.bias', 'top_mlp.6.weight', 'top_mlp.6.bias', 'top_mlp.8.weight', 'top_mlp.8.bias'])
18+
19+
20+
21+
The following function converts the PyTorch weights to the Jax format
22+
and assigns them to the Jax model parameters.
23+
The function assumes that the Jax model parameters are already initialized
24+
and that the PyTorch weights are in the correct format.
25+
"""
26+
def use_pytorch_weights(jax_params):
27+
# --- Embedding Table ---
28+
embedding_table = np.concatenate([
29+
numpy_weights[f'embedding_chunk_{i}'] for i in range(4)
30+
], axis=0) # adjust axis depending on chunking direction
31+
32+
jax_params['embedding_table'] = embedding_table
33+
34+
# --- Bot MLP: Dense_0 to Dense_2 ---
35+
for i, j in zip([0, 2, 4], range(3)):
36+
jax_params[f'Dense_{j}']['kernel'] = numpy_weights[f'bot_mlp.{i}.weight'].T
37+
jax_params[f'Dense_{j}']['bias'] = numpy_weights[f'bot_mlp.{i}.bias']
38+
39+
# --- Top MLP: Dense_3 to Dense_7 ---
40+
for i, j in zip([0, 2, 4, 6, 8], range(3, 8)):
41+
jax_params[f'Dense_{j}']['kernel'] = numpy_weights[f'top_mlp.{i}.weight'].T
42+
jax_params[f'Dense_{j}']['bias'] = numpy_weights[f'top_mlp.{i}.bias']
43+
44+
return jax_params
Lines changed: 222 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,222 @@
1+
"""Submission file for an Schedule Free AdamW optimizer in Jax."""
2+
3+
import functools
4+
from typing import Dict, Iterator, List, Tuple
5+
import optax
6+
7+
from flax import jax_utils
8+
import jax
9+
from jax import lax
10+
import jax.numpy as jnp
11+
from optax.contrib import schedule_free_adamw
12+
from algoperf import spec
13+
14+
_GRAD_CLIP_EPS = 1e-6
15+
16+
HPARAMS = {
17+
"dropout_rate": 0.1,
18+
"learning_rate": 0.0025,
19+
"one_minus_beta1": 0.1,
20+
"beta2": 0.9955159689799007,
21+
"weight_decay": 0.08121616522670176,
22+
"warmup_factor": 0.02,
23+
"weight_lr_power": 2,
24+
"label_smoothing": 0.2,
25+
"r": 0.75,
26+
"eps": 1e-8,
27+
}
28+
29+
def init_optimizer_state(workload: spec.Workload,
30+
model_params: spec.ParameterContainer,
31+
model_state: spec.ModelAuxiliaryState,
32+
hyperparameters: spec.Hyperparameters,
33+
rng: spec.RandomState) -> spec.OptimizerState:
34+
"""Creates an AdamW optimizer and a learning rate schedule."""
35+
del model_params
36+
del model_state
37+
del rng
38+
lr=HPARAMS['learning_rate']
39+
betas=(1.0 - HPARAMS['one_minus_beta1'], HPARAMS['beta2'])
40+
warmup_steps=int(HPARAMS['warmup_factor'] * workload.step_hint * 0.75)
41+
weight_decay=HPARAMS['weight_decay']
42+
weight_lr_power=HPARAMS['weight_lr_power']
43+
r=HPARAMS['r']
44+
45+
opt_init_fn, opt_update_fn = schedule_free_adamw(
46+
learning_rate=HPARAMS['learning_rate'],
47+
warmup_steps=int(HPARAMS['warmup_factor'] * workload.step_hint * 0.75),
48+
49+
b1=1.0 - HPARAMS['one_minus_beta1'],
50+
b2=HPARAMS['beta2'],
51+
eps=HPARAMS['eps'],
52+
weight_decay=HPARAMS['weight_decay'],
53+
weight_lr_power=HPARAMS['weight_lr_power'],
54+
)
55+
params_zeros_like = jax.tree_map(lambda s: jnp.zeros(s.shape_tuple),
56+
workload.param_shapes)
57+
optimizer_state = opt_init_fn(params_zeros_like)
58+
59+
return jax_utils.replicate(optimizer_state), opt_update_fn
60+
61+
62+
@functools.partial(
63+
jax.pmap,
64+
axis_name='batch',
65+
in_axes=(None, None, 0, 0, 0, 0, 0, None, None),
66+
static_broadcasted_argnums=(0, 1),
67+
donate_argnums=(2, 3, 4))
68+
def pmapped_train_step(workload,
69+
opt_update_fn,
70+
model_state,
71+
optimizer_state,
72+
current_param_container,
73+
batch,
74+
rng,
75+
grad_clip,
76+
label_smoothing):
77+
78+
def _loss_fn(params):
79+
"""Loss function used for training."""
80+
logits, new_model_state = workload.model_fn(
81+
params,
82+
batch,
83+
model_state,
84+
spec.ForwardPassMode.TRAIN,
85+
rng,
86+
update_batch_norm=True)
87+
loss_dict = workload.loss_fn(
88+
label_batch=batch['targets'],
89+
logits_batch=logits,
90+
mask_batch=batch.get('weights'),
91+
label_smoothing=label_smoothing)
92+
summed_loss = loss_dict['summed']
93+
n_valid_examples = loss_dict['n_valid_examples']
94+
return summed_loss, (n_valid_examples, new_model_state)
95+
96+
grad_fn = jax.value_and_grad(_loss_fn, has_aux=True)
97+
(summed_loss, (n_valid_examples, new_model_state)), grad = grad_fn(
98+
current_param_container)
99+
# Get correct global mean loss and grad.
100+
(summed_loss, n_valid_examples, grad) = lax.psum(
101+
(summed_loss, n_valid_examples, grad), axis_name='batch')
102+
loss = summed_loss / n_valid_examples
103+
grad = jax.tree_map(lambda x: x / n_valid_examples, grad)
104+
105+
grad_norm = jnp.sqrt(
106+
sum(jnp.sum(g**2) for g in jax.tree_util.tree_leaves(grad)))
107+
108+
# Extract the leaves of the pytree
109+
leaves = jax.tree_util.tree_leaves(grad)
110+
# Count the total number of elements in all leaves
111+
total_size = sum(jnp.size(leaf) for leaf in leaves)
112+
113+
# jax.debug.print('GRAD NORM {}', grad_norm)
114+
# jax.debug.print('NUM PARAMS {}', total_size)
115+
116+
if grad_clip is not None:
117+
grad_scaling_factor = grad_clip / (grad_norm + _GRAD_CLIP_EPS)
118+
grad_scaling_factor = jax.lax.clamp(min=0.0, x=grad_scaling_factor, max=1.0)
119+
grad = jax.tree_map(lambda x: x * grad_scaling_factor, grad)
120+
121+
updates, new_optimizer_state = opt_update_fn(grad, optimizer_state,
122+
current_param_container)
123+
updated_params = optax.apply_updates(current_param_container, updates)
124+
return new_optimizer_state, updated_params, new_model_state, loss, grad_norm
125+
126+
127+
def update_params(workload: spec.Workload,
128+
current_param_container: spec.ParameterContainer,
129+
current_params_types: spec.ParameterTypeTree,
130+
model_state: spec.ModelAuxiliaryState,
131+
hyperparameters: spec.Hyperparameters,
132+
batch: Dict[str, spec.Tensor],
133+
loss_type: spec.LossType,
134+
optimizer_state: spec.OptimizerState,
135+
eval_results: List[Tuple[int, float]],
136+
global_step: int,
137+
rng: spec.RandomState) -> spec.UpdateReturn:
138+
"""Return (updated_optimizer_state, updated_params, updated_model_state)."""
139+
del current_params_types
140+
del loss_type
141+
del eval_results
142+
143+
optimizer_state, opt_update_fn = optimizer_state
144+
per_device_rngs = jax.random.split(rng, jax.local_device_count())
145+
if hasattr(hyperparameters, 'label_smoothing'):
146+
label_smoothing = hyperparameters.label_smoothing
147+
else:
148+
label_smoothing = 0.0
149+
if hasattr(hyperparameters, 'grad_clip'):
150+
grad_clip = hyperparameters.grad_clip
151+
else:
152+
grad_clip = None
153+
outputs = pmapped_train_step(workload,
154+
opt_update_fn,
155+
model_state,
156+
optimizer_state,
157+
current_param_container,
158+
batch,
159+
per_device_rngs,
160+
grad_clip,
161+
label_smoothing)
162+
new_optimizer_state, new_params, new_model_state, loss, grad_norm = outputs
163+
164+
# Log loss, grad_norm.
165+
if global_step % 100 == 0 and workload.metrics_logger is not None:
166+
workload.metrics_logger.append_scalar_metrics(
167+
{
168+
'loss': loss[0],
169+
'grad_norm': grad_norm[0],
170+
}, global_step)
171+
return (new_optimizer_state, opt_update_fn), new_params, new_model_state
172+
173+
174+
def get_batch_size(workload_name):
175+
# Return the global batch size.
176+
if workload_name == 'criteo1tb':
177+
return 262_144
178+
elif workload_name == 'fastmri':
179+
return 32
180+
elif workload_name == 'imagenet_resnet':
181+
return 1024
182+
elif workload_name == 'imagenet_resnet_silu':
183+
return 512
184+
elif workload_name == 'imagenet_resnet_gelu':
185+
return 512
186+
elif workload_name == 'imagenet_vit':
187+
return 1024
188+
elif workload_name == 'librispeech_conformer':
189+
return 256
190+
elif workload_name == 'librispeech_deepspeech':
191+
return 256
192+
elif workload_name == 'ogbg':
193+
return 512
194+
elif workload_name == 'wmt':
195+
return 128
196+
elif workload_name == 'mnist':
197+
return 16
198+
else:
199+
raise ValueError(f'Unsupported workload name: {workload_name}.')
200+
201+
202+
def data_selection(workload: spec.Workload,
203+
input_queue: Iterator[Dict[str, spec.Tensor]],
204+
optimizer_state: spec.OptimizerState,
205+
current_param_container: spec.ParameterContainer,
206+
model_state: spec.ModelAuxiliaryState,
207+
hyperparameters: spec.Hyperparameters,
208+
global_step: int,
209+
rng: spec.RandomState) -> Dict[str, spec.Tensor]:
210+
"""Select data from the infinitely repeating, pre-shuffled input queue.
211+
Each element of the queue is a batch of training examples and labels.
212+
"""
213+
del workload
214+
del optimizer_state
215+
del current_param_container
216+
del model_state
217+
del hyperparameters
218+
del global_step
219+
del rng
220+
batch = next(input_queue)
221+
return batch
222+

0 commit comments

Comments
 (0)