-
Notifications
You must be signed in to change notification settings - Fork 32
/
Copy pathgradio_app.py
128 lines (105 loc) · 8.73 KB
/
gradio_app.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
import gradio as gr
import os
from PIL import Image
import subprocess
from gradio_model4dgs import Model4DGS
import numpy
import hashlib
js_func = """
function refresh() {
const url = new URL(window.location);
if (url.searchParams.get('__theme') !== 'light') {
url.searchParams.set('__theme', 'light');
window.location.href = url.href;
}
}
"""
# check if there is a picture uploaded or selected
def check_img_input(control_image):
if control_image is None:
raise gr.Error("Please select or upload an input image")
# check if there is a picture uploaded or selected
def check_video_input():
if not os.path.exists(os.path.join('data', 'tmp_rgba_generated.mp4')):
raise gr.Error("Please generate a video first")
def optimize_stage_1(image_block: Image.Image, preprocess_chk: bool, seed_slider: int):
if not os.path.exists('tmp_data'):
os.makedirs('tmp_data')
img_hash = hashlib.sha256(image_block.tobytes()).hexdigest()
if preprocess_chk:
# save image to a designated path
image_block.save(os.path.join('tmp_data', f'{img_hash}.png'))
# preprocess image
print(f'python scripts/process.py {os.path.join("tmp_data", f"{img_hash}.png")}')
subprocess.run(f'python scripts/process.py {os.path.join("tmp_data", f"{img_hash}.png")}', shell=True)
else:
image_block.save(os.path.join('tmp_data', f'{img_hash}_rgba.png'))
# stage 1
subprocess.run(f'export MKL_THREADING_LAYER=GNU;export MKL_SERVICE_FORCE_INTEL=1;python scripts/gen_vid.py --path tmp_data/{img_hash}_rgba.png --seed {seed_slider} --bg white', shell=True)
# return [os.path.join('logs', 'tmp_rgba_model.ply')]
return os.path.join('tmp_data', f'{img_hash}_rgba_generated.mp4')
def optimize_stage_2(image_block: Image.Image, seed_slider: int):
img_hash = hashlib.sha256(image_block.tobytes()).hexdigest()
subprocess.run(f'python lgm/infer.py big --test_path tmp_data/{img_hash}_rgba.png', shell=True)
# stage 2
subprocess.run(f'python main_4d.py --config {os.path.join("configs", "4d_demo.yaml")} input={os.path.join("tmp_data", f"{img_hash}_rgba.png")}', shell=True)
# os.rename(os.path.join('logs', f'{img_hash}_rgba_frames'), os.path.join('logs', f'{img_hash}_{seed_slider:03d}_rgba_frames'))
image_dir = os.path.join('logs', f'{img_hash}_rgba_frames')
# return 'vis_data/tmp_rgba.mp4', [os.path.join(image_dir, file) for file in os.listdir(image_dir) if file.endswith('.ply')]
return [image_dir+f'/{t:03d}.ply' for t in range(28)]
if __name__ == "__main__":
_TITLE = '''DreamGaussian4D: Generative 4D Gaussian Splatting'''
_DESCRIPTION = '''
<div>
<a style="display:inline-block" href="https://jiawei-ren.github.io/projects/dreamgaussian4d/"><img src='https://img.shields.io/badge/public_website-8A2BE2'></a>
<a style="display:inline-block; margin-left: .5em" href="https://arxiv.org/abs/2312.17142"><img src="https://img.shields.io/badge/2312.17142-f9f7f7?logo="></a>
<a style="display:inline-block; margin-left: .5em" href='https://github.com/jiawei-ren/dreamgaussian4d'><img src='https://img.shields.io/github/stars/jiawei-ren/dreamgaussian4d?style=social'/></a>
</div>
We present DreamGausssion4D, an efficient 4D generation framework that builds on Gaussian Splatting.
'''
_IMG_USER_GUIDE = "Please upload an image in the block above (or choose an example above), select a random seed, and click **Generate Video**. After having the video generated, please click **Generate 4D**."
# load images in 'data' folder as examples
example_folder = os.path.join(os.path.dirname(__file__), 'data')
example_fns = os.listdir(example_folder)
example_fns.sort()
examples_full = [os.path.join(example_folder, x) for x in example_fns if x.endswith('.png')]
# Compose demo layout & data flow
with gr.Blocks(title=_TITLE, theme=gr.themes.Soft(), js=js_func) as demo:
with gr.Row():
with gr.Column(scale=1):
gr.Markdown('# ' + _TITLE)
gr.Markdown(_DESCRIPTION)
# Image-to-3D
with gr.Row(variant='panel'):
with gr.Column(scale=4):
image_block = gr.Image(type='pil', image_mode='RGBA', height=290, label='Input image')
# elevation_slider = gr.Slider(-90, 90, value=0, step=1, label='Estimated elevation angle')
seed_slider = gr.Slider(0, 100000, value=0, step=1, label='Random Seed')
gr.Markdown(
"random seed for video generation.")
preprocess_chk = gr.Checkbox(True,
label='Preprocess image automatically (remove background and recenter object)')
gr.Examples(
examples=examples_full, # NOTE: elements must match inputs list!
inputs=[image_block],
outputs=[image_block],
cache_examples=False,
label='Examples (click one of the images below to start)',
examples_per_page=40
)
img_run_btn = gr.Button("Generate Video")
fourd_run_btn = gr.Button("Generate 4D")
img_guide_text = gr.Markdown(_IMG_USER_GUIDE, visible=True)
with gr.Column(scale=5):
obj3d = gr.Video(label="video",height=290)
obj4d = Model4DGS(label="4D Model", height=500, fps=14)
img_run_btn.click(check_img_input, inputs=[image_block], queue=False).success(optimize_stage_1,
inputs=[image_block,
preprocess_chk,
seed_slider],
outputs=[
obj3d])
fourd_run_btn.click(check_video_input, inputs=[], queue=False).success(optimize_stage_2, inputs=[image_block, seed_slider], outputs=[obj4d])
# demo.queue().launch(share=True)
demo.queue(max_size=10) # <-- Sets up a queue with default parameters
demo.launch(share=True)