Skip to content

Commit 7d9436b

Browse files
Merge pull request mlcommons#892 from mlcommons/a100
Migrate workloads to A100 hardware weightclass
2 parents 7d8f609 + f7ce628 commit 7d9436b

27 files changed

Lines changed: 383 additions & 79 deletions

File tree

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ The MLCommons™ **AlgoPerf: Training Algorithms benchmark** is designed to find
3131
When training neural nets, practitioners face many critical yet often opaque decisions: What optimizer to choose? How should its learning rate be tuned? What learning rate schedule should be used? These choices can make or break training, yet the community has lacked a clear, standardized way to identify the state of the art.
3232
Unlike benchmarks focused on hardware or model architecture, AlgoPerf isolates the **training algorithm** itself, which includes the optimizer, regularization, data selection, and hyperparameters like the learning rate schedule. By standardizing the benchmark process, AlgoPerf offers a meaningful apples-to-apples comparison of training algorithms and follows the following **key principles**:
3333

34-
- 🎯 **Fixed Target, Model & Hardware:** Submitted training algorithms must train a set of [**fixed models**](/docs/DOCUMENTATION.md#workloads) to a pre-defined validation performance target as fast as possible. All submissions use the same model architecture and are run on the same [**standardized hardware**](/docs/DOCUMENTATION.md#benchmarking-hardware) (8x NVIDIA V100 GPUs). This isolates the training algorithm's performance and allows a fair apples-to-apples comparison.
34+
- 🎯 **Fixed Target, Model & Hardware:** Submitted training algorithms must train a set of [**fixed models**](/docs/DOCUMENTATION.md#workloads) to a pre-defined validation performance target as fast as possible. All submissions use the same model architecture and are run on the same [**standardized hardware**](/docs/DOCUMENTATION.md#benchmarking-hardware) (4x A100 (40GB) GPUs). This isolates the training algorithm's performance and allows a fair apples-to-apples comparison.
3535
- ⏱️ **Time-To-Result:** Submissions are evaluated based on the total wall-clock time required to reach the target, rewarding practical and efficient algorithms.
3636
- 🧠 **Diverse Workloads:** The benchmark includes [**8 diverse deep learning workloads**](/docs/DOCUMENTATION.md#workloads) across domains like image classification, speech recognition, and machine translation. A submission's score is computed by aggregating its performance, using [**performance profiles**](/docs/DOCUMENTATION.md#benchmark-score-using-performance-profiles), across all workloads to ensure general-purpose algorithms.
3737
- 📦 **Fully-Specified Algorithms:** Submissions must be complete procedures and thus hyperparameter tuning is treated as part of the algorithm. Submissions can either provide a search space for automated tuning ([**External tuning ruleset**](/docs/DOCUMENTATION.md#external-tuning-ruleset)) or be hyperparameter-free ([**Self-tuning ruleset**](/docs/DOCUMENTATION.md#self-tuning-ruleset)) with any tuning done automatically and "on the clock". This measures an algorithm's _total_ practical cost and provides practitioners with a complete method, eliminating the guesswork of how to apply it.

algoperf/pytorch_utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020

2121

2222
def pytorch_setup() -> Tuple[bool, int, torch.device, int]:
23+
torch.set_float32_matmul_precision('high')
2324
use_pytorch_ddp = 'LOCAL_RANK' in os.environ
2425
rank = int(os.environ['LOCAL_RANK']) if use_pytorch_ddp else 0
2526
device = torch.device(f'cuda:{rank}' if torch.cuda.is_available() else 'cpu')

algoperf/workloads/cifar/cifar_pytorch/workload.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -110,12 +110,12 @@ def _build_dataset(
110110
batch_size=ds_iter_batch_size,
111111
shuffle=not USE_PYTORCH_DDP and is_train,
112112
sampler=sampler,
113-
num_workers=4 if is_train else self.eval_num_workers,
113+
num_workers=2 * N_GPUS if is_train else self.eval_num_workers,
114114
pin_memory=True,
115115
drop_last=is_train,
116116
)
117-
dataloader = data_utils.PrefetchedWrapper(dataloader, DEVICE)
118117
dataloader = data_utils.cycle(dataloader, custom_sampler=USE_PYTORCH_DDP)
118+
dataloader = data_utils.dataloader_iterator_wrapper(dataloader, DEVICE)
119119
return dataloader
120120

121121
def init_model_fn(self, rng: spec.RandomState) -> spec.ModelInitState:

algoperf/workloads/criteo1tb/workload.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -95,11 +95,11 @@ def train_stddev(self):
9595

9696
@property
9797
def max_allowed_runtime_sec(self) -> int:
98-
return 7_703 # ~2.1 hours.
98+
return 8_915 # ~2.4 hours.
9999

100100
@property
101101
def eval_period_time_sec(self) -> int:
102-
return 2 * 60 # 2 mins.
102+
return 356 # approx 25 evals
103103

104104
def _build_input_queue(
105105
self,

algoperf/workloads/fastmri/workload.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -95,11 +95,11 @@ def accelerations(self):
9595

9696
@property
9797
def max_allowed_runtime_sec(self) -> int:
98-
return 4_430 # ~1.2 hours
98+
return 2_745 # ~0.7 hours
9999

100100
@property
101101
def eval_period_time_sec(self) -> int:
102-
return 80
102+
return 110 # approx 25 evals
103103

104104
@property
105105
def step_hint(self) -> int:

algoperf/workloads/imagenet_resnet/imagenet_pytorch/workload.py

Lines changed: 109 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,18 +3,25 @@
33
import contextlib
44
import functools
55
import itertools
6+
import json
67
import math
78
import os
89
import random
9-
from typing import Dict, Iterator, Optional, Tuple
10+
import time
11+
from pathlib import Path
12+
from typing import Any, Callable, Dict, Iterator, Optional, Tuple, Union
1013

1114
import numpy as np
1215
import torch
1316
import torch.distributed as dist
1417
import torch.nn.functional as F
1518
from torch.nn.parallel import DistributedDataParallel as DDP
1619
from torchvision import transforms
17-
from torchvision.datasets.folder import ImageFolder
20+
from torchvision.datasets.folder import (
21+
IMG_EXTENSIONS,
22+
ImageFolder,
23+
default_loader,
24+
)
1825

1926
import algoperf.random_utils as prng
2027
from algoperf import data_utils, param_utils, pytorch_utils, spec
@@ -28,6 +35,100 @@
2835
USE_PYTORCH_DDP, RANK, DEVICE, N_GPUS = pytorch_utils.pytorch_setup()
2936

3037

38+
class CachedImageFolder(ImageFolder):
39+
"""ImageFolder that caches the file listing to avoid repeated filesystem scans."""
40+
41+
def __init__(
42+
self,
43+
root: Union[str, Path],
44+
cache_file: Optional[Union[str, Path]] = None,
45+
transform: Optional[Callable] = None,
46+
target_transform: Optional[Callable] = None,
47+
loader: Callable[[str], Any] = default_loader,
48+
is_valid_file: Optional[Callable[[str], bool]] = None,
49+
allow_empty: bool = False,
50+
rebuild_cache: bool = False,
51+
cache_build_timeout_minutes: int = 30,
52+
):
53+
self.root = os.path.abspath(root)
54+
self.transform = transform
55+
self.target_transform = target_transform
56+
self.loader = loader
57+
self.extensions = IMG_EXTENSIONS if is_valid_file is None else None
58+
59+
# Default cache location: .cache_index.json in the root directory
60+
if cache_file is None:
61+
cache_file = os.path.join(self.root, '.cache_index.json')
62+
self.cache_file = cache_file
63+
64+
is_distributed = dist.is_available() and dist.is_initialized()
65+
rank = dist.get_rank() if is_distributed else 0
66+
67+
cache_exists = os.path.exists(self.cache_file)
68+
needs_rebuild = rebuild_cache or not cache_exists
69+
70+
if needs_rebuild:
71+
# We only want one process to build the cache
72+
# and others to wait for it to finish.
73+
if rank == 0:
74+
self._build_and_save_cache(is_valid_file, allow_empty)
75+
if is_distributed:
76+
self._wait_for_cache(timeout_minutes=cache_build_timeout_minutes)
77+
dist.barrier()
78+
79+
self._load_from_cache()
80+
81+
self.targets = [s[1] for s in self.samples]
82+
self.imgs = self.samples
83+
84+
def _wait_for_cache(self, timeout_minutes: int):
85+
"""Poll for cache file to exist."""
86+
timeout_seconds = timeout_minutes * 60
87+
poll_interval = 5
88+
elapsed = 0
89+
90+
while not os.path.exists(self.cache_file):
91+
if elapsed >= timeout_seconds:
92+
raise TimeoutError(
93+
f'Timed out waiting for cache file after {timeout_minutes} minutes: {self.cache_file}'
94+
)
95+
time.sleep(poll_interval)
96+
elapsed += poll_interval
97+
98+
def _load_from_cache(self):
99+
"""Load classes and samples from cache file."""
100+
with open(os.path.abspath(self.cache_file), 'r') as f:
101+
cache = json.load(f)
102+
self.classes = cache['classes']
103+
self.class_to_idx = cache['class_to_idx']
104+
# Convert relative paths back to absolute
105+
self.samples = [
106+
(os.path.join(self.root, rel_path), idx)
107+
for rel_path, idx in cache['samples']
108+
]
109+
110+
def _build_and_save_cache(self, is_valid_file, allow_empty):
111+
"""Scan filesystem, build index, and save to cache."""
112+
self.classes, self.class_to_idx = self.find_classes(self.root)
113+
self.samples = self.make_dataset(
114+
self.root,
115+
class_to_idx=self.class_to_idx,
116+
extensions=self.extensions,
117+
is_valid_file=is_valid_file,
118+
allow_empty=allow_empty,
119+
)
120+
121+
cache = {
122+
'classes': self.classes,
123+
'class_to_idx': self.class_to_idx,
124+
'samples': [
125+
(os.path.relpath(path, self.root), idx) for path, idx in self.samples
126+
],
127+
}
128+
with open(os.path.abspath(self.cache_file), 'w') as f:
129+
json.dump(cache, f)
130+
131+
31132
def imagenet_v2_to_torch(
32133
batch: Dict[str, spec.Tensor],
33134
) -> Dict[str, spec.Tensor]:
@@ -119,8 +220,10 @@ def _build_dataset(
119220
)
120221

121222
folder = 'train' if 'train' in split else 'val'
122-
dataset = ImageFolder(
123-
os.path.join(data_dir, folder), transform=transform_config
223+
dataset = CachedImageFolder(
224+
os.path.join(data_dir, folder),
225+
transform=transform_config,
226+
cache_file='.imagenet_{}_cache_index.json'.format(split),
124227
)
125228

126229
if split == 'eval_train':
@@ -145,16 +248,16 @@ def _build_dataset(
145248
sampler = data_utils.DistributedEvalSampler(
146249
dataset, num_replicas=N_GPUS, rank=RANK, shuffle=False
147250
)
148-
149251
dataloader = torch.utils.data.DataLoader(
150252
dataset,
151253
batch_size=ds_iter_batch_size,
152254
shuffle=not USE_PYTORCH_DDP and is_train,
153255
sampler=sampler,
154-
num_workers=4 if is_train else self.eval_num_workers,
256+
num_workers=5 * N_GPUS if is_train else self.eval_num_workers,
155257
pin_memory=True,
156258
drop_last=is_train,
157259
persistent_workers=is_train,
260+
prefetch_factor=N_GPUS,
158261
)
159262
dataloader = data_utils.PrefetchedWrapper(dataloader, DEVICE)
160263
dataloader = data_utils.cycle(
@@ -163,7 +266,6 @@ def _build_dataset(
163266
use_mixup=use_mixup,
164267
mixup_alpha=0.2,
165268
)
166-
167269
return dataloader
168270

169271
def init_model_fn(self, rng: spec.RandomState) -> spec.ModelInitState:

algoperf/workloads/imagenet_resnet/workload.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -103,11 +103,11 @@ def resize_size(self) -> int:
103103

104104
@property
105105
def max_allowed_runtime_sec(self) -> int:
106-
return 66_159 # ~18.4 hours
106+
return 49_918 # ~13.8 hours
107107

108108
@property
109109
def eval_period_time_sec(self) -> int:
110-
return 510 # 8.5 minutes.
110+
return 1_996 # approx 25 evals
111111

112112
def _build_dataset(
113113
self,

algoperf/workloads/imagenet_vit/imagenet_pytorch/models.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
and https://github.com/lucidrains/vit-pytorch.
66
"""
77

8-
import math
98
from typing import Any, Optional, Tuple, Union
109

1110
import torch
@@ -126,13 +125,14 @@ def forward(self, x: spec.Tensor, dropout_rate: float) -> spec.Tensor:
126125
value_layer = self.transpose_for_scores(self.value(x))
127126
query_layer = self.transpose_for_scores(mixed_query_layer)
128127

129-
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
130-
attention_scores = attention_scores / math.sqrt(self.head_dim)
131-
132-
attention_probs = F.softmax(attention_scores, dim=-1)
133-
attention_probs = F.dropout(attention_probs, dropout_rate, self.training)
128+
# Use built-in scaled_dot_product_attention (Flash Attention when available)
129+
context_layer = F.scaled_dot_product_attention(
130+
query_layer,
131+
key_layer,
132+
value_layer,
133+
dropout_p=dropout_rate if self.training else 0.0,
134+
)
134135

135-
context_layer = torch.matmul(attention_probs, value_layer)
136136
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
137137
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_dim,)
138138
context_layer = context_layer.view(new_context_layer_shape)

algoperf/workloads/imagenet_vit/workload.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -88,11 +88,11 @@ def eval_batch_size(self) -> int:
8888

8989
@property
9090
def max_allowed_runtime_sec(self) -> int:
91-
return 69_768 # ~19.4 hours
91+
return 64_292 # ~17.8 hours
9292

9393
@property
9494
def eval_period_time_sec(self) -> int:
95-
return 7 * 60 # 7 mins.
95+
return 2_571 # 7 mins.
9696

9797
def _build_dataset(
9898
self,

algoperf/workloads/librispeech_conformer/workload.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -80,11 +80,11 @@ def train_stddev(self):
8080

8181
@property
8282
def max_allowed_runtime_sec(self) -> int:
83-
return 58_015 # ~16.1 hours
83+
return 43_680 # ~16.1 hours
8484

8585
@property
8686
def eval_period_time_sec(self) -> int:
87-
return 24 * 60
87+
return 1747 # approx 25 evals
8888

8989
@property
9090
def step_hint(self) -> int:

0 commit comments

Comments
 (0)