@@ -50,7 +50,7 @@ def __init__(
5050 rebuild_cache : bool = False ,
5151 cache_build_timeout_minutes : int = 30 ,
5252 ):
53- self .root = os .path .expanduser (root )
53+ self .root = os .path .abspath (root )
5454 self .transform = transform
5555 self .target_transform = target_transform
5656 self .loader = loader
@@ -223,7 +223,7 @@ def _build_dataset(
223223 dataset = CachedImageFolder (
224224 os .path .join (data_dir , folder ),
225225 transform = transform_config ,
226- cache_file = '.imagenet_cache_index .json' ,
226+ cache_file = '.imagenet_{}_cache_index .json' . format ( split ) ,
227227 )
228228
229229 if split == 'eval_train' :
@@ -248,13 +248,13 @@ def _build_dataset(
248248 sampler = data_utils .DistributedEvalSampler (
249249 dataset , num_replicas = N_GPUS , rank = RANK , shuffle = False
250250 )
251-
252251 dataloader = torch .utils .data .DataLoader (
253252 dataset ,
254253 batch_size = ds_iter_batch_size ,
255254 shuffle = not USE_PYTORCH_DDP and is_train ,
256255 sampler = sampler ,
257256 num_workers = 5 * N_GPUS if is_train else self .eval_num_workers ,
257+ num_workers = 5 * N_GPUS if is_train else self .eval_num_workers ,
258258 pin_memory = True ,
259259 drop_last = is_train ,
260260 persistent_workers = is_train ,
0 commit comments