Skip to content

Commit

Permalink
refactor: do block mask before torch compile
Browse files Browse the repository at this point in the history
  • Loading branch information
samsja committed Dec 10, 2024
1 parent 3a7f9cc commit 3fac57d
Show file tree
Hide file tree
Showing 5 changed files with 29 additions and 17 deletions.
2 changes: 1 addition & 1 deletion src/zeroband/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ def collate_fn(samples: list[dict[str, torch.LongTensor]]) -> dict[str, torch.Lo
inputs_ids.append(sample["input_ids"])
labels.append(sample["labels"])

seqlens.extend(torch.Tensor(sample["seqlens"]).long())
seqlens.append(torch.Tensor(sample["seqlens"]).long())

return {
"input_ids": torch.stack(inputs_ids, dim=0),
Expand Down
27 changes: 21 additions & 6 deletions src/zeroband/models/llama/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,22 @@

from torch.nn.attention.flex_attention import create_block_mask, flex_attention, BlockMask, _DEFAULT_SPARSE_BLOCK_SIZE

flex_attention_compiled = torch.compile(flex_attention, dynamic=False)
_flex_attention_compiled = torch.compile(flex_attention, dynamic=False)


# copied from https://github.com/pytorch/torchtune/blob/f2bd4bc25b24587aef40f486087412b9da8f1d94/torchtune/modules/attention_utils.py#L27
# We cannot do nested compile, but flex attention only has perf benefits
# when compiled. To insulate it from the compiler, we wrap it with
# compiler.disable so that it can be used regardless of whether the model
# is compiled or not, and flex attention always remains compiled.
@torch.compiler.disable(recursive=False)
def flex_attention_compiled(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
block_mask: BlockMask,
) -> torch.Tensor:
return _flex_attention_compiled(q, k, v, block_mask=block_mask)


@dataclass
Expand Down Expand Up @@ -252,6 +267,9 @@ def _flex_attention_with_seqlens(self, xq, xk, xv, block_mask: BlockMask) -> tor
output = flex_attention_compiled(xq, xk, xv, block_mask=block_mask)
output = output.transpose(1, 2).contiguous() # (bs, seqlen, n_local_heads, head_dim)
return output
# output = F.scaled_dot_product_attention(xq, xk, xv, is_causal=True)
# output = output.transpose(1, 2).contiguous() # (bs, seqlen, n_local_heads, head_dim)
# return output

def self_attention(
self, xq: torch.Tensor, xk: torch.Tensor, xv: torch.Tensor, block_mask: BlockMask | None = None
Expand Down Expand Up @@ -462,23 +480,20 @@ def _precompute_freqs_cis(self) -> torch.Tensor:
self.model_args.rope_theta,
)

def forward(self, tokens: torch.Tensor, seqlens: torch.Tensor | None = None):
def forward(self, tokens: torch.Tensor, block_mask: BlockMask | None = None):
"""
Perform a forward pass through the Transformer model.
Args:
tokens (torch.Tensor): Input token indices.
seqlens (torch.Tensor | None): Sequence lengths tensor for packing.
block_mask (BlockMask | None): Block mask for attention.
Returns:
torch.Tensor: Output logits after applying the Transformer model.
"""
# passthrough for nonexistent layers, allows easy configuration of pipeline parallel stages
h = self.tok_embeddings(tokens) if self.tok_embeddings else tokens

block_mask = create_block_mask_from_seqlens(seqlens) if seqlens is not None else None

for layer in self.layers.values():
h = layer(h, self.freqs_cis, block_mask=block_mask)

Expand Down
11 changes: 4 additions & 7 deletions src/zeroband/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from zeroband.diloco import Diloco, DilocoConfig
from zeroband.comms import ElasticDeviceMesh
from zeroband.loss import cross_entropy_max_z_loss
from zeroband.models.llama.model import create_block_mask_from_seqlens

from zeroband.utils import (
FakeTokenizer,
Expand Down Expand Up @@ -361,15 +362,11 @@ def train(config: Config):
labels = batch["labels"].to("cuda")
if config.train.sequence_packing:
seqlens = [seqlen.to("cuda") for seqlen in batch["seqlens"]]

# seqlens has a dynamic shape but fixed dimension, this allow to still torch compile
# https://pytorch.org/docs/stable/torch.compiler_dynamic_shapes.html
# torch._dynamo.mark_dynamic(seqlens, 0)
logger.debug(f"seqlens: {seqlens}")
block_mask = create_block_mask_from_seqlens(seqlens) if seqlens is not None else None
else:
seqlens = None
block_mask = None

logits = model(tokens=input_ids, seqlens=seqlens).contiguous()
logits = model(tokens=input_ids, block_mask=block_mask).contiguous()
flatten_logits = rearrange(logits, "b seq vocab -> (b seq) vocab")
flatten_labels = rearrange(labels, "b seq -> (b seq)")

Expand Down
4 changes: 2 additions & 2 deletions tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,9 +217,9 @@ def test_end_to_end_packing(llama_config: ModelArgs):
input_ = torch.randint(1, llama_config.vocab_size, (BS, SEQ_LEN)).to("cuda")

seqlens = [torch.Tensor([SEQ_LEN // 4, SEQ_LEN // 4, SEQ_LEN // 2]).int().to("cuda") for _ in range(BS)]

block_mask = create_block_mask_from_seqlens(seqlens)
with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
output = model(input_, seqlens=seqlens)
output = model(input_, block_mask=block_mask)

assert output.shape == (BS, SEQ_LEN, llama_config.vocab_size)

Expand Down
2 changes: 1 addition & 1 deletion tests/test_torchrun/test_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def test_z_loss():
_test_multi_gpu(num_gpus, "debug/normal.toml", extra_args=["--optim.z_loss"])


@pytest.mark.parametrize("packing", [True]) # , False])
@pytest.mark.parametrize("packing", [True, False])
def test_packing(packing: bool):
num_gpus = [2, 1]
packing_arg = "--train.sequence_packing" if packing else "--no-train.sequence_packing"
Expand Down

0 comments on commit 3fac57d

Please sign in to comment.