Skip to content

Commit

Permalink
[hotfix] fix CogVideoX parallel bug with 4 gpus (#221)
Browse files Browse the repository at this point in the history
* [UPD]1. fix 4 process bug;

* fix pad problem

---------

Co-authored-by: Xuanlei Zhao <[email protected]>
  • Loading branch information
gttiankai and oahzxl authored Sep 26, 2024
1 parent e48a642 commit ff918ec
Showing 1 changed file with 22 additions and 11 deletions.
33 changes: 22 additions & 11 deletions videosys/models/transformers/cogvideox_transformer_3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from diffusers.utils.torch_utils import maybe_allow_in_graph
from torch import nn

from videosys.core.comm import all_to_all_with_pad, gather_sequence, get_pad, set_pad, split_sequence
from videosys.core.comm import all_to_all_comm, gather_sequence, get_pad, set_pad, split_sequence
from videosys.core.pab_mgr import enable_pab, if_broadcast_spatial
from videosys.core.parallel_mgr import ParallelManager
from videosys.models.modules.embeddings import apply_rotary_emb
Expand Down Expand Up @@ -52,9 +52,26 @@ def _remove_extra_encoder(self, hidden_states, text_seq_length, attn):
for i in range(sp_size):
new_seq.append(split_seq[i][:, :, text_seq_length:])
hidden_states = torch.cat(new_seq, dim=2)

# remove padding added when all2all
# if pad is removed earlier than this
# the split size will be wrong
pad = get_pad("pad")
if pad > 0:
hidden_states = hidden_states.narrow(2, 0, hidden_states.size(2) - pad)
return hidden_states

def _add_extra_encoder(self, hidden_states, text_seq_length, attn):
# add padding for split and later all2all
# if pad is removed later than this
# the split size will be wrong
pad = get_pad("pad")
if pad > 0:
pad_shape = list(hidden_states.shape)
pad_shape[1] = pad
pad_tensor = torch.zeros(pad_shape, device=hidden_states.device, dtype=hidden_states.dtype)
hidden_states = torch.cat([hidden_states, pad_tensor], dim=1)

# current layout is [text, seq]
# we want to add the extra encoder info [text, 1/n seq, text, 1/n seq, ...]
sp_size = attn.parallel_manager.sp_size
Expand Down Expand Up @@ -97,10 +114,10 @@ def __call__(
attn.heads % attn.parallel_manager.sp_size == 0
), f"Number of heads {attn.heads} must be divisible by sequence parallel size {attn.parallel_manager.sp_size}"
attn_heads = attn.heads // attn.parallel_manager.sp_size
# normally we operate pad for every all2all. but for more convient implementation
# we move pad operation to encoder add and remove in cogvideo
query, key, value = map(
lambda x: all_to_all_with_pad(
x, attn.parallel_manager.sp_group, scatter_dim=2, gather_dim=1, gather_pad=get_pad("pad")
),
lambda x: all_to_all_comm(x, attn.parallel_manager.sp_group, scatter_dim=2, gather_dim=1),
[query, key, value],
)
else:
Expand Down Expand Up @@ -145,13 +162,7 @@ def __call__(
if attn.parallel_manager.sp_size > 1:
# add extra encoder for all_to_all
hidden_states = self._add_extra_encoder(hidden_states, text_seq_length, attn)
hidden_states = all_to_all_with_pad(
hidden_states,
attn.parallel_manager.sp_group,
scatter_dim=1,
gather_dim=2,
scatter_pad=get_pad("pad"),
)
hidden_states = all_to_all_comm(hidden_states, attn.parallel_manager.sp_group, scatter_dim=1, gather_dim=2)

# linear proj
hidden_states = attn.to_out[0](hidden_states)
Expand Down

0 comments on commit ff918ec

Please sign in to comment.