From 9e19e7fa91933e5c3010a215813e5237ce30c5b1 Mon Sep 17 00:00:00 2001 From: Daniel Vega-Myhre Date: Fri, 10 Jan 2025 09:39:08 -0800 Subject: [PATCH] add options to use AC in float8nocompile linear layers; add support for only compiling linear layers --- torchtitan/config_manager.py | 7 ++++++- torchtitan/float8.py | 5 ++++- torchtitan/parallelisms/parallelize_llama.py | 21 +++++++++++++++----- train_configs/llama3_8b.toml | 3 ++- 4 files changed, 28 insertions(+), 8 deletions(-) diff --git a/torchtitan/config_manager.py b/torchtitan/config_manager.py index 2177000d..1fc6916b 100644 --- a/torchtitan/config_manager.py +++ b/torchtitan/config_manager.py @@ -568,10 +568,15 @@ def __init__(self): help="float8 scaling for input, dynamic (default) or delayed", ) self.parser.add_argument( - "--float8.no_compile", + "--float8.float8nocompile", action="store_true", help="use the float8nocompile prototype implementation", ) + self.parser.add_argument( + "--float8.float8nocompile_ac", + action="store_true", + help="use activation checkpointing with float8nocompile linear layers", + ) # communications library settings self.parser.add_argument( diff --git a/torchtitan/float8.py b/torchtitan/float8.py index 638c1ddc..20540d7b 100644 --- a/torchtitan/float8.py +++ b/torchtitan/float8.py @@ -47,7 +47,8 @@ def __init__(self, job_config: JobConfig, parallel_dims: ParallelDims): "torchao is not installed. Please install it to use float8 linear layers." ) from e - self.use_float8nocompile = float8_config.no_compile + self.use_float8nocompile = float8_config.float8nocompile + self.use_float8nocompile_ac = float8_config.float8nocompile_ac # Mutates the model inplace replacing instances of torch.nn.Linear with Float8Linear enable_fsdp_float8_all_gather = ( @@ -104,8 +105,10 @@ def convert_to_float8_training(self, model: nn.Module): model, config=self.config, module_filter_fn=lambda mod, fqn: fqn != "output", + use_activation_checkpointing=self.use_float8nocompile_ac, ) else: + logger.info("Using float8 training") from torchao.float8 import convert_to_float8_training # Mutates the model inplace replacing instances of nn.Linear with Float8Linear diff --git a/torchtitan/parallelisms/parallelize_llama.py b/torchtitan/parallelisms/parallelize_llama.py index 9728569a..f2888350 100644 --- a/torchtitan/parallelisms/parallelize_llama.py +++ b/torchtitan/parallelisms/parallelize_llama.py @@ -7,6 +7,7 @@ # This file applies the PT-D parallelisms (except pipeline parallelism) and various # training techniques (e.g. activation checkpointing and compile) to the Llama model. +import os from collections import defaultdict import torch @@ -299,11 +300,21 @@ def apply_compile(model: nn.Module): Apply torch.compile to each TransformerBlock, which makes compilation efficient due to repeated structure. Alternatively one can compile the whole model (after applying DP). """ - for layer_id, transformer_block in model.layers.named_children(): - transformer_block = torch.compile(transformer_block, fullgraph=True) - model.layers.register_module(layer_id, transformer_block) - - logger.info("Compiling each TransformerBlock with torch.compile") + compile_linear_only = bool(os.environ.get("TORCHTITAN_COMPILE_LINEAR_ONLY", False)) + + if compile_linear_only: + logger.info("Compiling linear layers with torch.compile") + for name, child in model.named_children(): + if isinstance(child, torch.nn.Linear): + new_child = torch.compile(child) + setattr(model, name, new_child) + else: + apply_compile(child) + else: + logger.info("Compiling each TransformerBlock with torch.compile") + for layer_id, transformer_block in model.layers.named_children(): + transformer_block = torch.compile(transformer_block, fullgraph=True) + model.layers.register_module(layer_id, transformer_block) def apply_fsdp( diff --git a/train_configs/llama3_8b.toml b/train_configs/llama3_8b.toml index 9873d129..7890eb0c 100644 --- a/train_configs/llama3_8b.toml +++ b/train_configs/llama3_8b.toml @@ -56,4 +56,5 @@ selective_ac_option = 'op' # 'int' = ac every positive int layer or 'op', ac ba [float8] enable_float8_linear = false -no_compile = false # TODO: should this go in [experimental]? +float8nocompile = false # TODO: should this go in [experimental]? +float8nocompile_ac = false