-
Notifications
You must be signed in to change notification settings - Fork 8
/
Copy path__init__.py
39 lines (27 loc) · 1.04 KB
/
__init__.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
import os
from typing import Iterable
import numpy as np
import tensorflow as tf
from pose_to_video.utils import set_tensorflow_memory_growth
INPUT_RESOLUTION = (256, 256)
def upscale_frame(model, frame):
# make frame into numpy if not already
if not isinstance(frame, np.ndarray):
frame = np.array(frame)
# if uint8 then cast as float32
if frame.dtype != np.float32:
frame = frame.astype('float32')
# if not normalized then normalize
if frame.max(initial=0) > 1:
frame /= 255.0
model_input = np.expand_dims(frame, axis=0)
model_output = model.predict(model_input, verbose=None)[0]
return (model_output * 255.0).astype('uint8')
def process(frames: Iterable[np.ndarray]) -> Iterable[np.ndarray]:
set_tensorflow_memory_growth()
# Load the model
current_dir = os.path.dirname(os.path.realpath(__file__))
model_path = os.path.join(current_dir, "dist", "model.h5")
model = tf.keras.models.load_model(model_path)
for frame in frames:
yield upscale_frame(model, frame)