@@ -50,7 +50,7 @@ def __init__(
5050 rebuild_cache : bool = False ,
5151 cache_build_timeout_minutes : int = 30 ,
5252 ):
53- self .root = os .path .expanduser (root )
53+ self .root = os .path .abspath (root )
5454 self .transform = transform
5555 self .target_transform = target_transform
5656 self .loader = loader
@@ -223,7 +223,7 @@ def _build_dataset(
223223 dataset = CachedImageFolder (
224224 os .path .join (data_dir , folder ),
225225 transform = transform_config ,
226- cache_file = '.imagenet_cache_index .json' ,
226+ cache_file = '.imagenet_{}_cache_index .json' . format ( split ) ,
227227 )
228228
229229 if split == 'eval_train' :
@@ -248,16 +248,16 @@ def _build_dataset(
248248 sampler = data_utils .DistributedEvalSampler (
249249 dataset , num_replicas = N_GPUS , rank = RANK , shuffle = False
250250 )
251-
252251 dataloader = torch .utils .data .DataLoader (
253252 dataset ,
254253 batch_size = ds_iter_batch_size ,
255254 shuffle = not USE_PYTORCH_DDP and is_train ,
256255 sampler = sampler ,
257- num_workers = 4 if is_train else self .eval_num_workers ,
256+ num_workers = 5 * N_GPUS if is_train else self .eval_num_workers ,
258257 pin_memory = True ,
259258 drop_last = is_train ,
260259 persistent_workers = is_train ,
260+ prefetch_factor = N_GPUS if is_train else None ,
261261 )
262262 dataloader = data_utils .PrefetchedWrapper (dataloader , DEVICE )
263263 dataloader = data_utils .cycle (
@@ -266,7 +266,6 @@ def _build_dataset(
266266 use_mixup = use_mixup ,
267267 mixup_alpha = 0.2 ,
268268 )
269-
270269 return dataloader
271270
272271 def init_model_fn (self , rng : spec .RandomState ) -> spec .ModelInitState :
0 commit comments