Skip to content

Commit

Permalink
Implemented cpu offload (#197)
Browse files Browse the repository at this point in the history
* Implemented cpu offload for other frameworks

* polish example code

* Added test for low mem settings.

* polish

* format

* fix arg

* update pipeline and test

---------

Co-authored-by: ExtremeViscent <[email protected]>
  • Loading branch information
oahzxl and ExtremeViscent authored Sep 13, 2024
1 parent 4ad17b5 commit 2e81cd8
Show file tree
Hide file tree
Showing 14 changed files with 136 additions and 27 deletions.
10 changes: 10 additions & 0 deletions examples/latte/sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,15 @@ def run_base():
engine.save_video(video, f"./outputs/{prompt}.mp4")


def run_low_mem():
config = LatteConfig("maxin-cn/Latte-1", cpu_offload=True)
engine = VideoSysEngine(config)

prompt = "Sunset over the sea."
video = engine.generate(prompt).video[0]
engine.save_video(video, f"./outputs/{prompt}.mp4")


def run_pab():
config = LatteConfig("maxin-cn/Latte-1", enable_pab=True)
engine = VideoSysEngine(config)
Expand All @@ -27,4 +36,5 @@ def run_pab():

if __name__ == "__main__":
run_base()
# run_low_mem()
# run_pab()
10 changes: 10 additions & 0 deletions examples/open_sora/sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,15 @@ def run_base():
engine.save_video(video, f"./outputs/{prompt}.mp4")


def run_low_mem():
config = OpenSoraConfig(cpu_offload=True, tiling_size=1)
engine = VideoSysEngine(config)

prompt = "Sunset over the sea."
video = engine.generate(prompt).video[0]
engine.save_video(video, f"./outputs/{prompt}.mp4")


def run_pab():
config = OpenSoraConfig(enable_pab=True)
engine = VideoSysEngine(config)
Expand All @@ -31,4 +40,5 @@ def run_pab():

if __name__ == "__main__":
run_base()
# run_low_mem()
# run_pab()
12 changes: 11 additions & 1 deletion examples/open_sora_plan/sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,17 @@ def run_base():
engine.save_video(video, f"./outputs/{prompt}.mp4")


def run_low_mem():
config = OpenSoraPlanConfig(cpu_offload=True, enable_tiling=True)
engine = VideoSysEngine(config)

prompt = "Sunset over the sea."
video = engine.generate(prompt).video[0]
engine.save_video(video, f"./outputs/{prompt}.mp4")


def run_pab():
config = OpenSoraPlanConfig(num_gpus=1, enable_pab=True)
config = OpenSoraPlanConfig(enable_pab=True)
engine = VideoSysEngine(config)

prompt = "Sunset over the sea."
Expand All @@ -27,4 +36,5 @@ def run_pab():

if __name__ == "__main__":
run_base()
# run_low_mem()
# run_pab()
23 changes: 15 additions & 8 deletions tests/examples/test_sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,19 @@
import examples.open_sora.sample as open_sora
import examples.open_sora_plan.sample as open_sora_plan

files = [cogvideox, latte, open_sora, open_sora_plan]
members = []

@pytest.mark.parametrize("file", [cogvideox, latte, open_sora, open_sora_plan])
def test_examples(file):
funcs = inspect.getmembers(file, inspect.isfunction)
for name, func in funcs:
try:
func()
except Exception as e:
raise Exception(f"Failed to run {name} in {file.__file__}") from e
for file in files:
for m in inspect.getmembers(file, inspect.isfunction):
members.append(m)
print(members)


@pytest.mark.parametrize("members", members)
def test_examples(members):
name, func = members
try:
func()
except Exception as e:
raise Exception(f"Failed to run {name} in {file.__file__}") from e
10 changes: 10 additions & 0 deletions tests/pipelines/cogvideox/test_cogvideox.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,3 +21,13 @@ def test_pab(num_gpus):
prompt = "Sunset over the sea."
video = engine.generate(prompt).video[0]
engine.save_video(video, f"./test_outputs/{prompt}_cogvideo_pab_{num_gpus}.mp4")


@pytest.mark.parametrize("num_gpus", [1])
def test_low_mem(num_gpus):
config = CogVideoXConfig(num_gpus=num_gpus, cpu_offload=True, vae_tiling=True)
engine = VideoSysEngine(config)

prompt = "Sunset over the sea."
video = engine.generate(prompt).video[0]
engine.save_video(video, f"./test_outputs/{prompt}_cogvideo_low_mem_{num_gpus}.mp4")
10 changes: 10 additions & 0 deletions tests/pipelines/latte/test_latte.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,3 +21,13 @@ def test_pab(num_gpus):
prompt = "Sunset over the sea."
video = engine.generate(prompt).video[0]
engine.save_video(video, f"./test_outputs/{prompt}_latte_pab_{num_gpus}.mp4")


@pytest.mark.parametrize("num_gpus", [1])
def test_low_mem(num_gpus):
config = LatteConfig(num_gpus=num_gpus, cpu_offload=True)
engine = VideoSysEngine(config)

prompt = "Sunset over the sea."
video = engine.generate(prompt).video[0]
engine.save_video(video, f"./test_outputs/{prompt}_latte_low_mem_{num_gpus}.mp4")
10 changes: 10 additions & 0 deletions tests/pipelines/open_sora/test_open_sora.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,3 +21,13 @@ def test_pab(num_gpus):
prompt = "Sunset over the sea."
video = engine.generate(prompt).video[0]
engine.save_video(video, f"./test_outputs/{prompt}_open_sora_pab_{num_gpus}.mp4")


@pytest.mark.parametrize("num_gpus", [1])
def test_low_mem(num_gpus):
config = OpenSoraConfig(num_gpus=num_gpus, cpu_offload=True, tiling_size=1)
engine = VideoSysEngine(config)

prompt = "Sunset over the sea."
video = engine.generate(prompt).video[0]
engine.save_video(video, f"./test_outputs/{prompt}_open_sora_low_mem_{num_gpus}.mp4")
10 changes: 10 additions & 0 deletions tests/pipelines/open_sora_plan/test_open_sora_plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,3 +21,13 @@ def test_pab(num_gpus):
prompt = "Sunset over the sea."
video = engine.generate(prompt).video[0]
engine.save_video(video, f"./test_outputs/{prompt}_open_sora_plan_pab_{num_gpus}.mp4")


@pytest.mark.parametrize("num_gpus", [1])
def test_low_mem(num_gpus):
config = OpenSoraPlanConfig(num_gpus=num_gpus, cpu_offload=True, enable_tiling=True)
engine = VideoSysEngine(config)

prompt = "Sunset over the sea."
video = engine.generate(prompt).video[0]
engine.save_video(video, f"./test_outputs/{prompt}_open_sora_plan_low_mem_{num_gpus}.mp4")
6 changes: 2 additions & 4 deletions videosys/core/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from typing import Any, Optional

import torch
import torch.distributed as dist

import videosys

Expand All @@ -22,9 +23,6 @@ def __init__(self, config):
def _init_worker(self, pipeline_cls):
world_size = self.config.num_gpus

if "CUDA_VISIBLE_DEVICES" not in os.environ:
os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(str(i) for i in range(world_size))

# Disable torch async compiling which won't work with daemonic processes
os.environ["TORCHINDUCTOR_COMPILE_THREADS"] = "1"

Expand Down Expand Up @@ -124,7 +122,7 @@ def save_video(self, video, output_path):
def shutdown(self):
if (worker_monitor := getattr(self, "worker_monitor", None)) is not None:
worker_monitor.close()
torch.distributed.destroy_process_group()
dist.destroy_process_group()

def __del__(self):
self.shutdown()
3 changes: 3 additions & 0 deletions videosys/models/autoencoders/autoencoder_kl_open_sora.py
Original file line number Diff line number Diff line change
Expand Up @@ -670,6 +670,9 @@ def encode(self, x):
return (z - self.shift) / self.scale

def decode(self, z, num_frames=None):
device = z.device
self.scale = self.scale.to(device)
self.shift = self.shift.to(device)
if not self.cal_loss:
z = z * self.scale.to(z.dtype) + self.shift.to(z.dtype)

Expand Down
2 changes: 1 addition & 1 deletion videosys/pipelines/cogvideox/pipeline_cogvideox.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ def __init__(


class CogVideoXPipeline(VideoSysPipeline):
_optional_components = ["tokenizer", "text_encoder", "vae", "transformer", "scheduler"]
_optional_components = ["text_encoder", "tokenizer"]
model_cpu_offload_seq = "text_encoder->transformer->vae"
_callback_tensor_inputs = [
"latents",
Expand Down
14 changes: 12 additions & 2 deletions videosys/pipelines/latte/pipeline_latte.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,8 @@ def __init__(
beta_end: float = 0.02,
beta_schedule: str = "linear",
variance_type: str = "learned_range",
# ======= memory =======
cpu_offload: bool = False,
# ======= pab ========
enable_pab: bool = False,
pab_config: PABConfig = LattePABConfig(),
Expand All @@ -148,6 +150,8 @@ def __init__(
self.num_gpus = num_gpus
# ======= vae ========
self.enable_vae_temporal_decoder = enable_vae_temporal_decoder
# ======= memory ========
self.cpu_offload = cpu_offload
# ======= scheduler ========
self.beta_start = beta_start
self.beta_end = beta_end
Expand Down Expand Up @@ -235,12 +239,18 @@ def __init__(
set_pab_manager(config.pab_config)

# set eval and device
self.set_eval_and_device(device, text_encoder, vae, transformer)
self.set_eval_and_device(device, vae, transformer)

self.register_modules(
tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler
)

# cpu offload
if config.cpu_offload:
self.enable_model_cpu_offload()
else:
self.set_eval_and_device(device, text_encoder)

self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)

Expand Down Expand Up @@ -744,7 +754,7 @@ def generate(
else:
batch_size = prompt_embeds.shape[0]

device = self.text_encoder.device or self._execution_device
device = self._execution_device

# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
Expand Down
25 changes: 19 additions & 6 deletions videosys/pipelines/open_sora/pipeline_open_sora.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,8 @@ def __init__(
# ======== scheduler ========
num_sampling_steps: int = 30,
cfg_scale: float = 7.0,
# ======= memory =======
cpu_offload: bool = False,
# ======== vae ========
tiling_size: int = 4,
# ======== speedup ========
Expand All @@ -151,6 +153,8 @@ def __init__(
self.cfg_scale = cfg_scale
# ======== vae ========
self.tiling_size = tiling_size
# ======= memory ========
self.cpu_offload = cpu_offload
# ======== speedup ========
self.enable_flash_attn = enable_flash_attn
# ======== pab ========
Expand Down Expand Up @@ -184,7 +188,10 @@ class OpenSoraPipeline(VideoSysPipeline):
r"[" + "#®•©™&@·º½¾¿¡§~" + "\)" + "\(" + "\]" + "\[" + "\}" + "\{" + "\|" + "\\" + "\/" + "\*" + r"]{1,}"
) # noqa

_optional_components = ["tokenizer", "text_encoder"]
_optional_components = [
"text_encoder",
"tokenizer",
]
model_cpu_offload_seq = "text_encoder->transformer->vae"

def __init__(
Expand Down Expand Up @@ -228,12 +235,18 @@ def __init__(
set_pab_manager(config.pab_config)

# set eval and device
self.set_eval_and_device(device, text_encoder, vae, transformer)
self.set_eval_and_device(device, vae, transformer)

self.register_modules(
text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler, tokenizer=tokenizer
)

# cpu offload
if config.cpu_offload:
self.enable_model_cpu_offload()
else:
self.set_eval_and_device(self._device, text_encoder)

def get_text_embeddings(self, texts):
text_tokens_and_mask = self.tokenizer(
texts,
Expand All @@ -244,9 +257,9 @@ def get_text_embeddings(self, texts):
add_special_tokens=True,
return_tensors="pt",
)

input_ids = text_tokens_and_mask["input_ids"].to(self.device)
attention_mask = text_tokens_and_mask["attention_mask"].to(self.device)
device = self._execution_device
input_ids = text_tokens_and_mask["input_ids"].to(device)
attention_mask = text_tokens_and_mask["attention_mask"].to(device)
with torch.no_grad():
text_encoder_embs = self.text_encoder(
input_ids=input_ids,
Expand All @@ -260,7 +273,7 @@ def encode_prompt(self, text):
return dict(y=caption_embs, mask=emb_masks)

def null_embed(self, n):
null_y = self.transformer.y_embedder.y_embedding[None].repeat(n, 1, 1)[:, None]
null_y = self.transformer.y_embedder.y_embedding[None].repeat(n, 1, 1)[:, None].to(self._execution_device)
return null_y

@staticmethod
Expand Down
18 changes: 13 additions & 5 deletions videosys/pipelines/open_sora_plan/pipeline_open_sora_plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,8 @@ def __init__(
num_frames: int = 65,
# ======= distributed ========
num_gpus: int = 1,
# ======= vae =======
# ======= memory =======
cpu_offload: bool = False,
enable_tiling: bool = True,
tile_overlap_factor: float = 0.25,
# ======= pab ========
Expand All @@ -185,7 +186,8 @@ def __init__(
self.version = f"{num_frames}x512x512"
# ======= distributed ========
self.num_gpus = num_gpus
# ======= vae ========
# ======= memory ========
self.cpu_offload = cpu_offload
self.enable_tiling = enable_tiling
self.tile_overlap_factor = tile_overlap_factor
# ======= pab ========
Expand Down Expand Up @@ -256,7 +258,7 @@ def __init__(
transformer.force_images = False

# set eval and device
self.set_eval_and_device(device, text_encoder, vae, transformer)
self.set_eval_and_device(device, vae, transformer)

# pab
if config.enable_pab:
Expand All @@ -266,6 +268,12 @@ def __init__(
tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler
)

# cpu offload
if config.cpu_offload:
self.enable_model_cpu_offload()
else:
self.set_eval_and_device(device, text_encoder)

# self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)

# Adapted from https://github.com/PixArt-alpha/PixArt-alpha/blob/master/diffusion/model/utils.py
Expand Down Expand Up @@ -320,7 +328,7 @@ def encode_prompt(
embeds_initially_provided = prompt_embeds is not None and negative_prompt_embeds is not None

if device is None:
device = self.text_encoder.device or self._execution_device
device = self._execution_device

if prompt is not None and isinstance(prompt, str):
batch_size = 1
Expand Down Expand Up @@ -774,7 +782,7 @@ def generate(
else:
batch_size = prompt_embeds.shape[0]

device = self.text_encoder.device or self._execution_device
device = self._execution_device

# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
Expand Down

0 comments on commit 2e81cd8

Please sign in to comment.