forked from PatrickLib/captcha_recognize
-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathcaptcha_input.py
59 lines (49 loc) · 2.29 KB
/
captcha_input.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os.path
import tensorflow as tf
import config
RECORD_DIR = config.RECORD_DIR
TRAIN_FILE = config.TRAIN_FILE
VALID_FILE = config.VALID_FILE
IMAGE_WIDTH = config.IMAGE_WIDTH
IMAGE_HEIGHT = config.IMAGE_HEIGHT
CLASSES_NUM = config.CLASSES_NUM
CHARS_NUM = config.CHARS_NUM
def read_and_decode(filename_queue):
reader = tf.compat.v1.TFRecordReader()
_, serialized_example = reader.read(filename_queue)
features = tf.compat.v1.parse_single_example(
serialized_example,
features={
'image_raw': tf.compat.v1.FixedLenFeature([], tf.string),
'label_raw': tf.compat.v1.FixedLenFeature([], tf.string),
})
image = tf.io.decode_raw(features['image_raw'], tf.int16)
image.set_shape([IMAGE_HEIGHT * IMAGE_WIDTH])
image = tf.cast(image, tf.float32) * (1. / 255) - 0.5
reshape_image = tf.reshape(image, [IMAGE_HEIGHT, IMAGE_WIDTH, 1])
label = tf.io.decode_raw(features['label_raw'], tf.uint8)
label.set_shape([CHARS_NUM * CLASSES_NUM])
reshape_label = tf.reshape(label, [CHARS_NUM, CLASSES_NUM])
return tf.cast(reshape_image, tf.float32), tf.cast(reshape_label, tf.float32)
def inputs(train, batch_size):
filename = os.path.join(RECORD_DIR,
TRAIN_FILE if train else VALID_FILE)
with tf.name_scope('input'):
filename_queue = tf.compat.v1.train.string_input_producer([filename])
#filename_queue = tf.data.Dataset.from_generator(filename).shuffle(tf.shape())
image, label = read_and_decode(filename_queue)
if train:
images, sparse_labels = tf.compat.v1.train.shuffle_batch([image, label],
batch_size=batch_size,
num_threads=6,
capacity=2000 + 3 * batch_size,
min_after_dequeue=2000)
else:
images, sparse_labels = tf.compat.v1.train.batch([image, label],
batch_size=batch_size,
num_threads=6,
capacity=2000 + 3 * batch_size)
return images, sparse_labels