-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdata.py
38 lines (33 loc) · 989 Bytes
/
data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
# --------------------------------------
# -*- coding: utf-8 -*-
# @Time : 2022/8/29 17:40
# @Author : wzy
# @File : data.py
# ---------------------------------------
from torch.utils.data import DataLoader
import torchvision
from torchvision import transforms
from arg_parse import parse_args
args = parse_args()
train_data = torchvision.datasets.CIFAR10(
root='../data/',
train=True,
transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
]),
download=True
)
val_data = torchvision.datasets.CIFAR10(
root='../data/',
train=False,
transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
]),
download=True
)
train_loader = DataLoader(dataset=train_data, batch_size=args.bs, shuffle=True)
val_loader = DataLoader(dataset=val_data, batch_size=args.bs, shuffle=True)
if __name__ == '__main__':
print(train_data.data[0].shape)