Skip to content

Commit

Permalink
add options to use AC in float8nocompile linear layers; add support f…
Browse files Browse the repository at this point in the history
…or only compiling linear layers
  • Loading branch information
danielvegamyhre committed Jan 10, 2025
1 parent a4a1f74 commit 9e19e7f
Show file tree
Hide file tree
Showing 4 changed files with 28 additions and 8 deletions.
7 changes: 6 additions & 1 deletion torchtitan/config_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -568,10 +568,15 @@ def __init__(self):
help="float8 scaling for input, dynamic (default) or delayed",
)
self.parser.add_argument(
"--float8.no_compile",
"--float8.float8nocompile",
action="store_true",
help="use the float8nocompile prototype implementation",
)
self.parser.add_argument(
"--float8.float8nocompile_ac",
action="store_true",
help="use activation checkpointing with float8nocompile linear layers",
)

# communications library settings
self.parser.add_argument(
Expand Down
5 changes: 4 additions & 1 deletion torchtitan/float8.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,8 @@ def __init__(self, job_config: JobConfig, parallel_dims: ParallelDims):
"torchao is not installed. Please install it to use float8 linear layers."
) from e

self.use_float8nocompile = float8_config.no_compile
self.use_float8nocompile = float8_config.float8nocompile
self.use_float8nocompile_ac = float8_config.float8nocompile_ac

# Mutates the model inplace replacing instances of torch.nn.Linear with Float8Linear
enable_fsdp_float8_all_gather = (
Expand Down Expand Up @@ -104,8 +105,10 @@ def convert_to_float8_training(self, model: nn.Module):
model,
config=self.config,
module_filter_fn=lambda mod, fqn: fqn != "output",
use_activation_checkpointing=self.use_float8nocompile_ac,
)
else:
logger.info("Using float8 training")
from torchao.float8 import convert_to_float8_training

# Mutates the model inplace replacing instances of nn.Linear with Float8Linear
Expand Down
21 changes: 16 additions & 5 deletions torchtitan/parallelisms/parallelize_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
# This file applies the PT-D parallelisms (except pipeline parallelism) and various
# training techniques (e.g. activation checkpointing and compile) to the Llama model.

import os
from collections import defaultdict

import torch
Expand Down Expand Up @@ -299,11 +300,21 @@ def apply_compile(model: nn.Module):
Apply torch.compile to each TransformerBlock, which makes compilation efficient due to
repeated structure. Alternatively one can compile the whole model (after applying DP).
"""
for layer_id, transformer_block in model.layers.named_children():
transformer_block = torch.compile(transformer_block, fullgraph=True)
model.layers.register_module(layer_id, transformer_block)

logger.info("Compiling each TransformerBlock with torch.compile")
compile_linear_only = bool(os.environ.get("TORCHTITAN_COMPILE_LINEAR_ONLY", False))

if compile_linear_only:
logger.info("Compiling linear layers with torch.compile")
for name, child in model.named_children():
if isinstance(child, torch.nn.Linear):
new_child = torch.compile(child)
setattr(model, name, new_child)
else:
apply_compile(child)
else:
logger.info("Compiling each TransformerBlock with torch.compile")
for layer_id, transformer_block in model.layers.named_children():
transformer_block = torch.compile(transformer_block, fullgraph=True)
model.layers.register_module(layer_id, transformer_block)


def apply_fsdp(
Expand Down
3 changes: 2 additions & 1 deletion train_configs/llama3_8b.toml
Original file line number Diff line number Diff line change
Expand Up @@ -56,4 +56,5 @@ selective_ac_option = 'op' # 'int' = ac every positive int layer or 'op', ac ba

[float8]
enable_float8_linear = false
no_compile = false # TODO: should this go in [experimental]?
float8nocompile = false # TODO: should this go in [experimental]?
float8nocompile_ac = false

0 comments on commit 9e19e7f

Please sign in to comment.