-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathdataset.py
52 lines (38 loc) · 1.24 KB
/
dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
import torch
from torch import nn
import torchvision.datasets as dsets
from torch.utils.data import Dataset
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import glob
import numpy as np
import os
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
class audioDataset(Dataset):
def __init__(self, root_dir):
self.root_dir = root_dir
self.labels = ['cel','cla','flu','gac','gel','org','pia','sax','tru','vio','voi']
self.videos = []
self.numbers = []
self.info = []
num = 0
for folder in self.labels:
pth = os.path.join(self.root_dir,folder,'*.png')
files = glob.glob(pth)
print(files)
self.videos = self.videos+files
self.numbers.append(len(files))
self.len = len(self.videos)
self.pos = [sum(self.numbers[:i+1]) for i in range(11)]
def __len__(self):
return self.len
def __getitem__(self, idx):
im = mpimg.imread(self.videos[idx])
for i in range(11):
if idx < self.pos[i]:
lab = i
break
return im,lab
s = audioDataset('.\\input\\')
s[2]