diff --git a/videosys/models/transformers/cogvideox_transformer_3d.py b/videosys/models/transformers/cogvideox_transformer_3d.py index be17c5d6..e568e06c 100644 --- a/videosys/models/transformers/cogvideox_transformer_3d.py +++ b/videosys/models/transformers/cogvideox_transformer_3d.py @@ -22,7 +22,7 @@ from diffusers.utils.torch_utils import maybe_allow_in_graph from torch import nn -from videosys.core.comm import all_to_all_with_pad, gather_sequence, get_pad, set_pad, split_sequence +from videosys.core.comm import all_to_all_comm, gather_sequence, get_pad, set_pad, split_sequence from videosys.core.pab_mgr import enable_pab, if_broadcast_spatial from videosys.core.parallel_mgr import ParallelManager from videosys.models.modules.embeddings import apply_rotary_emb @@ -52,9 +52,26 @@ def _remove_extra_encoder(self, hidden_states, text_seq_length, attn): for i in range(sp_size): new_seq.append(split_seq[i][:, :, text_seq_length:]) hidden_states = torch.cat(new_seq, dim=2) + + # remove padding added when all2all + # if pad is removed earlier than this + # the split size will be wrong + pad = get_pad("pad") + if pad > 0: + hidden_states = hidden_states.narrow(2, 0, hidden_states.size(2) - pad) return hidden_states def _add_extra_encoder(self, hidden_states, text_seq_length, attn): + # add padding for split and later all2all + # if pad is removed later than this + # the split size will be wrong + pad = get_pad("pad") + if pad > 0: + pad_shape = list(hidden_states.shape) + pad_shape[1] = pad + pad_tensor = torch.zeros(pad_shape, device=hidden_states.device, dtype=hidden_states.dtype) + hidden_states = torch.cat([hidden_states, pad_tensor], dim=1) + # current layout is [text, seq] # we want to add the extra encoder info [text, 1/n seq, text, 1/n seq, ...] sp_size = attn.parallel_manager.sp_size @@ -97,10 +114,10 @@ def __call__( attn.heads % attn.parallel_manager.sp_size == 0 ), f"Number of heads {attn.heads} must be divisible by sequence parallel size {attn.parallel_manager.sp_size}" attn_heads = attn.heads // attn.parallel_manager.sp_size + # normally we operate pad for every all2all. but for more convient implementation + # we move pad operation to encoder add and remove in cogvideo query, key, value = map( - lambda x: all_to_all_with_pad( - x, attn.parallel_manager.sp_group, scatter_dim=2, gather_dim=1, gather_pad=get_pad("pad") - ), + lambda x: all_to_all_comm(x, attn.parallel_manager.sp_group, scatter_dim=2, gather_dim=1), [query, key, value], ) else: @@ -145,13 +162,7 @@ def __call__( if attn.parallel_manager.sp_size > 1: # add extra encoder for all_to_all hidden_states = self._add_extra_encoder(hidden_states, text_seq_length, attn) - hidden_states = all_to_all_with_pad( - hidden_states, - attn.parallel_manager.sp_group, - scatter_dim=1, - gather_dim=2, - scatter_pad=get_pad("pad"), - ) + hidden_states = all_to_all_comm(hidden_states, attn.parallel_manager.sp_group, scatter_dim=1, gather_dim=2) # linear proj hidden_states = attn.to_out[0](hidden_states)