Skip to content

Commit c8e0a32

Browse files
Merge pull request mlcommons#902 from mlcommons/lm_workload_tf32
LM workload tf32
2 parents 7d9436b + 5d4cee9 commit c8e0a32

32 files changed

Lines changed: 3472 additions & 35 deletions

.gitignore

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,4 +25,4 @@ scoring/plots/
2525
!scoring/test_data/experiment_dir/study_0/mnist_jax/trial_0/eval_measurements.csv
2626
!scoring/test_data/experiment_dir/study_0/mnist_jax/trial_1/eval_measurements.csv
2727

28-
algoperf/_version.py
28+
algoperf/_version.py

algoperf/checkpoint_utils.py

Lines changed: 48 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,16 @@
55
"""
66

77
import os
8-
from typing import Sequence, Tuple
8+
from typing import Optional, Sequence, Tuple
99

1010
import numpy as np
11+
import orbax.checkpoint as ocp
1112
import torch
1213
from absl import logging
1314
from flax import jax_utils
1415
from flax.training import checkpoints as flax_checkpoints
1516
from flax.training.checkpoints import latest_checkpoint
17+
from orbax.checkpoint.type_handlers import NumpyHandler
1618
from tensorflow.io import gfile # pytype: disable=import-error
1719

1820
from algoperf import spec
@@ -30,6 +32,51 @@
3032
]
3133

3234

35+
class BoolHandler(NumpyHandler):
36+
"""
37+
An implementation of TypeHandler for np.bool_ that inherits from NumpyHandler.
38+
It works by treating the scalar as a 0-dimensional array.
39+
"""
40+
41+
def typestr(self) -> str:
42+
"""Unique string identifier for this handler."""
43+
return 'np.bool_'
44+
45+
async def serialize(
46+
self,
47+
values: Sequence[np.bool_],
48+
infos: Sequence,
49+
args: Optional[Sequence[ocp.SaveArgs]] = None,
50+
):
51+
"""
52+
Serializes a sequence of np.bool_ scalars by first converting them
53+
to 0-dim numpy arrays and then calling the parent NumpyHandler.
54+
"""
55+
# Convert each scalar np.bool_ to a 0-dimensional np.ndarray
56+
array_values = [np.asarray(v, dtype=np.bool_) for v in values]
57+
# Use the parent class's robust serialization logic
58+
return await super().serialize(array_values, infos, args)
59+
60+
async def deserialize(
61+
self,
62+
infos: Sequence,
63+
args: Optional[Sequence[ocp.RestoreArgs]] = None,
64+
) -> Sequence[np.bool_]:
65+
"""
66+
Deserializes into a sequence of np.bool_ scalars by calling the
67+
parent handler and then converting the resulting 0-dim arrays.
68+
"""
69+
# Parent deserialize will return a sequence of 0-dimensional np.ndarray
70+
results = await super().deserialize(infos, args)
71+
72+
# Convert each 0-d array back to an np.bool_ scalar using .item()
73+
scalar_results = [np.bool_(r.item()) for r in results]
74+
return scalar_results
75+
76+
77+
ocp.type_handlers.register_type_handler(np.bool_, BoolHandler(), override=True)
78+
79+
3380
def maybe_restore_checkpoint(
3481
framework: str,
3582
optimizer_state: spec.OptimizerState,

algoperf/param_utils.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,8 @@ def pytorch_param_types(
4444
param_types[name] = spec.ParameterType.ATTENTION_BIAS
4545
elif 'in_proj' in name:
4646
param_types[name] = spec.ParameterType.ATTENTION_QKV
47+
elif 'qkv' in name:
48+
param_types[name] = spec.ParameterType.ATTENTION_QKV
4749
elif 'kv_proj' in name:
4850
param_types[name] = spec.ParameterType.ATTENTION_KV
4951
elif 'k_proj' in name or 'key' in name:

algoperf/pytorch_utils.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,14 +21,17 @@
2121

2222
def pytorch_setup() -> Tuple[bool, int, torch.device, int]:
2323
torch.set_float32_matmul_precision('high')
24+
2425
use_pytorch_ddp = 'LOCAL_RANK' in os.environ
2526
rank = int(os.environ['LOCAL_RANK']) if use_pytorch_ddp else 0
2627
device = torch.device(f'cuda:{rank}' if torch.cuda.is_available() else 'cpu')
2728
n_gpus = torch.cuda.device_count()
2829
return use_pytorch_ddp, rank, device, n_gpus
2930

3031

31-
def pytorch_init(use_pytorch_ddp: bool, rank: int, profiler: Profiler) -> None:
32+
def pytorch_init(
33+
use_pytorch_ddp: bool, rank: int, profiler: Profiler, limit_tf_threads=True
34+
) -> None:
3235
# Make sure no GPU memory is preallocated to Jax.
3336
os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = 'false'
3437
# Only use CPU for Jax to avoid memory issues.
@@ -40,7 +43,7 @@ def pytorch_init(use_pytorch_ddp: bool, rank: int, profiler: Profiler) -> None:
4043

4144
if use_pytorch_ddp:
4245
# Avoid tf input pipeline creating too many threads.
43-
if rank != 0:
46+
if rank != 0 and limit_tf_threads:
4447
tf.config.threading.set_intra_op_parallelism_threads(1)
4548
tf.config.threading.set_inter_op_parallelism_threads(1)
4649

algoperf/random_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,13 +35,13 @@ def _signed_to_unsigned(seed: SeedType) -> SeedType:
3535

3636
def _fold_in(seed: SeedType, data: Any) -> List[Union[SeedType, Any]]:
3737
rng = np.random.RandomState(seed=_signed_to_unsigned(seed))
38-
new_seed = rng.randint(MIN_INT32, MAX_INT32, dtype=np.int32)
38+
new_seed = rng.randint(MIN_INT32, MAX_INT32, dtype=np.uint32)
3939
return [new_seed, data]
4040

4141

4242
def _split(seed: SeedType, num: int = 2) -> SeedType:
4343
rng = np.random.RandomState(seed=_signed_to_unsigned(seed))
44-
return rng.randint(MIN_INT32, MAX_INT32, dtype=np.int32, size=[num, 2])
44+
return rng.randint(MIN_INT32, MAX_INT32, dtype=np.uint32, size=[num, 2])
4545

4646

4747
def _PRNGKey(seed: SeedType) -> SeedType: # pylint: disable=invalid-name

algoperf/workloads/finewebedu_lm/__init__.py

Whitespace-only changes.

algoperf/workloads/finewebedu_lm/finewebedu_lm_jax/__init__.py

Whitespace-only changes.

0 commit comments

Comments
 (0)