33import contextlib
44import functools
55import itertools
6+ import json
67import math
78import os
89import random
9- from typing import Dict , Iterator , Optional , Tuple
10+ import time
11+ from pathlib import Path
12+ from typing import Any , Callable , Dict , Iterator , Optional , Tuple , Union
1013
1114import numpy as np
1215import torch
1316import torch .distributed as dist
1417import torch .nn .functional as F
1518from torch .nn .parallel import DistributedDataParallel as DDP
1619from torchvision import transforms
17- from torchvision .datasets .folder import ImageFolder
20+ from torchvision .datasets .folder import (
21+ IMG_EXTENSIONS ,
22+ ImageFolder ,
23+ default_loader ,
24+ )
1825
1926import algoperf .random_utils as prng
2027from algoperf import data_utils , param_utils , pytorch_utils , spec
2835USE_PYTORCH_DDP , RANK , DEVICE , N_GPUS = pytorch_utils .pytorch_setup ()
2936
3037
38+ class CachedImageFolder (ImageFolder ):
39+ """ImageFolder that caches the file listing to avoid repeated filesystem scans."""
40+
41+ def __init__ (
42+ self ,
43+ root : Union [str , Path ],
44+ cache_file : Optional [Union [str , Path ]] = None ,
45+ transform : Optional [Callable ] = None ,
46+ target_transform : Optional [Callable ] = None ,
47+ loader : Callable [[str ], Any ] = default_loader ,
48+ is_valid_file : Optional [Callable [[str ], bool ]] = None ,
49+ allow_empty : bool = False ,
50+ rebuild_cache : bool = False ,
51+ cache_build_timeout_minutes : int = 30 ,
52+ ):
53+ self .root = os .path .expanduser (root )
54+ self .transform = transform
55+ self .target_transform = target_transform
56+ self .loader = loader
57+ self .extensions = IMG_EXTENSIONS if is_valid_file is None else None
58+
59+ # Default cache location: .cache_index.json in the root directory
60+ if cache_file is None :
61+ cache_file = os .path .join (self .root , '.cache_index.json' )
62+ self .cache_file = cache_file
63+
64+ is_distributed = dist .is_available () and dist .is_initialized ()
65+ rank = dist .get_rank () if is_distributed else 0
66+
67+ cache_exists = os .path .exists (self .cache_file )
68+ needs_rebuild = rebuild_cache or not cache_exists
69+
70+ if needs_rebuild :
71+ # We only want one process to build the cache
72+ # and others to wait for it to finish.
73+ if rank == 0 :
74+ self ._build_and_save_cache (is_valid_file , allow_empty )
75+ if is_distributed :
76+ self ._wait_for_cache (timeout_minutes = cache_build_timeout_minutes )
77+ dist .barrier ()
78+
79+ self ._load_from_cache ()
80+
81+ self .targets = [s [1 ] for s in self .samples ]
82+ self .imgs = self .samples
83+
84+ def _wait_for_cache (self , timeout_minutes : int ):
85+ """Poll for cache file to exist."""
86+ timeout_seconds = timeout_minutes * 60
87+ poll_interval = 5
88+ elapsed = 0
89+
90+ while not os .path .exists (self .cache_file ):
91+ if elapsed >= timeout_seconds :
92+ raise TimeoutError (
93+ f'Timed out waiting for cache file after { timeout_minutes } minutes: { self .cache_file } '
94+ )
95+ time .sleep (poll_interval )
96+ elapsed += poll_interval
97+
98+ def _load_from_cache (self ):
99+ """Load classes and samples from cache file."""
100+ with open (os .path .abspath (self .cache_file ), 'r' ) as f :
101+ cache = json .load (f )
102+ self .classes = cache ['classes' ]
103+ self .class_to_idx = cache ['class_to_idx' ]
104+ # Convert relative paths back to absolute
105+ self .samples = [
106+ (os .path .join (self .root , rel_path ), idx )
107+ for rel_path , idx in cache ['samples' ]
108+ ]
109+
110+ def _build_and_save_cache (self , is_valid_file , allow_empty ):
111+ """Scan filesystem, build index, and save to cache."""
112+ self .classes , self .class_to_idx = self .find_classes (self .root )
113+ self .samples = self .make_dataset (
114+ self .root ,
115+ class_to_idx = self .class_to_idx ,
116+ extensions = self .extensions ,
117+ is_valid_file = is_valid_file ,
118+ allow_empty = allow_empty ,
119+ )
120+
121+ cache = {
122+ 'classes' : self .classes ,
123+ 'class_to_idx' : self .class_to_idx ,
124+ 'samples' : [
125+ (os .path .relpath (path , self .root ), idx ) for path , idx in self .samples
126+ ],
127+ }
128+ with open (os .path .abspath (self .cache_file ), 'w' ) as f :
129+ json .dump (cache , f )
130+
131+
31132def imagenet_v2_to_torch (
32133 batch : Dict [str , spec .Tensor ],
33134) -> Dict [str , spec .Tensor ]:
@@ -119,8 +220,10 @@ def _build_dataset(
119220 )
120221
121222 folder = 'train' if 'train' in split else 'val'
122- dataset = ImageFolder (
123- os .path .join (data_dir , folder ), transform = transform_config
223+ dataset = CachedImageFolder (
224+ os .path .join (data_dir , folder ),
225+ transform = transform_config ,
226+ cache_file = '.imagenet_cache_index.json' ,
124227 )
125228
126229 if split == 'eval_train' :
@@ -151,10 +254,11 @@ def _build_dataset(
151254 batch_size = ds_iter_batch_size ,
152255 shuffle = not USE_PYTORCH_DDP and is_train ,
153256 sampler = sampler ,
154- num_workers = 4 if is_train else self .eval_num_workers ,
257+ num_workers = 5 * N_GPUS if is_train else self .eval_num_workers ,
155258 pin_memory = True ,
156259 drop_last = is_train ,
157260 persistent_workers = is_train ,
261+ prefetch_factor = N_GPUS ,
158262 )
159263 dataloader = data_utils .PrefetchedWrapper (dataloader , DEVICE )
160264 dataloader = data_utils .cycle (
@@ -163,7 +267,6 @@ def _build_dataset(
163267 use_mixup = use_mixup ,
164268 mixup_alpha = 0.2 ,
165269 )
166-
167270 return dataloader
168271
169272 def init_model_fn (self , rng : spec .RandomState ) -> spec .ModelInitState :
0 commit comments