forked from sachit-menon/classify_by_description_release
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdatasets.py
125 lines (100 loc) · 4.34 KB
/
datasets.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
import os
import torch
from torchvision import datasets
class CUBDataset(datasets.ImageFolder):
"""
Wrapper for the CUB-200-2011 dataset.
Method DatasetBirds.__getitem__() returns tuple of image and its corresponding label.
Dataset per https://github.com/slipnitskaya/caltech-birds-advanced-classification
"""
def __init__(self,
root,
transform=None,
target_transform=None,
loader=datasets.folder.default_loader,
is_valid_file=None,
train=True,
bboxes=False):
img_root = os.path.join(root, 'images')
super(CUBDataset, self).__init__(
root=img_root,
transform=None,
target_transform=None,
loader=loader,
is_valid_file=is_valid_file,
)
self.redefine_class_to_idx()
self.transform_ = transform
self.target_transform_ = target_transform
self.train = train
# obtain sample ids filtered by split
path_to_splits = os.path.join(root, 'train_test_split.txt')
indices_to_use = list()
with open(path_to_splits, 'r') as in_file:
for line in in_file:
idx, use_train = line.strip('\n').split(' ', 2)
if bool(int(use_train)) == self.train:
indices_to_use.append(int(idx))
# obtain filenames of images
path_to_index = os.path.join(root, 'images.txt')
filenames_to_use = set()
with open(path_to_index, 'r') as in_file:
for line in in_file:
idx, fn = line.strip('\n').split(' ', 2)
if int(idx) in indices_to_use:
filenames_to_use.add(fn)
img_paths_cut = {'/'.join(img_path.rsplit('/', 2)[-2:]): idx for idx, (img_path, lb) in enumerate(self.imgs)}
imgs_to_use = [self.imgs[img_paths_cut[fn]] for fn in filenames_to_use]
_, targets_to_use = list(zip(*imgs_to_use))
self.imgs = self.samples = imgs_to_use
self.targets = targets_to_use
if bboxes:
# get coordinates of a bounding box
path_to_bboxes = os.path.join(root, 'bounding_boxes.txt')
bounding_boxes = list()
with open(path_to_bboxes, 'r') as in_file:
for line in in_file:
idx, x, y, w, h = map(lambda x: float(x), line.strip('\n').split(' '))
if int(idx) in indices_to_use:
bounding_boxes.append((x, y, w, h))
self.bboxes = bounding_boxes
else:
self.bboxes = None
def __getitem__(self, index):
# generate one sample
sample, target = super(CUBDataset, self).__getitem__(index)
if self.bboxes is not None:
# squeeze coordinates of the bounding box to range [0, 1]
width, height = sample.width, sample.height
x, y, w, h = self.bboxes[index]
scale_resize = 500 / width
scale_resize_crop = scale_resize * (375 / 500)
x_rel = scale_resize_crop * x / 375
y_rel = scale_resize_crop * y / 375
w_rel = scale_resize_crop * w / 375
h_rel = scale_resize_crop * h / 375
target = torch.tensor([target, x_rel, y_rel, w_rel, h_rel])
if self.transform_ is not None:
sample = self.transform_(sample)
if self.target_transform_ is not None:
target = self.target_transform_(target)
return sample, target
def redefine_class_to_idx(self):
adjusted_dict = {}
for k, v in self.class_to_idx.items():
k = k.split('.')[-1].replace('_', ' ')
split_key = k.split(' ')
if len(split_key) > 2:
k = '-'.join(split_key[:-1]) + " " + split_key[-1]
adjusted_dict[k] = v
self.class_to_idx = adjusted_dict
from PIL import Image
import torchvision.transforms as transforms
def _transform(n_px):
return transforms.Compose([
transforms.Resize(n_px, interpolation=Image.BICUBIC),
transforms.CenterCrop(n_px),
lambda image: image.convert("RGB"),
transforms.ToTensor(),
transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
])