-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathdataset_all.py
More file actions
69 lines (54 loc) · 2.01 KB
/
dataset_all.py
File metadata and controls
69 lines (54 loc) · 2.01 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
import os
import numpy as np
from torch.utils.data import Dataset
from PIL import Image
import torch
from torchvision.transforms import v2
DATA_AUG = False
IMG = '/data/users/lkang/RVL-CDIP/images/'
LABEL = '/data/users/lkang/RVL-CDIP/RVL_CDIP_full.npy'
class RVL(Dataset):
def __init__(self, img_dir, label_dir, split): # split: train, valid, test
self.split = split
data_all = np.load(label_dir, allow_pickle=True).item()
self.data = data_all[split]
self.img_proc = torch.nn.Sequential(
v2.ToImage(),
v2.ToDtype(torch.float32, scale=True),
v2.Resize((224, 224), antialias=True),
v2.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
)
self.img_proc_aug = torch.nn.Sequential(
v2.ToImage(),
v2.ToDtype(torch.float32, scale=True),
v2.Resize((224, 224), antialias=True),
v2.RandomResizedCrop((224, 224), scale=(0.5, 1), ratio=(0.75, 1.25), antialias=True),
v2.RandomAffine(degrees=5, shear=(-10, 10, -10, 10)),
v2.GaussianBlur(kernel_size=5, sigma=(0.1, 5)),
v2.RandomAdjustSharpness(sharpness_factor=2, p=0.5),
v2.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
)
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
record = self.data[idx]
img_url = record[0]
clas = record[1]
gt = int(clas)
img = Image.open(f'{IMG}{img_url}').convert('RGB')
if self.split == 'train' and DATA_AUG:
img_feat = self.img_proc_aug(img)
else:
img_feat = self.img_proc(img)
sample_info = {'img_url': img_url,
'img': img_feat,
'label': gt,
}
return sample_info
def loadData():
data_dir = dict()
for split in ['train', 'valid', 'test']:
data_dir[split] = RVL(IMG, LABEL, split)
return data_dir
if __name__ == '__main__':
pass