-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathcreate_data.py
77 lines (53 loc) · 2.54 KB
/
create_data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
import sys
import os
import shutil
import numpy as np
import socket
IMAGE_NET_PATH = '/local/data/ImageNet'
DATA_DIR = './data/'
def _get_in_dirname(target_class: str) -> str:
with open('imagenet_dirs.txt', 'r') as f:
lines = f.readlines()
class_to_dir = {}
for line in lines:
dir_name, _, cls = line.strip().split()
class_to_dir[cls] = dir_name
if target_class not in class_to_dir:
raise ValueError(f"Target class {target_class} not found in imagenet_dirs.txt! Make sure you spelled correctly.")
else:
return class_to_dir[target_class]
if __name__ == "__main__":
# TODO: make those parameters?
n_exps = 10
n_imgs = 50
if socket.gethostname() != "dgx2":
print("WARNING: You seem to be running this not on dgx2.gmum, please set proper IMAGE_NET_PATH in this file!")
assert len(sys.argv) == 2, "Please pass only one argument"
target_class = sys.argv[1]
DATA_DIR = os.path.join(DATA_DIR, target_class)
target_dirname = _get_in_dirname(target_class)
# 1. copy target class
target_path = os.path.join(DATA_DIR, target_class)
in_target = os.path.join(IMAGE_NET_PATH, 'train', target_dirname)
os.makedirs(os.path.join(target_path), exist_ok=True)
for zebra_file in filter(lambda x: x.endswith(".JPEG"), os.listdir(in_target)):
shutil.copyfile(src=os.path.join(in_target, zebra_file), dst=os.path.join(target_path, zebra_file))
# 2. copy to random_discovery
random_path = os.path.join(DATA_DIR, 'random_discovery')
os.makedirs(os.path.join(random_path), exist_ok=True)
val_dirs = os.listdir(os.path.join(IMAGE_NET_PATH, 'val'))
for i in range(n_imgs):
rand_dir = np.random.choice(val_dirs)
rand_file = np.random.choice(os.listdir(os.path.join(IMAGE_NET_PATH, 'val', rand_dir)))
shutil.copyfile(src=os.path.join(IMAGE_NET_PATH, 'val', rand_dir, rand_file),
dst=os.path.join(random_path, rand_file))
# 3. copy to random_500_X
val_dirs = os.listdir(os.path.join(IMAGE_NET_PATH, 'val'))
for exp_id in range(n_exps):
exp_path = os.path.join(DATA_DIR, f'random500_{exp_id}')
os.makedirs(os.path.join(exp_path), exist_ok=True)
for i in range(n_imgs):
rand_dir = np.random.choice(val_dirs)
rand_file = np.random.choice(os.listdir(os.path.join(IMAGE_NET_PATH, 'val', rand_dir)))
shutil.copyfile(src=os.path.join(IMAGE_NET_PATH, 'val', rand_dir, rand_file),
dst=os.path.join(exp_path, rand_file))