Skip to content

Commit b4d742c

Browse files
committed
Change num_workers for imagenet, add validation tests for step times
1 parent 9c93fc2 commit b4d742c

7 files changed

Lines changed: 214 additions & 292 deletions

File tree

algoperf/workloads/cifar/cifar_pytorch/workload.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -110,12 +110,12 @@ def _build_dataset(
110110
batch_size=ds_iter_batch_size,
111111
shuffle=not USE_PYTORCH_DDP and is_train,
112112
sampler=sampler,
113-
num_workers=4 if is_train else self.eval_num_workers,
113+
num_workers=2 * N_GPUS if is_train else self.eval_num_workers,
114114
pin_memory=True,
115115
drop_last=is_train,
116116
)
117-
dataloader = data_utils.PrefetchedWrapper(dataloader, DEVICE)
118117
dataloader = data_utils.cycle(dataloader, custom_sampler=USE_PYTORCH_DDP)
118+
dataloader = data_utils.dataloader_iterator_wrapper(dataloader, DEVICE)
119119
return dataloader
120120

121121
def init_model_fn(self, rng: spec.RandomState) -> spec.ModelInitState:

algoperf/workloads/imagenet_resnet/imagenet_pytorch/workload.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ def __init__(
5050
rebuild_cache: bool = False,
5151
cache_build_timeout_minutes: int = 30,
5252
):
53-
self.root = os.path.expanduser(root)
53+
self.root = os.path.abspath(root)
5454
self.transform = transform
5555
self.target_transform = target_transform
5656
self.loader = loader
@@ -223,7 +223,7 @@ def _build_dataset(
223223
dataset = CachedImageFolder(
224224
os.path.join(data_dir, folder),
225225
transform=transform_config,
226-
cache_file='.imagenet_cache_index.json',
226+
cache_file='.imagenet_{}_cache_index.json'.format(split),
227227
)
228228

229229
if split == 'eval_train':
@@ -248,16 +248,16 @@ def _build_dataset(
248248
sampler = data_utils.DistributedEvalSampler(
249249
dataset, num_replicas=N_GPUS, rank=RANK, shuffle=False
250250
)
251-
252251
dataloader = torch.utils.data.DataLoader(
253252
dataset,
254253
batch_size=ds_iter_batch_size,
255254
shuffle=not USE_PYTORCH_DDP and is_train,
256255
sampler=sampler,
257-
num_workers=4 if is_train else self.eval_num_workers,
256+
num_workers=5 * N_GPUS if is_train else self.eval_num_workers,
258257
pin_memory=True,
259258
drop_last=is_train,
260259
persistent_workers=is_train,
260+
prefetch_factor=N_GPUS if is_train else None,
261261
)
262262
dataloader = data_utils.PrefetchedWrapper(dataloader, DEVICE)
263263
dataloader = data_utils.cycle(
@@ -266,7 +266,6 @@ def _build_dataset(
266266
use_mixup=use_mixup,
267267
mixup_alpha=0.2,
268268
)
269-
270269
return dataloader
271270

272271
def init_model_fn(self, rng: spec.RandomState) -> spec.ModelInitState:

algoperf/workloads/imagenet_vit/imagenet_pytorch/models.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
and https://github.com/lucidrains/vit-pytorch.
66
"""
77

8-
import math
98
from typing import Any, Optional, Tuple, Union
109

1110
import torch
@@ -126,13 +125,14 @@ def forward(self, x: spec.Tensor, dropout_rate: float) -> spec.Tensor:
126125
value_layer = self.transpose_for_scores(self.value(x))
127126
query_layer = self.transpose_for_scores(mixed_query_layer)
128127

129-
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
130-
attention_scores = attention_scores / math.sqrt(self.head_dim)
131-
132-
attention_probs = F.softmax(attention_scores, dim=-1)
133-
attention_probs = F.dropout(attention_probs, dropout_rate, self.training)
128+
# Use built-in scaled_dot_product_attention (Flash Attention when available)
129+
context_layer = F.scaled_dot_product_attention(
130+
query_layer,
131+
key_layer,
132+
value_layer,
133+
dropout_p=dropout_rate if self.training else 0.0,
134+
)
134135

135-
context_layer = torch.matmul(attention_probs, value_layer)
136136
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
137137
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_dim,)
138138
context_layer = context_layer.view(new_context_layer_shape)

algorithms/baselines/external_tuning/pytorch_nadamw_full_budget.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55

66
import torch
77
import torch.distributed.nn as dist_nn
8-
from absl import logging
98
from torch import Tensor
109
from torch.optim.lr_scheduler import CosineAnnealingLR, LinearLR, SequentialLR
1110

benchmark_step_times.py

Lines changed: 0 additions & 274 deletions
This file was deleted.

submission_runner.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -256,7 +256,6 @@ def train_once(
256256
'librispeech_conformer',
257257
'ogbg',
258258
'criteo1tb',
259-
'imagenet_vit',
260259
'librispeech_deepspeech',
261260
]
262261
eager_backend_workloads = []
@@ -266,6 +265,7 @@ def train_once(
266265
'librispeech_deepspeech',
267266
'ogbg',
268267
'wmt',
268+
'imagenet_vit',
269269
]
270270
base_workload = workloads.get_base_workload_name(workload_name)
271271
if base_workload in compile_error_workloads:
@@ -411,9 +411,8 @@ def train_once(
411411
train_step_end_time = get_time()
412412
if global_step == 11:
413413
step_10_end_time = train_step_end_time
414-
414+
415415
# Log step time every 100 steps
416-
# Note: global_step was incremented, so use (global_step - 1) to match
417416
if (global_step - 1) % 100 == 0 and workload.metrics_logger is not None:
418417
if step_10_end_time is not None and global_step > 11:
419418
elapsed_time_ms = (train_step_end_time - step_10_end_time) * 1000.0

0 commit comments

Comments
 (0)