Skip to content

Commit

Permalink
[gradio] update gradio code and doc (#220)
Browse files Browse the repository at this point in the history
  • Loading branch information
oahzxl authored Sep 26, 2024
1 parent f99ad20 commit e48a642
Show file tree
Hide file tree
Showing 3 changed files with 52 additions and 101 deletions.
4 changes: 3 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
<h3 align="center">
An easy and efficient system for video generation
</h3>
<p align="center">| <a href="https://github.com/NUS-HPC-AI-Lab/VideoSys?tab=readme-ov-file#installation">Quick Start</a> | <a href="https://github.com/NUS-HPC-AI-Lab/VideoSys?tab=readme-ov-file#usage">Supported Models</a> | <a href="https://github.com/NUS-HPC-AI-Lab/VideoSys?tab=readme-ov-file#acceleration-techniques">Accelerations</a> | <a href="https://discord.gg/WhPmYm9FeG">Discord</a> | <a href="https://oahzxl.notion.site/VideoSys-News-42391db7e0a44f96a1f0c341450ae472?pvs=4">Media</a> |
<p align="center">| <a href="https://github.com/NUS-HPC-AI-Lab/VideoSys?tab=readme-ov-file#installation">Quick Start</a> | <a href="https://github.com/NUS-HPC-AI-Lab/VideoSys?tab=readme-ov-file#usage">Supported Models</a> | <a href="https://github.com/NUS-HPC-AI-Lab/VideoSys?tab=readme-ov-file#acceleration-techniques">Accelerations</a> | <a href="https://discord.gg/WhPmYm9FeG">Discord</a> | <a href="https://oahzxl.notion.site/VideoSys-News-42391db7e0a44f96a1f0c341450ae472?pvs=4">Media</a> | <a href="https://huggingface.co/VideoSys">HuggingFace Space</a> |
</p>

### Latest News 🔥
Expand Down Expand Up @@ -106,6 +106,8 @@ VideoSys supports many diffusion models with our various acceleration techniques
</tr>
</table>

You can also find easy demo with HuggingFace Space <a href="https://huggingface.co/VideoSys">[link]</a> and Gradio <a href="./gradio">[link]</a>.

## Acceleration Techniques

### Pyramid Attention Broadcast (PAB) [[paper](https://arxiv.org/abs/2408.12588)][[blog](https://arxiv.org/abs/2403.10266)][[doc](./docs/pab.md)]
Expand Down
4 changes: 4 additions & 0 deletions gradio/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
## Gradio Demo
Here are local gradio demos for easy UI and visualization. You can also find online demos on <a href="https://huggingface.co/VideoSys">HuggingFace Space</a>.

It's very easy to run the scripts: `python xxx.py`
145 changes: 45 additions & 100 deletions gradio/cogvideox.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,25 +3,17 @@
os.environ["GRADIO_TEMP_DIR"] = os.path.join(os.getcwd(), ".tmp_outputs")
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"

import logging
import uuid

import GPUtil
import psutil
import torch
import spaces

import gradio as gr
from videosys import CogVideoXConfig, CogVideoXPABConfig, VideoSysEngine

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

dtype = torch.float16


def load_model(enable_video_sys=False, pab_threshold=[100, 850], pab_range=2):
def load_model(model_name, enable_video_sys=False, pab_threshold=[100, 850], pab_range=2):
pab_config = CogVideoXPABConfig(spatial_threshold=pab_threshold, spatial_range=pab_range)
config = CogVideoXConfig(num_gpus=1, enable_pab=enable_video_sys, pab_config=pab_config)
config = CogVideoXConfig(model_name, enable_pab=enable_video_sys, pab_config=pab_config)
engine = VideoSysEngine(config)
return engine

Expand All @@ -36,33 +28,9 @@ def generate(engine, prompt, num_inference_steps=50, guidance_scale=6.0):
return output_path


def get_server_status():
cpu_percent = psutil.cpu_percent()
memory = psutil.virtual_memory()
disk = psutil.disk_usage("/")
gpus = GPUtil.getGPUs()
gpu_info = []
for gpu in gpus:
gpu_info.append(
{
"id": gpu.id,
"name": gpu.name,
"load": f"{gpu.load*100:.1f}%",
"memory_used": f"{gpu.memoryUsed}MB",
"memory_total": f"{gpu.memoryTotal}MB",
}
)

return {"cpu": f"{cpu_percent}%", "memory": f"{memory.percent}%", "disk": f"{disk.percent}%", "gpu": gpu_info}


def generate_vanilla(prompt, num_inference_steps, guidance_scale, progress=gr.Progress(track_tqdm=True)):
engine = load_model()
video_path = generate(engine, prompt, num_inference_steps, guidance_scale)
return video_path


@spaces.GPU(duration=200)
def generate_vs(
model_name,
prompt,
num_inference_steps,
guidance_scale,
Expand All @@ -73,38 +41,11 @@ def generate_vs(
):
threshold = [int(threshold_end), int(threshold_start)]
gap = int(gap)
engine = load_model(enable_video_sys=True, pab_threshold=threshold, pab_range=gap)
engine = load_model(model_name, enable_video_sys=True, pab_threshold=threshold, pab_range=gap)
video_path = generate(engine, prompt, num_inference_steps, guidance_scale)
return video_path


def get_server_status():
cpu_percent = psutil.cpu_percent()
memory = psutil.virtual_memory()
disk = psutil.disk_usage("/")
try:
gpus = GPUtil.getGPUs()
if gpus:
gpu = gpus[0]
gpu_memory = f"{gpu.memoryUsed}/{gpu.memoryTotal}MB ({gpu.memoryUtil*100:.1f}%)"
else:
gpu_memory = "No GPU found"
except:
gpu_memory = "GPU information unavailable"

return {
"cpu": f"{cpu_percent}%",
"memory": f"{memory.percent}%",
"disk": f"{disk.percent}%",
"gpu_memory": gpu_memory,
}


def update_server_status():
status = get_server_status()
return (status["cpu"], status["memory"], status["disk"], status["gpu_memory"])


css = """
body {
font-family: Arial, sans-serif;
Expand Down Expand Up @@ -206,60 +147,64 @@ def update_server_status():

with gr.Row():
with gr.Column():
prompt = gr.Textbox(label="Prompt (Less than 200 Words)", value="Sunset over the sea.", lines=4)
prompt = gr.Textbox(label="Prompt (Less than 200 Words)", value="Sunset over the sea.", lines=2)

with gr.Column():
gr.Markdown("**Generation Parameters**<br>")
with gr.Row():
num_inference_steps = gr.Number(label="Inference Steps", value=50)
guidance_scale = gr.Number(label="Guidance Scale", value=6.0)
model_name = gr.Radio(["THUDM/CogVideoX-2b"], label="Model Type", value="THUDM/CogVideoX-2b")
with gr.Row():
pab_range = gr.Number(
label="PAB Broadcast Range", value=2, precision=0, info="Broadcast timesteps range."
num_inference_steps = gr.Slider(label="Inference Steps", maximum=50, value=50)
guidance_scale = gr.Slider(label="Guidance Scale", value=6.0, maximum=15.0)
gr.Markdown("**Pyramid Attention Broadcast Parameters**<br>")
with gr.Row():
pab_range = gr.Slider(
label="Broadcast Range",
value=2,
step=1,
minimum=1,
maximum=4,
info="Attention broadcast range.",
)
pab_threshold_start = gr.Slider(
label="Start Timestep",
minimum=500,
maximum=1000,
value=850,
step=1,
info="Broadcast start timestep (1000 is the fisrt).",
)
pab_threshold_end = gr.Slider(
label="End Timestep",
minimum=0,
maximum=500,
step=1,
value=100,
info="Broadcast end timestep (0 is the last).",
)
pab_threshold_start = gr.Number(label="PAB Start Timestep", value=850, info="Start from step 1000.")
pab_threshold_end = gr.Number(label="PAB End Timestep", value=100, info="End at step 0.")
with gr.Row():
generate_button_vs = gr.Button("⚡️ Generate Video with VideoSys (Faster)")
generate_button = gr.Button("🎬 Generate Video (Original)")
with gr.Column(elem_classes="server-status"):
gr.Markdown("#### Server Status")

with gr.Row():
cpu_status = gr.Textbox(label="CPU", scale=1)
memory_status = gr.Textbox(label="Memory", scale=1)

with gr.Row():
disk_status = gr.Textbox(label="Disk", scale=1)
gpu_status = gr.Textbox(label="GPU Memory", scale=1)

with gr.Row():
refresh_button = gr.Button("Refresh")
generate_button_vs = gr.Button("⚡️ Generate Video with VideoSys")

with gr.Column():
with gr.Row():
video_output_vs = gr.Video(label="CogVideoX with VideoSys", width=720, height=480)
with gr.Row():
video_output = gr.Video(label="CogVideoX", width=720, height=480)

generate_button.click(
generate_vanilla,
inputs=[prompt, num_inference_steps, guidance_scale],
outputs=[video_output],
concurrency_id="gen",
concurrency_limit=1,
)

generate_button_vs.click(
generate_vs,
inputs=[prompt, num_inference_steps, guidance_scale, pab_threshold_start, pab_threshold_end, pab_range],
inputs=[
model_name,
prompt,
num_inference_steps,
guidance_scale,
pab_threshold_start,
pab_threshold_end,
pab_range,
],
outputs=[video_output_vs],
concurrency_id="gen",
concurrency_limit=1,
)

refresh_button.click(update_server_status, outputs=[cpu_status, memory_status, disk_status, gpu_status])
demo.load(update_server_status, outputs=[cpu_status, memory_status, disk_status, gpu_status], every=1)

if __name__ == "__main__":
demo.queue(max_size=10, default_concurrency_limit=1)
Expand Down

0 comments on commit e48a642

Please sign in to comment.