-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathapp.py
108 lines (94 loc) · 3.1 KB
/
app.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
from flask import Flask,render_template,url_for,request
from werkzeug.utils import secure_filename
from tensorflow.keras.preprocessing.text import Tokenizer
from tensorflow.keras.preprocessing.sequence import pad_sequences
from tensorflow.keras.applications.xception import Xception
from tensorflow.keras.models import load_model
from pickle import load
import numpy as np
from PIL import Image
import ssl
ssl._create_default_https_context = ssl._create_unverified_context
import os
app = Flask(__name__)
app.config["ALLOWED_IMAGE_EXTENSIONS"] = ["JPEG", "JPG", "PNG"]
# render the home page
@app.route('/')
@app.route('/home')
def home():
return render_template('home.html')
# render the image-captioning page
@app.route('/image')
def index():
return render_template('index.html')
# the predict function
@app.route('/predict', methods=['POST'])
def upload():
# function to extract image features
def extract_features(filename, model):
try:
image = Image.open(filename)
except:
print("ERROR: Couldn't open image! Make sure the image path and extension is correct")
image = image.resize((299,299))
image = np.array(image)
if image.shape[2] == 4:
image = image[..., :3]
image = np.expand_dims(image, axis=0)
image = image/127.5
image = image - 1.0
feature = model.predict(image)
return feature
def word_for_id(integer, tokenizer):
for word, index in tokenizer.word_index.items():
if index == integer:
return word
return None
# function to generate the caption based on extracted features
def generate_desc(model, tokenizer, photo, max_length):
in_text = 'start'
for i in range(max_length):
sequence = tokenizer.texts_to_sequences([in_text])[0]
sequence = pad_sequences([sequence], maxlen=max_length)
pred = model.predict([photo,sequence], verbose=0)
pred = np.argmax(pred)
word = word_for_id(pred, tokenizer)
if word is None:
break
in_text += ' ' + word
if word == 'end':
break
return in_text
max_length = 32
tokenizer = load(open("tokenizer.p","rb"))
model = load_model("models/model_9.h5")
xception_model = Xception(include_top=False, pooling="avg")
# function to check if the image has an extension and
# if the extension belongs to the allowed image extensions
def allowed_image(fname):
if not "." in fname:
return False
ext = fname.rsplit(".", 1)[1]
if ext.upper() in app.config["ALLOWED_IMAGE_EXTENSIONS"]:
return True
else:
return False
# if a post request has been made:
if request.method == 'POST':
f = request.files['file']
fname=secure_filename(f.filename)
if allowed_image(fname):
basepath = os.path.dirname(__file__)
file_path = os.path.join(basepath, 'uploads', fname)
f.save(file_path)
photo = extract_features(file_path, xception_model)
description = generate_desc(model, tokenizer, photo, max_length)
result= description[6:-3]
if os.path.exists(file_path):
os.remove(file_path)
return render_template("caption.html", captionResult=result)
else:
return 'Error occurred, Please ensure you\'re using jpeg or jpg file format.'
return " "
if __name__ == '__main__':
app.run(debug=True)