-
Notifications
You must be signed in to change notification settings - Fork 25
/
Copy pathdata_utils.py
88 lines (76 loc) · 3.53 KB
/
data_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
import keras
from keras.datasets import cifar10
import matplotlib.pyplot as plt
import numpy as np
import cv2
class CIFAR10Data(object):
def __init__(self):
(self.x_train, self.y_train), (self.x_test, self.y_test) = cifar10.load_data()
self.classes = ['plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
print('CIFAR10 Training data shape:', self.x_train.shape)
print('CIFAR10 Training label shape', self.y_train.shape)
print('CIFAR10 Test data shape', self.x_test.shape)
print('CIFAR10 Test label shape', self.y_test.shape)
def get_stretch_data(self, subtract_mean=True):
"""
reshape X each image to row vector, and transform Y to one_hot label.
:param subtract_mean:Indicate whether subtract mean image.
:return: x_train, one_hot_y_train, x_test, one_hot_y_test
"""
num_classes = len(self.classes)
# x_train = np.reshape(self.x_train, (self.x_train.shape[0], -1)).astype('float64')
x_train = np.reshape(self.x_train, (self.x_train.shape[0], -1)).astype('float16')
y_train = keras.utils.to_categorical(self.y_train, num_classes)
# x_test = np.reshape(self.x_test, (self.x_test.shape[0], -1)).astype('float64')
x_test = np.reshape(self.x_test, (self.x_test.shape[0], -1)).astype('float16')
y_test = keras.utils.to_categorical(self.y_test, num_classes)
if subtract_mean:
mean_image = np.mean(x_train, axis=0).astype('uint8')
x_train -= mean_image
x_test -= mean_image
# print(x_mean[:10])
# plt.figure(figsize=(4, 4))
# plt.imshow(x_mean.reshape((32, 32, 3)))
# plt.show()
return x_train, y_train, x_test, y_test
def get_data(self, subtract_mean=True, output_shape=None):
"""
The data is not reshaped, keep 3 channel.
:param subtract_mean:Indicate whether subtract mean image.
:param output_shape:Indicate whether resize image
:return: x_train, one_hot_y_train, x_test, one_hot_y_test
"""
num_classes = len(self.classes)
x_train = self.x_train
x_test = self.x_test
# if output_shape:resize
# x_train = np.array([cv2.resize(img, output_shape) for img in self.x_train])
# x_test = np.array([cv2.(img, output_shape) for img in self.x_test])
x_train = x_train.astype('float16')
y_train = keras.utils.to_categorical(self.y_train, num_classes)
x_test = x_test.astype('float16')
y_test = keras.utils.to_categorical(self.y_test, num_classes)
if subtract_mean:
mean_image = np.mean(x_train, axis=0)
x_train -= mean_image
x_test -= mean_image
return x_train, y_train, x_test, y_test
def plot_cifar10(cifar_data, num_sample_per_class):
"""
random select num_sample_per_class to plot
"""
num_classes = len(cifar_data.classes)
plt.figure()
for y, cls in enumerate(cifar_data.classes):
cls_indices = np.flatnonzero(cifar_data.y_train == y)
samples_indices = np.random.choice(cls_indices, num_sample_per_class, replace=False)
samples = cifar_data.x_train[samples_indices]
for x, sample in enumerate(samples):
# subplot index count from 1
plt_idx = x * num_classes + y + 1
plt.subplot(num_sample_per_class, num_classes, plt_idx)
plt.imshow(sample)
plt.axis('off')
if x == 0:
plt.title(cls)
plt.show()