Skip to content

Commit

Permalink
integrate float8nocompile, an experimental feature for high performance
Browse files Browse the repository at this point in the history
float8 training in eager mode
  • Loading branch information
danielvegamyhre committed Jan 7, 2025
1 parent 2a44370 commit a4a1f74
Show file tree
Hide file tree
Showing 4 changed files with 30 additions and 7 deletions.
1 change: 1 addition & 0 deletions .python-version
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
3.10.15
5 changes: 5 additions & 0 deletions torchtitan/config_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -567,6 +567,11 @@ def __init__(self):
default="dynamic",
help="float8 scaling for input, dynamic (default) or delayed",
)
self.parser.add_argument(
"--float8.no_compile",
action="store_true",
help="use the float8nocompile prototype implementation",
)

# communications library settings
self.parser.add_argument(
Expand Down
30 changes: 23 additions & 7 deletions torchtitan/float8.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@ def __init__(self, job_config: JobConfig, parallel_dims: ParallelDims):
"torchao is not installed. Please install it to use float8 linear layers."
) from e

self.use_float8nocompile = float8_config.no_compile

# Mutates the model inplace replacing instances of torch.nn.Linear with Float8Linear
enable_fsdp_float8_all_gather = (
parallel_dims.dp_shard_enabled
Expand Down Expand Up @@ -90,14 +92,28 @@ def convert_to_float8_training(self, model: nn.Module):
if not self.enabled:
return

from torchao.float8 import convert_to_float8_training
# TODO: should we implicitly use this if self.compile is False, rather
# than having an explicit flag?
if self.use_float8nocompile:
logger.info("Using float8nocompile prototype")
from torchao.prototype.float8nocompile.float8nocompile_linear_utils import (
convert_to_float8_nocompile_training,
)

# Mutates the model inplace replacing instances of nn.Linear with Float8Linear
convert_to_float8_training(
model,
config=self.config,
module_filter_fn=lambda mod, fqn: fqn != "output",
)
convert_to_float8_nocompile_training(
model,
config=self.config,
module_filter_fn=lambda mod, fqn: fqn != "output",
)
else:
from torchao.float8 import convert_to_float8_training

# Mutates the model inplace replacing instances of nn.Linear with Float8Linear
convert_to_float8_training(
model,
config=self.config,
module_filter_fn=lambda mod, fqn: fqn != "output",
)
logger.info(
"Swapped to Float8Linear layers with enable_fsdp_float8_all_gather="
f"{self.config.enable_fsdp_float8_all_gather}"
Expand Down
1 change: 1 addition & 0 deletions train_configs/llama3_8b.toml
Original file line number Diff line number Diff line change
Expand Up @@ -56,3 +56,4 @@ selective_ac_option = 'op' # 'int' = ac every positive int layer or 'op', ac ba

[float8]
enable_float8_linear = false
no_compile = false # TODO: should this go in [experimental]?

0 comments on commit a4a1f74

Please sign in to comment.