diff --git a/eval/pab/experiments/opensora_plan.py b/eval/pab/experiments/opensora_plan.py index 8e5c5321..3a5716fa 100644 --- a/eval/pab/experiments/opensora_plan.py +++ b/eval/pab/experiments/opensora_plan.py @@ -1,6 +1,6 @@ from utils import generate_func, read_prompt_list -from videosys import OpenSoraPlanConfig, OpenSoraPlanPABConfig, VideoSysEngine +from videosys import OpenSoraPlanConfig, OpenSoraPlanV110PABConfig, VideoSysEngine def eval_base(prompt_list): @@ -10,7 +10,7 @@ def eval_base(prompt_list): def eval_pab1(prompt_list): - pab_config = OpenSoraPlanPABConfig( + pab_config = OpenSoraPlanV110PABConfig( spatial_gap=2, temporal_gap=4, cross_gap=6, @@ -21,7 +21,7 @@ def eval_pab1(prompt_list): def eval_pab2(prompt_list): - pab_config = OpenSoraPlanPABConfig( + pab_config = OpenSoraPlanV110PABConfig( spatial_gap=3, temporal_gap=5, cross_gap=7, @@ -32,7 +32,7 @@ def eval_pab2(prompt_list): def eval_pab3(prompt_list): - pab_config = OpenSoraPlanPABConfig( + pab_config = OpenSoraPlanV110PABConfig( spatial_gap=5, temporal_gap=7, cross_gap=9, diff --git a/examples/open_sora_plan/sample.py b/examples/open_sora_plan/sample.py index 2916c7fc..f98a8016 100644 --- a/examples/open_sora_plan/sample.py +++ b/examples/open_sora_plan/sample.py @@ -5,7 +5,7 @@ def run_base(): # open-sora-plan v1.2.0 # transformer_type (len, res): 93x480p 93x720p 29x480p 29x720p # change num_gpus for multi-gpu inference - config = OpenSoraPlanConfig(version="v120", transformer_type="93x480p", num_gpus=1) + config = OpenSoraPlanConfig(version="v120", transformer_type="29x480p", num_gpus=1) engine = VideoSysEngine(config) prompt = "Sunset over the sea." diff --git a/tests/pipelines/cogvideox/test_cogvideox.py b/tests/pipelines/cogvideox/test_cogvideox.py index 3646876d..5c657137 100644 --- a/tests/pipelines/cogvideox/test_cogvideox.py +++ b/tests/pipelines/cogvideox/test_cogvideox.py @@ -1,33 +1,40 @@ import pytest from videosys import CogVideoXConfig, VideoSysEngine +from videosys.utils.test import empty_cache @pytest.mark.parametrize("num_gpus", [1, 2]) -def test_base(num_gpus): - config = CogVideoXConfig(num_gpus=num_gpus) +@pytest.mark.parametrize("model_path", ["THUDM/CogVideoX-2b", "THUDM/CogVideoX-5b"]) +@empty_cache +def test_base(num_gpus, model_path): + config = CogVideoXConfig(model_path=model_path, num_gpus=num_gpus) engine = VideoSysEngine(config) prompt = "Sunset over the sea." video = engine.generate(prompt, seed=0).video[0] - engine.save_video(video, f"./test_outputs/{prompt}_cogvideo_{num_gpus}.mp4") + engine.save_video(video, f"./test_outputs/{prompt}_cogvideo_{model_path.replace('/', '_')}_base_{num_gpus}.mp4") @pytest.mark.parametrize("num_gpus", [1]) -def test_pab(num_gpus): - config = CogVideoXConfig(num_gpus=num_gpus, enable_pab=True) +@pytest.mark.parametrize("model_path", ["THUDM/CogVideoX-2b"]) +@empty_cache +def test_pab(num_gpus, model_path): + config = CogVideoXConfig(model_path=model_path, num_gpus=num_gpus, enable_pab=True) engine = VideoSysEngine(config) prompt = "Sunset over the sea." video = engine.generate(prompt, seed=0).video[0] - engine.save_video(video, f"./test_outputs/{prompt}_cogvideo_pab_{num_gpus}.mp4") + engine.save_video(video, f"./test_outputs/{prompt}_cogvideo_{model_path.replace('/', '_')}_pab_{num_gpus}.mp4") @pytest.mark.parametrize("num_gpus", [1]) -def test_low_mem(num_gpus): - config = CogVideoXConfig(num_gpus=num_gpus, cpu_offload=True, vae_tiling=True) +@pytest.mark.parametrize("model_path", ["THUDM/CogVideoX-2b"]) +@empty_cache +def test_low_mem(num_gpus, model_path): + config = CogVideoXConfig(model_path=model_path, num_gpus=num_gpus, cpu_offload=True, vae_tiling=True) engine = VideoSysEngine(config) prompt = "Sunset over the sea." video = engine.generate(prompt, seed=0).video[0] - engine.save_video(video, f"./test_outputs/{prompt}_cogvideo_low_mem_{num_gpus}.mp4") + engine.save_video(video, f"./test_outputs/{prompt}_cogvideo_{model_path.replace('/', '_')}_low_mem_{num_gpus}.mp4") diff --git a/tests/pipelines/latte/test_latte.py b/tests/pipelines/latte/test_latte.py index 0c4aaf24..74d9428a 100644 --- a/tests/pipelines/latte/test_latte.py +++ b/tests/pipelines/latte/test_latte.py @@ -1,9 +1,11 @@ import pytest from videosys import LatteConfig, VideoSysEngine +from videosys.utils.test import empty_cache @pytest.mark.parametrize("num_gpus", [1, 2]) +@empty_cache def test_base(num_gpus): config = LatteConfig(num_gpus=num_gpus) engine = VideoSysEngine(config) @@ -14,6 +16,7 @@ def test_base(num_gpus): @pytest.mark.parametrize("num_gpus", [1]) +@empty_cache def test_pab(num_gpus): config = LatteConfig(num_gpus=num_gpus, enable_pab=True) engine = VideoSysEngine(config) @@ -24,6 +27,7 @@ def test_pab(num_gpus): @pytest.mark.parametrize("num_gpus", [1]) +@empty_cache def test_low_mem(num_gpus): config = LatteConfig(num_gpus=num_gpus, cpu_offload=True) engine = VideoSysEngine(config) diff --git a/tests/pipelines/open_sora/test_open_sora.py b/tests/pipelines/open_sora/test_open_sora.py index 5da6c6d4..891d0332 100644 --- a/tests/pipelines/open_sora/test_open_sora.py +++ b/tests/pipelines/open_sora/test_open_sora.py @@ -1,9 +1,11 @@ import pytest from videosys import OpenSoraConfig, VideoSysEngine +from videosys.utils.test import empty_cache @pytest.mark.parametrize("num_gpus", [1, 2]) +@empty_cache def test_base(num_gpus): config = OpenSoraConfig(num_gpus=num_gpus) engine = VideoSysEngine(config) @@ -14,6 +16,7 @@ def test_base(num_gpus): @pytest.mark.parametrize("num_gpus", [1]) +@empty_cache def test_pab(num_gpus): config = OpenSoraConfig(num_gpus=num_gpus, enable_pab=True) engine = VideoSysEngine(config) @@ -24,6 +27,7 @@ def test_pab(num_gpus): @pytest.mark.parametrize("num_gpus", [1]) +@empty_cache def test_low_mem(num_gpus): config = OpenSoraConfig(num_gpus=num_gpus, cpu_offload=True, tiling_size=1) engine = VideoSysEngine(config) diff --git a/tests/pipelines/open_sora_plan/test_open_sora_plan.py b/tests/pipelines/open_sora_plan/test_open_sora_plan.py index 4bee6ad2..b40170e4 100644 --- a/tests/pipelines/open_sora_plan/test_open_sora_plan.py +++ b/tests/pipelines/open_sora_plan/test_open_sora_plan.py @@ -1,33 +1,42 @@ import pytest from videosys import OpenSoraPlanConfig, VideoSysEngine +from videosys.utils.test import empty_cache @pytest.mark.parametrize("num_gpus", [1, 2]) -def test_base(num_gpus): - config = OpenSoraPlanConfig(num_gpus=num_gpus) +@pytest.mark.parametrize("model", [("v120", "29x480p")]) +@empty_cache +def test_base(num_gpus, model): + config = OpenSoraPlanConfig(version=model[0], transformer_type=model[1], num_gpus=num_gpus) engine = VideoSysEngine(config) prompt = "Sunset over the sea." video = engine.generate(prompt, seed=0).video[0] - engine.save_video(video, f"./test_outputs/{prompt}_open_sora_plan_{num_gpus}.mp4") + engine.save_video(video, f"./test_outputs/{prompt}_open_sora_plan_{model[0]}_{model[1]}_{num_gpus}.mp4") @pytest.mark.parametrize("num_gpus", [1]) -def test_pab(num_gpus): - config = OpenSoraPlanConfig(num_gpus=num_gpus, enable_pab=True) +@pytest.mark.parametrize("model", [("v120", "29x480p")]) +@empty_cache +def test_pab(num_gpus, model): + config = OpenSoraPlanConfig(version=model[0], transformer_type=model[1], num_gpus=num_gpus, enable_pab=True) engine = VideoSysEngine(config) prompt = "Sunset over the sea." video = engine.generate(prompt, seed=0).video[0] - engine.save_video(video, f"./test_outputs/{prompt}_open_sora_plan_pab_{num_gpus}.mp4") + engine.save_video(video, f"./test_outputs/{prompt}_open_sora_plan_{model[0]}_{model[1]}_pab_{num_gpus}.mp4") @pytest.mark.parametrize("num_gpus", [1]) -def test_low_mem(num_gpus): - config = OpenSoraPlanConfig(num_gpus=num_gpus, cpu_offload=True, enable_tiling=True) +@pytest.mark.parametrize("model", [("v120", "29x480p")]) +@empty_cache +def test_low_mem(num_gpus, model): + config = OpenSoraPlanConfig( + version=model[0], transformer_type=model[1], num_gpus=num_gpus, cpu_offload=True, enable_tiling=True + ) engine = VideoSysEngine(config) prompt = "Sunset over the sea." video = engine.generate(prompt, seed=0).video[0] - engine.save_video(video, f"./test_outputs/{prompt}_open_sora_plan_low_mem_{num_gpus}.mp4") + engine.save_video(video, f"./test_outputs/{prompt}_open_sora_plan_{model[0]}_{model[1]}_low_mem_{num_gpus}.mp4") diff --git a/tests/pipelines/vchitect/test_vchitect.py b/tests/pipelines/vchitect/test_vchitect.py index 9b9ff678..53cc6c67 100644 --- a/tests/pipelines/vchitect/test_vchitect.py +++ b/tests/pipelines/vchitect/test_vchitect.py @@ -1,33 +1,44 @@ import pytest from videosys import VchitectConfig, VideoSysEngine +from videosys.utils.test import empty_cache @pytest.mark.parametrize("num_gpus", [1, 2]) -def test_base(num_gpus): - config = VchitectConfig(num_gpus=num_gpus) +@pytest.mark.parametrize("model_path", ["Vchitect/Vchitect-2.0-2B"]) +@empty_cache +def test_base(num_gpus, model_path): + config = VchitectConfig(model_path=model_path, num_gpus=num_gpus) engine = VideoSysEngine(config) prompt = "Sunset over the sea." video = engine.generate(prompt, seed=0).video[0] - engine.save_video(video, f"./test_outputs/{prompt}_vchitect_{num_gpus}.mp4") + engine.save_video(video, f"./test_outputs/{prompt}_vchitect_{model_path.replace('/', '_')}_base_{num_gpus}.mp4") @pytest.mark.parametrize("num_gpus", [1]) -def test_pab(num_gpus): - config = VchitectConfig(num_gpus=num_gpus, enable_pab=True) +@pytest.mark.parametrize("model_path", ["Vchitect/Vchitect-2.0-2B"]) +@empty_cache +def test_pab(num_gpus, model_path): + config = VchitectConfig(model_path=model_path, num_gpus=num_gpus, enable_pab=True) engine = VideoSysEngine(config) prompt = "Sunset over the sea." video = engine.generate(prompt, seed=0).video[0] - engine.save_video(video, f"./test_outputs/{prompt}_vchitect_pab_{num_gpus}.mp4") + engine.save_video(video, f"./test_outputs/{prompt}_vchitect_{model_path.replace('/', '_')}_pab_{num_gpus}.mp4") @pytest.mark.parametrize("num_gpus", [1]) -def test_low_mem(num_gpus): - config = VchitectConfig(num_gpus=num_gpus, cpu_offload=True) +@pytest.mark.parametrize("model_path", ["Vchitect/Vchitect-2.0-2B"]) +@empty_cache +def test_low_mem(num_gpus, model_path): + config = VchitectConfig( + model_path=model_path, + num_gpus=num_gpus, + cpu_offload=True, + ) engine = VideoSysEngine(config) prompt = "Sunset over the sea." video = engine.generate(prompt, seed=0).video[0] - engine.save_video(video, f"./test_outputs/{prompt}_vchitect_low_mem_{num_gpus}.mp4") + engine.save_video(video, f"./test_outputs/{prompt}_vchitect_{model_path.replace('/', '_')}_low_mem_{num_gpus}.mp4") diff --git a/videosys/__init__.py b/videosys/__init__.py index d88d0580..6c539c0f 100644 --- a/videosys/__init__.py +++ b/videosys/__init__.py @@ -3,15 +3,20 @@ from .pipelines.cogvideox import CogVideoXConfig, CogVideoXPABConfig, CogVideoXPipeline from .pipelines.latte import LatteConfig, LattePABConfig, LattePipeline from .pipelines.open_sora import OpenSoraConfig, OpenSoraPABConfig, OpenSoraPipeline -from .pipelines.open_sora_plan import OpenSoraPlanConfig, OpenSoraPlanPABConfig, OpenSoraPlanPipeline -from .pipelines.vchitect import VchitectConfig, VchitectXLPipeline +from .pipelines.open_sora_plan import ( + OpenSoraPlanConfig, + OpenSoraPlanPipeline, + OpenSoraPlanV110PABConfig, + OpenSoraPlanV120PABConfig, +) +from .pipelines.vchitect import VchitectConfig, VchitectPABConfig, VchitectXLPipeline __all__ = [ "initialize", "VideoSysEngine", "LattePipeline", "LatteConfig", "LattePABConfig", - "OpenSoraPlanPipeline", "OpenSoraPlanConfig", "OpenSoraPlanPABConfig", + "OpenSoraPlanPipeline", "OpenSoraPlanConfig", "OpenSoraPlanV110PABConfig", "OpenSoraPlanV120PABConfig", "OpenSoraPipeline", "OpenSoraConfig", "OpenSoraPABConfig", "CogVideoXPipeline", "CogVideoXConfig", "CogVideoXPABConfig", - "VchitectXLPipeline", "VchitectConfig", + "VchitectXLPipeline", "VchitectConfig", "VchitectPABConfig" ] # fmt: skip diff --git a/videosys/models/transformers/open_sora_plan_v120_transformer_3d.py b/videosys/models/transformers/open_sora_plan_v120_transformer_3d.py index 94eec28a..761d0d9f 100644 --- a/videosys/models/transformers/open_sora_plan_v120_transformer_3d.py +++ b/videosys/models/transformers/open_sora_plan_v120_transformer_3d.py @@ -26,7 +26,8 @@ from torch import nn from torch.nn import functional as F -from videosys.core.comm import all_to_all_comm +from videosys.core.comm import all_to_all_comm, gather_sequence, split_sequence +from videosys.core.pab_mgr import enable_pab, if_broadcast_cross, if_broadcast_spatial from videosys.core.parallel_mgr import ParallelManager from videosys.core.pipeline import VideoSysPipelineOutput @@ -43,17 +44,13 @@ def __init__( ): self.cache_positions = {} - def __call__(self, b, t, h, w, device, attn): + def __call__(self, b, t, h, w, device): if not (b, t, h, w) in self.cache_positions: x = torch.arange(w, device=device) y = torch.arange(h, device=device) z = torch.arange(t, device=device) pos = torch.cartesian_prod(z, y, x) - if attn.parallel_manager.sp_size > 1: - # print('PositionGetter3D', PositionGetter3D) - pos = pos.reshape(t * h * w, 3).transpose(0, 1).reshape(3, -1, 1).contiguous().expand(3, -1, b).clone() - else: - pos = pos.reshape(t * h * w, 3).transpose(0, 1).reshape(3, 1, -1).contiguous().expand(3, b, -1).clone() + pos = pos.reshape(t * h * w, 3).transpose(0, 1).reshape(3, 1, -1).contiguous().expand(3, b, -1).clone() poses = (pos[0].contiguous(), pos[1].contiguous(), pos[2].contiguous()) max_poses = (int(poses[0].max()), int(poses[1].max()), int(poses[2].max())) @@ -89,20 +86,15 @@ def rotate_half(x): x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :] return torch.cat((-x2, x1), dim=-1) - def apply_rope1d(self, tokens, pos1d, cos, sin, attn): + def apply_rope1d(self, tokens, pos1d, cos, sin): assert pos1d.ndim == 2 - if attn.parallel_manager.sp_size == 1: - # for (batch_size x nheads x ntokens x dim) - cos = torch.nn.functional.embedding(pos1d, cos)[:, None, :, :] - sin = torch.nn.functional.embedding(pos1d, sin)[:, None, :, :] - else: - # for (batch_size x ntokens x nheads x dim) - cos = torch.nn.functional.embedding(pos1d, cos)[:, :, None, :] - sin = torch.nn.functional.embedding(pos1d, sin)[:, :, None, :] + # for (batch_size x ntokens x nheads x dim) + cos = torch.nn.functional.embedding(pos1d, cos)[:, None, :, :] + sin = torch.nn.functional.embedding(pos1d, sin)[:, None, :, :] return (tokens * cos) + (self.rotate_half(tokens) * sin) - def forward(self, tokens, positions, attn): + def forward(self, tokens, positions): """ input: * tokens: batch_size x nheads x ntokens x dim @@ -119,9 +111,9 @@ def forward(self, tokens, positions, attn): cos_x, sin_x = self.get_cos_sin(D, max_poses[2] + 1, tokens.device, tokens.dtype, self.interpolation_scale_w) # split features into three along the feature dimension, and apply rope1d on each half t, y, x = tokens.chunk(3, dim=-1) - t = self.apply_rope1d(t, poses[0], cos_t, sin_t, attn) - y = self.apply_rope1d(y, poses[1], cos_y, sin_y, attn) - x = self.apply_rope1d(x, poses[2], cos_x, sin_x, attn) + t = self.apply_rope1d(t, poses[0], cos_t, sin_t) + y = self.apply_rope1d(y, poses[1], cos_y, sin_y) + x = self.apply_rope1d(x, poses[2], cos_x, sin_x) tokens = torch.cat((t, y, x), dim=-1) return tokens @@ -305,9 +297,6 @@ def __init__( self.register_buffer("temp_pos_embed", torch.from_numpy(temp_pos_embed).float().unsqueeze(0), persistent=False) # self.temp_embed_gate = nn.Parameter(torch.tensor([0.0])) - # parallel - self.parallel_manager: ParallelManager = None - def forward(self, latent, num_frames): b, _, _, _, _ = latent.shape video_latent, image_latent = None, None @@ -339,30 +328,7 @@ def forward(self, latent, num_frames): pos_embed = self.pos_embed if self.num_frames != num_frames: - # import ipdb;ipdb.set_trace() - # raise NotImplementedError - if self.parallel_manager.sp_size > 1: - sp_size = self.parallel_manager.sp_size - temp_pos_embed = get_1d_sincos_pos_embed( - embed_dim=self.temp_pos_embed.shape[-1], - grid_size=num_frames * sp_size, - base_size=self.base_size_t, - interpolation_scale=self.interpolation_scale_t, - ) - rank = self.parallel_manager.sp_size % sp_size - st_frame = rank * num_frames - ed_frame = st_frame + num_frames - temp_pos_embed = temp_pos_embed[st_frame:ed_frame] - - else: - temp_pos_embed = get_1d_sincos_pos_embed( - embed_dim=self.temp_pos_embed.shape[-1], - grid_size=num_frames, - base_size=self.base_size_t, - interpolation_scale=self.interpolation_scale_t, - ) - temp_pos_embed = torch.from_numpy(temp_pos_embed) - temp_pos_embed = temp_pos_embed.float().unsqueeze(0).to(latent.device) + raise NotImplementedError else: temp_pos_embed = self.temp_pos_embed @@ -396,7 +362,7 @@ def forward(self, latent, num_frames): else None ) - if num_frames == 1 and image_latent is None and not (self.parallel_manager.sp_size > 1): + if num_frames == 1 and image_latent is None: image_latent = video_latent video_latent = None # print('video_latent is None, image_latent is None', video_latent is None, image_latent is None) @@ -738,8 +704,6 @@ def prepare_attention_mask( `torch.Tensor`: The prepared attention mask. """ head_size = self.heads - if self.parallel_manager.sp_size > 1: - head_size = head_size // self.parallel_manager.sp_size if attention_mask is None: return attention_mask @@ -920,32 +884,15 @@ def __call__( batch_size, channel, height, width = hidden_states.shape hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) - if attn.parallel_manager.sp_size > 1: - if npu_config is not None: - sequence_length, batch_size, _ = ( - hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape - ) - else: - sequence_length, batch_size, _ = ( - hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape - ) - else: - batch_size, sequence_length, _ = ( - hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape - ) + batch_size, sequence_length, _ = ( + hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + ) if attention_mask is not None: - attention_mask = attn.prepare_attention_mask( - attention_mask, sequence_length * attn.parallel_manager.sp_size, batch_size - ) + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) # scaled_dot_product_attention expects attention_mask shape to be # (batch, heads, source_length, target_length) - if attn.parallel_manager.sp_size > 1: - attention_mask = attention_mask.view( - batch_size, attn.heads // attn.parallel_manager.sp_size, -1, attention_mask.shape[-1] - ) - else: - attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) + attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) if attn.group_norm is not None: hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) @@ -962,99 +909,38 @@ def __call__( inner_dim = key.shape[-1] head_dim = inner_dim // attn.heads - if attn.parallel_manager.sp_size > 1: - query = query.reshape(-1, attn.heads, head_dim) # [s // sp, b, h * d] -> [s // sp * b, h, d] - key = key.reshape(-1, attn.heads, head_dim) - value = value.reshape(-1, attn.heads, head_dim) - # query = attn.q_norm(query) - # key = attn.k_norm(key) - h_size = attn.heads * head_dim - sp_size = attn.parallel_manager.sp_size - h_size_sp = h_size // sp_size - # apply all_to_all to gather sequence and split attention heads [s // sp * b, h, d] -> [s * b, h // sp, d] - query = all_to_all_comm(query, attn.parallel_manager.sp_group, scatter_dim=1, gather_dim=0).reshape( - -1, batch_size, h_size_sp - ) - key = all_to_all_comm(key, attn.parallel_manager.sp_group, scatter_dim=1, gather_dim=0).reshape( - -1, batch_size, h_size_sp - ) - value = all_to_all_comm(value, attn.parallel_manager.sp_group, scatter_dim=1, gather_dim=0).reshape( - -1, batch_size, h_size_sp - ) - query = query.reshape(-1, batch_size, attn.heads // sp_size, head_dim) - key = key.reshape(-1, batch_size, attn.heads // sp_size, head_dim) - value = value.reshape(-1, batch_size, attn.heads // sp_size, head_dim) - # print('query', query.shape, 'key', key.shape, 'value', value.shape) - if self.use_rope: - # require the shape of (batch_size x nheads x ntokens x dim) - pos_thw = self.position_getter( - batch_size, t=frame * sp_size, h=height, w=width, device=query.device, attn=attn - ) - query = self.rope(query, pos_thw, attn) - key = self.rope(key, pos_thw, attn) - - # print('after rope query', query.shape, 'key', key.shape, 'value', value.shape) - query = rearrange(query, "s b h d -> b h s d") - key = rearrange(key, "s b h d -> b h s d") - value = rearrange(value, "s b h d -> b h s d") - # print('rearrange query', query.shape, 'key', key.shape, 'value', value.shape) - - # 0, -10000 ->(bool) False, True ->(any) True ->(not) False - # 0, 0 ->(bool) False, False ->(any) False ->(not) True - if attention_mask is None or not torch.any(attention_mask.bool()): # 0 mean visible - attention_mask = None - # the output of sdp = (batch, num_heads, seq_len, head_dim) - # TODO: add support for attn.scale when we move to Torch 2.1 - hidden_states = F.scaled_dot_product_attention( - query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False + query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + if attn.parallel_manager.sp_size > 1 and query.shape[2] == key.shape[2]: + func = lambda x: all_to_all_comm( + x, process_group=attn.parallel_manager.sp_group, scatter_dim=1, gather_dim=2 ) - hidden_states = rearrange(hidden_states, "b h s d -> s b h d") - hidden_states = hidden_states.reshape(-1, attn.heads // sp_size, head_dim) + query, key, value = map(func, [query, key, value]) - # [s * b, h // sp, d] -> [s // sp * b, h, d] -> [s // sp, b, h * d] + if self.use_rope: + # require the shape of (batch_size x nheads x ntokens x dim) + pos_thw = self.position_getter(batch_size, t=frame, h=height, w=width, device=query.device) + query = self.rope(query, pos_thw) + key = self.rope(key, pos_thw) + + # 0, -10000 ->(bool) False, True ->(any) True ->(not) False + # 0, 0 ->(bool) False, False ->(any) False ->(not) True + if attention_mask is None or not torch.any(attention_mask.bool()): # 0 mean visible + attention_mask = None + # the output of sdp = (batch, num_heads, seq_len, head_dim) + # TODO: add support for attn.scale when we move to Torch 2.1 + hidden_states = F.scaled_dot_product_attention( + query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False + ) + + if attn.parallel_manager.sp_size > 1 and query.shape[2] == key.shape[2]: hidden_states = all_to_all_comm( - hidden_states, attn.parallel_manager.sp_group, scatter_dim=0, gather_dim=1 - ).reshape(-1, batch_size, h_size) - else: - query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - - # qk norm - # query = attn.q_norm(query) - # key = attn.k_norm(key) - - if self.use_rope: - # require the shape of (batch_size x nheads x ntokens x dim) - pos_thw = self.position_getter(batch_size, t=frame, h=height, w=width, device=query.device, attn=attn) - query = self.rope(query, pos_thw, attn) - key = self.rope(key, pos_thw, attn) - - value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - - # 0, -10000 ->(bool) False, True ->(any) True ->(not) False - # 0, 0 ->(bool) False, False ->(any) False ->(not) True - if attention_mask is None or not torch.any(attention_mask.bool()): # 0 mean visible - attention_mask = None - # the output of sdp = (batch, num_heads, seq_len, head_dim) - # TODO: add support for attn.scale when we move to Torch 2.1 - # import ipdb;ipdb.set_trace() - # print(attention_mask) - if self.attention_mode == "flash": - assert attention_mask is None, "flash-attn do not support attention_mask" - with torch.backends.cuda.sdp_kernel(enable_math=False, enable_flash=True, enable_mem_efficient=False): - hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False) - elif self.attention_mode == "xformers": - with torch.backends.cuda.sdp_kernel(enable_math=False, enable_flash=False, enable_mem_efficient=True): - hidden_states = F.scaled_dot_product_attention( - query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False - ) - elif self.attention_mode == "math": - hidden_states = F.scaled_dot_product_attention( - query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False - ) - else: - raise NotImplementedError(f"Found attention_mode: {self.attention_mode}") - hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + hidden_states, process_group=attn.parallel_manager.sp_group, scatter_dim=2, gather_dim=1 + ) + + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) hidden_states = hidden_states.to(query.dtype) # linear proj @@ -1422,8 +1308,11 @@ def __init__( self._chunk_size = None self._chunk_dim = 0 - # parallel - self.parallel_manager: ParallelManager = None + # pab + self.spatial_last = None + self.spatial_count = 0 + self.cross_last = None + self.cross_count = 0 def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int = 0): # Sets chunk feed-forward @@ -1443,56 +1332,46 @@ def forward( height: int = None, width: int = None, added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None, + org_timestep: Optional[torch.LongTensor] = None, ) -> torch.FloatTensor: # Notice that normalization is always applied before the real computation in the following blocks. # 0. Self-Attention batch_size = hidden_states.shape[0] # import ipdb;ipdb.set_trace() - if self.norm_type == "ada_norm": - norm_hidden_states = self.norm1(hidden_states, timestep) - elif self.norm_type == "ada_norm_zero": - norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1( - hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype - ) - elif self.norm_type in ["layer_norm", "layer_norm_i2vgen"]: - norm_hidden_states = self.norm1(hidden_states) - elif self.norm_type == "ada_norm_continuous": - norm_hidden_states = self.norm1(hidden_states, added_cond_kwargs["pooled_text_emb"]) - elif self.norm_type == "ada_norm_single": - if self.parallel_manager.sp_size > 1: - batch_size = hidden_states.shape[1] - # print('hidden_states', hidden_states.shape) - # print('timestep', timestep.shape) - shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ( - self.scale_shift_table[:, None] + timestep.reshape(6, batch_size, -1) - ).chunk(6, dim=0) - else: - shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ( - self.scale_shift_table[None] + timestep.reshape(batch_size, 6, -1) - ).chunk(6, dim=1) - norm_hidden_states = self.norm1(hidden_states) - norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa - # norm_hidden_states = norm_hidden_states.squeeze(1) + if self.norm_type == "ada_norm_single": + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ( + self.scale_shift_table[None] + timestep.reshape(batch_size, 6, -1) + ).chunk(6, dim=1) else: raise ValueError("Incorrect norm used") - if self.pos_embed is not None: - norm_hidden_states = self.pos_embed(norm_hidden_states) - # 1. Prepare GLIGEN inputs cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {} gligen_kwargs = cross_attention_kwargs.pop("gligen", None) - attn_output = self.attn1( - norm_hidden_states, - encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None, - attention_mask=attention_mask, - frame=frame, - height=height, - width=width, - **cross_attention_kwargs, - ) + broadcast, self.spatial_count = if_broadcast_spatial(int(org_timestep[0]), self.spatial_count) + if broadcast: + attn_output = self.spatial_last + else: + norm_hidden_states = self.norm1(hidden_states) + norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa + if self.pos_embed is not None: + norm_hidden_states = self.pos_embed(norm_hidden_states) + + attn_output = self.attn1( + norm_hidden_states, + encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None, + attention_mask=attention_mask, + frame=frame, + height=height, + width=width, + **cross_attention_kwargs, + ) + + if enable_pab(): + self.spatial_last = attn_output + if self.norm_type == "ada_norm_zero": attn_output = gate_msa.unsqueeze(1) * attn_output elif self.norm_type == "ada_norm_single": @@ -1508,28 +1387,36 @@ def forward( # 3. Cross-Attention if self.attn2 is not None: - if self.norm_type == "ada_norm": - norm_hidden_states = self.norm2(hidden_states, timestep) - elif self.norm_type in ["ada_norm_zero", "layer_norm", "layer_norm_i2vgen"]: - norm_hidden_states = self.norm2(hidden_states) - elif self.norm_type == "ada_norm_single": - # For PixArt norm2 isn't applied here: - # https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L70C1-L76C103 - norm_hidden_states = hidden_states - elif self.norm_type == "ada_norm_continuous": - norm_hidden_states = self.norm2(hidden_states, added_cond_kwargs["pooled_text_emb"]) + broadcast, self.cross_count = if_broadcast_cross(int(org_timestep[0]), self.cross_count) + if broadcast: + attn_output = self.cross_last + else: - raise ValueError("Incorrect norm") + if self.norm_type == "ada_norm": + norm_hidden_states = self.norm2(hidden_states, timestep) + elif self.norm_type in ["ada_norm_zero", "layer_norm", "layer_norm_i2vgen"]: + norm_hidden_states = self.norm2(hidden_states) + elif self.norm_type == "ada_norm_single": + # For PixArt norm2 isn't applied here: + # https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L70C1-L76C103 + norm_hidden_states = hidden_states + elif self.norm_type == "ada_norm_continuous": + norm_hidden_states = self.norm2(hidden_states, added_cond_kwargs["pooled_text_emb"]) + else: + raise ValueError("Incorrect norm") - if self.pos_embed is not None and self.norm_type != "ada_norm_single": - norm_hidden_states = self.pos_embed(norm_hidden_states) + if self.pos_embed is not None and self.norm_type != "ada_norm_single": + norm_hidden_states = self.pos_embed(norm_hidden_states) - attn_output = self.attn2( - norm_hidden_states, - encoder_hidden_states=encoder_hidden_states, - attention_mask=encoder_attention_mask, - **cross_attention_kwargs, - ) + attn_output = self.attn2( + norm_hidden_states, + encoder_hidden_states=encoder_hidden_states, + attention_mask=encoder_attention_mask, + **cross_attention_kwargs, + ) + + if enable_pab(): + self.cross_last = attn_output hidden_states = attn_output + hidden_states # 4. Feed-forward @@ -1728,19 +1615,6 @@ def _init_patched_inputs(self, norm_type): if self.config.interpolation_scale_w is not None else self.config.sample_size[1] / 40, ) - # if self.config.sample_size_t > 1: - # self.pos_embed = PatchEmbed3D( - # num_frames=self.config.sample_size_t, - # height=self.config.sample_size[0], - # width=self.config.sample_size[1], - # patch_size_t=self.config.patch_size_t, - # patch_size=self.config.patch_size, - # in_channels=self.in_channels, - # embed_dim=self.inner_dim, - # interpolation_scale=interpolation_scale, - # interpolation_scale_t=interpolation_scale_t, - # ) - # else: if self.config.downsampler is not None and len(self.config.downsampler) == 9: self.pos_embed = OverlapPatchEmbed3D( num_frames=self.config.sample_size_t, @@ -1934,14 +1808,8 @@ def forward( # b, frame+use_image_num, h, w -> a video with images # b, 1, h, w -> only images attention_mask = attention_mask.to(self.dtype) - if self.parallel_manager.sp_size > 1: - attention_mask_vid = attention_mask[:, : frame * self.parallel_manager.sp_size] # b, frame, h, w - attention_mask_img = attention_mask[ - :, frame * self.parallel_manager.sp_size : - ] # b, use_image_num, h, w - else: - attention_mask_vid = attention_mask[:, :frame] # b, frame, h, w - attention_mask_img = attention_mask[:, frame:] # b, use_image_num, h, w + attention_mask_vid = attention_mask[:, :frame] # b, frame, h, w + attention_mask_img = attention_mask[:, frame:] # b, use_image_num, h, w if attention_mask_vid.numel() > 0: attention_mask_vid_first_frame = attention_mask_vid[:, :1].repeat(1, self.patch_size_t - 1, 1, 1) @@ -1968,7 +1836,7 @@ def forward( (1 - attention_mask_img.bool().to(self.dtype)) * -10000.0 if attention_mask_img.numel() > 0 else None ) - if frame == 1 and use_image_num == 0 and not (self.parallel_manager.sp_size > 1): + if frame == 1 and use_image_num == 0: attention_mask_img = attention_mask_vid attention_mask_vid = None # convert encoder_attention_mask to a bias the same way we do for attention_mask @@ -1991,7 +1859,7 @@ def forward( else None ) - if frame == 1 and use_image_num == 0 and not (self.parallel_manager.sp_size > 1): + if frame == 1 and use_image_num == 0: encoder_attention_mask_img = encoder_attention_mask_vid encoder_attention_mask_vid = None @@ -2027,11 +1895,9 @@ def forward( # 2. Blocks if self.parallel_manager.sp_size > 1: if hidden_states_vid is not None: - hidden_states_vid = rearrange(hidden_states_vid, "b s h -> s b h", b=batch_size).contiguous() - encoder_hidden_states_vid = rearrange( - encoder_hidden_states_vid, "b s h -> s b h", b=batch_size - ).contiguous() - timestep_vid = timestep_vid.view(batch_size, 6, -1).transpose(0, 1).contiguous() + hidden_states_vid = split_sequence( + hidden_states_vid, dim=1, process_group=self.parallel_manager.sp_group, grad_scale="down" + ) for block in self.transformer_blocks: if self.training and self.gradient_checkpointing: @@ -2089,6 +1955,7 @@ def custom_forward(*inputs): frame=frame, height=height, width=width, + org_timestep=timestep, ) if hidden_states_img is not None: hidden_states_img = block( @@ -2102,11 +1969,14 @@ def custom_forward(*inputs): frame=1, height=height, width=width, + org_timestep=timestep, ) if self.parallel_manager.sp_size > 1: if hidden_states_vid is not None: - hidden_states_vid = rearrange(hidden_states_vid, "s b h -> b s h", b=batch_size).contiguous() + hidden_states_vid = gather_sequence( + hidden_states_vid, dim=1, process_group=self.parallel_manager.sp_group, grad_scale="up" + ) # 3. Output output_vid, output_img = None, None diff --git a/videosys/pipelines/open_sora_plan/__init__.py b/videosys/pipelines/open_sora_plan/__init__.py index 7a1ddb8e..ddf790db 100644 --- a/videosys/pipelines/open_sora_plan/__init__.py +++ b/videosys/pipelines/open_sora_plan/__init__.py @@ -1,3 +1,8 @@ -from .pipeline_open_sora_plan import OpenSoraPlanConfig, OpenSoraPlanPABConfig, OpenSoraPlanPipeline +from .pipeline_open_sora_plan import ( + OpenSoraPlanConfig, + OpenSoraPlanPipeline, + OpenSoraPlanV110PABConfig, + OpenSoraPlanV120PABConfig, +) -__all__ = ["OpenSoraPlanConfig", "OpenSoraPlanPipeline", "OpenSoraPlanPABConfig"] +__all__ = ["OpenSoraPlanConfig", "OpenSoraPlanPipeline", "OpenSoraPlanV110PABConfig", "OpenSoraPlanV120PABConfig"] diff --git a/videosys/pipelines/open_sora_plan/pipeline_open_sora_plan.py b/videosys/pipelines/open_sora_plan/pipeline_open_sora_plan.py index bdef12c4..fdd8e436 100644 --- a/videosys/pipelines/open_sora_plan/pipeline_open_sora_plan.py +++ b/videosys/pipelines/open_sora_plan/pipeline_open_sora_plan.py @@ -38,7 +38,7 @@ from videosys.utils.utils import save_video, set_seed -class OpenSoraPlanPABConfig(PABConfig): +class OpenSoraPlanV110PABConfig(PABConfig): def __init__( self, spatial_broadcast: bool = True, @@ -100,6 +100,26 @@ def __init__( ) +class OpenSoraPlanV120PABConfig(PABConfig): + def __init__( + self, + spatial_broadcast: bool = True, + spatial_threshold: list = [100, 850], + spatial_range: int = 2, + cross_broadcast: bool = True, + cross_threshold: list = [100, 850], + cross_range: int = 6, + ): + super().__init__( + spatial_broadcast=spatial_broadcast, + spatial_threshold=spatial_threshold, + spatial_range=spatial_range, + cross_broadcast=cross_broadcast, + cross_threshold=cross_threshold, + cross_range=cross_range, + ) + + class OpenSoraPlanConfig: """ This config is to instantiate a `OpenSoraPlanPipeline` class for video generation. @@ -151,7 +171,7 @@ class OpenSoraPlanConfig: def __init__( self, version: str = "v120", - transformer_type: str = "93x480p", + transformer_type: str = "29x480p", transformer: str = None, text_encoder: str = None, # ======= distributed ======== @@ -162,7 +182,7 @@ def __init__( tile_overlap_factor: float = 0.25, # ======= pab ======== enable_pab: bool = False, - pab_config: PABConfig = OpenSoraPlanPABConfig(), + pab_config: PABConfig = None, ): self.pipeline_cls = OpenSoraPlanPipeline @@ -196,7 +216,13 @@ def __init__( self.tile_overlap_factor = tile_overlap_factor # ======= pab ======== self.enable_pab = enable_pab - self.pab_config = pab_config + if self.enable_pab and pab_config is None: + if version == "v110": + self.pab_config = OpenSoraPlanV110PABConfig() + elif version == "v120": + self.pab_config = OpenSoraPlanV120PABConfig() + else: + self.pab_config = pab_config class OpenSoraPlanPipeline(VideoSysPipeline): @@ -242,13 +268,6 @@ def __init__( super().__init__() self._config = config - # not implemented - if config.version == "v120": - if config.num_gpus > 1: - raise NotImplementedError("v120 does not support multi-gpu inference") - if config.enable_pab: - raise NotImplementedError("v120 does not support PAB") - # init if tokenizer is None: if config.version == "v110": diff --git a/videosys/pipelines/vchitect/__init__.py b/videosys/pipelines/vchitect/__init__.py index 138c37f8..0ee6ac13 100644 --- a/videosys/pipelines/vchitect/__init__.py +++ b/videosys/pipelines/vchitect/__init__.py @@ -1,3 +1,3 @@ -from .pipeline_vchitect import VchitectConfig, VchitectXLPipeline +from .pipeline_vchitect import VchitectConfig, VchitectPABConfig, VchitectXLPipeline -__all__ = ["VchitectXLPipeline", "VchitectConfig"] +__all__ = ["VchitectXLPipeline", "VchitectConfig", "VchitectPABConfig"] diff --git a/videosys/utils/test.py b/videosys/utils/test.py new file mode 100644 index 00000000..aec6f7b9 --- /dev/null +++ b/videosys/utils/test.py @@ -0,0 +1,12 @@ +import functools + +import torch + + +def empty_cache(func): + @functools.wraps(func) + def wrapper(*args, **kwargs): + torch.cuda.empty_cache() + return func(*args, **kwargs) + + return wrapper